Skip to content

Commit 811745f

Browse files
committed
Add UIntxBitPackedTensor, UIntxWeightOnlyConfig, and Int8DynamicActivationUIntxWeightConfig
Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based GemliteUIntXWeightOnlyConfig path. - UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(), and aten.linear/t/slice dispatch implementations - UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit) - Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight - Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes ghstack-source-id: 5cdbbe1 Pull Request resolved: #4081
1 parent ab4a336 commit 811745f

4 files changed

Lines changed: 639 additions & 0 deletions

File tree

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# This source code is licensed under the license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import unittest
7+
8+
import torch
9+
from torch.testing._internal.common_utils import TestCase, run_tests
10+
11+
from torchao.quantization import quantize_
12+
13+
try:
14+
import gemlite # noqa: F401
15+
16+
has_gemlite = True
17+
except ModuleNotFoundError:
18+
has_gemlite = False
19+
20+
21+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
22+
@unittest.skipIf(not has_gemlite, "gemlite not available")
23+
class TestUIntxBitPackedTensor(TestCase):
24+
def _test_quantize_and_linear(self, bit_width, group_size, packing_bitwidth):
25+
"""Helper: quantize a linear layer and verify forward pass produces valid output."""
26+
from torchao.prototype.quantization.quant_api import UIntxWeightOnlyConfig
27+
28+
in_features = 512
29+
out_features = 256
30+
model = torch.nn.Linear(in_features, out_features, bias=False).to(
31+
device="cuda", dtype=torch.float16
32+
)
33+
34+
config = UIntxWeightOnlyConfig(
35+
group_size=group_size,
36+
bit_width=bit_width,
37+
packing_bitwidth=packing_bitwidth,
38+
)
39+
quantize_(model, config)
40+
41+
# Verify weight is now UIntxBitPackedTensor
42+
from torchao.prototype.quantization.uintx.uintx_bit_packed_tensor import (
43+
UIntxBitPackedTensor,
44+
)
45+
46+
self.assertIsInstance(model.weight, UIntxBitPackedTensor)
47+
48+
# Verify forward pass works
49+
x = torch.randn(2, in_features, device="cuda", dtype=torch.float16)
50+
out = model(x)
51+
self.assertEqual(out.shape, (2, out_features))
52+
self.assertFalse(torch.isnan(out).any())
53+
self.assertFalse(torch.isinf(out).any())
54+
55+
def test_4bit_group64_pack32(self):
56+
self._test_quantize_and_linear(bit_width=4, group_size=64, packing_bitwidth=32)
57+
58+
def test_4bit_group128_pack32(self):
59+
self._test_quantize_and_linear(bit_width=4, group_size=128, packing_bitwidth=32)
60+
61+
def test_4bit_group64_pack8(self):
62+
self._test_quantize_and_linear(bit_width=4, group_size=64, packing_bitwidth=8)
63+
64+
def test_8bit_perchannel_pack32(self):
65+
self._test_quantize_and_linear(
66+
bit_width=8, group_size=None, packing_bitwidth=32
67+
)
68+
69+
def test_8bit_perchannel_pack8(self):
70+
self._test_quantize_and_linear(bit_width=8, group_size=None, packing_bitwidth=8)
71+
72+
def test_4bit_dynamic_activation(self):
73+
"""Test dynamic activation quantization via Int8DynamicActivationUIntxWeightConfig."""
74+
from torchao.prototype.quantization.quant_api import (
75+
Int8DynamicActivationUIntxWeightConfig,
76+
)
77+
78+
model = torch.nn.Linear(512, 256, bias=False).to(
79+
device="cuda", dtype=torch.float16
80+
)
81+
config = Int8DynamicActivationUIntxWeightConfig(
82+
group_size=64, bit_width=4, packing_bitwidth=32
83+
)
84+
quantize_(model, config)
85+
86+
from torchao.prototype.quantization.uintx.uintx_bit_packed_tensor import (
87+
UIntxBitPackedTensor,
88+
)
89+
90+
self.assertIsInstance(model.weight, UIntxBitPackedTensor)
91+
92+
x = torch.randn(2, 512, device="cuda", dtype=torch.float16)
93+
out = model(x)
94+
self.assertEqual(out.shape, (2, 256))
95+
self.assertFalse(torch.isnan(out).any())
96+
self.assertFalse(torch.isinf(out).any())
97+
98+
def test_slice_dim0(self):
99+
"""Test narrow/slice on dim 0 (out_features) for tensor parallelism."""
100+
from torchao.prototype.quantization.quant_api import UIntxWeightOnlyConfig
101+
102+
model = torch.nn.Linear(512, 256, bias=False).to(
103+
device="cuda", dtype=torch.float16
104+
)
105+
quantize_(
106+
model,
107+
UIntxWeightOnlyConfig(group_size=64, bit_width=4, packing_bitwidth=32),
108+
)
109+
110+
sliced = model.weight.narrow(0, 0, 64)
111+
self.assertEqual(sliced.shape[0], 64)
112+
113+
def test_slice_dim1(self):
114+
"""Test narrow/slice on dim 1 (in_features) for tensor parallelism."""
115+
from torchao.prototype.quantization.quant_api import UIntxWeightOnlyConfig
116+
117+
model = torch.nn.Linear(512, 256, bias=False).to(
118+
device="cuda", dtype=torch.float16
119+
)
120+
quantize_(
121+
model,
122+
UIntxWeightOnlyConfig(group_size=64, bit_width=4, packing_bitwidth=32),
123+
)
124+
125+
sliced = model.weight.narrow(1, 0, 128)
126+
self.assertEqual(sliced.shape[1], 128)
127+
128+
def test_non_standard_shapes(self):
129+
"""Test shapes not divisible by 128 but divisible by 32 (gemlite requirement)."""
130+
from torchao.prototype.quantization.quant_api import UIntxWeightOnlyConfig
131+
132+
# gemlite requires in_features divisible by 32 or group_size
133+
model = torch.nn.Linear(1024, 1025, bias=False).to(
134+
device="cuda", dtype=torch.float16
135+
)
136+
config = UIntxWeightOnlyConfig(
137+
group_size=None, bit_width=4, packing_bitwidth=32
138+
)
139+
quantize_(model, config)
140+
141+
x = torch.randn(1, 1024, device="cuda", dtype=torch.float16)
142+
out = model(x)
143+
self.assertEqual(out.shape, (1, 1025))
144+
145+
def test_import_from_prototype_api(self):
146+
"""Verify UIntxWeightOnlyConfig is available from the prototype API."""
147+
from torchao.prototype.quantization.quant_api import (
148+
UIntxWeightOnlyConfig, # noqa: F401
149+
)
150+
151+
152+
if __name__ == "__main__":
153+
run_tests()

