From 3e418102a81d00816d778117052430d48846b7c0 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Mon, 9 Mar 2026 20:40:32 +0000 Subject: [PATCH 1/8] Update [ghstack-poisoned] --- .../float8/float8_inference_roofline.py | 10 ++++--- docs/source/workflows/inference.md | 26 +++++++++++++++---- torchao/testing/training/roofline_utils.py | 15 +++++++++++ 3 files changed, 43 insertions(+), 8 deletions(-) diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index 0dbcfab2f7..a786f6a626 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -112,7 +112,7 @@ def get_gemm_times( bf16_time_s = get_gpu_kernel_gemm_time_s(torch.mm, x_bf16, w_bf16) - if recipe_name in ("mxfp4_cutlass", "nvfp4"): + if recipe_name in ("mxfp4_cutlass", "nvfp4", "nvfp4_static"): d1, d2, d3 = torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2, torch.bfloat16 A = torch.randint(0, 255, (M, K // 2), device=device, dtype=torch.uint8).view( d1 @@ -151,7 +151,7 @@ def get_gemm_times( scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu) scale_a = to_blocked(scale_a) scale_b = to_blocked(scale_b) - elif recipe_name == "nvfp4": + elif recipe_name in ("nvfp4", "nvfp4_static"): scale_a = torch.ones(M, K // 16, device=device, dtype=torch.float8_e4m3fn) scale_b = torch.ones(N, K // 16, device=device, dtype=torch.float8_e4m3fn) scale_a = to_blocked(scale_a) @@ -177,7 +177,7 @@ def do_matmul(A, B): swizzle_b=SwizzleType.SWIZZLE_32_4_4, output_dtype=d3, ) - if recipe_name == "nvfp4": + if recipe_name in ("nvfp4", "nvfp4_static"): return torch._scaled_mm( A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False ) @@ -794,6 +794,10 @@ def run( kernel_preference=KernelPreference.AUTO, ) elif recipe_name == "nvfp4": + config = NVFP4DynamicActivationNVFP4WeightConfig( + use_dynamic_per_tensor_scale=True, + ) + elif recipe_name == "nvfp4_static": config = NVFP4DynamicActivationNVFP4WeightConfig( use_dynamic_per_tensor_scale=False, ) diff --git a/docs/source/workflows/inference.md b/docs/source/workflows/inference.md index 5958db6acb..f0586ff91d 100644 --- a/docs/source/workflows/inference.md +++ b/docs/source/workflows/inference.md @@ -139,11 +139,11 @@ torch version 2.12.0.dev20260218+cu130 torchao version 0.17.0+git3075bb624 ... fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp -0 1024 1024 1024 1.00 0.93 -1 2048 2048 2048 1.75 1.20 -2 4096 4096 4096 1.90 1.46 -3 8192 8192 8192 1.94 1.76 -4 16384 16384 16384 1.97 1.77 +0 1024 1024 1024 1.00 0.31 +1 2048 2048 2048 1.75 0.52 +2 4096 4096 4096 1.90 0.90 +3 8192 8192 8192 1.94 1.41 +4 16384 16384 16384 1.97 2.14 # # nvfp4 with dynamic global scaling @@ -160,6 +160,22 @@ torchao version 0.17.0+git3075bb624 2 4096 4096 4096 2.92 1.19 3 8192 8192 8192 3.34 1.80 4 16384 16384 16384 3.63 2.56 + +# +# nvfp4 with static global scaling (user API in progress) +# +> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True +... +GPU NVIDIA B200 +torch version 2.12.0.dev20260218+cu130 +torchao version 0.17.0+git3075bb624 +... + fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp +0 1024 1024 1024 1.00 0.38 +1 2048 2048 2048 2.39 0.73 +2 4096 4096 4096 2.92 1.19 +3 8192 8192 8192 3.34 1.80 +4 16384 16384 16384 3.63 2.56 ``` ## Other Available Quantization Techniques diff --git a/torchao/testing/training/roofline_utils.py b/torchao/testing/training/roofline_utils.py index 9601a2186e..74f549cc88 100644 --- a/torchao/testing/training/roofline_utils.py +++ b/torchao/testing/training/roofline_utils.py @@ -519,6 +519,21 @@ def get_inference_tensor_memory_traffic_ovhd_s( ) res_bytes = [kernel_1_rw + kernel_3_rw] + case "nvfp4_static": + # nvfp4 with static global scaling + # x_b16 = ... + # static_max_abs = ... + # kernel 1: x_bf16, static_max_abs -> to_nvfp4 -> x_nvfp4 + kernel_1_rw = ( + # read bf16 + BYTES_PER_EL_BF16 * numel + # write fp4_x2 qdata + + BYTES_PER_EL_FLOAT4 * numel + # write e8m0 scale + + BYTES_PER_EL_FLOAT8 * dim0 * (dim1 // 16) + ) + res_bytes = [kernel_1_rw] + case _: raise ValueError( f"Unknown recipe name: {recipe_name}. " From 15ffe10e6697941babcb3612f8e48b3688252a53 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Mon, 9 Mar 2026 21:46:05 +0000 Subject: [PATCH 2/8] Update [ghstack-poisoned] --- .../float8/float8_inference_roofline.py | 15 +++++++++- docs/source/workflows/inference.md | 30 +++++++++---------- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index a786f6a626..5ada83453d 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -798,13 +798,26 @@ def run( use_dynamic_per_tensor_scale=True, ) elif recipe_name == "nvfp4_static": + config_calib = NVFP4DynamicActivationNVFP4WeightConfig( + step="prepare", + ) config = NVFP4DynamicActivationNVFP4WeightConfig( - use_dynamic_per_tensor_scale=False, + step="convert", ) else: assert False, "unsupported" m_fp8_dyn = copy.deepcopy(m_orig) + + if recipe_name == "nvfp4_static": + # calibrate with sample data + # this benchmark is performance-only, so a toy datum is fine + quantize_(m_fp8_dyn, config_calib) + toy_datum = torch.randn( + M_val, K_val, dtype=torch.bfloat16, device="cuda" + ) + m_fp8_dyn(toy_datum) + if op_name == "linear": quantize_(m_fp8_dyn, config) elif op_name == "conv2d": diff --git a/docs/source/workflows/inference.md b/docs/source/workflows/inference.md index f0586ff91d..1705503486 100644 --- a/docs/source/workflows/inference.md +++ b/docs/source/workflows/inference.md @@ -139,11 +139,11 @@ torch version 2.12.0.dev20260218+cu130 torchao version 0.17.0+git3075bb624 ... fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp -0 1024 1024 1024 1.00 0.31 -1 2048 2048 2048 1.75 0.52 -2 4096 4096 4096 1.90 0.90 -3 8192 8192 8192 1.94 1.41 -4 16384 16384 16384 1.97 2.14 +0 1024 1024 1024 1.00 0.93 +1 2048 2048 2048 1.75 1.20 +2 4096 4096 4096 1.90 1.46 +3 8192 8192 8192 1.94 1.76 +4 16384 16384 16384 1.97 1.77 # # nvfp4 with dynamic global scaling @@ -155,11 +155,11 @@ torch version 2.12.0.dev20260218+cu130 torchao version 0.17.0+git3075bb624 ... fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp -0 1024 1024 1024 1.00 0.38 -1 2048 2048 2048 2.39 0.73 -2 4096 4096 4096 2.92 1.19 -3 8192 8192 8192 3.34 1.80 -4 16384 16384 16384 3.63 2.56 +0 1024 1024 1024 1.00 0.28 +1 2048 2048 2048 2.36 0.52 +2 4096 4096 4096 2.89 0.90 +3 8192 8192 8192 3.32 1.41 +4 16384 16384 16384 3.62 2.14 # # nvfp4 with static global scaling (user API in progress) @@ -171,11 +171,11 @@ torch version 2.12.0.dev20260218+cu130 torchao version 0.17.0+git3075bb624 ... fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp -0 1024 1024 1024 1.00 0.38 -1 2048 2048 2048 2.39 0.73 -2 4096 4096 4096 2.92 1.19 -3 8192 8192 8192 3.34 1.80 -4 16384 16384 16384 3.63 2.56 +0 1024 1024 1024 1.00 0.34 +1 2048 2048 2048 2.74 0.64 +2 4096 4096 4096 3.42 1.06 +3 8192 8192 8192 3.67 1.58 +4 16384 16384 16384 3.82 2.31 ``` ## Other Available Quantization Techniques From d32b09b57e6f2434f1b20893cd8ff1f3f8fb0eee Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Mon, 9 Mar 2026 21:46:05 +0000 Subject: [PATCH 3/8] Update [ghstack-poisoned] --- benchmarks/mx_formats/cast_bench.py | 6 +- docs/source/workflows/inference.md | 20 +++---- .../prototype/mx_formats/test_nvfp4_tensor.py | 58 ++++++++++++++++++- torchao/prototype/mx_formats/kernels.py | 54 +++++++++++++++++ torchao/prototype/mx_formats/nvfp4_tensor.py | 15 +++-- 5 files changed, 135 insertions(+), 18 deletions(-) diff --git a/benchmarks/mx_formats/cast_bench.py b/benchmarks/mx_formats/cast_bench.py index 54cc76b617..7637c104b2 100644 --- a/benchmarks/mx_formats/cast_bench.py +++ b/benchmarks/mx_formats/cast_bench.py @@ -83,8 +83,12 @@ def to_nvfp4_reference(x_hp): def to_nvfp4_reference_triton_swizzle(x_hp): + per_tensor_scale = torch.tensor(1.0, dtype=torch.float32, device=x_hp.device) nvfp4_tensor = NVFP4Tensor.to_nvfp4( - x_hp, use_triton_kernel=True, is_swizzled_scales=True + x_hp, + per_tensor_scale=per_tensor_scale, + use_triton_kernel=True, + is_swizzled_scales=True, ) return nvfp4_tensor.qdata, nvfp4_tensor.scale diff --git a/docs/source/workflows/inference.md b/docs/source/workflows/inference.md index 1705503486..3137467f67 100644 --- a/docs/source/workflows/inference.md +++ b/docs/source/workflows/inference.md @@ -155,11 +155,11 @@ torch version 2.12.0.dev20260218+cu130 torchao version 0.17.0+git3075bb624 ... fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp -0 1024 1024 1024 1.00 0.28 -1 2048 2048 2048 2.36 0.52 -2 4096 4096 4096 2.89 0.90 -3 8192 8192 8192 3.32 1.41 -4 16384 16384 16384 3.62 2.14 +0 1024 1024 1024 1.00 0.39 +1 2048 2048 2048 2.36 0.68 +2 4096 4096 4096 2.89 1.27 +3 8192 8192 8192 3.32 1.93 +4 16384 16384 16384 3.62 2.73 # # nvfp4 with static global scaling (user API in progress) @@ -171,11 +171,11 @@ torch version 2.12.0.dev20260218+cu130 torchao version 0.17.0+git3075bb624 ... fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp -0 1024 1024 1024 1.00 0.34 -1 2048 2048 2048 2.74 0.64 -2 4096 4096 4096 3.42 1.06 -3 8192 8192 8192 3.67 1.58 -4 16384 16384 16384 3.82 2.31 +0 1024 1024 1024 1.00 0.48 +1 2048 2048 2048 2.74 0.88 +2 4096 4096 4096 3.42 1.62 +3 8192 8192 8192 3.67 2.27 +4 16384 16384 16384 3.82 2.98 ``` ## Other Available Quantization Techniques diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py index ef017d7e71..eb254471cc 100644 --- a/test/prototype/mx_formats/test_nvfp4_tensor.py +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -8,6 +8,9 @@ import pytest import torch import torch.nn.functional as F +from mslk.quantize.triton.fp4_quantize import ( + triton_quantize_nvfp4 as mslk_triton_quantize_nvfp4, +) from torchao.prototype.mx_formats.constants import ( F4_E2M1_MAX, @@ -369,6 +372,8 @@ def test_nvfp4_swizzled_scales_get_scales_method(): @torch.no_grad() def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype): """Test that Triton and PyTorch NVFP4 quantization produce equivalent results.""" + if not use_per_tensor_scale: + pytest.skip("MSLK triton kernel requires per_tensor_scale") torch.manual_seed(42) x = torch.randn(M, N, dtype=dtype, device="cuda") @@ -413,6 +418,51 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype): ) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize( + "M", [128, 256, 512, 1024, 100, 200, 384], ids=lambda m: f"M{m}" +) +@pytest.mark.parametrize("N", [64, 128, 256, 512, 32, 96, 160], ids=lambda n: f"N{n}") +@pytest.mark.skipif( + not is_sm_at_least_100(), reason="requires sm100+ for nvfp4 triton kernel" +) +@torch.no_grad() +def test_mslk_nvfp4_numerics(M, N): + """Test that MSLK triton and TorchAO PyTorch NVFP4 quantization produce bitwise equal results.""" + dtype = torch.bfloat16 + + torch.manual_seed(42) + x = torch.randn(M, N, dtype=dtype, device="cuda") + + per_tensor_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(x))) + # MSLK expects global_scale as the reciprocal of TorchAO's per_tensor_scale + mslk_global_scale = 1.0 / per_tensor_scale + + # Quantize with TorchAO PyTorch path + nvfp4_pt = NVFP4Tensor.to_nvfp4( + x.clone(), + per_tensor_scale=per_tensor_scale, + is_swizzled_scales=True, + use_triton_kernel=False, + ) + + # Quantize with MSLK triton kernel (returns qdata, scales in swizzled layout) + mslk_qdata, mslk_scales = mslk_triton_quantize_nvfp4(x.clone(), mslk_global_scale) + + # Compare swizzled scales (bitwise equality) + torch.testing.assert_close( + nvfp4_pt.scale.view(torch.uint8).flatten(), + mslk_scales.view(torch.uint8).flatten(), + atol=0, + rtol=0, + ) + + # Compare unpacked qdata (bitwise equality) + pt_unpacked = unpack_uint4(nvfp4_pt.qdata) + mslk_unpacked = unpack_uint4(mslk_qdata.view(torch.uint8)) + torch.testing.assert_close(pt_unpacked, mslk_unpacked, atol=0, rtol=0) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" @@ -559,8 +609,14 @@ def test_scale_shape_matches_qdata( block_size = 16 x_hp = torch.randn(*shape, device="cuda") + + per_tensor_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(x_hp))) + x = NVFP4Tensor.to_nvfp4( - x_hp, is_swizzled_scales=is_swizzled_scales, use_triton_kernel=use_triton_kernel + x_hp, + per_tensor_scale=per_tensor_scale, + is_swizzled_scales=is_swizzled_scales, + use_triton_kernel=use_triton_kernel, ) if len(shape) == 2: diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index c7196edc12..323b391454 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1327,3 +1327,57 @@ def mxfp8_quantize_cuda( raise NotImplementedError( "`mxfp8_quantize_cuda` needs (1) torch 2.8+ and (2) torchao built from source on a machine with CUDA capability 10.0+. Please see https://github.com/pytorch/ao/issues/2932 for more details." ) + + +try: + from mslk.quantize.triton.fp4_quantize import ( + triton_quantize_nvfp4 as _mslk_triton_quantize_nvfp4, + ) + + _mslk_available = True +except ImportError: + _mslk_available = False + + +@torch.library.custom_op("ao::mslk_quantize_nvfp4", mutates_args=()) +def mslk_quantize_nvfp4( + x: torch.Tensor, per_tensor_scale: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize a tensor to NVFP4 using the MSLK triton kernel. + + Args: + x: Input tensor to quantize. + per_tensor_scale: Per-tensor scale (TorchAO convention: amax / (F8E4M3_MAX * F4_E2M1_MAX)). + + Returns: + Tuple of (blockwise_scales, quantized_data_uint8) matching TorchAO's convention. + """ + assert _mslk_available, ( + "mslk is required for NVFP4 triton quantization. " + "Install from https://github.com/pytorch/MSLK" + ) + mslk_global_scale = 1.0 / per_tensor_scale + data_lp, blockwise_scales = _mslk_triton_quantize_nvfp4(x, mslk_global_scale) + return blockwise_scales, data_lp.view(torch.uint8) + + +@mslk_quantize_nvfp4.register_fake +def _(x, per_tensor_scale): + # Mirror the reshape logic from the real MSLK kernel + orig_leading_dims, orig_N = x.shape[:-2], x.shape[-1] + x_2d = x.reshape(-1, orig_N) + M, N = x_2d.shape + + num_scales = N // 16 + n_row_blocks = triton.cdiv(M, 128) + n_col_blocks = triton.cdiv(num_scales, 4) + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + scales = x.new_empty(padded_rows, padded_cols, dtype=torch.float8_e4m3fn) + xq = x.new_empty(M, N // 2, dtype=torch.uint8) + + # Reshape back to match original leading dims + scales = scales.view(*orig_leading_dims, -1, padded_cols) + xq = xq.view(*orig_leading_dims, -1, N // 2) + return scales, xq diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index e5fa34b435..f60b463104 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -15,8 +15,8 @@ from torchao.prototype.mx_formats.kernels import ( f4_unpacked_to_f32, f32_to_f4_unpacked, + mslk_quantize_nvfp4, pack_uint4, - triton_quantize_nvfp4, unpack_uint4, ) from torchao.prototype.mx_formats.mx_tensor import ( @@ -155,7 +155,10 @@ def to_nvfp4( assert K % 16 == 0, ( f"Triton kernel requires K (dim -1) to be divisible by 16, got {K}" ) - blockwise_scales, data_lp = triton_quantize_nvfp4(data_hp, per_tensor_scale) + assert per_tensor_scale is not None, ( + "Triton kernel requires per_tensor_scale" + ) + blockwise_scales, data_lp = mslk_quantize_nvfp4(data_hp, per_tensor_scale) else: blockwise_scales, data_lp = nvfp4_quantize( data_hp, block_size, per_tensor_scale @@ -699,10 +702,10 @@ def nvfp4_quantize( scaled_block_scales, min=E4M3_EPS, max=F8E4M3_MAX ).to(torch.float8_e4m3fn) scaled_block_scales_fp32 = scaled_block_scales_fp8.to(torch.float32) - # We "temporarily" dequant the scaled_block_scales_fp32 to get the per_tensor_scale - # To apply to data - total_scale = per_tensor_scale * scaled_block_scales_fp32 - data_scaled = data_hp / total_scale.unsqueeze(-1) + # Multiply by reciprocal of combined scale instead of dividing, + # to match the MSLK triton kernel numerics: x * (global_scale / fp8_scale) + reciprocal_scale = (1.0 / per_tensor_scale) / scaled_block_scales_fp32 + data_scaled = data_hp * reciprocal_scale.unsqueeze(-1) out_scales = scaled_block_scales_fp8 data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX) From 317362275fcc2f706de01bbeb48083384a01a170 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Tue, 10 Mar 2026 13:06:53 +0000 Subject: [PATCH 4/8] Update [ghstack-poisoned] --- .../mx_formats/test_inference_workflow.py | 2 +- torchao/prototype/mx_formats/kernels.py | 9 ++++----- torchao/prototype/mx_formats/nvfp4_tensor.py | 16 ++++++++-------- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index de41616451..cdc2b31aa5 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -217,7 +217,7 @@ def test_inference_workflow_nvfp4( y_ref = m(x) if use_triton_kernel and quant_type == "dynamic": - with cuda_kernel_profiler("quantize_nvfp4_triton_kernel") as result: + with cuda_kernel_profiler("triton_quantize_nvfp4_kernel") as result: y_mx = m_mx(x) assert result["found"], "Expected quantize_nvfp4 kernel to be found" else: diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 323b391454..e945f71852 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1341,13 +1341,13 @@ def mxfp8_quantize_cuda( @torch.library.custom_op("ao::mslk_quantize_nvfp4", mutates_args=()) def mslk_quantize_nvfp4( - x: torch.Tensor, per_tensor_scale: torch.Tensor + x: torch.Tensor, global_scale: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Quantize a tensor to NVFP4 using the MSLK triton kernel. Args: x: Input tensor to quantize. - per_tensor_scale: Per-tensor scale (TorchAO convention: amax / (F8E4M3_MAX * F4_E2M1_MAX)). + global_scale: Global scale in MSLK convention (1.0 / per_tensor_scale). Returns: Tuple of (blockwise_scales, quantized_data_uint8) matching TorchAO's convention. @@ -1356,13 +1356,12 @@ def mslk_quantize_nvfp4( "mslk is required for NVFP4 triton quantization. " "Install from https://github.com/pytorch/MSLK" ) - mslk_global_scale = 1.0 / per_tensor_scale - data_lp, blockwise_scales = _mslk_triton_quantize_nvfp4(x, mslk_global_scale) + data_lp, blockwise_scales = _mslk_triton_quantize_nvfp4(x, global_scale) return blockwise_scales, data_lp.view(torch.uint8) @mslk_quantize_nvfp4.register_fake -def _(x, per_tensor_scale): +def _(x, global_scale): # Mirror the reshape logic from the real MSLK kernel orig_leading_dims, orig_N = x.shape[:-2], x.shape[-1] x_2d = x.reshape(-1, orig_N) diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index f60b463104..13105d9a42 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -248,7 +248,7 @@ def get_hp_scales(self) -> torch.Tensor: return ( scale_e4m3.to(self.orig_dtype) if self.per_tensor_scale is None - else self.per_tensor_scale * scale_e4m3.to(self.orig_dtype) + else scale_e4m3.to(self.orig_dtype) / self.per_tensor_scale ) @classmethod @@ -468,7 +468,7 @@ def _addmm_nvfp4_dispatch( # Merge double quant scales into 1 scale for Scale_In^D if a.per_tensor_scale is not None: assert b.per_tensor_scale is not None - scale_result = a.per_tensor_scale * b.per_tensor_scale + scale_result = 1.0 / (a.per_tensor_scale * b.per_tensor_scale) else: assert b.per_tensor_scale is None and a.per_tensor_scale is None scale_result = None @@ -628,9 +628,9 @@ def nvfp4_addmm(func, types, args, kwargs): def per_tensor_amax_to_scale(amax: torch.Tensor) -> torch.Tensor: """Convert per-tensor amax to per-tensor scale for NVFP4 quantization. - Divides by both F8E4M3_MAX and F4_E2M1_MAX to ensure block scales can utilize - the full FP8 E4M3 range (up to 448) when block_max equals tensor_max. - Without F4_E2M1_MAX, the maximum scale would only reach FP8_MAX / FP4_MAX. + Returns the global scale in MSLK convention: (F8E4M3_MAX * F4_E2M1_MAX) / amax. + This ensures block scales can utilize the full FP8 E4M3 range (up to 448) + when block_max equals tensor_max. Args: amax: Per-tensor absolute maximum value from calibration @@ -638,7 +638,7 @@ def per_tensor_amax_to_scale(amax: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: Per-tensor scale for two-level NVFP4 scaling """ - return amax.to(torch.float32) / (F8E4M3_MAX * F4_E2M1_MAX) + return (F8E4M3_MAX * F4_E2M1_MAX) / amax.to(torch.float32) def nvfp4_quantize( @@ -697,14 +697,14 @@ def nvfp4_quantize( # we want the per_tensor_scale ~= amax of the block_scale_fp32 block_scale_fp32 = block_scale.to(torch.float32) # Quantize the blockwise scales w/ the per_tensor_scale - scaled_block_scales = block_scale_fp32 / per_tensor_scale + scaled_block_scales = block_scale_fp32 * per_tensor_scale scaled_block_scales_fp8 = torch.clamp( scaled_block_scales, min=E4M3_EPS, max=F8E4M3_MAX ).to(torch.float8_e4m3fn) scaled_block_scales_fp32 = scaled_block_scales_fp8.to(torch.float32) # Multiply by reciprocal of combined scale instead of dividing, # to match the MSLK triton kernel numerics: x * (global_scale / fp8_scale) - reciprocal_scale = (1.0 / per_tensor_scale) / scaled_block_scales_fp32 + reciprocal_scale = per_tensor_scale / scaled_block_scales_fp32 data_scaled = data_hp * reciprocal_scale.unsqueeze(-1) out_scales = scaled_block_scales_fp8 From 6550c569d91ef4e2a71cfdde0ed7acfedca2a841 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Fri, 13 Mar 2026 18:12:21 +0000 Subject: [PATCH 5/8] Update [ghstack-poisoned] --- docs/source/workflows/inference.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/workflows/inference.md b/docs/source/workflows/inference.md index c8e37ded8e..fc6c7fdd12 100644 --- a/docs/source/workflows/inference.md +++ b/docs/source/workflows/inference.md @@ -199,7 +199,7 @@ high level, and measure performance improvements. | bfloat16 | 0 | 0.4178 | 1.00 | 1.4914 | 1.00 | | float8_rowwise | 0.1236| 0.3455 | 1.21 | 1.1986 | 1.24 | | mxfp8 | 0.1260 | 0.3673 | 1.14 | 1.2820 | 1.16 | -| nvfp4 | 0.2694 | 0.3308 | 1.26 | 1.1334 | 1.32 | +| nvfp4 | 0.2694 | 0.3203 | 1.30 | 1.0913 | 1.37 | To reproduce, run: From d20435e70109711abefc038a0aac09c1904d9a58 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Fri, 13 Mar 2026 18:39:13 +0000 Subject: [PATCH 6/8] Update [ghstack-poisoned] --- .../mx_formats/test_inference_workflow.py | 2 ++ torchao/prototype/mx_formats/kernels.py | 16 +++++++--------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index cdc2b31aa5..fcc1960cb4 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -181,6 +181,8 @@ def test_inference_workflow_nvfp4( pytest.skip("TODO: weight_only quant currently errors w/ compile") if quant_type == "weight_only" and use_triton_kernel: pytest.skip("unsupported configuration") + if use_triton_kernel and not use_dynamic_per_tensor_scale: + pytest.skip("unsupported configuration") if use_inference_mode and ( shapes != (128, 64, 256) or inpt_dtype != torch.bfloat16 or use_triton_kernel diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index fe5e7017b2..594f1151ad 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import importlib import logging from typing import Optional, Tuple @@ -1389,14 +1390,7 @@ def mxfp8_quantize_cuda( ) -try: - from mslk.quantize.triton.fp4_quantize import ( - triton_quantize_nvfp4 as _mslk_triton_quantize_nvfp4, - ) - - _mslk_available = True -except ImportError: - _mslk_available = False +_mslk_available = importlib.util.find_spec("mslk") is not None def mslk_quantize_nvfp4( @@ -1411,7 +1405,7 @@ def mslk_quantize_nvfp4( Returns: Tuple of (blockwise_scales, quantized_data_uint8) matching TorchAO's convention. """ - mslk_global_scale = 1.0 / per_tensor_scale + mslk_global_scale = per_tensor_scale.reciprocal() return _mslk_quantize_nvfp4_custom_op(x, mslk_global_scale) @@ -1432,6 +1426,10 @@ def _mslk_quantize_nvfp4_custom_op( "mslk is required for NVFP4 triton quantization. " "Install from https://github.com/pytorch/MSLK" ) + from mslk.quantize.triton.fp4_quantize import ( + triton_quantize_nvfp4 as _mslk_triton_quantize_nvfp4, + ) + data_lp, blockwise_scales = _mslk_triton_quantize_nvfp4(x, global_scale) return blockwise_scales, data_lp.view(torch.uint8) From 262c06c8d0fa28bb32a098980a35b7debe0e3251 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Fri, 13 Mar 2026 18:57:17 +0000 Subject: [PATCH 7/8] Update [ghstack-poisoned] --- README.md | 11 +++++++++++ docs/source/index.rst | 13 +++++++++++++ torchao/prototype/mx_formats/inference_workflow.py | 3 ++- 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index a0e21d7d39..5eb446cf19 100644 --- a/README.md +++ b/README.md @@ -110,6 +110,17 @@ pip install torchao Please see the [torchao compability table](https://github.com/pytorch/ao/issues/2919) for version requirements for dependencies. +### Optional Dependencies + +[MSLK](https://github.com/pytorch/MSLK) is an optional runtime dependency that provides accelerated kernels for some of the workflows in torchao. Stable MSLK should be used with stable torchao, and nightly MSLK with nightly torchao. +```bash +# Stable +pip install mslk-cuda==1.0.0 + +# Nightly +pip install --pre mslk --index-url https://download.pytorch.org/whl/nightly/cu128 +``` + ## 🔎 Inference TorchAO delivers substantial performance gains with minimal code changes: diff --git a/docs/source/index.rst b/docs/source/index.rst index 52aee55dcc..faec8c9b7f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -67,6 +67,19 @@ Other installation options: Please see the `torchao compatibility table `__ for version requirements for dependencies. +Optional Dependencies +^^^^^^^^^^^^^^^^^^^^^ + +`MSLK `__ is an optional runtime dependency that provides accelerated kernels for some of the workflows in torchao. Stable MSLK should be used with stable torchao, and nightly MSLK with nightly torchao. + +.. code:: bash + + # Stable + pip install mslk-cuda==1.0.0 + + # Nightly + pip install --pre mslk --index-url https://download.pytorch.org/whl/nightly/cu128 + .. toctree:: :glob: :maxdepth: 1 diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index ce2ecbaae0..4c3b004ec2 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -204,7 +204,8 @@ class NVFP4DynamicActivationNVFP4WeightConfig(AOBaseConfig): set to False. Configuration parameters: - - use_triton_kernel: bool, whether to use fused triton kernel for activation scaling (default: True) + - use_triton_kernel: bool, whether to use fused triton kernel for activation scaling (default: True). + Requires `MSLK `__ to be installed. - use_dynamic_per_tensor_scale: bool, whether to dynamically compute per tensor scale (default: True) - step: Optional[QuantizationStep], the quantization step for observer-based flow - Data: float4_e2m1fn_x2 From 224f618504d84f330253a55bc85bd95fb5142d6f Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Fri, 13 Mar 2026 19:12:27 +0000 Subject: [PATCH 8/8] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/kernels.py | 227 +----------------------- 1 file changed, 1 insertion(+), 226 deletions(-) diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 594f1151ad..74cf8397da 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -6,7 +6,7 @@ import importlib import logging -from typing import Optional, Tuple +from typing import Tuple import numpy as np import torch @@ -425,226 +425,6 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor: return out - @triton.jit - def convert_fp32_to_fp4_packed(x_pairs): - """Convert FP32 pairs to packed FP4 format. - - This function takes tensor where consecutive values along the last dimension - are packed together into single bytes. - - Args: - x_pairs: [Tensor, Tensor] both w/ shapes [..., 1] where zipped last dimension contains - interleaved pairs of FP32 values to be packed together. - - Returns: - Packed tensor with shape [...] (last dimension removed) where each - element is an int8 containing 2 FP4 values: - - First value of pair → low nibble (bits 0-3) - - Second value of pair → high nibble (bits 4-7) - - Example: - Input: [128, 32, 2] containing FP32 pairs - Output: [128, 32] containing packed FP4 bytes - - """ - - x_fp4x2 = tl.inline_asm_elementwise( - asm=""" - { - .reg .b8 byte0, byte1, byte2, byte3; - cvt.rn.satfinite.e2m1x2.f32 byte0, $5, $1; - cvt.rn.satfinite.e2m1x2.f32 byte1, $6, $2; - cvt.rn.satfinite.e2m1x2.f32 byte2, $7, $3; - cvt.rn.satfinite.e2m1x2.f32 byte3, $8, $4; - mov.b32 $0, {byte0, byte1, byte2, byte3}; - } - """, - constraints=("=r,r,r,r,r,r,r,r,r"), - args=x_pairs, - dtype=tl.uint8, - is_pure=True, - pack=4, - ) - - return x_fp4x2 - - # Sauce: https://github.com/gau-nernst/quantized-training - @triton.jit - def quantize_nvfp4_triton_kernel( - x_ptr, - tensor_scale_ptr, - q_ptr, - s_ptr, - stride_xm, - stride_xn, - M, - N, - USE_TENSOR_SCALE: tl.constexpr, - MASK_SCALES: tl.constexpr, - ): - F4_E2M1_MAX = 6.0 - F8E4M3_MAX = 448.0 - E4M3_EPS = 1.5258789e-05 - - pid_m = tl.program_id(1) - pid_n = tl.program_id(0) - - offs_m = pid_m * 128 + tl.arange(0, 128)[:, None] - offs_n = pid_n * 64 + tl.arange(0, 64)[None, :] - if MASK_SCALES: - mask = (offs_m < M) & (offs_n < N) - other = 0.0 - else: - mask = None - other = None - x = tl.load( - x_ptr + offs_m * stride_xm + offs_n * stride_xn, mask=mask, other=other - ) # [128, 64] - x_blocks = x.to(tl.float32).reshape(128, 4, 16) # [128, 4, 16] - - # Compute block-wise scales - block_amax = tl.max(x_blocks.abs(), axis=2) # [128, 4] - - if USE_TENSOR_SCALE: - # Two-level scaling: quantize block scales with per-tensor scale - tensor_scale = tl.load(tensor_scale_ptr) - - # First compute block scales - block_scale_f32 = (block_amax / F4_E2M1_MAX).to(tl.float32) - - # Quantize the block scales with per-tensor scale - scaled_block_scales = block_scale_f32 / tensor_scale - scaled_block_scales = tl.clamp(scaled_block_scales, E4M3_EPS, F8E4M3_MAX) - scales = scaled_block_scales.to(tl.float8e4nv) - - # Apply combined scale to data: per_tensor_scale * quantized_block_scale - total_scale = tensor_scale * scales.to(tl.float32)[:, :, None] - x_blocks = tl.div_rn(x_blocks, total_scale) - else: - # Single-level scaling: use block scales directly - scales_f32 = block_amax / F4_E2M1_MAX - scales_f32 = tl.clamp(scales_f32, E4M3_EPS, F8E4M3_MAX) - scales = scales_f32.to(tl.float8e4nv) - - # Apply block scale to data - total_scale = scales.to(tl.float32)[:, :, None] - x_blocks = tl.div_rn(x_blocks, total_scale) - - # NVIDIA layout for scales - if MASK_SCALES: - # Create offsets for the scale dimensions (4 blocks per row) - scale_offs_n = pid_n * 4 + tl.arange(0, 4)[None, :] - - # Mask out scales to 0 if we are not aligned to 128 x 64 - scales = tl.where( - (offs_m < M) & (scale_offs_n < N // 16), - scales, - 0.0, - ) - packed_scales = scales.reshape(4, 32, 4).permute(1, 0, 2).reshape(32, 16) - offs_m = tl.arange(0, 32)[:, None] - offs_n = tl.arange(0, 16)[None, :] - tl.store( - s_ptr - + (pid_m * tl.num_programs(0) + pid_n) * (32 * 16) - + offs_m * 16 - + offs_n, - packed_scales, - ) - - # Convert to FP4 - x_fp4x2 = convert_fp32_to_fp4_packed(x_blocks.reshape(128, 32, 2).split()) - offs_m = pid_m * 128 + tl.arange(0, 128)[:, None] - offs_n = pid_n * 32 + tl.arange(0, 32)[None, :] - if MASK_SCALES: - mask = (offs_m < M) & (offs_n < N // 2) - else: - mask = None - tl.store(q_ptr + offs_m * (N // 2) + offs_n, x_fp4x2, mask=mask) - - @torch.library.custom_op("ao::triton_quantize_nvfp4", mutates_args=()) - def triton_quantize_nvfp4( - x: torch.Tensor, per_tensor_scale: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Quantize a tensor to NVFP4 format. - - Args: - x (torch.Tensor): Input tensor to be quantized. - tensor_scale (Optional[torch.Tensor]): Per-tensor scale for two-level quantization. - If None, uses single-level block-wise quantization only. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: Quantized tensor and scales tensor in swizzled layout. - - Note: - Since VLLM does not use dyanmo guards we need to make this a custom op - to avoid the triton kernel being invoked w/ the wrong use of `MASK_SCALES` - """ - # reshape to 2d - orig_leading_dims, _orig_M, orig_N = x.shape[:-2], x.shape[-2], x.shape[-1] - x = x.reshape(-1, orig_N) - - M, N = x.shape - # assert M % 128 == 0 and N % 64 == 0 - assert N % 16 == 0, "N must be divisible by 16 for NVFP4 quantization" - - # Calculate blocks needed - num_scales = N // 16 - n_row_blocks = triton.cdiv(M, 128) - n_col_blocks = triton.cdiv(num_scales, 4) - padded_rows = n_row_blocks * 128 - padded_cols = n_col_blocks * 4 - - # mask out scales to 0 if we are not aligned to 128 x 64 - MASK_SCALES = M % 128 != 0 or N % 64 != 0 - - xq = x.new_empty(M, N // 2, dtype=torch.uint8) - scales = x.new_empty(padded_rows, padded_cols, dtype=torch.float8_e4m3fn) - - grid = (triton.cdiv(N, 64), triton.cdiv(M, 128)) - - if per_tensor_scale is None: - # Don't allocate tensor, we just steal this since it won't be used in kernel - tensor_scale_ptr = x - use_tensor_scale = False - else: - tensor_scale_ptr = per_tensor_scale - use_tensor_scale = True - - quantize_nvfp4_triton_kernel[grid]( - x, - tensor_scale_ptr, - xq, - scales, - x.stride(0), - x.stride(1), - M, - N, - USE_TENSOR_SCALE=use_tensor_scale, - MASK_SCALES=MASK_SCALES, - ) - - # reshape back to original shape - scales = scales.view(*orig_leading_dims, -1, padded_cols) - xq = xq.view(*orig_leading_dims, -1, N // 2) - - return scales, xq.view(torch.uint8) - - @triton_quantize_nvfp4.register_fake - def _(x, per_tensor_scale=None): - M, N = x.shape - num_scales = N // 16 - n_row_blocks = triton.cdiv(M, 128) - n_col_blocks = triton.cdiv(num_scales, 4) - padded_rows = n_row_blocks * 128 - padded_cols = n_col_blocks * 4 - - scales = torch.empty( - padded_rows, padded_cols, device=x.device, dtype=torch.float8_e4m3fn - ) - xq = torch.empty(M, N // 2, device=x.device, dtype=torch.uint8) - return scales, xq - @triton_mx_block_rearrange.register_fake def _(scale_tensor): rows, cols = scale_tensor.shape @@ -666,11 +446,6 @@ def triton_to_mxfp8_dim1_reference( def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor: raise AssertionError("needs torch version 2.8+ and triton") - def triton_quantize_nvfp4( - x: torch.Tensor, tensor_scale: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, torch.Tensor]: - raise AssertionError("needs torch version 2.8+ and triton") - def triton_mxfp8_dequant_dim0( e4m3_data: torch.Tensor, e8m0_scales: torch.Tensor,