Skip to content
92 changes: 91 additions & 1 deletion test/quantization/quantize_/workflows/int8/test_int8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,24 @@
from torch.testing import FileCheck
from torch.testing._internal import common_utils

from torchao.kernel.intmm import _cpu_is_vnni_supported
Comment thread
Xia-Weiwen marked this conversation as resolved.
Outdated
from torchao.quantization import (
Int8DynamicActivationInt8WeightConfig,
Int8StaticActivationInt8WeightConfig,
Int8WeightOnlyConfig,
quantize_,
)
from torchao.quantization.granularity import PerGroup, PerRow, PerTensor
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.quant_primitives import MappingType, choose_qparams_affine
from torchao.quantization.quantize_.common import (
_choose_quant_func_and_quantize_tensor,
)
from torchao.quantization.quantize_.workflows.int8.int8_tensor import (
_FULL_QUANT_MAX,
_FULL_QUANT_MIN,
_REDUCED_QUANT_MAX,
_REDUCED_QUANT_MIN,
Int8Tensor,
QuantizeTensorToInt8Kwargs,
)
from torchao.quantization.utils import compute_error, get_block_size
Expand Down Expand Up @@ -515,5 +521,89 @@ def test_int8_weight_only_v2_correct_eps(self, dtype, device):
self.assertGreater(sqnr, 40, f"SQNR too low: {sqnr}")


@common_utils.instantiate_parametrized_tests
class TestInt8TensorCPU(TorchAOIntegrationTestCase):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe create a new test_int8_tensor_cpu.py it will be easier to separate tests by device for CI in the future if we want.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for comments, updated.

@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize("config_mode", ["dynamic", "static"])
@common_utils.parametrize(
"granularity",
[PerRow(), PerTensor(), (PerRow(), PerTensor()), (PerTensor(), PerRow())],
)
@common_utils.parametrize(
"act_mapping_type", [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]
)
@common_utils.parametrize("reduce_range", [False, True])
def test_int8_tensor_cpu(
self, act_mapping_type, granularity, reduce_range, config_mode, compile, dtype
):
device = "cpu"
if is_ROCM():
self.skipTest("Don't test CPU for ROCM version of torch")

torch.compiler.reset()

M, N, K = 64, 256, 256
input_tensor = torch.randn(M, K, dtype=dtype, device=device)
model = ToyTwoLinearModel(K, N, K, dtype=dtype, device=device).eval()
model_q = copy.deepcopy(model)

if config_mode == "dynamic":
config = Int8DynamicActivationInt8WeightConfig(
version=2,
granularity=granularity,
act_mapping_type=act_mapping_type,
reduce_range=reduce_range,
)
else:
act_granularity, _ = Int8Tensor._normalize_granularity(granularity)
quant_min, quant_max = (
(_REDUCED_QUANT_MIN, _REDUCED_QUANT_MAX)
if reduce_range
else (_FULL_QUANT_MIN, _FULL_QUANT_MAX)
)
block_size = get_block_size(input_tensor.shape, act_granularity)
act_quant_scale, act_quant_zero_point = choose_qparams_affine(
input=input_tensor,
mapping_type=act_mapping_type,
block_size=block_size,
target_dtype=torch.int8,
quant_min=quant_min,
quant_max=quant_max,
scale_dtype=dtype,
zero_point_dtype=torch.int8,
keepdim=True,
eps=torch.finfo(torch.float32).eps,
)
config = Int8StaticActivationInt8WeightConfig(
act_quant_scale=act_quant_scale,
act_quant_zero_point=act_quant_zero_point,
granularity=granularity,
act_mapping_type=act_mapping_type,
reduce_range=reduce_range,
)

quantize_(model_q, config)

_, weight_granularity = Int8Tensor._normalize_granularity(config.granularity)
if isinstance(weight_granularity, PerRow):
self.assertEqual(model_q.linear2.weight.scale.shape, (K, 1))
elif isinstance(weight_granularity, PerTensor):
self.assertEqual(model_q.linear2.weight.scale.shape, (1, 1))

self.assertEqual(model_q.linear2.weight.scale.ndim, 2)

if compile:
model_q = torch.compile(model_q, fullgraph=True)

output_fp = model(input_tensor)
output_quantized = model_q(input_tensor)

