Skip to content

Commit b09e974

Browse files
committed
[wip] perf improvements
Summary: Test Plan: ghstack-source-id: 7824ff7 ghstack-comment-id: 4316156896 Pull-Request: #4331
1 parent 12cd338 commit b09e974

2 files changed

Lines changed: 34 additions & 9 deletions

File tree

torchao/prototype/gptq/api.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -274,12 +274,27 @@ def _nvfp4_with_precalculated_scales_q(
274274
return data_lp_packed
275275

276276

277+
def _nvfp4_gptq_inner_loop(
278+
w_t,
279+
nvfp4_global_scale,
280+
scale,
281+
Hinv_cur_k_k,
282+
):
283+
dq = _nvfp4_with_precalculated_scales_qdq(
284+
w_t,
285+
nvfp4_global_scale,
286+
scale.squeeze(-1),
287+
)
288+
err1 = (w_t - dq) / Hinv_cur_k_k
289+
return err1
290+
291+
277292
# Set to True to torch.compile the NVFP4 quantize/dequantize functions
278293
# inside gptq_quantize. Gives ~3x speedup.
279294
_use_torch_compile = True
280295

281296
if _use_torch_compile:
282-
_nvfp4_qdq_fn = torch.compile(_nvfp4_with_precalculated_scales_qdq)
297+
_nvfp4_gptq_inner_loop_fn = torch.compile(_nvfp4_gptq_inner_loop)
283298
_nvfp4_q_fn = torch.compile(_nvfp4_with_precalculated_scales_q)
284299

285300
if torch_version_at_least("2.11.0"):
@@ -304,7 +319,7 @@ def _nvfp4_with_precalculated_scales_q(
304319
"division rounding)."
305320
)
306321
else:
307-
_nvfp4_qdq_fn = _nvfp4_with_precalculated_scales_qdq
322+
_nvfp4_gptq_inner_loop_fn = _nvfp4_gptq_inner_loop
308323
_nvfp4_q_fn = _nvfp4_with_precalculated_scales_q
309324

310325

@@ -507,21 +522,22 @@ def gptq_quantize(H: torch.Tensor, W_t: torch.Tensor, config: GPTQConfig):
507522
w_t, scale, zero_point, group_size
508523
)
509524
dq = _int4_row_dequantize_zp(q, scale, zero_point, group_size)
525+
err1 = (w_t - dq) / Hinv_cur[k, k]
526+
510527
elif isinstance(base_config, Int8WeightOnlyConfig):
511528
q = Int8Tensor.from_hp(
512529
w_t,
513530
granularity=base_config.granularity,
514531
scale=quantized_tensor.scale,
515532
)
516533
dq = q.dequantize(output_dtype=torch.float)
534+
err1 = (w_t - dq) / Hinv_cur[k, k]
535+
517536
elif isinstance(base_config, NVFP4DynamicActivationNVFP4WeightConfig):
518-
dq = _nvfp4_qdq_fn(
519-
w_t,
520-
nvfp4_global_scale,
521-
scale.squeeze(-1),
537+
Hinv_cur_k_k = Hinv_cur[k, k]
538+
err1 = _nvfp4_gptq_inner_loop_fn(
539+
w_t, nvfp4_global_scale, scale, Hinv_cur_k_k
522540
)
523-
524-
err1 = (w_t - dq) / Hinv_cur[k, k]
525541
B_cur[:, k:] -= err1.matmul(Hinv_cur[k, k:].unsqueeze(0))
526542
B_cur_Err1[:, k] = err1.flatten()
527543

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,16 @@ def _to_mx_rceil(
204204
data_lp = data_hp * rcp_fp32
205205

206206
# Note: clamp preserves NaN values
207-
data_lp = torch.clamp(data_lp, min=-max_pos, max=max_pos)
207+
if not (torch.compiler.is_compiling() or is_fake(descale)):
208+
# As of 20250317, the Pytorch eager mode cast to `torch.float8_e4m3fn`
209+
# is unsaturated. This cast is saturated in triton. If we are compute bound,
210+
# we see a speedup if we remove this redundant clamp if we are compiling
211+
# to triton.
212+
# TODO(#1912): make the saturated cast work in eager mode and remove this
213+
# workaround.
214+
# TODO(future PR): unify this code between the FLOOR and RCEIL scaling
215+
# methods
216+
data_lp = torch.clamp(data_lp, min=-max_pos, max=max_pos)
208217

209218
return exponent, data_lp
210219

0 commit comments

Comments
 (0)