v0.16.0
Highlights
We are excited to announce the 0.16.0 release of torchao! This release adds support for MXFP8 MoE Building Blocks for Training with Expert Parallelism and deprecated older versions of some configs and less used quantization options to keep torchao leaner! We also revamped our doc page, README and made some progress in making torchao ABI stable.
MXFP8 MoE Building Blocks for Training with Expert Parallelism
This release includes the following differentiable building blocks for MXFP8 MoE Training with expert parallelism:
- a2a_dispatch_mxfp8_fwd_hp_bwd: All-to-all token dispatch (MXFP8 forward pass, BF16 backward pass)
- permute_mxfp8_fwd_hp_bwd: Permute and pad tokens for MXFP8 computation (MXFP8 forward pass, BF16 backward pass)
- _to_mxfp8_then_scaled_grouped_mm: MXFP8 grouped GEMM for routed expert computation (new: optionally accepts pre-quantized inputs). Produces bfloat16 output.
- unpermute_hp_fwd_mxfp8_bwd: Unpermute tokens back to original order (BF16 forward pass, MXFP8 backward pass)
- a2a_combine_hp_fwd_mxfp8_bwd: All-to-all token combine (BF16 forward pass, MXFP8 backward pass). Note the actual combine/aggregation op does not happen here, the naming is just to indicate it is intended to be used for the all2all immediatley preceding the aggregation.
These autograd functions can be chained together to implement efficient MoE training with expert parallel comms and grouped GEMMs in MXFP8.
This approach achieves 10% - 25% tokens/second speedup for DeepSeekV3 16b training:
- +10% tokens/second on single node 8xB200 with NVLink intra-node networking for inter-device communication.
- +25% tokens/second on multi-node B200 cluster with IB inter-node networking and NVLink intra-node networking.
Deprecations
- Deprecate v1 of
Float8WeightOnlyConfig,Float8DynamicActivationFloat8WeightConfig,Int8DynamicActivationIntxWeightConfig,IntxWeightOnlyConfig,Int4WeightOnlyConfig(#3510, #3511, #3512, #3513)
# v0.15.0 - version 1 was available.
config = Float8WeightOnlyConfig(version=1, ...)
config = Float8DynamicActivationFloat8WeightConfig(version=1, ...)
config = Int8DynamicActivationIntxWeightConfig(version=1, ...)
config = IntxWeightOnlyConfig(version=1, ...)
config = Int4WeightOnlyConfig(version=1, ...)
# v0.16.0 - use version 2 (default). Using version 1 is no longer supported.
config = Float8WeightOnlyConfig(version=2, ...)
config = Float8DynamicActivationFloat8WeightConfig(version=2, ...)
config = Int8DynamicActivationIntxWeightConfig(version=2, ...)
config = IntxWeightOnlyConfig(version=2, ...)
config = Int4WeightOnlyConfig(version=2, ...)- Move
Int8DynamicActivationInt4WeightConfig,Int4DynamicActivationInt4WeightConfig,GemliteUIntXWeightOnlyConfig,Float8StaticActivationFloat8WeightConfig,UIntXWeightOnlyConfig,FPXWeightOnlyConfigto prototype (#3491)
# v0.15.0
from torchao.quantization import (
Int8DynamicActivationInt4WeightConfig,
Int4DynamicActivationInt4WeightConfig,
GemliteUIntXWeightOnlyConfig,
Float8StaticActivationFloat8WeightConfig,
UIntXWeightOnlyConfig,
# removed in this release
FPXWeightOnlyConfig,
)
# v0.16.0. These workflows may be deleted in a future version
from torchao.prototype.quantization.quant_api import (
Int8DynamicActivationInt4WeightConfig,
GemliteUIntXWeightOnlyConfig,
Float8StaticActivationFloat8WeightConfig,
UIntXWeightOnlyConfig,
# removed in this release
FPXWeightOnlyConfig,
Int4DynamicActivationInt4WeightConfig,
)- Remove
Int4DynamicActivationInt4WeightConfig,FPXWeightOnlyConfig,Float8DynamicActivationFloat8SemiSparseWeightConfig and SRELUFloat8SemiSparseDynamicActivationFLoat8WeightConfig(#3723, #3520, #3744)
# v0.15.0
config = Int4DynamicActivationInt4WeightConfig()
config = FPXWeightOnlyConfig(3, 2)
config = fpx_weight_only(3, 2) # deprecated alias
config = Float8DynamicActivationFloat8SemiSparseWeightConfig()
config = SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig()
quantize_(model, config)
# v0.16.0
#
The configs are dropped. Please use torchao <= 0.15.0 to use these configs.- Remove
CutlassInt4PackedLayout,MarlinQQQLayout,CutlassSemiSparseLayout(#3723, #3612, #3613, #3744)
# v0.15.0
## CutlassInt4PackedLayout
config = Int8DynamicActivationInt4WeightConfig(
group_size=None,
mapping_type=MappingType.SYMMETRIC,
act_mapping_type=MappingType.SYMMETRIC,
layout=CutlassInt4PackedLayout(),
)
## MarlinQQQLayout
config = Int8DynamicActivationInt4WeightConfig(layout=MarlinQQQLayout())
config = Int4DynamicActivationInt4WeightConfig(layout=MarlinQQQLayout())
## MarlinSparseLayout
apply_fake_sparsity(model)
config = Int4WeightOnlyConfig(layout=MarlinSparseLayout(), version=1)
config = Int4WeightOnlyConfig(int4_packing_format="marlin_sparse", version=2)
## CutlassSemiSparseLayout
config = Float8DynamicActivationFloat8SemiSparseWeightConfig(layout=CutlassSemiSparseLayout())
## quantizing the model with the config
quantize_(model, config)
# v0.16.0
# The config and layout options are dropped. Please use torchao <= 0.15.0 to use the layout- Remove Old GPTQ Implementation (#3720)
# v0.15.0
from torchao.quantization import MultiTensorInputRecorder, Int4WeightOnlyGPTQQuantizer
model = get_model()
input_recorder = MultiTensorInputRecorder()
for i in range(calibration_limit):
args = get_next_input()
input_recorder(*args)
quantizer = Int4WeightOnlyGPTQQuantizer()
args = input_recorder.get_recorded_inputs()
quantizer.quantize(model, *args)
args = get_next_input()
out = model(*args)
# v0.16.0 - this functionality is deleted. We are starting over
# for GPTQ in https://github.com/pytorch/ao/tree/main/torchao/prototype/gptq- Delete Old SmoothQuant Implementation (#3495)
# v0.15.0 from torchao.quantization.smoothquant import (
swap_linear_with_smooth_fq_linear,
smooth_fq_linear_to_inference,
)
swap_linear_with_smooth_fq_linear(model_copy, alpha=0.75)
model_copy(**encoded_input)
smooth_fq_linear_to_inference(model_copy)
# v0.16.0 - the functionality above is deleted. Use `torchao.prototype.smoothquant`
# instead (https://github.com/pytorch/ao/tree/main/torchao/prototype/smoothquant) - Add deprecation warning to
torchao.autoquant(#3741)
This functionality will be removed in a future release of torchao. Please see #3739 for context.
New Features
- Add INT8 Static Quantization Workflow (#3442)
- [Prototype] Add support for MXFP8 and MXFP4 QAT (#3644)
- [Prototype] MXFP8 MoE training
- Add
wgrad_with_hpoption for mxfp8 moe training (#3508) - MXFP8
a2a_dispatchautograd function (#3579) - MXFP8 token
permuteautograd func + triton kernels (#3580) - MXFP8
unpermuteautograd function (#3581) - MXFP8 a2a_combine autograd function (#3582)
- Handle pre-quantized inputs/grads in forward/backward of _MXFP8GroupedMM autograd func (#3583)
- Add benchmark for e2e mxfp8 EP pipeline (#3585)
- Export
wgrad_with_hprecipe to quantize_ api (#3611) - Fallback cuda kernel for when input doesn't meet 2d TMA constraints (#3708)
- Add
emulatedmode (#3724) - Add support for MXFP8 All gather (#3435)
- Integrate cuda kernel for 'groups along M scale blocked layout' (#3556)
- Default to triton kernel for dim0 cast (#3560)
_to_mxfp8_then_scaled_grouped_mmwrapper that accepts keyword args (#3561)- auto-select chunk_width in cuda blocked layout kernel (#3658)
scaled_grouped_mmsupport gfx942 fp8 data type (#3540)- Add custom sharding for triton dim0 quant kernel (#3812)
- Bug fix and test updates related to new triton_calculate_scale param (#3522)
- Update bench script to use new cuda wrapper (#3562)
- Fix torch ref impl of SF blocked layout per group along K (#3603)
- Only use
torch._dynamo.nonstrict_traceif it exists in the torch version (#3650) - cuda blocked layout kernel handling for skinnier scale tensors (#3656)
- Fix bench script bug for dim0 mxfp8 rceil (#3665)
- Register constant with pytree (#3667)
- Update tensor subclass emulated param name (#3811)
- Add
Improvement
- Fix NVFP4 QAT convert path (#3450)
- Adding Int8 weight only and dynamic configs to safetensors (#3474)
- DTensor support for bfloat16 stochastic rounding (#3266)
- Remove unused exception parameter from multifeed/recommendation_platform/corpus/backfill/BackfillMain.cpp (#3445)
- Enable use of dinov2 models for offload benchmark_low_bit_adam for intel GPU (#3191)
- Migrate
Float8SemiSparseTensoroff of AQT (#3361) - Support NVFP4 FP32 bias (#3525)
- Remove
permuteop for FP8 conv (#3533) - Refine from_recipe_name to support mxfp8 on rocm. (#3620)
- Enable smoothquant for int8 static tensor (#3468)
- Allow non-tensor kwargs in prepare_pt2e (#3642)
- Add GPTQ to prototype (#3517)
- Support per tensor quantization for FP8 mslk KernelPreference (#3715)
- Add keepdim support for AffineQuantizedMinMaxObserver (#3748)
- Reduce memory consumption during sparse_2x4_cutlass_float8_tensor::dequantize (#3738)
- Make
rowwise_scaled_linear_sparse_cutlassABI stable (#3725) - Add
HQQoption inUIntxWeightOnlyConfig(#3829) - Support
RCEILintriton_to_mxfp8_dim0kernel with inline PTX for mxfp8 (#3498) - Add pinned memory support for
Int8Tensor(#3489) - Add pinned memory support for
Float8Tensor(#3526) - Add Linux aarch64 wheels (#3359)
- Support FP8 output for
scaled_embedding_bagfor CPU (#3755)
Bug Fixes
- Fix NVFP4 QAT backward typo (#3478)
- Use python version agnostic binding for mxfp8 cuda kernels (#3471)
- Make
FqnToConfighandle module swap configs (#3492) - Fix NVFP4 QAT mixed precision (#3501)
- Fix regression from distributed op changes in pytorch (#3548)
- Update version compatibility to allow torch 2.11.0.dev (#3545)
- Fix fake fusion for convolutions without bias in pt2e quant flow (#3633)
- Fix Int8Tensor v2 to use float32 eps for better bfloat16 precision (#3664)
- Make torchao import not initialize cuda (#3676)
- Fix torchao circular imports for BUCK (#3816)
Performance
- Roofline quantized conv3d/2d layer (#3419)
- Replace custom mxfp4 gemm kernel with
F.scaled_mm(#3675)
Documentation
- Update TorchAO PT2E QAT Doc Hyperlink (#3480)
- Update docs to reflect safetensors update (#3527)
- Update readme (#3532)
- Update readme with new dim1 cast info (#3563)
- Update quantization README.md (#3551)
- Add deprecation warning for
Float8DynamicActivationFloat8SemiSparseWeightConfig(#3595) - Fix incorrect description of Accuracy benchmarks (#3626)
- Update quick start tutorial (#3629)
- Allow doc-preview on workflow_dispatch event (#3662)
- Add torchao example for xpu (#3577)
- Update TorchAOBaseTensor docs (#3652)
- Add
FqnToConfigto API reference quantization page (#3709) - mxfp8 moe training
- Add tokens/sec table and speedups for e2e training (#3531)
- Update roofline and bench script to use new cuda kernel (#3565)
- Add kernel microbenchmarks table to readme (#3566)
- Readme updates: update table of contents, organization; remove outdated content (#3567)
- Update readme benchmark plots (#3647)
- Remove outdated example (#3750)
- Add MXFP8 expert parallel example (#3751)
- TorchAO docs page and README revamp
- File movement (#3743)
- Separate quick start (#3760)
- Make docs landing page be like the main README (#3761)
- Move workflow overview section from README to docs (#3762)
- Initial clean up of quantization API reference section (#3763)
- Move QAT readme to docs (#3764)
- Float8 training api reference (#3765)
- Fix KernelPreference docstring (#3766)
- Remove empty kernel section (#3767)
- Rename "eager quantization tutorials" to "tutorials" (#3769)
- Rename "developer notes" to "contributing" (#3770)
- Rename pt2e section to "pt2e quantization" (#3771)
- Make QAT readme point to docs (#3787)
- Improve workflows page (#3790)
- Move float8 training README.md to docs (#3791)
- Move quantization README.md to docs (#3792)
- Better page names for training and inference (#3796)
- Quantized training (#3804)
- Quantized inference (#3805)
- Make non-prototype inference workflow configs have docblocks (#3821)
Developers
- Add B300 specs for roofline analysis (#3640)
- Update KleidiAI dependency to 63205aa9 (#3484)
- Remove None in
get_current_accelerator_device(#3634) - Fixing TorchAOBaseTensor subclassing behavior (#3401)
- Fix doc-preview S3 path for workflow_dispatch runs (#3688)
- Deleted tensor_core_tiled_layout.cu (#3722)
New Contributors
- @avizon-aws made their first contribution in #3435
- @EduardDurech made their first contribution in #3266
- @WilliamZhang20 made their first contribution in #3480
- @GregoryComer made their first contribution in #3484
- @Stonepia made their first contribution in #3332
- @guangyey made their first contribution in #3318
- @AryanBagade made their first contribution in #3447
- @RuibinCheung made their first contribution in #3620
- @ZhaoqiongZ made their first contribution in #3577
- @billmguo made their first contribution in #3679
- @cthi made their first contribution in #3622
- @Erik-Lundell made their first contribution in #3642
- @xiaobochen-amd made their first contribution in #3540
- @pantha704 made their first contribution in #3709
- @MagellaX made their first contribution in #3635
- @Yuxingwang-intel made their first contribution in #3685
- @puneetmatharu made their first contribution in #3359
- @ved1beta made their first contribution in #3644
- @mergennachin made their first contribution in #3829
Full Changelog: v0.15.0...v0.16.0-rc1