Skip to content

Commit 12cd338

Browse files
committed
make gptq convert work for moe
Summary: Makes gptq + moe + nvfp4 work e2e, results as expected on tiny model + tiny dataset Test Plan: ``` > TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1 torchao/prototype/gptq/gptq_nvfp4_olmoe_1b_7b_nonsequential_wikitext.sh bf16 | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr| |--------|------:|------|-----:|---------------|---|-----:|---|------| |wikitext| 2|none | 0|bits_per_byte |↓ |0.5895|± | N/A| | | |none | 0|byte_perplexity|↓ |1.5047|± | N/A| | | |none | 0|word_perplexity|↓ |8.8910|± | N/A| real 0m59.219s user 0m42.554s sys 0m20.534s nvfp4-rtn | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr| |--------|------:|------|-----:|---------------|---|-----:|---|------| |wikitext| 2|none | 0|bits_per_byte |↓ |0.6024|± | N/A| | | |none | 0|byte_perplexity|↓ |1.5183|± | N/A| | | |none | 0|word_perplexity|↓ |9.3277|± | N/A| real 0m42.528s user 0m41.217s sys 0m12.817s nvfp4-nonsequential with 4096 calibration samples on c4 | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr| |--------|------:|------|-----:|---------------|---|-----:|---|------| |wikitext| 2|none | 0|bits_per_byte |↓ |0.6019|± | N/A| | | |none | 0|byte_perplexity|↓ |1.5177|± | N/A| | | |none | 0|word_perplexity|↓ |9.3087|± | N/A| real 22m28.505s user 22m36.008s sys 0m13.872s ``` ghstack-source-id: 43538e0 ghstack-comment-id: 4315147581 Pull-Request: #4330
1 parent ace3d95 commit 12cd338

5 files changed

Lines changed: 124 additions & 10 deletions

File tree

