Skip to content

Commit 8572ca8

Browse files
committed
Update on "Refactor use_triton_kernel to use nvfp4_quantize_kernel_choice"
Summary: This is to prefer the addition of flashinfer quantize kernel path in next PR Test Plan: python test/prototype/mx_formats/test_inference_workflow.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
2 parents 9681579 + 1815c88 commit 8572ca8

28 files changed

Lines changed: 131 additions & 1111 deletions

benchmarks/benchmark_e2e_fp8_sparse_linear.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
)
1515
from torchao.prototype.sparsity.activation.utils import SquaredReLU
1616
from torchao.quantization import (
17-
Float8DynamicActivationFloat8SemiSparseWeightConfig,
1817
Float8DynamicActivationFloat8WeightConfig,
1918
Float8MMConfig,
2019
PerRow,
@@ -84,20 +83,6 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192):
8483
ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True)
8584
fp8_c_time = benchmark_microseconds(ffn_clone, input_tensor)
8685

87-
# fp8 sparse
88-
ffn_clone = (
89-
nn.Sequential(
90-
nn.Linear(hidden_size, intermediate_size, bias=False),
91-
SquaredReLU(),
92-
nn.Linear(intermediate_size, hidden_size, bias=False),
93-
)
94-
.to(torch.bfloat16)
95-
.cuda()
96-
)
97-
quantize_(ffn_clone, Float8DynamicActivationFloat8SemiSparseWeightConfig())
98-
ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True)
99-
fp8_c_sparse_time = benchmark_microseconds(ffn_clone, input_tensor)
100-
10186
# activation fp8 sparse
10287
ffn_clone = (
10388
nn.Sequential(
@@ -127,7 +112,6 @@ def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192):
127112
"bf16_latency (us)": fp16_time,
128113
"bf16_c_latency (us)": fp16_c_time,
129114
"fp8_c_time (us)": fp8_c_time,
130-
"fp8_c_sparse_time (us)": fp8_c_sparse_time,
131115
"fp8_c_activation_sparse_time (us)": fp8_c_activation_sparse_time,
132116
"ao_fast_sparsification_time (us)": ao_fast_sparsification_time,
133117
"cusparselt_compress_time (us)": cusparselt_time,

benchmarks/microbenchmarks/test/test_utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
BenchmarkConfig,
1414
BenchmarkResult,
1515
BlockSparseWeightConfig,
16-
Float8DynamicActivationFloat8SemiSparseWeightConfig,
1716
Int4WeightOnlyConfig,
1817
SemiSparseWeightConfig,
1918
clean_caches,
@@ -112,12 +111,6 @@ def test_string_to_config_sparsity(self):
112111
config = string_to_config("marlin", "semi-sparse")
113112
self.assertIsInstance(config, Int4WeightOnlyConfig)
114113

115-
# Test float8 with semi-sparse
116-
config = string_to_config("float8dq", "semi-sparse")
117-
self.assertIsInstance(
118-
config, Float8DynamicActivationFloat8SemiSparseWeightConfig
119-
)
120-
121114
def test_block_sparsity_with_baseline_quantization(self):
122115
"""Test that block sparsity with baseline quantization returns BlockSparseWeightConfig"""
123116
config = string_to_config("baseline", "block")

benchmarks/microbenchmarks/utils.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from torchao.core.config import AOBaseConfig
1616
from torchao.quantization import (
17-
Float8DynamicActivationFloat8SemiSparseWeightConfig,
1817
Float8DynamicActivationFloat8WeightConfig,
1918
Float8WeightOnlyConfig,
2019
GemliteUIntXWeightOnlyConfig,
@@ -23,7 +22,6 @@
2322
MappingType,
2423
PerRow,
2524
PerTensor,
26-
UIntXWeightOnlyConfig,
2725
)
2826
from torchao.sparsity.sparse_api import BlockSparseWeightConfig, SemiSparseWeightConfig
2927

@@ -192,30 +190,7 @@ def string_to_config(
192190
return Int8DynamicActivationInt8WeightConfig(weight_only_decode=True)
193191
else:
194192
return Int8DynamicActivationInt8WeightConfig()
195-
if "uintx" in quantization:
196-
# uintx-nbits-group_size, e.g. "uintx-2-64"
197-
if "hqq" in quantization:
198-
# uintx-nbits-group_size-hqq
199-
use_hqq = True
200-
else:
201-
use_hqq = False
202-
_quant_args = quantization.split("-")
203-
nbits = int(_quant_args[1])
204-
assert nbits >= 1 and nbits <= 8, "nbits must be 1 to 8"
205-
_NBITS_TO_DTYPE = {
206-
1: torch.uint1,
207-
2: torch.uint2,
208-
3: torch.uint3,
209-
4: torch.uint4,
210-
5: torch.uint5,
211-
6: torch.uint6,
212-
7: torch.uint7,
213-
8: torch.uint8,
214-
}
215-
dtype = _NBITS_TO_DTYPE[nbits]
216-
group_size = int(_quant_args[2])
217-
return UIntXWeightOnlyConfig(dtype, group_size, use_hqq=use_hqq)
218-
elif "int8_dynamic_activation_intx_weight" in quantization:
193+
if "int8_dynamic_activation_intx_weight" in quantization:
219194
assert high_precision_dtype == torch.float32, (
220195
"int8_dynamic_activation_intx_weight requires using high_precision_dtype=torch.float32"
221196
)
@@ -242,8 +217,6 @@ def string_to_config(
242217
elif "float8wo" in quantization:
243218
return Float8WeightOnlyConfig()
244219
elif "float8dq" in quantization:
245-
if sparsity and "semi" in sparsity:
246-
return Float8DynamicActivationFloat8SemiSparseWeightConfig()
247220
granularity = str(quantization.split("-")[-1])
248221
if granularity == "tensor":
249222
granularity = PerTensor()

docs/source/eager_tutorials/finetuning.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ is optional:
7676
7777
# Fine-tuning with QAT, by default:
7878
# activations are fake quantized to asymmetric per token int8
79-
# weights are fake quantized to symmetric per group int4
79+
# weights are fake quantized to symmetric per group int4
8080
# configurable through "quantizer._component_" in the command
8181
tune run --nnodes 1 --nproc_per_node 4 qat_distributed --config llama3_2/3B_qat_full batch_size=16
8282
@@ -205,13 +205,13 @@ because we are not actually casting the fake quantized values.
205205

206206
.. code:: py
207207
208-
from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
208+
from torchao.quantization import quantize_, Int4WeightOnlyConfig
209209
from torchao.quantization.qat import QATConfig
210210
211211
model = get_model()
212212
213213
# prepare: swap `torch.nn.Linear` -> `FakeQuantizedLinear`
214-
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
214+
base_config = Int4WeightOnlyConfig(group_size=32)
215215
quantize_(model, QATConfig(base_config, step="prepare"))
216216
217217
# fine-tune
@@ -225,7 +225,7 @@ The next step is to actually quantize the model:
225225

226226
.. code:: py
227227
228-
from torchao.quantization import Int8DynamicActivationInt4WeightConfig
228+
from torchao.quantization import Int4WeightOnlyConfig
229229
230230
# convert: swap `FakeQuantizedLinear` -> `torch.nn.Linear`, then quantize using `base_config`
231231
quantize_(model, QATConfig(base_config, step="convert"))
@@ -381,7 +381,7 @@ for fine-tuning Llama3.2-3B in float8:
381381
fp8_tensorwise 7222.198 (+11.074%) 30.010 (-0.266%)
382382
fp8_rowwise 6387.968 (-1.756%) 29.158 (-3.096%)
383383
fp8_rowwise_with_gw_hp 7573.698 (+16.480%) 29.516 (-1.908%)
384-
384+
385385
experiment_name hellaswag_acc wikitext_word_perplexity
386386
---------------------- --------------- --------------------------
387387
bf16 0.533 (+0.000) 12.407 (+0.000)
Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
18
import torch.nn as nn
29

3-
from torchao.prototype.quantization import Int8DynamicActivationInt4WeightConfig
4-
from torchao.quantization import quantize_
10+
from torchao.quantization import Int8DynamicActivationIntxWeightConfig, quantize_
11+
from torchao.quantization.granularity import PerGroup
512

613
model = nn.Sequential(nn.Linear(2048, 2048, device="cuda"))
7-
quantize_(model, Int8DynamicActivationInt4WeightConfig())
14+
quantize_(
15+
model,
16+
Int8DynamicActivationIntxWeightConfig(
17+
weight_dtype=torch.int4, weight_granularity=PerGroup(32)
18+
),
19+
)

docs/source/workflows/qat.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,19 +78,18 @@ the corresponding fake quantization configs to use.
7878
2. **Convert:** quantize the model using the base config provided
7979

8080
Currently only the following PTQ base configs are supported:
81-
- [`Int8DynamicActivationInt4WeightConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int8DynamicActivationInt4WeightConfig.html)
8281
- [`Int4WeightOnlyConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int4WeightOnlyConfig.html)
8382

8483
For example (most use cases):
8584

8685
```python
87-
from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
86+
from torchao.quantization import quantize_, Int4WeightOnlyConfig
8887
from torchao.quantization.qat import QATConfig
8988

9089
model = get_model()
9190

9291
# prepare: swap `torch.nn.Linear` -> `FakeQuantizedLinear`
93-
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
92+
base_config = Int4WeightOnlyConfig(group_size=32)
9493
quantize_(model, QATConfig(base_config, step="prepare"))
9594

9695
# train
@@ -109,7 +108,7 @@ and/or weights. For example, the following usage is numerically equivalent
109108
to the above:
110109

111110
```python
112-
from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
111+
from torchao.quantization import quantize_, Int4WeightOnlyConfig
113112
from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig
114113

115114
model = get_model()

test/core/test_config.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,9 @@
3434
Float8WeightOnlyConfig,
3535
GemliteUIntXWeightOnlyConfig,
3636
Int4WeightOnlyConfig,
37-
Int8DynamicActivationInt4WeightConfig,
3837
Int8DynamicActivationInt8WeightConfig,
3938
Int8WeightOnlyConfig,
4039
ModuleFqnToConfig,
41-
UIntXWeightOnlyConfig,
4240
quantize_,
4341
)
4442
from torchao.quantization.quantize_.common.quantization_step import QuantizationStep
@@ -56,7 +54,6 @@
5654
Float8WeightOnlyConfig(
5755
weight_dtype=torch.float8_e4m3fn,
5856
),
59-
UIntXWeightOnlyConfig(dtype=torch.uint1),
6057
Float8DynamicActivationInt4WeightConfig(),
6158
Int4WeightOnlyConfig(
6259
group_size=32,
@@ -67,19 +64,11 @@
6764
int4_choose_qparams_algorithm="hqq",
6865
version=2,
6966
),
70-
Int8DynamicActivationInt4WeightConfig(
71-
group_size=64,
72-
),
7367
Int8DynamicActivationInt8WeightConfig(),
7468
# Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()),
7569
Int8WeightOnlyConfig(
7670
group_size=128,
7771
),
78-
UIntXWeightOnlyConfig(
79-
dtype=torch.uint3,
80-
group_size=32,
81-
use_hqq=True,
82-
),
8372
GemliteUIntXWeightOnlyConfig(
8473
group_size=128, # Optional, has default of 64
8574
bit_width=8, # Optional, has default of 4
@@ -92,7 +81,7 @@
9281
ModuleFqnToConfig(
9382
{
9483
"linear1": Int4WeightOnlyConfig(),
95-
"linear2": Int8DynamicActivationInt4WeightConfig(),
84+
"linear2": Int8DynamicActivationInt8WeightConfig(),
9685
}
9786
),
9887
AWQConfig(

test/dtypes/test_affine_quantized.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from torchao.quantization import (
2525
Float8WeightOnlyConfig,
2626
GemliteUIntXWeightOnlyConfig,
27-
Int8DynamicActivationInt4WeightConfig,
2827
Int8DynamicActivationInt8WeightConfig,
2928
Int8WeightOnlyConfig,
3029
quantize_,
@@ -49,7 +48,6 @@ def get_quantization_functions(
4948
):
5049
base_functions = [
5150
Int8WeightOnlyConfig(),
52-
Int8DynamicActivationInt4WeightConfig(),
5351
Int8DynamicActivationInt8WeightConfig(),
5452
Int8DynamicActivationInt8WeightConfig(act_mapping_type=MappingType.ASYMMETRIC),
5553
]

test/dtypes/test_uintx.py

Lines changed: 1 addition & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111

1212
from torchao.prototype.dtypes.uintx.uintx_layout import to_uintx
13-
from torchao.quantization.quant_api import UIntXWeightOnlyConfig, quantize_
13+
from torchao.quantization.quant_api import quantize_ # noqa: F401
1414
from torchao.quantization.quant_primitives import (
1515
MappingType,
1616
choose_qparams_affine,
@@ -60,38 +60,6 @@ def forward(self, x):
6060
return self.net(x)
6161

6262

63-
@pytest.mark.parametrize("dtype", dtypes)
64-
@pytest.mark.parametrize("group_size", group_sizes)
65-
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
66-
def test_uintx_quant_on_cpu_then_move_to_cuda(dtype, group_size):
67-
scale = 512
68-
fp16_mod_on_cpu = Linear16(scale, "cpu")
69-
device = get_current_accelerator_device()
70-
quantize_(fp16_mod_on_cpu, UIntXWeightOnlyConfig(dtype, group_size=group_size))
71-
test_input_on_cpu = torch.randn(scale * 2, dtype=torch.float16, device="cpu")
72-
output_on_cpu = fp16_mod_on_cpu(test_input_on_cpu)
73-
fp16_mod_on_cuda = fp16_mod_on_cpu.to(device)
74-
test_input_on_cuda = test_input_on_cpu.to(device)
75-
output_on_cuda = fp16_mod_on_cuda(test_input_on_cuda)
76-
assert torch.allclose(output_on_cpu, output_on_cuda.cpu(), atol=1.0e-3), (
77-
"The output of the model on CPU and CUDA should be close"
78-
)
79-
80-
81-
@pytest.mark.parametrize("dtype", dtypes)
82-
@pytest.mark.parametrize("group_size", group_sizes)
83-
@pytest.mark.parametrize("device", devices)
84-
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
85-
def test_uintx_weight_only_model_quant(dtype, group_size, device):
86-
scale = 512
87-
fp16 = Linear16(scale, device)
88-
quantize_(fp16, UIntXWeightOnlyConfig(dtype, group_size=group_size))
89-
uintx = torch.compile(fp16, fullgraph=True)
90-
test_input = torch.randn(scale * 2, dtype=torch.float16, device=device)
91-
output = uintx.forward(test_input)
92-
assert output is not None, "model quantization failed"
93-
94-
9563
@pytest.mark.parametrize("dtype", dtypes)
9664
@pytest.mark.parametrize("group_size", group_sizes)
9765
@pytest.mark.parametrize("device", devices)
@@ -128,55 +96,6 @@ def test_uintx_weight_only_quant(dtype, group_size, device):
12896
assert deqaunt is not None, "deqauntization failed"
12997

13098

131-
@pytest.mark.parametrize("dtype", dtypes)
132-
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="Need GPU available")
133-
def test_uintx_target_dtype(dtype):
134-
device = get_current_accelerator_device()
135-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=device)
136-
# make sure it runs
137-
quantize_(linear, UIntXWeightOnlyConfig(dtype))
138-
linear(torch.randn(1, 128, dtype=torch.bfloat16, device=device))
139-
140-
141-
@pytest.mark.parametrize("dtype", dtypes)
142-
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="Need GPU available")
143-
def test_uintx_target_dtype_compile(dtype):
144-
device = get_current_accelerator_device()
145-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=device)
146-
# make sure it runs
147-
quantize_(linear, UIntXWeightOnlyConfig(dtype))
148-
linear = torch.compile(linear)
149-
linear(torch.randn(1, 128, dtype=torch.bfloat16, device=device))
150-
151-
152-
@pytest.mark.parametrize("dtype", dtypes)
153-
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="Need GPU available")
154-
def test_uintx_model_size(dtype):
155-
from torchao.utils import get_model_size_in_bytes
156-
157-
# scale size = 1/64 * 2 bytes = 1/32 bytes
158-
# zero_point size = 1/64 * 4 bytes = 1/16 bytes
159-
# dtype data size = 1 * bit_width/8 = bit_width/8 bytes
160-
_dtype_to_ratio = {
161-
torch.uint1: (1 / 8 + 1 / 16 + 1 / 32) / 2,
162-
torch.uint2: (2 / 8 + 1 / 16 + 1 / 32) / 2,
163-
torch.uint3: (3 / 8 + 1 / 16 + 1 / 32) / 2,
164-
torch.uint4: (4 / 8 + 1 / 16 + 1 / 32) / 2,
165-
torch.uint5: (5 / 8 + 1 / 16 + 1 / 32) / 2,
166-
torch.uint6: (6 / 8 + 1 / 16 + 1 / 32) / 2,
167-
torch.uint7: (7 / 8 + 1 / 16 + 1 / 32) / 2,
168-
}
169-
device = get_current_accelerator_device()
170-
linear = torch.nn.Sequential(
171-
torch.nn.Linear(128, 256, bias=False, dtype=torch.bfloat16, device=device)
172-
)
173-
bf16_size = get_model_size_in_bytes(linear)
174-
# make sure it runs
175-
quantize_(linear[0], UIntXWeightOnlyConfig(dtype))
176-
quantized_size = get_model_size_in_bytes(linear)
177-
assert bf16_size * _dtype_to_ratio[dtype] == quantized_size
178-
179-
18099
def test_uintx_api_deprecation():
181100
"""
182101
Test that deprecated uintx APIs trigger deprecation warnings on import.

0 commit comments

Comments
 (0)