Skip to content

Commit 2203ec9

Browse files
committed
Migrate GemliteUIntXWeightOnlyConfig to v2 design (UIntxWeightOnlyConfig)
Replace the old AQT-based GemliteUIntXWeightOnlyConfig with a new v2 UIntxWeightOnlyConfig backed by UIntxBitPackedTensor(TorchAOBaseTensor), removing the dependency on AffineQuantizedTensor. - Add UIntxBitPackedTensor tensor subclass with from_hp(), dequantize(), and aten.linear/t/slice dispatch implementations - Add UIntxWeightOnlyConfig and handler in prototype quant_api - Export UIntxWeightOnlyConfig from public quantization API - Add comprehensive tests (4-bit, 8-bit, dynamic, slice, non-standard shapes) - Old GemliteUIntXWeightOnlyConfig remains as deprecated alias ghstack-source-id: cce3e06 Pull Request resolved: #4079
1 parent ab4a336 commit 2203ec9

6 files changed

Lines changed: 589 additions & 0 deletions

File tree

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# This source code is licensed under the license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import unittest
7+
8+
import torch
9+
from torch.testing._internal.common_utils import TestCase, run_tests
10+
11+
from torchao.quantization import quantize_
12+
13+
try:
14+
import gemlite # noqa: F401
15+
16+
has_gemlite = True
17+
except ModuleNotFoundError:
18+
has_gemlite = False
19+
20+
21+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
22+
@unittest.skipIf(not has_gemlite, "gemlite not available")
23+
class TestUIntxBitPackedTensor(TestCase):
24+
def _test_quantize_and_linear(
25+
self, bit_width, group_size, packing_bitwidth, mode="weight_only"
26+
):
27+
"""Helper: quantize a linear layer and verify forward pass produces valid output."""
28+
from torchao.prototype.quantization.quant_api import UIntxWeightOnlyConfig
29+
30+
in_features = 512
31+
out_features = 256
32+
model = torch.nn.Linear(in_features, out_features, bias=False).to(
33+
device="cuda", dtype=torch.float16
34+
)
35+
36+
config = UIntxWeightOnlyConfig(
37+
group_size=group_size,
38+
bit_width=bit_width,
39+
packing_bitwidth=packing_bitwidth,
40+
mode=mode,
41+
)
42+
quantize_(model, config)
43+
44+
# Verify weight is now UIntxBitPackedTensor
45+
from torchao.prototype.quantization.uintx.uintx_bit_packed_tensor import (
46+
UIntxBitPackedTensor,
47+
)
48+
49+
self.assertIsInstance(model.weight, UIntxBitPackedTensor)
50+
51+
# Verify forward pass works
52+
x = torch.randn(2, in_features, device="cuda", dtype=torch.float16)
53+
out = model(x)
54+
self.assertEqual(out.shape, (2, out_features))
55+
self.assertFalse(torch.isnan(out).any())
56+
self.assertFalse(torch.isinf(out).any())
57+
58+
def test_4bit_group64_pack32(self):
59+
self._test_quantize_and_linear(bit_width=4, group_size=64, packing_bitwidth=32)
60+
61+
def test_4bit_group128_pack32(self):
62+
self._test_quantize_and_linear(bit_width=4, group_size=128, packing_bitwidth=32)
63+
64+
def test_4bit_group64_pack8(self):
65+
self._test_quantize_and_linear(bit_width=4, group_size=64, packing_bitwidth=8)
66+
67+
def test_8bit_perchannel_pack32(self):
68+
self._test_quantize_and_linear(
69+
bit_width=8, group_size=None, packing_bitwidth=32
70+
)
71+
72+
def test_8bit_perchannel_pack8(self):
73+
self._test_quantize_and_linear(bit_width=8, group_size=None, packing_bitwidth=8)
74+
75+
def test_4bit_dynamic_mode(self):
76+
self._test_quantize_and_linear(
77+
bit_width=4, group_size=64, packing_bitwidth=32, mode="dynamic"
78+
)
79+
80+
def test_slice_dim0(self):
81+
"""Test narrow/slice on dim 0 (out_features) for tensor parallelism."""
82+
from torchao.prototype.quantization.quant_api import UIntxWeightOnlyConfig
83+
84+
model = torch.nn.Linear(512, 256, bias=False).to(
85+
device="cuda", dtype=torch.float16
86+
)
87+
quantize_(
88+
model,
89+
UIntxWeightOnlyConfig(group_size=64, bit_width=4, packing_bitwidth=32),
90+
)
91+
92+
sliced = model.weight.narrow(0, 0, 64)
93+
self.assertEqual(sliced.shape[0], 64)
94+
95+
def test_slice_dim1(self):
96+
"""Test narrow/slice on dim 1 (in_features) for tensor parallelism."""
97+
from torchao.prototype.quantization.quant_api import UIntxWeightOnlyConfig
98+
99+
model = torch.nn.Linear(512, 256, bias=False).to(
100+
device="cuda", dtype=torch.float16
101+
)
102+
quantize_(
103+
model,
104+
UIntxWeightOnlyConfig(group_size=64, bit_width=4, packing_bitwidth=32),
105+
)
106+
107+
sliced = model.weight.narrow(1, 0, 128)
108+
self.assertEqual(sliced.shape[1], 128)
109+
110+
def test_non_standard_shapes(self):
111+
"""Test shapes not divisible by 128 but divisible by 32 (gemlite requirement)."""
112+
from torchao.prototype.quantization.quant_api import UIntxWeightOnlyConfig
113+
114+
# gemlite requires in_features divisible by 32 or group_size
115+
model = torch.nn.Linear(1024, 1025, bias=False).to(
116+
device="cuda", dtype=torch.float16
117+
)
118+
config = UIntxWeightOnlyConfig(
119+
group_size=None, bit_width=4, packing_bitwidth=32
120+
)
121+
quantize_(model, config)
122+
123+
x = torch.randn(1, 1024, device="cuda", dtype=torch.float16)
124+
out = model(x)
125+
self.assertEqual(out.shape, (1, 1025))
126+
127+
def test_import_from_public_api(self):
128+
"""Verify UIntxWeightOnlyConfig is available from the public API."""
129+
from torchao.quantization import UIntxWeightOnlyConfig # noqa: F401
130+
131+
132+
if __name__ == "__main__":
133+
run_tests()

