|
20 | 20 | from torchao.prototype.gptq import ( |
21 | 21 | GPTQConfig, |
22 | 22 | gptq_quantize, |
| 23 | + gptq_quantize_3d, |
23 | 24 | ) |
24 | 25 | from torchao.prototype.gptq.observer import GPTQObserverTensor |
25 | 26 | from torchao.prototype.mx_formats.inference_workflow import ( |
@@ -102,33 +103,13 @@ def test_observer_tensor_creation(self): |
102 | 103 | # Check hp_data is stored correctly |
103 | 104 | torch.testing.assert_close(observer.hp_data, weight) |
104 | 105 |
|
105 | | - # Check hessian is initialized as None |
106 | | - assert observer.hessian is None |
| 106 | + # Check hessian is initialized as zeros |
| 107 | + assert torch.equal( |
| 108 | + observer.hessian, torch.zeros(64, 64, dtype=torch.float32, device="cuda") |
| 109 | + ) |
107 | 110 |
|
108 | 111 | # Check total_batches is initialized as 0 |
109 | | - assert observer.total_batches == 0 |
110 | | - |
111 | | - @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") |
112 | | - def test_observer_tensor_attributes(self): |
113 | | - """Test GPTQObserverTensor attributes are correctly set.""" |
114 | | - weight = torch.randn(16, 32, dtype=torch.bfloat16, device="cuda") |
115 | | - observer = GPTQObserverTensor.from_hp(weight) |
116 | | - |
117 | | - # Test hp_data attribute |
118 | | - assert hasattr(observer, "hp_data") |
119 | | - assert isinstance(observer.hp_data, torch.Tensor) |
120 | | - |
121 | | - # Test hessian attribute |
122 | | - assert hasattr(observer, "hessian") |
123 | | - assert observer.hessian is None |
124 | | - |
125 | | - # Test total_batches attribute |
126 | | - assert hasattr(observer, "total_batches") |
127 | | - assert observer.total_batches == 0 |
128 | | - |
129 | | - # Test update method exists |
130 | | - assert hasattr(observer, "update") |
131 | | - assert callable(observer.update) |
| 112 | + assert (observer.total_batches == 0).all() |
132 | 113 |
|
133 | 114 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") |
134 | 115 | def test_linear_operation_with_observer(self): |
@@ -157,7 +138,7 @@ def test_linear_operation_with_observer(self): |
157 | 138 | # Check that Hessian was initialized and updated |
158 | 139 | assert observer_weight.hessian is not None |
159 | 140 | assert observer_weight.hessian.shape == (in_features, in_features) |
160 | | - assert observer_weight.total_batches == 1 |
| 141 | + assert (observer_weight.total_batches == 1).all() |
161 | 142 |
|
162 | 143 | # Verify output is correct |
163 | 144 | expected_output = F.linear(input_tensor, weight) |
@@ -191,35 +172,127 @@ def test_multiple_observations(self): |
191 | 172 | assert observer_weight.hessian.shape == (in_features, in_features) |
192 | 173 |
|
193 | 174 | # Check total_batches matches total samples |
194 | | - assert observer_weight.total_batches == total_samples |
| 175 | + assert (observer_weight.total_batches == total_samples).all() |
195 | 176 |
|
196 | 177 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") |
197 | 178 | def test_bmm_operation_with_observer(self): |
198 | 179 | """Test torch.bmm with GPTQObserverTensor updates Hessian correctly.""" |
199 | | - batch = 4 |
| 180 | + num_experts = 4 |
200 | 181 | m = 8 |
201 | 182 | n = 16 |
202 | 183 | k = 12 |
| 184 | + num_passes = 4 |
203 | 185 |
|
204 | 186 | # Create input and weight tensors |
205 | | - input_tensor = torch.randn(batch, m, k, dtype=torch.float32, device="cuda") |
206 | | - weight = torch.randn(batch, k, n, dtype=torch.float32, device="cuda") |
207 | | - observer_weight = GPTQObserverTensor.from_hp(weight) |
| 187 | + weight = torch.randn(num_experts, n, k, dtype=torch.float32, device="cuda") |
208 | 188 |
|
209 | | - # Perform bmm operation |
210 | | - output = torch.bmm(input_tensor, observer_weight) |
| 189 | + inputs = [ |
| 190 | + torch.randn(num_experts, m, k, dtype=torch.float32, device="cuda") |
| 191 | + for _ in range(num_passes) |
| 192 | + ] |
211 | 193 |
|
212 | | - # Check output shape |
213 | | - assert output.shape == (batch, m, n) |
| 194 | + # 3D path: single observer with bmm |
| 195 | + observer_3d = GPTQObserverTensor.from_hp(weight) |
| 196 | + for x in inputs: |
| 197 | + torch.bmm(x, observer_3d.transpose(-2, -1)) |
214 | 198 |
|
215 | | - # Check Hessian was initialized and updated |
216 | | - assert observer_weight.hessian is not None |
217 | | - # For bmm with batch dimension, the Hessian is computed on the last dimension |
218 | | - assert observer_weight.total_batches == batch |
| 199 | + # 2D path: per-expert observers with F.linear |
| 200 | + observers_2d = [ |
| 201 | + GPTQObserverTensor.from_hp(weight[e]) for e in range(num_experts) |
| 202 | + ] |
| 203 | + for x in inputs: |
| 204 | + for e in range(num_experts): |
| 205 | + F.linear(x[e], observers_2d[e]) |
| 206 | + |
| 207 | + # Verify per-expert hessians match bitwise to calculating each expert's |
| 208 | + # hessian individually |
| 209 | + for e in range(num_experts): |
| 210 | + assert torch.equal(observer_3d.hessian[e], observers_2d[e].hessian), ( |
| 211 | + f"Expert {e} hessian mismatch" |
| 212 | + ) |
| 213 | + assert torch.equal( |
| 214 | + observer_3d.total_batches[e : e + 1], observers_2d[e].total_batches |
| 215 | + ), f"Expert {e} total_batches mismatch" |
219 | 216 |
|
220 | | - # Verify output is correct |
221 | | - expected_output = torch.bmm(input_tensor, weight) |
222 | | - torch.testing.assert_close(output, expected_output) |
| 217 | + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") |
| 218 | + @pytest.mark.skipif( |
| 219 | + not is_sm_at_least_100(), |
| 220 | + reason="CUDA capability >= 10.0 required for _grouped_mm", |
| 221 | + ) |
| 222 | + def test_grouped_mm_operation_with_observer(self): |
| 223 | + """Test torch._grouped_mm with GPTQObserverTensor updates per-expert Hessians correctly.""" |
| 224 | + num_experts = 4 |
| 225 | + n = 16 |
| 226 | + k = 12 |
| 227 | + |
| 228 | + weight = torch.randn(num_experts, n, k, dtype=torch.float32, device="cuda") |
| 229 | + |
| 230 | + # 4 different per-expert token distributions. Several of these have |
| 231 | + # experts that see 0 tokens, which exercises the empty-slice skip path. |
| 232 | + m_per_group_list = [ |
| 233 | + [1, 3, 4, 16], # all experts active |
| 234 | + [0, 3, 4, 13], # expert 0 sees 0 tokens |
| 235 | + [5, 5, 0, 5], # expert 2 sees 0 tokens |
| 236 | + [2, 0, 6, 4], # expert 1 sees 0 tokens |
| 237 | + [2, 3, 5, 0], # expert 3 sees 0 tokens |
| 238 | + ] |
| 239 | + |
| 240 | + offs_list = [ |
| 241 | + torch.tensor( |
| 242 | + [sum(m_per_group[: i + 1]) for i in range(num_experts)], |
| 243 | + device="cuda", |
| 244 | + dtype=torch.int32, |
| 245 | + ) |
| 246 | + for m_per_group in m_per_group_list |
| 247 | + ] |
| 248 | + |
| 249 | + inputs = [ |
| 250 | + torch.randn(sum(m_per_group), k, dtype=torch.float32, device="cuda") |
| 251 | + for m_per_group in m_per_group_list |
| 252 | + ] |
| 253 | + |
| 254 | + # 3D path: single observer with _grouped_mm |
| 255 | + observer_3d = GPTQObserverTensor.from_hp(weight) |
| 256 | + for x, offs in zip(inputs, offs_list): |
| 257 | + torch._grouped_mm(x, observer_3d.transpose(-2, -1), offs=offs) |
| 258 | + |
| 259 | + # 2D path: per-expert observers with F.linear |
| 260 | + observers_2d = [ |
| 261 | + GPTQObserverTensor.from_hp(weight[e]) for e in range(num_experts) |
| 262 | + ] |
| 263 | + for x, offs in zip(inputs, offs_list): |
| 264 | + prev_end = 0 |
| 265 | + for e in range(num_experts): |
| 266 | + end = offs[e].item() |
| 267 | + if end > prev_end: |
| 268 | + F.linear(x[prev_end:end], observers_2d[e]) |
| 269 | + prev_end = end |
| 270 | + |
| 271 | + # Verify per-expert hessians match bitwise to calculating each expert's |
| 272 | + # hessian individually |
| 273 | + for e in range(num_experts): |
| 274 | + assert torch.equal(observer_3d.hessian[e], observers_2d[e].hessian), ( |
| 275 | + f"Expert {e} hessian mismatch" |
| 276 | + ) |
| 277 | + assert torch.equal( |
| 278 | + observer_3d.total_batches[e : e + 1], observers_2d[e].total_batches |
| 279 | + ), f"Expert {e} total_batches mismatch" |
| 280 | + |
| 281 | + # Verify total_batches matches an independent count derived directly |
| 282 | + # from the offsets: each non-empty forward pass contributes 1 per |
| 283 | + # active expert (each expert's 2D slice has len(shape) == 2, so n=1). |
| 284 | + expected_total_batches = torch.tensor( |
| 285 | + [ |
| 286 | + sum(1 for m_per_group in m_per_group_list if m_per_group[e] > 0) |
| 287 | + for e in range(num_experts) |
| 288 | + ], |
| 289 | + dtype=torch.int64, |
| 290 | + device="cuda", |
| 291 | + ) |
| 292 | + assert torch.equal(observer_3d.total_batches, expected_total_batches), ( |
| 293 | + f"total_batches {observer_3d.total_batches.tolist()} " |
| 294 | + f"does not match expected {expected_total_batches.tolist()}" |
| 295 | + ) |
223 | 296 |
|
224 | 297 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") |
225 | 298 | @pytest.mark.parametrize( |
@@ -250,17 +323,20 @@ def test_observer_config_transform(self, base_config): |
250 | 323 | # Check hp_data matches original weight |
251 | 324 | torch.testing.assert_close(linear.weight.hp_data, original_weight) |
252 | 325 |
|
253 | | - # Check hessian is None initially |
254 | | - assert linear.weight.hessian is None |
255 | | - assert linear.weight.total_batches == 0 |
| 326 | + # Check hessian is initialized as zeros |
| 327 | + assert torch.equal( |
| 328 | + linear.weight.hessian, |
| 329 | + torch.zeros(64, 64, dtype=torch.float32, device="cuda"), |
| 330 | + ) |
| 331 | + assert (linear.weight.total_batches == 0).all() |
256 | 332 |
|
257 | 333 | # Perform a forward pass |
258 | 334 | input_tensor = torch.randn(4, 64, dtype=torch.float32, device="cuda") |
259 | 335 | output = linear(input_tensor) |
260 | 336 |
|
261 | 337 | # Check Hessian was initialized after forward pass |
262 | 338 | assert linear.weight.hessian is not None |
263 | | - assert linear.weight.total_batches == 1 |
| 339 | + assert (linear.weight.total_batches == 1).all() |
264 | 340 |
|
265 | 341 | # Check output shape |
266 | 342 | assert output.shape == (4, 32) |
@@ -376,6 +452,9 @@ def test_unified_config_two_phase(self, base_config): |
376 | 452 | ) |
377 | 453 | def test_gptq_quantize_function(self, base_config): |
378 | 454 | """Test gptq_quantize function with synthetic Hessian and weights.""" |
| 455 | + if isinstance(base_config, Int4WeightOnlyConfig) and is_sm_at_least_100(): |
| 456 | + pytest.skip("int4 kernels do not work on sm100") |
| 457 | + |
379 | 458 | torch.manual_seed(42) |
380 | 459 |
|
381 | 460 | # Create synthetic weight matrix |
@@ -518,6 +597,66 @@ def test_gptq_quantize_better_than_naive(self, base_config): |
518 | 597 | assert gptq_loss is not None |
519 | 598 | assert naive_loss is not None |
520 | 599 |
|
| 600 | + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") |
| 601 | + @pytest.mark.skipif( |
| 602 | + not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for nvfp4" |
| 603 | + ) |
| 604 | + def test_gptq_quantize_2d_matches_3d(self): |
| 605 | + """Verify per-expert gptq_quantize and gptq_quantize_3d produce bitwise-identical outputs.""" |
| 606 | + torch.manual_seed(43) |
| 607 | + |
| 608 | + E = 4 |
| 609 | + out_features = 64 |
| 610 | + in_features = 128 |
| 611 | + num_samples = 10 |
| 612 | + |
| 613 | + base_config = NVFP4DynamicActivationNVFP4WeightConfig( |
| 614 | + use_dynamic_per_tensor_scale=True, |
| 615 | + use_triton_kernel=True, |
| 616 | + ) |
| 617 | + config = GPTQConfig(step="convert", base_config=base_config) |
| 618 | + |
| 619 | + # Per-expert weights (E, N, K) and per-expert Hessians (E, K, K) |
| 620 | + weight_3d = torch.randn( |
| 621 | + E, out_features, in_features, dtype=torch.bfloat16, device="cuda" |
| 622 | + ) |
| 623 | + hessians = [] |
| 624 | + for _ in range(E): |
| 625 | + activations = [ |
| 626 | + torch.randn(4, in_features, dtype=torch.float32, device="cuda") |
| 627 | + for _ in range(num_samples) |
| 628 | + ] |
| 629 | + hessians.append(_calculate_hessian(activations, device="cuda")) |
| 630 | + hessian_3d = torch.stack(hessians, dim=0) |
| 631 | + |
| 632 | + # gptq_quantize mutates its weight/Hessian arguments in place, so clone |
| 633 | + # per-experiment to keep the two paths independent. |
| 634 | + weight_a = weight_3d.clone() |
| 635 | + weight_b = weight_3d.clone() |
| 636 | + hessian_a = hessian_3d.clone() |
| 637 | + hessian_b = hessian_3d.clone() |
| 638 | + |
| 639 | + # Experiment A: E separate 2D gptq_quantize calls |
| 640 | + per_expert_2d = [ |
| 641 | + gptq_quantize(hessian_a[e], weight_a[e], config) for e in range(E) |
| 642 | + ] |
| 643 | + |
| 644 | + # Experiment B: single 3D gptq_quantize_3d call |
| 645 | + stacked_3d = gptq_quantize_3d(hessian_b, weight_b, config) |
| 646 | + |
| 647 | + # Bitwise match per expert |
| 648 | + for e in range(E): |
| 649 | + assert torch.equal(per_expert_2d[e].qdata, stacked_3d.qdata[e]), ( |
| 650 | + f"Expert {e}: qdata mismatch" |
| 651 | + ) |
| 652 | + assert torch.equal(per_expert_2d[e].scale, stacked_3d.scale[e]), ( |
| 653 | + f"Expert {e}: scale mismatch" |
| 654 | + ) |
| 655 | + assert torch.equal( |
| 656 | + per_expert_2d[e].per_tensor_scale.view(1, 1), |
| 657 | + stacked_3d.per_tensor_scale[e], |
| 658 | + ), f"Expert {e}: per_tensor_scale mismatch" |
| 659 | + |
521 | 660 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") |
522 | 661 | @pytest.mark.parametrize( |
523 | 662 | "base_config", |
@@ -548,6 +687,8 @@ def test_gptq_sqnr(self, base_config): |
548 | 687 | and not is_sm_at_least_100() |
549 | 688 | ): |
550 | 689 | pytest.skip("CUDA capability >= 10.0 required for nvfp4") |
| 690 | + if isinstance(base_config, Int4WeightOnlyConfig) and is_sm_at_least_100(): |
| 691 | + pytest.skip("int4 kernels do not work on sm100") |
551 | 692 |
|
552 | 693 | torch.manual_seed(43) |
553 | 694 |
|
|
0 commit comments