-
Notifications
You must be signed in to change notification settings - Fork 501
Add reduce_range to avoid overflow in int8 tensor #4266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 12 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
dd390af
Add reduce_range to avoid overflow in int8 tensor
2adbdd9
Refine UT
e087a60
Only test reduce_range=True on CPUs without VNNI support
3c0c12b
Hardcode recude range
9fd46cd
Move vnni check func into utils
4e825f1
Automatically set reduce_range
716cca8
Re-add user-configurable reduce_range
8a5840d
Add UT notes
741115a
Fix CI
6e8eb08
Set reduce_range default to False
6eaf74b
Use should_reduce_range in UT
6cdd804
Merge upstream/main into reduce_range
232882c
Move should_reduce_range to utils
ea96be6
Refine codes
c4d095c
Remove Optional
4877828
Modify assert
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
119 changes: 119 additions & 0 deletions
119
test/quantization/quantize_/workflows/int8/test_int8_tensor_cpu.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import copy | ||
|
|
||
| import torch | ||
| from torch.testing._internal import common_utils | ||
|
|
||
| from torchao.quantization import ( | ||
| Int8DynamicActivationInt8WeightConfig, | ||
| Int8StaticActivationInt8WeightConfig, | ||
| quantize_, | ||
| ) | ||
| from torchao.quantization.granularity import PerRow, PerTensor | ||
| from torchao.quantization.quant_primitives import ( | ||
| _DTYPE_TO_QVALUE_BOUNDS, | ||
| MappingType, | ||
| choose_qparams_affine, | ||
| ) | ||
| from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( | ||
| Int8Tensor, | ||
| should_reduce_range, | ||
| ) | ||
| from torchao.quantization.utils import compute_error, get_block_size | ||
| from torchao.testing.model_architectures import ToyTwoLinearModel | ||
| from torchao.testing.utils import TorchAOIntegrationTestCase | ||
| from torchao.utils import ( | ||
| is_ROCM, | ||
| ) | ||
|
|
||
|
|
||
| @common_utils.instantiate_parametrized_tests | ||
| class TestInt8TensorCPU(TorchAOIntegrationTestCase): | ||
| # Note: The reduce_range parameter can be manually set by users via the config. | ||
| # This UT only tests automatic reduce_range to avoid CI failures on CPUs without VNNI support. | ||
| @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) | ||
| @common_utils.parametrize("compile", [True, False]) | ||
| @common_utils.parametrize("config_mode", ["dynamic", "static"]) | ||
| @common_utils.parametrize( | ||
| "granularity", | ||
| [PerRow(), PerTensor(), (PerRow(), PerTensor()), (PerTensor(), PerRow())], | ||
| ) | ||
| @common_utils.parametrize( | ||
| "act_mapping_type", [MappingType.SYMMETRIC, MappingType.ASYMMETRIC] | ||
| ) | ||
| def test_int8_tensor_cpu( | ||
| self, act_mapping_type, granularity, config_mode, compile, dtype | ||
| ): | ||
| device = "cpu" | ||
| if is_ROCM(): | ||
| self.skipTest("Don't test CPU for ROCM version of torch") | ||
|
|
||
| torch.compiler.reset() | ||
|
|
||
| M, N, K = 64, 256, 256 | ||
| input_tensor = torch.randn(M, K, dtype=dtype, device=device) | ||
| model = ToyTwoLinearModel(K, N, K, dtype=dtype, device=device).eval() | ||
| model_q = copy.deepcopy(model) | ||
| reduce_range = should_reduce_range(input_tensor.device) | ||
|
|
||
| if config_mode == "dynamic": | ||
| config = Int8DynamicActivationInt8WeightConfig( | ||
| version=2, | ||
| granularity=granularity, | ||
| act_mapping_type=act_mapping_type, | ||
| reduce_range=reduce_range, | ||
| ) | ||
| else: | ||
| act_granularity, _ = Int8Tensor._normalize_granularity(granularity) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: add assert for config_mode == "static" |
||
| quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[torch.int8] | ||
| if reduce_range: | ||
| quant_min, quant_max = quant_min // 2, quant_max // 2 | ||
| block_size = get_block_size(input_tensor.shape, act_granularity) | ||
| act_quant_scale, act_quant_zero_point = choose_qparams_affine( | ||
| input=input_tensor, | ||
| mapping_type=act_mapping_type, | ||
| block_size=block_size, | ||
| target_dtype=torch.int8, | ||
| quant_min=quant_min, | ||
| quant_max=quant_max, | ||
| scale_dtype=dtype, | ||
| zero_point_dtype=torch.int8, | ||
| keepdim=True, | ||
| eps=torch.finfo(torch.float32).eps, | ||
| ) | ||
| config = Int8StaticActivationInt8WeightConfig( | ||
| act_quant_scale=act_quant_scale, | ||
| act_quant_zero_point=act_quant_zero_point, | ||
| granularity=granularity, | ||
| act_mapping_type=act_mapping_type, | ||
| reduce_range=reduce_range, | ||
| ) | ||
|
|
||
| quantize_(model_q, config) | ||
|
|
||
| _, weight_granularity = Int8Tensor._normalize_granularity(config.granularity) | ||
| if isinstance(weight_granularity, PerRow): | ||
| self.assertEqual(model_q.linear2.weight.scale.shape, (K, 1)) | ||
| elif isinstance(weight_granularity, PerTensor): | ||
| self.assertEqual(model_q.linear2.weight.scale.shape, (1, 1)) | ||
|
|
||
| self.assertEqual(model_q.linear2.weight.scale.ndim, 2) | ||
|
|
||
| if compile: | ||
| model_q = torch.compile(model_q, fullgraph=True) | ||
|
|
||
| output_fp = model(input_tensor) | ||
| output_quantized = model_q(input_tensor) | ||
|
|
||
| assert compute_error(output_fp, output_quantized) > 20, ( | ||
| f"Quantization error is too high got a SQNR of {compute_error(output_fp, output_quantized)}" | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| common_utils.run_tests() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the effect when reduce_range is not set correctly for cpu? should we test that as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comments. On CPUs without VNNI, when using torch.compile with int8 matmul, there may be accuracy degradation compared to eager mode due to overflow in oneDNN qlinear. The
reduce_rangesetting is designed to prevent this.We don't need to test incorrect settings because this UT focuses on the recommended path, and I've already provided the helper function
should_reduce_range()and clear UT Notes.