|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import math |
| 8 | +import random |
| 9 | +import unittest |
| 10 | + |
| 11 | +import torch |
| 12 | +import torch.distributed as dist |
| 13 | +from torch.distributed.tensor import distribute_tensor |
| 14 | +from torch.distributed.tensor.placement_types import Replicate, Shard |
| 15 | +from torch.testing._internal import common_utils |
| 16 | + |
| 17 | +from test.prototype.pat.test_common import ( |
| 18 | + DistributedTestMixin, |
| 19 | + TwoLayerMLP, |
| 20 | + make_prox_kwargs, |
| 21 | + optim_step, |
| 22 | +) |
| 23 | +from torchao.prototype.pat.group import Dim0Grouper, Dim1Grouper |
| 24 | +from torchao.prototype.pat.group.grouper import ElemGrouper |
| 25 | +from torchao.prototype.pat.optim import ProxGroupLasso, ProxLasso, PruneOptimizer |
| 26 | +from torchao.prototype.pat.optim.iterative_reweight import IterativeReweight |
| 27 | +from torchao.prototype.pat.utils import get_param_groups |
| 28 | + |
| 29 | + |
| 30 | +class TestIterativeReweight(common_utils.TestCase): |
| 31 | + """Unit tests for the IterativeReweight class directly.""" |
| 32 | + |
| 33 | + def test_reweight_formula(self): |
| 34 | + """1 / (group_norm / sigma + eps) produces correct values.""" |
| 35 | + eps = 1e-3 |
| 36 | + reweight = IterativeReweight(reweight_freq=1, eps=eps) |
| 37 | + group_norm = torch.tensor([1.0, 2.0, 4.0]) |
| 38 | + sigma = torch.tensor([2.0, 2.0, 2.0]) |
| 39 | + |
| 40 | + result = reweight(group_norm.clone(), sigma.clone()) |
| 41 | + expected = 1.0 / (torch.tensor([1.0, 2.0, 4.0]) / (sigma + eps) + eps) |
| 42 | + self.assertEqual(result, expected) |
| 43 | + |
| 44 | + def test_small_norm_high_reweight(self): |
| 45 | + """Groups with norm << sigma get large tau_reweight.""" |
| 46 | + reweight = IterativeReweight(reweight_freq=1, eps=1e-3) |
| 47 | + sigma = torch.tensor([1.0, 1.0]) |
| 48 | + small_norm = torch.tensor([0.001, 0.001]) |
| 49 | + large_norm = torch.tensor([10.0, 10.0]) |
| 50 | + |
| 51 | + rw_small = reweight(small_norm.clone(), sigma.clone()) |
| 52 | + rw_large = reweight(large_norm.clone(), sigma.clone()) |
| 53 | + self.assertTrue((rw_small > rw_large).all()) |
| 54 | + |
| 55 | + def test_eps_prevents_division_by_zero(self): |
| 56 | + """When group_norm == 0, result is 1/eps.""" |
| 57 | + eps = 1e-3 |
| 58 | + reweight = IterativeReweight(reweight_freq=1, eps=eps) |
| 59 | + group_norm = torch.tensor([0.0]) |
| 60 | + sigma = torch.tensor([1.0]) |
| 61 | + |
| 62 | + result = reweight(group_norm.clone(), sigma.clone()) |
| 63 | + self.assertAlmostEqual(result.item(), 1.0 / eps, places=1) |
| 64 | + |
| 65 | + def test_should_update_at_end_step(self): |
| 66 | + """True when step == end_step and on-frequency; False one step later.""" |
| 67 | + rw = IterativeReweight(reweight_freq=2, reweight_end_step=6) |
| 68 | + self.assertTrue(rw.should_update(6)) |
| 69 | + self.assertFalse(rw.should_update(7)) |
| 70 | + |
| 71 | + def test_should_update_past_end_step(self): |
| 72 | + """Updates at steps 0..end_step, stops after.""" |
| 73 | + rw = IterativeReweight(reweight_freq=1, reweight_end_step=3) |
| 74 | + for step in range(4): |
| 75 | + self.assertTrue(rw.should_update(step), f"step={step}") |
| 76 | + for step in range(4, 7): |
| 77 | + self.assertFalse(rw.should_update(step), f"step={step}") |
| 78 | + |
| 79 | + def test_should_update_step_zero_with_freq_gt_one(self): |
| 80 | + """Step 0 is always on-frequency (0 % freq == 0).""" |
| 81 | + rw = IterativeReweight(reweight_freq=3, reweight_end_step=100) |
| 82 | + self.assertTrue(rw.should_update(0)) |
| 83 | + self.assertFalse(rw.should_update(1)) |
| 84 | + self.assertTrue(rw.should_update(3)) |
| 85 | + |
| 86 | + |
| 87 | +class TestApplyProxReweight(common_utils.TestCase): |
| 88 | + """Tests _apply_prox with tau_reweight != 1.0 across branches.""" |
| 89 | + |
| 90 | + @common_utils.parametrize( |
| 91 | + "grouper_cls,prox_cls,tau_reweight,disable_vmap", |
| 92 | + [ |
| 93 | + (ElemGrouper, ProxLasso, 2.0, True), |
| 94 | + (Dim0Grouper, ProxLasso, 3.0, False), |
| 95 | + (Dim0Grouper, ProxGroupLasso, 2.0, False), |
| 96 | + ], |
| 97 | + ) |
| 98 | + def test_tau_reweight_scales_threshold( |
| 99 | + self, grouper_cls, prox_cls, tau_reweight, disable_vmap |
| 100 | + ): |
| 101 | + """tau_reweight multiplies into the pruning threshold correctly.""" |
| 102 | + torch.manual_seed(42) |
| 103 | + reg_lambda = 0.5 |
| 104 | + gamma = 2.0 |
| 105 | + |
| 106 | + p = torch.randn(4, 6) |
| 107 | + p_ref = p.clone() |
| 108 | + |
| 109 | + # Compute reference manually |
| 110 | + if prox_cls is ProxGroupLasso: |
| 111 | + tau = math.sqrt(p.numel() // p.size(0)) # group_size for Dim0Grouper |
| 112 | + threshold = reg_lambda * tau * tau_reweight * gamma |
| 113 | + for i in range(4): |
| 114 | + row = p_ref[i] |
| 115 | + norm = torch.linalg.vector_norm(row) |
| 116 | + row.mul_(max(1 - threshold / norm.item(), 0)) |
| 117 | + else: |
| 118 | + threshold = reg_lambda * gamma * tau_reweight |
| 119 | + mult_ref = (1 - threshold / p_ref.abs()).clamp(min=0) |
| 120 | + p_ref.mul_(mult_ref) |
| 121 | + |
| 122 | + grouper = grouper_cls(p) |
| 123 | + prox_kwargs = make_prox_kwargs(gamma, disable_vmap=disable_vmap) |
| 124 | + PruneOptimizer._apply_prox( |
| 125 | + grouper, prox_cls(reg_lambda), p, tau_reweight=tau_reweight, **prox_kwargs |
| 126 | + ) |
| 127 | + |
| 128 | + self.assertEqual(p, p_ref) |
| 129 | + |
| 130 | + def test_reweight_monotonicity(self): |
| 131 | + """Higher tau_reweight zeros more elements; lower zeros fewer.""" |
| 132 | + torch.manual_seed(42) |
| 133 | + reg_lambda = 0.5 |
| 134 | + gamma = 1.0 |
| 135 | + |
| 136 | + data = torch.randn(4, 6) |
| 137 | + zeros = {} |
| 138 | + for tw in [0.1, 1.0, 5.0]: |
| 139 | + p = data.clone() |
| 140 | + grouper = Dim0Grouper(p) |
| 141 | + prox_kwargs = make_prox_kwargs(gamma) |
| 142 | + z, _, _ = PruneOptimizer._apply_prox( |
| 143 | + grouper, ProxGroupLasso(reg_lambda), p, tau_reweight=tw, **prox_kwargs |
| 144 | + ) |
| 145 | + zeros[tw] = z |
| 146 | + |
| 147 | + self.assertLessEqual(zeros[0.1], zeros[1.0]) |
| 148 | + self.assertGreaterEqual(zeros[5.0], zeros[1.0]) |
| 149 | + |
| 150 | + |
| 151 | +@unittest.skipUnless(dist.is_available(), "torch.distributed not available") |
| 152 | +class TestApplyProxReweightDTensor(DistributedTestMixin, common_utils.TestCase): |
| 153 | + """DTensor tests for _apply_prox with tau_reweight.""" |
| 154 | + |
| 155 | + @common_utils.parametrize( |
| 156 | + "GrouperCls,placements,prox_cls", |
| 157 | + [ |
| 158 | + (Dim0Grouper, (Shard(0), Replicate()), ProxGroupLasso), |
| 159 | + (Dim1Grouper, (Shard(1), Replicate()), ProxGroupLasso), |
| 160 | + (Dim0Grouper, (Shard(0), Replicate()), ProxLasso), |
| 161 | + (Dim1Grouper, (Shard(1), Replicate()), ProxLasso), |
| 162 | + ], |
| 163 | + ) |
| 164 | + def test_dtensor_matches_regular_with_reweight( |
| 165 | + self, GrouperCls, placements, prox_cls |
| 166 | + ): |
| 167 | + """DTensor vs regular tensor equivalence with tau_reweight.""" |
| 168 | + torch.manual_seed(42) |
| 169 | + reg_lambda = 0.5 |
| 170 | + gamma = 2.0 |
| 171 | + tau_reweight = 2.5 |
| 172 | + |
| 173 | + p_regular = torch.randn(4, 6) |
| 174 | + p_dtensor = distribute_tensor( |
| 175 | + p_regular.clone(), device_mesh=self.mesh, placements=placements |
| 176 | + ) |
| 177 | + |
| 178 | + prox_kwargs = make_prox_kwargs(gamma) |
| 179 | + |
| 180 | + grouper_reg = GrouperCls(p_regular) |
| 181 | + z_reg, _, _ = PruneOptimizer._apply_prox( |
| 182 | + grouper_reg, |
| 183 | + prox_cls(reg_lambda), |
| 184 | + p_regular, |
| 185 | + tau_reweight=tau_reweight, |
| 186 | + **prox_kwargs, |
| 187 | + ) |
| 188 | + |
| 189 | + grouper_dt = GrouperCls(p_dtensor) |
| 190 | + z_dt, _, _ = PruneOptimizer._apply_prox( |
| 191 | + grouper_dt, |
| 192 | + prox_cls(reg_lambda), |
| 193 | + p_dtensor, |
| 194 | + tau_reweight=tau_reweight, |
| 195 | + **prox_kwargs, |
| 196 | + ) |
| 197 | + |
| 198 | + self.assertEqual(z_reg, z_dt) |
| 199 | + self.assertEqual(p_regular, p_dtensor.full_tensor()) |
| 200 | + |
| 201 | + @common_utils.parametrize( |
| 202 | + "GrouperCls,placements,prox_cls", |
| 203 | + [ |
| 204 | + (Dim0Grouper, (Shard(0), Replicate()), ProxGroupLasso), |
| 205 | + (Dim1Grouper, (Shard(1), Replicate()), ProxGroupLasso), |
| 206 | + (Dim0Grouper, (Shard(0), Replicate()), ProxLasso), |
| 207 | + (Dim1Grouper, (Shard(1), Replicate()), ProxLasso), |
| 208 | + ], |
| 209 | + ) |
| 210 | + def test_dtensor_gamma_index_slope_with_tensor_reweight( |
| 211 | + self, GrouperCls, placements, prox_cls |
| 212 | + ): |
| 213 | + """DTensor with gamma_index_slope > 0 and tensor tau_reweight.""" |
| 214 | + torch.manual_seed(42) |
| 215 | + reg_lambda = 0.5 |
| 216 | + gamma = 2.0 |
| 217 | + |
| 218 | + p_regular = torch.randn(4, 6) |
| 219 | + n_groups = p_regular.size(0) if GrouperCls == Dim0Grouper else p_regular.size(1) |
| 220 | + tau_reweight = torch.rand(n_groups) + 0.5 # [0.5, 1.5) |
| 221 | + |
| 222 | + p_dtensor = distribute_tensor( |
| 223 | + p_regular.clone(), device_mesh=self.mesh, placements=placements |
| 224 | + ) |
| 225 | + |
| 226 | + prox_kwargs = make_prox_kwargs(gamma, gamma_index_slope=0.5) |
| 227 | + |
| 228 | + grouper_reg = GrouperCls(p_regular) |
| 229 | + z_reg, _, _ = PruneOptimizer._apply_prox( |
| 230 | + grouper_reg, |
| 231 | + prox_cls(reg_lambda), |
| 232 | + p_regular, |
| 233 | + tau_reweight=tau_reweight.clone(), |
| 234 | + **prox_kwargs, |
| 235 | + ) |
| 236 | + |
| 237 | + grouper_dt = GrouperCls(p_dtensor) |
| 238 | + z_dt, _, _ = PruneOptimizer._apply_prox( |
| 239 | + grouper_dt, |
| 240 | + prox_cls(reg_lambda), |
| 241 | + p_dtensor, |
| 242 | + tau_reweight=tau_reweight.clone(), |
| 243 | + **prox_kwargs, |
| 244 | + ) |
| 245 | + |
| 246 | + self.assertEqual(z_reg, z_dt) |
| 247 | + self.assertEqual(p_regular, p_dtensor.full_tensor()) |
| 248 | + |
| 249 | + |
| 250 | +class TestPruneOptimizerReweight(common_utils.TestCase): |
| 251 | + """End-to-end tests using PruneOptimizer with reweight_tau_freq > 0.""" |
| 252 | + |
| 253 | + def test_sigma_initialized_at_warmup_end(self): |
| 254 | + """After warmup, state['sigma'] exists for regularized params.""" |
| 255 | + torch.manual_seed(42) |
| 256 | + model = TwoLayerMLP(input_size=10, output_size=2) |
| 257 | + prune_config = model._linear_prune_config() |
| 258 | + param_groups = get_param_groups(model, prune_config, verbose=False) |
| 259 | + warmup = 2 |
| 260 | + optimizer = PruneOptimizer( |
| 261 | + torch.optim.SGD(param_groups, lr=0.1), |
| 262 | + reg_lambda=1.0, |
| 263 | + warmup_steps=warmup, |
| 264 | + reweight_tau_freq=1, |
| 265 | + ) |
| 266 | + |
| 267 | + dummy_input = torch.randn(10, 10) |
| 268 | + label = torch.randint(0, 2, (10,)) |
| 269 | + for step in range(5): |
| 270 | + optim_step(model, optimizer, dummy_input, label, step) |
| 271 | + if step < warmup: |
| 272 | + for group in optimizer.regularized_param_groups(): |
| 273 | + for p in group["params"]: |
| 274 | + self.assertNotIn("sigma", optimizer.state[p]) |
| 275 | + elif step == warmup: |
| 276 | + for group in optimizer.regularized_param_groups(): |
| 277 | + for p in group["params"]: |
| 278 | + self.assertIn("sigma", optimizer.state[p]) |
| 279 | + |
| 280 | + def test_tau_reweight_updated_at_freq(self): |
| 281 | + """state['tau_reweight'] is updated every reweight_tau_freq steps.""" |
| 282 | + torch.manual_seed(42) |
| 283 | + model = TwoLayerMLP(input_size=10, output_size=2) |
| 284 | + prune_config = model._linear_prune_config() |
| 285 | + param_groups = get_param_groups(model, prune_config, verbose=False) |
| 286 | + optimizer = PruneOptimizer( |
| 287 | + torch.optim.SGD(param_groups, lr=0.1), |
| 288 | + reg_lambda=1.0, |
| 289 | + warmup_steps=0, |
| 290 | + reweight_tau_freq=3, |
| 291 | + ) |
| 292 | + |
| 293 | + dummy_input = torch.randn(10, 10) |
| 294 | + label = torch.randint(0, 2, (10,)) |
| 295 | + for step in range(10): |
| 296 | + optim_step(model, optimizer, dummy_input, label, step) |
| 297 | + |
| 298 | + has_tau_reweight = any( |
| 299 | + "tau_reweight" in optimizer.state[p] |
| 300 | + for group in optimizer.regularized_param_groups() |
| 301 | + for p in group["params"] |
| 302 | + ) |
| 303 | + self.assertTrue(has_tau_reweight) |
| 304 | + |
| 305 | + def test_no_reweight_when_freq_zero(self): |
| 306 | + """With reweight_tau_freq=0, no sigma/tau_reweight in state.""" |
| 307 | + torch.manual_seed(42) |
| 308 | + model = TwoLayerMLP(input_size=10, output_size=2) |
| 309 | + prune_config = model._linear_prune_config() |
| 310 | + param_groups = get_param_groups(model, prune_config, verbose=False) |
| 311 | + optimizer = PruneOptimizer( |
| 312 | + torch.optim.SGD(param_groups, lr=0.1), |
| 313 | + reg_lambda=1.0, |
| 314 | + warmup_steps=0, |
| 315 | + reweight_tau_freq=0, |
| 316 | + ) |
| 317 | + |
| 318 | + dummy_input = torch.randn(10, 10) |
| 319 | + label = torch.randint(0, 2, (10,)) |
| 320 | + for step in range(10): |
| 321 | + optim_step(model, optimizer, dummy_input, label, step) |
| 322 | + |
| 323 | + for group in optimizer.regularized_param_groups(): |
| 324 | + for p in group["params"]: |
| 325 | + self.assertNotIn("sigma", optimizer.state[p]) |
| 326 | + self.assertNotIn("tau_reweight", optimizer.state[p]) |
| 327 | + |
| 328 | + def test_tau_reweight_frozen_after_end_step(self): |
| 329 | + """tau_reweight stops updating after reweight_tau_end_step.""" |
| 330 | + torch.manual_seed(42) |
| 331 | + model = TwoLayerMLP(input_size=10, output_size=2) |
| 332 | + prune_config = model._linear_prune_config() |
| 333 | + param_groups = get_param_groups(model, prune_config, verbose=False) |
| 334 | + end_step = 5 |
| 335 | + optimizer = PruneOptimizer( |
| 336 | + torch.optim.SGD(param_groups, lr=0.1), |
| 337 | + reg_lambda=1.0, |
| 338 | + warmup_steps=0, |
| 339 | + reweight_tau_freq=1, |
| 340 | + reweight_tau_end_step=end_step, |
| 341 | + ) |
| 342 | + |
| 343 | + dummy_input = torch.randn(20, 10) |
| 344 | + label = torch.randint(0, 2, (20,)) |
| 345 | + for step in range(10): |
| 346 | + optim_step(model, optimizer, dummy_input, label, step) |
| 347 | + |
| 348 | + # Capture tau_reweight after it should have frozen |
| 349 | + frozen = { |
| 350 | + id(p): optimizer.state[p]["tau_reweight"].clone() |
| 351 | + for group in optimizer.regularized_param_groups() |
| 352 | + for p in group["params"] |
| 353 | + if "tau_reweight" in optimizer.state[p] |
| 354 | + } |
| 355 | + self.assertTrue(len(frozen) > 0, "tau_reweight should exist") |
| 356 | + |
| 357 | + for step in range(10, 15): |
| 358 | + optim_step(model, optimizer, dummy_input, label, step) |
| 359 | + |
| 360 | + for group in optimizer.regularized_param_groups(): |
| 361 | + for p in group["params"]: |
| 362 | + self.assertEqual(optimizer.state[p]["tau_reweight"], frozen[id(p)]) |
| 363 | + |
| 364 | + def test_reweight_with_group_lasso(self): |
| 365 | + """End-to-end with Dim0Grouper + ProxGroupLasso (hits vmap branch).""" |
| 366 | + torch.manual_seed(42) |
| 367 | + model = TwoLayerMLP(input_size=10, output_size=2) |
| 368 | + prune_config = model._group_lasso_prune_config() |
| 369 | + param_groups = get_param_groups(model, prune_config, verbose=False) |
| 370 | + optimizer = PruneOptimizer( |
| 371 | + torch.optim.SGD(param_groups, lr=0.1), |
| 372 | + reg_lambda=1.0, |
| 373 | + warmup_steps=0, |
| 374 | + reweight_tau_freq=2, |
| 375 | + ) |
| 376 | + |
| 377 | + dummy_input = torch.randn(10, 10) |
| 378 | + label = torch.randint(0, 2, (10,)) |
| 379 | + for step in range(10): |
| 380 | + optim_step(model, optimizer, dummy_input, label, step) |
| 381 | + |
| 382 | + for group in optimizer.regularized_param_groups(): |
| 383 | + for p in group["params"]: |
| 384 | + state = optimizer.state[p] |
| 385 | + self.assertIn("sigma", state) |
| 386 | + self.assertIn("tau_reweight", state) |
| 387 | + n_groups = p.size(0) # Dim0Grouper groups along dim 0 |
| 388 | + self.assertEqual(state["tau_reweight"].numel(), n_groups) |
| 389 | + |
| 390 | + |
| 391 | +common_utils.instantiate_parametrized_tests(TestApplyProxReweight) |
| 392 | +common_utils.instantiate_parametrized_tests(TestApplyProxReweightDTensor) |
| 393 | + |
| 394 | +if __name__ == "__main__": |
| 395 | + random.seed(0) |
| 396 | + torch.manual_seed(0) |
| 397 | + unittest.main() |
0 commit comments