Skip to content

Commit 43b0e61

Browse files
drisspgcrcrpar
authored andcommitted
Apply suggestions from code review
Co-authored-by: Masaki <mkozuki@nvidia.com>
1 parent 78d88ae commit 43b0e61

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchao/prototype/mx_formats/kernels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ def triton_quantize_nvfp4(
711711
return scales, xq.view(torch.uint8)
712712

713713
@triton_quantize_nvfp4.register_fake
714-
def _(x, per_tensor_scale=None, rounding_mode=0, seed=None):
714+
def _(x, per_tensor_scale=None, rounding_mode=RoundingMode.RN, seed=None):
715715
M, N = x.shape
716716
num_scales = N // 16
717717
n_row_blocks = triton.cdiv(M, 128)
@@ -749,7 +749,7 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
749749
def triton_quantize_nvfp4(
750750
x: torch.Tensor,
751751
tensor_scale: Optional[torch.Tensor] = None,
752-
rounding_mode: int = 0,
752+
rounding_mode: RoundingMode = RoundingMode.RN,
753753
) -> Tuple[torch.Tensor, torch.Tensor]:
754754
raise AssertionError("needs torch version 2.8+ and triton")
755755

0 commit comments

Comments
 (0)