Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions test/prototype/test_uintx_bit_packed_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from torch.testing._internal.common_utils import TestCase, run_tests

from torchao.quantization import quantize_

try:
import gemlite # noqa: F401

has_gemlite = True
except ModuleNotFoundError:
has_gemlite = False


@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not has_gemlite, "gemlite not available")
class TestUIntxBitPackedTensor(TestCase):
def _test_quantize_and_linear(self, bit_width, group_size, packing_bitwidth):
"""Helper: quantize a linear layer and verify forward pass produces valid output."""
from torchao.prototype.quantization.quant_api import UIntxWeightOnlyConfig

in_features = 512
out_features = 256
model = torch.nn.Linear(in_features, out_features, bias=False).to(
device="cuda", dtype=torch.float16
)

config = UIntxWeightOnlyConfig(
group_size=group_size,
bit_width=bit_width,
packing_bitwidth=packing_bitwidth,
)
quantize_(model, config)

# Verify weight is now UIntxBitPackedTensor
from torchao.prototype.quantization.uintx.uintx_bit_packed_tensor import (
UIntxBitPackedTensor,
)

self.assertIsInstance(model.weight, UIntxBitPackedTensor)

# Verify forward pass works
x = torch.randn(2, in_features, device="cuda", dtype=torch.float16)
out = model(x)
self.assertEqual(out.shape, (2, out_features))
self.assertFalse(torch.isnan(out).any())
self.assertFalse(torch.isinf(out).any())

def test_4bit_group64_pack32(self):
self._test_quantize_and_linear(bit_width=4, group_size=64, packing_bitwidth=32)

def test_4bit_group128_pack32(self):
self._test_quantize_and_linear(bit_width=4, group_size=128, packing_bitwidth=32)

def test_4bit_group64_pack8(self):
self._test_quantize_and_linear(bit_width=4, group_size=64, packing_bitwidth=8)

def test_8bit_perchannel_pack32(self):
self._test_quantize_and_linear(
bit_width=8, group_size=None, packing_bitwidth=32
)

def test_8bit_perchannel_pack8(self):
self._test_quantize_and_linear(bit_width=8, group_size=None, packing_bitwidth=8)

def test_4bit_dynamic_activation(self):
"""Test dynamic activation quantization via Int8DynamicActivationUIntxWeightConfig."""
from torchao.prototype.quantization.quant_api import (
Int8DynamicActivationUIntxWeightConfig,
)

model = torch.nn.Linear(512, 256, bias=False).to(
device="cuda", dtype=torch.float16
)
config = Int8DynamicActivationUIntxWeightConfig(
group_size=64, bit_width=4, packing_bitwidth=32
)
quantize_(model, config)

from torchao.prototype.quantization.uintx.uintx_bit_packed_tensor import (
UIntxBitPackedTensor,
)

self.assertIsInstance(model.weight, UIntxBitPackedTensor)

x = torch.randn(2, 512, device="cuda", dtype=torch.float16)
out = model(x)
self.assertEqual(out.shape, (2, 256))
self.assertFalse(torch.isnan(out).any())
self.assertFalse(torch.isinf(out).any())

def test_slice_dim0(self):
"""Test narrow/slice on dim 0 (out_features) for tensor parallelism."""
from torchao.prototype.quantization.quant_api import UIntxWeightOnlyConfig

model = torch.nn.Linear(512, 256, bias=False).to(
device="cuda", dtype=torch.float16
)
quantize_(
model,
UIntxWeightOnlyConfig(group_size=64, bit_width=4, packing_bitwidth=32),
)

sliced = model.weight.narrow(0, 0, 64)
self.assertEqual(sliced.shape[0], 64)

def test_slice_dim1(self):
"""Test narrow/slice on dim 1 (in_features) for tensor parallelism."""
from torchao.prototype.quantization.quant_api import UIntxWeightOnlyConfig

model = torch.nn.Linear(512, 256, bias=False).to(
device="cuda", dtype=torch.float16
)
quantize_(
model,
UIntxWeightOnlyConfig(group_size=64, bit_width=4, packing_bitwidth=32),
)

sliced = model.weight.narrow(1, 0, 128)
self.assertEqual(sliced.shape[1], 128)