torchao/prototype/quantization/quant_api.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,57 @@ def _gemlite_uintx_weight_only_transform(
112112
return module
113113

114114

115+
@dataclass
116+
class UIntxWeightOnlyConfig(AOBaseConfig):
117+
"""Weight-only uintx quantization using bit-packed format with gemlite Triton kernels.
118+
119+
Supports 4-bit (asymmetric, grouped) and 8-bit (symmetric, per-channel) quantization.
120+
Uses gemlite library for efficient Triton-based GEMM.
121+
122+
Args:
123+
group_size: quantization group size. Use None for per-channel (required for 8-bit).
124+
Valid values: 32, 64, 128, 256, 512, 1024, None. Default: 128.
125+
bit_width: quantization bit width, 4 or 8. Default: 4.
126+
packing_bitwidth: bit width for packing, 8/16/32/None (auto). Default: None.
127+
mode: "weight_only" (default) or "dynamic" (quantize activations at runtime).
128+
set_inductor_config: if True, set recommended torchinductor config. Default: True.
129+
"""
130+
131+
group_size: Optional[int] = 128
132+
bit_width: int = 4
133+
packing_bitwidth: Optional[int] = None
134+
mode: str = "weight_only"
135+
set_inductor_config: bool = True
136+
137+
def __post_init__(self):
138+
torch._C._log_api_usage_once("torchao.quantization.UIntxWeightOnlyConfig")
139+
140+
141+
@register_quantize_module_handler(UIntxWeightOnlyConfig)
142+
def _uintx_weight_only_transform(
143+
module: torch.nn.Module,
144+
config: UIntxWeightOnlyConfig,
145+
) -> torch.nn.Module:
146+
from torchao.prototype.quantization.uintx.uintx_bit_packed_tensor import (
147+
UIntxBitPackedTensor,
148+
)
149+
150+
if config.set_inductor_config:
151+
torchao.quantization.utils.recommended_inductor_config_setter()
152+
153+
weight = module.weight
154+
quantized_weight = UIntxBitPackedTensor.from_hp(
155+
weight,
156+
bit_width=config.bit_width,
157+
group_size=config.group_size,
158+
packing_bitwidth=config.packing_bitwidth,
159+
mode=config.mode,
160+
)
161+
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
162+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
163+
return module
164+
165+
115166
@dataclass
116167
class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
117168
"""
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# This source code is licensed under the license found in the
4+
# LICENSE file in the root directory of this source tree.

0 commit comments

Comments
 (0)