Skip to content

Commit da281fd

Browse files
committed
enable fp8 qlinear patterns for AOT Inductor
1 parent 637c4ac commit da281fd

1 file changed

Lines changed: 89 additions & 38 deletions

File tree

  • torchao/quantization/pt2e/inductor_passes

torchao/quantization/pt2e/inductor_passes/x86.py

Lines changed: 89 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,15 @@ def _unary_fusion_pattern(unary_fusion, call_fn, users, is_bf16):
148148

149149

150150
def get_dequantize_per_tensor_activation_pattern(
151-
is_tensor_overload=False, is_fp8=False
151+
is_tensor_overload=False, is_fp8=False, users=1
152152
):
153153
if is_fp8:
154154
dequantize_per_tensor_activation_pattern = CallFunction(
155155
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default,
156156
KeywordArg("x"),
157157
KeywordArg("x_scale"),
158158
output_dtype=KeywordArg("x_dq_dtype"),
159+
_users=users,
159160
)
160161
else:
161162
dequantize_per_tensor_activation_pattern = CallFunction(
@@ -168,6 +169,7 @@ def get_dequantize_per_tensor_activation_pattern(
168169
KeywordArg("x_quant_min"),
169170
KeywordArg("x_quant_max"),
170171
KeywordArg("x_dq_dtype"),
172+
_users=users,
171173
)
172174
return dequantize_per_tensor_activation_pattern
173175

@@ -991,6 +993,26 @@ def _get_linear_dq_node(
991993
return dequant_node, act_reshape_node, activation_to_bf16_node, act_expand_node
992994

993995

996+
def _is_assert_tensor_metadata_node(node: Any) -> bool:
997+
return (
998+
isinstance(node, torch.fx.Node)
999+
and node.op == "call_function"
1000+
and node.target == torch.ops.aten._assert_tensor_metadata.default
1001+
)
1002+
1003+
1004+
def _get_non_assert_users(node: torch.fx.Node) -> list[torch.fx.Node]:
1005+
return [user for user in node.users if not _is_assert_tensor_metadata_node(user)]
1006+
1007+
1008+
def _erase_assert_tensor_metadata_users(
1009+
graph: torch.fx.Graph, node: torch.fx.Node
1010+
) -> None:
1011+
for user in list(node.users):
1012+
if _is_assert_tensor_metadata_node(user):
1013+
graph.erase_node(user)
1014+
1015+
9941016
def _is_valid_dequant_linear_pattern(dtype, input_dim_exceeds_two, input_contiguous):
9951017
def _inner(match):
9961018
# Check dequant pattern has only 1 user.
@@ -1016,10 +1038,11 @@ def _inner(match):
10161038
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default,
10171039
]
10181040

1019-
if len(list(dequant_node.users)) != 1:
1020-
# Ensure the dequant pattern only has 1 user
1021-
# since we will delete the dequant pattern here
1041+
# Allow dequant has multi users including _assert_tensor_metadata introduced by AOT Inductor.
1042+
if len(_get_non_assert_users(dequant_node)) != 1:
10221043
return False
1044+
# Ensure the dequant pattern only has 1 effective user
1045+
# since we will delete the dequant pattern here
10231046

10241047
# Extra check for bmm pattern
10251048
if input_dim_exceeds_two and not input_contiguous:
@@ -1202,6 +1225,10 @@ def qlinear_weight_prepack(match: Match, *args, **kwargs):
12021225
linear_node.replace_all_uses_with(new_linear_node)
12031226
new_linear_node.meta.update(linear_node.meta)
12041227

1228+
# Erase assert users first so dequant nodes become erasable.
1229+
_erase_assert_tensor_metadata_users(graph, dequant_node)
1230+
_erase_assert_tensor_metadata_users(graph, dequant)
1231+
12051232
# Erase the original linear node
12061233
if input_dim_exceeds_two:
12071234
if input_contiguous:
@@ -1237,6 +1264,7 @@ def _generate_dequant_linear_node_pattern(
12371264
input_dim_exceeds_two=False,
12381265
is_tensor_overload=False,
12391266
is_fp8=False,
1267+
users=1,
12401268
):
12411269
assert dtype in [torch.float32, torch.bfloat16]
12421270
t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype)
@@ -1247,7 +1275,7 @@ def _generate_dequant_linear_node_pattern(
12471275
_may_generate_pattern_with_reshape(
12481276
_may_generate_pattern_with_dtype_convert(
12491277
get_dequantize_per_tensor_activation_pattern(
1250-
is_tensor_overload, is_fp8
1278+
is_tensor_overload, is_fp8, users
12511279
),
12521280
KeywordArg("autocast_act_dtype"),
12531281
dtype == torch.bfloat16,
@@ -1266,7 +1294,7 @@ def _generate_dequant_linear_node_pattern(
12661294
_may_generate_pattern_with_reshape(
12671295
_may_generate_pattern_with_dtype_convert(
12681296
get_dequantize_per_tensor_activation_pattern(
1269-
is_tensor_overload, is_fp8
1297+
is_tensor_overload, is_fp8, users
12701298
),
12711299
KeywordArg("autocast_act_dtype"),
12721300
dtype == torch.bfloat16,
@@ -1288,6 +1316,7 @@ def _generate_dequant_bmm_node_pattern(
12881316
with_bias=False,
12891317
is_tensor_overload=False,
12901318
is_fp8=False,
1319+
users=1,
12911320
):
12921321
# When activation of linear dim exceed 2 and not contiguous
12931322
t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype)
@@ -1299,7 +1328,7 @@ def _generate_dequant_bmm_node_pattern(
12991328
aten.expand.default,
13001329
_may_generate_pattern_with_dtype_convert(
13011330
get_dequantize_per_tensor_activation_pattern(
1302-
is_tensor_overload, is_fp8
1331+
is_tensor_overload, is_fp8, users
13031332
),
13041333
KeywordArg("autocast_act_dtype"),
13051334
dtype == torch.bfloat16,
@@ -1333,9 +1362,17 @@ def _generate_qlinear_weight_prepack_patterns(
13331362
with_bias=False,
13341363
is_tensor_overload=False,
13351364
is_fp8=False,
1365+
users=1,
13361366
):
13371367
if is_fp8:
1338-
dequant_wgt_pattern = dequantize_fp8_weight_pattern
1368+
# dequant_wgt_pattern = dequantize_fp8_weight_pattern
1369+
dequant_wgt_pattern = CallFunction(
1370+
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default,
1371+
KeywordArg("q_weight"),
1372+
KeywordArg("w_scale"),
1373+
output_dtype=KeywordArg("w_dtype"),
1374+
_users=users,
1375+
)
13391376
else:
13401377
dequant_wgt_pattern = dequantize_per_channel_weight_pattern
13411378
if input_dim_exceeds_two and not input_contiguous:
@@ -1345,6 +1382,7 @@ def _generate_qlinear_weight_prepack_patterns(
13451382
with_bias,
13461383
is_tensor_overload,
13471384
is_fp8=is_fp8,
1385+
users=users,
13481386
)
13491387
else:
13501388
return _generate_dequant_linear_node_pattern(
@@ -1353,6 +1391,7 @@ def _generate_qlinear_weight_prepack_patterns(
13531391
input_dim_exceeds_two,
13541392
is_tensor_overload,
13551393
is_fp8=is_fp8,
1394+
users=users,
13561395
)
13571396

13581397

@@ -1526,7 +1565,7 @@ def _register_qlinear_weight_prepack():
15261565
# | OPT(add) |
15271566

15281567
linear_weight_prepack_cases = itertools.product(
1529-
[torch.float32, torch.bfloat16], [True, False], [True, False], [True, False]
1568+
[torch.float32, torch.bfloat16], [True, False], [True, False], [True, False], [1, 2]
15301569
)
15311570

15321571
# Step 1: register patterns from mm and addmm
@@ -1535,6 +1574,7 @@ def _register_qlinear_weight_prepack():
15351574
input_dim_exceeds_two,
15361575
is_tensor_overload,
15371576
is_fp8,
1577+
users,
15381578
) in linear_weight_prepack_cases:
15391579
if is_fp8 and not is_tensor_overload:
15401580
continue
@@ -1543,6 +1583,7 @@ def _register_qlinear_weight_prepack():
15431583
input_dim_exceeds_two,
15441584
is_tensor_overload=is_tensor_overload,
15451585
is_fp8=is_fp8,
1586+
users=users,
15461587
)
15471588
for weight_prepack_pattern in weight_prepack_patterns:
15481589
# Register to pass_number 1, so we can do dequant promotion in pass_number 0.
@@ -2458,6 +2499,40 @@ def _register_qconv_binary_fusion():
24582499
)
24592500

24602501

2502+
def _extract_const_float_from_node(v: Any) -> float | None:
2503+
if isinstance(v, (int, float)):
2504+
return float(v)
2505+
2506+
if isinstance(v, torch.fx.Node):
2507+
# case 1: aten.full([1], c)
2508+
if v.op == "call_function" and v.target is torch.ops.aten.full.default:
2509+
if len(v.args) >= 2 and isinstance(v.args[1], (int, float)):
2510+
return float(v.args[1])
2511+
2512+
# case 2: get_attr(lifted_tensor)
2513+
if v.op == "get_attr":
2514+
obj = v.graph.owning_module
2515+
for atom in str(v.target).split("."):
2516+
if not hasattr(obj, atom):
2517+
obj = None
2518+
break
2519+
obj = getattr(obj, atom)
2520+
2521+
if isinstance(obj, (int, float)):
2522+
return float(obj)
2523+
if isinstance(obj, torch.Tensor) and obj.numel() == 1:
2524+
return float(obj.item())
2525+
2526+
# case 3: meta val fallback
2527+
mv = v.meta.get("val", None)
2528+
if isinstance(mv, (int, float)):
2529+
return float(mv)
2530+
if isinstance(mv, torch.Tensor) and mv.numel() == 1:
2531+
return float(mv.item())
2532+
2533+
return None
2534+
2535+
24612536
def _register_qlinear_post_op_fusion_pass(
24622537
pattern,
24632538
pass_number,
@@ -2495,10 +2570,10 @@ def qlinear_post_op_fusion(match: Match, *args, **kwargs):
24952570

24962571
# Output QParams
24972572
if output_dtype == torch.float8_e4m3fn:
2498-
# For float8, we assume the scale is from aten.full.default instead of
2499-
# a constant buffer to avoid constant folding of q/dq before fusion passes.
2500-
assert kwargs["o_inv_scale"].target is torch.ops.aten.full.default
2501-
o_inv_scale = kwargs["o_inv_scale"].args[1]
2573+
o_inv_scale = _extract_const_float_from_node(kwargs["o_inv_scale"])
2574+
assert o_inv_scale is not None, (
2575+
f"Unsupported fp8 o_inv_scale node: {kwargs['o_inv_scale']}"
2576+
)
25022577
else:
25032578
o_inv_scale = (
25042579
kwargs["o_inv_scale"]
@@ -2979,31 +3054,7 @@ def _normalize_dtype(dtype_or_enum: Any) -> torch.dtype | Any:
29793054
normalized_o_dtype = _normalize_dtype(kwargs["o_dtype"])
29803055
output_type = normalized_o_dtype
29813056

2982-
def _extract_const_float(val) -> float | None:
2983-
# Prefer extracting from python scalars and FX node structure
2984-
if isinstance(val, (int, float)):
2985-
return float(val)
2986-
if isinstance(val, torch.fx.Node):
2987-
meta_val = val.meta.get("val", None)
2988-
if isinstance(meta_val, (int, float)):
2989-
return float(meta_val)
2990-
# Common pattern: aten.full([1], fill_value, dtype=float)
2991-
if val.target is torch.ops.aten.full.default and len(val.args) >= 2:
2992-
fill_value = val.args[1]
2993-
if isinstance(fill_value, (int, float)):
2994-
return float(fill_value)
2995-
# Common pattern in user code: torch.tensor([scalar])
2996-
if val.target is torch.tensor and len(val.args) >= 1:
2997-
data = val.args[0]
2998-
if (
2999-
isinstance(data, (list, tuple))
3000-
and len(data) == 1
3001-
and isinstance(data[0], (int, float))
3002-
):
3003-
return float(data[0])
3004-
return None
3005-
3006-
o_scale = _extract_const_float(kwargs["o_inv_scale"])
3057+
o_scale = _extract_const_float_from_node(kwargs["o_inv_scale"])
30073058
assert o_scale is not None, "Output scale is not a constant float."
30083059

30093060
graph = match.graph

0 commit comments

Comments
 (0)