Skip to content

Commit 1c17731

Browse files
committed
fix format
1 parent 7474bd0 commit 1c17731

2 files changed

Lines changed: 28 additions & 8 deletions

File tree

test/quantization/pt2e/test_x86inductor_fusion.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,9 @@ def _test_code_common(
269269
package_path=package_path,
270270
)
271271

272-
cpp_paths = glob.glob(os.path.join(tmpdir, "**/*.wrapper.cpp"), recursive=True)
272+
cpp_paths = glob.glob(
273+
os.path.join(tmpdir, "**/*.wrapper.cpp"), recursive=True
274+
)
273275
assert cpp_paths, "Failed to find generated .wrapper.cpp"
274276
with open(cpp_paths[0], "r") as f:
275277
source_code = f.read()
@@ -1609,18 +1611,25 @@ def test_fp8_qlinear_cpu(self):
16091611
r"""
16101612
This testcase will quantize a single Linear Moduel.
16111613
"""
1612-
aoti_options = [False,]
1614+
aoti_options = [
1615+
False,
1616+
]
16131617
try:
16141618
import torch._inductor.constant_folding as cf
1619+
16151620
if hasattr(cf, "add_dont_constant_fold"):
1616-
cf.add_dont_constant_fold(torch.ops.torchao.dequantize_affine_float8_non_decomposed.default)
1621+
cf.add_dont_constant_fold(
1622+
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default
1623+
)
16171624
aoti_options = [False, True]
16181625
finally:
16191626
pass
16201627

16211628
for is_aoti in aoti_options:
16221629
for bias in [True, False]:
1623-
self._qlinear_test_helper((torch.randn((2, 4)),), bias=bias, is_fp8=True, is_aoti=is_aoti)
1630+
self._qlinear_test_helper(
1631+
(torch.randn((2, 4)),), bias=bias, is_fp8=True, is_aoti=is_aoti
1632+
)
16241633

16251634
@skipIfNoDynamoSupport
16261635
@skipIfNoONEDNN
@@ -3164,7 +3173,9 @@ def test_fp8_q_attention_block(self):
31643173
annotate_matmul=annotate_matmul, is_fp8=True
31653174
)
31663175

3167-
def _test_scaled_embedding_bag_helper(self, dtype, with_output_quant=False, is_aoti=False):
3176+
def _test_scaled_embedding_bag_helper(
3177+
self, dtype, with_output_quant=False, is_aoti=False
3178+
):
31683179
class FP8QDQEmbeddingBag(torch.nn.Module):
31693180
def __init__(self):
31703181
super().__init__()
@@ -3270,11 +3281,16 @@ def matcher_check_fn():
32703281
reason="cpp kernels not built",
32713282
)
32723283
def test_fp8_scaled_embedding_bag(self):
3273-
aoti_options = [False,]
3284+
aoti_options = [
3285+
False,
3286+
]
32743287
try:
32753288
import torch._inductor.constant_folding as cf
3289+
32763290
if hasattr(cf, "add_dont_constant_fold"):
3277-
cf.add_dont_constant_fold(torch.ops.torchao.dequantize_affine_float8_non_decomposed.default)
3291+
cf.add_dont_constant_fold(
3292+
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default
3293+
)
32783294
aoti_options = [False, True]
32793295
finally:
32803296
pass

torchao/quantization/pt2e/inductor_passes/x86.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1563,7 +1563,11 @@ def _register_qlinear_weight_prepack():
15631563
# | OPT(add) |
15641564

15651565
linear_weight_prepack_cases = itertools.product(
1566-
[torch.float32, torch.bfloat16], [True, False], [True, False], [True, False], [1, 2]
1566+
[torch.float32, torch.bfloat16],
1567+
[True, False],
1568+
[True, False],
1569+
[True, False],
1570+
[1, 2],
15671571
)
15681572

15691573
# Step 1: register patterns from mm and addmm

0 commit comments

Comments
 (0)