Skip to content

Commit 9472d7d

Browse files
authored
[X86] Enable FP8 patterns for AOT Inductor (#4099)
* enable fp8 qlinear patterns for AOT Inductor * update ut * update comments * fix import format * fix format * remove metadata_assertion * fix conflicts * add fp8 support for concat_dq_q_pattern * update * update ut * update comments * update * skip UT when torch < 2.13.0
1 parent bbe615c commit 9472d7d

3 files changed

Lines changed: 259 additions & 68 deletions

File tree

test/quantization/pt2e/test_x86inductor_fusion.py

Lines changed: 156 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@
88
import contextlib
99
import copy
1010
import functools
11+
import glob
1112
import itertools
13+
import os
14+
import tempfile
1215
import unittest
1316
from typing import NamedTuple
1417

1518
import torch
19+
import torch._export.utils as eu
1620
from torch._dynamo import config as dynamo_config
1721
from torch._dynamo.testing import make_test_cls_with_patches
1822
from torch._dynamo.utils import counters
@@ -101,6 +105,37 @@
101105
)
102106

103107

108+
def _get_fp8_aoti_options():
109+
"""
110+
Detect whether AOTI can exercise the FP8 fusion path.
111+
112+
AOTI constant-folds lifted tensors by default, which folds away the
113+
dequant scale and breaks pattern matching. We need to call
114+
``add_dont_constant_fold`` on the FP8 dequant op so that the scale
115+
survives into the inductor graph and the fusion pattern can still be
116+
matched. If the API is not yet available we fall back to compile-only
117+
testing.
118+
119+
Returns ``[False]`` when AOTI is not exercisable, ``[False, True]``
120+
when it is. Float8 AOTI execution is additionally gated on torch >= 2.13.
121+
"""
122+
aoti_options = [False]
123+
if not torch_version_at_least("2.13.0"):
124+
return aoti_options
125+
126+
try:
127+
import torch._inductor.constant_folding as cf
128+
129+
if hasattr(cf, "add_dont_constant_fold"):
130+
cf.add_dont_constant_fold(
131+
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default
132+
)
133+
aoti_options = [False, True]
134+
except ImportError:
135+
pass
136+
return aoti_options
137+
138+
104139
def cal_conv_generated_kernel_number(mod, input, dtype, dim=4, device="cpu"):
105140
# this function is to decide how many kernels are generated
106141
# while testing conv2d/3d/deconv2d
@@ -168,6 +203,7 @@ def _test_common(
168203
quantizer=None,
169204
compile_options={}, # noqa: B006
170205
is_fp8=False,
206+
use_aoti=False,
171207
include_ops=None,
172208
exclude_ops=None,
173209
check_dynamic=None,
@@ -185,6 +221,33 @@ def _test_common(
185221
mod = mod.to(device=device)
186222
counters.clear()
187223
torch._dynamo.reset()
224+
225+
if use_aoti:
226+
227+
def aoti_compile(model, inputs, get_source_code=False):
228+
with eu._disable_aten_to_metadata_assertions():
229+
exported = torch.export.export(model, inputs)
230+
with tempfile.TemporaryDirectory() as tmpdir:
231+
package_path = os.path.join(tmpdir, "model.pt2")
232+
with config.patch({"aot_inductor.output_path": tmpdir}):
233+
torch._inductor.aoti_compile_and_package(
234+
exported,
235+
package_path=package_path,
236+
)
237+
238+
if get_source_code:
239+
cpp_paths = glob.glob(
240+
os.path.join(tmpdir, "**/*.wrapper.cpp"), recursive=True
241+
)
242+
assert cpp_paths, "Failed to find generated .wrapper.cpp"
243+
with open(cpp_paths[0]) as f:
244+
source_code = f.read()
245+
246+
compiled_mod = torch._inductor.aoti_load_package(package_path)
247+
if get_source_code:
248+
return compiled_mod, source_code
249+
return compiled_mod
250+
188251
if check_autocast == torch.bfloat16 and (
189252
torch.ops.mkldnn._is_mkldnn_bf16_supported() or device == "xpu"
190253
):
@@ -213,16 +276,29 @@ def _test_common(
213276
mod, inputs, is_qat, is_dynamic, quantizer, is_fp8
214277
)
215278

279+
# Dynamic aoti check is not supported for now.
280+
# Support can be added in the future if needed (e.g. via dynamic_shapes in
281+
# torch.export.export for the dynamic case).
282+
assert not (
283+
use_aoti and (check_dynamic is not None or compile_options.get("dynamic"))
284+
), "Dynamic AOTI check is not supported for now"
285+
216286
with torch.no_grad(), maybe_autocast:
217287
if check_code:
218288
expected = mod(*inputs)
219-
code_compile_options = dict(compile_options)
220-
if check_dynamic is not None:
221-
code_compile_options["dynamic"] = check_dynamic
222-
actual, (source_code,) = run_and_get_code(
223-
torch.compile(mod, **code_compile_options),
224-
*inputs,
225-
)
289+
if use_aoti:
290+
compiled_mod, source_code = aoti_compile(
291+
mod, inputs, get_source_code=True
292+
)
293+
actual = compiled_mod(*inputs)
294+
else:
295+
code_compile_options = dict(compile_options)
296+
if check_dynamic is not None:
297+
code_compile_options["dynamic"] = check_dynamic
298+
actual, (source_code,) = run_and_get_code(
299+
torch.compile(mod, **code_compile_options),
300+
*inputs,
301+
)
226302
for op in include_ops:
227303
self.assertIn(op, source_code)
228304
if num_include_ops is not None:
@@ -239,10 +315,17 @@ def _test_common(
239315
# Skip due to reduce range setting for Quantization on preCI system.
240316
torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
241317
elif check_quantization:
242-
_ = torch.compile(mod, **compile_options)(*inputs)
318+
if use_aoti:
319+
_ = aoti_compile(mod, inputs)(*inputs)
320+
else:
321+
_ = torch.compile(mod, **compile_options)(*inputs)
243322
else:
244323
expected = mod(*inputs)
245-
actual = torch.compile(mod, **compile_options)(*inputs)
324+
if use_aoti:
325+
compiled_mod = aoti_compile(mod, inputs)
326+
actual = compiled_mod(*inputs)
327+
else:
328+
actual = torch.compile(mod, **compile_options)(*inputs)
246329
torch.testing.assert_close(
247330
actual.float(), expected.float(), atol=atol, rtol=rtol
248331
)
@@ -1488,6 +1571,7 @@ def _qlinear_test_helper(
14881571
is_dynamic=False,
14891572
is_qat=False,
14901573
is_fp8=False,
1574+
use_aoti=False,
14911575
):
14921576
class M(torch.nn.Module):
14931577
def __init__(self, use_bias, do_permute=False):
@@ -1527,6 +1611,7 @@ def _default_matcher_check_fn():
15271611
is_qat=is_qat,
15281612
is_dynamic=is_dynamic,
15291613
is_fp8=is_fp8,
1614+
use_aoti=use_aoti,
15301615
include_ops=[] if is_fp8 else None,
15311616
# ensure quantize_affine_float8_non_decomposed is lowered
15321617
exclude_ops=[
@@ -1552,8 +1637,10 @@ def test_fp8_qlinear_cpu(self):
15521637
r"""
15531638
This testcase will quantize a single Linear Moduel.
15541639
"""
1555-
for bias in [True, False]:
1556-
self._qlinear_test_helper((torch.randn((2, 4)),), bias=bias, is_fp8=True)
1640+
for use_aoti, bias in itertools.product(_get_fp8_aoti_options(), [True, False]):
1641+
self._qlinear_test_helper(
1642+
(torch.randn((2, 4)),), bias=bias, is_fp8=True, use_aoti=use_aoti
1643+
)
15571644

15581645
@skipIfNoDynamoSupport
15591646
@skipIfNoONEDNN
@@ -3076,7 +3163,9 @@ def test_fp8_q_attention_block(self):
30763163
annotate_matmul=annotate_matmul, is_fp8=True
30773164
)
30783165

3079-
def _test_scaled_embedding_bag_helper(self, dtype, with_output_quant=False):
3166+
def _test_scaled_embedding_bag_helper(
3167+
self, dtype, with_output_quant=False, use_aoti=False
3168+
):
30803169
class FP8QDQEmbeddingBag(torch.nn.Module):
30813170
def __init__(self):
30823171
super().__init__()
@@ -3086,13 +3175,13 @@ def __init__(self):
30863175
def _dequantize(self, weight):
30873176
if dtype == torch.float8_e4m3fn:
30883177
res = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default(
3089-
tensor=weight.data,
3178+
tensor=weight,
30903179
scale=torch.tensor([self.weight_scale]),
30913180
output_dtype=torch.float,
30923181
)
30933182
else:
30943183
res = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
3095-
weight.data,
3184+
weight,
30963185
self.weight_scale,
30973186
0,
30983187
-128,
@@ -3171,6 +3260,7 @@ def matcher_check_fn():
31713260
mod,
31723261
(weight, indices, offsets),
31733262
matcher_check_fn,
3263+
use_aoti=use_aoti,
31743264
)
31753265

31763266
@skipIfNoDynamoSupport
@@ -3181,7 +3271,10 @@ def matcher_check_fn():
31813271
reason="cpp kernels not built",
31823272
)
31833273
def test_fp8_scaled_embedding_bag(self):
3184-
self._test_scaled_embedding_bag_helper(torch.float8_e4m3fn)
3274+
for use_aoti in _get_fp8_aoti_options():
3275+
self._test_scaled_embedding_bag_helper(
3276+
torch.float8_e4m3fn, use_aoti=use_aoti
3277+
)
31853278

31863279
@skipIfNoDynamoSupport
31873280
@skipIfNoONEDNN
@@ -3252,15 +3345,61 @@ def matcher_check_fn():
32523345
shape = (128, 3)
32533346

32543347
mod = Mod()
3255-
for len in input_len_list:
3256-
inputs = [torch.randn(shape) for _ in range(len)]
3348+
for length in input_len_list:
3349+
inputs = [torch.randn(shape) for _ in range(length)]
32573350
int8_inputs = [quant_input(x) for x in inputs]
32583351
self._test_common(
32593352
mod,
32603353
(int8_inputs,),
32613354
matcher_check_fn,
32623355
)
32633356

3357+
@skipIfNoDynamoSupport
3358+
@skipIfNoONEDNN
3359+
@skipIfNoFloat8Support
3360+
@unittest.skipIf(
3361+
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
3362+
reason="cpp kernels not built",
3363+
)
3364+
def test_fp8_concat_dequant_quant(self):
3365+
class Mod(torch.nn.Module):
3366+
def __init__(self):
3367+
super().__init__()
3368+
self.scale = 0.5
3369+
3370+
def forward(self, fp8_inputs):
3371+
res = torch.cat(fp8_inputs, dim=1)
3372+
res = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default(
3373+
tensor=res,
3374+
scale=torch.tensor([self.scale]),
3375+
output_dtype=torch.float32,
3376+
)
3377+
res = torch.ops.torchao.quantize_affine_float8_non_decomposed.default(
3378+
tensor=res,
3379+
scale=torch.tensor([self.scale]),
3380+
float8_dtype=torch.float8_e4m3fn,
3381+
)
3382+
return res
3383+
3384+
def quant_input(x):
3385+
scale = x.abs().max() / 448.0
3386+
return (x / scale).to(torch.float8_e4m3fn)
3387+
3388+
def matcher_check_fn():
3389+
self.assertEqual(counters["inductor"]["concat_dq_q_matcher_count"], 1)
3390+
3391+
shape = (128, 3)
3392+
mod = Mod()
3393+
for use_aoti, length in itertools.product(_get_fp8_aoti_options(), [2, 3]):
3394+
inputs = [torch.randn(shape) for _ in range(length)]
3395+
fp8_inputs = [quant_input(x) for x in inputs]
3396+
self._test_common(
3397+
mod,
3398+
(fp8_inputs,),
3399+
matcher_check_fn,
3400+
use_aoti=use_aoti,
3401+
)
3402+
32643403

32653404
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Requires torch 2.8+")
32663405
@unittest.skipIf(torch.version.hip is not None, "Not applicable to ROCm")

0 commit comments

Comments
 (0)