Skip to content

Commit 4b7219f

Browse files
FindHaometa-codesync[bot]
authored andcommitted
Detect function references in call arguments for reproducer (#348)
Summary: Pull Request resolved: #348 The reproducer's AST call graph analyzer (`ast_analyzer.py`) only detected direct function calls (`func(args)`) but failed to detect function references passed as arguments to higher-order functions. This caused the `_mask_scalar` function to be missing from generated repro scripts when it was passed by reference to `tl.map_elementwise(_mask_scalar, qk, ...)`, resulting in a `NameError` at compilation time. The fix adds argument scanning in `visit_Call` to detect `ast.Name` nodes in both positional and keyword arguments that resolve to known local functions, recording them as dependency edges in the call graph. This ensures the transitive closure correctly includes functions passed by reference. Reviewed By: bhuang7477 Differential Revision: D94126107 fbshipit-source-id: 64c9b137decc5fd2e2e4a6b2a7ec46dd3f8b72f7
1 parent 26b6517 commit 4b7219f

2 files changed

Lines changed: 244 additions & 0 deletions

File tree

tritonparse/reproducer/ast_analyzer.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,24 @@ def visit_Call(self, node: ast.Call) -> None:
530530
callee = "<dynamic_call>"
531531

532532
self._record_call(callee, node, maybe_triton=maybe_triton)
533+
534+
# Detect function references passed as positional arguments.
535+
# This handles higher-order function patterns such as:
536+
# tl.map_elementwise(_mask_scalar, qk, ...)
537+
# where _mask_scalar is a local function passed by reference.
538+
for arg in node.args:
539+
if isinstance(arg, ast.Name):
540+
resolved = self._resolve_name(arg.id)
541+
if resolved in self.local_functions:
542+
self._record_call(resolved, arg)
543+
544+
# Same for keyword arguments (e.g., fn=my_helper)
545+
for kw in node.keywords:
546+
if kw.value is not None and isinstance(kw.value, ast.Name):
547+
resolved = self._resolve_name(kw.value.id)
548+
if resolved in self.local_functions:
549+
self._record_call(resolved, kw.value)
550+
533551
self.generic_visit(node)
534552

535553

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
# pyre-strict
4+
5+
"""
6+
Tests for CallGraph AST analyzer.
7+
8+
This test suite validates call graph analysis, including detection of
9+
function references passed as arguments to higher-order functions
10+
(e.g., tl.map_elementwise(_mask_scalar, ...)).
11+
"""
12+
13+
import ast
14+
import os
15+
import tempfile
16+
import unittest
17+
18+
from tritonparse.reproducer.ast_analyzer import CallGraph
19+
20+
21+
class TestHigherOrderFunctionDetection(unittest.TestCase):
22+
"""Test that function references passed as call arguments are detected."""
23+
24+
def _analyze_source(
25+
self,
26+
source: str,
27+
backend: str,
28+
callee_prefix_filters: list[str] | None = None,
29+
) -> CallGraph:
30+
"""Helper: write source to a temp file, parse, and analyze."""
31+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as tmp:
32+
tmp.write(source)
33+
tmp.flush()
34+
tmp_path = tmp.name
35+
36+
try:
37+
tree = ast.parse(source, filename=tmp_path)
38+
analyzer = CallGraph(
39+
filename=tmp_path,
40+
module_name="test_module",
41+
backends=[f"test_module.{backend}"],
42+
transitive_closure=True,
43+
callee_prefix_filters=callee_prefix_filters or [],
44+
)
45+
analyzer.visit(tree)
46+
return analyzer
47+
finally:
48+
os.unlink(tmp_path)
49+
50+
def test_function_ref_in_positional_arg(self) -> None:
51+
"""Test that a function passed as a positional argument is detected."""
52+
source = """\
53+
def helper(x):
54+
return x * 2
55+
56+
def apply_fn(fn, data):
57+
return fn(data)
58+
59+
def main_kernel(data):
60+
return apply_fn(helper, data)
61+
"""
62+
analyzer = self._analyze_source(source, "main_kernel")
63+
dependent = analyzer.get_dependent_functions()
64+
dep_short_names = {name.split(".")[-1] for name in dependent}
65+
66+
# Assert: both apply_fn (direct call) and helper (passed as arg) are dependencies
67+
self.assertIn("apply_fn", dep_short_names)
68+
self.assertIn("helper", dep_short_names)
69+
70+
def test_function_ref_in_keyword_arg(self) -> None:
71+
"""Test that a function passed as a keyword argument is detected."""
72+
source = """\
73+
def helper(x):
74+
return x + 1
75+
76+
def apply_fn(data, fn=None):
77+
return fn(data)
78+
79+
def main_kernel(data):
80+
return apply_fn(data, fn=helper)
81+
"""
82+
analyzer = self._analyze_source(source, "main_kernel")
83+
dependent = analyzer.get_dependent_functions()
84+
dep_short_names = {name.split(".")[-1] for name in dependent}
85+
86+
self.assertIn("apply_fn", dep_short_names)
87+
self.assertIn("helper", dep_short_names)
88+
89+
def test_map_elementwise_pattern(self) -> None:
90+
"""Test the tl.map_elementwise(_mask_scalar, ...) pattern.
91+
92+
This reproduces the exact bug: _mask_scalar is passed as a function
93+
reference to tl.map_elementwise, which is filtered out by prefix
94+
filters. The function reference should still be detected.
95+
"""
96+
source = """\
97+
import triton
98+
import triton.language as tl
99+
100+
@triton.jit
101+
def _mask_scalar(qk, col_limit_right, s, i):
102+
return qk
103+
104+
@triton.jit
105+
def _apply_causal_mask(qk, col_limit_right, BLOCK_N: tl.constexpr):
106+
offs_n = tl.arange(0, BLOCK_N)
107+
s = offs_n & ~15
108+
i = offs_n & 15
109+
return tl.map_elementwise(_mask_scalar, qk, col_limit_right, s, i)
110+
111+
@triton.jit
112+
def _softmax_inner_loop(qk):
113+
return _apply_causal_mask(qk, 0, 128)
114+
115+
@triton.jit
116+
def _attn_fwd_ws(data):
117+
return _softmax_inner_loop(data)
118+
"""
119+
analyzer = self._analyze_source(
120+
source, "_attn_fwd_ws", callee_prefix_filters=["triton.", "tl."]
121+
)
122+
dependent = analyzer.get_dependent_functions()
123+
dep_short_names = {name.split(".")[-1] for name in dependent}
124+
125+
# Assert: the full call chain is detected, including _mask_scalar
126+
self.assertIn("_softmax_inner_loop", dep_short_names)
127+
self.assertIn("_apply_causal_mask", dep_short_names)
128+
self.assertIn(
129+
"_mask_scalar",
130+
dep_short_names,
131+
"_mask_scalar should be detected even though it is passed as an "
132+
"argument to tl.map_elementwise (which is filtered by prefix)",
133+
)
134+
135+
def test_non_function_args_are_not_detected(self) -> None:
136+
"""Test that regular variable arguments do not create false edges."""
137+
source = """\
138+
def helper():
139+
return 42
140+
141+
def main_kernel(data):
142+
x = 10
143+
return helper() + x
144+
"""
145+
analyzer = self._analyze_source(source, "main_kernel")
146+
dependent = analyzer.get_dependent_functions()
147+
dep_short_names = {name.split(".")[-1] for name in dependent}
148+
149+
# Assert: helper is detected (direct call), but no extra functions
150+
self.assertIn("helper", dep_short_names)
151+
self.assertEqual(len(dep_short_names), 1)
152+
153+
def test_transitive_function_ref_detection(self) -> None:
154+
"""Test transitive closure through function references.
155+
156+
If A calls B(C) where C is passed by reference, and C calls D,
157+
then D should also be in the dependency set.
158+
"""
159+
source = """\
160+
def leaf_fn(x):
161+
return x
162+
163+
def mid_fn(x):
164+
return leaf_fn(x)
165+
166+
def apply_fn(fn, data):
167+
return fn(data)
168+
169+
def main_kernel(data):
170+
return apply_fn(mid_fn, data)
171+
"""
172+
analyzer = self._analyze_source(source, "main_kernel")
173+
dependent = analyzer.get_dependent_functions()
174+
dep_short_names = {name.split(".")[-1] for name in dependent}
175+
176+
# Assert: full chain is detected
177+
self.assertIn("apply_fn", dep_short_names)
178+
self.assertIn("mid_fn", dep_short_names)
179+
self.assertIn("leaf_fn", dep_short_names)
180+
181+
def test_multiple_function_refs_in_one_call(self) -> None:
182+
"""Test that multiple function references in a single call are all detected."""
183+
source = """\
184+
def fn_a(x):
185+
return x + 1
186+
187+
def fn_b(x):
188+
return x + 2
189+
190+
def combine(f1, f2, data):
191+
return f1(data) + f2(data)
192+
193+
def main_kernel(data):
194+
return combine(fn_a, fn_b, data)
195+
"""
196+
analyzer = self._analyze_source(source, "main_kernel")
197+
dependent = analyzer.get_dependent_functions()
198+
dep_short_names = {name.split(".")[-1] for name in dependent}
199+
200+
self.assertIn("fn_a", dep_short_names)
201+
self.assertIn("fn_b", dep_short_names)
202+
self.assertIn("combine", dep_short_names)
203+
204+
def test_function_ref_with_filtered_caller(self) -> None:
205+
"""Test function ref detection works even when the enclosing call is filtered.
206+
207+
When tl.map_elementwise is filtered by callee_prefix_filters,
208+
the function reference in its arguments should still be recorded.
209+
"""
210+
source = """\
211+
def _scalar_op(x):
212+
return x * 2
213+
214+
def kernel(data):
215+
return tl.map_elementwise(_scalar_op, data)
216+
"""
217+
analyzer = self._analyze_source(source, "kernel", callee_prefix_filters=["tl."])
218+
dependent = analyzer.get_dependent_functions()
219+
dep_short_names = {name.split(".")[-1] for name in dependent}
220+
221+
# Assert: _scalar_op should still be detected despite tl. being filtered
222+
self.assertIn("_scalar_op", dep_short_names)
223+
224+
225+
if __name__ == "__main__":
226+
unittest.main()

0 commit comments

Comments
 (0)