|
15 | 15 | pytest.skip("CUDA and PyTorch 2.7.0+ required", allow_module_level=True) |
16 | 16 |
|
17 | 17 | from torchao.prototype.moe_training.config import ( |
| 18 | + Float8TrainingOpConfig, |
18 | 19 | MXFP8TrainingOpConfig, |
19 | 20 | MXFP8TrainingRecipe, |
20 | 21 | ) |
21 | 22 | from torchao.prototype.moe_training.kernels.mxfp8.quant import ( |
22 | 23 | _mxfp8_cutedsl_kernels_available, |
23 | 24 | ) |
24 | | -from torchao.prototype.moe_training.tensor import MXFP8TrainingWeightWrapperTensor |
| 25 | +from torchao.prototype.moe_training.tensor import ( |
| 26 | + Float8TrainingWeightWrapperTensor, |
| 27 | + MXFP8TrainingWeightWrapperTensor, |
| 28 | +) |
25 | 29 | from torchao.prototype.mx_formats.config import ( |
26 | 30 | MXFP8Dim1CastKernelChoice, |
27 | 31 | ) |
@@ -183,3 +187,104 @@ def test_mxfp8_training_tensor_ops_preserve_subclass(): |
183 | 187 | assert isinstance(result, MXFP8TrainingWeightWrapperTensor), ( |
184 | 188 | "slice should preserve subclass" |
185 | 189 | ) |
| 190 | + |
| 191 | + |
| 192 | +@pytest.mark.parametrize("op_name", ["mm", "matmul", "linear"]) |
| 193 | +@pytest.mark.parametrize("batch_size", [None, 2]) |
| 194 | +@pytest.mark.parametrize( |
| 195 | + "float8_linear_recipe", ["tensorwise", "rowwise", "rowwise_with_gw_hp"] |
| 196 | +) |
| 197 | +def test_float8_training_tensor_ops_fwd_bwd(op_name, batch_size, float8_linear_recipe): |
| 198 | + # mm doesn't support batching |
| 199 | + if op_name == "mm" and batch_size is not None: |
| 200 | + pytest.skip("mm doesn't support batching") |
| 201 | + |
| 202 | + # All FP8 linear recipes require SM89+ (torch._scaled_mm) |
| 203 | + if torch.cuda.get_device_capability() < (8, 9): |
| 204 | + pytest.skip("FP8 linear requires SM89+") |
| 205 | + |
| 206 | + # rowwise and rowwise_with_gw_hp require SM90+ (CUTLASS axiswise kernels) |
| 207 | + if float8_linear_recipe in ( |
| 208 | + "rowwise", |
| 209 | + "rowwise_with_gw_hp", |
| 210 | + ) and torch.cuda.get_device_capability() < (9, 0): |
| 211 | + pytest.skip("Rowwise FP8 requires SM90+") |
| 212 | + |
| 213 | + config = Float8TrainingOpConfig(float8_linear_recipe=float8_linear_recipe) |
| 214 | + |
| 215 | + M, K, N = 1024, 1024, 2048 |
| 216 | + if batch_size is None: |
| 217 | + A_shape = (M, K) |
| 218 | + else: |
| 219 | + A_shape = (batch_size, M, K) |
| 220 | + |
| 221 | + A = torch.randn(*A_shape, dtype=torch.bfloat16, device="cuda", requires_grad=True) |
| 222 | + B = torch.randn(N, K, dtype=torch.bfloat16, device="cuda", requires_grad=True) |
| 223 | + bias = ( |
| 224 | + torch.randn(N, dtype=torch.bfloat16, device="cuda") |
| 225 | + if op_name == "linear" |
| 226 | + else None |
| 227 | + ) |
| 228 | + |
| 229 | + # Reference computation with bf16 |
| 230 | + A_ref = A.clone().detach().requires_grad_(True) |
| 231 | + B_ref = B.clone().detach().requires_grad_(True) |
| 232 | + |
| 233 | + if op_name == "mm": |
| 234 | + result_ref = torch.mm(A_ref, B_ref.t()) |
| 235 | + elif op_name == "matmul": |
| 236 | + result_ref = torch.matmul(A_ref, B_ref.t()) |
| 237 | + elif op_name == "linear": |
| 238 | + result_ref = F.linear(A_ref, B_ref, bias) |
| 239 | + |
| 240 | + # FP8 computation |
| 241 | + B_fp8 = Float8TrainingWeightWrapperTensor(B, config) |
| 242 | + |
| 243 | + if op_name == "mm": |
| 244 | + result_fp8 = torch.mm(A, B_fp8) |
| 245 | + elif op_name == "matmul": |
| 246 | + result_fp8 = torch.matmul(A, B_fp8) |
| 247 | + elif op_name == "linear": |
| 248 | + result_fp8 = F.linear(A, B_fp8, bias) |
| 249 | + |
| 250 | + # Validate forward pass |
| 251 | + assert result_fp8.shape == result_ref.shape, "Shape mismatch" |
| 252 | + assert result_fp8.dtype == torch.bfloat16, "Dtype should be bfloat16" |
| 253 | + assert not isinstance(result_fp8, Float8TrainingWeightWrapperTensor), ( |
| 254 | + "Result should be unwrapped" |
| 255 | + ) |
| 256 | + |
| 257 | + # Check forward SQNR |
| 258 | + sqnr_fwd = compute_error(result_ref, result_fp8) |
| 259 | + min_sqnr_fwd = 25.0 |
| 260 | + assert sqnr_fwd >= min_sqnr_fwd, ( |
| 261 | + f"Forward SQNR {sqnr_fwd} is too low, must be >= {min_sqnr_fwd}" |
| 262 | + ) |
| 263 | + |
| 264 | + # Backward pass |
| 265 | + labels_ref = torch.ones_like(result_ref) |
| 266 | + labels_fp8 = torch.ones_like(result_fp8) |
| 267 | + loss_ref = F.mse_loss(result_ref, labels_ref) |
| 268 | + loss_fp8 = F.mse_loss(result_fp8, labels_fp8) |
| 269 | + loss_ref.backward() |
| 270 | + loss_fp8.backward() |
| 271 | + |
| 272 | + # Verify gradients exist |
| 273 | + assert A.grad is not None, "A.grad should be computed" |
| 274 | + assert A_ref.grad is not None, "A_ref.grad should be computed" |
| 275 | + assert B_fp8.grad is not None, "B_fp8.grad should be computed" |
| 276 | + assert B_ref.grad is not None, "B_ref.grad should be computed" |
| 277 | + |
| 278 | + # Check input gradient SQNR |
| 279 | + sqnr_input_grad = compute_error(A_ref.grad, A.grad) |
| 280 | + min_sqnr_input_grad = 24.0 |
| 281 | + assert sqnr_input_grad >= min_sqnr_input_grad, ( |
| 282 | + f"Input grad SQNR {sqnr_input_grad} is too low, must be >= {min_sqnr_input_grad}" |
| 283 | + ) |
| 284 | + |
| 285 | + # Check weight gradient SQNR |
| 286 | + sqnr_weight_grad = compute_error(B_ref.grad, B_fp8.grad) |
| 287 | + min_sqnr_weight_grad = 23.0 |
| 288 | + assert sqnr_weight_grad >= min_sqnr_weight_grad, ( |
| 289 | + f"Weight grad SQNR {sqnr_weight_grad} is too low, must be >= {min_sqnr_weight_grad}" |
| 290 | + ) |
0 commit comments