Implements rounding mode for NVFP4 tensor#3384
Implements rounding mode for NVFP4 tensor#3384syed-ahmed wants to merge 6 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3384
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
d8792dd to
70485c5
Compare
drisspg
left a comment
There was a problem hiding this comment.
We should have the lower level ops all be functional, e.g. take in a seed anlgous to what is done for the tritonkernel, but same for the _f32_to_floatx_unpacked
And then for how users use this e.g.: _seed = torch.randint(2**31, (1,)).item()
in the nvfp4 tensors I am not entirely sure I like this..
|
I think that should be fine, and I can make the seed handling more functional. As a side note, may be need to review TE's rng_state mechanism a little bit and make sure we have full control over the determinism knobs. |
|
TBH this PR has sparked some convo on how we expect users to handle RNG for RS mode. I am not sure I have perosnally have an infromed enough opinion on the difficulties and nuance here. maybe @vkuzo has some better ideas here My intial thought is lower level is functional. And then user sets a seed (philox state), ensure that we properly increment philox offset before every invocation of RS which will pull from there to generate the source of randomness. Also ensure that cudagraphs work here. I bet inductor will not behave well here though. @jbschlosser just to throw some more ideas into the ring |
|
I'm currently looking into the API design for a new set of "stateless RNG" APIs (i.e. pseudorandom number generation without pytorch maintaining global PRNG state). My concern is that this composes well with RS mode as implemented in this PR, and, more generally, with any implementations of stochastic rounding throughout torch / torchao.
so I generally agree with Driss's statement here. I think I'd want to see the lower level ops accept a philox seed / offset and that should satisfy composability with a new stateless RNG API. |
|
Thanks @drisspg and @jbschlosser for the input! I'll refactor the ops to be functional and only take seed, and remove the .item(). For the eager path, should |
|
@pytorchbot label module: training |
|
Didn't find following labels among repository labels: module:,training |
|
@pytorchbot label training |
|
Didn't find following labels among repository labels: training |
|
Ok, addressed the comments.
|
There was a problem hiding this comment.
purely cosmetic but auto can be used while it'd start with 1, not 0 -- https://docs.python.org/3/library/enum.html#enum.auto
|
|
||
|
|
||
| def f32_to_f4_unpacked(x): | ||
| def f32_to_f4_unpacked(x, rounding_mode=RoundingMode.RN, rand_bits=None): |
There was a problem hiding this comment.
i feel a type annotation is good to have for this method
def f32_to_f4_unpacked(x: torch.Tensor, rounding_mode: RoundingMode = RoundingMode.RN, rand_bits: torch.Tensor | None = None):| assert data_hp.is_contiguous(), "Only support contiguous data for now" | ||
| assert block_size == 16, "NVFP4 requires block_size=16" | ||
| if rounding_mode == RoundingMode.RS: | ||
| assert rand_bits is not None and rand_bits.numel() > 1, ( |
There was a problem hiding this comment.
this numel check si werid it should just do a direct shape compare right?
There was a problem hiding this comment.
hmm I guess it needs to have two differtn forms 1 for triton and one for fallback?
There was a problem hiding this comment.
Yes. For triton, I only need a fresh seed which informs the philox in triton where to start from and the offsets ensure getting the correct random number. For the fallback, we are generating all the randoms outside the quantize.
353d24a to
43b0e61
Compare
| offs_m = pid_m * 128 + tl.arange(0, 128)[:, None] | ||
| offs_n = pid_n * 32 + tl.arange(0, 32)[None, :] | ||
| out_offs = offs_m * (N // 2) + offs_n |
There was a problem hiding this comment.
I'm worried about how these offsets are calculated wrt randomness below during stochastic rounding application. Passing a seed only to this kernel does not seem like enough to avoid undesirable RNG reuse. I think at the very least, we'd want to accept a base offset to start from. And even more ideally from a control perspective, we'd allow for a full set of (seed, offset) pairs controlling every single application of rounding (but I realize there will likely be perf issues with this).
There was a problem hiding this comment.
Thanks @jbschlosser for the comment.
Would it be better to write a fused kernel for now, similar to MSLK: https://github.com/meta-pytorch/MSLK/blob/main/mslk/quantize/triton/fp4_quantize.py#L723 when composing with random hadamard transform (#4040)?
It seems like there is some discussion on how to support stochastic rounding better here: pytorch/pytorch#175409, so not sure anymore if the approach in this PR would make sense for UX and performance.
Co-authored-by: Masaki <mkozuki@nvidia.com>
Co-authored-by: Masaki <mkozuki@nvidia.com>
Co-authored-by: Masaki <mkozuki@nvidia.com>
- Remove incorrect assert per_tensor_scale is not None: both mslk_quantize_nvfp4 and triton_quantize_nvfp4 accept None; the assert also blocked dynamo fullgraph=True - Gate MSLK path on rounding_mode == RoundingMode.RN: MSLK ignores rand_bits/rounding_mode, so RS requests must route to triton_quantize_nvfp4 instead - Use torch.ops.ao.triton_quantize_nvfp4 instead of bare name: dynamo cannot resolve names defined inside conditional scopes; the op registry handle works correctly - Fix ROUNDING_MODE == 0 -> == 1 in kernels: RoundingMode.RN.value == 1 (auto() starts at 1), so == 0 always routed to the stochastic branch; fixed in both production kernel and test wrapper Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
43b0e61 to
a569bd0
Compare
Closes #3264.
Code Assistant Used: Claude Code Opus 4.6
Summary
Implements RS (stochastic rounding) support for NVFP4 as described in the RFC. RN remains the default.
Key implementation details
Triton kernel path: Uses hardware
cvt.rs.satfinite.e2m1x4.f32PTX inline asm. The RS instruction converts 4 floats and 1 random uint32 into 2 packed FP4 bytes, while the RN instruction converts 2 floats into 1 byte. Since the RN path usespack=4(8 floats per invocation), the RS path issues twocvt.rscalls to match. Withpack=4, 4rbitsvalues are loaded but only 2 are consumed per invocation — the other 2 are wasted to keep the output layout identical to RN.Seed determinism: Triton path takes an explicit
seedparameter fortl.randint; PyTorch path usestorch.manual_seed. Same seed produces bitwise-identical results within each path. The two paths use different RNGs so they diverge for RS even with the same seed.Validation: Both paths validate
rounding_modeusingnot in RoundingMode, raisingValueErrorfor invalid values.RoundingModeenum usesintvalues (RN=0,RS=1) rather than string values from the RFC, since the Triton kernel needs an integertl.constexpr.Tests
Single parametrized
test_f4_roundingintest_kernels.pycovers rounding mode (RN/RS/invalid), kernel (PyTorch/Triton), seed determinism (same/different seeds), and value axes. Verifies RN is biased to nearest, RS is unbiased in expectation, same seed is deterministic, different seeds diverge, and invalid mode raises.Existing
test_nvfp4_tensor.pytests are parametrized overrounding_modeand_triton_kernel_params. Triton-vs-PyTorch equivalence uses SQNR threshold of 40 for RN and 8 for RS (different RNGs). End-to-end matmul uses SQNR threshold of 16 for RN and 6 for RS.Test Environment