@@ -83,7 +83,6 @@ def _to_mxfp8_then_scaled_grouped_mm(
8383 A : torch .Tensor ,
8484 B_t : torch .Tensor ,
8585 offs : Optional [torch .Tensor ] = None ,
86- block_size : Optional [int ] = None ,
8786 out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
8887 kernel_preference : KernelPreference = KernelPreference .AUTO ,
8988 wgrad_with_hp : bool = False ,
@@ -103,7 +102,6 @@ def _to_mxfp8_then_scaled_grouped_mm(
103102 which must be 3D, which must be shape (G, K, N)
104103 and in "per group column-major memory" layout (i.e., strides of (N*K, 1, N)).
105104 offs (int32 torch.Tensor): The offsets to use to mark the end index of each group along the dim0 of the A tensor.
106- block_size (int): Block size for MXFP8 quantization. Must be 32 (the only supported value). This parameter exists for backward compatibility but is ignored.
107105 out_dtype (torch.dtype): Output dtype for the result. Defaults to torch.bfloat16.
108106 kernel_preference (KernelPreference): Kernel preference (AUTO uses CUDA/Triton, EMULATED uses to_mx). Defaults to KernelPreference.AUTO.
109107 wgrad_with_hp (bool): Whether to compute weight gradient in high precision. Defaults to False.
@@ -120,7 +118,6 @@ def _to_mxfp8_then_scaled_grouped_mm(
120118 A ,
121119 B_t ,
122120 offs ,
123- block_size ,
124121 out_dtype ,
125122 kernel_preference ,
126123 wgrad_with_hp ,
@@ -144,7 +141,6 @@ def forward(
144141 input_act : torch .Tensor ,
145142 weight_t : torch .Tensor ,
146143 group_end_offsets : Optional [torch .Tensor ] = None ,
147- block_size : int = 32 ,
148144 out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
149145 kernel_preference : KernelPreference = KernelPreference .AUTO ,
150146 wgrad_with_hp : bool = False ,
@@ -158,15 +154,18 @@ def forward(
158154 input_act: Input activations, shape (M, K) - may be MXTensor or high-precision
159155 weight_t: Expert weights transposed, shape (E, K, N) - always high-precision
160156 group_end_offsets: End index of each token group, shape (E,)
161- block_size: Block size for MXFP8 quantization (must be 32)
162157 out_dtype: Output dtype (bfloat16 or float32)
163158 kernel_preference: Kernel preference (AUTO uses CUDA/Triton, EMULATED uses to_mx)
164159 wgrad_with_hp: Compute weight gradient in high precision
165160 scale_calculation_mode: Mode for scale calculation (RCEIL, FLOOR, etc.)
161+ pad_token_groups_for_grouped_mm: Whether to pad token groups to the next multiple of 32
166162
167163 Returns:
168164 Output tensor, shape (M, N)
169165 """
166+ # block_size is always 32 for MXFP8
167+ block_size = 32
168+
170169 assert kernel_preference in (
171170 KernelPreference .AUTO ,
172171 KernelPreference .EMULATED ,
@@ -182,7 +181,6 @@ def forward(
182181 # Input validation
183182 assert input_act .ndim == 2 , "input_act must be 2D"
184183 assert weight_t .ndim == 3 , "weight_t must be 3D"
185- assert block_size == 32 , "Only block_size=32 is supported"
186184 assert group_end_offsets is not None , (
187185 "group_end_offsets must be provided for 2d-3d grouped mm"
188186 )
@@ -247,7 +245,6 @@ def forward(
247245 padded_group_start_offsets ,
248246 padded_group_end_offsets ,
249247 )
250- ctx .block_size = block_size
251248 ctx .out_dtype = out_dtype
252249 ctx .kernel_preference = kernel_preference
253250 ctx .wgrad_with_hp = wgrad_with_hp
@@ -279,7 +276,8 @@ def backward(ctx, grad_output: torch.Tensor):
279276 padded_group_end_offsets ,
280277 ) = ctx .saved_tensors
281278
282- block_size = ctx .block_size
279+ # block_size is always 32 for MXFP8
280+ block_size = 32
283281 out_dtype = ctx .out_dtype
284282 kernel_preference = ctx .kernel_preference
285283 wgrad_with_hp = ctx .wgrad_with_hp
@@ -338,13 +336,12 @@ def backward(ctx, grad_output: torch.Tensor):
338336 return (
339337 grad_input ,
340338 grad_weight_t ,
341- None ,
342- None ,
343- None ,
344- None ,
345- None ,
346- None ,
347- None ,
339+ None , # group_end_offsets
340+ None , # out_dtype
341+ None , # kernel_preference
342+ None , # wgrad_with_hp
343+ None , # scale_calculation_mode
344+ None , # pad_token_groups_for_grouped_mm
348345 )
349346
350347
0 commit comments