Skip to content

Commit 7a34455

Browse files
committed
extend GPTQ coverage to grouped_mm
Summary: Extend GPTQ for grouped_mm. Punting the redefinition of counting batches vs tokens to a future PR. Test Plan: ``` pytest test/prototype/gptq/test_gptqv2.py -s ``` ghstack-source-id: d93a8c3 ghstack-comment-id: 4313533590 Pull-Request: #4328
1 parent e841b07 commit 7a34455

2 files changed

Lines changed: 109 additions & 0 deletions

File tree

test/prototype/gptq/test_gptqv2.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,85 @@ def test_bmm_operation_with_observer(self):
213213
observer_3d.total_batches[e : e + 1], observers_2d[e].total_batches
214214
), f"Expert {e} total_batches mismatch"
215215

216+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
217+
@pytest.mark.skipif(
218+
not is_sm_at_least_100(),
219+
reason="CUDA capability >= 10.0 required for _grouped_mm",
220+
)
221+
def test_grouped_mm_operation_with_observer(self):
222+
"""Test torch._grouped_mm with GPTQObserverTensor updates per-expert Hessians correctly."""
223+
num_experts = 4
224+
n = 16
225+
k = 12
226+
227+
weight = torch.randn(num_experts, n, k, dtype=torch.float32, device="cuda")
228+
229+
# 4 different per-expert token distributions. Several of these have
230+
# experts that see 0 tokens, which exercises the empty-slice skip path.
231+
m_per_group_list = [
232+
[1, 3, 4, 16], # all experts active
233+
[0, 3, 4, 13], # expert 0 sees 0 tokens
234+
[5, 5, 0, 5], # expert 2 sees 0 tokens
235+
[2, 0, 6, 4], # expert 1 sees 0 tokens
236+
]
237+
238+
offs_list = [
239+
torch.tensor(
240+
[sum(m_per_group[: i + 1]) for i in range(num_experts)],
241+
device="cuda",
242+
dtype=torch.int32,
243+
)
244+
for m_per_group in m_per_group_list
245+
]
246+
247+
inputs = [
248+
torch.randn(sum(m_per_group), k, dtype=torch.float32, device="cuda")
249+
for m_per_group in m_per_group_list
250+
]
251+
252+
# 3D path: single observer with _grouped_mm
253+
observer_3d = GPTQObserverTensor.from_hp(weight)
254+
for x, offs in zip(inputs, offs_list):
255+
torch._grouped_mm(x, observer_3d.transpose(-2, -1), offs=offs)
256+
257+
# 2D path: per-expert observers with F.linear
258+
observers_2d = [
259+
GPTQObserverTensor.from_hp(weight[e]) for e in range(num_experts)
260+
]
261+
for x, offs in zip(inputs, offs_list):
262+
prev_end = 0
263+
for e in range(num_experts):
264+
end = offs[e].item()
265+
if end > prev_end:
266+
F.linear(x[prev_end:end], observers_2d[e])
267+
prev_end = end
268+
269+
# Verify per-expert hessians match bitwise to calculating each expert's
270+
# hessian individually
271+
for e in range(num_experts):
272+
assert torch.equal(observer_3d.hessian[e], observers_2d[e].hessian), (
273+
f"Expert {e} hessian mismatch"
274+
)
275+
assert torch.equal(
276+
observer_3d.total_batches[e : e + 1], observers_2d[e].total_batches
277+
), f"Expert {e} total_batches mismatch"
278+
279+
# Verify total_batches matches an independent count derived directly
280+
# from the offsets: each non-empty forward pass contributes 1 per
281+
# active expert (each expert's 2D slice has len(shape) == 2, so n=1).
282+
expected_total_batches = torch.tensor(
283+
[
284+
sum(1 for m_per_group in m_per_group_list if m_per_group[e] > 0)
285+
for e in range(num_experts)
286+
],
287+
dtype=torch.int64,
288+
device="cuda",
289+
)
290+
assert torch.equal(observer_3d.total_batches, expected_total_batches), (
291+
f"total_batches {observer_3d.total_batches.tolist()} "
292+
f"does not match expected {expected_total_batches.tolist()}"
293+
)
294+
216295
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
217296
@pytest.mark.parametrize(
218297
"base_config",

torchao/prototype/gptq/observer.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ def __init__(self, hp_data: torch.Tensor, total_batches, hessian=None):
3030
if isinstance(total_batches, torch.Tensor):
3131
self.total_batches = total_batches
3232
elif len(self.hp_data.shape) == 3:
33+
# TODO(future PR): audit whether we need to change this
34+
# from `total_batches` (current) to something like `total_tokens`,
35+
# to ensure that each token is weighted equally in the 3d case.
3336
self.total_batches = torch.zeros(
3437
self.hp_data.shape[0], dtype=torch.int64, device=self.hp_data.device
3538
)
@@ -98,6 +101,23 @@ def update_3d(self, input: torch.Tensor):
98101
total_batches = self.total_batches[e_idx : e_idx + 1]
99102
self._update_single_hessian(x_cur, h_cur, total_batches)
100103

104+
def update_3d_with_offs(self, input: torch.Tensor, offs: torch.Tensor):
105+
x = input.float().to(self.hp_data.device)
106+
# offs is cumulative end indices; expert e gets rows [prev_end : offs[e]]
107+
# Pull offs to CPU once to avoid a GPU->CPU sync per expert.
108+
# TODO(future PR): optimize if this is too slow
109+
offs_cpu = offs.tolist()
110+
prev_end = 0
111+
for e_idx in range(self.hessian.shape[0]):
112+
end = offs_cpu[e_idx]
113+
if end == prev_end:
114+
continue
115+
x_cur = x[prev_end:end]
116+
h_cur = self.hessian[e_idx]
117+
total_batches = self.total_batches[e_idx : e_idx + 1]
118+
self._update_single_hessian(x_cur, h_cur, total_batches)
119+
prev_end = end
120+
101121
@classmethod
102122
def from_hp(cls, hp_tensor):
103123
return GPTQObserverTensor(hp_tensor, 0, None)
@@ -145,3 +165,13 @@ def _(func, types, args, kwargs):
145165
)
146166
weight_tensor.update_3d(input_tensor.detach())
147167
return func(input_tensor, weight_tensor.hp_data)
168+
169+
170+
@implements([aten._grouped_mm.default])
171+
def _(func, types, args, kwargs):
172+
mat_a, mat_b = args[0], args[1]
173+
offs = args[2] if len(args) > 2 else kwargs.get("offs", None)
174+
assert offs is not None, "offs is required for grouped_mm"
175+
assert isinstance(mat_b, GPTQObserverTensor)
176+
mat_b.update_3d_with_offs(mat_a.detach(), offs)
177+
return func(mat_a, mat_b.hp_data, offs)

0 commit comments

Comments
 (0)