|
12 | 12 | from typing import Any, List, Optional |
13 | 13 |
|
14 | 14 | import torch |
15 | | -import transformers |
16 | 15 | from datasets import load_dataset |
17 | 16 | from tqdm import tqdm |
18 | | -from transformers import AutoModelForCausalLM, AutoTokenizer |
| 17 | +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig |
| 18 | +from transformers.quantizers.quantizer_torchao import TorchAoHfQuantizer |
19 | 19 |
|
20 | | -from packaging.version import Version |
21 | 20 | from torchao.prototype.gptq import GPTQConfig |
22 | 21 | from torchao.prototype.gptq.observer import GPTQObserverTensor |
23 | 22 | from torchao.prototype.mx_formats.inference_workflow import ( |
@@ -447,11 +446,24 @@ def skip_lm_head_o_proj(module, fqn): |
447 | 446 | tokenizer.save_pretrained(output_dir) |
448 | 447 | print(model) |
449 | 448 |
|
450 | | - # transformers 5.0.0 have a lot of errors with nvfp4 subclasses |
451 | | - # TODO(before land): debug this further |
452 | | - assert Version(transformers.__version__) < Version("5.0.0"), ( |
453 | | - f"transformers {transformers.__version__} is not supported, need < 5.0.0" |
454 | | - ) |
| 449 | + if "nvfp4" in args.quantization: |
| 450 | + import inspect |
| 451 | + |
| 452 | + source = inspect.getsource(TorchAoHfQuantizer.get_weight_conversions) |
| 453 | + if "_weight_per_tensor_scale" not in source: |
| 454 | + raise RuntimeError( |
| 455 | + "Your version of `transformers` does not support NVFP4 serialization. " |
| 456 | + "Please install a version that includes " |
| 457 | + "https://github.com/huggingface/transformers/pull/45573" |
| 458 | + ) |
| 459 | + |
| 460 | + if args.quantization != "none": |
| 461 | + # Attach hf_quantizer so save_pretrained uses the flatten path for tensor |
| 462 | + # subclasses (e.g. NVFP4Tensor) that don't have a valid storage pointer. |
| 463 | + ao_config = base_config if "gptq" in args.quantization else config |
| 464 | + torchao_config = TorchAoConfig(quant_type=ao_config) |
| 465 | + model.config.quantization_config = torchao_config |
| 466 | + model.hf_quantizer = TorchAoHfQuantizer(torchao_config) |
455 | 467 |
|
456 | 468 | model.save_pretrained(output_dir, safe_serialization=False) |
457 | 469 |
|
|
0 commit comments