Skip to content

Commit 01d3a2d

Browse files
[metal] Enable bias in metal lowbit-quantized linear kernels (#3745)
Enhances the quantization API in torchao/experimental/quant_api.py by adding support for quantized linear layers with bias and improving error handling for non-contiguous weights. The changes ensure that quantized layers can correctly handle bias terms and provide clearer error messages when input weights do not meet required conditions.
1 parent 90e7ca0 commit 01d3a2d

2 files changed

Lines changed: 28 additions & 13 deletions

File tree

torchao/experimental/ops/mps/test/test_quantizer.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,7 @@
1313

1414
# Need to import to load the ops
1515
import torchao.experimental.ops.mps # noqa: F401
16-
from torchao.experimental.quant_api import (
17-
UIntxWeightOnlyConfig,
18-
_linear_int_weight_mps_check,
19-
_quantize,
20-
)
16+
from torchao.experimental.quant_api import UIntxWeightOnlyConfig, _quantize
2117
from torchao.quantization.quant_api import quantize_
2218

2319

@@ -50,7 +46,7 @@ def _quantize_model(self, model, precision, nbit, group_size):
5046
)
5147
quantized_model = copy.deepcopy(model)
5248
quantized_model = quantized_model.to(device="mps", dtype=precision)
53-
quantize_(quantized_model, config, filter_fn=_linear_int_weight_mps_check)
49+
quantize_(quantized_model, config)
5450
return quantized_model
5551

5652
@parameterized.expand(BITWIDTHS)

torchao/experimental/quant_api.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,15 @@ def __init__(
7777
self,
7878
pack_weight_op,
7979
linear_op,
80+
bias: Optional[torch.Tensor] = None,
8081
):
8182
super().__init__()
8283
self._pack_weights_op = pack_weight_op
8384
self._linear_op = linear_op
85+
if bias is not None:
86+
self.bias = nn.Parameter(bias, requires_grad=False)
87+
else:
88+
self.register_parameter("bias", None)
8489

8590
def quantize_and_pack_weights(self, weights, nbit, group_size):
8691
self.nbit = nbit
@@ -100,24 +105,30 @@ def quantize_and_pack_weights(self, weights, nbit, group_size):
100105
def forward(self, x):
101106
assert x.dim() >= 2
102107
if x.dim() == 2:
103-
return self._linear_op(
108+
output = self._linear_op(
104109
x,
105110
self.packed_weights,
106111
self.group_size,
107112
self.weight_scales,
108113
self.weight_zeros,
109114
)
115+
if self.bias is not None:
116+
output = output + self.bias
117+
return output
110118

111119
lead_shape = x.shape[0:-1]
112120
k = x.shape[-1]
113121
n = self.weight_scales.shape[0]
114-
return self._linear_op(
122+
output = self._linear_op(
115123
x.reshape(-1, k),
116124
self.packed_weights,
117125
self.group_size,
118126
self.weight_scales,
119127
self.weight_zeros,
120128
).reshape(*lead_shape, n)
129+
if self.bias is not None:
130+
output = output + self.bias
131+
return output
121132

122133

123134
# TODO(mcandales): Consolidate with _replace_linear_with_quantized_linear
@@ -132,12 +143,17 @@ def _replace_linear_with_quantized_linear_mps(module: nn.Module, kwargs={}):
132143
if not isinstance(child, nn.Linear):
133144
_replace_linear_with_quantized_linear_mps(child, kwargs)
134145
else:
135-
assert child.bias is None
146+
if not child.weight.is_contiguous():
147+
raise ValueError(
148+
f"UIntxWeightOnlyQuantizedLinear requires contiguous weights for layer '{name}'. "
149+
"Please call .contiguous() on the weight tensor before quantization."
150+
)
136151
qlinear = UIntxWeightOnlyQuantizedLinear(
137152
pack_weight_op=getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit"),
138153
linear_op=getattr(
139154
torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight"
140155
),
156+
bias=child.bias,
141157
)
142158
setattr(module, name, qlinear)
143159
qlinear.quantize_and_pack_weights(child.weight, nbit, group_size)
@@ -232,20 +248,23 @@ def __post_init__(self):
232248
)
233249

234250

235-
def _linear_int_weight_mps_check(module: nn.Module, fqn: str) -> bool:
236-
return isinstance(module, nn.Linear) and module.bias is None
237-
238-
239251
@register_quantize_module_handler(UIntxWeightOnlyConfig)
240252
def _uintx_weight_only_mps_transform(
241253
module: torch.nn.Module, config: UIntxWeightOnlyConfig
242254
) -> torch.nn.Module:
243255
nbit = config.bitwidth
244256
group_size = config.group_size
245257

258+
if not module.weight.is_contiguous():
259+
raise ValueError(
260+
"UIntxWeightOnlyQuantizedLinear requires contiguous weights. "
261+
"Please call .contiguous() on the weight tensor before quantization."
262+
)
263+
246264
qlinear = UIntxWeightOnlyQuantizedLinear(
247265
pack_weight_op=getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit"),
248266
linear_op=getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight"),
267+
bias=module.bias,
249268
)
250269
qlinear.quantize_and_pack_weights(module.weight, nbit, group_size)
251270
return qlinear

0 commit comments

Comments
 (0)