Skip to content

Flydsl mxfp8 quantize#4357

Draft
zstreet87 wants to merge 11 commits intopytorch:mainfrom
zstreet87:flydsl-mxfp8-quantize
Draft

Flydsl mxfp8 quantize#4357
zstreet87 wants to merge 11 commits intopytorch:mainfrom
zstreet87:flydsl-mxfp8-quantize

Conversation

@zstreet87
Copy link
Copy Markdown

@zstreet87 zstreet87 commented Apr 30, 2026

Summary

Adds the FlyDSL backend for MXFP8 quantize kernels on AMD CDNA, mirroring the existing
cuTeDSL surface so the dispatcher can route to either backend by hardware. Three new kernels
(2d_1x32, 2d_32x1, 3d) ship with matching dispatcher entrypoints, custom-op registration,
tests, and benchmarks.

What's in this PR

  • 3 FlyDSL kernels
    (torchao/prototype/moe_training/kernels/mxfp8/flydsl_quantize_{2d_1x32,2d_32x1,3d}.py) —
    same I/O contract as the cutedsl peers, with FlyDSL-specific perf work already applied
    (multi-wave workgroups, buffer_load_dwordx2 widening, M-stacked WG layout).
  • Dispatcher integration — mxfp8_quantize_*_flydsl entrypoints in quant.py, registered as
    torch.library custom ops, gated on a _mxfp8_flydsl_kernels_available flag and wired into
    MXFP8Dim1CastKernelChoice.
  • API surface parity — kernel-module signatures accept the same kwargs as the cutedsl
    wrappers (stage_count, blocked_scale_output, offs, scale_block_n); unsupported values raise
    NotImplementedError at both the dispatcher and kernel-module boundaries.
  • 3D blocked scale output + scale_block_k=32 implemented (matches cutedsl 3D's full option
    set).
  • Tests — numerics tests for all 3 kernels mirror cutedsl coverage where divisibility
    allows; dispatcher and kernel-module raise contracts both asserted; custom-op registration
    checked.
  • Benchmarks — new bench_flydsl_quantize_2d_{1x32,32x1}.py mirror the cutedsl bench files
    (FLOOR-only, baseline = triton_to_mxfp8_dim{0,1}); bench_quantize_3d.py extended with a
    flydsl column and gated so the script runs on either backend.

Known limitations (follow-up PRs)

  • 1x32 requires K % 2048 == 0 (one workgroup-iteration consumes wave_size × block_size).
    Supporting K=7168 needs tail-handling in the kernel — see ROCm/FlyDSL PR Bug fix for TORCH_VERSION_AFTER_* importation #433's
    compute_compile_constants + make_pingpong_kloop for the compile-time-loop +
    runtime-tail-guard pattern.
  • scaling_mode="rceil" not implemented. The 3-op exponent-roundup pattern is straightforward
    to port from ROCm/FlyDSL PR Bug fix for TORCH_VERSION_AFTER_* importation #433's
    tests/kernels/blockscale_gemm_test_utils.py:fp32_to_e8m0: extract exponent byte, add 1 iff
    mantissa is nonzero, clamp to [1, 254]. Required for correctness — FLOOR systematically
    biases scales down and causes FP8 saturation.
  • 2D blocked_scale_output not implemented (tcgen05-specific scale layout).
  • offs (token-group offsets) not yet wired into the 2D kernels.
  • 3D scale_block_n != 32 not implemented.

All raise NotImplementedError with descriptive messages.

Test plan

  • pytest test/prototype/moe_training/test_kernels.py -k "flydsl or amd_mx_3d_flydsl" — 458
    passing on gfx950
  • pytest test/prototype/moe_training/test_kernels.py::test_flydsl_dispatcher_rejects_unsuppo
    rted_options — contract test passes
  • pytest test/prototype/moe_training/test_kernels.py::test_flydsl_kernel_module_rejects_unsu
    pported_options — direct-import contract test passes
  • pytest test/prototype/moe_training/test_kernels.py::test_flydsl_custom_ops_registered —
    ops registered
  • python benchmarks/prototype/moe_training/mxfp8/bench_flydsl_quantize_2d_1x32.py — runs
    end-to-end on AMD
  • python benchmarks/prototype/moe_training/mxfp8/bench_flydsl_quantize_2d_32x1.py — runs
    end-to-end on AMD
  • python benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py — runs on either
    backend; columns NaN'd on the absent one

Move FLOOR scale derivation, FP8 clamp constants, and chunk-quantize-pack
into flydsl_utils.py as module-level helpers (cutedsl pattern). Each
kernel now imports them at module level and calls them inline. Net
~256 lines (-22%) removed across the three kernel files.
…olidate tests

- FlyDSL dispatcher functions now accept the same kwargs as their cutedsl
  counterparts (stage_count, blocked_scale_output, offs). Unsupported values
  raise NotImplementedError; stage_count is accepted but ignored (no TMA on
  CDNA3). Allows callers to swap backends without changing call signatures.
