Skip to content

Implements rounding mode for NVFP4 tensor#3384

Open
syed-ahmed wants to merge 6 commits intopytorch:mainfrom
syed-ahmed:rounding-mode
Open

Implements rounding mode for NVFP4 tensor#3384
syed-ahmed wants to merge 6 commits intopytorch:mainfrom
syed-ahmed:rounding-mode

Conversation

@syed-ahmed
Copy link
Copy Markdown
Contributor

@syed-ahmed syed-ahmed commented Nov 24, 2025

Closes #3264.
Code Assistant Used: Claude Code Opus 4.6

Summary

Implements RS (stochastic rounding) support for NVFP4 as described in the RFC. RN remains the default.

Key implementation details

Triton kernel path: Uses hardware cvt.rs.satfinite.e2m1x4.f32 PTX inline asm. The RS instruction converts 4 floats and 1 random uint32 into 2 packed FP4 bytes, while the RN instruction converts 2 floats into 1 byte. Since the RN path uses pack=4 (8 floats per invocation), the RS path issues two cvt.rs calls to match. With pack=4, 4 rbits values are loaded but only 2 are consumed per invocation — the other 2 are wasted to keep the output layout identical to RN.

Seed determinism: Triton path takes an explicit seed parameter for tl.randint; PyTorch path uses torch.manual_seed. Same seed produces bitwise-identical results within each path. The two paths use different RNGs so they diverge for RS even with the same seed.

Validation: Both paths validate rounding_mode using not in RoundingMode, raising ValueError for invalid values.

RoundingMode enum uses int values (RN=0, RS=1) rather than string values from the RFC, since the Triton kernel needs an integer tl.constexpr.

Tests

Single parametrized test_f4_rounding in test_kernels.py covers rounding mode (RN/RS/invalid), kernel (PyTorch/Triton), seed determinism (same/different seeds), and value axes. Verifies RN is biased to nearest, RS is unbiased in expectation, same seed is deterministic, different seeds diverge, and invalid mode raises.

Existing test_nvfp4_tensor.py tests are parametrized over rounding_mode and _triton_kernel_params. Triton-vs-PyTorch equivalence uses SQNR threshold of 40 for RN and 8 for RS (different RNGs). End-to-end matmul uses SQNR threshold of 16 for RN and 6 for RS.

Test Environment

Collecting environment information...                                                                                     
PyTorch version: 2.12.0a0+gitbabda95                                                                                      
Is debug build: False                                                                                                     
CUDA used to build PyTorch: 13.2                                                                                          
ROCM used to build PyTorch: N/A                                                                                           
                                                                                                                          
OS: Ubuntu 24.04.4 LTS (aarch64)                                                                                          
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04.1) 13.3.0                                                                      
Clang version: Could not collect                                                                                          
CMake version: version 3.31.6                                                                                             
Libc version: glibc-2.39                                                                                                  
                                                                                                                          
Python version: 3.12.3 (main, Jan 22 2026, 20:57:42) [GCC 13.3.0] (64-bit runtime)                                        
Python platform: Linux-6.14.0-1008-nvidia-64k-aarch64-with-glibc2.39                                                      
Is CUDA available: True                                                                                                   
CUDA runtime version: 13.2.50                                                                                             
CUDA_MODULE_LOADING set to: LAZY                                                                                          
GPU models and configuration:                                                                                             
GPU 0: NVIDIA GB200                                                                                                       
GPU 1: NVIDIA GB200                                                                                                       
GPU 2: NVIDIA GB200                                                                                                       
GPU 3: NVIDIA GB200

Nvidia driver version: 580.105.08                                                                                                                                                                                                                                                          

Versions of relevant libraries:                                                                                     
[pip3] mypy==1.16.0                                       
[pip3] mypy_extensions==1.1.0                             
[pip3] numpy==1.26.4                                      
[pip3] nvidia-cudnn-frontend==1.18.0                                                                                
[pip3] onnx==1.20.0                                       
[pip3] onnx-ir==0.1.16                                    
[pip3] onnxscript==0.6.2                                  
[pip3] optree==0.13.0                                     
[pip3] pytorch-lightning==2.6.1                                                                                     
[pip3] torch==2.12.0a0+gitbabda95                                                                                   
[pip3] torchmetrics==1.8.2                                
[pip3] triton==3.6.0+git9844da95                                                                                    
[conda] Could not collect

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Nov 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3384

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@syed-ahmed syed-ahmed marked this pull request as draft November 24, 2025 18:03
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 24, 2025
@syed-ahmed syed-ahmed mentioned this pull request Feb 11, 2026
59 tasks
@syed-ahmed syed-ahmed marked this pull request as ready for review March 4, 2026 03:39
@syed-ahmed syed-ahmed changed the title [WIP] Implements rounding mode for NVFP4 tensor Implements rounding mode for NVFP4 tensor Mar 4, 2026
@syed-ahmed syed-ahmed moved this to In Progress in PyTorch + CUDA Mar 4, 2026
Copy link
Copy Markdown
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have the lower level ops all be functional, e.g. take in a seed anlgous to what is done for the tritonkernel, but same for the _f32_to_floatx_unpacked

And then for how users use this e.g.: _seed = torch.randint(2**31, (1,)).item()

in the nvfp4 tensors I am not entirely sure I like this..

@syed-ahmed
Copy link
Copy Markdown
Contributor Author

syed-ahmed commented Mar 4, 2026

I think that should be fine, and I can make the seed handling more functional. As a side note, may be need to review TE's rng_state mechanism a little bit and make sure we have full control over the determinism knobs.

@drisspg
Copy link
Copy Markdown
Contributor

drisspg commented Mar 5, 2026

TBH this PR has sparked some convo on how we expect users to handle RNG for RS mode.

