Skip to content

Commit 920c502

Browse files
authored
Land 3894, 3887, 3884, 3883 into main from base branch (#3920)
* Remove Float8DynamicActivationFloat8SemiSparseWeightConfig This config was deprecated in favor of Float8DynamicActivationFloat8WeightConfig with packing_format=Float8PackingFormat.SPARSE_CUTLASS and granularity=PerRow(). Remove the class definition, handler, and all references from imports, tests, and benchmarks. Co-authored-by: Cursor <cursoragent@cursor.com> [ghstack-poisoned] * Remove Int8DynamicActivationInt4WeightConfig This config was deprecated in favor of Int8DynamicActivationIntxWeightConfig. Remove the class definition, handler, and all references from imports, tests, QAT code, benchmarks, and documentation. Update QAT docs to reference Int4WeightOnlyConfig as the example base config. Co-authored-by: Cursor <cursoragent@cursor.com> [ghstack-poisoned] * Remove GemliteUIntXWeightOnlyConfig This config was deprecated and scheduled for deletion. Remove the class definition, handler, and all references from imports, tests, benchmarks, and documentation. Co-authored-by: Cursor <cursoragent@cursor.com> [ghstack-poisoned] * Remove Float8StaticActivationFloat8WeightConfig Remove the config class, its supporting classes (Float8ObservedLinear, Float8ObservedSoftmax, Float8QuantizedSoftmax), the handler function, and all references from imports and tests. Co-authored-by: Cursor <cursoragent@cursor.com> [ghstack-poisoned] * Remove UIntXWeightOnlyConfig This config was deprecated and scheduled for deletion. Remove the class definition, handler, and all references from imports, tests, benchmarks, and the autoround eval script. This also removes the entire BC import block from quant_api.py since all prototype configs have been removed. Co-authored-by: Cursor <cursoragent@cursor.com> [ghstack-poisoned] * Update base for Update on "Remove UIntXWeightOnlyConfig" This config was deprecated and scheduled for deletion. Remove the class definition, handler, and all references from imports, tests, benchmarks, and the autoround eval script. This also removes the entire BC import block from quant_api.py since all prototype configs have been removed. Co-authored-by: Cursor <cursoragentcursor.com> [ghstack-poisoned] * Update base for Update on "Remove CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS, inspect function sig" Summary: This PR removes CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS, in favor of using `inspect.signature` to ensure that the given handler has a parameter_name kwarg we can use to pass in the param fqn. Test Plan: ``` pytest test/quantization/test_quant_api -k fqn ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] * Update base for Update on "Remove CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS, inspect function sig" Summary: This PR removes CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS, in favor of using `inspect.signature` to ensure that the given handler has a parameter_name kwarg we can use to pass in the param fqn. Test Plan: ``` pytest test/quantization/test_quant_api -k fqn ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] * Update base for Update on "Remove CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS, inspect function sig" Summary: This PR removes CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS, in favor of using `inspect.signature` to ensure that the given handler has a parameter_name kwarg we can use to pass in the param fqn. Test Plan: ``` pytest test/quantization/test_quant_api -k fqn ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] * Update base for Update on "Remove CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS, inspect function sig" Summary: This PR removes CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS, in favor of using `inspect.signature` to ensure that the given handler has a parameter_name kwarg we can use to pass in the param fqn. Test Plan: ``` pytest test/quantization/test_quant_api -k fqn ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] * Remove CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS, inspect function sig (#3894) Summary: This PR removes CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS, in favor of using `inspect.signature` to ensure that the given handler has a parameter_name kwarg we can use to pass in the param fqn. Test Plan: ``` pytest test/quantization/test_quant_api -k fqn ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent 01b37b2 commit 920c502

25 files changed

Lines changed: 114 additions & 1091 deletions

File tree

benchmarks/benchmark_e2e_fp8_sparse_linear.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
)
1515
from torchao.prototype.sparsity.activation.utils import SquaredReLU
1616
from torchao.quantization import (
17-
Float8DynamicActivationFloat8SemiSparseWeightConfig,
1817
Float8DynamicActivationFloat8WeightConfig,
1918
Float8MMConfig,
2019
PerRow,
@@ -84,20 +83,6 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192):
8483
ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True)
8584
fp8_c_time = benchmark_microseconds(ffn_clone, input_tensor)
8685

