Skip to content

Commit e841b07

Browse files
committed
nvfp4 gptq for bmm
Summary: Extend GPTQ coverage for bmm, formulating the bmm as a 3d case of mm. This involves: 1. refactoring the 2d code to make it easily extendable to 3d 2. the existing bmm logic was numerically incorrect (used a single hessian), modify it to instead use E K by K Hessians for an `E, N, K` input shape, route to the 2D hessian logic N times. This is slow but we can optimize later. We test numerical correctness by bitwise matching E 2d hessian calculations to the 3D one. Test Plan: ``` torchao/prototype/gptq/gptq_nvfp4_llama3_2_1b_nonsequential_wikitext.sh // gptq accuracy unchanged ``` ghstack-source-id: 85a5964 ghstack-comment-id: 4313324694 Pull-Request: #4327
1 parent 052725e commit e841b07

3 files changed

Lines changed: 117 additions & 75 deletions

File tree

test/prototype/gptq/test_gptqv2.py

Lines changed: 38 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -108,31 +108,7 @@ def test_observer_tensor_creation(self):
108108
)
109109

110110
# Check total_batches is initialized as 0
111-
assert observer.total_batches == 0
112-
113-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
114-
def test_observer_tensor_attributes(self):
115-
"""Test GPTQObserverTensor attributes are correctly set."""
116-
weight = torch.randn(16, 32, dtype=torch.bfloat16, device="cuda")
117-
observer = GPTQObserverTensor.from_hp(weight)
118-
119-
# Test hp_data attribute
120-
assert hasattr(observer, "hp_data")
121-
assert isinstance(observer.hp_data, torch.Tensor)
122-
123-
# Test hessian attribute
124-
assert hasattr(observer, "hessian")
125-
assert torch.equal(
126-
observer.hessian, torch.zeros(32, 32, dtype=torch.float32, device="cuda")
127-
)
128-
129-
# Test total_batches attribute
130-
assert hasattr(observer, "total_batches")
131-
assert observer.total_batches == 0
132-
133-
# Test update method exists
134-
assert hasattr(observer, "update")
135-
assert callable(observer.update)
111+
assert (observer.total_batches == 0).all()
136112

137113
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
138114
def test_linear_operation_with_observer(self):
@@ -161,7 +137,7 @@ def test_linear_operation_with_observer(self):
161137
# Check that Hessian was initialized and updated
162138
assert observer_weight.hessian is not None
163139
assert observer_weight.hessian.shape == (in_features, in_features)
164-
assert observer_weight.total_batches == 1
140+
assert (observer_weight.total_batches == 1).all()
165141

166142
# Verify output is correct
167143
expected_output = F.linear(input_tensor, weight)
@@ -195,36 +171,47 @@ def test_multiple_observations(self):
195171
assert observer_weight.hessian.shape == (in_features, in_features)
196172

197173
# Check total_batches matches total samples
198-
assert observer_weight.total_batches == total_samples
174+
assert (observer_weight.total_batches == total_samples).all()
199175

200-
@pytest.mark.skip(reason="bmm math is incorrect, will fix in next PR")
201176
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
202177
def test_bmm_operation_with_observer(self):
203178
"""Test torch.bmm with GPTQObserverTensor updates Hessian correctly."""
204-
batch = 4
179+
num_experts = 4
205180
m = 8
206181
n = 16
207182
k = 12
183+
num_passes = 4
208184

209185
# Create input and weight tensors
210-
input_tensor = torch.randn(batch, m, k, dtype=torch.float32, device="cuda")
211-
weight = torch.randn(batch, k, n, dtype=torch.float32, device="cuda")
212-
observer_weight = GPTQObserverTensor.from_hp(weight)
186+
weight = torch.randn(num_experts, n, k, dtype=torch.float32, device="cuda")
213187

214-
# Perform bmm operation
215-
output = torch.bmm(input_tensor, observer_weight)
188+
inputs = [
189+
torch.randn(num_experts, m, k, dtype=torch.float32, device="cuda")
190+
for _ in range(num_passes)
191+
]
216192

217-
# Check output shape
218-
assert output.shape == (batch, m, n)
193+
# 3D path: single observer with bmm
194+
observer_3d = GPTQObserverTensor.from_hp(weight)
195+
for x in inputs:
196+
torch.bmm(x, observer_3d.transpose(-2, -1))
219197

