Skip to content

Commit 49453c8

Browse files
committed
fix format
1 parent 09a5f8c commit 49453c8

2 files changed

Lines changed: 28 additions & 8 deletions

File tree

test/quantization/pt2e/test_x86inductor_fusion.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,9 @@ def _test_code_common(
269269
package_path=package_path,
270270
)
271271

272-
cpp_paths = glob.glob(os.path.join(tmpdir, "**/*.wrapper.cpp"), recursive=True)
272+
cpp_paths = glob.glob(
273+
os.path.join(tmpdir, "**/*.wrapper.cpp"), recursive=True
274+
)
273275
assert cpp_paths, "Failed to find generated .wrapper.cpp"
274276
with open(cpp_paths[0], "r") as f:
275277
source_code = f.read()
@@ -1609,18 +1611,25 @@ def test_fp8_qlinear_cpu(self):
16091611
r"""
16101612
This testcase will quantize a single Linear Moduel.
16111613
"""
1612-
aoti_options = [False,]
1614+
aoti_options = [
1615+
False,
1616+
]
16131617
try:
16141618
import torch._inductor.constant_folding as cf
1619+
16151620
if hasattr(cf, "add_dont_constant_fold"):
1616-
cf.add_dont_constant_fold(torch.ops.torchao.dequantize_affine_float8_non_decomposed.default)
1621+
cf.add_dont_constant_fold(
1622+
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default
1623+
)
16171624
aoti_options = [False, True]
16181625
finally:
16191626
pass
16201627

16211628
for is_aoti in aoti_options:
16221629
for bias in [True, False]:
1623-
self._qlinear_test_helper((torch.randn((2, 4)),), bias=bias, is_fp8=True, is_aoti=is_aoti)
1630+
self._qlinear_test_helper(
1631+
(torch.randn((2, 4)),), bias=bias, is_fp8=True, is_aoti=is_aoti
1632+
)
16241633

16251634
@skipIfNoDynamoSupport
16261635
@skipIfNoONEDNN
@@ -3170,7 +3179,9 @@ def test_fp8_q_attention_block(self):
31703179
annotate_matmul=annotate_matmul, is_fp8=True
31713180
)
31723181

3173-
def _test_scaled_embedding_bag_helper(self, dtype, with_output_quant=False, is_aoti=False):
3182+
def _test_scaled_embedding_bag_helper(
3183+
self, dtype, with_output_quant=False, is_aoti=False
3184+
):
31743185
class FP8QDQEmbeddingBag(torch.nn.Module):
31753186
def __init__(self):
31763187
super().__init__()
@@ -3276,11 +3287,16 @@ def matcher_check_fn():
32763287
reason="cpp kernels not built",
32773288
)
32783289
def test_fp8_scaled_embedding_bag(self):
3279-
aoti_options = [False,]
3290+
aoti_options = [
3291+
False,
3292+
]
32803293
try:
32813294
import torch._inductor.constant_folding as cf
3295+
32823296
if hasattr(cf, "add_dont_constant_fold"):
3283-
cf.add_dont_constant_fold(torch.ops.torchao.dequantize_affine_float8_non_decomposed.default)
3297+
cf.add_dont_constant_fold(
3298+
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default
3299+
)
32843300
aoti_options = [False, True]
32853301
finally:
32863302
pass

torchao/quantization/pt2e/inductor_passes/x86.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1563,7 +1563,11 @@ def _register_qlinear_weight_prepack():
15631563
# | OPT(add) |
15641564

15651565
linear_weight_prepack_cases = itertools.product(
1566-
[torch.float32, torch.bfloat16], [True, False], [True, False], [True, False], [1, 2]
1566+
[torch.float32, torch.bfloat16],
1567+
[True, False],
1568+
[True, False],
1569+
[True, False],
1570+
[1, 2],
15671571
)
15681572

15691573
# Step 1: register patterns from mm and addmm

0 commit comments

Comments
 (0)