@@ -148,14 +148,15 @@ def _unary_fusion_pattern(unary_fusion, call_fn, users, is_bf16):
148148
149149
150150def get_dequantize_per_tensor_activation_pattern (
151- is_tensor_overload = False , is_fp8 = False
151+ is_tensor_overload = False , is_fp8 = False , users = 1
152152):
153153 if is_fp8 :
154154 dequantize_per_tensor_activation_pattern = CallFunction (
155155 torch .ops .torchao .dequantize_affine_float8_non_decomposed .default ,
156156 KeywordArg ("x" ),
157157 KeywordArg ("x_scale" ),
158158 output_dtype = KeywordArg ("x_dq_dtype" ),
159+ _users = users ,
159160 )
160161 else :
161162 dequantize_per_tensor_activation_pattern = CallFunction (
@@ -168,6 +169,7 @@ def get_dequantize_per_tensor_activation_pattern(
168169 KeywordArg ("x_quant_min" ),
169170 KeywordArg ("x_quant_max" ),
170171 KeywordArg ("x_dq_dtype" ),
172+ _users = users ,
171173 )
172174 return dequantize_per_tensor_activation_pattern
173175
@@ -991,6 +993,26 @@ def _get_linear_dq_node(
991993 return dequant_node , act_reshape_node , activation_to_bf16_node , act_expand_node
992994
993995
996+ def _is_assert_tensor_metadata_node (node : Any ) -> bool :
997+ return (
998+ isinstance (node , torch .fx .Node )
999+ and node .op == "call_function"
1000+ and node .target == torch .ops .aten ._assert_tensor_metadata .default
1001+ )
1002+
1003+
1004+ def _get_non_assert_users (node : torch .fx .Node ) -> list [torch .fx .Node ]:
1005+ return [user for user in node .users if not _is_assert_tensor_metadata_node (user )]
1006+
1007+
1008+ def _erase_assert_tensor_metadata_users (
1009+ graph : torch .fx .Graph , node : torch .fx .Node
1010+ ) -> None :
1011+ for user in list (node .users ):
1012+ if _is_assert_tensor_metadata_node (user ):
1013+ graph .erase_node (user )
1014+
1015+
9941016def _is_valid_dequant_linear_pattern (dtype , input_dim_exceeds_two , input_contiguous ):
9951017 def _inner (match ):
9961018 # Check dequant pattern has only 1 user.
@@ -1016,10 +1038,11 @@ def _inner(match):
10161038 torch .ops .torchao .dequantize_affine_float8_non_decomposed .default ,
10171039 ]
10181040
1019- if len (list (dequant_node .users )) != 1 :
1020- # Ensure the dequant pattern only has 1 user
1021- # since we will delete the dequant pattern here
1041+ # Allow dequant has multi users including _assert_tensor_metadata introduced by AOT Inductor.
1042+ if len (_get_non_assert_users (dequant_node )) != 1 :
10221043 return False
1044+ # Ensure the dequant pattern only has 1 effective user
1045+ # since we will delete the dequant pattern here
10231046
10241047 # Extra check for bmm pattern
10251048 if input_dim_exceeds_two and not input_contiguous :
@@ -1202,6 +1225,10 @@ def qlinear_weight_prepack(match: Match, *args, **kwargs):
12021225 linear_node .replace_all_uses_with (new_linear_node )
12031226 new_linear_node .meta .update (linear_node .meta )
12041227
1228+ # Erase assert users first so dequant nodes become erasable.
1229+ _erase_assert_tensor_metadata_users (graph , dequant_node )
1230+ _erase_assert_tensor_metadata_users (graph , dequant )
1231+
12051232 # Erase the original linear node
12061233 if input_dim_exceeds_two :
12071234 if input_contiguous :
@@ -1237,6 +1264,7 @@ def _generate_dequant_linear_node_pattern(
12371264 input_dim_exceeds_two = False ,
12381265 is_tensor_overload = False ,
12391266 is_fp8 = False ,
1267+ users = 1 ,
12401268):
12411269 assert dtype in [torch .float32 , torch .bfloat16 ]
12421270 t_pattern = _generate_linear_t_pattern (_dequant_per_channel_pattern , dtype )
@@ -1247,7 +1275,7 @@ def _generate_dequant_linear_node_pattern(
12471275 _may_generate_pattern_with_reshape (
12481276 _may_generate_pattern_with_dtype_convert (
12491277 get_dequantize_per_tensor_activation_pattern (
1250- is_tensor_overload , is_fp8
1278+ is_tensor_overload , is_fp8 , users
12511279 ),
12521280 KeywordArg ("autocast_act_dtype" ),
12531281 dtype == torch .bfloat16 ,
@@ -1266,7 +1294,7 @@ def _generate_dequant_linear_node_pattern(
12661294 _may_generate_pattern_with_reshape (
12671295 _may_generate_pattern_with_dtype_convert (
12681296 get_dequantize_per_tensor_activation_pattern (
1269- is_tensor_overload , is_fp8
1297+ is_tensor_overload , is_fp8 , users
12701298 ),
12711299 KeywordArg ("autocast_act_dtype" ),
12721300 dtype == torch .bfloat16 ,
@@ -1288,6 +1316,7 @@ def _generate_dequant_bmm_node_pattern(
12881316 with_bias = False ,
12891317 is_tensor_overload = False ,
12901318 is_fp8 = False ,
1319+ users = 1 ,
12911320):
12921321 # When activation of linear dim exceed 2 and not contiguous
12931322 t_pattern = _generate_linear_t_pattern (_dequant_per_channel_pattern , dtype )
@@ -1299,7 +1328,7 @@ def _generate_dequant_bmm_node_pattern(
12991328 aten .expand .default ,
13001329 _may_generate_pattern_with_dtype_convert (
13011330 get_dequantize_per_tensor_activation_pattern (
1302- is_tensor_overload , is_fp8
1331+ is_tensor_overload , is_fp8 , users
13031332 ),
13041333 KeywordArg ("autocast_act_dtype" ),
13051334 dtype == torch .bfloat16 ,
@@ -1333,9 +1362,17 @@ def _generate_qlinear_weight_prepack_patterns(
13331362 with_bias = False ,
13341363 is_tensor_overload = False ,
13351364 is_fp8 = False ,
1365+ users = 1 ,
13361366):
13371367 if is_fp8 :
1338- dequant_wgt_pattern = dequantize_fp8_weight_pattern
1368+ # dequant_wgt_pattern = dequantize_fp8_weight_pattern
1369+ dequant_wgt_pattern = CallFunction (
1370+ torch .ops .torchao .dequantize_affine_float8_non_decomposed .default ,
1371+ KeywordArg ("q_weight" ),
1372+ KeywordArg ("w_scale" ),
1373+ output_dtype = KeywordArg ("w_dtype" ),
1374+ _users = users ,
1375+ )
13391376 else :
13401377 dequant_wgt_pattern = dequantize_per_channel_weight_pattern
13411378 if input_dim_exceeds_two and not input_contiguous :
@@ -1345,6 +1382,7 @@ def _generate_qlinear_weight_prepack_patterns(
13451382 with_bias ,
13461383 is_tensor_overload ,
13471384 is_fp8 = is_fp8 ,
1385+ users = users ,
13481386 )
13491387 else :
13501388 return _generate_dequant_linear_node_pattern (
@@ -1353,6 +1391,7 @@ def _generate_qlinear_weight_prepack_patterns(
13531391 input_dim_exceeds_two ,
13541392 is_tensor_overload ,
13551393 is_fp8 = is_fp8 ,
1394+ users = users ,
13561395 )
13571396
13581397
@@ -1526,7 +1565,7 @@ def _register_qlinear_weight_prepack():
15261565 # | OPT(add) |
15271566
15281567 linear_weight_prepack_cases = itertools .product (
1529- [torch .float32 , torch .bfloat16 ], [True , False ], [True , False ], [True , False ]
1568+ [torch .float32 , torch .bfloat16 ], [True , False ], [True , False ], [True , False ], [ 1 , 2 ]
15301569 )
15311570
15321571 # Step 1: register patterns from mm and addmm
@@ -1535,6 +1574,7 @@ def _register_qlinear_weight_prepack():
15351574 input_dim_exceeds_two ,
15361575 is_tensor_overload ,
15371576 is_fp8 ,
1577+ users ,
15381578 ) in linear_weight_prepack_cases :
15391579 if is_fp8 and not is_tensor_overload :
15401580 continue
@@ -1543,6 +1583,7 @@ def _register_qlinear_weight_prepack():
15431583 input_dim_exceeds_two ,
15441584 is_tensor_overload = is_tensor_overload ,
15451585 is_fp8 = is_fp8 ,
1586+ users = users ,
15461587 )
15471588 for weight_prepack_pattern in weight_prepack_patterns :
15481589 # Register to pass_number 1, so we can do dequant promotion in pass_number 0.
@@ -2458,6 +2499,40 @@ def _register_qconv_binary_fusion():
24582499 )
24592500
24602501
2502+ def _extract_const_float_from_node (v : Any ) -> float | None :
2503+ if isinstance (v , (int , float )):
2504+ return float (v )
2505+
2506+ if isinstance (v , torch .fx .Node ):
2507+ # case 1: aten.full([1], c)
2508+ if v .op == "call_function" and v .target is torch .ops .aten .full .default :
2509+ if len (v .args ) >= 2 and isinstance (v .args [1 ], (int , float )):
2510+ return float (v .args [1 ])
2511+
2512+ # case 2: get_attr(lifted_tensor)
2513+ if v .op == "get_attr" :
2514+ obj = v .graph .owning_module
2515+ for atom in str (v .target ).split ("." ):
2516+ if not hasattr (obj , atom ):
2517+ obj = None
2518+ break
2519+ obj = getattr (obj , atom )
2520+
2521+ if isinstance (obj , (int , float )):
2522+ return float (obj )
2523+ if isinstance (obj , torch .Tensor ) and obj .numel () == 1 :
2524+ return float (obj .item ())
2525+
2526+ # case 3: meta val fallback
2527+ mv = v .meta .get ("val" , None )
2528+ if isinstance (mv , (int , float )):
2529+ return float (mv )
2530+ if isinstance (mv , torch .Tensor ) and mv .numel () == 1 :
2531+ return float (mv .item ())
2532+
2533+ return None
2534+
2535+
24612536def _register_qlinear_post_op_fusion_pass (
24622537 pattern ,
24632538 pass_number ,
@@ -2495,10 +2570,10 @@ def qlinear_post_op_fusion(match: Match, *args, **kwargs):
24952570
24962571 # Output QParams
24972572 if output_dtype == torch .float8_e4m3fn :
2498- # For float8, we assume the scale is from aten.full.default instead of
2499- # a constant buffer to avoid constant folding of q/dq before fusion passes.
2500- assert kwargs [ " o_inv_scale" ]. target is torch . ops . aten . full . default
2501- o_inv_scale = kwargs [ "o_inv_scale" ]. args [ 1 ]
2573+ o_inv_scale = _extract_const_float_from_node ( kwargs [ "o_inv_scale" ])
2574+ assert o_inv_scale is not None , (
2575+ f"Unsupported fp8 o_inv_scale node: { kwargs [ 'o_inv_scale' ] } "
2576+ )
25022577 else :
25032578 o_inv_scale = (
25042579 kwargs ["o_inv_scale" ]
@@ -2979,31 +3054,7 @@ def _normalize_dtype(dtype_or_enum: Any) -> torch.dtype | Any:
29793054 normalized_o_dtype = _normalize_dtype (kwargs ["o_dtype" ])
29803055 output_type = normalized_o_dtype
29813056
2982- def _extract_const_float (val ) -> float | None :
2983- # Prefer extracting from python scalars and FX node structure
2984- if isinstance (val , (int , float )):
2985- return float (val )
2986- if isinstance (val , torch .fx .Node ):
2987- meta_val = val .meta .get ("val" , None )
2988- if isinstance (meta_val , (int , float )):
2989- return float (meta_val )
2990- # Common pattern: aten.full([1], fill_value, dtype=float)
2991- if val .target is torch .ops .aten .full .default and len (val .args ) >= 2 :
2992- fill_value = val .args [1 ]
2993- if isinstance (fill_value , (int , float )):
2994- return float (fill_value )
2995- # Common pattern in user code: torch.tensor([scalar])
2996- if val .target is torch .tensor and len (val .args ) >= 1 :
2997- data = val .args [0 ]
2998- if (
2999- isinstance (data , (list , tuple ))
3000- and len (data ) == 1
3001- and isinstance (data [0 ], (int , float ))
3002- ):
3003- return float (data [0 ])
3004- return None
3005-
3006- o_scale = _extract_const_float (kwargs ["o_inv_scale" ])
3057+ o_scale = _extract_const_float_from_node (kwargs ["o_inv_scale" ])
30073058 assert o_scale is not None , "Output scale is not a constant float."
30083059
30093060 graph = match .graph
0 commit comments