forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathquant_api.py
More file actions
1834 lines (1597 loc) · 70.3 KB
/
quant_api.py
File metadata and controls
1834 lines (1597 loc) · 70.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024-2025 Arm Limited and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Quantization APIs
Generally these APIs can be applied directly to any model
with Linear modules to obtain quantized linear ops. The intended
usage involves applying torch.compile to the model afterwards
both because primitives were designed based on the fusions that
come along with it and because that is how we access the intended quantized
and mixed GEMM kernels
"""
import inspect
import logging
import re
import types
import warnings
from collections import OrderedDict
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import OrderedDict as OrderedDictType
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize
import torchao
from torchao.core.config import AOBaseConfig
from torchao.float8.config import e4m3_dtype
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.inference import (
Float8MMConfig,
FP8Granularity,
_check_hardware_support,
_granularity_is_a_1_128_w_128_128,
_normalize_granularity,
)
# for BC, make sure to keep the `noqa: F401` comments to prevent
# ruff from removing "unused imports"
from torchao.prototype.quantization.quant_api import (
Float8StaticActivationFloat8WeightConfig, # noqa: F401
)
from torchao.quantization.linear_activation_weight_observed_tensor import (
LinearActivationWeightObservedTensor,
)
from torchao.quantization.observer import AffineQuantizedObserverBase
from torchao.quantization.quantize_.common import (
KernelPreference,
)
from torchao.quantization.quantize_.workflows import (
Float8PackingFormat,
Float8Tensor,
Int4ChooseQParamsAlgorithm,
Int4PackingFormat,
Int4PlainInt32Tensor,
Int4PreshuffledTensor,
Int4Tensor,
Int4TilePackedTo4dTensor,
Int8Tensor,
IntxChooseQParamsAlgorithm,
IntxOpaqueTensor,
IntxPackingFormat,
IntxUnpackedToInt8Tensor,
QuantizeTensorToFloat8Kwargs,
QuantizeTensorToInt8Kwargs,
Sparse2x4CUTLASSFloat8Tensor,
)
from torchao.quantization.transform_module import (
_QUANTIZE_CONFIG_HANDLER,
register_quantize_module_handler,
)
from torchao.quantization.utils import (
_fp8_mm_compat,
_linear_extra_repr,
_module_extra_repr,
_quantization_type,
)
from torchao.utils import (
is_MI300,
is_MI350,
is_sm_at_least_89,
)
from .granularity import (
Granularity,
PerAxis,
PerGroup,
PerRow,
PerTensor,
)
from .linear_quant_modules import (
Int4WeightOnlyQuantizer,
Int8DynActInt4WeightQuantizer,
)
from .qat import (
intx_quantization_aware_training,
)
from .quant_primitives import (
_DTYPE_TO_QVALUE_BOUNDS,
MappingType,
quantize_affine,
)
from .unified import Quantizer, TwoStepQuantizer
logger = logging.getLogger(__name__)
# TODO: revisit this list?
__all__ = [
"swap_conv2d_1x1_to_linear",
"Quantizer",
"TwoStepQuantizer",
"Int4WeightOnlyQuantizer",
"_get_subclass_inserter",
"quantize_",
"intx_quantization_aware_training",
"Int8DynActInt4WeightQuantizer",
"ModuleFqnToConfig",
]
def _replace_with_custom_fn_if_matches_filter(
model,
replacement_fn,
filter_fn,
cur_fqn="",
device=None,
extra_args: Optional[Tuple[Any, ...]] = (),
) -> None:
"""
Recursively replaces each child module in `model` with the result of `replacement_fn(child)`
if `filter_fn(child)` returns `True`.
Args:
model (torch.nn.Module): The model containing modules to be replaced.
replacement_fn (Callable[[torch.nn.Module], torch.nn.Module]): The function to replace matching modules.
filter_fn (Callable[[torch.nn.Module], bool]): The filter function to determine which modules to replace.
cur_fqn (str, optional): The current fully qualified name of the module being processed. Defaults to "".
device (device, optional): Device to move the model to before applying `filter_fn`. Defaults to None.
extra_args (Tuple[Any, ...], optional): optional extra args to pass to `replacement_fn`.
Returns:
None
"""
if filter_fn(model, cur_fqn[:-1]):
if device is not None:
model.to(device=device) # move to device before quantization
model = replacement_fn(model, *extra_args)
return model
else:
named_children_list = list(model.named_children())
for name, child in named_children_list:
new_child = _replace_with_custom_fn_if_matches_filter(
child,
replacement_fn,
filter_fn,
f"{cur_fqn}{name}.",
device,
extra_args,
)
if new_child is not child and new_child is not None:
setattr(model, name, new_child)
if device is not None:
model.to(device=device) # move parent module to device
return model
def _is_linear(mod, *args):
# avoid circular dependencies
from torchao.quantization.qat.affine_fake_quantized_tensor import (
_AffineFakeQuantizedTensor,
)
# adding weight tensor subclass isinstance check to make sure the weight is only quantized once
# when it is shared by multiple linear modules
# TODO: check isinstance(TorchAOBaseTensor)?
return (
isinstance(mod, torch.nn.Linear)
and hasattr(mod, "weight")
and not isinstance(mod.weight, _AffineFakeQuantizedTensor)
and not isinstance(mod, nn.modules.linear.NonDynamicallyQuantizableLinear)
)
def _get_subclass_inserter(cls, enable_parametrization=False, **kwargs):
"""
Returns a function which inserts the given subclass into all linear modules
in the model. The inserted module will have its weight set to the result of
`cls(mod.weight, **kwargs)`. If parametrization is enabled then this will be done using
torch.nn.utils.parametrize instead of directly setting the attribute on the module.
Args:
cls (torch.Tensor): The class to insert as a child module.
kwargs (Any): Any additional arguments for the constructor.
"""
constructor = kwargs.pop("constructor", "subclass_constructor")
from_float = kwargs.pop("method", "from_float")
def insert_subclass(lin):
if enable_parametrization:
lin.weight = torch.nn.Parameter(
cls.from_float(lin.weight, **kwargs), requires_grad=False
)
_, args = lin.weight.__tensor_flatten__()
parametrize.register_parametrization(
lin, "weight", getattr(cls, constructor)(*args)
)
else:
lin.weight = torch.nn.Parameter(
# cls.from_float(...)
getattr(cls, from_float)(lin.weight, **kwargs),
requires_grad=False,
)
return lin
return insert_subclass
def swap_conv2d_1x1_to_linear(model, filter_fn=None):
"""
Changes all conv2d 1x1 modules to equivalent linear modules so that they can then be quantized.
"""
class PermuteSandwich(torch.nn.Module):
def __init__(self, mod):
super().__init__()
self.mod = mod
def forward(self, *args):
return self.mod(args[0].permute(0, 2, 3, 1)).permute(-0, 3, 1, 2)
def replace_conv2d_1x1(conv):
assert conv.kernel_size == (1, 1)
lin = torch.nn.Linear(
conv.in_channels, conv.out_channels, bias=(conv.bias is None)
)
lin.weight = torch.nn.Parameter(conv.weight.squeeze(-1, -2))
lin.bias = conv.bias
return PermuteSandwich(lin)
if filter_fn is None:
filter_fn = lambda mod, *args: isinstance(
mod, torch.nn.Conv2d
) and mod.kernel_size == (1, 1)
_replace_with_custom_fn_if_matches_filter(
model, replace_conv2d_1x1, filter_fn=filter_fn
)
def insert_observers_(
model: nn.Module,
input_observer: Optional[AffineQuantizedObserverBase],
weight_observer: Optional[AffineQuantizedObserverBase],
*,
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,
):
"""
Converts the weight of a linear module to a LinearActivationWeightObservedTensor.
This function wraps the weight of the given linear module with a LinearActivationWeightObservedTensor,
which enables observation of both input and weight tensors during forward passes.
The wrapped weight is then re-wrapped as a nn.Parameter to maintain compatibility
with PyTorch's module system.
Example::
```
import torch
import torch.nn as nn
from torchao.quantization import PerTensor
from torchao.quantization.linear_observer_tensor import insert_observers_
from torchao.quantization.observer import (
AffineQuantizedMinMaxObserver,
MappingType
)
# Create observers
input_observer = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=ZeroPointDomain.NONE,
)
# Create a linear module
linear_module = nn.Linear(10, 20)
# Convert the linear module's weight to an observed tensor
insert_observers_(linear_module, input_observer, weight_observer=None)
# The linear_module can now be used as usual, with observers calculating statistics
output = linear_module(torch.randn(10, 10))
# Get the scale and zero point of the input observer
scale, zero_point = linear_module.weight.input_observer.calculate_qparams()
```
Args:
model (nn.Module): The nn.Module to convert.
input_observer (Optional[AffineQuantizedObserverBase]): Observer for input tensor.
weight_observer (Optional[AffineQuantizedObserverBase]): Observer for weight tensor.
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): Filter function to select which modules to convert.
If not provided, all linear modules will be converted. This function should take a module and its fully qualified name.
Returns:
nn.Linear: The modified linear module with its weight wrapped in a LinearActivationWeightObservedTensor.
"""
def convert_to_linear_observer(linear_module: nn.Linear):
# Wrap the weight with LinearActivationWeightObservedTensor and then with nn.Parameter
linear_module.weight = nn.Parameter(
LinearActivationWeightObservedTensor.from_float(
linear_module.weight,
input_observer=input_observer,
weight_observer=weight_observer,
),
requires_grad=linear_module.weight.requires_grad,
)
return linear_module
_replace_with_custom_fn_if_matches_filter(
model,
convert_to_linear_observer,
_is_linear if filter_fn is None else filter_fn,
)
def _embedding_extra_repr(self):
return f"num_embeddings={self.weight.shape[0]}, embedding_dim={self.weight.shape[1]}, weight={_quantization_type(self.weight)}"
def _get_linear_subclass_inserter(
constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs
):
"""Helper function to apply the constructor that quantizes the weight Tensor (with additional kwargs)
to the weight of linear module
"""
def insert_subclass(lin):
requires_grad = allow_requires_grad and lin.weight.requires_grad
if propagate_bias == True:
kwargs["bias"] = lin.bias
lin.weight = torch.nn.Parameter(
constructor(lin.weight, **kwargs), requires_grad=requires_grad
)
lin.extra_repr = types.MethodType(_linear_extra_repr, lin)
return lin
return insert_subclass
def quantize_(
model: torch.nn.Module,
config: AOBaseConfig,
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = _is_linear,
device: Optional[torch.types.Device] = None,
):
"""Convert the weight of linear modules in the model with `config`, model is modified inplace
Args:
model (torch.nn.Module): input model
config (AOBaseConfig): a workflow configuration object.
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `config` on
the weight of the module
device (device, optional): Device to move module to before applying `filter_fn`. This can be set to `"cuda"` to speed up quantization. The final model will be on the specified `device`.
Defaults to None (do not change device).
Example::
import torch
import torch.nn as nn
from torchao import quantize_
# quantize with some predefined `config` method that corresponds to
# optimized execution paths or kernels (e.g. int4 tinygemm kernel)
# also customizable with arguments
# currently options are
# Int8DynamicActivationInt8WeightConfig (optimized with int8 mm op and torch.compile)
# Int4WeightOnlyConfig (optimized with int4 tinygemm kernel and torch.compile)
# Int8WeightOnlyConfig (optimized with int8 mm op and torch.compile
from torchao.quantization.quant_api import Int4WeightOnlyConfig
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
quantize_(m, Int4WeightOnlyConfig(group_size=32))
"""
torch._C._log_api_usage_once("torchao.quantization.quantize_")
if isinstance(config, FqnToConfig):
if filter_fn is not None:
raise ValueError(
"Custom filter_fn and FqnToConfig were both specified. Only filter_fn=None is supported when FqnToConfig is specified."
)
named_modules = dict(model.named_modules())
for module_fqn, module in named_modules.items():
if (
fqn_matches_fqn_config(module_fqn, config)
or _module_param_matches_fqn_config(module, module_fqn, config)
or ("_default" in config.fqn_to_config and _is_linear(module))
):
replacement = _fqn_to_config_handler(module, module_fqn, config)
if device is not None:
replacement = replacement.to(device=device)
# handle module swap
if replacement is not module and module_fqn != "":
child_name = module_fqn.split(".")[-1]
parent_fqn = module_fqn.removesuffix(child_name).removesuffix(".")
parent_module = named_modules[parent_fqn]
setattr(parent_module, child_name, replacement)
elif isinstance(config, AOBaseConfig):
filter_fn = _is_linear if filter_fn is None else filter_fn
handler = _QUANTIZE_CONFIG_HANDLER[type(config)]
# for each linear in the model, apply the transform if filtering passes
_replace_with_custom_fn_if_matches_filter(
model,
handler,
filter_fn,
device=device,
extra_args=(config,),
)
else:
raise AssertionError(
"""Passing a generic Callable to `quantize_` is no longer recommended and will be deprecated at a later release. Please see https://github.com/pytorch/ao/issues/1690 for instructions on how to pass in workflow configuration instead."""
)
@dataclass
class Int8DynamicActivationIntxWeightConfig(AOBaseConfig):
"""
Configuration for dynamically quantizing activations to torch.int8 and weights to torch.intx, with 1 <= x <= 8.
More specifically, activations are dynamically quantized to 8-bits at a per-token granularity with scales/zeros.
Weights are quantized with scales/zeros in a groupwise or channelwise manner using the number of bits specified by weight_dtype.
args:
`weight_dtype`: The dtype to use for weight quantization. Must be torch.intx, where 1 <= x <= 8.
` weight_granularity`: The granularity to use for weight quantization. Must be PerGroup or PerAxis(axis=0).
`weight_mapping_type`: The type of mapping to use for the weight quantization.
Must be one of MappingType.ASYMMETRIC or MappingType.SYMMETRIC. MappingType.SYMMETRIC requires ZeroPointDomain.NONE
`weight_scale_dtype`: The dtype to use for the weight scale.
`act_mapping_type`: The type of mapping to use for the activation quantization.
Must be one of MappingType.ASYMMETRIC or MappingType.SYMMETRIC.
`intx_packing_format`: The format to use for the packed weight tensor (version 2 only).
- unpacked_to_int8: this format is the default and is intended for export applications like ExecuTorch.
- opaque_torchao_auto: this format is optimized for CPU performance.
`intx_choose_qparams_algorithm`: The algorithm to use for choosing the quantization parameters.
`version`: version of the config to use, only subset of above args are valid based on version, see note for more details.
Example:
.. literalinclude:: ../../examples/inference/int8_dynamic_activation_intx_weight.py
:language: python
"""
weight_dtype: torch.dtype = torch.int8
weight_granularity: Granularity = PerGroup(32)
weight_mapping_type: MappingType = MappingType.SYMMETRIC
weight_scale_dtype: Optional[torch.dtype] = None
act_mapping_type: MappingType = MappingType.ASYMMETRIC
intx_packing_format: IntxPackingFormat = IntxPackingFormat.UNPACKED_TO_INT8
intx_choose_qparams_algorithm: IntxChooseQParamsAlgorithm = (
IntxChooseQParamsAlgorithm.AFFINE
)
version: int = 2
def __post_init__(self):
torch._C._log_api_usage_once(
"torchao.quantization.Int8DynamicActivationIntxWeightConfig"
)
assert self.weight_dtype in [getattr(torch, f"int{b}") for b in range(1, 9)], (
f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}"
)
assert isinstance(self.weight_granularity, (PerAxis, PerGroup)), (
f"weight_granularity must be PerAxis or PerGroup, but got {self.weight_granularity}"
)
if isinstance(self.weight_granularity, PerAxis):
assert self.weight_granularity.axis == 0, (
f"axis must be 0, but got {self.weight_granularity.axis}"
)
assert self.weight_mapping_type in [
MappingType.ASYMMETRIC,
MappingType.SYMMETRIC,
MappingType.SYMMETRIC_NO_CLIPPING_ERR,
], (
f"weight_mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC or MappingType.SYMMETRIC_NO_CLIPPING_ERR, but got {self.weight_mapping_type}"
)
assert self.act_mapping_type in [
MappingType.ASYMMETRIC,
MappingType.SYMMETRIC,
], (
f"act_mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC, but got {self.act_mapping_type}"
)
def _int8_dynamic_activation_intx_weight_quantize_tensor(
weight,
bias,
config,
*,
custom_scale: Optional[torch.Tensor] = None,
custom_zero_point: Optional[torch.Tensor] = None,
):
weight_dtype = config.weight_dtype
weight_granularity = config.weight_granularity
weight_mapping_type = config.weight_mapping_type
weight_scale_dtype = config.weight_scale_dtype
act_mapping_type = config.act_mapping_type
intx_packing_format = config.intx_packing_format
intx_choose_qparams_algorithm = config.intx_choose_qparams_algorithm
assert weight.dim() == 2, (
f"Int8DynamicActivationIntxWeightConfig only works for 2-d Tensor, got: {weight.dim()}"
)
if isinstance(weight_granularity, PerGroup):
group_size = weight_granularity.group_size
elif isinstance(weight_granularity, PerAxis):
assert weight_granularity.axis == 0, (
f"axis must be 0 with PerAxis, but got {weight_granularity.axis}"
)
group_size = weight.shape[-1]
else:
raise ValueError(
f"weight_granularity must be PerGroup or PerAxis, got {weight_granularity}"
)
block_size = (1, group_size)
assert config.version == 2
assert act_mapping_type == MappingType.ASYMMETRIC
opaque_formats = [
IntxPackingFormat.OPAQUE_ATEN_KLEIDIAI,
IntxPackingFormat.OPAQUE_TORCHAO_AUTO,
IntxPackingFormat.OPAQUE_TORCHAO_KLEIDIAI,
IntxPackingFormat.OPAQUE_TORCHAO_LOWBIT,
]
assert (
intx_packing_format == IntxPackingFormat.UNPACKED_TO_INT8
or intx_packing_format in opaque_formats
), f"Unsupported packing format: {intx_packing_format}"
if custom_zero_point is not None and custom_zero_point.dtype == torch.int32:
custom_zero_point = custom_zero_point.to(torch.int8)
new_weight = IntxUnpackedToInt8Tensor.from_hp(
weight,
block_size,
weight_dtype,
mapping_type=weight_mapping_type,
activation_quantization="int8_asym_per_token",
intx_choose_qparams_algorithm=intx_choose_qparams_algorithm,
custom_scale=custom_scale,
custom_zero_point=custom_zero_point,
)
if weight_scale_dtype is not None and weight_scale_dtype != weight.dtype:
_adjust_scale_dtype_in_intx_unpacked_tensor(
new_weight, weight, weight_scale_dtype
)
new_bias = bias
# Create packed tensor
if intx_packing_format in opaque_formats:
new_weight = IntxOpaqueTensor.from_intx_unpacked_to_int8_tensor(
new_weight, bias=new_bias, intx_packing_format=intx_packing_format
)
new_bias = None # bias is packed with weights
return new_weight, new_bias
@register_quantize_module_handler(Int8DynamicActivationIntxWeightConfig)
def _int8_dynamic_activation_intx_weight_transform(
module: torch.nn.Module,
config: Int8DynamicActivationIntxWeightConfig,
*,
parameter_name: str = "weight",
custom_scale: Optional[torch.Tensor] = None,
custom_zero_point: Optional[torch.Tensor] = None,
) -> torch.nn.Module:
new_weight, new_bias = _int8_dynamic_activation_intx_weight_quantize_tensor(
getattr(module, parameter_name),
module.bias,
config,
custom_scale=custom_scale,
custom_zero_point=custom_zero_point,
)
setattr(
module,
parameter_name,
torch.nn.Parameter(new_weight, requires_grad=False),
)
if new_bias is None:
module.bias = None
module.extra_repr = types.MethodType(
partial(
_module_extra_repr,
original_extra_repr=module.extra_repr,
parameter_name=parameter_name,
),
module,
)
return module
@dataclass
class Int4WeightOnlyConfig(AOBaseConfig):
"""
Configuration for int4 weight only quantization, only groupwise quantization is supported.
Args:
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
size is more fine grained, choices are [256, 128, 64, 32]
`int4_packing_format`: the packing format for int4 tensor
`int4_choose_qparams_algorithm`: variants of choose qparams algorithm to use for int4,
currently support TINYGEMM ("tinygemm") and HQQ ("hqq")
`set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values.
`int4_tile_packed_ntile`: ntile size for TILED_PACKED_TO_4D format, default is 8 for CUDA platform, 16 for ROCm platform
Example:
.. literalinclude:: ../../examples/inference/int4_weight_only.py
:language: python
"""
group_size: int = 128
set_inductor_config: bool = True
int4_packing_format: Int4PackingFormat = Int4PackingFormat.PLAIN
int4_choose_qparams_algorithm: Int4ChooseQParamsAlgorithm = (
Int4ChooseQParamsAlgorithm.TINYGEMM
)
int4_tile_packed_ntile: int = 8
version: int = 2
def __post_init__(self):
assert self.int4_tile_packed_ntile in [8, 16], (
"int4_tile_packed_ntile must be either 8 or 16"
)
torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig")
def _int4_weight_only_quantize_tensor(weight, config):
# TODO(future PR): perhaps move this logic to a different file, to keep the API
# file clean of implementation details
# for now, make these local variables to allow the rest of the function
# to be a direct copy-paste
group_size = config.group_size
int4_choose_qparams_algorithm = config.int4_choose_qparams_algorithm
int4_packing_format = config.int4_packing_format
int4_tile_packed_ntile = config.int4_tile_packed_ntile
if weight.shape[-1] % group_size != 0:
logger.info(
f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}"
)
return weight
block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])
assert config.version == 2
block_size = list(block_size)
if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ:
assert int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D, (
f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {int4_packing_format}, "
f"it's only supported by Int4PackingFormat.TILE_PACKED_TO_4D currently"
)
if int4_packing_format == Int4PackingFormat.PRESHUFFLED:
new_weight = Int4PreshuffledTensor.from_hp(
weight,
block_size,
activation_dtype=torch.bfloat16,
)
return new_weight
elif int4_packing_format == Int4PackingFormat.PLAIN:
new_weight = Int4Tensor.from_hp(
weight,
block_size,
)
return new_weight
elif int4_packing_format == Int4PackingFormat.PLAIN_INT32:
new_weight = Int4PlainInt32Tensor.from_hp(
weight,
block_size,
)
return new_weight
elif int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D:
new_weight = Int4TilePackedTo4dTensor.from_hp(
weight,
block_size,
int4_choose_qparams_algorithm=int4_choose_qparams_algorithm,
ntile_size=int4_tile_packed_ntile,
)
return new_weight
else:
raise ValueError(f"Unsupported int4 packing format: {int4_packing_format}")
@register_quantize_module_handler(Int4WeightOnlyConfig)
def _int4_weight_only_transform(
module: torch.nn.Module,
config: Int4WeightOnlyConfig,
*,
parameter_name: str = "weight",
) -> torch.nn.Module:
if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()
assert hasattr(module, parameter_name), (
f"applying int4 weight only quant requires module to have {parameter_name} attribute"
+ f" but {module} does not have one"
)
new_weight = _int4_weight_only_quantize_tensor(
getattr(module, parameter_name), config
)
setattr(
module,
parameter_name,
torch.nn.Parameter(new_weight, requires_grad=False),
)
module.extra_repr = types.MethodType(
partial(
_module_extra_repr,
original_extra_repr=module.extra_repr,
parameter_name=parameter_name,
),
module,
)
return module
@dataclass
class Float8DynamicActivationInt4WeightConfig(AOBaseConfig):
"""Configuration for apply float8 dynamic per row quantization and int4
per group weight quantization to linear
(only group_size 128 is supported right now since underlying kernel used only supports 128
and above and no benefits of making it bigger)
Args:
`int4_packing_format`: how the weight is packed, supported values are "preshuffled" and "plain"
Example:
.. literalinclude:: ../../examples/inference/float8_dynamic_activation_int4_weight.py
:language: python
"""
int4_packing_format: Int4PackingFormat = "preshuffled"
@register_quantize_module_handler(Float8DynamicActivationInt4WeightConfig)
def _float8_dynamic_activation_int4_weight_transform(
module: torch.nn.Module,
config: Float8DynamicActivationInt4WeightConfig,
*,
parameter_name: str = "weight",
) -> torch.nn.Module:
assert hasattr(module, parameter_name), (
f"applying float8 dynamic activation int4 weight quant requires module to have {parameter_name} attribute"
+ f" but {module} does not have one"
)
int4_packing_format = config.int4_packing_format
assert int4_packing_format in (
"preshuffled",
"plain",
), (
f"only preshuffled and plain int4_packing_format supported right now, got: {int4_packing_format}"
)
weight = getattr(module, parameter_name)
group_size = 128
block_size = list([1 for _ in range(weight.ndim - 1)] + [group_size])
if int4_packing_format == "preshuffled":
new_weight = Int4PreshuffledTensor.from_hp(
module.weight,
block_size,
activation_dtype=torch.float8_e4m3fn,
)
else:
# plain format
new_weight = Int4Tensor.from_hp(
module.weight,
block_size,
activation_dtype=torch.float8_e4m3fn,
)
setattr(
module,
parameter_name,
torch.nn.Parameter(new_weight, requires_grad=False),
)
module.extra_repr = types.MethodType(
partial(
_module_extra_repr,
original_extra_repr=module.extra_repr,
parameter_name=parameter_name,
),
module,
)
return module
@dataclass
class Int8WeightOnlyConfig(AOBaseConfig):
"""
Configuration for applying int8 weight-only symmetric per-channel quantization to linear layers.
Args:
group_size (version 1) - Controls the granularity of quantization.
If None, applies per-channel quantization. Otherwise, applies per-group quantization with the specified group size.
granularity (version 2) - Quantization granularity.
PerRow() for per-channel quantization, PerTensor() for per-tensor quantization,
PerGroup(group_size) for per-group quantization.
set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values
for better performance with this quantization scheme.
Example:
.. literalinclude:: ../../examples/inference/int8_weight_only.py
:language: python
"""
group_size: Optional[int] = None
granularity: Optional[Granularity] = PerRow()
set_inductor_config: bool = True
version: int = 2
def __post_init__(self):
torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig")
if self.version == 1:
raise ValueError(
"version 1 of Int8WeightOnlyConfig has been removed, please use version 2, "
"see https://github.com/pytorch/ao/issues/2752 for more details"
)
assert self.group_size is None, (
f"Only support version 2 with group_size=None, got {self.group_size}. "
f"Use granularity=PerGroup({self.group_size}) instead."
)
assert isinstance(self.granularity, (PerTensor, PerRow, PerGroup)), (
f"granularity must be PerTensor, PerRow, or PerGroup, but got {self.granularity}"
)
def _int8_weight_only_quantize_tensor(weight, config):
assert config.version == 2, f"Unexpected version: {config.version}"
new_weight = Int8Tensor.from_hp(weight, granularity=config.granularity)
return new_weight
@register_quantize_module_handler(Int8WeightOnlyConfig)
def _int8_weight_only_transform(
module: torch.nn.Module,
config: Int8WeightOnlyConfig,
*,
parameter_name: str = "weight",
):
if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()
assert hasattr(module, parameter_name), (
"applying int8 weight only quant requires module to have {parameter_name} attribute"
+ " but {module} does not have one"
)
quantized_tensor = _int8_weight_only_quantize_tensor(
getattr(module, parameter_name), config
)
setattr(
module,
parameter_name,
torch.nn.Parameter(quantized_tensor, requires_grad=False),
)
module.extra_repr = types.MethodType(
partial(
_module_extra_repr,
original_extra_repr=module.extra_repr,
parameter_name=parameter_name,
),
module,
)
return module
def _validate_granularity_int8(
act_granularity: Granularity,
weight_granularity: Granularity,
) -> None:
supported = (PerTensor, PerRow)
if not isinstance(act_granularity, supported):
raise ValueError(
f"Unsupported activation granularity type: {type(act_granularity)}. "
f"Only PerTensor and PerRow are supported."
)
# Validate activation granularity - PerRow must use dim=-1 (per-token)
if isinstance(act_granularity, PerRow) and act_granularity.dim != -1:
raise ValueError(
f"Only PerRow(dim=-1) is supported for activation quantization, "
f"got PerRow(dim={act_granularity.dim}). "
f"Per-feature activation quantization is not supported due to slicing limitations."
)
if not isinstance(weight_granularity, supported):
raise ValueError(
f"Unsupported weight granularity type: {type(weight_granularity)}. "
f"Only PerTensor and PerRow are supported."
)
@dataclass
class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
"""
Configuration for applying int8 dynamic per-token activation and int8 per-channel weight
quantization to linear layers.
Args:
granularity: Optional[Union[Granularity, Tuple[Granularity, Granularity], List[Granularity]]] = PerRow()
The granularity for quantization. Can be either a single granularity (applied to both
activations and weights) or a tuple / list of two granularities (first for activations, second for weights).
If None, defaults to PerRow for both. Only PerTensor and PerRow are supported.
act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC - Mapping type for activation quantization.
SYMMETRIC and ASYMMETRIC are supported.
set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values
for better performance with this quantization scheme.
version (int): the version of the config
reduce_range (Optional[bool] = False): If True, both activation and weight int8 quantization use reduced range
[_REDUCED_QUANT_MIN, _REDUCED_QUANT_MAX] instead of full range
[_FULL_QUANT_MIN, _FULL_QUANT_MAX] to reduce overflow risk on platforms without VNNI instructions.
Example:
.. literalinclude:: ../../examples/inference/int8_dynamic_activation_int8_weight.py
:language: python
"""
act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC
weight_only_decode: bool = False
granularity: Optional[
Union[Granularity, Tuple[Granularity, Granularity], list[Granularity]]
] = PerRow()
set_inductor_config: bool = True
version: int = 2
reduce_range: Optional[bool] = False
def __post_init__(self):
torch._C._log_api_usage_once(
"torchao.quantization.Int8DynamicActivationInt8WeightConfig"
)
if self.version == 1:
raise ValueError(
"version 1 of Int8DynamicActivationInt8WeightConfig has been removed, please use version 2, "
"see https://github.com/pytorch/ao/issues/2752 for more details"
)
act_granularity, weight_granularity = Int8Tensor._normalize_granularity(
self.granularity
)
_validate_granularity_int8(act_granularity, weight_granularity)
assert self.act_mapping_type in (
MappingType.SYMMETRIC,
MappingType.ASYMMETRIC,
), (
"Int8DynamicActivationInt8WeightConfig requires "
"`act_mapping_type` in (MappingType.SYMMETRIC, "
"MappingType.ASYMMETRIC). "
"Please set it to MappingType.SYMMETRIC or "
"MappingType.ASYMMETRIC."
)
assert self.reduce_range in (True, False), (
"`reduce_range` must be True or False"
)
def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
assert config.version == 2, f"Unexpected version: {config.version}"
act_granularity, weight_granularity = Int8Tensor._normalize_granularity(
config.granularity
)
quantized_weight = Int8Tensor.from_hp(
weight,
granularity=weight_granularity,
act_quant_kwargs=QuantizeTensorToInt8Kwargs(
granularity=act_granularity,
mapping_type=config.act_mapping_type,
reduce_range=config.reduce_range,
),
reduce_range=config.reduce_range,
)
return quantized_weight