Skip to content

Commit 25ca6b8

Browse files
authored
Add iterative reweighting to PruneOptimizer (#4283)
Add back iterative reweighting
1 parent 3dd137c commit 25ca6b8

9 files changed

Lines changed: 508 additions & 18 deletions

File tree

Lines changed: 397 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,397 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import math
8+
import random
9+
import unittest
10+
11+
import torch
12+
import torch.distributed as dist
13+
from torch.distributed.tensor import distribute_tensor
14+
from torch.distributed.tensor.placement_types import Replicate, Shard
15+
from torch.testing._internal import common_utils
16+
17+
from test.prototype.pat.test_common import (
18+
DistributedTestMixin,
19+
TwoLayerMLP,
20+
make_prox_kwargs,
21+
optim_step,
22+
)
23+
from torchao.prototype.pat.group import Dim0Grouper, Dim1Grouper
24+
from torchao.prototype.pat.group.grouper import ElemGrouper
25+
from torchao.prototype.pat.optim import ProxGroupLasso, ProxLasso, PruneOptimizer
26+
from torchao.prototype.pat.optim.iterative_reweight import IterativeReweight
27+
from torchao.prototype.pat.utils import get_param_groups
28+
29+
30+
class TestIterativeReweight(common_utils.TestCase):
31+
"""Unit tests for the IterativeReweight class directly."""
32+
33+
def test_reweight_formula(self):
34+
"""1 / (group_norm / sigma + eps) produces correct values."""
35+
eps = 1e-3
36+
reweight = IterativeReweight(reweight_freq=1, eps=eps)
37+
group_norm = torch.tensor([1.0, 2.0, 4.0])
38+
sigma = torch.tensor([2.0, 2.0, 2.0])
39+
40+
result = reweight(group_norm.clone(), sigma.clone())
41+
expected = 1.0 / (torch.tensor([1.0, 2.0, 4.0]) / (sigma + eps) + eps)
42+
self.assertEqual(result, expected)
43+
44+
def test_small_norm_high_reweight(self):
45+
"""Groups with norm << sigma get large tau_reweight."""
46+
reweight = IterativeReweight(reweight_freq=1, eps=1e-3)
47+
sigma = torch.tensor([1.0, 1.0])
48+
small_norm = torch.tensor([0.001, 0.001])
49+
large_norm = torch.tensor([10.0, 10.0])
50+
51+
rw_small = reweight(small_norm.clone(), sigma.clone())
52+
rw_large = reweight(large_norm.clone(), sigma.clone())
53+
self.assertTrue((rw_small > rw_large).all())
54+
55+
def test_eps_prevents_division_by_zero(self):
56+
"""When group_norm == 0, result is 1/eps."""
57+
eps = 1e-3
58+
reweight = IterativeReweight(reweight_freq=1, eps=eps)
59+
group_norm = torch.tensor([0.0])
60+
sigma = torch.tensor([1.0])
61+
62+
result = reweight(group_norm.clone(), sigma.clone())
63+
self.assertAlmostEqual(result.item(), 1.0 / eps, places=1)
64+
65+
def test_should_update_at_end_step(self):
66+
"""True when step == end_step and on-frequency; False one step later."""
67+
rw = IterativeReweight(reweight_freq=2, reweight_end_step=6)
68+
self.assertTrue(rw.should_update(6))
69+
self.assertFalse(rw.should_update(7))
70+
71+
def test_should_update_past_end_step(self):
72+
"""Updates at steps 0..end_step, stops after."""
73+
rw = IterativeReweight(reweight_freq=1, reweight_end_step=3)
74+
for step in range(4):
75+
self.assertTrue(rw.should_update(step), f"step={step}")
76+
for step in range(4, 7):
77+
self.assertFalse(rw.should_update(step), f"step={step}")
78+
79+
def test_should_update_step_zero_with_freq_gt_one(self):
80+
"""Step 0 is always on-frequency (0 % freq == 0)."""
81+
rw = IterativeReweight(reweight_freq=3, reweight_end_step=100)
82+
self.assertTrue(rw.should_update(0))
83+
self.assertFalse(rw.should_update(1))
84+
self.assertTrue(rw.should_update(3))
85+
86+
87+
class TestApplyProxReweight(common_utils.TestCase):
88+
"""Tests _apply_prox with tau_reweight != 1.0 across branches."""
89+
90+
@common_utils.parametrize(
91+
"grouper_cls,prox_cls,tau_reweight,disable_vmap",
92+
[
93+
(ElemGrouper, ProxLasso, 2.0, True),
94+
(Dim0Grouper, ProxLasso, 3.0, False),
95+
(Dim0Grouper, ProxGroupLasso, 2.0, False),
96+
],
97+
)
98+
def test_tau_reweight_scales_threshold(
99+
self, grouper_cls, prox_cls, tau_reweight, disable_vmap
100+
):
101+
"""tau_reweight multiplies into the pruning threshold correctly."""
102+
torch.manual_seed(42)
103+
reg_lambda = 0.5
104+
gamma = 2.0
105+
106+
p = torch.randn(4, 6)
107+
p_ref = p.clone()
108+
109+
# Compute reference manually
110+
if prox_cls is ProxGroupLasso:
111+
tau = math.sqrt(p.numel() // p.size(0)) # group_size for Dim0Grouper
112+
threshold = reg_lambda * tau * tau_reweight * gamma
113+
for i in range(4):
114+
row = p_ref[i]
115+
norm = torch.linalg.vector_norm(row)
116+
row.mul_(max(1 - threshold / norm.item(), 0))
117+
else:
118+
threshold = reg_lambda * gamma * tau_reweight
119+
mult_ref = (1 - threshold / p_ref.abs()).clamp(min=0)
120+
p_ref.mul_(mult_ref)
121+
122+
grouper = grouper_cls(p)
123+
prox_kwargs = make_prox_kwargs(gamma, disable_vmap=disable_vmap)
124+
PruneOptimizer._apply_prox(
125+
grouper, prox_cls(reg_lambda), p, tau_reweight=tau_reweight, **prox_kwargs
126+
)
127+
128+
self.assertEqual(p, p_ref)
129+
130+
def test_reweight_monotonicity(self):
131+
"""Higher tau_reweight zeros more elements; lower zeros fewer."""
132+
torch.manual_seed(42)
133+
reg_lambda = 0.5
134+
gamma = 1.0
135+
136+
data = torch.randn(4, 6)
137+
zeros = {}
138+
for tw in [0.1, 1.0, 5.0]:
139+
p = data.clone()
140+
grouper = Dim0Grouper(p)
141+
prox_kwargs = make_prox_kwargs(gamma)
142+
z, _, _ = PruneOptimizer._apply_prox(
143+
grouper, ProxGroupLasso(reg_lambda), p, tau_reweight=tw, **prox_kwargs
144+
)
145+
zeros[tw] = z
146+
147+
self.assertLessEqual(zeros[0.1], zeros[1.0])
148+
self.assertGreaterEqual(zeros[5.0], zeros[1.0])
149+
150+
151+
@unittest.skipUnless(dist.is_available(), "torch.distributed not available")
152+
class TestApplyProxReweightDTensor(DistributedTestMixin, common_utils.TestCase):
153+
"""DTensor tests for _apply_prox with tau_reweight."""
154+
155+
@common_utils.parametrize(
156+
"GrouperCls,placements,prox_cls",
157+
[
158+
(Dim0Grouper, (Shard(0), Replicate()), ProxGroupLasso),
159+
(Dim1Grouper, (Shard(1), Replicate()), ProxGroupLasso),
160+
(Dim0Grouper, (Shard(0), Replicate()), ProxLasso),
161+
(Dim1Grouper, (Shard(1), Replicate()), ProxLasso),
162+
],
163+
)
164+
def test_dtensor_matches_regular_with_reweight(
165+
self, GrouperCls, placements, prox_cls
166+
):
167+
"""DTensor vs regular tensor equivalence with tau_reweight."""
168+
torch.manual_seed(42)
169+
reg_lambda = 0.5
170+
gamma = 2.0
171+
tau_reweight = 2.5
172+
173+
p_regular = torch.randn(4, 6)
174+
p_dtensor = distribute_tensor(
175+
p_regular.clone(), device_mesh=self.mesh, placements=placements
176+
)
177+
178+
prox_kwargs = make_prox_kwargs(gamma)
179+
180+
grouper_reg = GrouperCls(p_regular)
181+
z_reg, _, _ = PruneOptimizer._apply_prox(
182+
grouper_reg,
183+
prox_cls(reg_lambda),
184+
p_regular,
185+
tau_reweight=tau_reweight,
186+
**prox_kwargs,
187+
)
188+
189+
grouper_dt = GrouperCls(p_dtensor)
190+
z_dt, _, _ = PruneOptimizer._apply_prox(
191+
grouper_dt,
192+
prox_cls(reg_lambda),
193+
p_dtensor,
194+
tau_reweight=tau_reweight,
195+
**prox_kwargs,
196+
)
197+
198+
self.assertEqual(z_reg, z_dt)
199+
self.assertEqual(p_regular, p_dtensor.full_tensor())
200+
201+
@common_utils.parametrize(
202+
"GrouperCls,placements,prox_cls",
203+
[
204+
(Dim0Grouper, (Shard(0), Replicate()), ProxGroupLasso),
205+
(Dim1Grouper, (Shard(1), Replicate()), ProxGroupLasso),
206+
(Dim0Grouper, (Shard(0), Replicate()), ProxLasso),
207+
(Dim1Grouper, (Shard(1), Replicate()), ProxLasso),
208+
],
209+
)
210+
def test_dtensor_gamma_index_slope_with_tensor_reweight(
211+
self, GrouperCls, placements, prox_cls
212+
):
213+
"""DTensor with gamma_index_slope > 0 and tensor tau_reweight."""
214+
torch.manual_seed(42)
215+
reg_lambda = 0.5
216+
gamma = 2.0
217+
218+
p_regular = torch.randn(4, 6)
219+
n_groups = p_regular.size(0) if GrouperCls == Dim0Grouper else p_regular.size(1)
220+
tau_reweight = torch.rand(n_groups) + 0.5 # [0.5, 1.5)
221+
222+
p_dtensor = distribute_tensor(
223+
p_regular.clone(), device_mesh=self.mesh, placements=placements
224+
)
225+
226+
prox_kwargs = make_prox_kwargs(gamma, gamma_index_slope=0.5)
227+
228+
grouper_reg = GrouperCls(p_regular)
229+
z_reg, _, _ = PruneOptimizer._apply_prox(
230+
grouper_reg,
231+
prox_cls(reg_lambda),
232+
p_regular,
233+
tau_reweight=tau_reweight.clone(),
234+
**prox_kwargs,
235+
)
236+
237+
grouper_dt = GrouperCls(p_dtensor)
238+
z_dt, _, _ = PruneOptimizer._apply_prox(
239+
grouper_dt,
240+
prox_cls(reg_lambda),
241+
p_dtensor,
242+
tau_reweight=tau_reweight.clone(),
243+
**prox_kwargs,
244+
)
245+
246+
self.assertEqual(z_reg, z_dt)
247+
self.assertEqual(p_regular, p_dtensor.full_tensor())
248+
249+
250+
class TestPruneOptimizerReweight(common_utils.TestCase):
251+
"""End-to-end tests using PruneOptimizer with reweight_tau_freq > 0."""
252+
253+
def test_sigma_initialized_at_warmup_end(self):
254+
"""After warmup, state['sigma'] exists for regularized params."""
255+
torch.manual_seed(42)
256+
model = TwoLayerMLP(input_size=10, output_size=2)
257+
prune_config = model._linear_prune_config()
258+
param_groups = get_param_groups(model, prune_config, verbose=False)
259+
warmup = 2
260+
optimizer = PruneOptimizer(
261+
torch.optim.SGD(param_groups, lr=0.1),
262+
reg_lambda=1.0,
263+
warmup_steps=warmup,
264+
reweight_tau_freq=1,
265+
)
266+
267+
dummy_input = torch.randn(10, 10)
268+
label = torch.randint(0, 2, (10,))
269+
for step in range(5):
270+
optim_step(model, optimizer, dummy_input, label, step)
271+
if step < warmup:
272+
for group in optimizer.regularized_param_groups():
273+
for p in group["params"]:
274+
self.assertNotIn("sigma", optimizer.state[p])
275+
elif step == warmup:
276+
for group in optimizer.regularized_param_groups():
277+
for p in group["params"]:
278+
self.assertIn("sigma", optimizer.state[p])
279+
280+
def test_tau_reweight_updated_at_freq(self):
281+
"""state['tau_reweight'] is updated every reweight_tau_freq steps."""
282+
torch.manual_seed(42)
283+
model = TwoLayerMLP(input_size=10, output_size=2)
284+
prune_config = model._linear_prune_config()
285+
param_groups = get_param_groups(model, prune_config, verbose=False)
286+
optimizer = PruneOptimizer(
287+
torch.optim.SGD(param_groups, lr=0.1),
288+
reg_lambda=1.0,
289+
warmup_steps=0,
290+
reweight_tau_freq=3,
291+
)
292+
293+
dummy_input = torch.randn(10, 10)
294+
label = torch.randint(0, 2, (10,))
295+
for step in range(10):
296+
optim_step(model, optimizer, dummy_input, label, step)
297+
298+
has_tau_reweight = any(
299+
"tau_reweight" in optimizer.state[p]
300+
for group in optimizer.regularized_param_groups()
301+
for p in group["params"]
302+
)
303+
self.assertTrue(has_tau_reweight)
304+
305+
def test_no_reweight_when_freq_zero(self):
306+
"""With reweight_tau_freq=0, no sigma/tau_reweight in state."""
307+
torch.manual_seed(42)
308+
model = TwoLayerMLP(input_size=10, output_size=2)
309+
prune_config = model._linear_prune_config()
310+
param_groups = get_param_groups(model, prune_config, verbose=False)
311+
optimizer = PruneOptimizer(
312+
torch.optim.SGD(param_groups, lr=0.1),
313+
reg_lambda=1.0,
314+
warmup_steps=0,
315+
reweight_tau_freq=0,
316+
)
317+
318+
dummy_input = torch.randn(10, 10)
319+
label = torch.randint(0, 2, (10,))
320+
for step in range(10):
321+
optim_step(model, optimizer, dummy_input, label, step)
322+
323+
for group in optimizer.regularized_param_groups():
324+
for p in group["params"]:
325+
self.assertNotIn("sigma", optimizer.state[p])
326+
self.assertNotIn("tau_reweight", optimizer.state[p])
327+
328+
def test_tau_reweight_frozen_after_end_step(self):
329+
"""tau_reweight stops updating after reweight_tau_end_step."""
330+
torch.manual_seed(42)
331+
model = TwoLayerMLP(input_size=10, output_size=2)
332+
prune_config = model._linear_prune_config()
333+
param_groups = get_param_groups(model, prune_config, verbose=False)
334+
end_step = 5
335+
optimizer = PruneOptimizer(
336+
torch.optim.SGD(param_groups, lr=0.1),
337+
reg_lambda=1.0,
338+
warmup_steps=0,
339+
reweight_tau_freq=1,
340+
reweight_tau_end_step=end_step,
341+
)
342+
343+
dummy_input = torch.randn(20, 10)
344+
label = torch.randint(0, 2, (20,))
345+
for step in range(10):
346+
optim_step(model, optimizer, dummy_input, label, step)
347+
348+
# Capture tau_reweight after it should have frozen
349+
frozen = {
350+
id(p): optimizer.state[p]["tau_reweight"].clone()
351+
for group in optimizer.regularized_param_groups()
352+
for p in group["params"]
353+
if "tau_reweight" in optimizer.state[p]
354+
}
355+
self.assertTrue(len(frozen) > 0, "tau_reweight should exist")
356+
357+
for step in range(10, 15):
358+
optim_step(model, optimizer, dummy_input, label, step)
359+
360+
for group in optimizer.regularized_param_groups():
361+
for p in group["params"]:
362+
self.assertEqual(optimizer.state[p]["tau_reweight"], frozen[id(p)])
363+
364+
def test_reweight_with_group_lasso(self):
365+
"""End-to-end with Dim0Grouper + ProxGroupLasso (hits vmap branch)."""
366+
torch.manual_seed(42)
367+
model = TwoLayerMLP(input_size=10, output_size=2)
368+
prune_config = model._group_lasso_prune_config()
369+
param_groups = get_param_groups(model, prune_config, verbose=False)
370+
optimizer = PruneOptimizer(
371+
torch.optim.SGD(param_groups, lr=0.1),
372+
reg_lambda=1.0,
373+
warmup_steps=0,
374+
reweight_tau_freq=2,
375+
)
376+
377+
dummy_input = torch.randn(10, 10)
378+
label = torch.randint(0, 2, (10,))
379+
for step in range(10):
380+
optim_step(model, optimizer, dummy_input, label, step)
381+
382+
for group in optimizer.regularized_param_groups():
383+
for p in group["params"]:
384+
state = optimizer.state[p]
385+
self.assertIn("sigma", state)
386+
self.assertIn("tau_reweight", state)
387+
n_groups = p.size(0) # Dim0Grouper groups along dim 0
388+
self.assertEqual(state["tau_reweight"].numel(), n_groups)
389+
390+
391+
common_utils.instantiate_parametrized_tests(TestApplyProxReweight)
392+
common_utils.instantiate_parametrized_tests(TestApplyProxReweightDTensor)
393+
394+
if __name__ == "__main__":
395+
random.seed(0)
396+
torch.manual_seed(0)
397+
unittest.main()

0 commit comments

Comments
 (0)