@@ -1813,26 +1813,33 @@ def _intx_weight_only_transform(
18131813 module : torch .nn .Module ,
18141814 config : IntxWeightOnlyConfig ,
18151815 * ,
1816+ parameter_name : str = "weight" ,
18161817 custom_scale : Optional [torch .Tensor ] = None ,
18171818 custom_zero_point : Optional [torch .Tensor ] = None ,
18181819) -> torch .nn .Module :
1819- assert hasattr (module , "weight" ), (
1820- "applying intx weight only quant requires module to have weight attribute"
1821- + " but {module} does not have one"
1820+ assert hasattr (module , parameter_name ), (
1821+ f "applying intx weight only quant requires module to have { parameter_name } attribute"
1822+ + f " but { module } does not have one"
18221823 )
18231824 new_weight = _intx_weight_only_quantize_tensor (
1824- module . weight ,
1825+ getattr ( module , parameter_name ) ,
18251826 config ,
18261827 custom_scale = custom_scale ,
18271828 custom_zero_point = custom_zero_point ,
18281829 )
1829- module .weight = torch .nn .Parameter (new_weight , requires_grad = False )
1830-
1831- if isinstance (module , nn .Linear ):
1832- module .extra_repr = types .MethodType (_linear_extra_repr , module )
1833- elif isinstance (module , nn .Embedding ):
1834- module .extra_repr = types .MethodType (_embedding_extra_repr , module )
1835-
1830+ setattr (
1831+ module ,
1832+ parameter_name ,
1833+ torch .nn .Parameter (new_weight , requires_grad = False ),
1834+ )
1835+ module .extra_repr = types .MethodType (
1836+ partial (
1837+ _module_extra_repr ,
1838+ original_extra_repr = module .extra_repr ,
1839+ parameter_name = parameter_name ,
1840+ ),
1841+ module ,
1842+ )
18361843 return module
18371844
18381845
0 commit comments