2121from torchao .prototype .moe_training .kernels .mxfp8 import (
2222 mx_block_rearrange_2d_M_groups_cuda ,
2323 torch_to_blocked_2d_M_groups ,
24- torch_to_blocked_per_group_3d ,
2524 triton_mx_block_rearrange_2d_K_groups ,
2625 triton_mx_block_rearrange_per_group_3d ,
2726)
@@ -238,33 +237,27 @@ def compute_mxfp8_2d_2d_gemm_time(self, N, M, K):
238237 return time_s
239238
240239 def compute_mxfp8_fwd_bwd_time (self , M , K , N , G ):
241- """Compute time for MXFP8 forward + backward pass including scale rearrangement overhead """
240+ """Compute time for MXFP8 forward + backward pass. """
242241 block_size = 32
243242
244243 # Forward: (M, K) @ (G, K, N)^T -> (M, N) [2D-3D]
245244 fwd_quant_time = self .compute_mxfp8_fwd_quant_time (M , K , G , N )
246245 # Forward scale rearrangement:
247246 # - Input scales (M, K//32) -> M-groups rearrangement
248- # - Weight scales (G, N, K//32) -> 3D per-group rearrangement
247+ # - Weight scales are emitted directly in blocked layout by the 3D kernel
249248 fwd_input_scale_rearrange_time = self .compute_rearrange_2d_M_groups_time (
250249 M , K // block_size
251250 )
252- fwd_weight_scale_rearrange_time = self .compute_rearrange_3d_per_group_time (
253- G , N , K // block_size
254- )
255251 fwd_gemm_time = self .compute_mxfp8_2d_3d_gemm_time (M , K , N )
256252
257253 # Backward input: (M, N) @ (G, N, K) -> (M, K) [2D-3D]
258254 bwd_input_quant_time = self .compute_mxfp8_bwd_input_quant_time (M , K , G , N )
259255 # Backward input scale rearrangement:
260256 # - grad_output scales (M, N//32) -> M-groups rearrangement
261- # - Weight scales (G, K, N//32) -> 3D per-group rearrangement (transposed weight)
257+ # - Weight scales are emitted directly in blocked layout by the 3D kernel
262258 bwd_input_grad_scale_rearrange_time = self .compute_rearrange_2d_M_groups_time (
263259 M , N // block_size
264260 )
265- bwd_input_weight_scale_rearrange_time = (
266- self .compute_rearrange_3d_per_group_time (G , K , N // block_size )
267- )
268261 bwd_input_gemm_time = self .compute_mxfp8_2d_3d_gemm_time (M , N , K )
269262
270263 # Backward weight: (N, M) @ (M, K) -> G separate (N, K) [2D-2D]
@@ -283,11 +276,9 @@ def compute_mxfp8_fwd_bwd_time(self, M, K, N, G):
283276 total_time = (
284277 fwd_quant_time
285278 + fwd_input_scale_rearrange_time
286- + fwd_weight_scale_rearrange_time
287279 + fwd_gemm_time
288280 + bwd_input_quant_time
289281 + bwd_input_grad_scale_rearrange_time
290- + bwd_input_weight_scale_rearrange_time
291282 + bwd_input_gemm_time
292283 + bwd_weight_quant_time
293284 + bwd_weight_grad_scale_rearrange_time
@@ -441,8 +432,6 @@ def benchmark_mxfp8_grouped_mm_fwd_bwd(x, w_t, offs, labels):
441432 x_clone = x .clone ().requires_grad_ (True )
442433 w_t_clone = w_t .clone ().requires_grad_ (True )
443434
444- fn = torch .compile (_to_mxfp8_then_scaled_grouped_mm , fullgraph = True )
445-
446435 # Set all parameters explicitly as variables for positional args
447436 A = x_clone
448437 B_t = w_t_clone
@@ -453,7 +442,7 @@ def benchmark_mxfp8_grouped_mm_fwd_bwd(x, w_t, offs, labels):
453442 scale_calculation_mode = MoEScaleCalculationMode .RCEIL
454443
455444 def wrapper ():
456- out = fn (
445+ out = _to_mxfp8_then_scaled_grouped_mm (
457446 A ,
458447 B_t ,
459448 offs_arg ,
@@ -492,7 +481,7 @@ def benchmark_to_mxfp8_dim1_cuda(tensor, block_size=32):
492481
493482
494483def benchmark_mxfp8_quantize_cuda_3d (tensor , block_size = 32 ):
495- """Benchmark mxfp8_quantize_cuda_3d kernel """
484+ """Benchmark the 3D 32x1 quantizer on its input tensor. """
496485 return benchmark_cuda_function_in_microseconds (
497486 lambda : mxfp8_quantize_cuda_3d (
498487 tensor ,
@@ -715,7 +704,7 @@ def run(
715704 # 3. 3D Quantization Kernel Analysis
716705 # =============================================================================
717706 print ("\n " + "=" * 80 )
718- print ("3D QUANTIZATION KERNELS (Backward Pass - Weight Quantization)" )
707+ print ("3D QUANTIZATION KERNELS (Direct Transposed- Weight Quantization)" )
719708 print ("=" * 80 )
720709
721710 quant_3d_results = []
@@ -741,8 +730,10 @@ def run(
741730
742731 print (f"\n Benchmarking { desc } ..." )
743732
744- # Create test tensor
745- tensor = torch .randn (G_val , N_val , K_val , dtype = torch .bfloat16 , device = "cuda" )
733+ # Benchmark the direct grouped-GEMM weight contract: w_t has shape
734+ # (G, K, N), and the existing 3D 32x1 kernel quantizes it directly.
735+ weight = torch .randn (G_val , N_val , K_val , dtype = torch .bfloat16 , device = "cuda" )
736+ tensor = weight .transpose (- 2 , - 1 )
746737
747738 # Benchmark mxfp8_quantize_cuda_3d
748739 cuda_3d_time_us = benchmark_mxfp8_quantize_cuda_3d (tensor )
@@ -1030,21 +1021,25 @@ def run(
10301021 f" BF16 Grouped GEMM: Roofline={ model .bf16_tflops :.1f} TFLOPS, Actual={ bf16_actual_tflops :.1f} TFLOPS, Efficiency={ result_dict ['bf16_tflops_efficiency_pct' ]:.1f} %"
10311022 )
10321023
1033- # Convert to MXFP8 format using triton_to_mxfp8_dim0 (blocks along dim0)
1024+ # Convert activations to MXFP8 format using triton_to_mxfp8_dim0
10341025 x_fp8 , x_scales = triton_to_mxfp8_dim0 (x , inner_block_size = 32 )
1035- w_fp8 , w_scales = triton_to_mxfp8_dim0 (
1036- w_t .transpose (- 2 , - 1 ), inner_block_size = 32
1026+ w_fp8 , w_scales_blocked = mxfp8_quantize_cuda_3d (
1027+ w_t ,
1028+ block_size = 32 ,
1029+ scale_block_n = 32 ,
1030+ scale_block_k = 1 ,
1031+ scaling_mode = "rceil" ,
10371032 )
10381033
1039- # Convert scales to blocked format
1034+ # Convert only activation scales to blocked format. Weight scales are
1035+ # already produced in blocked layout by mxfp8_quantize_cuda_3d.
10401036 x_scales_blocked , _ = torch_to_blocked_2d_M_groups (
10411037 x_scales , offs , block_size = 32
10421038 )
1043- w_scales_blocked = torch_to_blocked_per_group_3d (w_scales )
10441039
10451040 # Benchmark the MXFP8 grouped GEMM kernel
10461041 mxfp8_gemm_time_us = benchmark_mxfp8_grouped_gemm (
1047- x_fp8 , w_fp8 . transpose ( - 2 , - 1 ) , x_scales_blocked , w_scales_blocked , offs
1042+ x_fp8 , w_fp8 , x_scales_blocked , w_scales_blocked , offs
10481043 )
10491044
10501045 # Calculate MXFP8 actual TFLOPS
@@ -1068,7 +1063,7 @@ def run(
10681063 grouped_gemm_results .append (result_dict )
10691064
10701065 # Clean up tensors to free GPU memory
1071- del x , w , w_t , offs , x_fp8 , x_scales , w_fp8 , w_scales
1066+ del x , w , w_t , offs , x_fp8 , x_scales , w_fp8
10721067 del x_scales_blocked , w_scales_blocked
10731068 torch .cuda .empty_cache ()
10741069
@@ -1436,7 +1431,8 @@ def run(
14361431 # Input quantization: use triton_to_mxfp8_dim0 for (M, K)
14371432 fwd_input_quant_ms = df_quant_2d .loc [idx_large , "triton_to_mxfp8_dim0_us" ] / 1000
14381433
1439- # Weight quantization: use mxfp8_quantize_cuda_3d for (G, N, K)
1434+ # Weight quantization: use mxfp8_quantize_cuda_3d directly on w_t, shape
1435+ # (G, K, N), with no separate 3D scale rearrangement step.
14401436 idx_3d_large = df_quant_3d [df_quant_3d ["description" ] == f"M={ M_large } " ].index [0 ]
14411437 fwd_weight_quant_ms = (
14421438 df_quant_3d .loc [idx_3d_large , "mxfp8_quantize_cuda_3d_us" ] / 1000
@@ -1450,14 +1446,8 @@ def run(
14501446 df_rearrange .loc [idx_m_groups , "mx_block_rearrange_2d_M_groups_cuda_us" ] / 1000
14511447 )
14521448
1453- # Weight scale rearrangement: 3D per-group for (G, N, K//32)
1454- idx_3d_rearrange = df_rearrange_3d [df_rearrange_3d ["M" ] == M_large ].index [0 ]
1455- fwd_weight_scale_rearrange_ms = (
1456- df_rearrange_3d .loc [
1457- idx_3d_rearrange , "triton_mx_block_rearrange_per_group_3d_us"
1458- ]
1459- / 1000
1460- )
1449+ # Weight scales are emitted directly in blocked layout by the 3D quantizer.
1450+ fwd_weight_scale_rearrange_ms = 0.0
14611451
14621452 # GEMM: use actual MXFP8 2D/3D grouped GEMM time
14631453 idx_gemm = df_grouped_gemm [df_grouped_gemm ["M" ] == M_large ].index [0 ]
0 commit comments