Skip to content
86 changes: 85 additions & 1 deletion test/quantization/quantize_/workflows/int8/test_int8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,20 @@
from torch.testing import FileCheck
from torch.testing._internal import common_utils

from torchao.kernel.intmm import _cpu_is_vnni_supported
Comment thread
Xia-Weiwen marked this conversation as resolved.
Outdated
from torchao.quantization import (
Int8DynamicActivationInt8WeightConfig,
Int8StaticActivationInt8WeightConfig,
Int8WeightOnlyConfig,
quantize_,
)
from torchao.quantization.granularity import PerGroup, PerRow, PerTensor
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.quant_primitives import MappingType, choose_qparams_affine
from torchao.quantization.quantize_.common import (
_choose_quant_func_and_quantize_tensor,
)
from torchao.quantization.quantize_.workflows.int8.int8_tensor import (
Int8Tensor,
QuantizeTensorToInt8Kwargs,
)
from torchao.quantization.utils import compute_error, get_block_size
Expand Down Expand Up @@ -515,5 +517,87 @@ def test_int8_weight_only_v2_correct_eps(self, dtype, device):
self.assertGreater(sqnr, 40, f"SQNR too low: {sqnr}")


@common_utils.instantiate_parametrized_tests
class TestInt8TensorCPU(TorchAOIntegrationTestCase):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe create a new test_int8_tensor_cpu.py it will be easier to separate tests by device for CI in the future if we want.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for comments, updated.

@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize("config_mode", ["dynamic", "static"])
@common_utils.parametrize(
"granularity",
[PerRow(), PerTensor(), (PerRow(), PerTensor()), (PerTensor(), PerRow())],
)
@common_utils.parametrize(
"act_mapping_type", [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]
)
@common_utils.parametrize("reduce_range", [False, True])
def test_int8_tensor_cpu(
self, act_mapping_type, granularity, reduce_range, config_mode, compile, dtype
):
device = "cpu"
if is_ROCM():
self.skipTest("Don't test CPU for ROCM version of torch")
if not reduce_range and not _cpu_is_vnni_supported():
self.skipTest(
"Only test reduce_range=True on CPUs without VNNI support to avoid int8 overflow."
)

torch.compiler.reset()

M, N, K = 64, 256, 256
input_tensor = torch.randn(M, K, dtype=dtype, device=device)
model = ToyTwoLinearModel(K, N, K, dtype=dtype, device=device).eval()
model_q = copy.deepcopy(model)

if config_mode == "dynamic":
config = Int8DynamicActivationInt8WeightConfig(
version=2,
granularity=granularity,
act_mapping_type=act_mapping_type,
reduce_range=reduce_range,
)
else:
act_granularity, _ = Int8Tensor._normalize_granularity(granularity)
quant_min, quant_max = (-64, 63) if reduce_range else (-128, 127)
block_size = get_block_size(input_tensor.shape, act_granularity)
act_quant_scale, act_quant_zero_point = choose_qparams_affine(
input=input_tensor,
mapping_type=act_mapping_type,
block_size=block_size,
target_dtype=torch.int8,
quant_min=quant_min,
quant_max=quant_max,
scale_dtype=dtype,
zero_point_dtype=torch.int8,
keepdim=True,
eps=torch.finfo(torch.float32).eps,
)
config = Int8StaticActivationInt8WeightConfig(
act_quant_scale=act_quant_scale,
act_quant_zero_point=act_quant_zero_point,
granularity=granularity,
act_mapping_type=act_mapping_type,
reduce_range=reduce_range,
)

quantize_(model_q, config)

_, weight_granularity = Int8Tensor._normalize_granularity(config.granularity)
if isinstance(weight_granularity, PerRow):
self.assertEqual(model_q.linear2.weight.scale.shape, (K, 1))
elif isinstance(weight_granularity, PerTensor):
self.assertEqual(model_q.linear2.weight.scale.shape, (1, 1))

self.assertEqual(model_q.linear2.weight.scale.ndim, 2)

if compile:
model_q = torch.compile(model_q, fullgraph=True)

output_fp = model(input_tensor)
output_quantized = model_q(input_tensor)
assert compute_error(output_fp, output_quantized) > 20, (
f"Quantization error is too high got a SQNR of {compute_error(output_fp, output_quantized)}"
)


if __name__ == "__main__":
common_utils.run_tests()
18 changes: 18 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,9 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
SYMMETRIC and ASYMMETRIC are supported.
set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values
for better performance with this quantization scheme.
version (int): the version of the config
reduce_range (Optional[bool] = False): If True, both activation and weight int8 quantization use reduced range
[-64, 63] instead of full range [-128, 127] to reduce overflow risk on platforms without VNNI instructions.

Example:

