@@ -77,10 +77,15 @@ def __init__(
7777 self ,
7878 pack_weight_op ,
7979 linear_op ,
80+ bias : Optional [torch .Tensor ] = None ,
8081 ):
8182 super ().__init__ ()
8283 self ._pack_weights_op = pack_weight_op
8384 self ._linear_op = linear_op
85+ if bias is not None :
86+ self .bias = nn .Parameter (bias , requires_grad = False )
87+ else :
88+ self .register_parameter ("bias" , None )
8489
8590 def quantize_and_pack_weights (self , weights , nbit , group_size ):
8691 self .nbit = nbit
@@ -100,24 +105,30 @@ def quantize_and_pack_weights(self, weights, nbit, group_size):
100105 def forward (self , x ):
101106 assert x .dim () >= 2
102107 if x .dim () == 2 :
103- return self ._linear_op (
108+ output = self ._linear_op (
104109 x ,
105110 self .packed_weights ,
106111 self .group_size ,
107112 self .weight_scales ,
108113 self .weight_zeros ,
109114 )
115+ if self .bias is not None :
116+ output = output + self .bias
117+ return output
110118
111119 lead_shape = x .shape [0 :- 1 ]
112120 k = x .shape [- 1 ]
113121 n = self .weight_scales .shape [0 ]
114- return self ._linear_op (
122+ output = self ._linear_op (
115123 x .reshape (- 1 , k ),
116124 self .packed_weights ,
117125 self .group_size ,
118126 self .weight_scales ,
119127 self .weight_zeros ,
120128 ).reshape (* lead_shape , n )
129+ if self .bias is not None :
130+ output = output + self .bias
131+ return output
121132
122133
123134# TODO(mcandales): Consolidate with _replace_linear_with_quantized_linear
@@ -132,12 +143,17 @@ def _replace_linear_with_quantized_linear_mps(module: nn.Module, kwargs={}):
132143 if not isinstance (child , nn .Linear ):
133144 _replace_linear_with_quantized_linear_mps (child , kwargs )
134145 else :
135- assert child .bias is None
146+ if not child .weight .is_contiguous ():
147+ raise ValueError (
148+ f"UIntxWeightOnlyQuantizedLinear requires contiguous weights for layer '{ name } '. "
149+ "Please call .contiguous() on the weight tensor before quantization."
150+ )
136151 qlinear = UIntxWeightOnlyQuantizedLinear (
137152 pack_weight_op = getattr (torch .ops .torchao , f"_pack_weight_{ nbit } bit" ),
138153 linear_op = getattr (
139154 torch .ops .torchao , f"_linear_fp_act_{ nbit } bit_weight"
140155 ),
156+ bias = child .bias ,
141157 )
142158 setattr (module , name , qlinear )
143159 qlinear .quantize_and_pack_weights (child .weight , nbit , group_size )
@@ -232,20 +248,23 @@ def __post_init__(self):
232248 )
233249
234250
235- def _linear_int_weight_mps_check (module : nn .Module , fqn : str ) -> bool :
236- return isinstance (module , nn .Linear ) and module .bias is None
237-
238-
239251@register_quantize_module_handler (UIntxWeightOnlyConfig )
240252def _uintx_weight_only_mps_transform (
241253 module : torch .nn .Module , config : UIntxWeightOnlyConfig
242254) -> torch .nn .Module :
243255 nbit = config .bitwidth
244256 group_size = config .group_size
245257
258+ if not module .weight .is_contiguous ():
259+ raise ValueError (
260+ "UIntxWeightOnlyQuantizedLinear requires contiguous weights. "
261+ "Please call .contiguous() on the weight tensor before quantization."
262+ )
263+
246264 qlinear = UIntxWeightOnlyQuantizedLinear (
247265 pack_weight_op = getattr (torch .ops .torchao , f"_pack_weight_{ nbit } bit" ),
248266 linear_op = getattr (torch .ops .torchao , f"_linear_fp_act_{ nbit } bit_weight" ),
267+ bias = module .bias ,
249268 )
250269 qlinear .quantize_and_pack_weights (module .weight , nbit , group_size )
251270 return qlinear
0 commit comments