@@ -213,6 +213,85 @@ def test_bmm_operation_with_observer(self):
213213 observer_3d .total_batches [e : e + 1 ], observers_2d [e ].total_batches
214214 ), f"Expert { e } total_batches mismatch"
215215
216+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "Need CUDA available" )
217+ @pytest .mark .skipif (
218+ not is_sm_at_least_100 (),
219+ reason = "CUDA capability >= 10.0 required for _grouped_mm" ,
220+ )
221+ def test_grouped_mm_operation_with_observer (self ):
222+ """Test torch._grouped_mm with GPTQObserverTensor updates per-expert Hessians correctly."""
223+ num_experts = 4
224+ n = 16
225+ k = 12
226+
227+ weight = torch .randn (num_experts , n , k , dtype = torch .float32 , device = "cuda" )
228+
229+ # 4 different per-expert token distributions. Several of these have
230+ # experts that see 0 tokens, which exercises the empty-slice skip path.
231+ m_per_group_list = [
232+ [1 , 3 , 4 , 16 ], # all experts active
233+ [0 , 3 , 4 , 13 ], # expert 0 sees 0 tokens
234+ [5 , 5 , 0 , 5 ], # expert 2 sees 0 tokens
235+ [2 , 0 , 6 , 4 ], # expert 1 sees 0 tokens
236+ ]
237+
238+ offs_list = [
239+ torch .tensor (
240+ [sum (m_per_group [: i + 1 ]) for i in range (num_experts )],
241+ device = "cuda" ,
242+ dtype = torch .int32 ,
243+ )
244+ for m_per_group in m_per_group_list
245+ ]
246+
247+ inputs = [
248+ torch .randn (sum (m_per_group ), k , dtype = torch .float32 , device = "cuda" )
249+ for m_per_group in m_per_group_list
250+ ]
251+
252+ # 3D path: single observer with _grouped_mm
253+ observer_3d = GPTQObserverTensor .from_hp (weight )
254+ for x , offs in zip (inputs , offs_list ):
255+ torch ._grouped_mm (x , observer_3d .transpose (- 2 , - 1 ), offs = offs )
256+
257+ # 2D path: per-expert observers with F.linear
258+ observers_2d = [
259+ GPTQObserverTensor .from_hp (weight [e ]) for e in range (num_experts )
260+ ]
261+ for x , offs in zip (inputs , offs_list ):
262+ prev_end = 0
263+ for e in range (num_experts ):
264+ end = offs [e ].item ()
265+ if end > prev_end :
266+ F .linear (x [prev_end :end ], observers_2d [e ])
267+ prev_end = end
268+
269+ # Verify per-expert hessians match bitwise to calculating each expert's
270+ # hessian individually
271+ for e in range (num_experts ):
272+ assert torch .equal (observer_3d .hessian [e ], observers_2d [e ].hessian ), (
273+ f"Expert { e } hessian mismatch"
274+ )
275+ assert torch .equal (
276+ observer_3d .total_batches [e : e + 1 ], observers_2d [e ].total_batches
277+ ), f"Expert { e } total_batches mismatch"
278+
279+ # Verify total_batches matches an independent count derived directly
280+ # from the offsets: each non-empty forward pass contributes 1 per
281+ # active expert (each expert's 2D slice has len(shape) == 2, so n=1).
282+ expected_total_batches = torch .tensor (
283+ [
284+ sum (1 for m_per_group in m_per_group_list if m_per_group [e ] > 0 )
285+ for e in range (num_experts )
286+ ],
287+ dtype = torch .int64 ,
288+ device = "cuda" ,
289+ )
290+ assert torch .equal (observer_3d .total_batches , expected_total_batches ), (
291+ f"total_batches { observer_3d .total_batches .tolist ()} "
292+ f"does not match expected { expected_total_batches .tolist ()} "
293+ )
294+
216295 @pytest .mark .skipif (not torch .cuda .is_available (), reason = "Need CUDA available" )
217296 @pytest .mark .parametrize (
218297 "base_config" ,
0 commit comments