|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | + |
| 3 | +# pyre-strict |
| 4 | + |
| 5 | +""" |
| 6 | +Tests for CallGraph AST analyzer. |
| 7 | +
|
| 8 | +This test suite validates call graph analysis, including detection of |
| 9 | +function references passed as arguments to higher-order functions |
| 10 | +(e.g., tl.map_elementwise(_mask_scalar, ...)). |
| 11 | +""" |
| 12 | + |
| 13 | +import ast |
| 14 | +import os |
| 15 | +import tempfile |
| 16 | +import unittest |
| 17 | + |
| 18 | +from tritonparse.reproducer.ast_analyzer import CallGraph |
| 19 | + |
| 20 | + |
| 21 | +class TestHigherOrderFunctionDetection(unittest.TestCase): |
| 22 | + """Test that function references passed as call arguments are detected.""" |
| 23 | + |
| 24 | + def _analyze_source( |
| 25 | + self, |
| 26 | + source: str, |
| 27 | + backend: str, |
| 28 | + callee_prefix_filters: list[str] | None = None, |
| 29 | + ) -> CallGraph: |
| 30 | + """Helper: write source to a temp file, parse, and analyze.""" |
| 31 | + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as tmp: |
| 32 | + tmp.write(source) |
| 33 | + tmp.flush() |
| 34 | + tmp_path = tmp.name |
| 35 | + |
| 36 | + try: |
| 37 | + tree = ast.parse(source, filename=tmp_path) |
| 38 | + analyzer = CallGraph( |
| 39 | + filename=tmp_path, |
| 40 | + module_name="test_module", |
| 41 | + backends=[f"test_module.{backend}"], |
| 42 | + transitive_closure=True, |
| 43 | + callee_prefix_filters=callee_prefix_filters or [], |
| 44 | + ) |
| 45 | + analyzer.visit(tree) |
| 46 | + return analyzer |
| 47 | + finally: |
| 48 | + os.unlink(tmp_path) |
| 49 | + |
| 50 | + def test_function_ref_in_positional_arg(self) -> None: |
| 51 | + """Test that a function passed as a positional argument is detected.""" |
| 52 | + source = """\ |
| 53 | +def helper(x): |
| 54 | + return x * 2 |
| 55 | +
|
| 56 | +def apply_fn(fn, data): |
| 57 | + return fn(data) |
| 58 | +
|
| 59 | +def main_kernel(data): |
| 60 | + return apply_fn(helper, data) |
| 61 | +""" |
| 62 | + analyzer = self._analyze_source(source, "main_kernel") |
| 63 | + dependent = analyzer.get_dependent_functions() |
| 64 | + dep_short_names = {name.split(".")[-1] for name in dependent} |
| 65 | + |
| 66 | + # Assert: both apply_fn (direct call) and helper (passed as arg) are dependencies |
| 67 | + self.assertIn("apply_fn", dep_short_names) |
| 68 | + self.assertIn("helper", dep_short_names) |
| 69 | + |
| 70 | + def test_function_ref_in_keyword_arg(self) -> None: |
| 71 | + """Test that a function passed as a keyword argument is detected.""" |
| 72 | + source = """\ |
| 73 | +def helper(x): |
| 74 | + return x + 1 |
| 75 | +
|
| 76 | +def apply_fn(data, fn=None): |
| 77 | + return fn(data) |
| 78 | +
|
| 79 | +def main_kernel(data): |
| 80 | + return apply_fn(data, fn=helper) |
| 81 | +""" |
| 82 | + analyzer = self._analyze_source(source, "main_kernel") |
| 83 | + dependent = analyzer.get_dependent_functions() |
| 84 | + dep_short_names = {name.split(".")[-1] for name in dependent} |
| 85 | + |
| 86 | + self.assertIn("apply_fn", dep_short_names) |
| 87 | + self.assertIn("helper", dep_short_names) |
| 88 | + |
| 89 | + def test_map_elementwise_pattern(self) -> None: |
| 90 | + """Test the tl.map_elementwise(_mask_scalar, ...) pattern. |
| 91 | +
|
| 92 | + This reproduces the exact bug: _mask_scalar is passed as a function |
| 93 | + reference to tl.map_elementwise, which is filtered out by prefix |
| 94 | + filters. The function reference should still be detected. |
| 95 | + """ |
| 96 | + source = """\ |
| 97 | +import triton |
| 98 | +import triton.language as tl |
| 99 | +
|
| 100 | +@triton.jit |
| 101 | +def _mask_scalar(qk, col_limit_right, s, i): |
| 102 | + return qk |
| 103 | +
|
| 104 | +@triton.jit |
| 105 | +def _apply_causal_mask(qk, col_limit_right, BLOCK_N: tl.constexpr): |
| 106 | + offs_n = tl.arange(0, BLOCK_N) |
| 107 | + s = offs_n & ~15 |
| 108 | + i = offs_n & 15 |
| 109 | + return tl.map_elementwise(_mask_scalar, qk, col_limit_right, s, i) |
| 110 | +
|
| 111 | +@triton.jit |
| 112 | +def _softmax_inner_loop(qk): |
| 113 | + return _apply_causal_mask(qk, 0, 128) |
| 114 | +
|
| 115 | +@triton.jit |
| 116 | +def _attn_fwd_ws(data): |
| 117 | + return _softmax_inner_loop(data) |
| 118 | +""" |
| 119 | + analyzer = self._analyze_source( |
| 120 | + source, "_attn_fwd_ws", callee_prefix_filters=["triton.", "tl."] |
| 121 | + ) |
| 122 | + dependent = analyzer.get_dependent_functions() |
| 123 | + dep_short_names = {name.split(".")[-1] for name in dependent} |
| 124 | + |
| 125 | + # Assert: the full call chain is detected, including _mask_scalar |
| 126 | + self.assertIn("_softmax_inner_loop", dep_short_names) |
| 127 | + self.assertIn("_apply_causal_mask", dep_short_names) |
| 128 | + self.assertIn( |
| 129 | + "_mask_scalar", |
| 130 | + dep_short_names, |
| 131 | + "_mask_scalar should be detected even though it is passed as an " |
| 132 | + "argument to tl.map_elementwise (which is filtered by prefix)", |
| 133 | + ) |
| 134 | + |
| 135 | + def test_non_function_args_are_not_detected(self) -> None: |
| 136 | + """Test that regular variable arguments do not create false edges.""" |
| 137 | + source = """\ |
| 138 | +def helper(): |
| 139 | + return 42 |
| 140 | +
|
| 141 | +def main_kernel(data): |
| 142 | + x = 10 |
| 143 | + return helper() + x |
| 144 | +""" |
| 145 | + analyzer = self._analyze_source(source, "main_kernel") |
| 146 | + dependent = analyzer.get_dependent_functions() |
| 147 | + dep_short_names = {name.split(".")[-1] for name in dependent} |
| 148 | + |
| 149 | + # Assert: helper is detected (direct call), but no extra functions |
| 150 | + self.assertIn("helper", dep_short_names) |
| 151 | + self.assertEqual(len(dep_short_names), 1) |
| 152 | + |
| 153 | + def test_transitive_function_ref_detection(self) -> None: |
| 154 | + """Test transitive closure through function references. |
| 155 | +
|
| 156 | + If A calls B(C) where C is passed by reference, and C calls D, |
| 157 | + then D should also be in the dependency set. |
| 158 | + """ |
| 159 | + source = """\ |
| 160 | +def leaf_fn(x): |
| 161 | + return x |
| 162 | +
|
| 163 | +def mid_fn(x): |
| 164 | + return leaf_fn(x) |
| 165 | +
|
| 166 | +def apply_fn(fn, data): |
| 167 | + return fn(data) |
| 168 | +
|
| 169 | +def main_kernel(data): |
| 170 | + return apply_fn(mid_fn, data) |
| 171 | +""" |
| 172 | + analyzer = self._analyze_source(source, "main_kernel") |
| 173 | + dependent = analyzer.get_dependent_functions() |
| 174 | + dep_short_names = {name.split(".")[-1] for name in dependent} |
| 175 | + |
| 176 | + # Assert: full chain is detected |
| 177 | + self.assertIn("apply_fn", dep_short_names) |
| 178 | + self.assertIn("mid_fn", dep_short_names) |
| 179 | + self.assertIn("leaf_fn", dep_short_names) |
| 180 | + |
| 181 | + def test_multiple_function_refs_in_one_call(self) -> None: |
| 182 | + """Test that multiple function references in a single call are all detected.""" |
| 183 | + source = """\ |
| 184 | +def fn_a(x): |
| 185 | + return x + 1 |
| 186 | +
|
| 187 | +def fn_b(x): |
| 188 | + return x + 2 |
| 189 | +
|
| 190 | +def combine(f1, f2, data): |
| 191 | + return f1(data) + f2(data) |
| 192 | +
|
| 193 | +def main_kernel(data): |
| 194 | + return combine(fn_a, fn_b, data) |
| 195 | +""" |
| 196 | + analyzer = self._analyze_source(source, "main_kernel") |
| 197 | + dependent = analyzer.get_dependent_functions() |
| 198 | + dep_short_names = {name.split(".")[-1] for name in dependent} |
| 199 | + |
| 200 | + self.assertIn("fn_a", dep_short_names) |
| 201 | + self.assertIn("fn_b", dep_short_names) |
| 202 | + self.assertIn("combine", dep_short_names) |
| 203 | + |
| 204 | + def test_function_ref_with_filtered_caller(self) -> None: |
| 205 | + """Test function ref detection works even when the enclosing call is filtered. |
| 206 | +
|
| 207 | + When tl.map_elementwise is filtered by callee_prefix_filters, |
| 208 | + the function reference in its arguments should still be recorded. |
| 209 | + """ |
| 210 | + source = """\ |
| 211 | +def _scalar_op(x): |
| 212 | + return x * 2 |
| 213 | +
|
| 214 | +def kernel(data): |
| 215 | + return tl.map_elementwise(_scalar_op, data) |
| 216 | +""" |
| 217 | + analyzer = self._analyze_source(source, "kernel", callee_prefix_filters=["tl."]) |
| 218 | + dependent = analyzer.get_dependent_functions() |
| 219 | + dep_short_names = {name.split(".")[-1] for name in dependent} |
| 220 | + |
| 221 | + # Assert: _scalar_op should still be detected despite tl. being filtered |
| 222 | + self.assertIn("_scalar_op", dep_short_names) |
| 223 | + |
| 224 | + |
| 225 | +if __name__ == "__main__": |
| 226 | + unittest.main() |
0 commit comments