- Move FlyDSL tests from 4 separate test_flydsl_mxfp8_*.py files into
  test_kernels.py with @pytest.mark.skipif gates, matching the cutedsl
  test layout convention. Remove the _flydsl_test_utils.py helper module.
Add MXFP8Dim1CastKernelChoice.FLYDSL and dispatch branch in mx_formats
utils. Mirrors the CUTEDSL branch for AMD: callers can now pick FlyDSL
as the dim1 cast backend just like cutedsl on NVIDIA.

The FlyDSL backend currently supports FLOOR scale only; the dispatch
asserts on RCEIL with a clear message. All existing MXFP8TrainingRecipe
entries hardcode RCEIL, so end-to-end training-recipe coverage waits on
either a FLOOR recipe or a FlyDSL RCEIL implementation. Direct
MXFP8TrainingOpConfig with scale_calculation_mode=FLOOR exercises the
backend today.
Each workgroup now runs up to 4 waves, each handling its own K-tile,
mirroring triton_to_mxfp8_dim1's num_warps=4 strategy. The SIMD
scheduler can overlap memory latency across waves within a CU.

waves_per_block is picked at launch time based on K (largest power of
two such that K % (waves * K_TILE) == 0) so small-K shapes still work
via the original 1-wave path.
…er launch

32x1 Phase-1 loads now issue dwordx2 (vec_width=VEC=4 bf16 per lane);
K_TILE grows to AMD_WAVE_SIZE * VEC = 256 and per-WG LDS is capped at
the 64 KB budget. Test K list updated.

All three FlyDSL wrappers pass `x` as a raw torch tensor so the JIT
runtime takes its bare-pointer fast path (~46 us/launch saved vs the
from_dlpack adapter). Outputs use as_strided over flat fp8 storage.
…lookup

32x1 now stacks up to 4 MX blocks (M-direction) per workgroup, with 4
waves cooperating on a (128, 256) bf16 tile sharing one 64 KB WG-level
LDS region; each wave owns one of the 4 stacked 32-row blocks. Phase-1
keeps the multi-wave HBM hide; Phase-4 has 4 waves write 4 disjoint
32-byte fragments of the same 128 B HBM cache line concurrently from
4 SIMDs of one CU, letting the L2 controller coalesce into a single
line fill (no eviction, no RMW).

Closes the 16384^2 0.75x regression: PMC WriteSize drops from 396 MB
(+43% over ideal) to 338 MB (+22%, matching Triton exactly). MI355X
bf16 end-to-end: 16384^2 0.75x -> 1.19x; 8192^2 1.33x -> 1.50x. Layout
is adaptive via _pick_layout(M, K, ...), so M=32/64/128/... all still
compile via 1/2/4-wave configurations and existing small-shape numerics
tests pass unchanged.

Stream fast-path: replace torch.cuda.current_stream() (2.6 us/call)
with fx.Stream(torch._C._cuda_getCurrentStream(idx)[0]) (0.2 us) in all
three wrappers via a shared current_stream_fast helper in flydsl_utils.
Saves ~2.7 us/call; brings 1024^2 and 4096^2 from ~0.89x to within 1%
of Triton parity.
Mirrors the cutedsl MXFP8 quantize surface so the FlyDSL backend can be
swapped in without changing call sites. No new perf work — existing FlyDSL
optimizations (multi-wave, dwordx2, M-stacked WG layout) are preserved.

- Symmetrize all three FlyDSL kernel-module signatures (1x32, 32x1, 3D) to
  accept the same kwargs as their cutedsl peers; raise NotImplementedError
  at the kernel-module boundary (belt-and-suspenders with the dispatcher).
- Add 3D blocked_scale_output + scale_block_k=32 support so the 3D FlyDSL
  kernel matches cutedsl_quantize_3d's option surface (2D blocked output
  remains follow-up).
- Add bench_flydsl_quantize_2d_{1x32,32x1}.py mirroring the cutedsl bench
  files (FLOOR-only; baseline = triton_to_mxfp8_dim{0,1}).
- Extend bench_quantize_3d.py with a flydsl column and gate cuda/cutedsl
  baselines on _mxfp8_cuda_kernels_available so the script runs on either
  backend.
- Widen the FlyDSL 2D test param grids to mirror cutedsl coverage where
  divisibility allows; add test_flydsl_kernel_module_rejects_unsupported_options
  asserting NotImplementedError on direct kernel-module imports.

Known limitations (follow-up PRs):
- 1x32 requires K % 2048 == 0 (one wg-iter consumes wave_size * block_size);
  K=7168 needs tail-handling.
- scaling_mode="rceil" not implemented (needs sw cvt.rp.satfinite.ue8m0x2.f32).
- 2D blocked_scale_output not implemented (tcgen05-specific).
- offs (token-group offsets) not yet wired into the 2D kernels.
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 30, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4357

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 30, 2026
@zstreet87 zstreet87 marked this pull request as draft April 30, 2026 18:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant