@@ -108,31 +108,7 @@ def test_observer_tensor_creation(self):
108108 )
109109
110110 # Check total_batches is initialized as 0
111- assert observer .total_batches == 0
112-
113- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "Need CUDA available" )
114- def test_observer_tensor_attributes (self ):
115- """Test GPTQObserverTensor attributes are correctly set."""
116- weight = torch .randn (16 , 32 , dtype = torch .bfloat16 , device = "cuda" )
117- observer = GPTQObserverTensor .from_hp (weight )
118-
119- # Test hp_data attribute
120- assert hasattr (observer , "hp_data" )
121- assert isinstance (observer .hp_data , torch .Tensor )
122-
123- # Test hessian attribute
124- assert hasattr (observer , "hessian" )
125- assert torch .equal (
126- observer .hessian , torch .zeros (32 , 32 , dtype = torch .float32 , device = "cuda" )
127- )
128-
129- # Test total_batches attribute
130- assert hasattr (observer , "total_batches" )
131- assert observer .total_batches == 0
132-
133- # Test update method exists
134- assert hasattr (observer , "update" )
135- assert callable (observer .update )
111+ assert (observer .total_batches == 0 ).all ()
136112
137113 @pytest .mark .skipif (not torch .cuda .is_available (), reason = "Need CUDA available" )
138114 def test_linear_operation_with_observer (self ):
@@ -161,7 +137,7 @@ def test_linear_operation_with_observer(self):
161137 # Check that Hessian was initialized and updated
162138 assert observer_weight .hessian is not None
163139 assert observer_weight .hessian .shape == (in_features , in_features )
164- assert observer_weight .total_batches == 1
140+ assert ( observer_weight .total_batches == 1 ). all ()
165141
166142 # Verify output is correct
167143 expected_output = F .linear (input_tensor , weight )
@@ -195,36 +171,47 @@ def test_multiple_observations(self):
195171 assert observer_weight .hessian .shape == (in_features , in_features )
196172
197173 # Check total_batches matches total samples
198- assert observer_weight .total_batches == total_samples
174+ assert ( observer_weight .total_batches == total_samples ). all ()
199175
200- @pytest .mark .skip (reason = "bmm math is incorrect, will fix in next PR" )
201176 @pytest .mark .skipif (not torch .cuda .is_available (), reason = "Need CUDA available" )
202177 def test_bmm_operation_with_observer (self ):
203178 """Test torch.bmm with GPTQObserverTensor updates Hessian correctly."""
204- batch = 4
179+ num_experts = 4
205180 m = 8
206181 n = 16
207182 k = 12
183+ num_passes = 4
208184
209185 # Create input and weight tensors
210- input_tensor = torch .randn (batch , m , k , dtype = torch .float32 , device = "cuda" )
211- weight = torch .randn (batch , k , n , dtype = torch .float32 , device = "cuda" )
212- observer_weight = GPTQObserverTensor .from_hp (weight )
186+ weight = torch .randn (num_experts , n , k , dtype = torch .float32 , device = "cuda" )
213187
214- # Perform bmm operation
215- output = torch .bmm (input_tensor , observer_weight )
188+ inputs = [
189+ torch .randn (num_experts , m , k , dtype = torch .float32 , device = "cuda" )
190+ for _ in range (num_passes )
191+ ]
216192
217- # Check output shape
218- assert output .shape == (batch , m , n )
193+ # 3D path: single observer with bmm
194+ observer_3d = GPTQObserverTensor .from_hp (weight )
195+ for x in inputs :
196+ torch .bmm (x , observer_3d .transpose (- 2 , - 1 ))
219197
220- # Check Hessian was initialized and updated
221- assert observer_weight .hessian is not None
222- # For bmm with batch dimension, the Hessian is computed on the last dimension
223- assert observer_weight .total_batches == batch
224-
225- # Verify output is correct
226- expected_output = torch .bmm (input_tensor , weight )
227- torch .testing .assert_close (output , expected_output )
198+ # 2D path: per-expert observers with F.linear
199+ observers_2d = [
200+ GPTQObserverTensor .from_hp (weight [e ]) for e in range (num_experts )
201+ ]
202+ for x in inputs :
203+ for e in range (num_experts ):
204+ F .linear (x [e ], observers_2d [e ])
205+
206+ # Verify per-expert hessians match bitwise to calculating each expert's
207+ # hessian individually
208+ for e in range (num_experts ):
209+ assert torch .equal (observer_3d .hessian [e ], observers_2d [e ].hessian ), (
210+ f"Expert { e } hessian mismatch"
211+ )
212+ assert torch .equal (
213+ observer_3d .total_batches [e : e + 1 ], observers_2d [e ].total_batches
214+ ), f"Expert { e } total_batches mismatch"
228215
229216 @pytest .mark .skipif (not torch .cuda .is_available (), reason = "Need CUDA available" )
230217 @pytest .mark .parametrize (
@@ -260,15 +247,15 @@ def test_observer_config_transform(self, base_config):
260247 linear .weight .hessian ,
261248 torch .zeros (64 , 64 , dtype = torch .float32 , device = "cuda" ),
262249 )
263- assert linear .weight .total_batches == 0
250+ assert ( linear .weight .total_batches == 0 ). all ()
264251
265252 # Perform a forward pass
266253 input_tensor = torch .randn (4 , 64 , dtype = torch .float32 , device = "cuda" )
267254 output = linear (input_tensor )
268255
269256 # Check Hessian was initialized after forward pass
270257 assert linear .weight .hessian is not None
271- assert linear .weight .total_batches == 1
258+ assert ( linear .weight .total_batches == 1 ). all ()
272259
273260 # Check output shape
274261 assert output .shape == (4 , 32 )
@@ -384,6 +371,9 @@ def test_unified_config_two_phase(self, base_config):
384371 )
385372 def test_gptq_quantize_function (self , base_config ):
386373 """Test gptq_quantize function with synthetic Hessian and weights."""
374+ if isinstance (base_config , Int4WeightOnlyConfig ) and is_sm_at_least_100 ():
375+ pytest .skip ("int4 kernels do not work on sm100" )
376+
387377 torch .manual_seed (42 )
388378
389379 # Create synthetic weight matrix
@@ -556,6 +546,8 @@ def test_gptq_sqnr(self, base_config):
556546 and not is_sm_at_least_100 ()
557547 ):
558548 pytest .skip ("CUDA capability >= 10.0 required for nvfp4" )
549+ if isinstance (base_config , Int4WeightOnlyConfig ) and is_sm_at_least_100 ():
550+ pytest .skip ("int4 kernels do not work on sm100" )
559551
560552 torch .manual_seed (43 )
561553
0 commit comments