torchao/prototype/quantization/quant_api.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,105 @@ def _gemlite_uintx_weight_only_transform(
112112
return module
113113

114114

115+
@dataclass
116+
class UIntxWeightOnlyConfig(AOBaseConfig):
117+
"""Weight-only uintx quantization using bit-packed format with gemlite Triton kernels.
118+
119+
Supports 4-bit (asymmetric, grouped) and 8-bit (symmetric, per-channel) quantization.
120+
Uses gemlite library for efficient Triton-based GEMM.
121+
122+
Args:
123+
group_size: quantization group size. Use None for per-channel (required for 8-bit).
124+
Valid values: 32, 64, 128, 256, 512, 1024, None. Default: 128.
125+
bit_width: quantization bit width, 4 or 8. Default: 4.
126+
packing_bitwidth: bit width for packing, 8/16/32/None (auto). Default: None.
127+
set_inductor_config: if True, set recommended torchinductor config. Default: True.
128+
"""
129+
130+
group_size: Optional[int] = 128
131+
bit_width: int = 4
132+
packing_bitwidth: Optional[int] = None
133+
set_inductor_config: bool = True
134+
135+
def __post_init__(self):
136+
torch._C._log_api_usage_once("torchao.quantization.UIntxWeightOnlyConfig")
137+
138+
139+
@register_quantize_module_handler(UIntxWeightOnlyConfig)
140+
def _uintx_weight_only_transform(
141+
module: torch.nn.Module,
142+
config: UIntxWeightOnlyConfig,
143+
) -> torch.nn.Module:
144+
from torchao.prototype.quantization.uintx.uintx_bit_packed_tensor import (
145+
UIntxBitPackedTensor,
146+
)
147+
148+
if config.set_inductor_config:
149+
torchao.quantization.utils.recommended_inductor_config_setter()
150+
151+
weight = module.weight
152+
quantized_weight = UIntxBitPackedTensor.from_hp(
153+
weight,
154+
bit_width=config.bit_width,
155+
group_size=config.group_size,
156+
packing_bitwidth=config.packing_bitwidth,
157+
)
158+
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
159+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
160+
return module
161+
162+
163+
@dataclass
164+
class Int8DynamicActivationUIntxWeightConfig(AOBaseConfig):
165+
"""Dynamic activation + uintx weight quantization using gemlite Triton kernels.
166+
167+
Activations are quantized dynamically at runtime (int8). Weights use bit-packed
168+
uintx format. Supports 4-bit and 8-bit weight quantization.
169+
170+
Args:
171+
group_size: quantization group size. Use None for per-channel (required for 8-bit).
172+
Valid values: 32, 64, 128, 256, 512, 1024, None. Default: 128.
173+
bit_width: weight quantization bit width, 4 or 8. Default: 4.
174+
packing_bitwidth: bit width for packing, 8/16/32/None (auto). Default: None.
175+
set_inductor_config: if True, set recommended torchinductor config. Default: True.
176+
"""
177+
178+
group_size: Optional[int] = 128
179+
bit_width: int = 4
180+
packing_bitwidth: Optional[int] = None
181+
set_inductor_config: bool = True
182+
183+
def __post_init__(self):
184+
torch._C._log_api_usage_once(
185+
"torchao.quantization.Int8DynamicActivationUIntxWeightConfig"
186+
)
187+
188+
189+
@register_quantize_module_handler(Int8DynamicActivationUIntxWeightConfig)
190+
def _int8_dynamic_activation_uintx_weight_transform(
191+
module: torch.nn.Module,
192+
config: Int8DynamicActivationUIntxWeightConfig,
193+
) -> torch.nn.Module:
194+
from torchao.prototype.quantization.uintx.uintx_bit_packed_tensor import (
195+
UIntxBitPackedTensor,
196+
)
197+
198+
if config.set_inductor_config:
199+
torchao.quantization.utils.recommended_inductor_config_setter()
200+
201+
weight = module.weight
202+
quantized_weight = UIntxBitPackedTensor.from_hp(
203+
weight,
204+
bit_width=config.bit_width,
205+
group_size=config.group_size,
206+
packing_bitwidth=config.packing_bitwidth,
207+
mode="dynamic",
208+
)
209+
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
210+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
211+
return module
212+
213+
115214
@dataclass
116215
class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
117216
"""
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# This source code is licensed under the license found in the
4+
# LICENSE file in the root directory of this source tree.

0 commit comments

Comments
 (0)