|
15 | 15 | from torchao.prototype.mx_formats.kernels import ( |
16 | 16 | f4_unpacked_to_f32, |
17 | 17 | f32_to_f4_unpacked, |
| 18 | + mslk_quantize_nvfp4, |
18 | 19 | pack_uint4, |
19 | | - triton_quantize_nvfp4, |
20 | 20 | unpack_uint4, |
21 | 21 | ) |
22 | 22 | from torchao.prototype.mx_formats.mx_tensor import ( |
@@ -155,7 +155,10 @@ def to_nvfp4( |
155 | 155 | assert K % 16 == 0, ( |
156 | 156 | f"Triton kernel requires K (dim -1) to be divisible by 16, got {K}" |
157 | 157 | ) |
158 | | - blockwise_scales, data_lp = triton_quantize_nvfp4(data_hp, per_tensor_scale) |
| 158 | + assert per_tensor_scale is not None, ( |
| 159 | + "Triton kernel requires per_tensor_scale" |
| 160 | + ) |
| 161 | + blockwise_scales, data_lp = mslk_quantize_nvfp4(data_hp, per_tensor_scale) |
159 | 162 | else: |
160 | 163 | blockwise_scales, data_lp = nvfp4_quantize( |
161 | 164 | data_hp, block_size, per_tensor_scale |
@@ -245,7 +248,7 @@ def get_hp_scales(self) -> torch.Tensor: |
245 | 248 | return ( |
246 | 249 | scale_e4m3.to(self.orig_dtype) |
247 | 250 | if self.per_tensor_scale is None |
248 | | - else self.per_tensor_scale * scale_e4m3.to(self.orig_dtype) |
| 251 | + else scale_e4m3.to(self.orig_dtype) / self.per_tensor_scale |
249 | 252 | ) |
250 | 253 |
|
251 | 254 | @classmethod |
@@ -465,7 +468,7 @@ def _addmm_nvfp4_dispatch( |
465 | 468 | # Merge double quant scales into 1 scale for Scale_In^D |
466 | 469 | if a.per_tensor_scale is not None: |
467 | 470 | assert b.per_tensor_scale is not None |
468 | | - scale_result = a.per_tensor_scale * b.per_tensor_scale |
| 471 | + scale_result = 1.0 / (a.per_tensor_scale * b.per_tensor_scale) |
469 | 472 | else: |
470 | 473 | assert b.per_tensor_scale is None and a.per_tensor_scale is None |
471 | 474 | scale_result = None |
@@ -625,17 +628,17 @@ def nvfp4_addmm(func, types, args, kwargs): |
625 | 628 | def per_tensor_amax_to_scale(amax: torch.Tensor) -> torch.Tensor: |
626 | 629 | """Convert per-tensor amax to per-tensor scale for NVFP4 quantization. |
627 | 630 |
|
628 | | - Divides by both F8E4M3_MAX and F4_E2M1_MAX to ensure block scales can utilize |
629 | | - the full FP8 E4M3 range (up to 448) when block_max equals tensor_max. |
630 | | - Without F4_E2M1_MAX, the maximum scale would only reach FP8_MAX / FP4_MAX. |
| 631 | + Returns the global scale in MSLK convention: (F8E4M3_MAX * F4_E2M1_MAX) / amax. |
| 632 | + This ensures block scales can utilize the full FP8 E4M3 range (up to 448) |
| 633 | + when block_max equals tensor_max. |
631 | 634 |
|
632 | 635 | Args: |
633 | 636 | amax: Per-tensor absolute maximum value from calibration |
634 | 637 |
|
635 | 638 | Returns: |
636 | 639 | torch.Tensor: Per-tensor scale for two-level NVFP4 scaling |
637 | 640 | """ |
638 | | - return amax.to(torch.float32) / (F8E4M3_MAX * F4_E2M1_MAX) |
| 641 | + return (F8E4M3_MAX * F4_E2M1_MAX) / amax.to(torch.float32) |
639 | 642 |
|
640 | 643 |
|
641 | 644 | def nvfp4_quantize( |
@@ -694,15 +697,15 @@ def nvfp4_quantize( |
694 | 697 | # we want the per_tensor_scale ~= amax of the block_scale_fp32 |
695 | 698 | block_scale_fp32 = block_scale.to(torch.float32) |
696 | 699 | # Quantize the blockwise scales w/ the per_tensor_scale |
697 | | - scaled_block_scales = block_scale_fp32 / per_tensor_scale |
| 700 | + scaled_block_scales = block_scale_fp32 * per_tensor_scale |
698 | 701 | scaled_block_scales_fp8 = torch.clamp( |
699 | 702 | scaled_block_scales, min=E4M3_EPS, max=F8E4M3_MAX |
700 | 703 | ).to(torch.float8_e4m3fn) |
701 | 704 | scaled_block_scales_fp32 = scaled_block_scales_fp8.to(torch.float32) |
702 | | - # We "temporarily" dequant the scaled_block_scales_fp32 to get the per_tensor_scale |
703 | | - # To apply to data |
704 | | - total_scale = per_tensor_scale * scaled_block_scales_fp32 |
705 | | - data_scaled = data_hp / total_scale.unsqueeze(-1) |
| 705 | + # Multiply by reciprocal of combined scale instead of dividing, |
| 706 | + # to match the MSLK triton kernel numerics: x * (global_scale / fp8_scale) |
| 707 | + reciprocal_scale = per_tensor_scale / scaled_block_scales_fp32 |
| 708 | + data_scaled = data_hp * reciprocal_scale.unsqueeze(-1) |
706 | 709 | out_scales = scaled_block_scales_fp8 |
707 | 710 |
|
708 | 711 | data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX) |
|
0 commit comments