Skip to content

Commit d1fa9a2

Browse files
authored
Add parameter_name support to _int4_weight_only_transform (#3901)
1 parent be10b2d commit d1fa9a2

2 files changed

Lines changed: 54 additions & 7 deletions

File tree

test/quantization/test_quant_api.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
Int4WeightOnlyConfig,
4141
Int8DynamicActivationInt8WeightConfig,
4242
Int8DynamicActivationIntxWeightConfig,
43+
Int8StaticActivationInt8WeightConfig,
4344
Int8WeightOnlyConfig,
4445
IntxWeightOnlyConfig,
4546
ModuleFqnToConfig,
@@ -1036,6 +1037,36 @@ def __init__(self):
10361037
assert isinstance(m.nested.linear.weight, AffineQuantizedTensor)
10371038
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
10381039

1040+
def test_fqn_to_config_non_weight_param(self):
1041+
configs = [
1042+
Int4WeightOnlyConfig(group_size=128),
1043+
Int8WeightOnlyConfig(),
1044+
Int8StaticActivationInt8WeightConfig(),
1045+
Float8WeightOnlyConfig(),
1046+
Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()),
1047+
]
1048+
for config in configs:
1049+
with self.subTest(config=type(config).__name__):
1050+
model = torch.nn.Sequential(
1051+
torch.nn.Linear(128, 128).to(torch.bfloat16).cuda()
1052+
)
1053+
model[0].register_parameter(
1054+
"custom_param",
1055+
torch.nn.Parameter(
1056+
torch.randn(128, 128, dtype=torch.bfloat16, device="cuda")
1057+
),
1058+
)
1059+
original_custom_param = model[0].custom_param
1060+
original_weight = model[0].weight
1061+
quant_config = FqnToConfig({"0.custom_param": config})
1062+
quantize_(model, quant_config, filter_fn=None)
1063+
assert model[0].custom_param is not original_custom_param, (
1064+
f"custom_param should be quantized for {type(config).__name__}"
1065+
)
1066+
assert model[0].weight is original_weight, (
1067+
f"weight should be unchanged for {type(config).__name__}"
1068+
)
1069+
10391070
def test_fqn_config_module_config_and_fqn_config_both_specified(self):
10401071
with self.assertRaises(ValueError):
10411072
FqnToConfig(

torchao/quantization/quant_api.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -821,18 +821,34 @@ def _int4_weight_only_quantize_tensor(weight, config):
821821

822822
@register_quantize_module_handler(Int4WeightOnlyConfig)
823823
def _int4_weight_only_transform(
824-
module: torch.nn.Module, config: Int4WeightOnlyConfig
824+
module: torch.nn.Module,
825+
config: Int4WeightOnlyConfig,
826+
*,
827+
parameter_name: str = "weight",
825828
) -> torch.nn.Module:
826829
if config.set_inductor_config:
827830
torchao.quantization.utils.recommended_inductor_config_setter()
828831

829-
assert hasattr(module, "weight"), (
830-
"applying int8 weight only quant requires module to have weight attribute"
831-
+ " but {module} does not have one"
832+
assert hasattr(module, parameter_name), (
833+
f"applying int4 weight only quant requires module to have {parameter_name} attribute"
834+
+ f" but {module} does not have one"
835+
)
836+
new_weight = _int4_weight_only_quantize_tensor(
837+
getattr(module, parameter_name), config
838+
)
839+
setattr(
840+
module,
841+
parameter_name,
842+
torch.nn.Parameter(new_weight, requires_grad=False),
843+
)
844+
module.extra_repr = types.MethodType(
845+
partial(
846+
_module_extra_repr,
847+
original_extra_repr=module.extra_repr,
848+
parameter_name=parameter_name,
849+
),
850+
module,
832851
)
833-
new_weight = _int4_weight_only_quantize_tensor(module.weight, config)
834-
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
835-
module.extra_repr = types.MethodType(_linear_extra_repr, module)
836852
return module
837853

838854

0 commit comments

Comments
 (0)