Skip to content

Commit d9d14aa

Browse files
committed
fix: bugs and code quality improvements across prototype modules
- Replace print() with logger.info() in fusion_utils.py (unconsoleable output during torch.compile) - Add kernel availability guard to Int4OpaqueTensor.from_hp_da8w4() with clear error message - Add kernel availability guard to QuantizedLinear._forward_2d() to prevent AttributeError - Replace mutable default kwargs={} in _replace_embedding_with_quantized_embedding() - Fix fragile stdout capture in test_rope_fusion_detection.py to use logger capture - Add public API exports to embedding/__init__.py (EmbeddingQuantizer, QuantizedLinear, etc.) - Remove unused _is_blackwell() from attention/utils.py - Remove misconfigured @triton.autotune decorators (empty configs, constexpr key)
1 parent a5da06e commit d9d14aa

8 files changed

Lines changed: 47 additions & 39 deletions

File tree

test/prototype/attention/test_rope_fusion_detection.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"""
1212

1313
import contextlib
14+
import logging
1415
import io
1516
import unittest
1617
from functools import partial
@@ -118,18 +119,29 @@ def tearDown(self):
118119
torch._dynamo.reset()
119120

120121
def _run_fusion_pass(self, model, *args):
121-
"""Compile model with fusion pass, return captured stdout."""
122+
"""Compile model with fusion pass, return captured logger output."""
122123
inductor_config.pre_grad_custom_pass = partial(
123124
rope_sdpa_fusion_pass,
124125
rope_sdpa_op=_ops.rope_sdpa_op,
125126
fp8_sdpa_op=_ops.fp8_sdpa_op,
126127
backend_name="TEST",
127128
)
128129
compiled = torch.compile(model)
129-
buf = io.StringIO()
130-
with torch.no_grad(), contextlib.redirect_stdout(buf):
131-
compiled(*args)
132-
return buf.getvalue()
130+
fusion_logger = logging.getLogger(
131+
"torchao.prototype.attention.shared_utils.fusion_utils"
132+
)
133+
old_level = fusion_logger.level
134+
fusion_logger.setLevel(logging.DEBUG)
135+
handler = logging.StreamHandler(io.StringIO())
136+
handler.setLevel(logging.DEBUG)
137+
fusion_logger.addHandler(handler)
138+
try:
139+
with torch.no_grad():
140+
compiled(*args)
141+
return handler.stream.getvalue()
142+
finally:
143+
fusion_logger.removeHandler(handler)
144+
fusion_logger.setLevel(old_level)
133145

134146
def _assert_fused(self, model, *extra_args):
135147
"""Create BSHD inputs, run fusion pass, assert 1 node was fused."""

torchao/prototype/attention/quantization/triton_hadamard_qkv_quantization.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,6 @@
3838
)
3939

4040

41-
@triton.autotune(
42-
configs=[
43-
triton.Config({}, num_warps=2),
44-
triton.Config({}, num_warps=4),
45-
triton.Config({}, num_warps=8),
46-
],
47-
key=["D"],
48-
)
4941
@triton.jit
5042
def hadamard_single_phase1_kernel(
5143
# Input tensor [B, H, S, D]

torchao/prototype/attention/quantization/triton_hadamard_rope_qkv_quantization.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,6 @@
3333
)
3434

3535

36-
@triton.autotune(
37-
configs=[
38-
triton.Config({}, num_warps=2),
39-
triton.Config({}, num_warps=4),
40-
triton.Config({}, num_warps=8),
41-
],
42-
key=["D"],
43-
)
4436
@triton.jit
4537
def hadamard_rope_single_phase1_kernel(
4638
# Input tensor [B, S, H, D]
@@ -160,14 +152,6 @@ def hadamard_rope_single_phase1_kernel(
160152
tl.store(partial_max_ptr + chunk_idx, x_max_scalar)
161153

162154

163-
@triton.autotune(
164-
configs=[
165-
triton.Config({}, num_warps=2),
166-
triton.Config({}, num_warps=4),
167-
triton.Config({}, num_warps=8),
168-
],
169-
key=["D"],
170-
)
171155
@triton.jit
172156
def hadamard_v_phase1_kernel(
173157
# Input tensor [B, S, H, D]

torchao/prototype/attention/shared_utils/fusion_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,7 @@ def rope_sdpa_fusion_pass(
964964
fp8_sdpa_nodes = [n for n in graph.nodes if _is_fp8_sdpa_node(n, fp8_sdpa_op)]
965965

966966
if not fp8_sdpa_nodes:
967-
print(
967+
logger.info(
968968
f"[low_precision_attention] RoPE fusion pass ({backend_name}): "
969969
f"found 0 FP8 SDPA nodes in graph"
970970
)
@@ -1102,7 +1102,7 @@ def rope_sdpa_fusion_pass(
11021102
fused_count += 1
11031103
continue
11041104

1105-
print(
1105+
logger.info(
11061106
f"[low_precision_attention] RoPE fusion pass ({backend_name}): "
11071107
f"found {len(fp8_sdpa_nodes)} FP8 SDPA node(s), "
11081108
f"{fused_count} fused with RoPE"

torchao/prototype/attention/utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,6 @@ def _is_hopper() -> bool:
1616
return major == 9
1717

1818

19-
def _is_blackwell() -> bool:
20-
if not torch.cuda.is_available():
21-
return False
22-
major, _ = torch.cuda.get_device_capability()
23-
return major == 10
24-
25-
2619
def _is_fa3_available() -> bool:
2720
try:
2821
importlib.import_module("flash_attn_interface")
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from .api import (
2+
EmbeddingQuantizer,
3+
QuantizedEmbedding,
4+
QuantizedEmbeddingFallback,
5+
QuantizedLinear,
6+
QuantizedTiedEmbedding,
7+
TiedEmbeddingQuantizer,
8+
)
9+
10+
__all__ = [
11+
"EmbeddingQuantizer",
12+
"QuantizedEmbedding",
13+
"QuantizedEmbeddingFallback",
14+
"QuantizedLinear",
15+
"QuantizedTiedEmbedding",
16+
"TiedEmbeddingQuantizer",
17+
]

torchao/prototype/quantization/embedding/api.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,11 @@ def forward(self, x):
142142

143143
def _replace_embedding_with_quantized_embedding(
144144
module: nn.Module,
145-
kwargs={},
145+
kwargs=None,
146146
fqn: str = "",
147147
):
148+
if kwargs is None:
149+
kwargs = {}
148150
group_size = kwargs.get("group_size", None)
149151
bit_width = kwargs.get("bit_width", None)
150152
use_fallback = kwargs.get("use_fallback", None)
@@ -254,6 +256,10 @@ def _forward_2d(self, x):
254256
assert x.dim() == 2
255257
m, k = x.shape
256258
assert k == self.k
259+
assert _is_kernel_library_loaded(), (
260+
"QuantizedLinear requires the torchao kernel library to be loaded. "
261+
"Please build torchao with C++ extensions enabled (USE_CPP=1)."
262+
)
257263
return getattr(
258264
torch.ops.torchao, f"_linear_8bit_act_{self.bit_width}bit_weight"
259265
)(x, self.packed_weight, self.group_size, self.n, self.k)

torchao/prototype/quantization/int4/int4_opaque_tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ def from_hp_da8w4(
233233
act_mapping_type: MappingType.ASYMMETRIC (uint8 activation, default) or
234234
MappingType.SYMMETRIC (int8 activation, requires PyTorch >= 2.8)
235235
"""
236+
assert "CPU" in torch._C._dispatch_dump("torchao::da8w4_linear_prepack_cpu"), (
237+
"DA8W4 on CPU requires the da8w4_linear_cpu kernel to be built and available. "
238+
"Please build torchao with C++ extensions enabled (USE_CPP=1)."
239+
)
236240
assert w.ndim == 2 and w.device.type == "cpu", (
237241
f"Expecting 2D tensor on CPU, but got: {w.shape} on {w.device.type}"
238242
)

0 commit comments

Comments
 (0)