88
99import pytest
1010import torch
11+ from torch ._dynamo .testing import CompileCounterWithBackend
1112
1213from torchao .utils import is_sm_at_least_90
1314
2223torch .random .manual_seed (0 )
2324
2425
25- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
26- @pytest .mark .parametrize ("in_features" , [4096 ])
27- @pytest .mark .parametrize ("out_features" , [128256 ])
28- @pytest .mark .parametrize ("batch_size" , [1 , 8 ])
29- @pytest .mark .parametrize ("block_size" , [128 ])
30- def test_blockwise_quant_linear_fwd_bwd (
26+ def _run_blockwise_quant_linear_fwd_bwd (
3127 in_features ,
3228 out_features ,
3329 batch_size ,
3430 block_size ,
31+ * ,
32+ compile_mode : bool = False ,
3533):
3634 if in_features % block_size != 0 or out_features % block_size != 0 :
3735 pytest .skip (f"Dimensions must be divisible by block_size={ block_size } " )
@@ -41,33 +39,119 @@ def test_blockwise_quant_linear_fwd_bwd(
4139 out_features = out_features ,
4240 bias = False ,
4341 ).cuda ()
44-
4542 layer_test = Float8BlockwiseLinear .from_float (copy .deepcopy (layer_ref ))
43+ compiled_frame_counter = None
44+ compiled_step = None
45+
46+ if compile_mode :
47+ with torch ._dynamo .config .patch (trace_autograd_ops = True ):
48+ torch ._dynamo .reset ()
49+ compiled_frame_counter = CompileCounterWithBackend ("inductor" )
50+
51+ def step (x ):
52+ y = layer_test (x )
53+ x_grad , weight_grad = torch .autograd .grad (
54+ y .sum (),
55+ (x , layer_test .weight ),
56+ )
57+ return y .detach (), x_grad , weight_grad
58+
59+ compiled_step = torch .compile (
60+ step ,
61+ backend = compiled_frame_counter ,
62+ fullgraph = True ,
63+ )
64+
65+ def run_once (x_test , x_ref ):
66+ if compile_mode :
67+ assert compiled_step is not None
68+ with torch ._dynamo .config .patch (trace_autograd_ops = True ):
69+ y_test , x_grad_test , weight_grad_test = compiled_step (x_test )
70+ else :
71+ y_test = layer_test (x_test )
72+
73+ y_ref = layer_ref (x_ref )
74+
75+ sqnr = compute_error (y_ref , y_test )
76+ assert not y_test .isnan ().any (), "Output must not contain NaNs"
77+ assert sqnr >= 25.0 , f"SQNR: { sqnr .item ()} must be >= 25.0"
78+ assert not sqnr .isinf ().any (), "SQNR must not be inf"
79+
80+ if compile_mode :
81+ x_grad_ref , weight_grad_ref = torch .autograd .grad (
82+ y_ref .sum (),
83+ (x_ref , layer_ref .weight ),
84+ )
85+ else :
86+ y_test .sum ().backward ()
87+ y_ref .sum ().backward ()
88+ x_grad_test = x_test .grad
89+ weight_grad_test = layer_test .weight .grad
90+ x_grad_ref = x_ref .grad
91+ weight_grad_ref = layer_ref .weight .grad
92+
93+ sqnr = compute_error (x_grad_ref , x_grad_test )
94+ assert not x_grad_test .isnan ().any (), "Input grad must not contain NaNs"
95+ assert sqnr >= 30.0 , f"SQNR: { sqnr } must be >= 30.0"
96+
97+ sqnr = compute_error (weight_grad_ref , weight_grad_test )
98+ assert not weight_grad_test .isnan ().any (), "Weight grad must not contain NaNs"
99+ assert sqnr >= 30.0 , f"SQNR: { sqnr } must be >= 30.0"
46100
47- # Create input tensor
48101 x_test = torch .randn (batch_size , 256 , in_features ).cuda ().requires_grad_ (True )
49102 x_ref = x_test .clone ().detach ().requires_grad_ (True )
103+ run_once (x_test , x_ref )
104+
105+ if compile_mode :
106+ assert compiled_frame_counter is not None
107+ assert compiled_frame_counter .frame_count == 1 , (
108+ "Compiled blockwise linear should run in a single frame"
109+ )
110+
111+ x_test = torch .randn (batch_size , 256 , in_features ).cuda ().requires_grad_ (True )
112+ x_ref = x_test .clone ().detach ().requires_grad_ (True )
113+ run_once (x_test , x_ref )
114+
115+ assert compiled_frame_counter .frame_count == 1 , (
116+ "Compiled blockwise linear should not recompile for repeated calls "
117+ "with the same shapes"
118+ )
119+
120+
121+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
122+ @pytest .mark .parametrize ("in_features" , [4096 ])
123+ @pytest .mark .parametrize ("out_features" , [128256 ])
124+ @pytest .mark .parametrize ("batch_size" , [1 , 8 ])
125+ @pytest .mark .parametrize ("block_size" , [128 ])
126+ def test_blockwise_quant_linear_fwd_bwd (
127+ in_features ,
128+ out_features ,
129+ batch_size ,
130+ block_size ,
131+ ):
132+ _run_blockwise_quant_linear_fwd_bwd (
133+ in_features ,
134+ out_features ,
135+ batch_size ,
136+ block_size ,
137+ )
138+
50139
51- # Forward pass
52- y_test = layer_test (x_test )
53- y_ref = layer_ref (x_ref )
54-
55- # Compare outputs
56- sqnr = compute_error (y_ref , y_test )
57- assert not y_test .isnan ().any (), "Output must not contain NaNs"
58- assert sqnr >= 25.0 , f"SQNR: { sqnr .item ()} must be >= 25.0"
59- assert not sqnr .isinf ().any (), "SQNR must not be inf"
60-
61- # Backward pass
62- y_test .sum ().backward ()
63- y_ref .sum ().backward ()
64-
65- # Compare input grads
66- sqnr = compute_error (x_ref .grad , x_test .grad )
67- assert not x_test .grad .isnan ().any (), "Input grad must not contain NaNs"
68- assert sqnr >= 30.0 , f"SQNR: { sqnr } must be >= 25.0"
69-
70- # Compare weight grads
71- sqnr = compute_error (layer_ref .weight , layer_test .weight )
72- assert not layer_test .weight .grad .isnan ().any (), "Weight grad must not contain NaNs"
73- assert sqnr >= 30.0 , f"SQNR: { sqnr } must be >= 25.0"
140+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
141+ @pytest .mark .parametrize ("in_features" , [4096 ])
142+ @pytest .mark .parametrize ("out_features" , [4096 ])
143+ @pytest .mark .parametrize ("batch_size" , [1 ])
144+ @pytest .mark .parametrize ("block_size" , [128 ])
145+ def test_blockwise_quant_linear_compile_fullgraph_fwd_bwd (
146+ in_features ,
147+ out_features ,
148+ batch_size ,
149+ block_size ,
150+ ):
151+ _run_blockwise_quant_linear_fwd_bwd (
152+ in_features ,
153+ out_features ,
154+ batch_size ,
155+ block_size ,
156+ compile_mode = True ,
157+ )
0 commit comments