Skip to content

Commit 4efd83a

Browse files
authored
[pt2e] Skip linear+bn fusion when input is higher than 2-D (#4242)
Linear always operates on the last dimension while BatchNorm1d normalizes along dim 1 (channels). These two coincide only for 2-D inputs (N, C). For higher-rank inputs like 3-D (N, C, L), fusing the BN parameters into Linear weights silently produces incorrect results because the scale/shift is applied along the wrong axis. Add an ndim check in _fuse_linear_bn_ that skips fusion and emits a warning when the linear input has more than 2 dimensions. Fixes #4116 Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>
1 parent 707bee8 commit 4efd83a

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import copy
1313
import unittest
14+
import warnings
1415

1516
import torch
1617
from torch import Tensor
@@ -212,6 +213,69 @@ def test_linear_bn_fusion(self):
212213
torch.ops.aten.batch_norm.default,
213214
)
214215

216+
def test_linear_bn_fusion_skipped_for_3d_input(self):
217+
"""Verify that Linear+BN fusion is skipped when input is >2-D.
218+
219+
When the linear input is 3-D (N, C, L), Linear operates on the last
220+
dim while BatchNorm1d normalizes along dim 1. Fusing them silently
221+
produces incorrect results. See https://github.com/pytorch/ao/issues/4116
222+
"""
223+
for bias in [True, False]:
224+
m = torch.nn.Sequential(
225+
torch.nn.Linear(3, 5, bias=bias),
226+
torch.nn.BatchNorm1d(5),
227+
)
228+
m.eval()
229+
# 3-D input: (batch=2, channels=5, length=3)
230+
example_inputs = (torch.randn(2, 5, 3),)
231+
ref_outputs = m(*example_inputs)
232+
traced_model = torch.export.export(m, example_inputs, strict=True).module()
233+
with warnings.catch_warnings(record=True) as w:
234+
warnings.simplefilter("always")
235+
prepared_model = prepare_pt2e(traced_model, XNNPACKQuantizer())
236+
# Should emit a warning about skipping fusion
237+
fusion_warnings = [
238+
x for x in w if "Not fusing linear+bn" in str(x.message)
239+
]
240+
self.assertGreater(
241+
len(fusion_warnings),
242+
0,
243+
"Expected a warning about skipping 3-D linear+bn fusion",
244+
)
245+
prepared_outputs = prepared_model(*example_inputs)
246+
# Outputs must match the reference (no silent corruption)
247+
torch.testing.assert_close(
248+
ref_outputs, prepared_outputs, atol=1e-5, rtol=1e-5
249+
)
250+
251+
def test_linear_bn_fusion_correct_for_2d_input(self):
252+
"""Verify that 2-D Linear+BN fusion still works and BN is removed."""
253+
for bias in [True, False]:
254+
for N, M in [(8, 16), (5, 5)]:
255+
m = torch.nn.Sequential(
256+
torch.nn.Linear(N, M, bias=bias),
257+
torch.nn.BatchNorm1d(M),
258+
)
259+
m.eval()
260+
example_inputs = (torch.randn(4, N),)
261+
ref_outputs = m(*example_inputs)
262+
traced_model = torch.export.export(
263+
m, example_inputs, strict=True
264+
).module()
265+
prepared_model = prepare_pt2e(traced_model, XNNPACKQuantizer())
266+
prepared_outputs = prepared_model(*example_inputs)
267+
torch.testing.assert_close(ref_outputs, prepared_outputs)
268+
# BN nodes should be removed after fusion
269+
for node in prepared_model.graph.nodes:
270+
self.assertNotEqual(
271+
node.target,
272+
torch.ops.aten._native_batch_norm_legit_no_training.default,
273+
)
274+
self.assertNotEqual(
275+
node.target,
276+
torch.ops.aten.batch_norm.default,
277+
)
278+
215279
def test_wo_annotate_conv_output_quantizer(self):
216280
# TODO: use OP_TO_ANNOTATOR
217281
class BackendAQuantizer(Quantizer):

torchao/quantization/pt2e/utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,33 @@ def _fuse_linear_bn_(m: GraphModule) -> None:
970970
if not _is_linear_node(n):
971971
continue
972972
linear_node = n
973+
974+
# Linear+BN fusion is only valid when both layers operate on
975+
# the same dimension. Linear always acts on the last dim
976+
# while BatchNorm1d acts on the channel dim (dim 1). These
977+
# two coincide only when the linear input is 2-D (N, C).
978+
# For higher-rank inputs (e.g. 3-D (N, C, L)), BN normalises
979+
# along dim 1 whereas Linear transforms the last dim, so
980+
# fusing would silently produce incorrect results.
981+
# See https://github.com/pytorch/ao/issues/4116
982+
linear_input_node = linear_node.args[0]
983+
if isinstance(linear_input_node, Node):
984+
linear_input_val = linear_input_node.meta.get("val")
985+
if (
986+
linear_input_val is not None
987+
and isinstance(linear_input_val, torch.Tensor)
988+
and linear_input_val.ndim > 2
989+
):
990+
warnings.warn(
991+
f"Not fusing linear+bn for node "
992+
f"'{linear_node.name}': the linear input "
993+
f"is {linear_input_val.ndim}-D so Linear "
994+
f"and BatchNorm operate on different "
995+
f"dimensions",
996+
stacklevel=1,
997+
)
998+
continue
999+
9731000
linear_weight_node = linear_node.args[1]
9741001
linear_bias_node = linear_node.args[2] if len(linear_node.args) > 2 else None
9751002
fold_bn_weights_into_linear_node(

0 commit comments

Comments
 (0)