Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion benchmarks/mx_formats/cast_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,12 @@ def to_nvfp4_reference(x_hp):


def to_nvfp4_reference_triton_swizzle(x_hp):
per_tensor_scale = torch.tensor(1.0, dtype=torch.float32, device=x_hp.device)
nvfp4_tensor = NVFP4Tensor.to_nvfp4(
x_hp, use_triton_kernel=True, is_swizzled_scales=True
x_hp,
per_tensor_scale=per_tensor_scale,
use_triton_kernel=True,
is_swizzled_scales=True,
)
return nvfp4_tensor.qdata, nvfp4_tensor.scale

Expand Down
31 changes: 16 additions & 15 deletions docs/source/workflows/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,31 +151,32 @@ torchao version 0.17.0+git3075bb624
> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4 --enable_fusion_modeling True --skip_printing_detailed_metrics True
...
GPU NVIDIA B200
torch version 2.12.0.dev20260218+cu130
torchao version 0.17.0+git3075bb624
torch version 2.12.0.dev20260312+cu130
torchao version 0.17.0+gitbd7717d20
...
fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp
0 1024 1024 1024 1.00 0.28
1 2048 2048 2048 2.36 0.52
2 4096 4096 4096 2.89 0.90
3 8192 8192 8192 3.32 1.41
4 16384 16384 16384 3.62 2.14
0 1024 1024 1024 1.00 0.46
1 2048 2048 2048 2.36 0.76
2 4096 4096 4096 2.89 1.37
3 8192 8192 8192 3.32 1.97
4 16384 16384 16384 3.62 2.77

#
# nvfp4 with static global scaling (user API in progress)
#
> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True
...
GPU NVIDIA B200
torch version 2.12.0.dev20260218+cu130
torchao version 0.17.0+git3075bb624
torch version 2.12.0.dev20260312+cu130
torchao version 0.17.0+gitbd7717d20
...
fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp
0 1024 1024 1024 1.00 0.34
1 2048 2048 2048 2.74 0.64
2 4096 4096 4096 3.42 1.06
3 8192 8192 8192 3.67 1.58
4 16384 16384 16384 3.82 2.31
0 1024 1024 1024 1.00 0.55
1 2048 2048 2048 2.74 0.95
2 4096 4096 4096 3.42 1.69
3 8192 8192 8192 3.67 2.29
4 16384 16384 16384 3.82 2.98

