Skip to content

Enable blockwise FP8 dense training kernels on ROCm#4036

Open
brucechanglongxu wants to merge 3 commits intopytorch:mainfrom
brucechanglongxu:rocm-blockwise-fp8-dense-enablement
Open

Enable blockwise FP8 dense training kernels on ROCm#4036
brucechanglongxu wants to merge 3 commits intopytorch:mainfrom
brucechanglongxu:rocm-blockwise-fp8-dense-enablement

Conversation

@brucechanglongxu
Copy link
Copy Markdown
Contributor

@brucechanglongxu brucechanglongxu commented Mar 10, 2026

The blockwise FP8 quantization kernels in torchao/prototype/blockwise_fp8_training/kernels.py had hardcoded FP8 max values (448.0) and dtype references to float8_e4m3fn, which prevented them from running on AMD GPUs that use float8_e4m3fnuz. Same class of issue fixed for the MoE path in #3996.

This PR parameterizes FP8_MAX as a tl.constexpr in the 5 Triton quantization kernels, updates the wrapper functions and reference implementations to derive limits from torch.finfo(dtype), and defaults the dtype to e4m3_dtype (platform-aware). The dtype assertions are widened to accept both float8_e4m3fn and float8_e4m3fnuz. Test skip guards are updated to use get_device_capability()[0] >= 9 which covers both NVIDIA SM >= 9.0 and ROCm GFX9xx+, and the @skip_if_rocm decorators on the quantization kernel tests are removed. The Float8BlockwiseLinear init guard is also updated.

Addresses blockwise_fp8_training entries in #3666.

Tested on MI300X (gfx942) with float8_e4m3fnuz (FP8 max = 240.0):

Quantization kernel correctness (Triton vs PyTorch reference, block_size=128):

  M= 4096 K= 1024: PASS  (act_quant_lhs=ok  act_quant_rhs=ok  weight_quant_rhs=ok)
  M= 4096 K= 4096: PASS  (act_quant_lhs=ok  act_quant_rhs=ok  weight_quant_rhs=ok)
  M= 2048 K= 8192: PASS  (act_quant_lhs=ok  act_quant_rhs=ok  weight_quant_rhs=ok)

GEMM correctness (FP8 Triton GEMM vs BF16 matmul, SQNR):

  M=    2 N=  512 K=  128: PASS  SQNR=28.8 dB
  M=    2 N= 5120 K= 1280: PASS  SQNR=28.6 dB
  M=    4 N= 3584 K=  640: PASS  SQNR=28.6 dB
  M=  128 N= 4096 K= 4096: PASS  SQNR=28.7 dB
  M=  512 N= 4096 K= 4096: PASS  SQNR=28.7 dB
  M= 2048 N= 4096 K= 4096: PASS  SQNR=28.7 dB

GEMM kernel performance (FP8 Triton vs BF16 hipBLAS):

     M      N      K    FP8 (ms)  BF16 (ms)  FP8 TFLOPS  BF16 TFLOPS
   128   4096   4096       0.079      0.021       54.6       199.8
   512   4096   4096       0.078      0.043      220.3       396.2
  2048   4096   4096       0.234      0.134      294.1       513.7
  4096   4096   4096       0.459      0.243      299.3       565.8
  8192   4096   4096       0.898      0.460      306.2       597.7
  2048   8192   4096       0.303      0.260      453.4       529.5
  4096   8192   4096       0.584      0.463      470.5       594.0
  4096   5120   5120       0.409      0.361      525.0       594.8

The FP8 Triton GEMM doesn't outperform hipBLAS BF16 on raw throughput — hipBLAS is heavily tuned for MI300X. The value here is that the kernels now work correctly on ROCm, and the FP8 path provides memory/bandwidth savings during training.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 10, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4036

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 10, 2026
@brucechanglongxu
Copy link
Copy Markdown
Contributor Author

cc @danielvegamyhre @BowenBao

@danielvegamyhre
Copy link
Copy Markdown
Contributor

@brucechanglongxu some merge conflicts to resolve

@danielvegamyhre
Copy link
Copy Markdown
Contributor

@brucechanglongxu do you still plan to land this

@danielvegamyhre danielvegamyhre self-requested a review March 18, 2026 01:05
@brucechanglongxu brucechanglongxu force-pushed the rocm-blockwise-fp8-dense-enablement branch from a95b462 to 5b3a527 Compare March 18, 2026 03:53
@brucechanglongxu
Copy link
Copy Markdown
Contributor Author

Yes! Just rebased on main and resolved the merge conflicts. Ready for review.

@danielvegamyhre
Copy link
Copy Markdown
Contributor

@brucechanglongxu the blockwise unit tests fail in CI on ROCm: #4104 can we make sure this PR addresses this?

@brucechanglongxu
Copy link
Copy Markdown
Contributor Author

@danielvegamyhre @vkuzo This enables blockwise FP8 dense training kernels on ROCm, complementing #3996 (MoE path, already merged). Parameterizes FP8_MAX across the dense training Triton kernels and wrappers. Would appreciate a review when you get a chance.

