|
11 | 11 | """ |
12 | 12 |
|
13 | 13 | import contextlib |
| 14 | +import logging |
14 | 15 | import io |
15 | 16 | import unittest |
16 | 17 | from functools import partial |
@@ -118,18 +119,29 @@ def tearDown(self): |
118 | 119 | torch._dynamo.reset() |
119 | 120 |
|
120 | 121 | def _run_fusion_pass(self, model, *args): |
121 | | - """Compile model with fusion pass, return captured stdout.""" |
| 122 | + """Compile model with fusion pass, return captured logger output.""" |
122 | 123 | inductor_config.pre_grad_custom_pass = partial( |
123 | 124 | rope_sdpa_fusion_pass, |
124 | 125 | rope_sdpa_op=_ops.rope_sdpa_op, |
125 | 126 | fp8_sdpa_op=_ops.fp8_sdpa_op, |
126 | 127 | backend_name="TEST", |
127 | 128 | ) |
128 | 129 | compiled = torch.compile(model) |
129 | | - buf = io.StringIO() |
130 | | - with torch.no_grad(), contextlib.redirect_stdout(buf): |
131 | | - compiled(*args) |
132 | | - return buf.getvalue() |
| 130 | + fusion_logger = logging.getLogger( |
| 131 | + "torchao.prototype.attention.shared_utils.fusion_utils" |
| 132 | + ) |
| 133 | + old_level = fusion_logger.level |
| 134 | + fusion_logger.setLevel(logging.DEBUG) |
| 135 | + handler = logging.StreamHandler(io.StringIO()) |
| 136 | + handler.setLevel(logging.DEBUG) |
| 137 | + fusion_logger.addHandler(handler) |
| 138 | + try: |
| 139 | + with torch.no_grad(): |
| 140 | + compiled(*args) |
| 141 | + return handler.stream.getvalue() |
| 142 | + finally: |
| 143 | + fusion_logger.removeHandler(handler) |
| 144 | + fusion_logger.setLevel(old_level) |
133 | 145 |
|
134 | 146 | def _assert_fused(self, model, *extra_args): |
135 | 147 | """Create BSHD inputs, run fusion pass, assert 1 node was fused.""" |
|
0 commit comments