Skip to content

Commit 3ad1067

Browse files
authored
add torch.compile test for Float8BlockwiseLinear (#4187)
1 parent f26135c commit 3ad1067

1 file changed

Lines changed: 115 additions & 31 deletions

File tree

test/prototype/blockwise_fp8_training/test_blockwise_linear.py

Lines changed: 115 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import pytest
1010
import torch
11+
from torch._dynamo.testing import CompileCounterWithBackend
1112

1213
from torchao.utils import is_sm_at_least_90
1314

@@ -22,16 +23,13 @@
2223
torch.random.manual_seed(0)
2324

2425

25-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
26-
@pytest.mark.parametrize("in_features", [4096])
27-
@pytest.mark.parametrize("out_features", [128256])
28-
@pytest.mark.parametrize("batch_size", [1, 8])
29-
@pytest.mark.parametrize("block_size", [128])
30-
def test_blockwise_quant_linear_fwd_bwd(
26+
def _run_blockwise_quant_linear_fwd_bwd(
3127
in_features,
3228
out_features,
3329
batch_size,
3430
block_size,
31+
*,
32+
compile_mode: bool = False,
3533
):
3634
if in_features % block_size != 0 or out_features % block_size != 0:
3735
pytest.skip(f"Dimensions must be divisible by block_size={block_size}")
@@ -41,33 +39,119 @@ def test_blockwise_quant_linear_fwd_bwd(
4139
out_features=out_features,
4240
bias=False,
4341
).cuda()
44-
4542
layer_test = Float8BlockwiseLinear.from_float(copy.deepcopy(layer_ref))
43+
compiled_frame_counter = None
44+
compiled_step = None
45+
46+
if compile_mode:
47+
with torch._dynamo.config.patch(trace_autograd_ops=True):
48+
torch._dynamo.reset()
49+
compiled_frame_counter = CompileCounterWithBackend("inductor")
50+
51+
def step(x):
52+
y = layer_test(x)
53+
x_grad, weight_grad = torch.autograd.grad(
54+
y.sum(),
55+
(x, layer_test.weight),
56+
)
57+
return y.detach(), x_grad, weight_grad
58+
59+
compiled_step = torch.compile(
60+
step,
61+
backend=compiled_frame_counter,
62+
fullgraph=True,
63+
)
64+
65+
def run_once(x_test, x_ref):
66+
if compile_mode:
67+
assert compiled_step is not None
68+
with torch._dynamo.config.patch(trace_autograd_ops=True):
69+
y_test, x_grad_test, weight_grad_test = compiled_step(x_test)
70+
else:
71+
y_test = layer_test(x_test)
72+
73+
y_ref = layer_ref(x_ref)
74+
75+
sqnr = compute_error(y_ref, y_test)
76+
assert not y_test.isnan().any(), "Output must not contain NaNs"
77+
assert sqnr >= 25.0, f"SQNR: {sqnr.item()} must be >= 25.0"
78+
assert not sqnr.isinf().any(), "SQNR must not be inf"
79+
80+
if compile_mode:
81+
x_grad_ref, weight_grad_ref = torch.autograd.grad(
82+
y_ref.sum(),
83+
(x_ref, layer_ref.weight),
84+
)
85+
else:
86+
y_test.sum().backward()
87+
y_ref.sum().backward()
88+
x_grad_test = x_test.grad
89+
weight_grad_test = layer_test.weight.grad
90+
x_grad_ref = x_ref.grad
91+
weight_grad_ref = layer_ref.weight.grad
92+
93+
sqnr = compute_error(x_grad_ref, x_grad_test)
94+
assert not x_grad_test.isnan().any(), "Input grad must not contain NaNs"
95+
assert sqnr >= 30.0, f"SQNR: {sqnr} must be >= 30.0"
96+
97+
sqnr = compute_error(weight_grad_ref, weight_grad_test)
98+
assert not weight_grad_test.isnan().any(), "Weight grad must not contain NaNs"
99+
assert sqnr >= 30.0, f"SQNR: {sqnr} must be >= 30.0"
46100

47-
# Create input tensor
48101
x_test = torch.randn(batch_size, 256, in_features).cuda().requires_grad_(True)
49102
x_ref = x_test.clone().detach().requires_grad_(True)
103+
run_once(x_test, x_ref)
104+
105+
if compile_mode:
106+
assert compiled_frame_counter is not None
107+
assert compiled_frame_counter.frame_count == 1, (
108+
"Compiled blockwise linear should run in a single frame"
109+
)
110+
111+
x_test = torch.randn(batch_size, 256, in_features).cuda().requires_grad_(True)
112+
x_ref = x_test.clone().detach().requires_grad_(True)
113+
run_once(x_test, x_ref)
114+
115+
assert compiled_frame_counter.frame_count == 1, (
116+
"Compiled blockwise linear should not recompile for repeated calls "
117+
"with the same shapes"
118+
)
119+
120+
121+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
122+
@pytest.mark.parametrize("in_features", [4096])
123+
@pytest.mark.parametrize("out_features", [128256])
124+
@pytest.mark.parametrize("batch_size", [1, 8])
125+
@pytest.mark.parametrize("block_size", [128])
126+
def test_blockwise_quant_linear_fwd_bwd(
127+
in_features,
128+
out_features,
129+
batch_size,
130+
block_size,
131+
):
132+
_run_blockwise_quant_linear_fwd_bwd(
133+
in_features,
134+
out_features,
135+
batch_size,
136+
block_size,
137+
)
138+
50139

51-
# Forward pass
52-
y_test = layer_test(x_test)
53-
y_ref = layer_ref(x_ref)
54-
55-
# Compare outputs
56-
sqnr = compute_error(y_ref, y_test)
57-
assert not y_test.isnan().any(), "Output must not contain NaNs"
58-
assert sqnr >= 25.0, f"SQNR: {sqnr.item()} must be >= 25.0"
59-
assert not sqnr.isinf().any(), "SQNR must not be inf"
60-
61-
# Backward pass
62-
y_test.sum().backward()
63-
y_ref.sum().backward()
64-
65-
# Compare input grads
66-
sqnr = compute_error(x_ref.grad, x_test.grad)
67-
assert not x_test.grad.isnan().any(), "Input grad must not contain NaNs"
68-
assert sqnr >= 30.0, f"SQNR: {sqnr} must be >= 25.0"
69-
70-
# Compare weight grads
71-
sqnr = compute_error(layer_ref.weight, layer_test.weight)
72-
assert not layer_test.weight.grad.isnan().any(), "Weight grad must not contain NaNs"
73-
assert sqnr >= 30.0, f"SQNR: {sqnr} must be >= 25.0"
140+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
141+
@pytest.mark.parametrize("in_features", [4096])
142+
@pytest.mark.parametrize("out_features", [4096])
143+
@pytest.mark.parametrize("batch_size", [1])
144+
@pytest.mark.parametrize("block_size", [128])
145+
def test_blockwise_quant_linear_compile_fullgraph_fwd_bwd(
146+
in_features,
147+
out_features,
148+
batch_size,
149+
block_size,
150+
):
151+
_run_blockwise_quant_linear_fwd_bwd(
152+
in_features,
153+
out_features,
154+
batch_size,
155+
block_size,
156+
compile_mode=True,
157+
)

0 commit comments

Comments
 (0)