Skip to content

Commit 712c565

Browse files
committed
Update on "Remove WeightTensorWithLinearActivationScaleMetadata and related code"
Summary: Delete `linear_activation_scale.py` which defined `WeightTensorWithLinearActivationScaleMetadata` and its helper `to_weight_tensor_with_linear_activation_scale_metadata`. Remove the import and `__all__` entry from `torchao/quantization/__init__.py`. Test Plan: python -c "import torchao.quantization" [ghstack-poisoned]
2 parents 0e1ceff + 30a1614 commit 712c565

File tree

3 files changed

+94
-6
lines changed

3 files changed

+94
-6
lines changed

.github/scripts/ci_test_xpu.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ python3 -c "import torch; import torchao; print(f'Torch version: {torch.__versio
1414

1515
pip install pytest expecttest parameterized accelerate hf_transfer 'modelscope!=1.15.0' transformers tabulate fire
1616

17-
pytest -v -s torchao/test/quantization/pt2e/ \
17+
pytest -v -s --ignore=torchao/test/quantization/pt2e/test_x86inductor_fusion.py \
18+
torchao/test/quantization/pt2e/ \
1819
torchao/test/quantization/*.py \
1920
torchao/test/dtypes/ \
2021
torchao/test/float8/ \

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3330,6 +3330,74 @@ def has_inplace_ops(graph_module: torch.fx.GraphModule) -> bool:
33303330
result = m(*example_inputs)
33313331
self.assertIsNotNone(result)
33323332

3333+
def test_quantize_in_place_index_put(self):
3334+
class IndexPutQuantizer(Quantizer):
3335+
def __init__(self) -> None:
3336+
super().__init__()
3337+
self.qspec = QuantizationSpec(
3338+
dtype=torch.int8,
3339+
observer_or_fake_quant_ctr=observer.default_observer,
3340+
quant_min=-128,
3341+
quant_max=127,
3342+
qscheme=torch.per_tensor_symmetric,
3343+
)
3344+
3345+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
3346+
for node in model.graph.nodes:
3347+
if node.op != "call_function":
3348+
continue
3349+
if node.target != torch.ops.aten.index_put_.default:
3350+
continue
3351+
3352+
dst = node.args[0]
3353+
value = node.args[2]
3354+
node.meta["quantization_annotation"] = QuantizationAnnotation(
3355+
input_qspec_map={
3356+
dst: self.qspec,
3357+
value: SharedQuantizationSpec((dst, node)),
3358+
},
3359+
output_qspec=SharedQuantizationSpec((dst, node)),
3360+
_annotated=True,
3361+
)
3362+
return model
3363+
3364+
def transform_for_annotation(
3365+
self, model: torch.fx.GraphModule
3366+
) -> torch.fx.GraphModule:
3367+
return model
3368+
3369+
def validate(self, model: torch.fx.GraphModule) -> None:
3370+
return None
3371+
3372+
class M(torch.nn.Module):
3373+
def __init__(self) -> None:
3374+
super().__init__()
3375+
self.register_buffer("buf", torch.zeros(4, dtype=torch.float32))
3376+
3377+
def forward(self, x: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
3378+
updated = self.buf.index_put_((idx,), x)
3379+
return updated.clone()
3380+
3381+
m = M().eval()
3382+
quantizer = IndexPutQuantizer()
3383+
example_inputs = (
3384+
torch.tensor([1.0, 2.0], dtype=torch.float32),
3385+
torch.tensor([1, 3], dtype=torch.int64),
3386+
)
3387+
m = torch.export.export(m, example_inputs, strict=True).module()
3388+
3389+
m = prepare_pt2e(m, quantizer)
3390+
m(*example_inputs)
3391+
m = convert_pt2e(m, fold_quantize=True)
3392+
3393+
# Check that the named buffer is not folded
3394+
# If it folded it will be named _frozen_param0
3395+
self.assertTrue("buf" in dict(m.named_buffers()))
3396+
3397+
# Verify the quantized model works
3398+
result = m(*example_inputs)
3399+
self.assertIsNotNone(result)
3400+
33333401
def test_scan_op_quantization(self):
33343402
"""Test that prepare_pt2e and convert_pt2e correctly quantize ops
33353403
inside the combine_fn subgraph of torch._higher_order_ops.scan.

torchao/quantization/pt2e/constant_fold.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
import torch.utils._pytree as pytree
1212
from torch._inductor.freezing_utils import maybe_set_is_frozen_param
13+
from torch.ao.quantization.fx.utils import collect_producer_nodes
1314
from torch.utils._ordered_set import OrderedSet
1415

1516
aten = torch.ops.aten
@@ -94,19 +95,37 @@ def __init__(
9495
# Identify mutable buffers by finding copy_ operations
9596
self.mutable_buffers = self._find_mutable_buffers()
9697

98+
def _is_mutable_buffer(self, node: torch.fx.Node) -> bool:
99+
"""Check if a node is a mutable buffer."""
100+
named_buffers = dict(self.module.named_buffers())
101+
if node.op == "placeholder":
102+
return True
103+
104+
if node.op == "get_attr" and str(node.target) in named_buffers:
105+
return True
106+
107+
return False
108+
97109
def _find_mutable_buffers(self) -> set[torch.fx.Node]:
98-
"""Find mutable buffers by identifying copy_ operations.
99-
The first argument of copy_ op is the mutable buffer."""
110+
"""Find mutable buffers by identifying copy_ or put_ operations.
111+
The graph then traces all nodes that lead to a mutable buffer."""
100112
mutable_buffers = set()
101113
for node in self.module.graph.nodes:
102114
if (
103115
node.op == "call_function"
104116
and hasattr(node.target, "_schema")
105-
and "copy_" in str(node.target)
117+
and ("copy_" in str(node.target) or "put_" in str(node.target))
106118
):
107-
# The first argument of copy_ is the mutable buffer
119+
# The first argument of copy_ or put_ is the mutable input.
120+
# If any producer in the chain is a mutable buffer, mark
121+
# all producers as mutable to prevent constant folding.
108122
if len(node.args) > 0 and isinstance(node.args[0], torch.fx.Node):
109-
mutable_buffers.add(node.args[0])
123+
producer_nodes = collect_producer_nodes(node.args[0])
124+
if producer_nodes is not None and any(
125+
self._is_mutable_buffer(p) for p in producer_nodes
126+
):
127+
mutable_buffers.update(producer_nodes)
128+
110129
return mutable_buffers
111130

112131
def _support_dynamic_shape(self) -> bool:

0 commit comments

Comments
 (0)