@@ -274,12 +274,27 @@ def _nvfp4_with_precalculated_scales_q(
274274 return data_lp_packed
275275
276276
277+ def _nvfp4_gptq_inner_loop (
278+ w_t ,
279+ nvfp4_global_scale ,
280+ scale ,
281+ Hinv_cur_k_k ,
282+ ):
283+ dq = _nvfp4_with_precalculated_scales_qdq (
284+ w_t ,
285+ nvfp4_global_scale ,
286+ scale .squeeze (- 1 ),
287+ )
288+ err1 = (w_t - dq ) / Hinv_cur_k_k
289+ return err1
290+
291+
277292# Set to True to torch.compile the NVFP4 quantize/dequantize functions
278293# inside gptq_quantize. Gives ~3x speedup.
279294_use_torch_compile = True
280295
281296if _use_torch_compile :
282- _nvfp4_qdq_fn = torch .compile (_nvfp4_with_precalculated_scales_qdq )
297+ _nvfp4_gptq_inner_loop_fn = torch .compile (_nvfp4_gptq_inner_loop )
283298 _nvfp4_q_fn = torch .compile (_nvfp4_with_precalculated_scales_q )
284299
285300 if torch_version_at_least ("2.11.0" ):
@@ -304,7 +319,7 @@ def _nvfp4_with_precalculated_scales_q(
304319 "division rounding)."
305320 )
306321else :
307- _nvfp4_qdq_fn = _nvfp4_with_precalculated_scales_qdq
322+ _nvfp4_gptq_inner_loop_fn = _nvfp4_gptq_inner_loop
308323 _nvfp4_q_fn = _nvfp4_with_precalculated_scales_q
309324
310325
@@ -507,21 +522,22 @@ def gptq_quantize(H: torch.Tensor, W_t: torch.Tensor, config: GPTQConfig):
507522 w_t , scale , zero_point , group_size
508523 )
509524 dq = _int4_row_dequantize_zp (q , scale , zero_point , group_size )
525+ err1 = (w_t - dq ) / Hinv_cur [k , k ]
526+
510527 elif isinstance (base_config , Int8WeightOnlyConfig ):
511528 q = Int8Tensor .from_hp (
512529 w_t ,
513530 granularity = base_config .granularity ,
514531 scale = quantized_tensor .scale ,
515532 )
516533 dq = q .dequantize (output_dtype = torch .float )
534+ err1 = (w_t - dq ) / Hinv_cur [k , k ]
535+
517536 elif isinstance (base_config , NVFP4DynamicActivationNVFP4WeightConfig ):
518- dq = _nvfp4_qdq_fn (
519- w_t ,
520- nvfp4_global_scale ,
521- scale .squeeze (- 1 ),
537+ Hinv_cur_k_k = Hinv_cur [k , k ]
538+ err1 = _nvfp4_gptq_inner_loop_fn (
539+ w_t , nvfp4_global_scale , scale , Hinv_cur_k_k
522540 )
523-
524- err1 = (w_t - dq ) / Hinv_cur [k , k ]
525541 B_cur [:, k :] -= err1 .matmul (Hinv_cur [k , k :].unsqueeze (0 ))
526542 B_cur_Err1 [:, k ] = err1 .flatten ()
527543
0 commit comments