-
Notifications
You must be signed in to change notification settings - Fork 501
Add reduce_range to avoid overflow in int8 tensor #4266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
dd390af
2adbdd9
e087a60
3c0c12b
9fd46cd
4e825f1
716cca8
8a5840d
741115a
6e8eb08
6eaf74b
6cdd804
232882c
ea96be6
c4d095c
4877828
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,18 +12,20 @@ | |
| from torch.testing import FileCheck | ||
| from torch.testing._internal import common_utils | ||
|
|
||
| from torchao.kernel.intmm import _cpu_is_vnni_supported | ||
| from torchao.quantization import ( | ||
| Int8DynamicActivationInt8WeightConfig, | ||
| Int8StaticActivationInt8WeightConfig, | ||
| Int8WeightOnlyConfig, | ||
| quantize_, | ||
| ) | ||
| from torchao.quantization.granularity import PerGroup, PerRow, PerTensor | ||
| from torchao.quantization.quant_primitives import MappingType | ||
| from torchao.quantization.quant_primitives import MappingType, choose_qparams_affine | ||
| from torchao.quantization.quantize_.common import ( | ||
| _choose_quant_func_and_quantize_tensor, | ||
| ) | ||
| from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( | ||
| Int8Tensor, | ||
| QuantizeTensorToInt8Kwargs, | ||
| ) | ||
| from torchao.quantization.utils import compute_error, get_block_size | ||
|
|
@@ -515,5 +517,87 @@ def test_int8_weight_only_v2_correct_eps(self, dtype, device): | |
| self.assertGreater(sqnr, 40, f"SQNR too low: {sqnr}") | ||
|
|
||
|
|
||
| @common_utils.instantiate_parametrized_tests | ||
| class TestInt8TensorCPU(TorchAOIntegrationTestCase): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe create a new
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for comments, updated. |
||
| @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) | ||
| @common_utils.parametrize("compile", [True, False]) | ||
| @common_utils.parametrize("config_mode", ["dynamic", "static"]) | ||
| @common_utils.parametrize( | ||
| "granularity", | ||
| [PerRow(), PerTensor(), (PerRow(), PerTensor()), (PerTensor(), PerRow())], | ||
| ) | ||
| @common_utils.parametrize( | ||
| "act_mapping_type", [MappingType.SYMMETRIC, MappingType.ASYMMETRIC] | ||
| ) | ||
| @common_utils.parametrize("reduce_range", [False, True]) | ||
| def test_int8_tensor_cpu( | ||
| self, act_mapping_type, granularity, reduce_range, config_mode, compile, dtype | ||
| ): | ||
| device = "cpu" | ||
| if is_ROCM(): | ||
| self.skipTest("Don't test CPU for ROCM version of torch") | ||
| if not reduce_range and not _cpu_is_vnni_supported(): | ||
| self.skipTest( | ||
| "Only test reduce_range=True on CPUs without VNNI support to avoid int8 overflow." | ||
| ) | ||
|
|
||
| torch.compiler.reset() | ||
|
|
||
| M, N, K = 64, 256, 256 | ||
| input_tensor = torch.randn(M, K, dtype=dtype, device=device) | ||
| model = ToyTwoLinearModel(K, N, K, dtype=dtype, device=device).eval() | ||
| model_q = copy.deepcopy(model) | ||
|
|
||
| if config_mode == "dynamic": | ||
| config = Int8DynamicActivationInt8WeightConfig( | ||
| version=2, | ||
| granularity=granularity, | ||
| act_mapping_type=act_mapping_type, | ||
| reduce_range=reduce_range, | ||
| ) | ||
| else: | ||
| act_granularity, _ = Int8Tensor._normalize_granularity(granularity) | ||
| quant_min, quant_max = (-64, 63) if reduce_range else (-128, 127) | ||
| block_size = get_block_size(input_tensor.shape, act_granularity) | ||
| act_quant_scale, act_quant_zero_point = choose_qparams_affine( | ||
| input=input_tensor, | ||
| mapping_type=act_mapping_type, | ||
| block_size=block_size, | ||
| target_dtype=torch.int8, | ||
| quant_min=quant_min, | ||
| quant_max=quant_max, | ||
| scale_dtype=dtype, | ||
| zero_point_dtype=torch.int8, | ||
| keepdim=True, | ||
| eps=torch.finfo(torch.float32).eps, | ||
| ) | ||
| config = Int8StaticActivationInt8WeightConfig( | ||
| act_quant_scale=act_quant_scale, | ||
| act_quant_zero_point=act_quant_zero_point, | ||
| granularity=granularity, | ||
| act_mapping_type=act_mapping_type, | ||
| reduce_range=reduce_range, | ||
| ) | ||
|
|
||
| quantize_(model_q, config) | ||
|
|
||
| _, weight_granularity = Int8Tensor._normalize_granularity(config.granularity) | ||
| if isinstance(weight_granularity, PerRow): | ||
| self.assertEqual(model_q.linear2.weight.scale.shape, (K, 1)) | ||
| elif isinstance(weight_granularity, PerTensor): | ||
| self.assertEqual(model_q.linear2.weight.scale.shape, (1, 1)) | ||
|
|
||
| self.assertEqual(model_q.linear2.weight.scale.ndim, 2) | ||
|
|
||
| if compile: | ||
| model_q = torch.compile(model_q, fullgraph=True) | ||
|
|
||
| output_fp = model(input_tensor) | ||
| output_quantized = model_q(input_tensor) | ||
| assert compute_error(output_fp, output_quantized) > 20, ( | ||
| f"Quantization error is too high got a SQNR of {compute_error(output_fp, output_quantized)}" | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| common_utils.run_tests() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -931,6 +931,9 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig): | |
| SYMMETRIC and ASYMMETRIC are supported. | ||
| set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values | ||
| for better performance with this quantization scheme. | ||
| version (int): the version of the config | ||
| reduce_range (Optional[bool] = False): If True, both activation and weight int8 quantization use reduced range | ||
| [-64, 63] instead of full range [-128, 127] to reduce overflow risk on platforms without VNNI instructions. | ||
|
|
||
| Example: | ||
|
|
||
|
|
@@ -945,6 +948,7 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig): | |
| ] = PerRow() | ||
| set_inductor_config: bool = True | ||
| version: int = 2 | ||
| reduce_range: Optional[bool] = False | ||
|
|
||
| def __post_init__(self): | ||
| torch._C._log_api_usage_once( | ||
|
|
@@ -969,6 +973,9 @@ def __post_init__(self): | |
| "Please set it to MappingType.SYMMETRIC or " | ||
| "MappingType.ASYMMETRIC." | ||
| ) | ||
| assert self.reduce_range in (True, False), ( | ||
| "`reduce_range` must be True or False" | ||
| ) | ||
|
|
||
|
|
||
| def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): | ||
|
|
@@ -983,7 +990,9 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): | |
| act_quant_kwargs=QuantizeTensorToInt8Kwargs( | ||
| granularity=act_granularity, | ||
| mapping_type=config.act_mapping_type, | ||
| reduce_range=config.reduce_range, | ||
| ), | ||
| reduce_range=config.reduce_range, | ||
| ) | ||
|
|
||
| return quantized_weight | ||
|
|
@@ -1037,6 +1046,8 @@ class Int8StaticActivationInt8WeightConfig(AOBaseConfig): | |
| act_mapping_type (MappingType): The mapping type for activation quantization. SYMMETRIC and ASYMMETRIC are supported. | ||
| set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. | ||
| version (int): the version of the config | ||
| reduce_range (Optional[bool] = False): If True, both activation and weight int8 quantization use reduced range | ||
| [-64, 63] instead of full range [-128, 127] to reduce overflow risk on platforms without VNNI instructions. | ||
| """ | ||
|
|
||
| act_quant_scale: Optional[torch.Tensor] = None | ||
|
|
@@ -1047,6 +1058,7 @@ class Int8StaticActivationInt8WeightConfig(AOBaseConfig): | |
| act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC | ||
| set_inductor_config: bool = True | ||
| version: int = 1 | ||
| reduce_range: Optional[bool] = False | ||
|
|
||
| def __post_init__(self): | ||
| torch._C._log_api_usage_once( | ||
|
|
@@ -1066,6 +1078,9 @@ def __post_init__(self): | |
| "Please set it to MappingType.SYMMETRIC or " | ||
| "MappingType.ASYMMETRIC." | ||
| ) | ||
| assert self.reduce_range in (True, False), ( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||
| "`reduce_range` must be True or False" | ||
| ) | ||
|
|
||
| def get_act_quant_kwargs(self) -> QuantizeTensorToInt8Kwargs: | ||
| """Get the activation quantization kwargs for static quantization. | ||
|
|
@@ -1080,6 +1095,7 @@ def get_act_quant_kwargs(self) -> QuantizeTensorToInt8Kwargs: | |
| return QuantizeTensorToInt8Kwargs( | ||
| granularity=act_granularity, | ||
| mapping_type=self.act_mapping_type, | ||
| reduce_range=self.reduce_range, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -1110,9 +1126,11 @@ def _int8_static_activation_int8_weight_transform( | |
| act_quant_kwargs=QuantizeTensorToInt8Kwargs( | ||
| granularity=activation_granularity, | ||
| mapping_type=config.act_mapping_type, | ||
| reduce_range=config.reduce_range, | ||
| ), | ||
| act_quant_scale=config.act_quant_scale.detach(), | ||
| act_quant_zero_point=act_quant_zero_point, | ||
| reduce_range=config.reduce_range, | ||
| ) | ||
|
|
||
| setattr( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,10 +43,15 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs): | |
| Args: | ||
| granularity: the granularity for the Tensor, currently either PerRow() or PerTensor() | ||
| mapping_type: whether to use symmetric or asymmetric quant | ||
| reduce_range: optional flag. If True, use reduced int8 range [-64, 63] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: we should be consistent on using numbers v.s.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK to use variables but might also be good to mention what these values are in the docstring as well
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as commented in #3784 (comment) I'm wondering if we can just set it automatically and print out a warning, instead of leave this to user?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for comments, updated. |
||
| instead of full range [-128, 127] to reduce overflow risk on | ||
| platforms without VNNI instructions. Kept optional for backward | ||
| compatibility with older call sites and serialized configs. | ||
| """ | ||
|
|
||
| granularity: Granularity | ||
| mapping_type: MappingType = MappingType.SYMMETRIC | ||
| reduce_range: Optional[bool] = False | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add a comment on why Optional, and whether we can remove it in the future? I don't think we need to guarantee BC for old checkpoints btw
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the comments, removed |
||
|
|
||
|
|
||
| class Int8Tensor(TorchAOBaseTensor): | ||
|
|
@@ -61,6 +66,7 @@ class Int8Tensor(TorchAOBaseTensor): | |
| Non-Tensor Attributes: | ||
| granularity: the granularity for quantization (e.g., PerRow(), PerTensor()) | ||
| act_quant_kwargs: flags for dynamic activation quantization | ||
| reduce_range: optional flag for reduced int8 quantization range | ||
| """ | ||
|
|
||
| tensor_data_names = ["qdata", "scale"] | ||
|
|
@@ -73,6 +79,7 @@ class Int8Tensor(TorchAOBaseTensor): | |
| tensor_attribute_names = ["block_size", "dtype"] | ||
| optional_tensor_attribute_names = [ | ||
| "act_quant_kwargs", | ||
| "reduce_range", | ||
| ] | ||
|
|
||
| def __new__( | ||
|
|
@@ -86,6 +93,7 @@ def __new__( | |
| act_quant_zero_point: Optional[torch.Tensor] = None, | ||
| act_pre_scale: Optional[torch.Tensor] = None, | ||
| act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, | ||
| reduce_range: Optional[bool] = False, | ||
| ): | ||
| kwargs = { | ||
| "device": qdata.device, | ||
|
|
@@ -105,6 +113,7 @@ def __init__( | |
| act_quant_zero_point: Optional[torch.Tensor] = None, | ||
| act_pre_scale: Optional[torch.Tensor] = None, | ||
| act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, | ||
| reduce_range: Optional[bool] = False, | ||
| ): | ||
| super().__init__() | ||
| self.qdata = qdata | ||
|
|
@@ -116,6 +125,7 @@ def __init__( | |
| self.act_quant_scale = act_quant_scale | ||
| self.act_quant_zero_point = act_quant_zero_point | ||
| self.act_pre_scale = act_pre_scale | ||
| self.reduce_range = reduce_range | ||
|
|
||
| def __repr__(self): | ||
| return ( | ||
|
|
@@ -130,7 +140,8 @@ def __repr__(self): | |
| f"block_size={self.block_size}, " | ||
| f"shape={self.shape}, " | ||
| f"device={self.device}, " | ||
| f"dtype={self.dtype})" | ||
| f"dtype={self.dtype}, " | ||
| f"reduce_range={self.reduce_range})" | ||
| ) | ||
|
|
||
| @classmethod | ||
|
|
@@ -173,19 +184,21 @@ def from_hp( | |
| act_quant_scale: Optional[torch.Tensor] = None, | ||
| act_quant_zero_point: Optional[torch.Tensor] = None, | ||
| act_pre_scale: Optional[torch.Tensor] = None, | ||
| reduce_range: Optional[bool] = False, | ||
| ): | ||
| """Create Int8Tensor from high-precision tensor""" | ||
| block_size = get_block_size(hp_tensor.shape, granularity) | ||
| block_size = list(block_size) | ||
|
|
||
| quant_min, quant_max = (-64, 63) if reduce_range else (-128, 127) | ||
| if scale is None: | ||
| scale, zero_point = choose_qparams_affine( | ||
| input=hp_tensor, | ||
| mapping_type=mapping_type, | ||
| block_size=block_size, | ||
| target_dtype=torch.int8, | ||
| quant_min=-128, | ||
| quant_max=127, | ||
| quant_min=quant_min, | ||
| quant_max=quant_max, | ||
| scale_dtype=hp_tensor.dtype, | ||
| zero_point_dtype=torch.int8, | ||
| keepdim=True, | ||
|
|
@@ -208,6 +221,8 @@ def from_hp( | |
| scale=scale, | ||
| zero_point=zero_point, | ||
| output_dtype=torch.int8, | ||
| quant_min=quant_min, | ||
| quant_max=quant_max, | ||
| ) | ||
|
|
||
| if mapping_type == MappingType.ASYMMETRIC: | ||
|
|
@@ -225,18 +240,20 @@ def from_hp( | |
| act_quant_zero_point=act_quant_zero_point, | ||
| act_pre_scale=act_pre_scale, | ||
| act_quant_kwargs=act_quant_kwargs, | ||
| reduce_range=reduce_range, | ||
| ) | ||
|
|
||
| def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: | ||
| """Dequantize int8 tensor to floating point""" | ||
| quant_min, quant_max = (-64, 63) if self.reduce_range else (-128, 127) | ||
| return dequantize_affine( | ||
| input=self.qdata, | ||
| block_size=self.block_size, | ||
| scale=self.scale, | ||
| zero_point=self.zero_point, | ||
| input_dtype=torch.int8, | ||
| quant_min=-128, | ||
| quant_max=127, | ||
| quant_min=quant_min, | ||
| quant_max=quant_max, | ||
| output_dtype=output_dtype if output_dtype is not None else self.dtype, | ||
| ) | ||
|
|
||
|
|
@@ -399,6 +416,7 @@ def _(func, types, args, kwargs): | |
| act_quant_scale=self.act_quant_scale, | ||
| act_quant_zero_point=self.act_quant_zero_point, | ||
| act_pre_scale=self.act_pre_scale, | ||
| reduce_range=self.reduce_range, | ||
| ), | ||
| ) | ||
|
|
||
|
|
@@ -450,6 +468,10 @@ def _(func, types, args, kwargs): | |
| if args[0].act_quant_zero_point is not None: | ||
| pinned_act_quant_zero_point = args[0].act_quant_zero_point.pin_memory() | ||
|
|
||
| pinned_act_pre_scale = None | ||
| if args[0].act_pre_scale is not None: | ||
| pinned_act_pre_scale = args[0].act_pre_scale.pin_memory() | ||
|
|
||
| return Int8Tensor( | ||
| pinned_qdata, | ||
| pinned_scale, | ||
|
|
@@ -458,7 +480,9 @@ def _(func, types, args, kwargs): | |
| zero_point=pinned_zero_point, | ||
| act_quant_scale=pinned_act_quant_scale, | ||
| act_quant_zero_point=pinned_act_quant_zero_point, | ||
| act_pre_scale=pinned_act_pre_scale, | ||
| act_quant_kwargs=args[0].act_quant_kwargs, | ||
| reduce_range=args[0].reduce_range, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -487,6 +511,7 @@ def _(func, types, args, kwargs): | |
| act_quant_zero_point=old_int8_tensor.act_quant_zero_point, | ||
| act_pre_scale=old_int8_tensor.act_pre_scale, | ||
| act_quant_kwargs=old_int8_tensor.act_quant_kwargs, | ||
| reduce_range=old_int8_tensor.reduce_range, | ||
| ) | ||
| return return_and_correct_aliasing(func, args, kwargs, new_int8_tensor) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.