Expand All @@ -945,6 +948,7 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
] = PerRow()
set_inductor_config: bool = True
version: int = 2
reduce_range: Optional[bool] = False

def __post_init__(self):
torch._C._log_api_usage_once(
Expand All @@ -969,6 +973,9 @@ def __post_init__(self):
"Please set it to MappingType.SYMMETRIC or "
"MappingType.ASYMMETRIC."
)
assert self.reduce_range in (True, False), (
"`reduce_range` must be True or False"
)


def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
Expand All @@ -983,7 +990,9 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
act_quant_kwargs=QuantizeTensorToInt8Kwargs(
granularity=act_granularity,
mapping_type=config.act_mapping_type,
reduce_range=config.reduce_range,
),
reduce_range=config.reduce_range,
)

return quantized_weight
Expand Down Expand Up @@ -1037,6 +1046,8 @@ class Int8StaticActivationInt8WeightConfig(AOBaseConfig):
act_mapping_type (MappingType): The mapping type for activation quantization. SYMMETRIC and ASYMMETRIC are supported.
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
version (int): the version of the config
reduce_range (Optional[bool] = False): If True, both activation and weight int8 quantization use reduced range
[-64, 63] instead of full range [-128, 127] to reduce overflow risk on platforms without VNNI instructions.
"""

act_quant_scale: Optional[torch.Tensor] = None
Expand All @@ -1047,6 +1058,7 @@ class Int8StaticActivationInt8WeightConfig(AOBaseConfig):
act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC
set_inductor_config: bool = True
version: int = 1
reduce_range: Optional[bool] = False

def __post_init__(self):
torch._C._log_api_usage_once(
Expand All @@ -1066,6 +1078,9 @@ def __post_init__(self):
"Please set it to MappingType.SYMMETRIC or "
"MappingType.ASYMMETRIC."
)
assert self.reduce_range in (True, False), (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

"`reduce_range` must be True or False"
)

def get_act_quant_kwargs(self) -> QuantizeTensorToInt8Kwargs:
"""Get the activation quantization kwargs for static quantization.
Expand All @@ -1080,6 +1095,7 @@ def get_act_quant_kwargs(self) -> QuantizeTensorToInt8Kwargs:
return QuantizeTensorToInt8Kwargs(
granularity=act_granularity,
mapping_type=self.act_mapping_type,
reduce_range=self.reduce_range,
)


Expand Down Expand Up @@ -1110,9 +1126,11 @@ def _int8_static_activation_int8_weight_transform(
act_quant_kwargs=QuantizeTensorToInt8Kwargs(
granularity=activation_granularity,
mapping_type=config.act_mapping_type,
reduce_range=config.reduce_range,
),
act_quant_scale=config.act_quant_scale.detach(),
act_quant_zero_point=act_quant_zero_point,
reduce_range=config.reduce_range,
)

setattr(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def _choose_quant_func_and_quantize_tensor(
mapping_type=quant_kwargs.mapping_type,
scale=scale,
zero_point=zero_point,
reduce_range=quant_kwargs.reduce_range,
)

raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}")
35 changes: 30 additions & 5 deletions torchao/quantization/quantize_/workflows/int8/int8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,15 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs):
Args:
granularity: the granularity for the Tensor, currently either PerRow() or PerTensor()
mapping_type: whether to use symmetric or asymmetric quant
reduce_range: optional flag. If True, use reduced int8 range [-64, 63]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we should be consistent on using numbers v.s. _REDUCED_QUANT_MIN etc. I think, right now the quant_api is using the variable but here it's using variables.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK to use variables but might also be good to mention what these values are in the docstring as well

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as commented in #3784 (comment) I'm wondering if we can just set it automatically and print out a warning, instead of leave this to user?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for comments, updated.

instead of full range [-128, 127] to reduce overflow risk on
platforms without VNNI instructions. Kept optional for backward
compatibility with older call sites and serialized configs.
"""

granularity: Granularity
mapping_type: MappingType = MappingType.SYMMETRIC
reduce_range: Optional[bool] = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment on why Optional, and whether we can remove it in the future?

I don't think we need to guarantee BC for old checkpoints btw

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments, removed Optional as it's not needed to guarantee BC for old checkpoints.



class Int8Tensor(TorchAOBaseTensor):
Expand All @@ -61,6 +66,7 @@ class Int8Tensor(TorchAOBaseTensor):
Non-Tensor Attributes:
granularity: the granularity for quantization (e.g., PerRow(), PerTensor())
act_quant_kwargs: flags for dynamic activation quantization
reduce_range: optional flag for reduced int8 quantization range
"""

tensor_data_names = ["qdata", "scale"]
Expand All @@ -73,6 +79,7 @@ class Int8Tensor(TorchAOBaseTensor):
tensor_attribute_names = ["block_size", "dtype"]
optional_tensor_attribute_names = [
"act_quant_kwargs",
"reduce_range",
]

