Skip to content
119 changes: 119 additions & 0 deletions test/quantization/quantize_/workflows/int8/test_int8_tensor_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import copy

import torch
from torch.testing._internal import common_utils

from torchao.quantization import (
Int8DynamicActivationInt8WeightConfig,
Int8StaticActivationInt8WeightConfig,
quantize_,
)
from torchao.quantization.granularity import PerRow, PerTensor
from torchao.quantization.quant_primitives import (
_DTYPE_TO_QVALUE_BOUNDS,
MappingType,
choose_qparams_affine,
)
from torchao.quantization.quantize_.workflows.int8.int8_tensor import (
Int8Tensor,
should_reduce_range,
)
from torchao.quantization.utils import compute_error, get_block_size
from torchao.testing.model_architectures import ToyTwoLinearModel
from torchao.testing.utils import TorchAOIntegrationTestCase
from torchao.utils import (
is_ROCM,
)


@common_utils.instantiate_parametrized_tests
class TestInt8TensorCPU(TorchAOIntegrationTestCase):
# Note: The reduce_range parameter can be manually set by users via the config.
# This UT only tests automatic reduce_range to avoid CI failures on CPUs without VNNI support.
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize("config_mode", ["dynamic", "static"])
@common_utils.parametrize(
"granularity",
[PerRow(), PerTensor(), (PerRow(), PerTensor()), (PerTensor(), PerRow())],
)
@common_utils.parametrize(
"act_mapping_type", [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]
)
def test_int8_tensor_cpu(
self, act_mapping_type, granularity, config_mode, compile, dtype
):
device = "cpu"
if is_ROCM():
self.skipTest("Don't test CPU for ROCM version of torch")

torch.compiler.reset()

M, N, K = 64, 256, 256
input_tensor = torch.randn(M, K, dtype=dtype, device=device)
model = ToyTwoLinearModel(K, N, K, dtype=dtype, device=device).eval()
model_q = copy.deepcopy(model)
reduce_range = should_reduce_range(input_tensor.device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the effect when reduce_range is not set correctly for cpu? should we test that as well?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments. On CPUs without VNNI, when using torch.compile with int8 matmul, there may be accuracy degradation compared to eager mode due to overflow in oneDNN qlinear. The reduce_range setting is designed to prevent this.
We don't need to test incorrect settings because this UT focuses on the recommended path, and I've already provided the helper function should_reduce_range() and clear UT Notes.


if config_mode == "dynamic":
config = Int8DynamicActivationInt8WeightConfig(
version=2,
granularity=granularity,
act_mapping_type=act_mapping_type,
reduce_range=reduce_range,
)
else:
act_granularity, _ = Int8Tensor._normalize_granularity(granularity)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add assert for config_mode == "static"

quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[torch.int8]
if reduce_range:
quant_min, quant_max = quant_min // 2, quant_max // 2
block_size = get_block_size(input_tensor.shape, act_granularity)
act_quant_scale, act_quant_zero_point = choose_qparams_affine(
input=input_tensor,
mapping_type=act_mapping_type,
block_size=block_size,
target_dtype=torch.int8,
quant_min=quant_min,
quant_max=quant_max,
scale_dtype=dtype,
zero_point_dtype=torch.int8,
keepdim=True,
eps=torch.finfo(torch.float32).eps,
)
config = Int8StaticActivationInt8WeightConfig(
act_quant_scale=act_quant_scale,
act_quant_zero_point=act_quant_zero_point,
granularity=granularity,
act_mapping_type=act_mapping_type,
reduce_range=reduce_range,
)

quantize_(model_q, config)

_, weight_granularity = Int8Tensor._normalize_granularity(config.granularity)
if isinstance(weight_granularity, PerRow):
self.assertEqual(model_q.linear2.weight.scale.shape, (K, 1))
elif isinstance(weight_granularity, PerTensor):
self.assertEqual(model_q.linear2.weight.scale.shape, (1, 1))

self.assertEqual(model_q.linear2.weight.scale.ndim, 2)

if compile:
model_q = torch.compile(model_q, fullgraph=True)

output_fp = model(input_tensor)
output_quantized = model_q(input_tensor)

assert compute_error(output_fp, output_quantized) > 20, (
f"Quantization error is too high got a SQNR of {compute_error(output_fp, output_quantized)}"
)


if __name__ == "__main__":
common_utils.run_tests()
19 changes: 5 additions & 14 deletions torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from torch._dynamo import is_compiling as dynamo_is_compiling
from torch._higher_order_ops.out_dtype import out_dtype

from torchao.utils import _is_device, torch_version_at_least
from torchao.utils import (
_cpu_is_vnni_supported,
_is_device,
torch_version_at_least,
)

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
Expand Down Expand Up @@ -118,19 +122,6 @@ def _cpu_is_amx_tile_supported() -> bool:
return False


def _cpu_is_vnni_supported() -> bool:
"""
Safely query AVX512_VNNI support, guarding against private API absence.
torch.cpu._is_vnni_supported / torch._C._cpu._is_vnni_supported are
private and may be missing in certain PyTorch builds or versions.
"""
if hasattr(torch._C._cpu, "_is_vnni_supported"):
return torch._C._cpu._is_vnni_supported()
elif hasattr(torch.cpu, "_is_vnni_supported"):
return torch.cpu._is_vnni_supported()
return False


def _int_scaled_matmul_cpu(
a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor
) -> torch.Tensor:
Expand Down
18 changes: 18 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,9 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
SYMMETRIC and ASYMMETRIC are supported.
set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values
for better performance with this quantization scheme.
version (int): the version of the config
reduce_range (Optional[bool] = False): If True, use reduced activation and weight quantization ranges
to avoid overflow on CPU without VNNI. Users can call should_reduce_range() to help determine.

Example:

Expand All @@ -857,6 +860,7 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
] = PerRow()
set_inductor_config: bool = True
version: int = 2
reduce_range: Optional[bool] = False

