Skip to content

Commit ace3d95

Browse files
committed
extend gptq example script with olmoe model
Summary: Add the `allenai/OLMoE-1B-7B-0924` model to the GPTQ example script. For now, GPTQ calibration succeeds and GPTQ convert fails with https://gist.github.com/vkuzo/51b2bfcee77fc193253faf007d99d694, will fix it in next PR. Test Plan: ``` torchao/prototype/gptq/gptq_nvfp4_olmoe_1b_7b_nonsequential_wikitext.sh ``` ghstack-source-id: 570d8b2 ghstack-comment-id: 4313776317 Pull-Request: #4329
1 parent 7a34455 commit ace3d95

2 files changed

Lines changed: 118 additions & 7 deletions

File tree

torchao/prototype/gptq/gptq_example.py

Lines changed: 93 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import gc
1010
import subprocess
1111
import time
12+
from contextlib import nullcontext
1213
from typing import Any, List, Optional
1314

1415
import torch
@@ -22,7 +23,13 @@
2223
from torchao.prototype.mx_formats.inference_workflow import (
2324
NVFP4DynamicActivationNVFP4WeightConfig,
2425
)
25-
from torchao.quantization import Int4WeightOnlyConfig, Int8WeightOnlyConfig, quantize_
26+
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
27+
from torchao.quantization import (
28+
FqnToConfig,
29+
Int4WeightOnlyConfig,
30+
Int8WeightOnlyConfig,
31+
quantize_,
32+
)
2633
from torchao.quantization.granularity import PerRow
2734

