File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed
torchao/prototype/mx_formats Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -711,7 +711,7 @@ def triton_quantize_nvfp4(
711711 return scales , xq .view (torch .uint8 )
712712
713713 @triton_quantize_nvfp4 .register_fake
714- def _ (x , per_tensor_scale = None , rounding_mode = 0 , seed = None ):
714+ def _ (x , per_tensor_scale = None , rounding_mode = RoundingMode . RN , seed = None ):
715715 M , N = x .shape
716716 num_scales = N // 16
717717 n_row_blocks = triton .cdiv (M , 128 )
@@ -749,7 +749,7 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
749749 def triton_quantize_nvfp4 (
750750 x : torch .Tensor ,
751751 tensor_scale : Optional [torch .Tensor ] = None ,
752- rounding_mode : int = 0 ,
752+ rounding_mode : RoundingMode = RoundingMode . RN ,
753753 ) -> Tuple [torch .Tensor , torch .Tensor ]:
754754 raise AssertionError ("needs torch version 2.8+ and triton" )
755755
You can’t perform that action at this time.
0 commit comments