Skip to content

Commit 643bc80

Browse files
committed
Remove unified.py (Quantizer and TwoStepQuantizer ABCs)
Summary: Delete `torchao/quantization/unified.py` which defined the `Quantizer` and `TwoStepQuantizer` abstract base classes. These were trivial ABCs that only declared method signatures (`quantize`, `prepare`, `convert`) which all subclasses already implement. Remove the base class inheritance from all subclasses and clean up imports. Test Plan: pytest test/quantization/test_qat.py -x pytest test/quantization/test_quant_api.py -x ghstack-source-id: a5c1952 Pull Request resolved: #4264
1 parent c0708ce commit 643bc80

File tree

9 files changed

+10
-66
lines changed

9 files changed

+10
-66
lines changed

test/quantization/test_qat.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,6 @@
8585
quantize_affine,
8686
)
8787
from torchao.quantization.quantize_.workflows import Int4PackingFormat
88-
from torchao.quantization.unified import (
89-
TwoStepQuantizer,
90-
)
9188
from torchao.quantization.utils import (
9289
_get_per_token_block_size,
9390
compute_error,
@@ -751,7 +748,7 @@ def test_qat_4w_quantizer(self):
751748
ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0
752749
)
753750

754-
class _MyQATQuantizer(TwoStepQuantizer):
751+
class _MyQATQuantizer:
755752
"""
756753
Dummy quantizer that attaches a certain value to each nn.Linear's
757754
`_temp_quantizer_values` attribute.

test/quantization/test_quant_api.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@
4747
ModuleFqnToConfig,
4848
PerRow,
4949
PerTensor,
50-
Quantizer,
51-
TwoStepQuantizer,
5250
_replace_with_custom_fn_if_matches_filter,
5351
)
5452
from torchao.quantization.quant_primitives import MappingType
@@ -90,7 +88,7 @@ def capture_and_prepare(model, example_inputs):
9088
return m
9189

9290

93-
class XNNPackDynamicQuantizer(TwoStepQuantizer):
91+
class XNNPackDynamicQuantizer:
9492
def prepare(self, model: torch.nn.Module) -> torch.nn.Module:
9593
_replace_with_custom_fn_if_matches_filter(
9694
model,
@@ -110,7 +108,7 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module:
110108
return model
111109

112110

113-
class TorchCompileDynamicQuantizer(Quantizer):
111+
class TorchCompileDynamicQuantizer:
114112
def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
115113
quantize_(model, Int8DynamicActivationInt8WeightConfig())
116114
return model

torchao/quantization/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060
IntxUnpackedToInt8Tensor,
6161
)
6262
from .transform_module import register_quantize_module_handler
63-
from .unified import Quantizer, TwoStepQuantizer
6463
from .utils import (
6564
compute_error,
6665
)
@@ -124,7 +123,5 @@
124123
"Int4WeightOnlyQuantizer",
125124
"Int8DynActInt4WeightQuantizer",
126125
"Int8DynActInt4WeightLinear",
127-
"TwoStepQuantizer",
128-
"Quantizer",
129126
"Float8MMConfig",
130127
]

torchao/quantization/linear_quant_modules.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
MappingType,
2222
dequantize_affine,
2323
)
24-
from .unified import Quantizer
2524
from .utils import (
2625
group_quantize_tensor_symmetric,
2726
groupwise_affine_quantize_tensor,
@@ -232,7 +231,7 @@ def replace_linear_int4(
232231
)
233232

234233

235-
class Int4WeightOnlyQuantizer(Quantizer):
234+
class Int4WeightOnlyQuantizer:
236235
def __init__(
237236
self,
238237
groupsize: int = 256,
@@ -532,7 +531,7 @@ def replace_linear_8da4w(
532531
)
533532

534533

535-
class Int8DynActInt4WeightQuantizer(Quantizer):
534+
class Int8DynActInt4WeightQuantizer:
536535
def __init__(
537536
self,
538537
groupsize: int = 256,

torchao/quantization/qat/api.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import logging
99
from dataclasses import dataclass
1010
from enum import Enum
11-
from typing import Any, List, Optional, Tuple
11+
from typing import Any, Optional, Tuple
1212

1313
import torch
1414

@@ -17,7 +17,6 @@
1717
_QUANTIZE_CONFIG_HANDLER,
1818
register_quantize_module_handler,
1919
)
20-
from torchao.quantization.unified import TwoStepQuantizer
2120

2221
from .embedding import FakeQuantizedEmbedding
2322
from .fake_quantize_config import (
@@ -420,7 +419,7 @@ def _from_intx_quantization_aware_training_transform(
420419
return mod
421420

422421

423-
class ComposableQATQuantizer(TwoStepQuantizer):
422+
class ComposableQATQuantizer:
424423
"""
425424
Composable quantizer that users can use to apply multiple QAT quantizers easily.
426425
Quantizers will be applied in the order they are specified in the constructor.
@@ -440,7 +439,7 @@ class ComposableQATQuantizer(TwoStepQuantizer):
440439
model = my_quantizer.convert(model)
441440
"""
442441

443-
def __init__(self, quantizers: List[TwoStepQuantizer]):
442+
def __init__(self, quantizers: list):
444443
torch._C._log_api_usage_once("torchao.quantization.qat.ComposableQATQuantizer")
445444
self.quantizers = quantizers
446445

torchao/quantization/qat/embedding.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import torch.nn.functional as F
1111

1212
from torchao.quantization.quant_primitives import TorchAODType
13-
from torchao.quantization.unified import TwoStepQuantizer
1413
from torchao.quantization.utils import get_group_qparams_symmetric
1514

1615
from .fake_quantize_config import (
@@ -136,7 +135,7 @@ def from_embedding(
136135
# ======================================
137136

138137

139-
class Int4WeightOnlyEmbeddingQATQuantizer(TwoStepQuantizer):
138+
class Int4WeightOnlyEmbeddingQATQuantizer:
140139
"""
141140
Quantizer for performing QAT on a model, where embedding layers have
142141
int4 fake quantized grouped per channel weights.

torchao/quantization/qat/linear.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
TorchAODType,
2323
ZeroPointDomain,
2424
)
25-
from torchao.quantization.unified import TwoStepQuantizer
2625
from torchao.quantization.utils import get_group_qparams_symmetric
2726
from torchao.utils import _is_device
2827

@@ -181,7 +180,7 @@ def disable_linear_fake_quant(mod: torch.nn.Module):
181180
# ===========================
182181

183182

184-
class _LegacyQATQuantizer(TwoStepQuantizer):
183+
class _LegacyQATQuantizer:
185184
"""
186185
Base class for sharing common methods across legacy QAT quantizers.
187186
"""

torchao/quantization/quant_api.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,12 @@
103103
MappingType,
104104
quantize_affine,
105105
)
106-
from .unified import Quantizer, TwoStepQuantizer
107106

108107
logger = logging.getLogger(__name__)
109108

110109
# TODO: revisit this list?
111110
__all__ = [
112111
"swap_conv2d_1x1_to_linear",
113-
"Quantizer",
114-
"TwoStepQuantizer",
115112
"Int4WeightOnlyQuantizer",
116113
"_get_subclass_inserter",
117114
"quantize_",

torchao/quantization/unified.py

Lines changed: 0 additions & 41 deletions
This file was deleted.

0 commit comments

Comments
 (0)