Skip to content

Commit 83d1561

Browse files
committed
Add FP8-INT4 checkpoint upload code
Summary: att, the support is added in #3714 checkpoint: https://huggingface.co/jerryzh168/Qwen3-8B-FP8-INT4 Test Plan: ``` sh release.sh --model_id $MODEL --push_to_hub --populate_model_card_template --quants FP8-INT4 ``` produced checkpoint: https://huggingface.co/jerryzh168/Qwen3-8B-FP8-INT4 Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent 3075bb6 commit 83d1561

1 file changed

Lines changed: 31 additions & 7 deletions

File tree

.github/scripts/torchao_model_releases/quantize_and_upload.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import argparse
88
from typing import List
99

10+
import huggingface_hub
1011
import torch
1112
import transformers
1213
from huggingface_hub import ModelCard, get_token, whoami
@@ -16,6 +17,8 @@
1617
if _transformers_version >= "5":
1718
from transformers.quantizers.auto import get_hf_quantizer
1819

20+
_huggingface_hub_version = str(huggingface_hub.__version__)
21+
1922
from torchao._models._eval import TransformerEvalWrapper
2023
from torchao.prototype.awq import (
2124
AWQConfig,
@@ -27,6 +30,7 @@
2730
from torchao.prototype.smoothquant import SmoothQuantConfig
2831
from torchao.quantization import (
2932
Float8DynamicActivationFloat8WeightConfig,
33+
Float8DynamicActivationInt4WeightConfig,
3034
Int4WeightOnlyConfig,
3135
Int8DynamicActivationInt8WeightConfig,
3236
Int8DynamicActivationIntxWeightConfig,
@@ -238,6 +242,14 @@ def _untie_weights_and_save_locally(model_id, device):
238242
tokenizer = AutoTokenizer.from_pretrained(model_id)
239243
"""
240244

245+
_fp8_int4_quant_code = """
246+
from torchao.quantization import Float8DynamicActivationInt4WeightConfig
247+
quant_config = Float8DynamicActivationInt4WeightConfig(int4_packing_format="plain")
248+
quantization_config = TorchAoConfig(quant_type=quant_config)
249+
quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="{device}", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
250+
tokenizer = AutoTokenizer.from_pretrained(model_id)
251+
"""
252+
241253
_int8_int4_quant_code = """
242254
from torchao.quantization.quant_api import (
243255
IntxWeightOnlyConfig,
@@ -687,6 +699,9 @@ def quantize_and_upload(
687699

688700
quant_to_config = {
689701
"FP8": Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
702+
"FP8-INT4": Float8DynamicActivationInt4WeightConfig(
703+
int4_packing_format="plain"
704+
),
690705
"INT4": Int4WeightOnlyConfig(
691706
group_size=128,
692707
int4_packing_format="tile_packed_to_4d",
@@ -725,6 +740,7 @@ def quantize_and_upload(
725740

726741
quant_to_quant_code = {
727742
"FP8": _fp8_quant_code,
743+
"FP8-INT4": _fp8_int4_quant_code,
728744
"INT4": _int4_quant_code,
729745
"INT8-INT4": _int8_int4_quant_code,
730746
"INT8-INT4-HQQ": _int8_int4_hqq_quant_code,
@@ -908,16 +924,24 @@ def filter_fn_skip_lmhead(module, fqn):
908924

909925
# Push to hub
910926
if push_to_hub:
911-
quantized_model.push_to_hub(
912-
quantized_model_id, safe_serialization=safe_serialization
913-
)
927+
if _huggingface_hub_version < "1.4.1":
928+
quantized_model.push_to_hub(
929+
quantized_model_id, safe_serialization=safe_serialization
930+
)
931+
else:
932+
quantized_model.push_to_hub(quantized_model_id)
933+
914934
tokenizer.push_to_hub(quantized_model_id)
915935
if populate_model_card_template:
916936
card.push_to_hub(quantized_model_id)
917937
else:
918-
quantized_model.save_pretrained(
919-
quantized_model_id, safe_serialization=safe_serialization
920-
)
938+
if _huggingface_hub_version < "1.4.1":
939+
quantized_model.save_pretrained(
940+
quantized_model_id, safe_serialization=safe_serialization
941+
)
942+
else:
943+
quantized_model.save_pretrained(quantized_model_id)
944+
921945
tokenizer.save_pretrained(quantized_model_id)
922946

923947
# Manual Testing
@@ -960,7 +984,7 @@ def filter_fn_skip_lmhead(module, fqn):
960984
parser.add_argument(
961985
"--quant",
962986
type=str,
963-
help="Quantization method. Options are FP8, INT4, INT8-INT4, INT8-INT4-HQQ, AWQ-INT4, SmoothQuant-INT8-INT8, MXFP8, NVFP4",
987+
help="Quantization method. Options are FP8, FP8-INT4, INT4, INT8-INT4, INT8-INT4-HQQ, AWQ-INT4, SmoothQuant-INT8-INT8, MXFP8, NVFP4",
964988
)
965989
parser.add_argument(
966990
"--tasks",

0 commit comments

Comments
 (0)