|
7 | 7 | import argparse |
8 | 8 | from typing import List |
9 | 9 |
|
| 10 | +import huggingface_hub |
10 | 11 | import torch |
11 | 12 | import transformers |
12 | 13 | from huggingface_hub import ModelCard, get_token, whoami |
|
16 | 17 | if _transformers_version >= "5": |
17 | 18 | from transformers.quantizers.auto import get_hf_quantizer |
18 | 19 |
|
| 20 | +_huggingface_hub_version = str(huggingface_hub.__version__) |
| 21 | + |
19 | 22 | from torchao._models._eval import TransformerEvalWrapper |
20 | 23 | from torchao.prototype.awq import ( |
21 | 24 | AWQConfig, |
|
27 | 30 | from torchao.prototype.smoothquant import SmoothQuantConfig |
28 | 31 | from torchao.quantization import ( |
29 | 32 | Float8DynamicActivationFloat8WeightConfig, |
| 33 | + Float8DynamicActivationInt4WeightConfig, |
30 | 34 | Int4WeightOnlyConfig, |
31 | 35 | Int8DynamicActivationInt8WeightConfig, |
32 | 36 | Int8DynamicActivationIntxWeightConfig, |
@@ -238,6 +242,14 @@ def _untie_weights_and_save_locally(model_id, device): |
238 | 242 | tokenizer = AutoTokenizer.from_pretrained(model_id) |
239 | 243 | """ |
240 | 244 |
|
| 245 | +_fp8_int4_quant_code = """ |
| 246 | +from torchao.quantization import Float8DynamicActivationInt4WeightConfig |
| 247 | +quant_config = Float8DynamicActivationInt4WeightConfig(int4_packing_format="plain") |
| 248 | +quantization_config = TorchAoConfig(quant_type=quant_config) |
| 249 | +quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="{device}", torch_dtype=torch.bfloat16, quantization_config=quantization_config) |
| 250 | +tokenizer = AutoTokenizer.from_pretrained(model_id) |
| 251 | +""" |
| 252 | + |
241 | 253 | _int8_int4_quant_code = """ |
242 | 254 | from torchao.quantization.quant_api import ( |
243 | 255 | IntxWeightOnlyConfig, |
@@ -687,6 +699,9 @@ def quantize_and_upload( |
687 | 699 |
|
688 | 700 | quant_to_config = { |
689 | 701 | "FP8": Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), |
| 702 | + "FP8-INT4": Float8DynamicActivationInt4WeightConfig( |
| 703 | + int4_packing_format="plain" |
| 704 | + ), |
690 | 705 | "INT4": Int4WeightOnlyConfig( |
691 | 706 | group_size=128, |
692 | 707 | int4_packing_format="tile_packed_to_4d", |
@@ -725,6 +740,7 @@ def quantize_and_upload( |
725 | 740 |
|
726 | 741 | quant_to_quant_code = { |
727 | 742 | "FP8": _fp8_quant_code, |
| 743 | + "FP8-INT4": _fp8_int4_quant_code, |
728 | 744 | "INT4": _int4_quant_code, |
729 | 745 | "INT8-INT4": _int8_int4_quant_code, |
730 | 746 | "INT8-INT4-HQQ": _int8_int4_hqq_quant_code, |
@@ -908,16 +924,24 @@ def filter_fn_skip_lmhead(module, fqn): |
908 | 924 |
|
909 | 925 | # Push to hub |
910 | 926 | if push_to_hub: |
911 | | - quantized_model.push_to_hub( |
912 | | - quantized_model_id, safe_serialization=safe_serialization |
913 | | - ) |
| 927 | + if _huggingface_hub_version < "1.4.1": |
| 928 | + quantized_model.push_to_hub( |
| 929 | + quantized_model_id, safe_serialization=safe_serialization |
| 930 | + ) |
| 931 | + else: |
| 932 | + quantized_model.push_to_hub(quantized_model_id) |
| 933 | + |
914 | 934 | tokenizer.push_to_hub(quantized_model_id) |
915 | 935 | if populate_model_card_template: |
916 | 936 | card.push_to_hub(quantized_model_id) |
917 | 937 | else: |
918 | | - quantized_model.save_pretrained( |
919 | | - quantized_model_id, safe_serialization=safe_serialization |
920 | | - ) |
| 938 | + if _huggingface_hub_version < "1.4.1": |
| 939 | + quantized_model.save_pretrained( |
| 940 | + quantized_model_id, safe_serialization=safe_serialization |
| 941 | + ) |
| 942 | + else: |
| 943 | + quantized_model.save_pretrained(quantized_model_id) |
| 944 | + |
921 | 945 | tokenizer.save_pretrained(quantized_model_id) |
922 | 946 |
|
923 | 947 | # Manual Testing |
@@ -960,7 +984,7 @@ def filter_fn_skip_lmhead(module, fqn): |
960 | 984 | parser.add_argument( |
961 | 985 | "--quant", |
962 | 986 | type=str, |
963 | | - help="Quantization method. Options are FP8, INT4, INT8-INT4, INT8-INT4-HQQ, AWQ-INT4, SmoothQuant-INT8-INT8, MXFP8, NVFP4", |
| 987 | + help="Quantization method. Options are FP8, FP8-INT4, INT4, INT8-INT4, INT8-INT4-HQQ, AWQ-INT4, SmoothQuant-INT8-INT8, MXFP8, NVFP4", |
964 | 988 | ) |
965 | 989 | parser.add_argument( |
966 | 990 | "--tasks", |
|
0 commit comments