```

## e2e flux-1.schnell benchmarks
Expand All @@ -198,7 +199,7 @@ high level, and measure performance improvements.
| bfloat16 | 0 | 0.4178 | 1.00 | 1.4914 | 1.00 |
| float8_rowwise | 0.1236| 0.3455 | 1.21 | 1.1986 | 1.24 |
| mxfp8 | 0.1260 | 0.3673 | 1.14 | 1.2820 | 1.16 |
| nvfp4 | 0.2694 | 0.3308 | 1.26 | 1.1334 | 1.32 |
| nvfp4 | 0.2694 | 0.3203 | 1.30 | 1.0913 | 1.37 |

To reproduce, run:

Expand Down
2 changes: 1 addition & 1 deletion test/prototype/mx_formats/test_inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def test_inference_workflow_nvfp4(
y_ref = m(x)

if use_triton_kernel and quant_type == "dynamic":
with cuda_kernel_profiler("quantize_nvfp4_triton_kernel") as result:
with cuda_kernel_profiler("triton_quantize_nvfp4_kernel") as result:
y_mx = m_mx(x)
assert result["found"], "Expected quantize_nvfp4 kernel to be found"
else:
Expand Down
14 changes: 11 additions & 3 deletions test/prototype/mx_formats/test_nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ def test_nvfp4_swizzled_scales_get_scales_method():
@torch.no_grad()
def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
"""Test that Triton and PyTorch NVFP4 quantization produce equivalent results."""
if not use_per_tensor_scale:
pytest.skip("MSLK triton kernel requires per_tensor_scale")

torch.manual_seed(42)
x = torch.randn(M, N, dtype=dtype, device="cuda")
Expand All @@ -392,8 +394,8 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
)

torch.testing.assert_close(nvfp4_pt.scale.flatten(), nvfp4_triton.scale.flatten())
pt_unpacked = unpack_uint4(nvfp4_pt.qdata)
triton_unpacked = unpack_uint4(nvfp4_triton.qdata)
pt_unpacked = unpack_uint4(nvfp4_pt.qdata.view(torch.uint8))
triton_unpacked = unpack_uint4(nvfp4_triton.qdata.view(torch.uint8))
torch.testing.assert_close(
pt_unpacked,
triton_unpacked,
Expand Down Expand Up @@ -559,8 +561,14 @@ def test_scale_shape_matches_qdata(
block_size = 16

x_hp = torch.randn(*shape, device="cuda")

per_tensor_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(x_hp)))

x = NVFP4Tensor.to_nvfp4(
x_hp, is_swizzled_scales=is_swizzled_scales, use_triton_kernel=use_triton_kernel
x_hp,
per_tensor_scale=per_tensor_scale,
is_swizzled_scales=is_swizzled_scales,
use_triton_kernel=use_triton_kernel,
)

if len(shape) == 2:
Expand Down
69 changes: 69 additions & 0 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,3 +1387,72 @@ def mxfp8_quantize_cuda(
raise NotImplementedError(
"`mxfp8_quantize_cuda` needs (1) torch 2.8+ and (2) torchao built from source on a machine with CUDA capability 10.0+. Please see https://github.com/pytorch/ao/issues/2932 for more details."
)


try:
Comment thread
vkuzo marked this conversation as resolved.
Outdated
from mslk.quantize.triton.fp4_quantize import (
triton_quantize_nvfp4 as _mslk_triton_quantize_nvfp4,
)

_mslk_available = True
except ImportError:
_mslk_available = False


def mslk_quantize_nvfp4(
x: torch.Tensor, per_tensor_scale: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Quantize a tensor to NVFP4 using the MSLK triton kernel.

Args:
x: Input tensor to quantize.
per_tensor_scale: Per-tensor scale (TorchAO convention: amax / (F8E4M3_MAX * F4_E2M1_MAX)).

Returns:
Tuple of (blockwise_scales, quantized_data_uint8) matching TorchAO's convention.
"""
mslk_global_scale = 1.0 / per_tensor_scale
Comment thread
vkuzo marked this conversation as resolved.
Outdated
return _mslk_quantize_nvfp4_custom_op(x, mslk_global_scale)


@torch.library.custom_op("ao::mslk_quantize_nvfp4", mutates_args=())
def _mslk_quantize_nvfp4_custom_op(
x: torch.Tensor, global_scale: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Inner custom op for MSLK NVFP4 quantization.

Args:
x: Input tensor to quantize.
global_scale: Global scale in MSLK convention (1.0 / per_tensor_scale).

Returns:
Tuple of (blockwise_scales, quantized_data_uint8) matching TorchAO's convention.
"""
assert _mslk_available, (
"mslk is required for NVFP4 triton quantization. "
"Install from https://github.com/pytorch/MSLK"
)
data_lp, blockwise_scales = _mslk_triton_quantize_nvfp4(x, global_scale)
return blockwise_scales, data_lp.view(torch.uint8)


@_mslk_quantize_nvfp4_custom_op.register_fake
def _(x, global_scale):
# Mirror the reshape logic from the real MSLK kernel
orig_leading_dims, orig_N = x.shape[:-2], x.shape[-1]
x_2d = x.reshape(-1, orig_N)
M, N = x_2d.shape

num_scales = N // 16
n_row_blocks = triton.cdiv(M, 128)
n_col_blocks = triton.cdiv(num_scales, 4)
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4

scales = x.new_empty(padded_rows, padded_cols, dtype=torch.float8_e4m3fn)
xq = x.new_empty(M, N // 2, dtype=torch.uint8)

# Reshape back to match original leading dims
scales = scales.view(*orig_leading_dims, -1, padded_cols)
xq = xq.view(*orig_leading_dims, -1, N // 2)
return scales, xq
15 changes: 9 additions & 6 deletions torchao/prototype/mx_formats/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from torchao.prototype.mx_formats.kernels import (
f4_unpacked_to_f32,
f32_to_f4_unpacked,
mslk_quantize_nvfp4,
pack_uint4,
triton_quantize_nvfp4,
unpack_uint4,
)
from torchao.prototype.mx_formats.mx_tensor import (
Expand Down Expand Up @@ -155,7 +155,10 @@ def to_nvfp4(
assert K % 16 == 0, (
f"Triton kernel requires K (dim -1) to be divisible by 16, got {K}"
)
blockwise_scales, data_lp = triton_quantize_nvfp4(data_hp, per_tensor_scale)
assert per_tensor_scale is not None, (
"Triton kernel requires per_tensor_scale"
)
blockwise_scales, data_lp = mslk_quantize_nvfp4(data_hp, per_tensor_scale)
else:
blockwise_scales, data_lp = nvfp4_quantize(
data_hp, block_size, per_tensor_scale
Expand Down Expand Up @@ -699,10 +702,10 @@ def nvfp4_quantize(
scaled_block_scales, min=E4M3_EPS, max=F8E4M3_MAX
).to(torch.float8_e4m3fn)
scaled_block_scales_fp32 = scaled_block_scales_fp8.to(torch.float32)
# We "temporarily" dequant the scaled_block_scales_fp32 to get the per_tensor_scale
# To apply to data
total_scale = per_tensor_scale * scaled_block_scales_fp32
data_scaled = data_hp / total_scale.unsqueeze(-1)
# Multiply by reciprocal of combined scale instead of dividing,
# to match the MSLK triton kernel numerics: x * (global_scale / fp8_scale)
reciprocal_scale = (1.0 / per_tensor_scale) / scaled_block_scales_fp32
data_scaled = data_hp * reciprocal_scale.unsqueeze(-1)
out_scales = scaled_block_scales_fp8

data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)
Expand Down
Loading