Skip to content

Commit 07a5f54

Browse files
committed
Add UIntxBitPackedTensor, UIntxWeightOnlyConfig, and Int8DynamicActivationUIntxWeightConfig
Add v2 tensor subclass UIntxBitPackedTensor(TorchAOBaseTensor) using gemlite bit-packing and Triton GEMM kernels, replacing the old AQT-based GemliteUIntXWeightOnlyConfig path. - UIntxBitPackedTensor: tensor subclass with from_hp(), dequantize(), and aten.linear/t/slice dispatch implementations - UIntxWeightOnlyConfig: weight-only quantization (4-bit/8-bit) - Int8DynamicActivationUIntxWeightConfig: int8 dynamic activation + uintx weight - Tests for both configs covering 4-bit, 8-bit, slice, and non-standard shapes Test Plan: - python test/prototype/test_uintx_bit_packed_tensor.py - Tests cover UIntxWeightOnlyConfig: 4-bit (group64/128, pack32/8), 8-bit (perchannel, pack32/8) - Tests cover Int8DynamicActivationUIntxWeightConfig: same bit_width/group_size/packing combos - Tests cover slice dim0/dim1 for tensor parallelism - Tests cover non-standard shapes (1024x1025) - Verified backward compat: old GemliteUIntXWeightOnlyConfig still works ghstack-source-id: de8cbe1 Pull Request resolved: #4082
1 parent ab4a336 commit 07a5f54

4 files changed

Lines changed: 699 additions & 0 deletions

