Commit 4efd83a
authored
[pt2e] Skip linear+bn fusion when input is higher than 2-D (#4242)
Linear always operates on the last dimension while BatchNorm1d
normalizes along dim 1 (channels). These two coincide only for 2-D
inputs (N, C). For higher-rank inputs like 3-D (N, C, L), fusing
the BN parameters into Linear weights silently produces incorrect
results because the scale/shift is applied along the wrong axis.
Add an ndim check in _fuse_linear_bn_ that skips fusion and emits
a warning when the linear input has more than 2 dimensions.
Fixes #4116
Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>1 parent 707bee8 commit 4efd83a
File tree
2 files changed
+91
-0
lines changed- test/quantization/pt2e
- torchao/quantization/pt2e
2 files changed
+91
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
11 | 11 | | |
12 | 12 | | |
13 | 13 | | |
| 14 | + | |
14 | 15 | | |
15 | 16 | | |
16 | 17 | | |
| |||
212 | 213 | | |
213 | 214 | | |
214 | 215 | | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
215 | 279 | | |
216 | 280 | | |
217 | 281 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
970 | 970 | | |
971 | 971 | | |
972 | 972 | | |
| 973 | + | |
| 974 | + | |
| 975 | + | |
| 976 | + | |
| 977 | + | |
| 978 | + | |
| 979 | + | |
| 980 | + | |
| 981 | + | |
| 982 | + | |
| 983 | + | |
| 984 | + | |
| 985 | + | |
| 986 | + | |
| 987 | + | |
| 988 | + | |
| 989 | + | |
| 990 | + | |
| 991 | + | |
| 992 | + | |
| 993 | + | |
| 994 | + | |
| 995 | + | |
| 996 | + | |
| 997 | + | |
| 998 | + | |
| 999 | + | |
973 | 1000 | | |
974 | 1001 | | |
975 | 1002 | | |
| |||
0 commit comments