test/prototype/gptq/test_gptqv2.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torchao.prototype.gptq import (
2121
GPTQConfig,
2222
gptq_quantize,
23+
gptq_quantize_3d,
2324
)
2425
from torchao.prototype.gptq.observer import GPTQObserverTensor
2526
from torchao.prototype.mx_formats.inference_workflow import (
@@ -595,6 +596,66 @@ def test_gptq_quantize_better_than_naive(self, base_config):
595596
assert gptq_loss is not None
596597
assert naive_loss is not None
597598

599+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
600+
@pytest.mark.skipif(
601+
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for nvfp4"
602+
)
603+
def test_gptq_quantize_2d_matches_3d(self):
604+
"""Verify per-expert gptq_quantize and gptq_quantize_3d produce bitwise-identical outputs."""
605+
torch.manual_seed(43)
606+
607+
E = 4
608+
out_features = 64
609+
in_features = 128
610+
num_samples = 10
611+
612+
base_config = NVFP4DynamicActivationNVFP4WeightConfig(
613+
use_dynamic_per_tensor_scale=True,
614+
use_triton_kernel=True,
615+
)
616+
config = GPTQConfig(step="convert", base_config=base_config)
617+
618+
# Per-expert weights (E, N, K) and per-expert Hessians (E, K, K)
619+
weight_3d = torch.randn(
620+
E, out_features, in_features, dtype=torch.bfloat16, device="cuda"
621+
)
622+
hessians = []
623+
for _ in range(E):
624+
activations = [
625+
torch.randn(4, in_features, dtype=torch.float32, device="cuda")
626+
for _ in range(num_samples)
627+
]
628+
hessians.append(_calculate_hessian(activations, device="cuda"))
629+
hessian_3d = torch.stack(hessians, dim=0)
630+
631+
# gptq_quantize mutates its weight/Hessian arguments in place, so clone
632+
# per-experiment to keep the two paths independent.
633+
weight_a = weight_3d.clone()
634+
weight_b = weight_3d.clone()
635+
hessian_a = hessian_3d.clone()
636+
hessian_b = hessian_3d.clone()
637+
638+
# Experiment A: E separate 2D gptq_quantize calls
639+
per_expert_2d = [
640+
gptq_quantize(hessian_a[e], weight_a[e], config) for e in range(E)
641+
]
642+
643+
# Experiment B: single 3D gptq_quantize_3d call
644+
stacked_3d = gptq_quantize_3d(hessian_b, weight_b, config)
645+
646+
# Bitwise match per expert
647+
for e in range(E):
648+
assert torch.equal(per_expert_2d[e].qdata, stacked_3d.qdata[e]), (
649+
f"Expert {e}: qdata mismatch"
650+
)
651+
assert torch.equal(per_expert_2d[e].scale, stacked_3d.scale[e]), (
652+
f"Expert {e}: scale mismatch"
653+
)
654+
assert torch.equal(
655+
per_expert_2d[e].per_tensor_scale.view(1, 1),
656+
stacked_3d.per_tensor_scale[e],
657+
), f"Expert {e}: per_tensor_scale mismatch"
658+
598659
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
599660
@pytest.mark.parametrize(
600661
"base_config",

torchao/prototype/gptq/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from .api import GPTQConfig, gptq_quantize
7+
from .api import GPTQConfig, gptq_quantize, gptq_quantize_3d
88

9-
__all__ = ["GPTQConfig", "gptq_quantize"]
9+
__all__ = ["GPTQConfig", "gptq_quantize", "gptq_quantize_3d"]

torchao/prototype/gptq/api.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,12 @@ def _gptq_config_transform(
149149

150150
# Use pre-computed Hessian directly
151151
hessian = tensor.hessian
152-
new_tensor = gptq_quantize(hessian, tensor.hp_data, config)
152+
if len(tensor.shape) == 2:
153+
new_tensor = gptq_quantize(hessian, tensor.hp_data, config)
154+
else:
155+
assert len(tensor.shape) == 3, "unsupported"
156+
new_tensor = gptq_quantize_3d(hessian, tensor.hp_data, config)
157+
153158
new_quantized_tensor = nn.Parameter(new_tensor, requires_grad=False)
154159
setattr(module, parameter_name, new_quantized_tensor)
155160
return module
@@ -592,7 +597,51 @@ def gptq_quantize(H: torch.Tensor, W_t: torch.Tensor, config: GPTQConfig):
592597
return result
593598

594599

600+
def gptq_quantize_3d(H: torch.Tensor, W_t: torch.Tensor, config: GPTQConfig):
601+
"""3D variant of gptq_quantize for MoE expert weights.
602+
603+
Args:
604+
H: per-expert Hessian of shape (E, K, K)
605+
W_t: stacked expert weights of shape (E, N, K)
606+
config: GPTQ configuration (NVFP4 only)
607+
608+
Returns:
609+
NVFP4Tensor of shape (E, N, K) assembled from per-expert 2D results.
610+
"""
611+
assert H.dim() == 3 and W_t.dim() == 3
612+
assert H.shape[0] == W_t.shape[0]
613+
base_config = config.base_config
614+
assert isinstance(base_config, NVFP4DynamicActivationNVFP4WeightConfig), (
615+
"gptq_quantize_3d only supports NVFP4"
616+
)
617+
618+
E = W_t.shape[0]
619+
pieces = [gptq_quantize(H[e], W_t[e], config) for e in range(E)]
620+
621+
# Stack inner NVFP4Tensor fields along a new expert dim 0. These are plain
622+
# tensors (uint8 / float8_e4m3fn / float32), so torch.stack goes through
623+
# normal aten dispatch, not NVFP4Tensor.
624+
qdata_3d = torch.stack([p.qdata for p in pieces], dim=0)
625+
scale_3d = torch.stack([p.scale for p in pieces], dim=0)
626+
per_tensor_scale_3d = torch.stack(
627+
[p.per_tensor_scale.view(1, 1) for p in pieces], dim=0
628+
)
629+
630+
return NVFP4Tensor(
631+
qdata_3d,
632+
scale_3d,
633+
block_size=pieces[0].block_size,
634+
orig_dtype=pieces[0].orig_dtype,
635+
per_tensor_scale=per_tensor_scale_3d,
636+
act_per_tensor_scale=None,
637+
is_swizzled_scales=True,
638+
use_triton_kernel=pieces[0].use_triton_kernel,
639+
act_quant_kwargs=pieces[0].act_quant_kwargs,
640+
)
641+
642+
595643
__all__ = [
596644
"GPTQConfig",
597645
"gptq_quantize",
646+
"gptq_quantize_3d",
598647
]

torchao/prototype/gptq/gptq_example.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,13 @@ def main():
296296
args = parse_args()
297297

298298
is_olmoe = args.model_id == OLMOE_MODEL_ID
299-
if is_olmoe and args.quantization not in ("nvfp4-rtn", "nvfp4-gptq-nonsequential"):
299+
if is_olmoe and args.quantization not in (
300+
"none",
301+
"nvfp4-rtn",
302+
"nvfp4-gptq-nonsequential",
303+
):
300304
raise ValueError(
301-
f"model {args.model_id} only supports 'nvfp4-rtn' or "
305+
f"model {args.model_id} only supports 'none', 'nvfp4-rtn', or "
302306
f"'nvfp4-gptq-nonsequential', got '{args.quantization}'"
303307
)
304308

@@ -403,6 +407,7 @@ def skip_lm_head_o_proj(module, fqn):
403407
_verify_olmoe_experts_quantized(model)
404408
else:
405409
quantize_(model, config, filter_fn=filter_fn_to_use)
410+
print(model)
406411

407412
elif args.quantization in [
408413
"int4-gptq-sequential",
@@ -449,6 +454,7 @@ def skip_lm_head_o_proj(module, fqn):
449454
)
450455
else:
451456
quantize_(model, observe_config, filter_fn=filter_fn_to_use)
457+
print(model)
452458

453459
# Prepare calibration dataset
454460
print(

torchao/prototype/gptq/gptq_nvfp4_olmoe_1b_7b_nonsequential_wikitext.sh

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,16 @@ COMMON_ARGS="--output-dir-prefix /home/dev/tmp/20260421 --model-id allenai/OLMoE
88

99
# baseline (bf16)
1010
echo -e "\n\nbaseline (bf16)\n\n"
11-
# python -u torchao/prototype/gptq/gptq_example.py $COMMON_ARGS --quantization none
11+
time python -u torchao/prototype/gptq/gptq_example.py $COMMON_ARGS --quantization none
1212
echo -e "done"
1313

1414
# nvfp4-rtn
1515
echo -e "\n\nnvfp4-rtn\n\n"
16-
# python -u torchao/prototype/gptq/gptq_example.py $COMMON_ARGS --quantization nvfp4-rtn
16+
time python -u torchao/prototype/gptq/gptq_example.py $COMMON_ARGS --quantization nvfp4-rtn
1717
echo -e "done"
1818

1919
# nvfp4-gptq-nonsequential
2020
echo -e "\n\nnvfp4-gptq-nonsequential\n\n"
21-
# TODO(future PR): fix https://gist.github.com/vkuzo/51b2bfcee77fc193253faf007d99d694
22-
# and enable this
23-
# python -u torchao/prototype/gptq/gptq_example.py $COMMON_ARGS --quantization nvfp4-gptq-nonsequential --dataset-id c4 --dataset-split train
21+
time python -u torchao/prototype/gptq/gptq_example.py $COMMON_ARGS --quantization nvfp4-gptq-nonsequential --dataset-id c4 --dataset-split train --num-calibration-samples 4096
2422
echo -e "done"
2523

0 commit comments

Comments
 (0)