87-
# fp8 sparse
88-
ffn_clone = (
89-
nn.Sequential(
90-
nn.Linear(hidden_size, intermediate_size, bias=False),
91-
SquaredReLU(),
92-
nn.Linear(intermediate_size, hidden_size, bias=False),
93-
)
94-
.to(torch.bfloat16)
95-
.cuda()
96-
)
97-
quantize_(ffn_clone, Float8DynamicActivationFloat8SemiSparseWeightConfig())
98-
ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True)
99-
fp8_c_sparse_time = benchmark_microseconds(ffn_clone, input_tensor)
100-
10186
# activation fp8 sparse
10287
ffn_clone = (
10388
nn.Sequential(
@@ -127,7 +112,6 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192):
127112
"bf16_latency (us)": fp16_time,
128113
"bf16_c_latency (us)": fp16_c_time,
129114
"fp8_c_time (us)": fp8_c_time,
130-
"fp8_c_sparse_time (us)": fp8_c_sparse_time,
131115
"fp8_c_activation_sparse_time (us)": fp8_c_activation_sparse_time,
132116
"ao_fast_sparsification_time (us)": ao_fast_sparsification_time,
133117
"cusparselt_compress_time (us)": cusparselt_time,

benchmarks/microbenchmarks/test/test_utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
BenchmarkConfig,
1414
BenchmarkResult,
1515
BlockSparseWeightConfig,
16-
Float8DynamicActivationFloat8SemiSparseWeightConfig,
1716
Int4WeightOnlyConfig,
1817
SemiSparseWeightConfig,
1918
clean_caches,
@@ -112,12 +111,6 @@ def test_string_to_config_sparsity(self):
112111
config = string_to_config("marlin", "semi-sparse")
113112
self.assertIsInstance(config, Int4WeightOnlyConfig)
114113

115-
# Test float8 with semi-sparse
116-
config = string_to_config("float8dq", "semi-sparse")
117-
self.assertIsInstance(
118-
config, Float8DynamicActivationFloat8SemiSparseWeightConfig
119-
)
120-
121114
def test_block_sparsity_with_baseline_quantization(self):
122115
"""Test that block sparsity with baseline quantization returns BlockSparseWeightConfig"""
123116
config = string_to_config("baseline", "block")

benchmarks/microbenchmarks/utils.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from torchao.core.config import AOBaseConfig
1616
from torchao.quantization import (
17-
Float8DynamicActivationFloat8SemiSparseWeightConfig,
1817
Float8DynamicActivationFloat8WeightConfig,
1918
Float8WeightOnlyConfig,
2019
GemliteUIntXWeightOnlyConfig,
@@ -23,7 +22,6 @@
2322
MappingType,
2423
PerRow,
2524
PerTensor,
26-
UIntXWeightOnlyConfig,
2725
)
2826
from torchao.sparsity.sparse_api import BlockSparseWeightConfig, SemiSparseWeightConfig
2927

@@ -192,30 +190,7 @@ def string_to_config(
192190
return Int8DynamicActivationInt8WeightConfig(weight_only_decode=True)
193191
else:
194192
return Int8DynamicActivationInt8WeightConfig()
195-
if "uintx" in quantization:
196-
# uintx-nbits-group_size, e.g. "uintx-2-64"
197-
if "hqq" in quantization:
198-
# uintx-nbits-group_size-hqq
199-
use_hqq = True
200-
else:
201-
use_hqq = False
202-
_quant_args = quantization.split("-")
203-
nbits = int(_quant_args[1])
204-
assert nbits >= 1 and nbits <= 8, "nbits must be 1 to 8"
205-
_NBITS_TO_DTYPE = {
206-
1: torch.uint1,
207-
2: torch.uint2,
208-
3: torch.uint3,
209-
4: torch.uint4,
210-
5: torch.uint5,
211-
6: torch.uint6,
212-
7: torch.uint7,
213-
8: torch.uint8,
214-
}
215-
dtype = _NBITS_TO_DTYPE[nbits]
216-
group_size = int(_quant_args[2])
217-
return UIntXWeightOnlyConfig(dtype, group_size, use_hqq=use_hqq)
218-
elif "int8_dynamic_activation_intx_weight" in quantization:
193+
if "int8_dynamic_activation_intx_weight" in quantization:
219194
assert high_precision_dtype == torch.float32, (
220195
"int8_dynamic_activation_intx_weight requires using high_precision_dtype=torch.float32"
221196
)
@@ -242,8 +217,6 @@ def string_to_config(
242217
elif "float8wo" in quantization:
243218
return Float8WeightOnlyConfig()
244219
elif "float8dq" in quantization:
245-
if sparsity and "semi" in sparsity:
246-
return Float8DynamicActivationFloat8SemiSparseWeightConfig()
247220
granularity = str(quantization.split("-")[-1])
248221
if granularity == "tensor":
249222
granularity = PerTensor()

docs/source/eager_tutorials/finetuning.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ is optional:
7676
7777
# Fine-tuning with QAT, by default:
7878
# activations are fake quantized to asymmetric per token int8
79-
# weights are fake quantized to symmetric per group int4
79+
# weights are fake quantized to symmetric per group int4
8080
# configurable through "quantizer._component_" in the command
8181
tune run --nnodes 1 --nproc_per_node 4 qat_distributed --config llama3_2/3B_qat_full batch_size=16
8282
@@ -205,13 +205,13 @@ because we are not actually casting the fake quantized values.
205205

206206
.. code:: py
207207
208-
from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
208+
from torchao.quantization import quantize_, Int4WeightOnlyConfig
209209
from torchao.quantization.qat import QATConfig
210210
211211
model = get_model()
212212
213213
# prepare: swap `torch.nn.Linear` -> `FakeQuantizedLinear`
214-
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
214+
base_config = Int4WeightOnlyConfig(group_size=32)
215215
quantize_(model, QATConfig(base_config, step="prepare"))
216216
217217
# fine-tune
@@ -225,7 +225,7 @@ The next step is to actually quantize the model:
225225

226226
.. code:: py
227227
228-
from torchao.quantization import Int8DynamicActivationInt4WeightConfig
228+
from torchao.quantization import Int4WeightOnlyConfig
229229
230230
# convert: swap `FakeQuantizedLinear` -> `torch.nn.Linear`, then quantize using `base_config`
231231
quantize_(model, QATConfig(base_config, step="convert"))
@@ -381,7 +381,7 @@ for fine-tuning Llama3.2-3B in float8:
381381
fp8_tensorwise 7222.198 (+11.074%) 30.010 (-0.266%)
382382
fp8_rowwise 6387.968 (-1.756%) 29.158 (-3.096%)
383383
fp8_rowwise_with_gw_hp 7573.698 (+16.480%) 29.516 (-1.908%)
384-
384+
385385
experiment_name hellaswag_acc wikitext_word_perplexity
386386
---------------------- --------------- --------------------------
387387
bf16 0.533 (+0.000) 12.407 (+0.000)
Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
18
import torch.nn as nn
29

3-
from torchao.prototype.quantization import Int8DynamicActivationInt4WeightConfig
4-
from torchao.quantization import quantize_
10+
from torchao.quantization import Int8DynamicActivationIntxWeightConfig, quantize_
11+
from torchao.quantization.granularity import PerGroup
512

613
model = nn.Sequential(nn.Linear(2048, 2048, device="cuda"))
7-
quantize_(model, Int8DynamicActivationInt4WeightConfig())
14+
quantize_(
15+
model,
16+
Int8DynamicActivationIntxWeightConfig(
17+
weight_dtype=torch.int4, weight_granularity=PerGroup(32)
18+
),
19+
)

docs/source/workflows/qat.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,19 +78,18 @@ the corresponding fake quantization configs to use.
7878
2. **Convert:** quantize the model using the base config provided
7979

8080
Currently only the following PTQ base configs are supported:
81-
- [`Int8DynamicActivationInt4WeightConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int8DynamicActivationInt4WeightConfig.html)
8281
- [`Int4WeightOnlyConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int4WeightOnlyConfig.html)
8382

8483
For example (most use cases):
8584

8685
```python
87-
from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
86+
from torchao.quantization import quantize_, Int4WeightOnlyConfig
8887
from torchao.quantization.qat import QATConfig
8988

9089
model = get_model()
9190

9291
# prepare: swap `torch.nn.Linear` -> `FakeQuantizedLinear`
93-
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
92+
base_config = Int4WeightOnlyConfig(group_size=32)
9493
quantize_(model, QATConfig(base_config, step="prepare"))
9594

9695
# train
@@ -109,7 +108,7 @@ and/or weights. For example, the following usage is numerically equivalent
109108
to the above:
110109

111110
```python
112-
from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
111+
from torchao.quantization import quantize_, Int4WeightOnlyConfig
113112
from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig
114113

115114
model = get_model()

test/core/test_config.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,9 @@
3434
Float8WeightOnlyConfig,
3535
GemliteUIntXWeightOnlyConfig,
3636
Int4WeightOnlyConfig,
37-
Int8DynamicActivationInt4WeightConfig,
3837
Int8DynamicActivationInt8WeightConfig,
3938
Int8WeightOnlyConfig,
4039
ModuleFqnToConfig,
41-
UIntXWeightOnlyConfig,
4240
quantize_,
4341
)
4442
from torchao.quantization.quantize_.common.quantization_step import QuantizationStep
@@ -56,7 +54,6 @@
5654
Float8WeightOnlyConfig(
5755
weight_dtype=torch.float8_e4m3fn,
5856
),
59-
UIntXWeightOnlyConfig(dtype=torch.uint1),
6057
Float8DynamicActivationInt4WeightConfig(),
6158
Int4WeightOnlyConfig(
6259
group_size=32,
@@ -67,19 +64,11 @@
6764
int4_choose_qparams_algorithm="hqq",
6865
version=2,
6966
),
70-
Int8DynamicActivationInt4WeightConfig(
71-
group_size=64,
72-
),
7367
Int8DynamicActivationInt8WeightConfig(),
7468
# Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()),
7569
Int8WeightOnlyConfig(
7670
group_size=128,
7771
),
78-
UIntXWeightOnlyConfig(
79-
dtype=torch.uint3,
80-
group_size=32,
81-
use_hqq=True,
82-
),
8372
GemliteUIntXWeightOnlyConfig(
8473
group_size=128, # Optional, has default of 64
8574
bit_width=8, # Optional, has default of 4
@@ -92,7 +81,7 @@
9281
ModuleFqnToConfig(
9382
{
9483
"linear1": Int4WeightOnlyConfig(),
95-
"linear2": Int8DynamicActivationInt4WeightConfig(),
84+
"linear2": Int8DynamicActivationInt8WeightConfig(),
9685
}
9786
),
9887
AWQConfig(

test/dtypes/test_affine_quantized.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from torchao.quantization import (
2525
Float8WeightOnlyConfig,
2626
GemliteUIntXWeightOnlyConfig,
27-
Int8DynamicActivationInt4WeightConfig,
2827
Int8DynamicActivationInt8WeightConfig,
2928
Int8WeightOnlyConfig,
3029
quantize_,
@@ -49,7 +48,6 @@ def get_quantization_functions(
4948
):
5049
base_functions = [
5150
Int8WeightOnlyConfig(),
52-
Int8DynamicActivationInt4WeightConfig(),
5351
Int8DynamicActivationInt8WeightConfig(),
5452
Int8DynamicActivationInt8WeightConfig(act_mapping_type=MappingType.ASYMMETRIC),
5553
]

test/dtypes/test_uintx.py

Lines changed: 1 addition & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111

1212
from torchao.prototype.dtypes.uintx.uintx_layout import to_uintx
13-
from torchao.quantization.quant_api import UIntXWeightOnlyConfig, quantize_
13+
from torchao.quantization.quant_api import quantize_ # noqa: F401
1414
from torchao.quantization.quant_primitives import (
1515
MappingType,
1616
choose_qparams_affine,
@@ -60,38 +60,6 @@ def forward(self, x):
6060
return self.net(x)
6161

6262

63-
@pytest.mark.parametrize("dtype", dtypes)
64-
@pytest.mark.parametrize("group_size", group_sizes)
65-
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
66-
def test_uintx_quant_on_cpu_then_move_to_cuda(dtype, group_size):
67-
scale = 512
68-
fp16_mod_on_cpu = Linear16(scale, "cpu")
69-
device = get_current_accelerator_device()
70-
quantize_(fp16_mod_on_cpu, UIntXWeightOnlyConfig(dtype, group_size=group_size))
71-
test_input_on_cpu = torch.randn(scale * 2, dtype=torch.float16, device="cpu")
72-
output_on_cpu = fp16_mod_on_cpu(test_input_on_cpu)
73-
fp16_mod_on_cuda = fp16_mod_on_cpu.to(device)
74-
test_input_on_cuda = test_input_on_cpu.to(device)
75-
output_on_cuda = fp16_mod_on_cuda(test_input_on_cuda)
76-
assert torch.allclose(output_on_cpu, output_on_cuda.cpu(), atol=1.0e-3), (
77-
"The output of the model on CPU and CUDA should be close"
78-
)
79-
80-
81-
@pytest.mark.parametrize("dtype", dtypes)
82-
@pytest.mark.parametrize("group_size", group_sizes)
83-
@pytest.mark.parametrize("device", devices)
84-
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
85-
def test_uintx_weight_only_model_quant(dtype, group_size, device):
86-
scale = 512
87-
fp16 = Linear16(scale, device)
88-
quantize_(fp16, UIntXWeightOnlyConfig(dtype, group_size=group_size))
89-
uintx = torch.compile(fp16, fullgraph=True)
90-
test_input = torch.randn(scale * 2, dtype=torch.float16, device=device)
91-
output = uintx.forward(test_input)
92-
assert output is not None, "model quantization failed"
93-
94-
9563
@pytest.mark.parametrize("dtype", dtypes)
9664
@pytest.mark.parametrize("group_size", group_sizes)
9765
@pytest.mark.parametrize("device", devices)
@@ -128,55 +96,6 @@ def test_uintx_weight_only_quant(dtype, group_size, device):
12896
assert deqaunt is not None, "deqauntization failed"
12997

13098

131-
@pytest.mark.parametrize("dtype", dtypes)
132-
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="Need GPU available")
133-
def test_uintx_target_dtype(dtype):
134-
device = get_current_accelerator_device()
135-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=device)
136-
# make sure it runs
137-
quantize_(linear, UIntXWeightOnlyConfig(dtype))
138-
linear(torch.randn(1, 128, dtype=torch.bfloat16, device=device))
139-
140-
141-
@pytest.mark.parametrize("dtype", dtypes)
142-
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="Need GPU available")
143-
def test_uintx_target_dtype_compile(dtype):
144-
device = get_current_accelerator_device()
145-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=device)
146-
# make sure it runs
147-
quantize_(linear, UIntXWeightOnlyConfig(dtype))
148-
linear = torch.compile(linear)
149-
linear(torch.randn(1, 128, dtype=torch.bfloat16, device=device))
150-
151-
152-
@pytest.mark.parametrize("dtype", dtypes)
153-
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="Need GPU available")
154-
def test_uintx_model_size(dtype):
155-
from torchao.utils import get_model_size_in_bytes
156-
157-
# scale size = 1/64 * 2 bytes = 1/32 bytes
158-
# zero_point size = 1/64 * 4 bytes = 1/16 bytes
159-
# dtype data size = 1 * bit_width/8 = bit_width/8 bytes
160-
_dtype_to_ratio = {
161-
torch.uint1: (1 / 8 + 1 / 16 + 1 / 32) / 2,
162-
torch.uint2: (2 / 8 + 1 / 16 + 1 / 32) / 2,
163-
torch.uint3: (3 / 8 + 1 / 16 + 1 / 32) / 2,
164-
torch.uint4: (4 / 8 + 1 / 16 + 1 / 32) / 2,
165-
torch.uint5: (5 / 8 + 1 / 16 + 1 / 32) / 2,
166-
torch.uint6: (6 / 8 + 1 / 16 + 1 / 32) / 2,
167-
torch.uint7: (7 / 8 + 1 / 16 + 1 / 32) / 2,
168-
}
169-
device = get_current_accelerator_device()
170-
linear = torch.nn.Sequential(
171-
torch.nn.Linear(128, 256, bias=False, dtype=torch.bfloat16, device=device)
172-
)
173-
bf16_size = get_model_size_in_bytes(linear)
174-
# make sure it runs
175-
quantize_(linear[0], UIntXWeightOnlyConfig(dtype))
176-
quantized_size = get_model_size_in_bytes(linear)
177-
assert bf16_size * _dtype_to_ratio[dtype] == quantized_size
178-
179-
18099
def test_uintx_api_deprecation():
181100
"""
182101
Test that deprecated uintx APIs trigger deprecation warnings on import.

0 commit comments

Comments
 (0)