Skip to content

Commit d76624e

Browse files
authored
Suppress deprecation warnings from #4074 and #4100 (#4155)
**Summary:** #4074 and #4100 deprecated a few classes, but this triggered the following warnings when the user imports torchao from the top-level. This commit suppresses these warnings in this case. Before: ``` import torchao /data/users/andrewor/ao/torchao/dtypes/utils.py:89: UserWarning: Deprecation: PlainLayout is deprecated and will be removed in a future release of torchao, see #2752 for more details warnings.warn( /data/users/andrewor/ao/torchao/quantization/quant_primitives.py:95: UserWarning: Deprecation: TorchAODType is deprecated, please use the torch.intN dtype instead (e.g. TorchAODType.INT4 -> torch.int4) warnings.warn( /data/users/andrewor/ao/torchao/dtypes/utils.py:89: UserWarning: Deprecation: PlainLayout is deprecated and will be removed in a future release of torchao, see #2752 for more details warnings.warn( ``` After: ``` import torchao \# No warnings ``` **Test Plan:** Manual testing.
1 parent f0f9a05 commit d76624e

File tree

4 files changed

+50
-40
lines changed

4 files changed

+50
-40
lines changed

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def from_hp_to_intx(
234234
zero_point_dtype: Optional[torch.dtype] = None,
235235
preserve_zero: bool = True,
236236
zero_point_domain: ZeroPointDomain = _DEFAULT_ZPD, # type: ignore[assignment]
237-
_layout: Layout = PlainLayout(),
237+
_layout: Optional[Layout] = None,
238238
use_hqq: bool = False,
239239
*,
240240
custom_scale: Optional[torch.Tensor] = None,
@@ -253,6 +253,8 @@ def from_hp_to_intx(
253253
quantize_affine,
254254
)
255255

256+
if _layout is None:
257+
_layout = PlainLayout()
256258
if zero_point_domain is _DEFAULT_ZPD:
257259
zero_point_domain = ZeroPointDomain.INT
258260

@@ -399,7 +401,7 @@ def from_hp_to_intx_static(
399401
quant_min: Optional[int] = None,
400402
quant_max: Optional[int] = None,
401403
zero_point_domain: ZeroPointDomain = _DEFAULT_ZPD, # type: ignore[assignment]
402-
_layout: Layout = PlainLayout(),
404+
_layout: Optional[Layout] = None,
403405
):
404406
"""Create an integer AffineQuantizedTensor from a high precision tensor using static parameters."""
405407
from torchao.quantization.quant_primitives import (
@@ -409,6 +411,8 @@ def from_hp_to_intx_static(
409411
quantize_affine,
410412
)
411413

414+
if _layout is None:
415+
_layout = PlainLayout()
412416
if zero_point_domain is _DEFAULT_ZPD:
413417
zero_point_domain = ZeroPointDomain.INT
414418
if zero_point_domain is None:

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,6 @@ class IntxFakeQuantizeConfig(FakeQuantizeConfigBase):
109109
110110
Args:
111111
dtype: dtype to simulate during fake quantization, e.g. torch.int8.
112-
For PyTorch versions older than 2.6, you may use `TorchAODType` to represent
113-
torch.int1 to torch.int7 instead, e.g. TorchAODType.INT4.
114112
granularity: granularity of scales and zero points, e.g. PerGroup(32).
115113
We also support the following strings:
116114
1) 'per_token': equivalent to PerToken()
@@ -151,7 +149,7 @@ class IntxFakeQuantizeConfig(FakeQuantizeConfigBase):
151149
IntxFakeQuantizeConfig(torch.int4, PerGroup(32), MappingType.SYMMETRIC)
152150
"""
153151

154-
dtype: Union[torch.dtype, TorchAODType]
152+
dtype: Union[torch.dtype, "TorchAODType"]
155153
granularity: Granularity
156154
mapping_type: MappingType
157155
scale_precision: torch.dtype
@@ -163,7 +161,7 @@ class IntxFakeQuantizeConfig(FakeQuantizeConfigBase):
163161

164162
def __init__(
165163
self,
166-
dtype: Union[torch.dtype, TorchAODType],
164+
dtype: Union[torch.dtype, "TorchAODType"],
167165
granularity: Union[Granularity, str, None] = None,
168166
mapping_type: Optional[MappingType] = None,
169167
scale_precision: torch.dtype = torch.float32,

torchao/quantization/quant_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1095,7 +1095,7 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
10951095
:language: python
10961096
"""
10971097

1098-
layout: Optional[Layout] = PlainLayout()
1098+
layout: Optional[Layout] = None
10991099
act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC
11001100
weight_only_decode: bool = False
11011101
granularity: Optional[
@@ -1108,6 +1108,8 @@ def __post_init__(self):
11081108
torch._C._log_api_usage_once(
11091109
"torchao.quantization.Int8DynamicActivationInt8WeightConfig"
11101110
)
1111+
if self.layout is None:
1112+
self.layout = PlainLayout()
11111113
if self.version == 2:
11121114
act_granularity, weight_granularity = Int8Tensor._normalize_granularity(
11131115
self.granularity

torchao/quantization/quant_primitives.py

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,11 @@ class ZeroPointDomain(Enum):
9292
class _TorchAODTypeMeta(EnumMeta):
9393
def __getattribute__(cls, name):
9494
result = super().__getattribute__(name)
95-
warnings.warn(
96-
"Deprecation: TorchAODType is deprecated, please use the torch.intN dtype instead "
97-
"(e.g. TorchAODType.INT4 -> torch.int4)"
98-
)
95+
if isinstance(result, cls):
96+
warnings.warn(
97+
"Deprecation: TorchAODType is deprecated, please use the torch.intN dtype instead "
98+
"(e.g. TorchAODType.INT4 -> torch.int4)"
99+
)
99100
return result
100101

101102

@@ -128,36 +129,41 @@ class TorchAODType(Enum, metaclass=_TorchAODTypeMeta):
128129
Map from dtype to the bound value of integers
129130
TODO: maybe can replace this with call to torch.iinfo
130131
"""
131-
_DTYPE_TO_QVALUE_BOUNDS: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = {
132-
torch.uint8: (0, 255),
133-
torch.int8: (-128, 127),
134-
torch.int16: (-(2**15), 2**15 - 1),
135-
torch.int32: (-(2**31), 2**31 - 1),
136-
}
137-
_DTYPE_TO_BIT_WIDTH: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = {
138-
TorchAODType.INT1: 1,
139-
TorchAODType.INT2: 2,
140-
TorchAODType.INT3: 3,
141-
TorchAODType.INT4: 4,
142-
TorchAODType.INT5: 5,
143-
TorchAODType.INT6: 6,
144-
TorchAODType.INT7: 7,
145-
torch.uint8: 8,
146-
torch.int8: 8,
147-
torch.int16: 16,
148-
torch.int32: 32,
149-
}
132+
# Suppress TorchAODType deprecation warnings for internal usage
133+
with warnings.catch_warnings():
134+
warnings.simplefilter("ignore", UserWarning)
135+
136+
_DTYPE_TO_QVALUE_BOUNDS: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = {
137+
torch.uint8: (0, 255),
138+
torch.int8: (-128, 127),
139+
torch.int16: (-(2**15), 2**15 - 1),
140+
torch.int32: (-(2**31), 2**31 - 1),
141+
}
150142

151-
_SUB_BYTE_UINT_BOUNDS: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = {}
152-
_SUB_BYTE_INT_BOUNDS: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = {
153-
TorchAODType.INT1: (-(2**0), 2**0 - 1),
154-
TorchAODType.INT2: (-(2**1), 2**1 - 1),
155-
TorchAODType.INT3: (-(2**2), 2**2 - 1),
156-
TorchAODType.INT4: (-(2**3), 2**3 - 1),
157-
TorchAODType.INT5: (-(2**4), 2**4 - 1),
158-
TorchAODType.INT6: (-(2**5), 2**5 - 1),
159-
TorchAODType.INT7: (-(2**6), 2**6 - 1),
160-
}
143+
_DTYPE_TO_BIT_WIDTH: Dict[Union[torch.dtype, TorchAODType], int] = {
144+
TorchAODType.INT1: 1,
145+
TorchAODType.INT2: 2,
146+
TorchAODType.INT3: 3,
147+
TorchAODType.INT4: 4,
148+
TorchAODType.INT5: 5,
149+
TorchAODType.INT6: 6,
150+
TorchAODType.INT7: 7,
151+
torch.uint8: 8,
152+
torch.int8: 8,
153+
torch.int16: 16,
154+
torch.int32: 32,
155+
}
156+
157+
_SUB_BYTE_UINT_BOUNDS: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = {}
158+
_SUB_BYTE_INT_BOUNDS: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = {
159+
TorchAODType.INT1: (-(2**0), 2**0 - 1),
160+
TorchAODType.INT2: (-(2**1), 2**1 - 1),
161+
TorchAODType.INT3: (-(2**2), 2**2 - 1),
162+
TorchAODType.INT4: (-(2**3), 2**3 - 1),
163+
TorchAODType.INT5: (-(2**4), 2**4 - 1),
164+
TorchAODType.INT6: (-(2**5), 2**5 - 1),
165+
TorchAODType.INT7: (-(2**6), 2**6 - 1),
166+
}
161167

162168
_SUB_BYTE_UINT_BOUNDS = {
163169
torch.uint1: (0, 2**1 - 1),

0 commit comments

Comments
 (0)