2835
"""
@@ -264,9 +271,37 @@ def parse_args():
264271
return parser.parse_args()
265272

266273

274+
OLMOE_MODEL_ID = "allenai/OLMoE-1B-7B-0924"
275+
276+
277+
def _verify_olmoe_experts_quantized(model):
278+
"""Assert every OlmoeExperts module has NVFP4Tensor for both expert weights."""
279+
from transformers.models.olmoe.modeling_olmoe import OlmoeExperts
280+
281+
found = 0
282+
for name, mod in model.named_modules():
283+
if not isinstance(mod, OlmoeExperts):
284+
continue
285+
for pname in ("gate_up_proj", "down_proj"):
286+
param = getattr(mod, pname)
287+
assert isinstance(param, NVFP4Tensor), (
288+
f"{name}.{pname} is {type(param).__name__}, expected NVFP4Tensor"
289+
)
290+
found += 1
291+
assert found > 0, "no OlmoeExperts modules found to verify"
292+
print(f"Verified NVFP4 quantization on {found} OlmoeExperts modules")
293+
294+
267295
def main():
268296
args = parse_args()
269297

298+
is_olmoe = args.model_id == OLMOE_MODEL_ID
299+
if is_olmoe and args.quantization not in ("nvfp4-rtn", "nvfp4-gptq-nonsequential"):
300+
raise ValueError(
301+
f"model {args.model_id} only supports 'nvfp4-rtn' or "
302+
f"'nvfp4-gptq-nonsequential', got '{args.quantization}'"
303+
)
304+
270305
# lm_eval batch_size="auto" with nvfp4 gptq causes the error in
271306
# MSLK nvfp4 triton kernel, likely an unsupported shape:
272307
# https://gist.github.com/vkuzo/b71ca46365dee017d1602e9638d91603
@@ -284,10 +319,11 @@ def main():
284319
dtype = dtype_map.get("bfloat16", torch.bfloat16)
285320

286321
print(f"Loading model {args.model_id}...")
322+
from_pretrained_kwargs = dict(device_map="cuda:0", dtype=dtype)
323+
if is_olmoe:
324+
from_pretrained_kwargs["experts_implementation"] = "grouped_mm"
287325
model = AutoModelForCausalLM.from_pretrained(
288-
args.model_id,
289-
device_map="cuda:0",
290-
dtype=dtype,
326+
args.model_id, **from_pretrained_kwargs
291327
)
292328
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
293329

@@ -353,7 +389,20 @@ def skip_lm_head_o_proj(module, fqn):
353389
use_dynamic_per_tensor_scale=True,
354390
use_triton_kernel=True,
355391
)
356-
quantize_(model, config, filter_fn=filter_fn_to_use)
392+
if is_olmoe:
393+
quantize_(
394+
model,
395+
FqnToConfig(
396+
{
397+
r"re:.*\.experts\.gate_up_proj": config,
398+
r"re:.*\.experts\.down_proj": config,
399+
}
400+
),
401+
filter_fn=None,
402+
)
403+
_verify_olmoe_experts_quantized(model)
404+
else:
405+
quantize_(model, config, filter_fn=filter_fn_to_use)
357406

358407
elif args.quantization in [
359408
"int4-gptq-sequential",
@@ -387,7 +436,19 @@ def skip_lm_head_o_proj(module, fqn):
387436
percdamp=args.percdamp,
388437
gptq_quantize_block_size=args.gptq_block_size,
389438
)
390-
quantize_(model, observe_config, filter_fn=filter_fn_to_use)
439+
if is_olmoe:
440+
quantize_(
441+
model,
442+
FqnToConfig(
443+
{
444+
r"re:.*\.experts\.gate_up_proj": observe_config,
445+
r"re:.*\.experts\.down_proj": observe_config,
446+
}
447+
),
448+
filter_fn=None,
449+
)
450+
else:
451+
quantize_(model, observe_config, filter_fn=filter_fn_to_use)
391452

392453
# Prepare calibration dataset
393454
print(
@@ -425,12 +486,31 @@ def skip_lm_head_o_proj(module, fqn):
425486
num_gptq_weights += 1
426487
print(f"Total GPTQ weights to convert: {num_gptq_weights}")
427488
# Apply quantization
428-
quantize_(model, convert_config, filter_fn=filter_fn_to_use)
489+
if is_olmoe:
490+
quantize_(
491+
model,
492+
FqnToConfig(
493+
{
494+
r"re:.*\.experts\.gate_up_proj": convert_config,
495+
r"re:.*\.experts\.down_proj": convert_config,
496+
}
497+
),
498+
filter_fn=None,
499+
)
500+
_verify_olmoe_experts_quantized(model)
501+
else:
502+
quantize_(model, convert_config, filter_fn=filter_fn_to_use)
429503
else: # sequential
430504
print(f"Applying {quant_type} GPTQ quantization (sequential)...")
431505
assert filter_fn_to_use == skip_lm_head, "unsupported"
432506
sequential_quantize(model, dataset, convert_config)
433507

508+
if is_olmoe:
509+
# generate() switches to batched_mm for decoding, which doesn't support
510+
# NVFP4Tensor (needs aten.index.Tensor). Override to keep grouped_mm.
511+
# TODO(future): remove when NVFP4 MoE supports bmm-style decode
512+
model._optimize_model_for_decode = nullcontext
513+
434514
quantization_end_time = time.time()
435515
quantization_time = quantization_end_time - quantization_start_time
436516

@@ -456,6 +536,12 @@ def skip_lm_head_o_proj(module, fqn):
456536
"Please install a version that includes "
457537
"https://github.com/huggingface/transformers/pull/45573"
458538
)
539+
if is_olmoe and "gate_up_proj" not in source:
540+
raise RuntimeError(
541+
"Your version of `transformers` does not support NVFP4 MoE serialization. "
542+
"Please install a version that includes "
543+
"https://github.com/huggingface/transformers/pull/45609"
544+
)
459545

460546
if args.quantization != "none":
461547
# Attach hf_quantizer so save_pretrained uses the flatten path for tensor
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#!/bin/bash
2+
3+
#
4+
# A quick smoke test for non-sequential GPTQ on `allenai/OLMoE-1B-7B-0924`
5+
#
6+
7+
COMMON_ARGS="--output-dir-prefix /home/dev/tmp/20260421 --model-id allenai/OLMoE-1B-7B-0924 --lm-eval-tasks wikitext --num-fewshot 0 --lm-eval-batch-size 16"
8+
9+
# baseline (bf16)
10+
echo -e "\n\nbaseline (bf16)\n\n"
11+
# python -u torchao/prototype/gptq/gptq_example.py $COMMON_ARGS --quantization none
12+
echo -e "done"
13+
14+
# nvfp4-rtn
15+
echo -e "\n\nnvfp4-rtn\n\n"
16+
# python -u torchao/prototype/gptq/gptq_example.py $COMMON_ARGS --quantization nvfp4-rtn
17+
echo -e "done"
18+
19+
# nvfp4-gptq-nonsequential
20+
echo -e "\n\nnvfp4-gptq-nonsequential\n\n"
21+
# TODO(future PR): fix https://gist.github.com/vkuzo/51b2bfcee77fc193253faf007d99d694
22+
# and enable this
23+
# python -u torchao/prototype/gptq/gptq_example.py $COMMON_ARGS --quantization nvfp4-gptq-nonsequential --dataset-id c4 --dataset-split train
24+
echo -e "done"
25+

0 commit comments

Comments
 (0)