Flydsl mxfp8 quantize#4357
Draft
zstreet87 wants to merge 11 commits intopytorch:mainfrom
Draft
Conversation
Move FLOOR scale derivation, FP8 clamp constants, and chunk-quantize-pack into flydsl_utils.py as module-level helpers (cutedsl pattern). Each kernel now imports them at module level and calls them inline. Net ~256 lines (-22%) removed across the three kernel files.
…olidate tests - FlyDSL dispatcher functions now accept the same kwargs as their cutedsl counterparts (stage_count, blocked_scale_output, offs). Unsupported values raise NotImplementedError; stage_count is accepted but ignored (no TMA on CDNA3). Allows callers to swap backends without changing call signatures. - Move FlyDSL tests from 4 separate test_flydsl_mxfp8_*.py files into test_kernels.py with @pytest.mark.skipif gates, matching the cutedsl test layout convention. Remove the _flydsl_test_utils.py helper module.
Add MXFP8Dim1CastKernelChoice.FLYDSL and dispatch branch in mx_formats utils. Mirrors the CUTEDSL branch for AMD: callers can now pick FlyDSL as the dim1 cast backend just like cutedsl on NVIDIA. The FlyDSL backend currently supports FLOOR scale only; the dispatch asserts on RCEIL with a clear message. All existing MXFP8TrainingRecipe entries hardcode RCEIL, so end-to-end training-recipe coverage waits on either a FLOOR recipe or a FlyDSL RCEIL implementation. Direct MXFP8TrainingOpConfig with scale_calculation_mode=FLOOR exercises the backend today.
Each workgroup now runs up to 4 waves, each handling its own K-tile, mirroring triton_to_mxfp8_dim1's num_warps=4 strategy. The SIMD scheduler can overlap memory latency across waves within a CU. waves_per_block is picked at launch time based on K (largest power of two such that K % (waves * K_TILE) == 0) so small-K shapes still work via the original 1-wave path.
…er launch 32x1 Phase-1 loads now issue dwordx2 (vec_width=VEC=4 bf16 per lane); K_TILE grows to AMD_WAVE_SIZE * VEC = 256 and per-WG LDS is capped at the 64 KB budget. Test K list updated. All three FlyDSL wrappers pass `x` as a raw torch tensor so the JIT runtime takes its bare-pointer fast path (~46 us/launch saved vs the from_dlpack adapter). Outputs use as_strided over flat fp8 storage.
…lookup 32x1 now stacks up to 4 MX blocks (M-direction) per workgroup, with 4 waves cooperating on a (128, 256) bf16 tile sharing one 64 KB WG-level LDS region; each wave owns one of the 4 stacked 32-row blocks. Phase-1 keeps the multi-wave HBM hide; Phase-4 has 4 waves write 4 disjoint 32-byte fragments of the same 128 B HBM cache line concurrently from 4 SIMDs of one CU, letting the L2 controller coalesce into a single line fill (no eviction, no RMW). Closes the 16384^2 0.75x regression: PMC WriteSize drops from 396 MB (+43% over ideal) to 338 MB (+22%, matching Triton exactly). MI355X bf16 end-to-end: 16384^2 0.75x -> 1.19x; 8192^2 1.33x -> 1.50x. Layout is adaptive via _pick_layout(M, K, ...), so M=32/64/128/... all still compile via 1/2/4-wave configurations and existing small-shape numerics tests pass unchanged. Stream fast-path: replace torch.cuda.current_stream() (2.6 us/call) with fx.Stream(torch._C._cuda_getCurrentStream(idx)[0]) (0.2 us) in all three wrappers via a shared current_stream_fast helper in flydsl_utils. Saves ~2.7 us/call; brings 1024^2 and 4096^2 from ~0.89x to within 1% of Triton parity.
Mirrors the cutedsl MXFP8 quantize surface so the FlyDSL backend can be
swapped in without changing call sites. No new perf work — existing FlyDSL
optimizations (multi-wave, dwordx2, M-stacked WG layout) are preserved.
- Symmetrize all three FlyDSL kernel-module signatures (1x32, 32x1, 3D) to
accept the same kwargs as their cutedsl peers; raise NotImplementedError
at the kernel-module boundary (belt-and-suspenders with the dispatcher).
- Add 3D blocked_scale_output + scale_block_k=32 support so the 3D FlyDSL
kernel matches cutedsl_quantize_3d's option surface (2D blocked output
remains follow-up).
- Add bench_flydsl_quantize_2d_{1x32,32x1}.py mirroring the cutedsl bench
files (FLOOR-only; baseline = triton_to_mxfp8_dim{0,1}).
- Extend bench_quantize_3d.py with a flydsl column and gate cuda/cutedsl
baselines on _mxfp8_cuda_kernels_available so the script runs on either
backend.
- Widen the FlyDSL 2D test param grids to mirror cutedsl coverage where
divisibility allows; add test_flydsl_kernel_module_rejects_unsupported_options
asserting NotImplementedError on direct kernel-module imports.
Known limitations (follow-up PRs):
- 1x32 requires K % 2048 == 0 (one wg-iter consumes wave_size * block_size);
K=7168 needs tail-handling.
- scaling_mode="rceil" not implemented (needs sw cvt.rp.satfinite.ue8m0x2.f32).
- 2D blocked_scale_output not implemented (tcgen05-specific).
- offs (token-group offsets) not yet wired into the 2D kernels.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4357
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds the FlyDSL backend for MXFP8 quantize kernels on AMD CDNA, mirroring the existing
cuTeDSL surface so the dispatcher can route to either backend by hardware. Three new kernels
(2d_1x32, 2d_32x1, 3d) ship with matching dispatcher entrypoints, custom-op registration,
tests, and benchmarks.
What's in this PR
(torchao/prototype/moe_training/kernels/mxfp8/flydsl_quantize_{2d_1x32,2d_32x1,3d}.py) —
same I/O contract as the cutedsl peers, with FlyDSL-specific perf work already applied
(multi-wave workgroups, buffer_load_dwordx2 widening, M-stacked WG layout).
torch.library custom ops, gated on a _mxfp8_flydsl_kernels_available flag and wired into
MXFP8Dim1CastKernelChoice.
wrappers (stage_count, blocked_scale_output, offs, scale_block_n); unsupported values raise
NotImplementedError at both the dispatcher and kernel-module boundaries.
set).
allows; dispatcher and kernel-module raise contracts both asserted; custom-op registration
checked.
(FLOOR-only, baseline = triton_to_mxfp8_dim{0,1}); bench_quantize_3d.py extended with a
flydsl column and gated so the script runs on either backend.
Known limitations (follow-up PRs)
Supporting K=7168 needs tail-handling in the kernel — see ROCm/FlyDSL PR Bug fix for TORCH_VERSION_AFTER_* importation #433's
compute_compile_constants + make_pingpong_kloop for the compile-time-loop +
runtime-tail-guard pattern.
to port from ROCm/FlyDSL PR Bug fix for TORCH_VERSION_AFTER_* importation #433's
tests/kernels/blockscale_gemm_test_utils.py:fp32_to_e8m0: extract exponent byte, add 1 iff
mantissa is nonzero, clamp to [1, 254]. Required for correctness — FLOOR systematically
biases scales down and causes FP8 saturation.
All raise NotImplementedError with descriptive messages.
Test plan
passing on gfx950
rted_options — contract test passes
pported_options — direct-import contract test passes
ops registered
end-to-end on AMD
end-to-end on AMD
backend; columns NaN'd on the absent one