220-
# Check Hessian was initialized and updated
221-
assert observer_weight.hessian is not None
222-
# For bmm with batch dimension, the Hessian is computed on the last dimension
223-
assert observer_weight.total_batches == batch
224-
225-
# Verify output is correct
226-
expected_output = torch.bmm(input_tensor, weight)
227-
torch.testing.assert_close(output, expected_output)
198+
# 2D path: per-expert observers with F.linear
199+
observers_2d = [
200+
GPTQObserverTensor.from_hp(weight[e]) for e in range(num_experts)
201+
]
202+
for x in inputs:
203+
for e in range(num_experts):
204+
F.linear(x[e], observers_2d[e])
205+
206+
# Verify per-expert hessians match bitwise to calculating each expert's
207+
# hessian individually
208+
for e in range(num_experts):
209+
assert torch.equal(observer_3d.hessian[e], observers_2d[e].hessian), (
210+
f"Expert {e} hessian mismatch"
211+
)
212+
assert torch.equal(
213+
observer_3d.total_batches[e : e + 1], observers_2d[e].total_batches
214+
), f"Expert {e} total_batches mismatch"
228215

229216
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
230217
@pytest.mark.parametrize(
@@ -260,15 +247,15 @@ def test_observer_config_transform(self, base_config):
260247
linear.weight.hessian,
261248
torch.zeros(64, 64, dtype=torch.float32, device="cuda"),
262249
)
263-
assert linear.weight.total_batches == 0
250+
assert (linear.weight.total_batches == 0).all()
264251

265252
# Perform a forward pass
266253
input_tensor = torch.randn(4, 64, dtype=torch.float32, device="cuda")
267254
output = linear(input_tensor)
268255

269256
# Check Hessian was initialized after forward pass
270257
assert linear.weight.hessian is not None
271-
assert linear.weight.total_batches == 1
258+
assert (linear.weight.total_batches == 1).all()
272259

273260
# Check output shape
274261
assert output.shape == (4, 32)
@@ -384,6 +371,9 @@ def test_unified_config_two_phase(self, base_config):
384371
)
385372
def test_gptq_quantize_function(self, base_config):
386373
"""Test gptq_quantize function with synthetic Hessian and weights."""
374+
if isinstance(base_config, Int4WeightOnlyConfig) and is_sm_at_least_100():
375+
pytest.skip("int4 kernels do not work on sm100")
376+
387377
torch.manual_seed(42)
388378

389379
# Create synthetic weight matrix
@@ -556,6 +546,8 @@ def test_gptq_sqnr(self, base_config):
556546
and not is_sm_at_least_100()
557547
):
558548
pytest.skip("CUDA capability >= 10.0 required for nvfp4")
549+
if isinstance(base_config, Int4WeightOnlyConfig) and is_sm_at_least_100():
550+
pytest.skip("int4 kernels do not work on sm100")
559551

560552
torch.manual_seed(43)
561553

torchao/prototype/gptq/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def _gptq_config_transform(
141141
)
142142