I am not sure I have perosnally have an infromed enough opinion on the difficulties and nuance here.

maybe @vkuzo has some better ideas here

My intial thought is lower level is functional. And then user sets a seed (philox state), ensure that we properly increment philox offset before every invocation of RS which will pull from there to generate the source of randomness. Also ensure that cudagraphs work here. I bet inductor will not behave well here though.

@jbschlosser just to throw some more ideas into the ring

@jbschlosser
Copy link
Copy Markdown

I'm currently looking into the API design for a new set of "stateless RNG" APIs (i.e. pseudorandom number generation without pytorch maintaining global PRNG state). My concern is that this composes well with RS mode as implemented in this PR, and, more generally, with any implementations of stochastic rounding throughout torch / torchao.

My intial thought is lower level is functional. And then user sets a seed (philox state), ensure that we properly increment philox offset before every invocation of RS which will pull from there to generate the source of randomness.

so I generally agree with Driss's statement here. I think I'd want to see the lower level ops accept a philox seed / offset and that should satisfy composability with a new stateless RNG API.

@syed-ahmed
Copy link
Copy Markdown
Contributor Author

Thanks @drisspg and @jbschlosser for the input! I'll refactor the ops to be functional and only take seed, and remove the .item(). For the eager path, should _f32_to_floatx_unpacked take a seed (and so consume the seed in a torch.Generator, and that Generator gets passed to rand_ints) or should we just have a rand_int tensor as input?

@syed-ahmed
Copy link
Copy Markdown
Contributor Author

@pytorchbot label module: training

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 11, 2026

Didn't find following labels among repository labels: module:,training

@syed-ahmed
Copy link
Copy Markdown
Contributor Author

@pytorchbot label training

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 11, 2026

Didn't find following labels among repository labels: training

@danielvegamyhre danielvegamyhre added the module: training quantize_ api training flow label Mar 11, 2026
@syed-ahmed syed-ahmed requested a review from drisspg March 11, 2026 04:14
@syed-ahmed
Copy link
Copy Markdown
Contributor Author

syed-ahmed commented Mar 11, 2026

Ok, addressed the comments.

  • .item() call is removed. Triton kernel now loads a seed.
  • eager path expects a rand_bits tensor now keeping generator state outside.
  • to_nvfp4 expects caller to pass rand_bits.
  • added a CUDA graph composabilty test for both eager and torch.compile.

Comment thread torchao/prototype/custom_fp_utils.py Outdated
Comment thread torchao/prototype/custom_fp_utils.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

purely cosmetic but auto can be used while it'd start with 1, not 0 -- https://docs.python.org/3/library/enum.html#enum.auto



def f32_to_f4_unpacked(x):
def f32_to_f4_unpacked(x, rounding_mode=RoundingMode.RN, rand_bits=None):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i feel a type annotation is good to have for this method

def f32_to_f4_unpacked(x: torch.Tensor, rounding_mode: RoundingMode = RoundingMode.RN, rand_bits: torch.Tensor | None = None):

Comment thread torchao/prototype/mx_formats/kernels.py Outdated
Comment thread torchao/prototype/mx_formats/kernels.py Outdated
assert data_hp.is_contiguous(), "Only support contiguous data for now"
assert block_size == 16, "NVFP4 requires block_size=16"
if rounding_mode == RoundingMode.RS:
assert rand_bits is not None and rand_bits.numel() > 1, (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this numel check si werid it should just do a direct shape compare right?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I guess it needs to have two differtn forms 1 for triton and one for fallback?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. For triton, I only need a fresh seed which informs the philox in triton where to start from and the offsets ensure getting the correct random number. For the fallback, we are generating all the randoms outside the quantize.

Comment on lines +606 to +608
offs_m = pid_m * 128 + tl.arange(0, 128)[:, None]
offs_n = pid_n * 32 + tl.arange(0, 32)[None, :]
out_offs = offs_m * (N // 2) + offs_n
Copy link
Copy Markdown

@jbschlosser jbschlosser Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm worried about how these offsets are calculated wrt randomness below during stochastic rounding application. Passing a seed only to this kernel does not seem like enough to avoid undesirable RNG reuse. I think at the very least, we'd want to accept a base offset to start from. And even more ideally from a control perspective, we'd allow for a full set of (seed, offset) pairs controlling every single application of rounding (but I realize there will likely be perf issues with this).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jbschlosser for the comment.

Would it be better to write a fused kernel for now, similar to MSLK: https://github.com/meta-pytorch/MSLK/blob/main/mslk/quantize/triton/fp4_quantize.py#L723 when composing with random hadamard transform (#4040)?

It seems like there is some discussion on how to support stochastic rounding better here: pytorch/pytorch#175409, so not sure anymore if the approach in this PR would make sense for UX and performance.

drisspg and others added 3 commits April 17, 2026 14:25
Co-authored-by: Masaki <mkozuki@nvidia.com>
Co-authored-by: Masaki <mkozuki@nvidia.com>
- Remove incorrect assert per_tensor_scale is not None: both mslk_quantize_nvfp4
  and triton_quantize_nvfp4 accept None; the assert also blocked dynamo fullgraph=True
- Gate MSLK path on rounding_mode == RoundingMode.RN: MSLK ignores rand_bits/rounding_mode,
  so RS requests must route to triton_quantize_nvfp4 instead
- Use torch.ops.ao.triton_quantize_nvfp4 instead of bare name: dynamo cannot resolve
  names defined inside conditional scopes; the op registry handle works correctly
- Fix ROUNDING_MODE == 0 -> == 1 in kernels: RoundingMode.RN.value == 1 (auto() starts
  at 1), so == 0 always routed to the stochastic branch; fixed in both production kernel
  and test wrapper

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

[RFC] NVFP4 Rounding Modes

6 participants