def test_non_standard_shapes(self):
"""Test shapes not divisible by 128 but divisible by 32 (gemlite requirement)."""
from torchao.prototype.quantization.quant_api import UIntxWeightOnlyConfig

# gemlite requires in_features divisible by 32 or group_size
model = torch.nn.Linear(1024, 1025, bias=False).to(
device="cuda", dtype=torch.float16
)
config = UIntxWeightOnlyConfig(
group_size=None, bit_width=4, packing_bitwidth=32
)
quantize_(model, config)

x = torch.randn(1, 1024, device="cuda", dtype=torch.float16)
out = model(x)
self.assertEqual(out.shape, (1, 1025))

def test_import_from_prototype_api(self):
"""Verify UIntxWeightOnlyConfig is available from the prototype API."""
from torchao.prototype.quantization.quant_api import (
UIntxWeightOnlyConfig, # noqa: F401
)


if __name__ == "__main__":
run_tests()
99 changes: 99 additions & 0 deletions torchao/prototype/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,105 @@ def _gemlite_uintx_weight_only_transform(
return module


@dataclass
class UIntxWeightOnlyConfig(AOBaseConfig):
"""Weight-only uintx quantization using bit-packed format with gemlite Triton kernels.

Supports 4-bit (asymmetric, grouped) and 8-bit (symmetric, per-channel) quantization.
Uses gemlite library for efficient Triton-based GEMM.

Args:
group_size: quantization group size. Use None for per-channel (required for 8-bit).
Valid values: 32, 64, 128, 256, 512, 1024, None. Default: 128.
bit_width: quantization bit width, 4 or 8. Default: 4.
packing_bitwidth: bit width for packing, 8/16/32/None (auto). Default: None.
set_inductor_config: if True, set recommended torchinductor config. Default: True.
"""

group_size: Optional[int] = 128
bit_width: int = 4
packing_bitwidth: Optional[int] = None
set_inductor_config: bool = True

def __post_init__(self):
torch._C._log_api_usage_once("torchao.quantization.UIntxWeightOnlyConfig")


@register_quantize_module_handler(UIntxWeightOnlyConfig)
def _uintx_weight_only_transform(
module: torch.nn.Module,
config: UIntxWeightOnlyConfig,
) -> torch.nn.Module:
from torchao.prototype.quantization.uintx.uintx_bit_packed_tensor import (
UIntxBitPackedTensor,
)

if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

weight = module.weight
quantized_weight = UIntxBitPackedTensor.from_hp(
weight,
bit_width=config.bit_width,
group_size=config.group_size,
packing_bitwidth=config.packing_bitwidth,
)
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


@dataclass
class Int8DynamicActivationUIntxWeightConfig(AOBaseConfig):
"""Dynamic activation + uintx weight quantization using gemlite Triton kernels.

Activations are quantized dynamically at runtime (int8). Weights use bit-packed
uintx format. Supports 4-bit and 8-bit weight quantization.

Args:
group_size: quantization group size. Use None for per-channel (required for 8-bit).
Valid values: 32, 64, 128, 256, 512, 1024, None. Default: 128.
bit_width: weight quantization bit width, 4 or 8. Default: 4.
packing_bitwidth: bit width for packing, 8/16/32/None (auto). Default: None.
set_inductor_config: if True, set recommended torchinductor config. Default: True.
"""

group_size: Optional[int] = 128
bit_width: int = 4
packing_bitwidth: Optional[int] = None
set_inductor_config: bool = True

def __post_init__(self):
torch._C._log_api_usage_once(
"torchao.quantization.Int8DynamicActivationUIntxWeightConfig"
)


@register_quantize_module_handler(Int8DynamicActivationUIntxWeightConfig)
def _int8_dynamic_activation_uintx_weight_transform(
module: torch.nn.Module,
config: Int8DynamicActivationUIntxWeightConfig,
) -> torch.nn.Module:
from torchao.prototype.quantization.uintx.uintx_bit_packed_tensor import (
UIntxBitPackedTensor,
)

if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

weight = module.weight
quantized_weight = UIntxBitPackedTensor.from_hp(
weight,
bit_width=config.bit_width,
group_size=config.group_size,
packing_bitwidth=config.packing_bitwidth,
mode="dynamic",
)
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


@dataclass
class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
"""
Expand Down
4 changes: 4 additions & 0 deletions torchao/prototype/quantization/uintx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
Loading
Loading