Skip to content

Commit 721d69c

Browse files
[fp8 training] support linear op overrides in Float8TrainingWeightWrapperTensor (#4325)
stack-info: PR: #4325, branch: danielvegamyhre/stack/165
1 parent a9f24af commit 721d69c

4 files changed

Lines changed: 169 additions & 3 deletions

File tree

test/prototype/moe_training/test_tensor.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,17 @@
1515
pytest.skip("CUDA and PyTorch 2.7.0+ required", allow_module_level=True)
1616

1717
from torchao.prototype.moe_training.config import (
18+
Float8TrainingOpConfig,
1819
MXFP8TrainingOpConfig,
1920
MXFP8TrainingRecipe,
2021
)
2122
from torchao.prototype.moe_training.kernels.mxfp8.quant import (
2223
_mxfp8_cutedsl_kernels_available,
2324
)
24-
from torchao.prototype.moe_training.tensor import MXFP8TrainingWeightWrapperTensor
25+
from torchao.prototype.moe_training.tensor import (
26+
Float8TrainingWeightWrapperTensor,
27+
MXFP8TrainingWeightWrapperTensor,
28+
)
2529
from torchao.prototype.mx_formats.config import (
2630
MXFP8Dim1CastKernelChoice,
2731
)
@@ -183,3 +187,104 @@ def test_mxfp8_training_tensor_ops_preserve_subclass():
183187
assert isinstance(result, MXFP8TrainingWeightWrapperTensor), (
184188
"slice should preserve subclass"
185189
)
190+
191+
192+
@pytest.mark.parametrize("op_name", ["mm", "matmul", "linear"])
193+
@pytest.mark.parametrize("batch_size", [None, 2])
194+
@pytest.mark.parametrize(
195+
"float8_linear_recipe", ["tensorwise", "rowwise", "rowwise_with_gw_hp"]
196+
)
197+
def test_float8_training_tensor_ops_fwd_bwd(op_name, batch_size, float8_linear_recipe):
198+
# mm doesn't support batching
199+
if op_name == "mm" and batch_size is not None:
200+
pytest.skip("mm doesn't support batching")
201+
202+
# All FP8 linear recipes require SM89+ (torch._scaled_mm)
203+
if torch.cuda.get_device_capability() < (8, 9):
204+
pytest.skip("FP8 linear requires SM89+")
205+
206+
# rowwise and rowwise_with_gw_hp require SM90+ (CUTLASS axiswise kernels)
207+
if float8_linear_recipe in (
208+
"rowwise",
209+
"rowwise_with_gw_hp",
210+
) and torch.cuda.get_device_capability() < (9, 0):
211+
pytest.skip("Rowwise FP8 requires SM90+")
212+
213+
config = Float8TrainingOpConfig(float8_linear_recipe=float8_linear_recipe)
214+
215+
M, K, N = 1024, 1024, 2048
216+
if batch_size is None:
217+
A_shape = (M, K)
218+
else:
219+
A_shape = (batch_size, M, K)
220+
221+
A = torch.randn(*A_shape, dtype=torch.bfloat16, device="cuda", requires_grad=True)
222+
B = torch.randn(N, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
223+
bias = (
224+
torch.randn(N, dtype=torch.bfloat16, device="cuda")
225+
if op_name == "linear"
226+
else None
227+
)
228+
229+
# Reference computation with bf16
230+
A_ref = A.clone().detach().requires_grad_(True)
231+
B_ref = B.clone().detach().requires_grad_(True)
232+
233+
if op_name == "mm":
234+
result_ref = torch.mm(A_ref, B_ref.t())
235+
elif op_name == "matmul":
236+
result_ref = torch.matmul(A_ref, B_ref.t())
237+
elif op_name == "linear":
238+
result_ref = F.linear(A_ref, B_ref, bias)
239+
240+
# FP8 computation
241+
B_fp8 = Float8TrainingWeightWrapperTensor(B, config)
242+
243+
if op_name == "mm":
244+
result_fp8 = torch.mm(A, B_fp8)
245+
elif op_name == "matmul":
246+
result_fp8 = torch.matmul(A, B_fp8)
247+
elif op_name == "linear":
248+
result_fp8 = F.linear(A, B_fp8, bias)
249+
250+
# Validate forward pass
251+
assert result_fp8.shape == result_ref.shape, "Shape mismatch"
252+
assert result_fp8.dtype == torch.bfloat16, "Dtype should be bfloat16"
253+
assert not isinstance(result_fp8, Float8TrainingWeightWrapperTensor), (
254+
"Result should be unwrapped"
255+
)
256+
257+
# Check forward SQNR
258+
sqnr_fwd = compute_error(result_ref, result_fp8)
259+
min_sqnr_fwd = 25.0
260+
assert sqnr_fwd >= min_sqnr_fwd, (
261+
f"Forward SQNR {sqnr_fwd} is too low, must be >= {min_sqnr_fwd}"
262+
)
263+
264+
# Backward pass
265+
labels_ref = torch.ones_like(result_ref)
266+
labels_fp8 = torch.ones_like(result_fp8)
267+
loss_ref = F.mse_loss(result_ref, labels_ref)
268+
loss_fp8 = F.mse_loss(result_fp8, labels_fp8)
269+
loss_ref.backward()
270+
loss_fp8.backward()
271+
272+
# Verify gradients exist
273+
assert A.grad is not None, "A.grad should be computed"
274+
assert A_ref.grad is not None, "A_ref.grad should be computed"
275+
assert B_fp8.grad is not None, "B_fp8.grad should be computed"
276+
assert B_ref.grad is not None, "B_ref.grad should be computed"
277+
278+
# Check input gradient SQNR
279+
sqnr_input_grad = compute_error(A_ref.grad, A.grad)
280+
min_sqnr_input_grad = 24.0
281+
assert sqnr_input_grad >= min_sqnr_input_grad, (
282+
f"Input grad SQNR {sqnr_input_grad} is too low, must be >= {min_sqnr_input_grad}"
283+
)
284+
285+
# Check weight gradient SQNR
286+
sqnr_weight_grad = compute_error(B_ref.grad, B_fp8.grad)
287+
min_sqnr_weight_grad = 23.0
288+
assert sqnr_weight_grad >= min_sqnr_weight_grad, (
289+
f"Weight grad SQNR {sqnr_weight_grad} is too low, must be >= {min_sqnr_weight_grad}"
290+
)

torchao/prototype/moe_training/config.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66

77
from dataclasses import dataclass
88
from enum import Enum
9-
from typing import Optional
9+
from typing import Literal, Optional
1010

1111
import torch
1212
from torch import nn
1313

1414
from torchao.core.config import AOBaseConfig
15+
from torchao.float8.config import Float8LinearConfig
16+
from torchao.float8.float8_training_tensor import LinearMMConfig, ScaledMMConfig
1517
from torchao.prototype.mx_formats.config import ScaleCalculationMode
1618
from torchao.quantization.quantize_.common import KernelPreference
1719
from torchao.quantization.transform_module import register_quantize_module_handler
@@ -64,6 +66,41 @@ class Float8TrainingOpConfig(TrainingOpBaseConfig):
6466
# causes a D2H sync that breaks torch.compile.
6567
pad_token_groups_for_grouped_mm: bool = False
6668

69+
# Recipe for the float8 linear op override ("tensorwise" or "rowwise").
70+
float8_linear_recipe: Literal["tensorwise", "rowwise", "rowwise_with_gw_hp"] = (
71+
"rowwise"
72+
)
73+
74+
def __post_init__(self):
75+
# Pre-build internal configs for the linear op override.
76+
self._float8_linear_config = Float8LinearConfig.from_recipe_name(
77+
self.float8_linear_recipe
78+
)
79+
c = self._float8_linear_config
80+
self._linear_mm_config = LinearMMConfig(
81+
# output
82+
ScaledMMConfig(
83+
c.emulate,
84+
c.gemm_config_output.use_fast_accum,
85+
False,
86+
c.pad_inner_dim,
87+
),
88+
# grad_input
89+
ScaledMMConfig(
90+
c.emulate,
91+
c.gemm_config_grad_input.use_fast_accum,
92+
False,
93+
c.pad_inner_dim,
94+
),
95+
# grad_weight
96+
ScaledMMConfig(
97+
c.emulate,
98+
c.gemm_config_grad_weight.use_fast_accum,
99+
False,
100+
c.pad_inner_dim,
101+
),
102+
)
103+
67104
@classmethod
68105
def from_recipe(
69106
cls,
@@ -82,6 +119,7 @@ def __eq__(self, other):
82119
and self.out_dtype == other.out_dtype
83120
and self.pad_token_groups_for_grouped_mm
84121
== other.pad_token_groups_for_grouped_mm
122+
and self.float8_linear_recipe == other.float8_linear_recipe
85123
)
86124
return NotImplemented
87125

@@ -91,6 +129,7 @@ def __hash__(self):
91129
self.float8_dtype,
92130
self.out_dtype,
93131
self.pad_token_groups_for_grouped_mm,
132+
self.float8_linear_recipe,
94133
)
95134
)
96135

