Skip to content

Commit cd062f2

Browse files
authored
Add parameter_name support to _intx_weight_only_transform (#3905)
1 parent 1efc5d0 commit cd062f2

2 files changed

Lines changed: 19 additions & 11 deletions

File tree

test/quantization/test_quant_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,6 +1046,7 @@ def test_fqn_to_config_non_weight_param(self):
10461046
Int8DynamicActivationInt8WeightConfig(),
10471047
Int8DynamicActivationIntxWeightConfig(),
10481048
Int8StaticActivationInt8WeightConfig(),
1049+
IntxWeightOnlyConfig(),
10491050
Float8WeightOnlyConfig(),
10501051
Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()),
10511052
]

torchao/quantization/quant_api.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1813,26 +1813,33 @@ def _intx_weight_only_transform(
18131813
module: torch.nn.Module,
18141814
config: IntxWeightOnlyConfig,
18151815
*,
1816+
parameter_name: str = "weight",
18161817
custom_scale: Optional[torch.Tensor] = None,
18171818
custom_zero_point: Optional[torch.Tensor] = None,
18181819
) -> torch.nn.Module:
1819-
assert hasattr(module, "weight"), (
1820-
"applying intx weight only quant requires module to have weight attribute"
1821-
+ " but {module} does not have one"
1820+
assert hasattr(module, parameter_name), (
1821+
f"applying intx weight only quant requires module to have {parameter_name} attribute"
1822+
+ f" but {module} does not have one"
18221823
)
18231824
new_weight = _intx_weight_only_quantize_tensor(
1824-
module.weight,
1825+
getattr(module, parameter_name),
18251826
config,
18261827
custom_scale=custom_scale,
18271828
custom_zero_point=custom_zero_point,
18281829
)
1829-
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
1830-
1831-
if isinstance(module, nn.Linear):
1832-
module.extra_repr = types.MethodType(_linear_extra_repr, module)
1833-
elif isinstance(module, nn.Embedding):
1834-
module.extra_repr = types.MethodType(_embedding_extra_repr, module)
1835-
1830+
setattr(
1831+
module,
1832+
parameter_name,
1833+
torch.nn.Parameter(new_weight, requires_grad=False),
1834+
)
1835+
module.extra_repr = types.MethodType(
1836+
partial(
1837+
_module_extra_repr,
1838+
original_extra_repr=module.extra_repr,
1839+
parameter_name=parameter_name,
1840+
),
1841+
module,
1842+
)
18361843
return module
18371844

18381845

0 commit comments

Comments
 (0)