if reduce_range or _cpu_is_vnni_supported():
assert compute_error(output_fp, output_quantized) > 20, (
f"Quantization error is too high got a SQNR of {compute_error(output_fp, output_quantized)}"
)


if __name__ == "__main__":
common_utils.run_tests()
20 changes: 20 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,10 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
SYMMETRIC and ASYMMETRIC are supported.
set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values
for better performance with this quantization scheme.
version (int): the version of the config
reduce_range (Optional[bool] = False): If True, both activation and weight int8 quantization use reduced range
[_REDUCED_QUANT_MIN, _REDUCED_QUANT_MAX] instead of full range
[_FULL_QUANT_MIN, _FULL_QUANT_MAX] to reduce overflow risk on platforms without VNNI instructions.

Example:

Expand All @@ -945,6 +949,7 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
] = PerRow()
set_inductor_config: bool = True
version: int = 2
reduce_range: Optional[bool] = False

def __post_init__(self):
torch._C._log_api_usage_once(
Expand All @@ -969,6 +974,9 @@ def __post_init__(self):
"Please set it to MappingType.SYMMETRIC or "
"MappingType.ASYMMETRIC."
)
assert self.reduce_range in (True, False), (
"`reduce_range` must be True or False"
)


def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
Expand All @@ -983,7 +991,9 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
act_quant_kwargs=QuantizeTensorToInt8Kwargs(
granularity=act_granularity,
mapping_type=config.act_mapping_type,
reduce_range=config.reduce_range,
),
reduce_range=config.reduce_range,
)

return quantized_weight
Expand Down Expand Up @@ -1037,6 +1047,9 @@ class Int8StaticActivationInt8WeightConfig(AOBaseConfig):
act_mapping_type (MappingType): The mapping type for activation quantization. SYMMETRIC and ASYMMETRIC are supported.
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
version (int): the version of the config
reduce_range (Optional[bool] = False): If True, both activation and weight int8 quantization use reduced range
[_REDUCED_QUANT_MIN, _REDUCED_QUANT_MAX] instead of full range
[_FULL_QUANT_MIN, _FULL_QUANT_MAX] to reduce overflow risk on platforms without VNNI instructions.
"""

act_quant_scale: Optional[torch.Tensor] = None
Expand All @@ -1047,6 +1060,7 @@ class Int8StaticActivationInt8WeightConfig(AOBaseConfig):
act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC
set_inductor_config: bool = True
version: int = 1
reduce_range: Optional[bool] = False

def __post_init__(self):
torch._C._log_api_usage_once(
Expand All @@ -1066,6 +1080,9 @@ def __post_init__(self):
"Please set it to MappingType.SYMMETRIC or "
"MappingType.ASYMMETRIC."
)
assert self.reduce_range in (True, False), (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

"`reduce_range` must be True or False"
)

def get_act_quant_kwargs(self) -> QuantizeTensorToInt8Kwargs:
"""Get the activation quantization kwargs for static quantization.
Expand All @@ -1080,6 +1097,7 @@ def get_act_quant_kwargs(self) -> QuantizeTensorToInt8Kwargs:
return QuantizeTensorToInt8Kwargs(
granularity=act_granularity,
mapping_type=self.act_mapping_type,
reduce_range=self.reduce_range,
)


Expand Down Expand Up @@ -1110,9 +1128,11 @@ def _int8_static_activation_int8_weight_transform(
act_quant_kwargs=QuantizeTensorToInt8Kwargs(
granularity=activation_granularity,
mapping_type=config.act_mapping_type,
reduce_range=config.reduce_range,
),
act_quant_scale=config.act_quant_scale.detach(),
act_quant_zero_point=act_quant_zero_point,
reduce_range=config.reduce_range,
)

setattr(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def _choose_quant_func_and_quantize_tensor(
mapping_type=quant_kwargs.mapping_type,
scale=scale,
zero_point=zero_point,
reduce_range=quant_kwargs.reduce_range,
)

raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}")
2 changes: 0 additions & 2 deletions torchao/quantization/quantize_/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@
"Sparse2x4CUTLASSFloat8Tensor",
"Float8PackingFormat",
"QuantizeTensorToFloat8Kwargs",
"Int8Tensor",
"QuantizeTensorToInt8Kwargs",
"Int4ChooseQParamsAlgorithm",
"Int4PackingFormat",
"IntxChooseQParamsAlgorithm",
Expand Down
Loading