88import contextlib
99import copy
1010import functools
11+ import glob
1112import itertools
13+ import os
14+ import tempfile
1215import unittest
1316from typing import NamedTuple
1417
1518import torch
19+ import torch ._export .utils as eu
1620from torch ._dynamo import config as dynamo_config
1721from torch ._dynamo .testing import make_test_cls_with_patches
1822from torch ._dynamo .utils import counters
101105)
102106
103107
108+ def _get_fp8_aoti_options ():
109+ """
110+ Detect whether AOTI can exercise the FP8 fusion path.
111+
112+ AOTI constant-folds lifted tensors by default, which folds away the
113+ dequant scale and breaks pattern matching. We need to call
114+ ``add_dont_constant_fold`` on the FP8 dequant op so that the scale
115+ survives into the inductor graph and the fusion pattern can still be
116+ matched. If the API is not yet available we fall back to compile-only
117+ testing.
118+
119+ Returns ``[False]`` when AOTI is not exercisable, ``[False, True]``
120+ when it is. Float8 AOTI execution is additionally gated on torch >= 2.13.
121+ """
122+ aoti_options = [False ]
123+ if not torch_version_at_least ("2.13.0" ):
124+ return aoti_options
125+
126+ try :
127+ import torch ._inductor .constant_folding as cf
128+
129+ if hasattr (cf , "add_dont_constant_fold" ):
130+ cf .add_dont_constant_fold (
131+ torch .ops .torchao .dequantize_affine_float8_non_decomposed .default
132+ )
133+ aoti_options = [False , True ]
134+ except ImportError :
135+ pass
136+ return aoti_options
137+
138+
104139def cal_conv_generated_kernel_number (mod , input , dtype , dim = 4 , device = "cpu" ):
105140 # this function is to decide how many kernels are generated
106141 # while testing conv2d/3d/deconv2d
@@ -168,6 +203,7 @@ def _test_common(
168203 quantizer = None ,
169204 compile_options = {}, # noqa: B006
170205 is_fp8 = False ,
206+ use_aoti = False ,
171207 include_ops = None ,
172208 exclude_ops = None ,
173209 check_dynamic = None ,
@@ -185,6 +221,33 @@ def _test_common(
185221 mod = mod .to (device = device )
186222 counters .clear ()
187223 torch ._dynamo .reset ()
224+
225+ if use_aoti :
226+
227+ def aoti_compile (model , inputs , get_source_code = False ):
228+ with eu ._disable_aten_to_metadata_assertions ():
229+ exported = torch .export .export (model , inputs )
230+ with tempfile .TemporaryDirectory () as tmpdir :
231+ package_path = os .path .join (tmpdir , "model.pt2" )
232+ with config .patch ({"aot_inductor.output_path" : tmpdir }):
233+ torch ._inductor .aoti_compile_and_package (
234+ exported ,
235+ package_path = package_path ,
236+ )
237+
238+ if get_source_code :
239+ cpp_paths = glob .glob (
240+ os .path .join (tmpdir , "**/*.wrapper.cpp" ), recursive = True
241+ )
242+ assert cpp_paths , "Failed to find generated .wrapper.cpp"
243+ with open (cpp_paths [0 ]) as f :
244+ source_code = f .read ()
245+
246+ compiled_mod = torch ._inductor .aoti_load_package (package_path )
247+ if get_source_code :
248+ return compiled_mod , source_code
249+ return compiled_mod
250+
188251 if check_autocast == torch .bfloat16 and (
189252 torch .ops .mkldnn ._is_mkldnn_bf16_supported () or device == "xpu"
190253 ):
@@ -213,16 +276,29 @@ def _test_common(
213276 mod , inputs , is_qat , is_dynamic , quantizer , is_fp8
214277 )
215278
279+ # Dynamic aoti check is not supported for now.
280+ # Support can be added in the future if needed (e.g. via dynamic_shapes in
281+ # torch.export.export for the dynamic case).
282+ assert not (
283+ use_aoti and (check_dynamic is not None or compile_options .get ("dynamic" ))
284+ ), "Dynamic AOTI check is not supported for now"
285+
216286 with torch .no_grad (), maybe_autocast :
217287 if check_code :
218288 expected = mod (* inputs )
219- code_compile_options = dict (compile_options )
220- if check_dynamic is not None :
221- code_compile_options ["dynamic" ] = check_dynamic
222- actual , (source_code ,) = run_and_get_code (
223- torch .compile (mod , ** code_compile_options ),
224- * inputs ,
225- )
289+ if use_aoti :
290+ compiled_mod , source_code = aoti_compile (
291+ mod , inputs , get_source_code = True
292+ )
293+ actual = compiled_mod (* inputs )
294+ else :
295+ code_compile_options = dict (compile_options )
296+ if check_dynamic is not None :
297+ code_compile_options ["dynamic" ] = check_dynamic
298+ actual , (source_code ,) = run_and_get_code (
299+ torch .compile (mod , ** code_compile_options ),
300+ * inputs ,
301+ )
226302 for op in include_ops :
227303 self .assertIn (op , source_code )
228304 if num_include_ops is not None :
@@ -239,10 +315,17 @@ def _test_common(
239315 # Skip due to reduce range setting for Quantization on preCI system.
240316 torch .testing .assert_close (actual , expected , atol = atol , rtol = rtol )
241317 elif check_quantization :
242- _ = torch .compile (mod , ** compile_options )(* inputs )
318+ if use_aoti :
319+ _ = aoti_compile (mod , inputs )(* inputs )
320+ else :
321+ _ = torch .compile (mod , ** compile_options )(* inputs )
243322 else :
244323 expected = mod (* inputs )
245- actual = torch .compile (mod , ** compile_options )(* inputs )
324+ if use_aoti :
325+ compiled_mod = aoti_compile (mod , inputs )
326+ actual = compiled_mod (* inputs )
327+ else :
328+ actual = torch .compile (mod , ** compile_options )(* inputs )
246329 torch .testing .assert_close (
247330 actual .float (), expected .float (), atol = atol , rtol = rtol
248331 )
@@ -1488,6 +1571,7 @@ def _qlinear_test_helper(
14881571 is_dynamic = False ,
14891572 is_qat = False ,
14901573 is_fp8 = False ,
1574+ use_aoti = False ,
14911575 ):
14921576 class M (torch .nn .Module ):
14931577 def __init__ (self , use_bias , do_permute = False ):
@@ -1527,6 +1611,7 @@ def _default_matcher_check_fn():
15271611 is_qat = is_qat ,
15281612 is_dynamic = is_dynamic ,
15291613 is_fp8 = is_fp8 ,
1614+ use_aoti = use_aoti ,
15301615 include_ops = [] if is_fp8 else None ,
15311616 # ensure quantize_affine_float8_non_decomposed is lowered
15321617 exclude_ops = [
@@ -1552,8 +1637,10 @@ def test_fp8_qlinear_cpu(self):
15521637 r"""
15531638 This testcase will quantize a single Linear Moduel.
15541639 """
1555- for bias in [True , False ]:
1556- self ._qlinear_test_helper ((torch .randn ((2 , 4 )),), bias = bias , is_fp8 = True )
1640+ for use_aoti , bias in itertools .product (_get_fp8_aoti_options (), [True , False ]):
1641+ self ._qlinear_test_helper (
1642+ (torch .randn ((2 , 4 )),), bias = bias , is_fp8 = True , use_aoti = use_aoti
1643+ )
15571644
15581645 @skipIfNoDynamoSupport
15591646 @skipIfNoONEDNN
@@ -3076,7 +3163,9 @@ def test_fp8_q_attention_block(self):
30763163 annotate_matmul = annotate_matmul , is_fp8 = True
30773164 )
30783165
3079- def _test_scaled_embedding_bag_helper (self , dtype , with_output_quant = False ):
3166+ def _test_scaled_embedding_bag_helper (
3167+ self , dtype , with_output_quant = False , use_aoti = False
3168+ ):
30803169 class FP8QDQEmbeddingBag (torch .nn .Module ):
30813170 def __init__ (self ):
30823171 super ().__init__ ()
@@ -3086,13 +3175,13 @@ def __init__(self):
30863175 def _dequantize (self , weight ):
30873176 if dtype == torch .float8_e4m3fn :
30883177 res = torch .ops .torchao .dequantize_affine_float8_non_decomposed .default (
3089- tensor = weight . data ,
3178+ tensor = weight ,
30903179 scale = torch .tensor ([self .weight_scale ]),
30913180 output_dtype = torch .float ,
30923181 )
30933182 else :
30943183 res = torch .ops .quantized_decomposed .dequantize_per_tensor .default (
3095- weight . data ,
3184+ weight ,
30963185 self .weight_scale ,
30973186 0 ,
30983187 - 128 ,
@@ -3171,6 +3260,7 @@ def matcher_check_fn():
31713260 mod ,
31723261 (weight , indices , offsets ),
31733262 matcher_check_fn ,
3263+ use_aoti = use_aoti ,
31743264 )
31753265
31763266 @skipIfNoDynamoSupport
@@ -3181,7 +3271,10 @@ def matcher_check_fn():
31813271 reason = "cpp kernels not built" ,
31823272 )
31833273 def test_fp8_scaled_embedding_bag (self ):
3184- self ._test_scaled_embedding_bag_helper (torch .float8_e4m3fn )
3274+ for use_aoti in _get_fp8_aoti_options ():
3275+ self ._test_scaled_embedding_bag_helper (
3276+ torch .float8_e4m3fn , use_aoti = use_aoti
3277+ )
31853278
31863279 @skipIfNoDynamoSupport
31873280 @skipIfNoONEDNN
@@ -3252,15 +3345,61 @@ def matcher_check_fn():
32523345 shape = (128 , 3 )
32533346
32543347 mod = Mod ()
3255- for len in input_len_list :
3256- inputs = [torch .randn (shape ) for _ in range (len )]
3348+ for length in input_len_list :
3349+ inputs = [torch .randn (shape ) for _ in range (length )]
32573350 int8_inputs = [quant_input (x ) for x in inputs ]
32583351 self ._test_common (
32593352 mod ,
32603353 (int8_inputs ,),
32613354 matcher_check_fn ,
32623355 )
32633356
3357+ @skipIfNoDynamoSupport
3358+ @skipIfNoONEDNN
3359+ @skipIfNoFloat8Support
3360+ @unittest .skipIf (
3361+ "CPU" not in torch ._C ._dispatch_dump ("torchao::_scaled_embedding_bag" ),
3362+ reason = "cpp kernels not built" ,
3363+ )
3364+ def test_fp8_concat_dequant_quant (self ):
3365+ class Mod (torch .nn .Module ):
3366+ def __init__ (self ):
3367+ super ().__init__ ()
3368+ self .scale = 0.5
3369+
3370+ def forward (self , fp8_inputs ):
3371+ res = torch .cat (fp8_inputs , dim = 1 )
3372+ res = torch .ops .torchao .dequantize_affine_float8_non_decomposed .default (
3373+ tensor = res ,
3374+ scale = torch .tensor ([self .scale ]),
3375+ output_dtype = torch .float32 ,
3376+ )
3377+ res = torch .ops .torchao .quantize_affine_float8_non_decomposed .default (
3378+ tensor = res ,
3379+ scale = torch .tensor ([self .scale ]),
3380+ float8_dtype = torch .float8_e4m3fn ,
3381+ )
3382+ return res
3383+
3384+ def quant_input (x ):
3385+ scale = x .abs ().max () / 448.0
3386+ return (x / scale ).to (torch .float8_e4m3fn )
3387+
3388+ def matcher_check_fn ():
3389+ self .assertEqual (counters ["inductor" ]["concat_dq_q_matcher_count" ], 1 )
3390+
3391+ shape = (128 , 3 )
3392+ mod = Mod ()
3393+ for use_aoti , length in itertools .product (_get_fp8_aoti_options (), [2 , 3 ]):
3394+ inputs = [torch .randn (shape ) for _ in range (length )]
3395+ fp8_inputs = [quant_input (x ) for x in inputs ]
3396+ self ._test_common (
3397+ mod ,
3398+ (fp8_inputs ,),
3399+ matcher_check_fn ,
3400+ use_aoti = use_aoti ,
3401+ )
3402+
32643403
32653404@unittest .skipIf (not torch_version_at_least ("2.8.0" ), "Requires torch 2.8+" )
32663405@unittest .skipIf (torch .version .hip is not None , "Not applicable to ROCm" )
0 commit comments