@@ -22,6 +22,7 @@ def _to_fp8_rowwise_then_scaled_grouped_mm(
2222 B_t : torch .Tensor ,
2323 offs : torch .Tensor ,
2424 out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
25+ float8_dtype : torch .dtype = torch .float8_e4m3fn ,
2526) -> torch .Tensor :
2627 """
2728 Differentiable FP8 grouped matrix multiplication with dynamic FP8 rowwise quantization.
@@ -48,7 +49,7 @@ def _to_fp8_rowwise_then_scaled_grouped_mm(
4849 - Scales are computed per-row and rounded to powers of 2 for efficiency
4950 - This function is fully differentiable via custom autograd implementation
5051 """
51- return _Float8GroupedMM .apply (A , B_t , offs , out_dtype )
52+ return _Float8GroupedMM .apply (A , B_t , offs , out_dtype , float8_dtype )
5253
5354
5455class _Float8GroupedMM (torch .autograd .Function ):
@@ -61,6 +62,7 @@ def forward(
6162 B_t : torch .Tensor ,
6263 offs : Optional [torch .Tensor ] = None ,
6364 out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
65+ float8_dtype : torch .dtype = torch .float8_e4m3fn ,
6466 ) -> torch .Tensor :
6567 # torchao _quantize_then_scaled_grouped_mm only supports A=2D|3D and B=3D.
6668 assert A .ndim == 2 or A .ndim == 3 , "A must be 2D or 3D"
@@ -100,31 +102,32 @@ def forward(
100102 # A_scales shape: (M,1) or (B, M, 1)
101103 A_scales = tensor_to_scale (
102104 A ,
103- torch . float8_e4m3fn ,
105+ float8_dtype ,
104106 scaling_granularity = ScalingGranularity .AXISWISE ,
105107 axiswise_dim = - 1 ,
106108 round_scales_to_power_of_2 = True ,
107109 )
108110 A_scaled = A .to (torch .float32 ) * A_scales
109- A_data_row_major = to_fp8_saturated (A_scaled , torch . float8_e4m3fn )
111+ A_data_row_major = to_fp8_saturated (A_scaled , float8_dtype )
110112
111113 # Convert B to float8, column-major for right operand of grouped GEMM.
112114 # B_t shape: (E, K, N)
113115 # B_t scales must be computed rowwise keeping the outer/final dim, so:
114116 # B_t_scales shape: (E, 1, N)
115117 B_t_scales = tensor_to_scale (
116118 B_t ,
117- torch . float8_e4m3fn ,
119+ float8_dtype ,
118120 scaling_granularity = ScalingGranularity .AXISWISE ,
119121 axiswise_dim = - 2 ,
120122 round_scales_to_power_of_2 = True ,
121123 )
122124 B_t_scaled = B_t .to (torch .float32 ) * B_t_scales
123- B_t_data_col_major = to_fp8_saturated (B_t_scaled , torch . float8_e4m3fn )
125+ B_t_data_col_major = to_fp8_saturated (B_t_scaled , float8_dtype )
124126
125127 # Store what we need for backward.
126128 ctx .save_for_backward (A , B_t , offs )
127129 ctx .out_dtype = out_dtype
130+ ctx .float8_dtype = float8_dtype
128131
129132 # Perform scaled grouped GEMM and return result.
130133 # output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N)
@@ -154,6 +157,7 @@ def forward(
154157 def backward (ctx , grad_output : torch .Tensor ):
155158 A , B_t , offs = ctx .saved_tensors
156159 out_dtype = ctx .out_dtype
160+ float8_dtype = ctx .float8_dtype
157161
158162 # Convert grad_output to float8, row-major for left operand of grouped GEMM
159163 # needed for grad_A: grad_output @ B
@@ -162,21 +166,19 @@ def backward(ctx, grad_output: torch.Tensor):
162166 # grad_output_scale shape: (Mg, 1)
163167 grad_output_scales = tensor_to_scale (
164168 grad_output ,
165- torch . float8_e4m3fn ,
169+ float8_dtype ,
166170 scaling_granularity = ScalingGranularity .AXISWISE ,
167171 axiswise_dim = - 1 ,
168172 round_scales_to_power_of_2 = True ,
169173 )
170174 grad_output_scaled = grad_output .to (torch .float32 ) * grad_output_scales
171- grad_output_data_row_major = to_fp8_saturated (
172- grad_output_scaled , torch .float8_e4m3fn
173- )
175+ grad_output_data_row_major = to_fp8_saturated (grad_output_scaled , float8_dtype )
174176
175177 # Compute B fp8 column-major for right operand of grouped GEMM:
176178 # grad_A = grad_output @ B.
177179 B_data_col_major , B_scales = triton_fp8_rowwise_3d_transpose_rhs (
178180 B_t ._data if hasattr (B_t , "_data" ) else B_t ,
179- output_dtype = torch . float8_e4m3fn ,
181+ output_dtype = float8_dtype ,
180182 round_scales_to_power_of_2 = True ,
181183 )
182184
@@ -216,7 +218,7 @@ def backward(ctx, grad_output: torch.Tensor):
216218 .contiguous ()
217219 .t (), # Quantization is over 2x faster when input is col major, even with this transformation
218220 offs ,
219- torch . float8_e4m3fn ,
221+ float8_dtype ,
220222 round_scales_to_power_of_2 = True ,
221223 )
222224 grad_output_t_data_row_major = grad_out_data_colwise .t ()
@@ -227,7 +229,7 @@ def backward(ctx, grad_output: torch.Tensor):
227229 .contiguous ()
228230 .t (), # Quantization is over 2x faster when input is col major, even with this transformation
229231 offs ,
230- torch . float8_e4m3fn ,
232+ float8_dtype ,
231233 round_scales_to_power_of_2 = True ,
232234 )
233235
0 commit comments