def __new__(
Expand All @@ -86,6 +93,7 @@ def __new__(
act_quant_zero_point: Optional[torch.Tensor] = None,
act_pre_scale: Optional[torch.Tensor] = None,
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
reduce_range: Optional[bool] = False,
):
kwargs = {
"device": qdata.device,
Expand All @@ -105,6 +113,7 @@ def __init__(
act_quant_zero_point: Optional[torch.Tensor] = None,
act_pre_scale: Optional[torch.Tensor] = None,
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
reduce_range: Optional[bool] = False,
):
super().__init__()
self.qdata = qdata
Expand All @@ -116,6 +125,7 @@ def __init__(
self.act_quant_scale = act_quant_scale
self.act_quant_zero_point = act_quant_zero_point
self.act_pre_scale = act_pre_scale
self.reduce_range = reduce_range

def __repr__(self):
return (
Expand All @@ -130,7 +140,8 @@ def __repr__(self):
f"block_size={self.block_size}, "
f"shape={self.shape}, "
f"device={self.device}, "
f"dtype={self.dtype})"
f"dtype={self.dtype}, "
f"reduce_range={self.reduce_range})"
)

@classmethod
Expand Down Expand Up @@ -173,19 +184,21 @@ def from_hp(
act_quant_scale: Optional[torch.Tensor] = None,
act_quant_zero_point: Optional[torch.Tensor] = None,
act_pre_scale: Optional[torch.Tensor] = None,
reduce_range: Optional[bool] = False,
):
"""Create Int8Tensor from high-precision tensor"""
block_size = get_block_size(hp_tensor.shape, granularity)
block_size = list(block_size)

quant_min, quant_max = (-64, 63) if reduce_range else (-128, 127)
if scale is None:
scale, zero_point = choose_qparams_affine(
input=hp_tensor,
mapping_type=mapping_type,
block_size=block_size,
target_dtype=torch.int8,
quant_min=-128,
quant_max=127,
quant_min=quant_min,
quant_max=quant_max,
scale_dtype=hp_tensor.dtype,
zero_point_dtype=torch.int8,
keepdim=True,
Expand All @@ -208,6 +221,8 @@ def from_hp(
scale=scale,
zero_point=zero_point,
output_dtype=torch.int8,
quant_min=quant_min,
quant_max=quant_max,
)

if mapping_type == MappingType.ASYMMETRIC:
Expand All @@ -225,18 +240,20 @@ def from_hp(
act_quant_zero_point=act_quant_zero_point,
act_pre_scale=act_pre_scale,
act_quant_kwargs=act_quant_kwargs,
reduce_range=reduce_range,
)

def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""Dequantize int8 tensor to floating point"""
quant_min, quant_max = (-64, 63) if self.reduce_range else (-128, 127)
return dequantize_affine(
input=self.qdata,
block_size=self.block_size,
scale=self.scale,
zero_point=self.zero_point,
input_dtype=torch.int8,
quant_min=-128,
quant_max=127,
quant_min=quant_min,
quant_max=quant_max,
output_dtype=output_dtype if output_dtype is not None else self.dtype,
)

Expand Down Expand Up @@ -399,6 +416,7 @@ def _(func, types, args, kwargs):
act_quant_scale=self.act_quant_scale,
act_quant_zero_point=self.act_quant_zero_point,
act_pre_scale=self.act_pre_scale,
reduce_range=self.reduce_range,
),
)

Expand Down Expand Up @@ -450,6 +468,10 @@ def _(func, types, args, kwargs):
if args[0].act_quant_zero_point is not None:
pinned_act_quant_zero_point = args[0].act_quant_zero_point.pin_memory()

pinned_act_pre_scale = None
if args[0].act_pre_scale is not None:
pinned_act_pre_scale = args[0].act_pre_scale.pin_memory()

return Int8Tensor(
pinned_qdata,
pinned_scale,
Expand All @@ -458,7 +480,9 @@ def _(func, types, args, kwargs):
zero_point=pinned_zero_point,
act_quant_scale=pinned_act_quant_scale,
act_quant_zero_point=pinned_act_quant_zero_point,
act_pre_scale=pinned_act_pre_scale,
act_quant_kwargs=args[0].act_quant_kwargs,
reduce_range=args[0].reduce_range,
)


Expand Down Expand Up @@ -487,6 +511,7 @@ def _(func, types, args, kwargs):
act_quant_zero_point=old_int8_tensor.act_quant_zero_point,
act_pre_scale=old_int8_tensor.act_pre_scale,
act_quant_kwargs=old_int8_tensor.act_quant_kwargs,
reduce_range=old_int8_tensor.reduce_range,
)
return return_and_correct_aliasing(func, args, kwargs, new_int8_tensor)

Expand Down