@danielvegamyhre
Copy link
Copy Markdown
Contributor

@brucechanglongxu ROCM ci on this PR is failing, can you take a look

@pytorch-bot pytorch-bot Bot removed the ciflow/rocm label Mar 18, 2026
@brucechanglongxu
Copy link
Copy Markdown
Contributor Author

brucechanglongxu commented Mar 18, 2026

@danielvegamyhre Investigated the ROCm CI failures and found a real bug in the quantization kernel.

Root cause: triton_fp8_blockwise_act_quant_lhs_kernel writes reciprocal scales without masking out-of-bounds rows. When the autotuned NUM_GROUPS exceeds M (the number of rows), the unmasked writes to the column-major scale tensor corrupt later K-block columns — because in column-major layout, row index M aliases into column index 1, row M+1 into column 1 row 1, etc. This produces catastrophic SQNR (2-10 vs expected 28+). The bug affects both NVIDIA and ROCm, but the ROCm autotuner tends to pick larger NUM_GROUPS for small-M shapes, triggering it there first.

Fix: added a mask (m_offs < M) to the scale store — a one-line change in kernels.py. All 28 blockwise kernel tests now pass on MI300X, including the previously failing test_triton_fp8_gemm_1x128_128x128 parametrizations.

Separately, test_blockwise_linear is skipped on ROCm because Float8BlockwiseLinear calls torch._scaled_mm with 128x128 blockwise scales, which requires CUDA 12.9+.

Pre-existing failures unrelated to this PR:

  • test_sparse_api::test_sparse — hipSPARSELt not available
  • test_affine_quantized_tensor_parallel — float8_tensor.py IndexError (tracked in [ROCm] Fix ROCm CI failures #4061)

@brucechanglongxu brucechanglongxu force-pushed the rocm-blockwise-fp8-dense-enablement branch from 9297779 to aae12f4 Compare March 18, 2026 21:11
Introduce rocm_device_capability() and is_rocm_gpu_at_least() to provide
structured (major, minor) capability tags for AMD GPUs, analogous to
is_sm_at_least_*() for NVIDIA. Parses GCN arch strings (gfx942 -> (9,4),
gfx950 -> (9,5), gfx1200 -> (12,0)) following the convention established
by vLLM (vllm/platforms/rocm.py).

Also adds is_MI325X() and is_MI350() to __all__, fixes the is_Navi4()
operator precedence bug, and adds type annotations to all ROCm helpers.
The blockwise FP8 quantization kernels and GEMM ops in
torchao/prototype/blockwise_fp8_training had hardcoded FP8 max values
(448.0) and dtype references to float8_e4m3fn, which prevented them
from running on AMD GPUs that use float8_e4m3fnuz.

This commit parameterizes the FP8_MAX value in all 5 Triton
quantization kernels, updates the 5 wrapper functions and 3 reference
implementations to derive limits from torch.finfo(dtype), and defaults
the dtype to e4m3_dtype (platform-aware: e4m3fn on NVIDIA, e4m3fnuz on
ROCm). Test skip guards are updated to use a device-capability check
that works on both NVIDIA SM >= 9.0 and ROCm, and the @skip_if_rocm
decorators on the quantization kernel tests are removed.

Addresses the blockwise_fp8_training entries in pytorch#3666.
The triton_fp8_blockwise_act_quant_lhs_kernel wrote reciprocal scales
without masking out-of-bounds rows. When the autotuned NUM_GROUPS
exceeded M (number of rows), out-of-bounds writes to the column-major
scale tensor corrupted later K-block columns, producing catastrophic
SQNR (2-10 vs expected 28+). This manifested on ROCm where the
autotuner picked NUM_GROUPS > M for small-M shapes.

Also skip test_blockwise_linear on ROCm since Float8BlockwiseLinear
calls torch._scaled_mm with 128x128 blockwise scales which requires
CUDA 12.9+.

Addresses pytorch#4104
@brucechanglongxu brucechanglongxu force-pushed the rocm-blockwise-fp8-dense-enablement branch from aae12f4 to f8b8a15 Compare April 1, 2026 21:18
@brucechanglongxu
Copy link
Copy Markdown
Contributor Author

@vkuzo @danielvegamyhre This enables blockwise FP8 dense training kernels on ROCm. I found and fixed a real bug in the quantization kernel (unmasked out-of-bounds scale writes) during ROCm CI validation. Rebased on latest main, mergeable and ready for review.

@jerryzh168
Copy link
Copy Markdown
Contributor

btw I'm ging to remove module: rocm label and please use topic: rocm.

since module: rocm is a release related label. we want to keep it short for now: https://github.com/pytorch/ao/labels?q=module

@brucechanglongxu
Copy link
Copy Markdown
Contributor Author

@jerryzh168 @danielvegamyhre @vkuzo Blockwise FP8 dense training on ROCm (includes kernel fix). Hoping to land by end of week — review or merge appreciated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. float8 topic: rocm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants