Skip to content

Commit 5a5029d

Browse files
authored
Add parameter_name support to _nvfp4_inference_linear_transform (#3976)
* Add parameter_name support to _mx_inference_linear_transform Add `parameter_name` keyword-only argument to support FqnToConfig per-parameter quantization. Replace module.weight reads/writes with getattr/setattr, and use _module_extra_repr partial pattern. [ghstack-poisoned] * Add parameter_name support to _nvfp4_inference_linear_transform Add `parameter_name` keyword-only argument to the dynamic quantization branch (step=None) to support FqnToConfig per-parameter quantization. PREPARE and CONVERT branches are left unchanged as they are Linear-specific observer flow. [ghstack-poisoned] * Update base for Update on "Add parameter_name support to _nvfp4_inference_linear_transform" Add `parameter_name` keyword-only argument to the dynamic quantization branch (step=None) to support FqnToConfig per-parameter quantization. PREPARE and CONVERT branches are left unchanged as they are Linear-specific observer flow. [ghstack-poisoned]
1 parent f97ec45 commit 5a5029d

2 files changed

Lines changed: 24 additions & 4 deletions

File tree

test/quantization/test_quant_api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
is_sm_at_least_89,
6464
is_sm_at_least_90,
6565
is_sm_at_least_100,
66+
torch_version_at_least,
6667
unwrap_tensor_subclass,
6768
)
6869

@@ -75,6 +76,7 @@
7576

7677
from torchao.prototype.mx_formats.inference_workflow import (
7778
MXDynamicActivationMXWeightConfig,
79+
NVFP4DynamicActivationNVFP4WeightConfig,
7880
)
7981

8082

@@ -1057,6 +1059,8 @@ def test_fqn_to_config_non_weight_param(self):
10571059
]
10581060
if is_sm_at_least_100():
10591061
configs.append(MXDynamicActivationMXWeightConfig())
1062+
if is_sm_at_least_100() and torch_version_at_least("2.8.0"):
1063+
configs.append(NVFP4DynamicActivationNVFP4WeightConfig())
10601064
for config in configs:
10611065
with self.subTest(config=type(config).__name__):
10621066
model = torch.nn.Sequential(

torchao/prototype/mx_formats/inference_workflow.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,10 @@ def __post_init__(self):
240240

241241
@register_quantize_module_handler(NVFP4DynamicActivationNVFP4WeightConfig)
242242
def _nvfp4_inference_linear_transform(
243-
module: torch.nn.Linear, config: NVFP4DynamicActivationNVFP4WeightConfig
243+
module: torch.nn.Linear,
244+
config: NVFP4DynamicActivationNVFP4WeightConfig,
245+
*,
246+
parameter_name: str = "weight",
244247
):
245248
"""Quantization handler for NVFP4DynamicActivationNVFP4WeightConfig
246249
@@ -249,7 +252,7 @@ def _nvfp4_inference_linear_transform(
249252
- CONVERT: Extract amax from observer, compute static per_tensor_scale, quantize
250253
- None (default): Original dynamic quantization behavior
251254
"""
252-
weight = module.weight
255+
weight = getattr(module, parameter_name)
253256
if weight.shape[-2] % 16 != 0 or weight.shape[-1] % 16 != 0:
254257
raise RuntimeError(
255258
f"NVFP4 only supports weight shape with last 2 dims divisible by 16, got {weight.shape}"
@@ -306,6 +309,8 @@ def _nvfp4_inference_linear_transform(
306309
"NVFP4 DYNAMIC mode is only supported on sm100+ machines"
307310
)
308311

312+
weight = getattr(module, parameter_name)
313+
309314
per_tensor_scale = None
310315
if config.use_dynamic_per_tensor_scale:
311316
tensor_amax = torch.max(torch.abs(weight))
@@ -325,8 +330,19 @@ def _nvfp4_inference_linear_transform(
325330
act_quant_kwargs=act_quant_kwargs,
326331
)
327332
quantized_weight.use_triton_kernel = config.use_triton_kernel
328-
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
329-
module.extra_repr = types.MethodType(_linear_extra_repr, module)
333+
setattr(
334+
module,
335+
parameter_name,
336+
torch.nn.Parameter(quantized_weight, requires_grad=False),
337+
)
338+
module.extra_repr = types.MethodType(
339+
partial(
340+
_module_extra_repr,
341+
original_extra_repr=module.extra_repr,
342+
parameter_name=parameter_name,
343+
),
344+
module,
345+
)
330346
return module
331347

332348
else:

0 commit comments

Comments
 (0)