Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,17 @@ pip install torchao

Please see the [torchao compability table](https://github.com/pytorch/ao/issues/2919) for version requirements for dependencies.

### Optional Dependencies

[MSLK](https://github.com/pytorch/MSLK) is an optional runtime dependency that provides accelerated kernels for some of the workflows in torchao. Stable MSLK should be used with stable torchao, and nightly MSLK with nightly torchao.
```bash
# Stable
pip install mslk-cuda==1.0.0

# Nightly
pip install --pre mslk --index-url https://download.pytorch.org/whl/nightly/cu128
```

## 🔎 Inference

TorchAO delivers substantial performance gains with minimal code changes:
Expand Down
6 changes: 5 additions & 1 deletion benchmarks/mx_formats/cast_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,12 @@ def to_nvfp4_reference(x_hp):


def to_nvfp4_reference_triton_swizzle(x_hp):
per_tensor_scale = torch.tensor(1.0, dtype=torch.float32, device=x_hp.device)
nvfp4_tensor = NVFP4Tensor.to_nvfp4(
x_hp, use_triton_kernel=True, is_swizzled_scales=True
x_hp,
per_tensor_scale=per_tensor_scale,
use_triton_kernel=True,
is_swizzled_scales=True,
)
return nvfp4_tensor.qdata, nvfp4_tensor.scale

Expand Down
13 changes: 13 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,19 @@ Other installation options:

Please see the `torchao compatibility table <https://github.com/pytorch/ao/issues/2919>`__ for version requirements for dependencies.

Optional Dependencies
^^^^^^^^^^^^^^^^^^^^^

`MSLK <https://github.com/pytorch/MSLK>`__ is an optional runtime dependency that provides accelerated kernels for some of the workflows in torchao. Stable MSLK should be used with stable torchao, and nightly MSLK with nightly torchao.

.. code:: bash

# Stable
pip install mslk-cuda==1.0.0

# Nightly
pip install --pre mslk --index-url https://download.pytorch.org/whl/nightly/cu128

.. toctree::
:glob:
:maxdepth: 1
Expand Down
31 changes: 16 additions & 15 deletions docs/source/workflows/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,31 +151,32 @@ torchao version 0.17.0+git3075bb624
> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4 --enable_fusion_modeling True --skip_printing_detailed_metrics True
...
GPU NVIDIA B200
torch version 2.12.0.dev20260218+cu130
torchao version 0.17.0+git3075bb624
torch version 2.12.0.dev20260312+cu130
torchao version 0.17.0+gitbd7717d20
...
fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp
0 1024 1024 1024 1.00 0.28
1 2048 2048 2048 2.36 0.52
2 4096 4096 4096 2.89 0.90
3 8192 8192 8192 3.32 1.41
4 16384 16384 16384 3.62 2.14
0 1024 1024 1024 1.00 0.46
1 2048 2048 2048 2.36 0.76
2 4096 4096 4096 2.89 1.37
3 8192 8192 8192 3.32 1.97
4 16384 16384 16384 3.62 2.77

#
# nvfp4 with static global scaling (user API in progress)
#
> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True
...
GPU NVIDIA B200
torch version 2.12.0.dev20260218+cu130
torchao version 0.17.0+git3075bb624
torch version 2.12.0.dev20260312+cu130
torchao version 0.17.0+gitbd7717d20
...
fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp
0 1024 1024 1024 1.00 0.34
1 2048 2048 2048 2.74 0.64
2 4096 4096 4096 3.42 1.06
3 8192 8192 8192 3.67 1.58
4 16384 16384 16384 3.82 2.31
0 1024 1024 1024 1.00 0.55
1 2048 2048 2048 2.74 0.95
2 4096 4096 4096 3.42 1.69
3 8192 8192 8192 3.67 2.29
4 16384 16384 16384 3.82 2.98

```

## e2e flux-1.schnell benchmarks
Expand All @@ -198,7 +199,7 @@ high level, and measure performance improvements.
| bfloat16 | 0 | 0.4178 | 1.00 | 1.4914 | 1.00 |
| float8_rowwise | 0.1236| 0.3455 | 1.21 | 1.1986 | 1.24 |
| mxfp8 | 0.1260 | 0.3673 | 1.14 | 1.2820 | 1.16 |
| nvfp4 | 0.2694 | 0.3308 | 1.26 | 1.1334 | 1.32 |
| nvfp4 | 0.2694 | 0.3203 | 1.30 | 1.0913 | 1.37 |

To reproduce, run:

Expand Down
4 changes: 3 additions & 1 deletion test/prototype/mx_formats/test_inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ def test_inference_workflow_nvfp4(
pytest.skip("TODO: weight_only quant currently errors w/ compile")
if quant_type == "weight_only" and use_triton_kernel:
pytest.skip("unsupported configuration")
if use_triton_kernel and not use_dynamic_per_tensor_scale:
pytest.skip("unsupported configuration")

if use_inference_mode and (
shapes != (128, 64, 256) or inpt_dtype != torch.bfloat16 or use_triton_kernel
Expand Down Expand Up @@ -217,7 +219,7 @@ def test_inference_workflow_nvfp4(
y_ref = m(x)

if use_triton_kernel and quant_type == "dynamic":
with cuda_kernel_profiler("quantize_nvfp4_triton_kernel") as result:
with cuda_kernel_profiler("triton_quantize_nvfp4_kernel") as result:
y_mx = m_mx(x)
assert result["found"], "Expected quantize_nvfp4 kernel to be found"
else:
Expand Down
14 changes: 11 additions & 3 deletions test/prototype/mx_formats/test_nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ def test_nvfp4_swizzled_scales_get_scales_method():
@torch.no_grad()
def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
"""Test that Triton and PyTorch NVFP4 quantization produce equivalent results."""
if not use_per_tensor_scale:
pytest.skip("MSLK triton kernel requires per_tensor_scale")

torch.manual_seed(42)
x = torch.randn(M, N, dtype=dtype, device="cuda")
Expand All @@ -392,8 +394,8 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
)

torch.testing.assert_close(nvfp4_pt.scale.flatten(), nvfp4_triton.scale.flatten())
pt_unpacked = unpack_uint4(nvfp4_pt.qdata)
triton_unpacked = unpack_uint4(nvfp4_triton.qdata)
pt_unpacked = unpack_uint4(nvfp4_pt.qdata.view(torch.uint8))
triton_unpacked = unpack_uint4(nvfp4_triton.qdata.view(torch.uint8))
torch.testing.assert_close(
pt_unpacked,
triton_unpacked,
Expand Down Expand Up @@ -559,8 +561,14 @@ def test_scale_shape_matches_qdata(
block_size = 16

x_hp = torch.randn(*shape, device="cuda")

per_tensor_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(x_hp)))

x = NVFP4Tensor.to_nvfp4(
x_hp, is_swizzled_scales=is_swizzled_scales, use_triton_kernel=use_triton_kernel
x_hp,
per_tensor_scale=per_tensor_scale,
is_swizzled_scales=is_swizzled_scales,
use_triton_kernel=use_triton_kernel,
)

if len(shape) == 2:
Expand Down
3 changes: 2 additions & 1 deletion torchao/prototype/mx_formats/inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ class NVFP4DynamicActivationNVFP4WeightConfig(AOBaseConfig):
set to False.

Configuration parameters:
- use_triton_kernel: bool, whether to use fused triton kernel for activation scaling (default: True)
- use_triton_kernel: bool, whether to use fused triton kernel for activation scaling (default: True).
Requires `MSLK <https://github.com/pytorch/MSLK>`__ to be installed.
- use_dynamic_per_tensor_scale: bool, whether to dynamically compute per tensor scale (default: True)
- step: Optional[QuantizationStep], the quantization step for observer-based flow
- Data: float4_e2m1fn_x2
Expand Down
67 changes: 67 additions & 0 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import importlib
import logging
from typing import Optional, Tuple

Expand Down Expand Up @@ -1387,3 +1388,69 @@ def mxfp8_quantize_cuda(
raise NotImplementedError(
"`mxfp8_quantize_cuda` needs (1) torch 2.8+ and (2) torchao built from source on a machine with CUDA capability 10.0+. Please see https://github.com/pytorch/ao/issues/2932 for more details."
)


_mslk_available = importlib.util.find_spec("mslk") is not None


def mslk_quantize_nvfp4(
x: torch.Tensor, per_tensor_scale: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Quantize a tensor to NVFP4 using the MSLK triton kernel.

Args:
x: Input tensor to quantize.
per_tensor_scale: Per-tensor scale (TorchAO convention: amax / (F8E4M3_MAX * F4_E2M1_MAX)).

Returns:
Tuple of (blockwise_scales, quantized_data_uint8) matching TorchAO's convention.
"""
mslk_global_scale = per_tensor_scale.reciprocal()
return _mslk_quantize_nvfp4_custom_op(x, mslk_global_scale)


@torch.library.custom_op("ao::mslk_quantize_nvfp4", mutates_args=())
def _mslk_quantize_nvfp4_custom_op(
x: torch.Tensor, global_scale: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Inner custom op for MSLK NVFP4 quantization.

Args:
x: Input tensor to quantize.
global_scale: Global scale in MSLK convention (1.0 / per_tensor_scale).

Returns:
Tuple of (blockwise_scales, quantized_data_uint8) matching TorchAO's convention.
"""
assert _mslk_available, (
"mslk is required for NVFP4 triton quantization. "
"Install from https://github.com/pytorch/MSLK"
)
from mslk.quantize.triton.fp4_quantize import (
triton_quantize_nvfp4 as _mslk_triton_quantize_nvfp4,
)

data_lp, blockwise_scales = _mslk_triton_quantize_nvfp4(x, global_scale)
return blockwise_scales, data_lp.view(torch.uint8)


@_mslk_quantize_nvfp4_custom_op.register_fake
def _(x, global_scale):
# Mirror the reshape logic from the real MSLK kernel
orig_leading_dims, orig_N = x.shape[:-2], x.shape[-1]
x_2d = x.reshape(-1, orig_N)
M, N = x_2d.shape

num_scales = N // 16
n_row_blocks = triton.cdiv(M, 128)
n_col_blocks = triton.cdiv(num_scales, 4)
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4

scales = x.new_empty(padded_rows, padded_cols, dtype=torch.float8_e4m3fn)
xq = x.new_empty(M, N // 2, dtype=torch.uint8)

# Reshape back to match original leading dims
scales = scales.view(*orig_leading_dims, -1, padded_cols)
xq = xq.view(*orig_leading_dims, -1, N // 2)
return scales, xq
15 changes: 9 additions & 6 deletions torchao/prototype/mx_formats/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from torchao.prototype.mx_formats.kernels import (
f4_unpacked_to_f32,
f32_to_f4_unpacked,
mslk_quantize_nvfp4,
pack_uint4,
triton_quantize_nvfp4,
unpack_uint4,
)
from torchao.prototype.mx_formats.mx_tensor import (
Expand Down Expand Up @@ -155,7 +155,10 @@ def to_nvfp4(
assert K % 16 == 0, (
f"Triton kernel requires K (dim -1) to be divisible by 16, got {K}"
)
blockwise_scales, data_lp = triton_quantize_nvfp4(data_hp, per_tensor_scale)
assert per_tensor_scale is not None, (
"Triton kernel requires per_tensor_scale"
)
blockwise_scales, data_lp = mslk_quantize_nvfp4(data_hp, per_tensor_scale)
else:
blockwise_scales, data_lp = nvfp4_quantize(
data_hp, block_size, per_tensor_scale
Expand Down Expand Up @@ -699,10 +702,10 @@ def nvfp4_quantize(
scaled_block_scales, min=E4M3_EPS, max=F8E4M3_MAX
).to(torch.float8_e4m3fn)
scaled_block_scales_fp32 = scaled_block_scales_fp8.to(torch.float32)
# We "temporarily" dequant the scaled_block_scales_fp32 to get the per_tensor_scale
# To apply to data
total_scale = per_tensor_scale * scaled_block_scales_fp32
data_scaled = data_hp / total_scale.unsqueeze(-1)
# Multiply by reciprocal of combined scale instead of dividing,
# to match the MSLK triton kernel numerics: x * (global_scale / fp8_scale)
reciprocal_scale = (1.0 / per_tensor_scale) / scaled_block_scales_fp32
data_scaled = data_hp * reciprocal_scale.unsqueeze(-1)
out_scales = scaled_block_scales_fp8

data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)
Expand Down
Loading