Highlights
We are excited to announce the 0.17 release of torchao! This release adds support for cuteDSL MXFP8 MoE kernels, per-head FP8 quantized low precision attention, ABI stability, and more!
CuteDSL MXFP8 MoE Kernels
We added a new CuteDSL MXFP8 quantization kernel for 3d expert weights that writes scale factors directly to blocked layout for tensorcores: #4090
- Used for scaling along dim1 in the backward pass of MoE training with grouped GEMMs.
- ~12% speedup over previous 2 kernel “quantize then scale layout transformation” approach!
Per-Head FP8 Quantized Low Precision Attention
We added a new API for per-head fp8 quantized attention with FA3 as the backend (#3959 and #3857)
- Users can either choose to use the elementary blocks as direct replacements for `F.scaled_dot_product_attention` or use the high-level wrapper, which replaces all F.SDPA calls within a module with the low precision attention variant.
- Running torch.compile on a wrapped module will enable RoPE fusion where appropriate
- Results show a 1.84x speedup on Wan2.1-T2V-1.3B, 1.23x speedup on LLaMA 3 prefill with high sequence lengths (131k), 1.07x speedup on flux.1-schnell with 2048x2048 image size
Example Usage of Direct Replacement:
from torchao.prototype.attention.fp8_fa3 import fp8_fa3_sdpa, fp8_fa3_rope_sdpa
out = fp8_fa3_sdpa(q, k, v)Example Usage of Wrapper:
from torchao.prototype.attention import (
AttentionBackend,
LowPrecisionAttentionConfig,
apply_low_precision_attention,
)
# Instantiate any nn.Module()
model = MyModel()
# Simple SDPA replacement
config = LowPrecisionAttentionConfig(backend=AttentionBackend.FP8_FA3)
model = apply_low_precision_attention(model, config)
# Flash activation is handled internally by the wrapper
output = model(inputs)
# Torch.compile will enable rope fusion
model = torch.compile(model)PyTorch ABI stability
As of #3516, torchao is now ABI stable for all cuda kernels! This means if the user is running torch 2.11+, they will be able to access torchao’s cuda kernels without having to upgrade torch version for each new torchao version. This applies to the current and all future torchao releases (0.17.0+). Note that python-only API compatibility is the same as before: we support the latest 3 torch minor versions.
Before:
# Compatible versions for cpp extensions:
torchao 0.16.0 + torch 2.10.0
torchao 0.15.0 + torch 2.9.1
torchao 0.14.1 + torch 2.9.0
# Compatible versions for python-only API
# 3 most recent torch versions:
torchao 0.16.0 + torch 2.10.0, 2.9.1, 2.8.0
torchao 0.15.0 + torch 2.9.1, 2.8.0, 2.7.1
torchao 0.14.0 + torch 2.9.0, 2.8.0, 2.7.0After:
# Compatible versions for cpp extensions (all future releases):
torchao 0.17.0 onwards + any torch 2.11+
# Compatible versions for python-only API
# 3 most recent torch versions (same as before)
torchao 0.17.0 + torch 2.11.0, 2.10.0, 2.9.1Deprecations
- Deprecate AQT and related classes (#4074)
We are planning to remove these classes in the future, so we are deprecating them in advance.
Here are all the classes that now have new deprecation warnings:
Base classes (torchao/dtypes/utils.py):
1. Layout
2. PlainLayout
3. AQTTensorImpl
torchao/dtypes/affine_quantized_tensor.py:
4. AffineQuantizedTensor
torchao/dtypes/floatx/:
5. Float8Layout
6. Float8AQTTensorImpl
7. CutlassSemiSparseLayout
8. CutlassSemiSparseTensorImpl
torchao/dtypes/uintx/:
9. TensorCoreTiledLayout
10. TensorCoreTiledAQTTensorImpl
11. SemiSparseLayout
12. SemiSparseAQTTensorImpl
13. Int4CPULayout
14. Int4CPUAQTTensorImpl
15. Int4XPULayout
16. Int4XPUAQTTensorImpl
17. QDQLayout
18. QDQTensorImpl
19. PlainAQTTensorImpl
20. PackedLinearInt8DynamicActivationIntxWeightLayout
21. PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl
- Deprecate TorchAODType (#4100)
This was added for torch 2.5 as a placeholder for torch.int4. Now all torch versions we support have these dtypes, so this class is no longer needed and should not be used by anyone. We will remove it next release (0.18.0).
Before:
TorchAODType.INT1
TorchAODType.INT2
TorchAODType.INT3
TorchAODType.INT4
TorchAODType.INT5
TorchAODType.INT6
TorchAODType.INT7
After:
torch.int1
torch.int2
torch.int3
torch.int4
torch.int5
torch.int6
torch.int7
- Delete deprecated GemlitePackedLayout and GemliteUIntXWeightOnlyConfig (#4144)
GemliteUIntXWeightOnlyConfig is now replaced with torchao.prototype.UIntxWeightOnlyConfig and Int8DynamicActivationUIntxWeightConfig
# 0.16:
GemliteUIntXWeightOnlyConfig(group_size, bit_width, packing_bitwidth, mode="weight_only")
GemliteUIntXWeightOnlyConfig(group_size, bit_width, packing_bitwidth, mode="dynamic")
# 0.17
UIntxWeightOnlyConfig(group_size, bit_width, packing_bitwidth)
Int8DynamicActivationUIntxWeightConfig(group_size, bit_width, packing_bitwidth)
Core
- Clean up arglist for
choose_qparams_affine_with_min_max(#3808) - Extend TorchAOBaseTensor docstring with subclassing and safetensors docs (#3846)
- Fix FqnToConfig module skipping for _default and regex modules (#3877)
- Add mxfp8 and nvfp4 support to safetensors (#3668)
- [Prototype] Add pruning-aware training in torchao.prototype.pat (#3429)
- Rocm: scaled_grouped_mm support gfx942 fp8 data type (#3955)
- Update b200 peak memory bandwidth (#4002)
- [Prototype] quant_logger tool for logging weights and activations (#3987)
- [ROCM] Float8 deepseekv3_671b IntOverflow in triton kernels during training (#4016)
- [ROCm][INT4] Configurable ntile size for TilePacked format (#3834)
- Fix Android ARM64 build for torchao lowbit kernels (#4029)
- [mxfp8] fix exp2f_rcp to handle special case of unbiased exponent of 127 (#4117)
- Remove version compatibility table (#4154)
- Suppress deprecation warnings from #4074 and #4100 (#4155)
- [xpu][test] Skip WIP config for Intel GPU in test_safetensors_support.py and test_x86inductor_fusion.py (#4049)
Training
- [mxfp8 training] 128b alignment for CUTensormap to support CUDA 13.0+ (#3837)
- [mxfp8 moe training][docs] add tutorial for training with MXFP8 expert parallel (#3752)
- Make
_emulated_mxfp8_scaled_grouped_mm_2d_2dtorch.compile compatible (#3906) - [moe training] refactor configs, recipes; support converting linears + grouped gemms in a single quantize_() call (#3862)
- Use relaxed memory ordering for Triton atomics on AMDGPU. (#3945)
- Fix backward return count mismatch in _Float8GroupedMM (#3956)
- Expand Triton autotune configs for MoE FP8 kernels to improve AMD GPU performance (#3952)
- Optimize FP8 colwise scales kernel for AMD GPUs in MoE backward pass (#3972)
- TrainingWeightWrapperTensor base class; subclasses for FP8/MXFP8 with grouped_mm and linear overrides (#3968)
- [mx] Fix: pass scaling_mode parameter to triton_to_mxfp8_dim1 (#3686)
- Revert expanded MoE FP8 autotune configs that regress DeepSeek V3 shapes (#4024)
- [mxfp8 moe training] add cuda kernel for per group padding (#3998)
- Enable blockwise FP8 training kernels on AMD GPUs (MI300/MI350) (#3996)
- [mxfp8 training] cuda kernel for unpadding token groups (#4021)
- [ROCm] bring some mxfp8 quantization unit test back (#3628)
- [moe training] apply per group padding to fp8 grouped mm (#4045)
- [training] skip Dtensor/TP integration test pending solution (#4059)
- [mxfp8 training] clean up torch reference impl for token group padding (#4063)
- Add pre-quantized activation support to MXFP8 grouped GEMM (_to_mxfp8_then_scaled_grouped_mm) (#3961)
- [mxfp8 moe training] _permute_bf16 -> permute_and_pad (#4083)
- [mxfp8 moe training] update readme (#4084)
- [moe training] default pad_token_groups_for_grouped_mm=False (#4080)
- [TP] reorder MXFP8 wrapper over DTensor (#4010)
- [ROCm] Fix ROCm CI failures (#4061)
- Add CuTeDSL kernel for 3D tensor quantization to MXFP8 (#4090)
- Add bandwidth benchmarking script for fp8 quant blockwise (#4128)
- [mxfp8 moe training] remove unnecessary memory fences in cutedsl kernel; unify redundant conditionals (#4129)
- [moe training] Optimize FP8 MoE backward pass: fused colwise kernel + AMD tuning (#4069)
- [mxfp8 training] require torch nightly and cuda 12.8+ (#4141)
- Clean up unused rocm references in test_training.py (#4170)
- [mxfp8 training] add cutedsl kernel for mxfp8 quantation along dim0 (#4156)
Inference
- Add aten.select op support to int4_packed_to_4d_tensor (#3874)
- Separately control the activation quantization granularity (#3524)
- Avoid MSLK on NVIDIA B200/GB200 for per-tensor scaled weights (fallback to TORCH) (#3786)
- Add plain int4_packing_format support for Float8DynamicActivationInt4WeightConfig (#3714)
- Move mx documentation from README.md to docs site (#3915)
- Land 3894, 3887, 3884, 3883 into main from base branch (#3920)
- Fix up nvfp4 roofline sweep script (#3927)
- Add asymmetric support for Int8Tensor + SmoothQuant (#3900)
- Small fix for inference roofline model (#3990)
- Update inference.md with more context on roofline model (#3991)
- Add parameter_name support to _mx_inference_linear_transform (#3975)
- Add parameter_name support to _nvfp4_inference_linear_transform (#3976)
- Add FP8 FA3 low-precision attention with monkey-patch SDPA path (#3959)
- Added new API for low precision fp8 attention using FA3 (#3857)
- Fix nvfp4 static roofline and add nvfp4 dynamic roofline (#4030)
- Remove rope fusion option, do automatically on torch compile (#4055)
- [Prototype] Added prototype low precision attention API to the docs (#4056)
- Clean up flux-1.schnell benchmark and add to docs (#4072)
- Hook up mslk's to_nvfp4 kernel to torchao's inference nvfp4 workflows (#4031)
- Improve mslk docs (#4077)
- Delete old torchao to_nvfp4 triton kernel (#4078)
- Implement per-group quantization for Int8WeightOnlyConfig (#4018)
- [reland][xpu] INT8 quantization on Intel XPU (#3782)
PT2E Quantization
- Add support for
torch._higher_order_ops.scanin pt2e quantization (#3882) - Assign meta["val"] to new bias nodes in fold_bn_weights_into_conv_node. (#3907)
- Add AffineFakeQuantize for groupwise quantization support (#3848)
- Add support for
torch._higher_order_ops.while_loopin pt2e quantization (#3916) - Add Linear+BN folding to fuse_conv_bn() (#3917)
- Fix deprecated TreeSpec CTOR (#4109)
New Contributors
- @mohammed-saalim made their first contribution in #3808
- @oelachqar made their first contribution in #3832
- @sxu made their first contribution in #3838
- @JacobSzwejbka made their first contribution in #3807
- @immerSIR made their first contribution in #3864
- @ArivunidhiA made their first contribution in #3869
- @lordaarush made their first contribution in #3786
- @lyprince made their first contribution in #3893
- @SS-JIA made their first contribution in #3917
- @wenchenvincent made their first contribution in #3945
- @brucechanglongxu made their first contribution in #3952
- @lizamd made their first contribution in #3972
- @ugolowic made their first contribution in #3686
- @alex-minooka made their first contribution in #4016
- @bbeckca made their first contribution in #4064
- @pianpwk made their first contribution in #4010
Full Changelog: v0.16.0...v0.17.0-rc1