Skip to content

Commit 7911823

Browse files
andrewor14danielvegamyhre
authored andcommitted
Deprecate AQT and related classes (#4074)
**Summary:** Deprecate AffineQuantizedTensor, AQTTensorImpl, Layout, and all subclasses of the above in torchao.dtypes. We are planning to remove these classes in the future, so we deprecate them here in advance. ``` ● Here are all the classes that now have new deprecation warnings: Base classes (torchao/dtypes/utils.py): 1. Layout 2. PlainLayout 3. AQTTensorImpl torchao/dtypes/affine_quantized_tensor.py: 4. AffineQuantizedTensor torchao/dtypes/floatx/: 5. Float8Layout 6. Float8AQTTensorImpl 7. CutlassSemiSparseLayout 8. CutlassSemiSparseTensorImpl torchao/dtypes/uintx/: 9. TensorCoreTiledLayout 10. TensorCoreTiledAQTTensorImpl 11. SemiSparseLayout 12. SemiSparseAQTTensorImpl 13. Int4CPULayout 14. Int4CPUAQTTensorImpl 15. Int4XPULayout 16. Int4XPUAQTTensorImpl 17. QDQLayout 18. QDQTensorImpl 19. PlainAQTTensorImpl 20. PackedLinearInt8DynamicActivationIntxWeightLayout 21. PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl ``` More context here: #2752 **Test Plan:** Manual testing
1 parent 22430f4 commit 7911823

15 files changed

Lines changed: 138 additions & 16 deletions

benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,17 @@ def run_experiment(
116116
requires_grad=True,
117117
).transpose(-2, -1)
118118

119-
offs = generate_jagged_offs(G, total_M, multiple_of=1)
119+
# Create config object from recipe
120+
if isinstance(config.recipe, Float8TrainingRecipe):
121+
quant_config = Float8TrainingOpConfig.from_recipe(config.recipe)
122+
alignment_size = 16 if args.aligned else 1
123+
# TODO: support pad_token_groups_for_grouped_mm option in Float8TrainingOpConfig
124+
else:
125+
quant_config = MXFP8TrainingOpConfig.from_recipe(config.recipe)
126+
quant_config.pad_token_groups_for_grouped_mm = not args.aligned
127+
alignment_size = 32 if args.aligned else 1
128+
129+
offs = generate_jagged_offs(G, total_M, multiple_of=alignment_size)
120130

121131
# fwd_bwd bf16 benchmark + profiling
122132
bf16_fwd_bwd_us = bench_fwd_bwd_microseconds(
@@ -138,12 +148,6 @@ def run_experiment(
138148
profile_name="bf16_profile",
139149
)
140150

141-
# Create config object from recipe
142-
if isinstance(config.recipe, Float8TrainingRecipe):
143-
quant_config = Float8TrainingOpConfig.from_recipe(config.recipe)
144-
else:
145-
quant_config = MXFP8TrainingOpConfig.from_recipe(config.recipe)
146-
147151
# fwd_bwd scaled benchmark + profiling
148152
scaled_fwd_bwd_us = bench_fwd_bwd_microseconds(
149153
_quantize_then_scaled_grouped_mm,
@@ -262,5 +266,11 @@ def main(args: argparse.Namespace):
262266
arg_parser = argparse.ArgumentParser()
263267
arg_parser.add_argument("--compile", action="store_true")
264268
arg_parser.add_argument("--profile", action="store_true")
269+
arg_parser.add_argument(
270+
"--aligned",
271+
action="store_true",
272+
help="If true, token group sizes are pre-aligned, to simulate flow with HybridEP or similar",
273+
)
274+
265275
args = arg_parser.parse_args()
266276
main(args)

test/prototype/moe_training/test_training.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
@pytest.mark.parametrize(
3838
"kernel_preference", [KernelPreference.AUTO, KernelPreference.EMULATED]
3939
)
40+
@pytest.mark.parametrize("token_groups_aligned", [False])
4041
@pytest.mark.parametrize(
4142
"recipe_config",
4243
[
@@ -74,6 +75,7 @@ def test_moe_training(
7475
target_fqns: list[str],
7576
compile: bool,
7677
kernel_preference: KernelPreference,
78+
token_groups_aligned: bool,
7779
recipe_config: dict,
7880
):
7981
(
@@ -110,6 +112,8 @@ def test_moe_training(
110112
pytest.skip(
111113
f"Skipping FP8 rowwise tests, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}"
112114
)
115+
if not token_groups_aligned:
116+
pytest.skip("FP8 rowwise doesn't support per group token padding yet")
113117

114118
# MXFP8 hardware path requires SM100
115119
if recipe in (
@@ -123,7 +127,11 @@ def test_moe_training(
123127
f"Skipping MXFP8 hardware mode tests, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"
124128
)
125129

126-
set_token_group_alignment_size_m(1)
130+
alignment_size = 32 if isinstance(recipe, MXFP8TrainingRecipe) else 16
131+
if not token_groups_aligned:
132+
alignment_size = 1
133+
set_token_group_alignment_size_m(alignment_size)
134+
127135
model_args = MoEArgs(
128136
num_experts=8,
129137
num_shared_experts=1,
@@ -159,6 +167,11 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
159167
else Float8TrainingOpConfig
160168
)
161169
config = config_cls.from_recipe(recipe)
170+
171+
# TODO: support pad_token_groups_for_grouped_mm in Float8TrainingOpConfig
172+
if isinstance(recipe, MXFP8TrainingRecipe) and not token_groups_aligned:
173+
config.pad_token_groups_for_grouped_mm = True
174+
162175
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
163176

164177
# validate that only the experts were converted

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import logging
99
import math
10+
import warnings
1011
from typing import TYPE_CHECKING, Optional, Tuple, Union
1112

1213
import torch
@@ -112,6 +113,9 @@ def __init__(
112113
if zero_point_domain is _DEFAULT_ZPD:
113114
zero_point_domain = ZeroPointDomain.INT
114115
torch._C._log_api_usage_once(str(type(self)))
116+
warnings.warn(
117+
"Deprecation: AffineQuantizedTensor is deprecated and will be removed in a future release of torchao, see https://github.com/pytorch/ao/issues/2752 for more details"
118+
)
115119
self.tensor_impl = tensor_impl
116120
self.block_size = block_size
117121
self.quant_min = quant_min

torchao/dtypes/floatx/cutlass_semi_sparse_layout.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import warnings
67
from dataclasses import dataclass
78
from typing import Optional
89

@@ -42,6 +43,12 @@ def _same_metadata(
4243
class CutlassSemiSparseLayout(Layout):
4344
"""Layout class for float8 2:4 sparsity layout for affine quantized tensor, for cutlass kernel."""
4445

46+
def __post_init__(self):
47+
super().__post_init__()
48+
warnings.warn(
49+
"Deprecation: CutlassSemiSparseLayout is deprecated and will be removed in a future release of torchao, see https://github.com/pytorch/ao/issues/2752 for more details"
50+
)
51+
4552
def pre_process(self, dense: torch.Tensor) -> torch.Tensor:
4653
# prune to 2:4 if not already
4754
from torchao.sparsity.utils import mask_creator
@@ -76,6 +83,9 @@ def __init__(
7683
scale: torch.Tensor,
7784
_layout: Layout,
7885
):
86+
warnings.warn(
87+
"Deprecation: CutlassSemiSparseTensorImpl is deprecated and will be removed in a future release of torchao, see https://github.com/pytorch/ao/issues/2752 for more details"
88+
)
7989
self.sparse = sparse
8090
self.meta = meta
8191
self.scale = scale

torchao/dtypes/floatx/float8_layout.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ class Float8Layout(Layout):
6969

7070
mm_config: Optional[Float8MMConfig] = None
7171

72+
def __post_init__(self):
73+
super().__post_init__()
74+
warnings.warn(
75+
"Deprecation: Float8Layout is deprecated and will be removed in a future release of torchao, see https://github.com/pytorch/ao/issues/2752 for more details"
76+
)
77+
7278

7379
_fallback_warning_shown = False
7480

@@ -110,6 +116,9 @@ def __init__(
110116
transposed: bool,
111117
_layout: Layout,
112118
):
119+
warnings.warn(
120+
"Deprecation: Float8AQTTensorImpl is deprecated and will be removed in a future release of torchao, see https://github.com/pytorch/ao/issues/2752 for more details"
121+
)
113122
warnings.warn(
114123
"Models quantized with version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2649 for more details"
115124
)

torchao/dtypes/uintx/int4_cpu_layout.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@ class Int4CPULayout(Layout):
2929
Only for PyTorch version at least 2.6
3030
"""
3131

32-
pass
32+
def __post_init__(self):
33+
super().__post_init__()
34+
warnings.warn(
35+
"Deprecation: Int4CPULayout is deprecated and will be removed in a future release of torchao, see https://github.com/pytorch/ao/issues/2752 for more details"
36+
)
3337

3438

3539
@register_layout(Int4CPULayout)
@@ -75,6 +79,9 @@ def __init__(
7579
transposed: bool,
7680
_layout: Layout,
7781
):
82+
warnings.warn(
83+
"Deprecation: Int4CPUAQTTensorImpl is deprecated and will be removed in a future release of torchao, see https://github.com/pytorch/ao/issues/2752 for more details"
84+
)
7885
warnings.warn(
7986
"Models quantized with version 1 of Int4WeightOnlyConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2948 for more details"
8087
)

torchao/dtypes/uintx/int4_xpu_layout.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,11 @@ def _linear_fp_act_uint4_weight_int8_zero_impl(input_tensor, weight_tensor, bias
158158
class Int4XPULayout(Layout):
159159
"""Only for PyTorch version at least 2.7"""
160160

161-
pass
161+
def __post_init__(self):
162+
super().__post_init__()
163+
warnings.warn(
164+
"Deprecation: Int4XPULayout is deprecated and will be removed in a future release of torchao, see https://github.com/pytorch/ao/issues/2752 for more details"
165+
)
162166

163167

164168
@register_layout(Int4XPULayout)
@@ -211,6 +215,9 @@ def __init__(
211215
scale: torch.Tensor = None,
212216
zero: torch.Tensor = None,
213217
):
218+
warnings.warn(
219+
"Deprecation: Int4XPUAQTTensorImpl is deprecated and will be removed in a future release of torchao, see https://github.com/pytorch/ao/issues/2752 for more details"
220+
)
214221
warnings.warn(
215222
"Models quantized with version 1 of Int4WeightOnlyConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2948 for more details"
216223
)

torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ def __init__(
7070
self,
7171
target: Union[str, Target] = "auto",
7272
):
73+
warnings.warn(
74+
"Deprecation: PackedLinearInt8DynamicActivationIntxWeightLayout is deprecated and will be removed in a future release of torchao, see https://github.com/pytorch/ao/issues/2752 for more details"
75+
)
7376
warnings.warn(
7477
"Models quantized with version 1 of IntxWeightOnlyConfig/Int8DynamicActivationIntxWeightConfig are deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2967 for more details"
7578
)
@@ -130,6 +133,9 @@ def __init__(
130133
packed_weight: torch.Tensor,
131134
_layout: Layout,
132135
):
136+
warnings.warn(
137+
"Deprecation: PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl is deprecated and will be removed in a future release of torchao, see https://github.com/pytorch/ao/issues/2752 for more details"
138+
)
133139
assert isinstance(_layout, PackedLinearInt8DynamicActivationIntxWeightLayout)
134140
self.packed_weight = packed_weight
135141
self._layout = _layout

torchao/dtypes/uintx/plain_layout.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import warnings
67
from typing import Optional, Tuple
78

89
import torch
@@ -77,6 +78,10 @@ def __init__(
7778
zero_point: Optional[torch.Tensor],
7879
_layout: Layout,
7980
):
81+
if type(self) is PlainAQTTensorImpl:
82+
warnings.warn(
83+
"Deprecation: PlainAQTTensorImpl is deprecated and will be removed in a future release of torchao, see https://github.com/pytorch/ao/issues/2752 for more details"
84+
)
8085
self.int_data = int_data
8186
self.scale = scale
8287
self.zero_point = zero_point

torchao/dtypes/uintx/q_dq_layout.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@
4040

4141
@dataclass(frozen=True)
4242
class QDQLayout(Layout):
43-
pass
43+
def __post_init__(self):
44+
super().__post_init__()
45+
warnings.warn(
46+
"Deprecation: QDQLayout is deprecated and will be removed in a future release of torchao, see https://github.com/pytorch/ao/issues/2752 for more details"
47+
)
4448

4549

4650
def _same_metadata(self: "QDQTensorImpl", src: "QDQTensorImpl") -> bool:
@@ -96,6 +100,9 @@ def __init__(
96100
zero_point: Optional[torch.Tensor],
97101
_layout: Layout,
98102
):
103+
warnings.warn(
104+
"Deprecation: QDQTensorImpl is deprecated and will be removed in a future release of torchao, see https://github.com/pytorch/ao/issues/2752 for more details"
105+
)
99106
warnings.warn(
100107
"Models quantized with version 1 of IntxWeightOnlyConfig/Int8DynamicActivationIntxWeightConfig are deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2967 for more details"
101108
)

0 commit comments

Comments
 (0)