Skip to content

Commit 8388909

Browse files
author
Copilot
committed
Merge branch 'main' into cpu-multi-isa-dispatch
2 parents 3f781e1 + 9472d7d commit 8388909

18 files changed

Lines changed: 1045 additions & 254 deletions

File tree

test/prototype/gptq/test_gptqv2.py

Lines changed: 187 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torchao.prototype.gptq import (
2121
GPTQConfig,
2222
gptq_quantize,
23+
gptq_quantize_3d,
2324
)
2425
from torchao.prototype.gptq.observer import GPTQObserverTensor
2526
from torchao.prototype.mx_formats.inference_workflow import (
@@ -102,33 +103,13 @@ def test_observer_tensor_creation(self):
102103
# Check hp_data is stored correctly
103104
torch.testing.assert_close(observer.hp_data, weight)
104105

105-
# Check hessian is initialized as None
106-
assert observer.hessian is None
106+
# Check hessian is initialized as zeros
107+
assert torch.equal(
108+
observer.hessian, torch.zeros(64, 64, dtype=torch.float32, device="cuda")
109+
)
107110

108111
# Check total_batches is initialized as 0
109-
assert observer.total_batches == 0
110-
111-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
112-
def test_observer_tensor_attributes(self):
113-
"""Test GPTQObserverTensor attributes are correctly set."""
114-
weight = torch.randn(16, 32, dtype=torch.bfloat16, device="cuda")
115-
observer = GPTQObserverTensor.from_hp(weight)
116-
117-
# Test hp_data attribute
118-
assert hasattr(observer, "hp_data")
119-
assert isinstance(observer.hp_data, torch.Tensor)
120-
121-
# Test hessian attribute
122-
assert hasattr(observer, "hessian")
123-
assert observer.hessian is None
124-
125-
# Test total_batches attribute
126-
assert hasattr(observer, "total_batches")
127-
assert observer.total_batches == 0
128-
129-
# Test update method exists
130-
assert hasattr(observer, "update")
131-
assert callable(observer.update)
112+
assert (observer.total_batches == 0).all()
132113

133114
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
134115
def test_linear_operation_with_observer(self):
@@ -157,7 +138,7 @@ def test_linear_operation_with_observer(self):
157138
# Check that Hessian was initialized and updated
158139
assert observer_weight.hessian is not None
159140
assert observer_weight.hessian.shape == (in_features, in_features)
160-
assert observer_weight.total_batches == 1
141+
assert (observer_weight.total_batches == 1).all()
161142

162143
# Verify output is correct
163144
expected_output = F.linear(input_tensor, weight)
@@ -191,35 +172,127 @@ def test_multiple_observations(self):
191172
assert observer_weight.hessian.shape == (in_features, in_features)
192173

193174
# Check total_batches matches total samples
194-
assert observer_weight.total_batches == total_samples
175+
assert (observer_weight.total_batches == total_samples).all()
195176

196177
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
197178
def test_bmm_operation_with_observer(self):
198179
"""Test torch.bmm with GPTQObserverTensor updates Hessian correctly."""
199-
batch = 4
180+
num_experts = 4
200181
m = 8
201182
n = 16
202183
k = 12
184+
num_passes = 4
203185

204186
# Create input and weight tensors
205-
input_tensor = torch.randn(batch, m, k, dtype=torch.float32, device="cuda")
206-
weight = torch.randn(batch, k, n, dtype=torch.float32, device="cuda")
207-
observer_weight = GPTQObserverTensor.from_hp(weight)
187+
weight = torch.randn(num_experts, n, k, dtype=torch.float32, device="cuda")
208188

209-
# Perform bmm operation
210-
output = torch.bmm(input_tensor, observer_weight)
189+
inputs = [
190+
torch.randn(num_experts, m, k, dtype=torch.float32, device="cuda")
191+
for _ in range(num_passes)
192+
]
211193

212-
# Check output shape
213-
assert output.shape == (batch, m, n)
194+
# 3D path: single observer with bmm
195+
observer_3d = GPTQObserverTensor.from_hp(weight)
196+
for x in inputs:
197+
torch.bmm(x, observer_3d.transpose(-2, -1))
214198

215-
# Check Hessian was initialized and updated
216-
assert observer_weight.hessian is not None
217-
# For bmm with batch dimension, the Hessian is computed on the last dimension
218-
assert observer_weight.total_batches == batch
199+
# 2D path: per-expert observers with F.linear
200+
observers_2d = [
201+
GPTQObserverTensor.from_hp(weight[e]) for e in range(num_experts)
202+
]
203+
for x in inputs:
204+
for e in range(num_experts):
205+
F.linear(x[e], observers_2d[e])
206+
207+
# Verify per-expert hessians match bitwise to calculating each expert's
208+
# hessian individually
209+
for e in range(num_experts):
210+
assert torch.equal(observer_3d.hessian[e], observers_2d[e].hessian), (
211+
f"Expert {e} hessian mismatch"
212+
)
213+
assert torch.equal(
214+
observer_3d.total_batches[e : e + 1], observers_2d[e].total_batches
215+
), f"Expert {e} total_batches mismatch"
219216

220-
# Verify output is correct
221-
expected_output = torch.bmm(input_tensor, weight)
222-
torch.testing.assert_close(output, expected_output)
217+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
218+
@pytest.mark.skipif(
219+
not is_sm_at_least_100(),
220+
reason="CUDA capability >= 10.0 required for _grouped_mm",
221+
)
222+
def test_grouped_mm_operation_with_observer(self):
223+
"""Test torch._grouped_mm with GPTQObserverTensor updates per-expert Hessians correctly."""
224+
num_experts = 4
225+
n = 16
226+
k = 12
227+
228+
weight = torch.randn(num_experts, n, k, dtype=torch.float32, device="cuda")
229+
230+
# 4 different per-expert token distributions. Several of these have
231+
# experts that see 0 tokens, which exercises the empty-slice skip path.
232+
m_per_group_list = [
233+
[1, 3, 4, 16], # all experts active
234+
[0, 3, 4, 13], # expert 0 sees 0 tokens
235+
[5, 5, 0, 5], # expert 2 sees 0 tokens
236+
[2, 0, 6, 4], # expert 1 sees 0 tokens
237+
[2, 3, 5, 0], # expert 3 sees 0 tokens
238+
]
239+
240+
offs_list = [
241+
torch.tensor(
242+
[sum(m_per_group[: i + 1]) for i in range(num_experts)],
243+
device="cuda",
244+
dtype=torch.int32,
245+
)
246+
for m_per_group in m_per_group_list
247+
]
248+
249+
inputs = [
250+
torch.randn(sum(m_per_group), k, dtype=torch.float32, device="cuda")
251+
for m_per_group in m_per_group_list
252+
]
253+
254+
# 3D path: single observer with _grouped_mm
255+
observer_3d = GPTQObserverTensor.from_hp(weight)
256+
for x, offs in zip(inputs, offs_list):
257+
torch._grouped_mm(x, observer_3d.transpose(-2, -1), offs=offs)
258+
259+
# 2D path: per-expert observers with F.linear
260+
observers_2d = [
261+
GPTQObserverTensor.from_hp(weight[e]) for e in range(num_experts)
262+
]
263+
for x, offs in zip(inputs, offs_list):
264+
prev_end = 0
265+
for e in range(num_experts):
266+
end = offs[e].item()
267+
if end > prev_end:
268+
F.linear(x[prev_end:end], observers_2d[e])
269+
prev_end = end
270+
271+
# Verify per-expert hessians match bitwise to calculating each expert's
272+
# hessian individually
273+
for e in range(num_experts):
274+
assert torch.equal(observer_3d.hessian[e], observers_2d[e].hessian), (
275+
f"Expert {e} hessian mismatch"
276+
)
277+
assert torch.equal(
278+
observer_3d.total_batches[e : e + 1], observers_2d[e].total_batches
279+
), f"Expert {e} total_batches mismatch"
280+
281+
# Verify total_batches matches an independent count derived directly
282+
# from the offsets: each non-empty forward pass contributes 1 per
283+
# active expert (each expert's 2D slice has len(shape) == 2, so n=1).
284+
expected_total_batches = torch.tensor(
285+
[
286+
sum(1 for m_per_group in m_per_group_list if m_per_group[e] > 0)
287+
for e in range(num_experts)
288+
],
289+
dtype=torch.int64,
290+
device="cuda",
291+
)
292+
assert torch.equal(observer_3d.total_batches, expected_total_batches), (
293+
f"total_batches {observer_3d.total_batches.tolist()} "
294+
f"does not match expected {expected_total_batches.tolist()}"
295+
)
223296

224297
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
225298
@pytest.mark.parametrize(
@@ -250,17 +323,20 @@ def test_observer_config_transform(self, base_config):
250323
# Check hp_data matches original weight
251324
torch.testing.assert_close(linear.weight.hp_data, original_weight)
252325

253-
# Check hessian is None initially
254-
assert linear.weight.hessian is None
255-
assert linear.weight.total_batches == 0
326+
# Check hessian is initialized as zeros
327+
assert torch.equal(
328+
linear.weight.hessian,
329+
torch.zeros(64, 64, dtype=torch.float32, device="cuda"),
330+
)
331+
assert (linear.weight.total_batches == 0).all()
256332

257333
# Perform a forward pass
258334
input_tensor = torch.randn(4, 64, dtype=torch.float32, device="cuda")
259335
output = linear(input_tensor)
260336

261337
# Check Hessian was initialized after forward pass
262338
assert linear.weight.hessian is not None
263-
assert linear.weight.total_batches == 1
339+
assert (linear.weight.total_batches == 1).all()
264340

265341
# Check output shape
266342
assert output.shape == (4, 32)
@@ -376,6 +452,9 @@ def test_unified_config_two_phase(self, base_config):
376452
)
377453
def test_gptq_quantize_function(self, base_config):
378454
"""Test gptq_quantize function with synthetic Hessian and weights."""
455+
if isinstance(base_config, Int4WeightOnlyConfig) and is_sm_at_least_100():
456+
pytest.skip("int4 kernels do not work on sm100")
457+
379458
torch.manual_seed(42)
380459

381460
# Create synthetic weight matrix
@@ -518,6 +597,66 @@ def test_gptq_quantize_better_than_naive(self, base_config):
518597
assert gptq_loss is not None
519598
assert naive_loss is not None
520599

600+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
601+
@pytest.mark.skipif(
602+
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for nvfp4"
603+
)
604+
def test_gptq_quantize_2d_matches_3d(self):
605+
"""Verify per-expert gptq_quantize and gptq_quantize_3d produce bitwise-identical outputs."""
606+
torch.manual_seed(43)
607+
608+
E = 4
609+
out_features = 64
610+
in_features = 128
611+
num_samples = 10
612+
613+
base_config = NVFP4DynamicActivationNVFP4WeightConfig(
614+
use_dynamic_per_tensor_scale=True,
615+
use_triton_kernel=True,
616+
)
617+
config = GPTQConfig(step="convert", base_config=base_config)
618+
619+
# Per-expert weights (E, N, K) and per-expert Hessians (E, K, K)
620+
weight_3d = torch.randn(
621+
E, out_features, in_features, dtype=torch.bfloat16, device="cuda"
622+
)
623+
hessians = []
624+
for _ in range(E):
625+
activations = [
626+
torch.randn(4, in_features, dtype=torch.float32, device="cuda")
627+
for _ in range(num_samples)
628+
]
629+
hessians.append(_calculate_hessian(activations, device="cuda"))
630+
hessian_3d = torch.stack(hessians, dim=0)
631+
632+
# gptq_quantize mutates its weight/Hessian arguments in place, so clone
633+
# per-experiment to keep the two paths independent.
634+
weight_a = weight_3d.clone()
635+
weight_b = weight_3d.clone()
636+
hessian_a = hessian_3d.clone()
637+
hessian_b = hessian_3d.clone()
638+
639+
# Experiment A: E separate 2D gptq_quantize calls
640+
per_expert_2d = [
641+
gptq_quantize(hessian_a[e], weight_a[e], config) for e in range(E)
642+
]
643+
644+
# Experiment B: single 3D gptq_quantize_3d call
645+
stacked_3d = gptq_quantize_3d(hessian_b, weight_b, config)
646+
647+
# Bitwise match per expert
648+
for e in range(E):
649+
assert torch.equal(per_expert_2d[e].qdata, stacked_3d.qdata[e]), (
650+
f"Expert {e}: qdata mismatch"
651+
)
652+
assert torch.equal(per_expert_2d[e].scale, stacked_3d.scale[e]), (
653+
f"Expert {e}: scale mismatch"
654+
)
655+
assert torch.equal(
656+
per_expert_2d[e].per_tensor_scale.view(1, 1),
657+
stacked_3d.per_tensor_scale[e],
658+
), f"Expert {e}: per_tensor_scale mismatch"
659+
521660
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
522661
@pytest.mark.parametrize(
523662
"base_config",
@@ -548,6 +687,8 @@ def test_gptq_sqnr(self, base_config):
548687
and not is_sm_at_least_100()
549688
):
550689
pytest.skip("CUDA capability >= 10.0 required for nvfp4")
690+
if isinstance(base_config, Int4WeightOnlyConfig) and is_sm_at_least_100():
691+
pytest.skip("int4 kernels do not work on sm100")
551692

552693
torch.manual_seed(43)
553694

0 commit comments

Comments
 (0)