-
Notifications
You must be signed in to change notification settings - Fork 501
Add reduce_range to avoid overflow in int8 tensor #4266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
dd390af
2adbdd9
e087a60
3c0c12b
9fd46cd
4e825f1
716cca8
8a5840d
741115a
6e8eb08
6eaf74b
6cdd804
232882c
ea96be6
c4d095c
4877828
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,18 +12,24 @@ | |
| from torch.testing import FileCheck | ||
| from torch.testing._internal import common_utils | ||
|
|
||
| from torchao.kernel.intmm import _cpu_is_vnni_supported | ||
| from torchao.quantization import ( | ||
| Int8DynamicActivationInt8WeightConfig, | ||
| Int8StaticActivationInt8WeightConfig, | ||
| Int8WeightOnlyConfig, | ||
| quantize_, | ||
| ) | ||
| from torchao.quantization.granularity import PerGroup, PerRow, PerTensor | ||
| from torchao.quantization.quant_primitives import MappingType | ||
| from torchao.quantization.quant_primitives import MappingType, choose_qparams_affine | ||
| from torchao.quantization.quantize_.common import ( | ||
| _choose_quant_func_and_quantize_tensor, | ||
| ) | ||
| from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( | ||
| _FULL_QUANT_MAX, | ||
| _FULL_QUANT_MIN, | ||
| _REDUCED_QUANT_MAX, | ||
| _REDUCED_QUANT_MIN, | ||
| Int8Tensor, | ||
| QuantizeTensorToInt8Kwargs, | ||
| ) | ||
| from torchao.quantization.utils import compute_error, get_block_size | ||
|
|
@@ -515,5 +521,89 @@ def test_int8_weight_only_v2_correct_eps(self, dtype, device): | |
| self.assertGreater(sqnr, 40, f"SQNR too low: {sqnr}") | ||
|
|
||
|
|
||
| @common_utils.instantiate_parametrized_tests | ||
| class TestInt8TensorCPU(TorchAOIntegrationTestCase): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe create a new
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for comments, updated. |
||
| @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) | ||
| @common_utils.parametrize("compile", [True, False]) | ||
| @common_utils.parametrize("config_mode", ["dynamic", "static"]) | ||
| @common_utils.parametrize( | ||
| "granularity", | ||
| [PerRow(), PerTensor(), (PerRow(), PerTensor()), (PerTensor(), PerRow())], | ||
| ) | ||
| @common_utils.parametrize( | ||
| "act_mapping_type", [MappingType.SYMMETRIC, MappingType.ASYMMETRIC] | ||
| ) | ||
| @common_utils.parametrize("reduce_range", [False, True]) | ||
| def test_int8_tensor_cpu( | ||
| self, act_mapping_type, granularity, reduce_range, config_mode, compile, dtype | ||
| ): | ||
| device = "cpu" | ||
| if is_ROCM(): | ||
| self.skipTest("Don't test CPU for ROCM version of torch") | ||
|
|
||
| torch.compiler.reset() | ||
|
|
||
| M, N, K = 64, 256, 256 | ||
| input_tensor = torch.randn(M, K, dtype=dtype, device=device) | ||
| model = ToyTwoLinearModel(K, N, K, dtype=dtype, device=device).eval() | ||
| model_q = copy.deepcopy(model) | ||
|
|
||
| if config_mode == "dynamic": | ||
| config = Int8DynamicActivationInt8WeightConfig( | ||
| version=2, | ||
| granularity=granularity, | ||
| act_mapping_type=act_mapping_type, | ||
| reduce_range=reduce_range, | ||
| ) | ||
| else: | ||
| act_granularity, _ = Int8Tensor._normalize_granularity(granularity) | ||
| quant_min, quant_max = ( | ||
| (_REDUCED_QUANT_MIN, _REDUCED_QUANT_MAX) | ||
| if reduce_range | ||
| else (_FULL_QUANT_MIN, _FULL_QUANT_MAX) | ||
| ) | ||
| block_size = get_block_size(input_tensor.shape, act_granularity) | ||
| act_quant_scale, act_quant_zero_point = choose_qparams_affine( | ||
| input=input_tensor, | ||
| mapping_type=act_mapping_type, | ||
| block_size=block_size, | ||
| target_dtype=torch.int8, | ||
| quant_min=quant_min, | ||
| quant_max=quant_max, | ||
| scale_dtype=dtype, | ||
| zero_point_dtype=torch.int8, | ||
| keepdim=True, | ||
| eps=torch.finfo(torch.float32).eps, | ||
| ) | ||
| config = Int8StaticActivationInt8WeightConfig( | ||
| act_quant_scale=act_quant_scale, | ||
| act_quant_zero_point=act_quant_zero_point, | ||
| granularity=granularity, | ||
| act_mapping_type=act_mapping_type, | ||
| reduce_range=reduce_range, | ||
| ) | ||
|
|
||
| quantize_(model_q, config) | ||
|
|
||
| _, weight_granularity = Int8Tensor._normalize_granularity(config.granularity) | ||
| if isinstance(weight_granularity, PerRow): | ||
| self.assertEqual(model_q.linear2.weight.scale.shape, (K, 1)) | ||
| elif isinstance(weight_granularity, PerTensor): | ||
| self.assertEqual(model_q.linear2.weight.scale.shape, (1, 1)) | ||
|
|
||
| self.assertEqual(model_q.linear2.weight.scale.ndim, 2) | ||
|
|
||
| if compile: | ||
| model_q = torch.compile(model_q, fullgraph=True) | ||
|
|
||
| output_fp = model(input_tensor) | ||
| output_quantized = model_q(input_tensor) | ||
|
|
||
| if reduce_range or _cpu_is_vnni_supported(): | ||
| assert compute_error(output_fp, output_quantized) > 20, ( | ||
| f"Quantization error is too high got a SQNR of {compute_error(output_fp, output_quantized)}" | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| common_utils.run_tests() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -931,6 +931,10 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig): | |
| SYMMETRIC and ASYMMETRIC are supported. | ||
| set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values | ||
| for better performance with this quantization scheme. | ||
| version (int): the version of the config | ||
| reduce_range (Optional[bool] = False): If True, both activation and weight int8 quantization use reduced range | ||
| [_REDUCED_QUANT_MIN, _REDUCED_QUANT_MAX] instead of full range | ||
| [_FULL_QUANT_MIN, _FULL_QUANT_MAX] to reduce overflow risk on platforms without VNNI instructions. | ||
|
|
||
| Example: | ||
|
|
||
|
|
@@ -945,6 +949,7 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig): | |
| ] = PerRow() | ||
| set_inductor_config: bool = True | ||
| version: int = 2 | ||
| reduce_range: Optional[bool] = False | ||
|
|
||
| def __post_init__(self): | ||
| torch._C._log_api_usage_once( | ||
|
|
@@ -969,6 +974,9 @@ def __post_init__(self): | |
| "Please set it to MappingType.SYMMETRIC or " | ||
| "MappingType.ASYMMETRIC." | ||
| ) | ||
| assert self.reduce_range in (True, False), ( | ||
| "`reduce_range` must be True or False" | ||
| ) | ||
|
|
||
|
|
||
| def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): | ||
|
|
@@ -983,7 +991,9 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): | |
| act_quant_kwargs=QuantizeTensorToInt8Kwargs( | ||
| granularity=act_granularity, | ||
| mapping_type=config.act_mapping_type, | ||
| reduce_range=config.reduce_range, | ||
| ), | ||
| reduce_range=config.reduce_range, | ||
| ) | ||
|
|
||
| return quantized_weight | ||
|
|
@@ -1037,6 +1047,9 @@ class Int8StaticActivationInt8WeightConfig(AOBaseConfig): | |
| act_mapping_type (MappingType): The mapping type for activation quantization. SYMMETRIC and ASYMMETRIC are supported. | ||
| set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. | ||
| version (int): the version of the config | ||
| reduce_range (Optional[bool] = False): If True, both activation and weight int8 quantization use reduced range | ||
| [_REDUCED_QUANT_MIN, _REDUCED_QUANT_MAX] instead of full range | ||
| [_FULL_QUANT_MIN, _FULL_QUANT_MAX] to reduce overflow risk on platforms without VNNI instructions. | ||
| """ | ||
|
|
||
| act_quant_scale: Optional[torch.Tensor] = None | ||
|
|
@@ -1047,6 +1060,7 @@ class Int8StaticActivationInt8WeightConfig(AOBaseConfig): | |
| act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC | ||
| set_inductor_config: bool = True | ||
| version: int = 1 | ||
| reduce_range: Optional[bool] = False | ||
|
|
||
| def __post_init__(self): | ||
| torch._C._log_api_usage_once( | ||
|
|
@@ -1066,6 +1080,9 @@ def __post_init__(self): | |
| "Please set it to MappingType.SYMMETRIC or " | ||
| "MappingType.ASYMMETRIC." | ||
| ) | ||
| assert self.reduce_range in (True, False), ( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||
| "`reduce_range` must be True or False" | ||
| ) | ||
|
|
||
| def get_act_quant_kwargs(self) -> QuantizeTensorToInt8Kwargs: | ||
| """Get the activation quantization kwargs for static quantization. | ||
|
|
@@ -1080,6 +1097,7 @@ def get_act_quant_kwargs(self) -> QuantizeTensorToInt8Kwargs: | |
| return QuantizeTensorToInt8Kwargs( | ||
| granularity=act_granularity, | ||
| mapping_type=self.act_mapping_type, | ||
| reduce_range=self.reduce_range, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -1110,9 +1128,11 @@ def _int8_static_activation_int8_weight_transform( | |
| act_quant_kwargs=QuantizeTensorToInt8Kwargs( | ||
| granularity=activation_granularity, | ||
| mapping_type=config.act_mapping_type, | ||
| reduce_range=config.reduce_range, | ||
| ), | ||
| act_quant_scale=config.act_quant_scale.detach(), | ||
| act_quant_zero_point=act_quant_zero_point, | ||
| reduce_range=config.reduce_range, | ||
| ) | ||
|
|
||
| setattr( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.