Enable blockwise FP8 dense training kernels on ROCm#4036
Enable blockwise FP8 dense training kernels on ROCm#4036brucechanglongxu wants to merge 3 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4036
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@brucechanglongxu some merge conflicts to resolve |
|
@brucechanglongxu do you still plan to land this |
a95b462 to
5b3a527
Compare
|
Yes! Just rebased on main and resolved the merge conflicts. Ready for review. |
|
@brucechanglongxu the blockwise unit tests fail in CI on ROCm: #4104 can we make sure this PR addresses this? |
|
@danielvegamyhre @vkuzo This enables blockwise FP8 dense training kernels on ROCm, complementing #3996 (MoE path, already merged). Parameterizes FP8_MAX across the dense training Triton kernels and wrappers. Would appreciate a review when you get a chance. |
|
@brucechanglongxu ROCM ci on this PR is failing, can you take a look |
|
@danielvegamyhre Investigated the ROCm CI failures and found a real bug in the quantization kernel. Root cause: Fix: added a mask ( Separately, Pre-existing failures unrelated to this PR:
|
9297779 to
aae12f4
Compare
Introduce rocm_device_capability() and is_rocm_gpu_at_least() to provide structured (major, minor) capability tags for AMD GPUs, analogous to is_sm_at_least_*() for NVIDIA. Parses GCN arch strings (gfx942 -> (9,4), gfx950 -> (9,5), gfx1200 -> (12,0)) following the convention established by vLLM (vllm/platforms/rocm.py). Also adds is_MI325X() and is_MI350() to __all__, fixes the is_Navi4() operator precedence bug, and adds type annotations to all ROCm helpers.
The blockwise FP8 quantization kernels and GEMM ops in torchao/prototype/blockwise_fp8_training had hardcoded FP8 max values (448.0) and dtype references to float8_e4m3fn, which prevented them from running on AMD GPUs that use float8_e4m3fnuz. This commit parameterizes the FP8_MAX value in all 5 Triton quantization kernels, updates the 5 wrapper functions and 3 reference implementations to derive limits from torch.finfo(dtype), and defaults the dtype to e4m3_dtype (platform-aware: e4m3fn on NVIDIA, e4m3fnuz on ROCm). Test skip guards are updated to use a device-capability check that works on both NVIDIA SM >= 9.0 and ROCm, and the @skip_if_rocm decorators on the quantization kernel tests are removed. Addresses the blockwise_fp8_training entries in pytorch#3666.
The triton_fp8_blockwise_act_quant_lhs_kernel wrote reciprocal scales without masking out-of-bounds rows. When the autotuned NUM_GROUPS exceeded M (number of rows), out-of-bounds writes to the column-major scale tensor corrupted later K-block columns, producing catastrophic SQNR (2-10 vs expected 28+). This manifested on ROCm where the autotuner picked NUM_GROUPS > M for small-M shapes. Also skip test_blockwise_linear on ROCm since Float8BlockwiseLinear calls torch._scaled_mm with 128x128 blockwise scales which requires CUDA 12.9+. Addresses pytorch#4104
aae12f4 to
f8b8a15
Compare
|
@vkuzo @danielvegamyhre This enables blockwise FP8 dense training kernels on ROCm. I found and fixed a real bug in the quantization kernel (unmasked out-of-bounds scale writes) during ROCm CI validation. Rebased on latest main, mergeable and ready for review. |
|
btw I'm ging to remove since |
|
@jerryzh168 @danielvegamyhre @vkuzo Blockwise FP8 dense training on ROCm (includes kernel fix). Hoping to land by end of week — review or merge appreciated. |
The blockwise FP8 quantization kernels in
torchao/prototype/blockwise_fp8_training/kernels.pyhad hardcoded FP8 max values (448.0) and dtype references tofloat8_e4m3fn, which prevented them from running on AMD GPUs that usefloat8_e4m3fnuz. Same class of issue fixed for the MoE path in #3996.This PR parameterizes
FP8_MAXas atl.constexprin the 5 Triton quantization kernels, updates the wrapper functions and reference implementations to derive limits fromtorch.finfo(dtype), and defaults the dtype toe4m3_dtype(platform-aware). The dtype assertions are widened to accept bothfloat8_e4m3fnandfloat8_e4m3fnuz. Test skip guards are updated to useget_device_capability()[0] >= 9which covers both NVIDIA SM >= 9.0 and ROCm GFX9xx+, and the@skip_if_rocmdecorators on the quantization kernel tests are removed. TheFloat8BlockwiseLinearinit guard is also updated.Addresses
blockwise_fp8_trainingentries in #3666.Tested on MI300X (gfx942) with
float8_e4m3fnuz(FP8 max = 240.0):Quantization kernel correctness (Triton vs PyTorch reference, block_size=128):
GEMM correctness (FP8 Triton GEMM vs BF16 matmul, SQNR):
GEMM kernel performance (FP8 Triton vs BF16 hipBLAS):
The FP8 Triton GEMM doesn't outperform hipBLAS BF16 on raw throughput — hipBLAS is heavily tuned for MI300X. The value here is that the kernels now work correctly on ROCm, and the FP8 path provides memory/bandwidth savings during training.