@@ -269,7 +269,9 @@ def _test_code_common(
269269 package_path = package_path ,
270270 )
271271
272- cpp_paths = glob .glob (os .path .join (tmpdir , "**/*.wrapper.cpp" ), recursive = True )
272+ cpp_paths = glob .glob (
273+ os .path .join (tmpdir , "**/*.wrapper.cpp" ), recursive = True
274+ )
273275 assert cpp_paths , "Failed to find generated .wrapper.cpp"
274276 with open (cpp_paths [0 ], "r" ) as f :
275277 source_code = f .read ()
@@ -1609,18 +1611,25 @@ def test_fp8_qlinear_cpu(self):
16091611 r"""
16101612 This testcase will quantize a single Linear Moduel.
16111613 """
1612- aoti_options = [False ,]
1614+ aoti_options = [
1615+ False ,
1616+ ]
16131617 try :
16141618 import torch ._inductor .constant_folding as cf
1619+
16151620 if hasattr (cf , "add_dont_constant_fold" ):
1616- cf .add_dont_constant_fold (torch .ops .torchao .dequantize_affine_float8_non_decomposed .default )
1621+ cf .add_dont_constant_fold (
1622+ torch .ops .torchao .dequantize_affine_float8_non_decomposed .default
1623+ )
16171624 aoti_options = [False , True ]
16181625 finally :
16191626 pass
16201627
16211628 for is_aoti in aoti_options :
16221629 for bias in [True , False ]:
1623- self ._qlinear_test_helper ((torch .randn ((2 , 4 )),), bias = bias , is_fp8 = True , is_aoti = is_aoti )
1630+ self ._qlinear_test_helper (
1631+ (torch .randn ((2 , 4 )),), bias = bias , is_fp8 = True , is_aoti = is_aoti
1632+ )
16241633
16251634 @skipIfNoDynamoSupport
16261635 @skipIfNoONEDNN
@@ -3170,7 +3179,9 @@ def test_fp8_q_attention_block(self):
31703179 annotate_matmul = annotate_matmul , is_fp8 = True
31713180 )
31723181
3173- def _test_scaled_embedding_bag_helper (self , dtype , with_output_quant = False , is_aoti = False ):
3182+ def _test_scaled_embedding_bag_helper (
3183+ self , dtype , with_output_quant = False , is_aoti = False
3184+ ):
31743185 class FP8QDQEmbeddingBag (torch .nn .Module ):
31753186 def __init__ (self ):
31763187 super ().__init__ ()
@@ -3276,11 +3287,16 @@ def matcher_check_fn():
32763287 reason = "cpp kernels not built" ,
32773288 )
32783289 def test_fp8_scaled_embedding_bag (self ):
3279- aoti_options = [False ,]
3290+ aoti_options = [
3291+ False ,
3292+ ]
32803293 try :
32813294 import torch ._inductor .constant_folding as cf
3295+
32823296 if hasattr (cf , "add_dont_constant_fold" ):
3283- cf .add_dont_constant_fold (torch .ops .torchao .dequantize_affine_float8_non_decomposed .default )
3297+ cf .add_dont_constant_fold (
3298+ torch .ops .torchao .dequantize_affine_float8_non_decomposed .default
3299+ )
32843300 aoti_options = [False , True ]
32853301 finally :
32863302 pass
0 commit comments