143143
# Validate that observations were recorded
144-
if tensor.total_batches == 0:
144+
if (tensor.total_batches == 0).any():
145145
raise ValueError(
146146
f"No observations recorded for {parameter_name}. "
147147
f"total_batches is 0. Did you run forward passes during the observe step?"

torchao/prototype/gptq/observer.py

Lines changed: 78 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,54 +11,92 @@
1111

1212

1313
class GPTQObserverTensor(TorchAOBaseTensor):
14-
tensor_data_names = ["hp_data"]
14+
tensor_data_names = ["hp_data", "total_batches"]
1515
optional_tensor_data_names = ["hessian"]
16-
tensor_attribute_names = ["total_batches"]
16+
tensor_attribute_names = []
1717

18-
def __new__(cls, hp_data: torch.Tensor, total_batches: int, hessian=None):
18+
def __new__(cls, hp_data: torch.Tensor, total_batches, hessian=None):
1919
shape = hp_data.shape
2020
kwargs = {}
2121
kwargs["device"] = hp_data.device
2222
kwargs["dtype"] = hp_data.dtype
2323
kwargs["requires_grad"] = False
2424
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
2525

26-
def __init__(self, hp_data: torch.Tensor, total_batches: int, hessian=None):
26+
def __init__(self, hp_data: torch.Tensor, total_batches, hessian=None):
2727
super().__init__()
2828
self.hp_data = hp_data
2929
self.hessian = hessian
30-
self.total_batches = total_batches
30+
if isinstance(total_batches, torch.Tensor):
31+
self.total_batches = total_batches
32+
elif len(self.hp_data.shape) == 3:
33+
self.total_batches = torch.zeros(
34+
self.hp_data.shape[0], dtype=torch.int64, device=self.hp_data.device
35+
)
36+
else:
37+
self.total_batches = torch.zeros(
38+
1, dtype=torch.int64, device=self.hp_data.device
39+
)
3140

3241
# initialize hessian
33-
assert self.hp_data.is_contiguous()
3442
if self.hessian is None:
43+
assert self.hp_data.is_contiguous()
3544
feature_dim = self.hp_data.shape[-1]
36-
self.hessian = torch.zeros(
37-
feature_dim,
38-
feature_dim,
39-
dtype=torch.float32,
40-
device=self.hp_data.device,
41-
)
42-
43-
def update(self, input: torch.Tensor):
44-
"""Incrementally update Hessian matrix from input activations."""
45-
# Move input to same device as hp_data and convert to float
46-
x = input.float().to(self.hp_data.device)
45+
if len(self.hp_data.shape) == 2:
46+
self.hessian = torch.zeros(
47+
feature_dim,
48+
feature_dim,
49+
dtype=torch.float32,
50+
device=self.hp_data.device,
51+
)
52+
else:
53+
assert len(self.hp_data.shape) == 3, "unsupported"
54+
expert_dim = self.hp_data.shape[0]
55+
self.hessian = torch.zeros(
56+
expert_dim,
57+
feature_dim,
58+
feature_dim,
59+
dtype=torch.float32,
60+
device=self.hp_data.device,
61+
)
62+
63+
@staticmethod
64+
def _update_single_hessian(
65+
x: torch.Tensor, hessian: torch.Tensor, total_batches: torch.Tensor
66+
):
67+
"""Update a single 2D Hessian and total_batches in-place."""
4768
shape = x.shape
48-
49-
# Calculate batch size
5069
n = 1 if len(shape) == 2 else shape[0]
5170
x = x.reshape(-1, shape[-1])
5271

53-
# Apply running average formula
54-
if self.total_batches > 0:
55-
self.hessian *= self.total_batches / (self.total_batches + n)
72+
# cast to Python int64 for optimal type promotion semantics
73+
# Note: there is definitely a better way to get ^, saving for
74+
# a follow-up PR. For now, this preserves numerics.
75+
tb = total_batches.item()
76+
if tb > 0:
77+
hessian *= tb / (tb + n)
5678

57-
self.total_batches += n
79+
total_batches += n
80+
# cast to Python int64 for optimal type promotion semantics
81+
# Note: there is definitely a better way to get ^, saving for
82+
# a follow-up PR. For now, this preserves numerics.
83+
tb = total_batches.item()
5884

59-
# Update Hessian: x = ((2 / total_batches) ** (1 / 2)) * x.t()
60-
x = ((2 / self.total_batches) ** (1 / 2)) * x.t()
61-
self.hessian += x.matmul(x.t())
85+
x = ((2 / tb) ** (1 / 2)) * x.t()
86+
hessian += x.matmul(x.t())
87+
88+
def update_2d(self, input: torch.Tensor):
89+
x = input.float().to(self.hp_data.device)
90+
self._update_single_hessian(x, self.hessian, self.total_batches[0:1])
91+
92+
def update_3d(self, input: torch.Tensor):
93+
x = input.float().to(self.hp_data.device)
94+
# TODO(future PR): optimize if this is too slow
95+
for e_idx in range(self.hessian.shape[0]):
96+
x_cur = x[e_idx]
97+
h_cur = self.hessian[e_idx]
98+
total_batches = self.total_batches[e_idx : e_idx + 1]
99+
self._update_single_hessian(x_cur, h_cur, total_batches)
62100

63101
@classmethod
64102
def from_hp(cls, hp_tensor):
@@ -79,19 +117,31 @@ def _(func, types, args, kwargs):
79117
args[2] if len(args) > 2 else None,
80118
)
81119
if isinstance(weight_tensor, GPTQObserverTensor):
82-
weight_tensor.update(input_tensor.detach())
120+
weight_tensor.update_2d(input_tensor.detach())
83121
return F.linear(input_tensor, weight_tensor.hp_data, bias)
84122
else:
85123
raise ValueError(
86124
f"Expected weight_tensor to be GPTQObserverTensor, got: {type(weight_tensor)}"
87125
)
88126

89127

128+
@implements(aten.transpose.int)
129+
def _(func, types, args, kwargs):
130+
self, dim0, dim1 = args[0], args[1], args[2]
131+
assert {dim0, dim1} == {-2, -1} or {dim0, dim1} == {
132+
self.hp_data.ndim - 2,
133+
self.hp_data.ndim - 1,
134+
}, f"only transpose of last two dims is supported, got dims {dim0}, {dim1}"
135+
new_data = func(self.hp_data, dim0, dim1)
136+
new_hessian = func(self.hessian, dim0, dim1)
137+
return GPTQObserverTensor(new_data, self.total_batches, new_hessian)
138+
139+
90140
@implements(aten.bmm.default)
91141
def _(func, types, args, kwargs):
92142
input_tensor, weight_tensor = (
93143
args[0],
94144
args[1],
95145
)
96-
weight_tensor.update(input_tensor.detach())
146+
weight_tensor.update_3d(input_tensor.detach())
97147
return func(input_tensor, weight_tensor.hp_data)

0 commit comments

Comments
 (0)