def __post_init__(self):
torch._C._log_api_usage_once(
Expand All @@ -881,6 +885,9 @@ def __post_init__(self):
"Please set it to MappingType.SYMMETRIC or "
"MappingType.ASYMMETRIC."
)
assert self.reduce_range in (True, False, None), (
"`reduce_range` must be True, False, or None"
)


def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
Expand All @@ -895,7 +902,9 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
act_quant_kwargs=QuantizeTensorToInt8Kwargs(
granularity=act_granularity,
mapping_type=config.act_mapping_type,
reduce_range=config.reduce_range,
),
reduce_range=config.reduce_range,
)

return quantized_weight
Expand Down Expand Up @@ -949,6 +958,8 @@ class Int8StaticActivationInt8WeightConfig(AOBaseConfig):
act_mapping_type (MappingType): The mapping type for activation quantization. SYMMETRIC and ASYMMETRIC are supported.
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
version (int): the version of the config
reduce_range (Optional[bool] = False): If True, use reduced activation and weight quantization ranges
to avoid overflow on CPU without VNNI. Users can call should_reduce_range() to help determine.
"""

act_quant_scale: Optional[torch.Tensor] = None
Expand All @@ -959,6 +970,7 @@ class Int8StaticActivationInt8WeightConfig(AOBaseConfig):
act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC
set_inductor_config: bool = True
version: int = 1
reduce_range: Optional[bool] = False

def __post_init__(self):
torch._C._log_api_usage_once(
Expand All @@ -978,6 +990,9 @@ def __post_init__(self):
"Please set it to MappingType.SYMMETRIC or "
"MappingType.ASYMMETRIC."
)
assert self.reduce_range in (True, False, None), (
"`reduce_range` must be True, False, or None"
)

def get_act_quant_kwargs(self) -> QuantizeTensorToInt8Kwargs:
"""Get the activation quantization kwargs for static quantization.
Expand All @@ -992,6 +1007,7 @@ def get_act_quant_kwargs(self) -> QuantizeTensorToInt8Kwargs:
return QuantizeTensorToInt8Kwargs(
granularity=act_granularity,
mapping_type=self.act_mapping_type,
reduce_range=self.reduce_range,
)


Expand Down Expand Up @@ -1022,9 +1038,11 @@ def _int8_static_activation_int8_weight_transform(
act_quant_kwargs=QuantizeTensorToInt8Kwargs(
granularity=activation_granularity,
mapping_type=config.act_mapping_type,
reduce_range=config.reduce_range,
),
act_quant_scale=config.act_quant_scale.detach(),
act_quant_zero_point=act_quant_zero_point,
reduce_range=config.reduce_range,
)

setattr(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def _choose_quant_func_and_quantize_tensor(
mapping_type=quant_kwargs.mapping_type,
scale=scale,
zero_point=zero_point,
reduce_range=quant_kwargs.reduce_range,
)

raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}")
Loading