@@ -529,6 +529,7 @@ def quantize_nvfp4_triton_kernel(
529529 M ,
530530 N ,
531531 seed_ptr ,
532+ offset_ptr ,
532533 USE_TENSOR_SCALE : tl .constexpr ,
533534 MASK_SCALES : tl .constexpr ,
534535 ROUNDING_MODE : tl .constexpr , # 0=RN, 1=RS
@@ -616,6 +617,9 @@ def quantize_nvfp4_triton_kernel(
616617 else :
617618 # Stochastic rounding (RS) via hardware cvt.rs.satfinite.e2m1x4.f32
618619 seed = tl .load (seed_ptr )
620+ offset_base = tl .load (offset_ptr )
621+ # Place local index in upper 32 bits and offset_base in lower 32 bits
622+ offset = (tl .cast (out_offs , tl .int64 ) << 32 ) | offset_base
619623 rbits = tl .randint (seed , out_offs )
620624 x_fp4x2 = convert_fp32_to_fp4_packed_rs (x_pairs , rbits )
621625 if MASK_SCALES :
@@ -629,7 +633,8 @@ def triton_quantize_nvfp4(
629633 x : torch .Tensor ,
630634 per_tensor_scale : Optional [torch .Tensor ] = None ,
631635 rounding_mode : int = 0 ,
632- seed : Optional [torch .Tensor ] = None ,
636+ seed : torch .Tensor | None = None ,
637+ offset : torch .Tensor | None = None ,
633638 ) -> Tuple [torch .Tensor , torch .Tensor ]:
634639 """Quantize a tensor to NVFP4 format.
635640
@@ -638,8 +643,12 @@ def triton_quantize_nvfp4(
638643 per_tensor_scale (Optional[torch.Tensor]): Per-tensor scale for two-level quantization.
639644 If None, uses single-level block-wise quantization only.
640645 rounding_mode (int): 0 for round-to-nearest, 1 for stochastic rounding.
641- seed (Optional[torch.Tensor]): Seed tensor for stochastic rounding RNG.
642- Should be a single-element int32 tensor on the same device as x.
646+ seed (torch.Tensor | None): Seed tensor for stochastic rounding RNG.
647+ Should be a single-element int64 tensor on the same device as x.
648+ When None, stochastic rounding uses a dummy seed (caller should
649+ only pass None when rounding_mode=0).
650+ offset (torch.Tensor | None): Seed tensor for stochastic rounding RNG.
651+ Should be a single-element int64 tensor on the same device as x.
643652 When None, stochastic rounding uses a dummy seed (caller should
644653 only pass None when rounding_mode=0).
645654
@@ -689,6 +698,7 @@ def triton_quantize_nvfp4(
689698 # For seed_ptr: if seed is None (RN mode), reuse x as dummy pointer
690699 # (kernel won't read it when ROUNDING_MODE=0)
691700 seed_ptr = seed if seed is not None else x
701+ offset_ptr = offset if offset is not None else x
692702
693703 quantize_nvfp4_triton_kernel [grid ](
694704 x ,
@@ -700,6 +710,7 @@ def triton_quantize_nvfp4(
700710 M ,
701711 N ,
702712 seed_ptr ,
713+ offset_ptr ,
703714 USE_TENSOR_SCALE = use_tensor_scale ,
704715 MASK_SCALES = MASK_SCALES ,
705716 ROUNDING_MODE = rounding_mode ,
@@ -712,7 +723,7 @@ def triton_quantize_nvfp4(
712723 return scales , xq .view (torch .uint8 )
713724
714725 @triton_quantize_nvfp4 .register_fake
715- def _ (x , per_tensor_scale = None , rounding_mode = RoundingMode .RN , seed = None ):
726+ def _ (x , per_tensor_scale = None , rounding_mode = RoundingMode .RN , seed = None , offset = None ):
716727 M , N = x .shape
717728 num_scales = N // 16
718729 n_row_blocks = triton .cdiv (M , 128 )
0 commit comments