File tree

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# This source code is licensed under the license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import unittest
7+
8+
import torch
9+
from torch.testing._internal.common_utils import TestCase, run_tests
10+
11+
from torchao.quantization import quantize_
12+
13+
try:
14+
import gemlite # noqa: F401
15+
16+
has_gemlite = True
17+
except ModuleNotFoundError:
18+
has_gemlite = False
19+
20+
21+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
22+
@unittest.skipIf(not has_gemlite, "gemlite not available")
23+
class TestUIntxBitPackedTensor(TestCase):
24+
def _test_quantize_and_linear(self, bit_width, group_size, packing_bitwidth):
25+
"""Helper: quantize a linear layer and verify forward pass produces valid output."""
26+
from torchao.prototype.quantization.quant_api import UIntxWeightOnlyConfig
27+
28+
in_features = 512
29+
out_features = 256
30+
model = torch.nn.Linear(in_features, out_features, bias=False).to(
31+
device="cuda", dtype=torch.float16
32+
)
33+
34+
config = UIntxWeightOnlyConfig(
35+
group_size=group_size,
36+
bit_width=bit_width,
37+
packing_bitwidth=packing_bitwidth,
38+
)
39+
quantize_(model, config)
40+
41+
# Verify weight is now UIntxBitPackedTensor
42+
from torchao.prototype.quantization.uintx.uintx_bit_packed_tensor import (
43+
UIntxBitPackedTensor,
44+
)
45+
46+
self.assertIsInstance(model.weight, UIntxBitPackedTensor)
47+
48+
# Verify forward pass works
49+
x = torch.randn(2, in_features, device="cuda", dtype=torch.float16)
50+
out = model(x)
51+
self.assertEqual(out.shape, (2, out_features))
52+
self.assertFalse(torch.isnan(out).any())
53+
self.assertFalse(torch.isinf(out).any())
54+
55+
def test_4bit_group64_pack32(self):
56+
self._test_quantize_and_linear(bit_width=4, group_size=64, packing_bitwidth=32)
57+
58+
def test_4bit_group128_pack32(self):
59+
self._test_quantize_and_linear(bit_width=4, group_size=128, packing_bitwidth=32)
60+
61+
def test_4bit_group64_pack8(self):
62+
self._test_quantize_and_linear(bit_width=4, group_size=64, packing_bitwidth=8)
63+
64+
def test_8bit_perchannel_pack32(self):
65+
self._test_quantize_and_linear(
66+
bit_width=8, group_size=None, packing_bitwidth=32
67+
)
68+
69+
def test_8bit_perchannel_pack8(self):
70+
self._test_quantize_and_linear(bit_width=8, group_size=None, packing_bitwidth=8)
71+
72+
def _test_dynamic_quantize_and_linear(
73+
self, bit_width, group_size, packing_bitwidth
74+
):
75+
"""Helper: quantize with dynamic activation and verify forward pass."""
76+
from torchao.prototype.quantization.quant_api import (
77+
Int8DynamicActivationUIntxWeightConfig,
78+
)
79+
80+
in_features = 512
81+
out_features = 256
82+
model = torch.nn.Linear(in_features, out_features, bias=False).to(
83+
device="cuda", dtype=torch.float16
84+
)
85+
86+
config = Int8DynamicActivationUIntxWeightConfig(
87+
group_size=group_size,
88+
bit_width=bit_width,
89+
packing_bitwidth=packing_bitwidth,
90+
)
91+
quantize_(model, config)
92+
93+
from torchao.prototype.quantization.uintx.uintx_bit_packed_tensor import (
94+
UIntxBitPackedTensor,
95+
)
96+
97+
self.assertIsInstance(model.weight, UIntxBitPackedTensor)
98+
99+
x = torch.randn(2, in_features, device="cuda", dtype=torch.float16)
100+
out = model(x)
101+
self.assertEqual(out.shape, (2, out_features))
102+
self.assertFalse(torch.isnan(out).any())
103+
self.assertFalse(torch.isinf(out).any())
104+
105+
def test_dynamic_4bit_group64_pack32(self):
106+
self._test_dynamic_quantize_and_linear(
107+
bit_width=4, group_size=64, packing_bitwidth=32
108+
)
109+
110+
def test_dynamic_4bit_group128_pack32(self):
111+
self._test_dynamic_quantize_and_linear(
112+
bit_width=4, group_size=128, packing_bitwidth=32
113+
)
114+
115+
def test_dynamic_4bit_group64_pack8(self):
116+
self._test_dynamic_quantize_and_linear(
117+
bit_width=4, group_size=64, packing_bitwidth=8
118+
)
119+
120+
def test_dynamic_8bit_perchannel_pack32(self):
121+
self._test_dynamic_quantize_and_linear(
122+
bit_width=8, group_size=None, packing_bitwidth=32
123+
)
124+
125+
def test_dynamic_8bit_perchannel_pack8(self):
126+
self._test_dynamic_quantize_and_linear(
127+
bit_width=8, group_size=None, packing_bitwidth=8
128+
)
129+
130+
def test_slice_dim0(self):
131+
"""Test narrow/slice on dim 0 (out_features) for tensor parallelism."""
132+
from torchao.prototype.quantization.quant_api import UIntxWeightOnlyConfig
133+
134+
model = torch.nn.Linear(512, 256, bias=False).to(
135+
device="cuda", dtype=torch.float16
136+
)
137+
quantize_(
138+
model,
139+
UIntxWeightOnlyConfig(group_size=64, bit_width=4, packing_bitwidth=32),
140+
)
141+
142+
weight = model.weight
143+
sliced = weight.narrow(0, 0, 64)
144+
self.assertEqual(sliced.shape[0], 64)
145+
146+
# Verify internal tensors match direct slicing
147+
# Data is stored transposed (K x N), so logical dim 0 -> data dim 1
148+
self.assertEqual(
149+
sliced.packed_weight,
150+
weight.packed_weight.narrow(1, 0, 64),
151+
)
152+
self.assertEqual(
153+
sliced.scale,
154+
weight.scale.narrow(1, 0, 64),
155+
)
156+
157+
def test_slice_dim1(self):
158+
"""Test narrow/slice on dim 1 (in_features) for tensor parallelism."""
159+
from torchao.prototype.quantization.quant_api import UIntxWeightOnlyConfig
160+
161+
model = torch.nn.Linear(512, 256, bias=False).to(
162+
device="cuda", dtype=torch.float16
163+
)
164+
quantize_(
165+
model,
166+
UIntxWeightOnlyConfig(group_size=64, bit_width=4, packing_bitwidth=32),
167+
)
168+
169+
weight = model.weight
170+
sliced = weight.narrow(1, 0, 128)
171+
self.assertEqual(sliced.shape[1], 128)
172+
173+
# Verify internal tensors match direct slicing
174+
# Data is stored transposed (K x N), so logical dim 1 -> data dim 0
175+
# packed_weight dim 0 is packed by elements_per_sample
176+
eps = weight.gemlite_kwargs["elements_per_sample"]
177+
self.assertEqual(
178+
sliced.packed_weight,
179+
weight.packed_weight.narrow(0, 0, 128 // eps),
180+
)
181+
# scale dim 0 corresponds to groups along in_features
182+
scale_ratio = 128 // 64 # in_features_slice / group_size
183+
self.assertEqual(
184+
sliced.scale,
185+
weight.scale.narrow(0, 0, scale_ratio),
186+
)
187+
188+
def test_non_standard_shapes(self):
189+
"""Test shapes not divisible by 128 but divisible by 32 (gemlite requirement)."""
190+
from torchao.prototype.quantization.quant_api import UIntxWeightOnlyConfig
191+
192+
# gemlite requires in_features divisible by 32 or group_size
193+
model = torch.nn.Linear(1024, 1025, bias=False).to(
194+
device="cuda", dtype=torch.float16
195+
)
196+
config = UIntxWeightOnlyConfig(
197+
group_size=None, bit_width=4, packing_bitwidth=32
198+
)
199+
quantize_(model, config)
200+
201+
x = torch.randn(1, 1024, device="cuda", dtype=torch.float16)
202+
out = model(x)
203+
self.assertEqual(out.shape, (1, 1025))
204+
205+
206+
if __name__ == "__main__":
207+
run_tests()

torchao/prototype/quantization/quant_api.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,111 @@ def _gemlite_uintx_weight_only_transform(
112112
return module
113113

114114

115+
@dataclass
116+
class UIntxWeightOnlyConfig(AOBaseConfig):
117+
"""Weight-only uintx quantization using bit-packed format with gemlite Triton kernels.
118+
119+
Supports 4-bit (asymmetric, grouped) and 8-bit (symmetric, per-channel) quantization.
120+
Uses gemlite library for efficient Triton-based GEMM.
121+
122+
Args:
123+
group_size: quantization group size. Use None for per-channel (required for 8-bit).
124+
Valid values: 32, 64, 128, 256, 512, 1024, None. Default: 128.
125+
bit_width: quantization bit width, 4 or 8. Default: 4.
126+
packing_bitwidth: bit width for packing, 8/16/32/None (auto). Default: None.
127+
set_inductor_config: if True, set recommended torchinductor config. Default: True.
128+
"""
129+
130+
group_size: Optional[int] = 128
131+
bit_width: int = 4
132+
packing_bitwidth: Optional[int] = None
133+
set_inductor_config: bool = True
134+
135+
def __post_init__(self):
136+
torch._C._log_api_usage_once("torchao.quantization.UIntxWeightOnlyConfig")
137+
assert self.bit_width in [4, 8], (
138+
f"bit_width must be 4 or 8, got {self.bit_width}"
139+
)
140+
141+
142+
@register_quantize_module_handler(UIntxWeightOnlyConfig)
143+
def _uintx_weight_only_transform(
144+
module: torch.nn.Module,
145+
config: UIntxWeightOnlyConfig,
146+
) -> torch.nn.Module:
147+
from torchao.prototype.quantization.uintx.uintx_bit_packed_tensor import (
148+
UIntxBitPackedTensor,
149+
)
150+
151+
if config.set_inductor_config:
152+
torchao.quantization.utils.recommended_inductor_config_setter()
153+
154+
weight = module.weight
155+
quantized_weight = UIntxBitPackedTensor.from_hp(
156+
weight,
157+
bit_width=config.bit_width,
158+
group_size=config.group_size,
159+
packing_bitwidth=config.packing_bitwidth,
160+
)
161+
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
162+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
163+
return module
164+
165+
166+
@dataclass
167+
class Int8DynamicActivationUIntxWeightConfig(AOBaseConfig):
168+
"""Dynamic activation + uintx weight quantization using gemlite Triton kernels.
169+
170+
Activations are quantized dynamically at runtime (int8). Weights use bit-packed
171+
uintx format. Supports 4-bit and 8-bit weight quantization.
172+
173+
Args:
174+
group_size: quantization group size. Use None for per-channel (required for 8-bit).
175+
Valid values: 32, 64, 128, 256, 512, 1024, None. Default: 128.
176+
bit_width: weight quantization bit width, 4 or 8. Default: 4.
177+
packing_bitwidth: bit width for packing, 8/16/32/None (auto). Default: None.
178+
set_inductor_config: if True, set recommended torchinductor config. Default: True.
179+
"""
180+
181+
group_size: Optional[int] = 128
182+
bit_width: int = 4
183+
packing_bitwidth: Optional[int] = None
184+
set_inductor_config: bool = True
185+
186+
def __post_init__(self):
187+
torch._C._log_api_usage_once(
188+
"torchao.quantization.Int8DynamicActivationUIntxWeightConfig"
189+
)
190+
assert self.bit_width in [4, 8], (
191+
f"bit_width must be 4 or 8, got {self.bit_width}"
192+
)
193+
194+
195+
@register_quantize_module_handler(Int8DynamicActivationUIntxWeightConfig)
196+
def _int8_dynamic_activation_uintx_weight_transform(
197+
module: torch.nn.Module,
198+
config: Int8DynamicActivationUIntxWeightConfig,
199+
) -> torch.nn.Module:
200+
from torchao.prototype.quantization.uintx.uintx_bit_packed_tensor import (
201+
UIntxBitPackedTensor,
202+
)
203+
204+
if config.set_inductor_config:
205+
torchao.quantization.utils.recommended_inductor_config_setter()
206+
207+
weight = module.weight
208+
quantized_weight = UIntxBitPackedTensor.from_hp(
209+
weight,
210+
bit_width=config.bit_width,
211+
group_size=config.group_size,
212+
packing_bitwidth=config.packing_bitwidth,
213+
mode="dynamic",
214+
)
215+
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
216+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
217+
return module
218+
219+
115220
@dataclass
116221
class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
117222
"""
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# This source code is licensed under the license found in the
4+
# LICENSE file in the root directory of this source tree.

0 commit comments

Comments
 (0)