Skip to content

Commit fa5aed0

Browse files
authored
gptq example: remove transformers version check (#4313)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent b49d8cb commit fa5aed0

1 file changed

Lines changed: 20 additions & 8 deletions

File tree

torchao/prototype/gptq/gptq_example.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212
from typing import Any, List, Optional
1313

1414
import torch
15-
import transformers
1615
from datasets import load_dataset
1716
from tqdm import tqdm
18-
from transformers import AutoModelForCausalLM, AutoTokenizer
17+
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
18+
from transformers.quantizers.quantizer_torchao import TorchAoHfQuantizer
1919

20-
from packaging.version import Version
2120
from torchao.prototype.gptq import GPTQConfig
2221
from torchao.prototype.gptq.observer import GPTQObserverTensor
2322
from torchao.prototype.mx_formats.inference_workflow import (
@@ -447,11 +446,24 @@ def skip_lm_head_o_proj(module, fqn):
447446
tokenizer.save_pretrained(output_dir)
448447
print(model)
449448

450-
# transformers 5.0.0 have a lot of errors with nvfp4 subclasses
451-
# TODO(before land): debug this further
452-
assert Version(transformers.__version__) < Version("5.0.0"), (
453-
f"transformers {transformers.__version__} is not supported, need < 5.0.0"
454-
)
449+
if "nvfp4" in args.quantization:
450+
import inspect
451+
452+
source = inspect.getsource(TorchAoHfQuantizer.get_weight_conversions)
453+
if "_weight_per_tensor_scale" not in source:
454+
raise RuntimeError(
455+
"Your version of `transformers` does not support NVFP4 serialization. "
456+
"Please install a version that includes "
457+
"https://github.com/huggingface/transformers/pull/45573"
458+
)
459+
460+
if args.quantization != "none":
461+
# Attach hf_quantizer so save_pretrained uses the flatten path for tensor
462+
# subclasses (e.g. NVFP4Tensor) that don't have a valid storage pointer.
463+
ao_config = base_config if "gptq" in args.quantization else config
464+
torchao_config = TorchAoConfig(quant_type=ao_config)
465+
model.config.quantization_config = torchao_config
466+
model.hf_quantizer = TorchAoHfQuantizer(torchao_config)
455467

456468
model.save_pretrained(output_dir, safe_serialization=False)
457469

0 commit comments

Comments
 (0)