torchao/prototype/moe_training/tensor.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from torchao.prototype.moe_training.utils import (
2323
_quantize_then_scaled_grouped_mm,
24+
_to_fp8_then_scaled_mm,
2425
unwrap_weight,
2526
)
2627
from torchao.prototype.mx_formats.mx_linear import _to_mxfp8_then_scaled_mm
@@ -250,7 +251,19 @@ def __torch_function__(cls, func, types, args, kwargs={}):
250251
config=config,
251252
)
252253

253-
# TOOD: linear op override
254+
# linear op override
255+
elif func.__name__ in ("linear", "mm", "matmul", "addmm"):
256+
A, B = args[0], args[1]
257+
assert not isinstance(A, cls), f"A should not be a {cls.__name__}"
258+
assert isinstance(B, cls), f"B should be a {cls.__name__}"
259+
config = B.config
260+
result = _to_fp8_then_scaled_mm(A, unwrap_weight(B), config)
261+
# Handle bias for F.linear(input, weight, bias) calls
262+
bias = args[2] if len(args) > 2 else kwargs.get("bias", None)
263+
if bias is not None:
264+
result = result + bias.to(result.dtype)
265+
return result
266+
254267
else:
255268
# Disable torch_function by hand because we don't want
256269
# the wrapping behavior of the super() impl, go directly to dispatch

torchao/prototype/moe_training/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,3 +512,12 @@ def backward(ctx, grad_output):
512512

513513
def unwrap_weight(wrapper_tensor):
514514
return _UnwrapWeight.apply(wrapper_tensor)
515+
516+
517+
def _to_fp8_then_scaled_mm(input, weight, config):
518+
"""Helper to perform FP8 linear via matmul_with_hp_or_float8_args."""
519+
from torchao.float8.float8_linear import matmul_with_hp_or_float8_args
520+
521+
return matmul_with_hp_or_float8_args.apply(
522+
input, weight.t(), config._linear_mm_config, config._float8_linear_config
523+
)

0 commit comments

Comments
 (0)