Skip to content

Commit fa507b3

Browse files
Fix OVERRIDE_TTIR reproducer with stub kernel + ir_override (#376)
Summary: Pull Request resolved: #376 The existing OVERRIDE_TTIR mode was broken: it skipped defining the kernel function (causing NameError), only worked for autotuned kernels, and discarded constexpr values. This rewrites OVERRIDE_TTIR to generate a stub triton.jit function (same name and params, `pass` body) wrapped with triton.autotune carrying the captured constexpr values, compile params, and ir_override pointing to the captured TTGIR. The stub compiles trivially through the frontend, then ir_override replaces the compilation output with the real kernel's IR. This eliminates the need to copy kernel source code and its transitive dependencies (which was flaky due to complex imports, generated code, and compiled extensions). Key changes: - New stub_generator.py: generates stub functions and extracts constexpr values - Fixed _replace_kernel_import for OVERRIDE_TTIR: generates stub + autotune config - Fixed _replace_kernel_invocation: filters constexpr/compile params (autotune provides them) - Save captured IR files from compilation event's file_content to captured_irs/ - Reuses _is_constexpr_annotation from function_extractor.py (no duplication) - lru_cache on extract_params_from_source to avoid redundant AST parses Reviewed By: FindHao Differential Revision: D99893169 fbshipit-source-id: ce1c42a9afcbce08cb853ce84f5c934551449911
1 parent 3337a0c commit fa507b3

5 files changed

Lines changed: 311 additions & 59 deletions

File tree

tritonparse/reproducer/orchestrator.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
)
1212
from tritonparse.reproducer.templates.loader import load_template_code
1313
from tritonparse.reproducer.types import KernelImportMode
14-
from tritonparse.reproducer.utils import determine_output_paths, format_python_code
14+
from tritonparse.reproducer.utils import (
15+
determine_output_paths,
16+
format_python_code,
17+
save_captured_irs,
18+
)
1519
from tritonparse.shared_vars import is_fbcode
1620
from tritonparse.tools.prettify_ndjson import load_ndjson, save_prettified_json
1721
from tritonparse.tp_logger import logger
@@ -102,6 +106,10 @@ def reproduce(
102106
)
103107
save_prettified_json(context_bundle.raw_comp_event, comp_json_path)
104108

109+
# Save IR files for OVERRIDE_TTIR mode
110+
if kernel_import == KernelImportMode.OVERRIDE_TTIR:
111+
save_captured_irs(context_bundle.raw_comp_event, out_py_path.parent)
112+
105113
logger.debug("Loading reproducer template.")
106114
template_code = load_template_code(template)
107115

tritonparse/reproducer/placeholder_replacer.py

Lines changed: 134 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
match_autotune_config,
1616
)
1717
from tritonparse.reproducer.ingestion.ndjson import ContextBundle
18+
from tritonparse.reproducer.stub_generator import (
19+
_find_ir_override_file,
20+
generate_stub_source,
21+
get_constexpr_values,
22+
)
1823
from tritonparse.reproducer.templates.utils import (
1924
_disable_triton_autotune,
2025
get_function_source,
@@ -242,60 +247,7 @@ def _replace_ir_override_setup(
242247
self, code: str, context_bundle: ContextBundle, **kwargs
243248
) -> str:
244249
"""Replace the IR override setup placeholder."""
245-
kernel_import = kwargs.get("kernel_import", KernelImportMode.DEFAULT)
246-
247-
if kernel_import != KernelImportMode.OVERRIDE_TTIR:
248-
return code.replace(self.IR_OVERRIDE_SETUP_PLACEHOLDER, "")
249-
250-
comp_json_filename = kwargs.get("comp_json_filename")
251-
if not comp_json_filename:
252-
raise ValueError("comp_json_filename is required for OVERRIDE_TTIR mode")
253-
254-
setup_code = f'''
255-
def create_ttir_tempfile():
256-
"""Extract TTIR from compilation event and create temporary file."""
257-
script_dir = Path(__file__).resolve().parent
258-
comp_json_file = script_dir / "{comp_json_filename}"
259-
260-
with open(comp_json_file, 'r') as f:
261-
comp_data = json.load(f)
262-
263-
# Extract TTIR content
264-
kernel_name = comp_data['payload']['metadata']['name']
265-
ttir_key = f"{{kernel_name}}.ttir"
266-
ttir_content = comp_data['payload']['file_content'][ttir_key]
267-
268-
# Create temporary file
269-
temp_file = tempfile.NamedTemporaryFile(
270-
mode='w',
271-
suffix='.ttir',
272-
delete=False,
273-
prefix=f'{{kernel_name}}_'
274-
)
275-
temp_file.write(ttir_content)
276-
temp_file.close()
277-
return temp_file.name
278-
279-
280-
# Monkeypatch triton.autotune to use our TTIR
281-
_ttir_file = create_ttir_tempfile()
282-
_original_autotune = None
283-
284-
def _patched_autotune(configs, key=None, **kwargs):
285-
"""Patched autotune that uses our TTIR file."""
286-
import triton
287-
# Replace configs with our single config using ir_override
288-
new_configs = [triton.Config(kwargs={{}}, ir_override=_ttir_file)]
289-
# Call original autotune with our config
290-
return _original_autotune(new_configs, key=[], **kwargs)
291-
292-
# Apply the monkeypatch before importing the kernel
293-
import triton
294-
_original_autotune = triton.autotune
295-
triton.autotune = _patched_autotune
296-
'''
297-
298-
return code.replace(self.IR_OVERRIDE_SETUP_PLACEHOLDER, setup_code)
250+
return code.replace(self.IR_OVERRIDE_SETUP_PLACEHOLDER, "")
299251

300252
def _replace_kernel_syspath(
301253
self, code: str, context_bundle: ContextBundle, **kwargs
@@ -312,7 +264,7 @@ def _replace_kernel_syspath(
312264
)
313265
return code.replace(self.KERNEL_SYSPATH_PLACEHOLDER, comment)
314266
elif kernel_import == KernelImportMode.OVERRIDE_TTIR:
315-
comment = "# Kernel sys.path setup skipped - using IR override mode"
267+
comment = "# Kernel sys.path setup skipped - using stub with IR override"
316268
return code.replace(self.KERNEL_SYSPATH_PLACEHOLDER, comment)
317269
else:
318270
raise ValueError(f"Unknown kernel_import mode: {kernel_import}")
@@ -425,8 +377,96 @@ def _replace_kernel_import(
425377

426378
return code.replace(self.KERNEL_IMPORT_PLACEHOLDER, embedded_code)
427379
elif kernel_import == KernelImportMode.OVERRIDE_TTIR:
428-
comment = "# Kernel import skipped - using IR override mode with TTIR"
429-
return code.replace(self.KERNEL_IMPORT_PLACEHOLDER, comment)
380+
source_code = context_bundle.kernel_info.source_code
381+
func_name = context_bundle.kernel_info.function_name
382+
383+
if not source_code or not source_code.strip():
384+
raise ValueError(
385+
"Kernel source code is empty, cannot use 'override-ttir' mode"
386+
)
387+
if not func_name:
388+
raise ValueError(
389+
"Cannot determine kernel function name for 'override-ttir' mode"
390+
)
391+
392+
# Build the autotune config with ir_override + constexpr values
393+
constexpr_vals = get_constexpr_values(context_bundle)
394+
compile_meta = context_bundle.compile
395+
396+
config_parts = []
397+
# Constexpr kwargs
398+
kwargs_repr = repr(constexpr_vals)
399+
config_parts.append(f"kwargs={kwargs_repr}")
400+
# Compile params
401+
for param in TRITON_COMPILE_PARAMS:
402+
value = compile_meta.get(param)
403+
if value is not None:
404+
config_parts.append(f"{param}={value!r}")
405+
# ir_override — find the best available IR file
406+
config_parts.append("ir_override=_IR_OVERRIDE_FILE")
407+
408+
config_str = ", ".join(config_parts)
409+
410+
# Generate imports + IR file discovery + stub function + autotune wrapper
411+
import_lines = list(_BASE_IMPORT_LINES) + [""]
412+
413+
# Scan original source for extra imports (e.g., tlx)
414+
extra_imports = _detect_extra_imports(source_code)
415+
if extra_imports:
416+
import_lines.extend(extra_imports)
417+
import_lines.append("")
418+
419+
import_lines.extend(
420+
[
421+
"import os",
422+
"",
423+
]
424+
+ get_function_source(_find_ir_override_file, with_invocation=False)
425+
+ [
426+
"",
427+
"# Find the best captured IR file for ir_override",
428+
"_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))",
429+
"_IR_DIR = os.path.join(_SCRIPT_DIR, 'captured_irs')",
430+
"_IR_OVERRIDE_FILE = _find_ir_override_file(_IR_DIR) if os.path.isdir(_IR_DIR) else None",
431+
"",
432+
]
433+
)
434+
435+
# Generate the stub function source
436+
stub_code = generate_stub_source(func_name, source_code)
437+
438+
# Wrap with autotune carrying ir_override + constexpr values.
439+
# Apply @triton.autotune programmatically so we can conditionally
440+
# include ir_override only when the IR file exists.
441+
# Detect indent of the import placeholder for the if/else block.
442+
placeholder = self.KERNEL_IMPORT_PLACEHOLDER
443+
indent = next(
444+
(
445+
line[: len(line) - len(line.lstrip())]
446+
for line in code.splitlines()
447+
if line.lstrip().startswith(placeholder)
448+
),
449+
"",
450+
)
451+
inner = indent + " "
452+
453+
import_lines.extend(
454+
[
455+
"# Stub kernel — body is replaced by captured IR via ir_override",
456+
stub_code,
457+
"",
458+
f"{indent}if _IR_OVERRIDE_FILE:",
459+
f"{inner}{func_name} = triton.autotune(",
460+
f"{inner} configs=[triton.Config({config_str})],",
461+
f"{inner} key=[],",
462+
f"{inner})({func_name})",
463+
"",
464+
f"{indent}imported_kernel_function = {func_name}",
465+
]
466+
)
467+
468+
embedded_code = "\n".join(import_lines)
469+
return code.replace(self.KERNEL_IMPORT_PLACEHOLDER, embedded_code)
430470
else:
431471
raise ValueError(f"Unknown kernel_import mode: {kernel_import}")
432472

@@ -452,9 +492,46 @@ def _replace_kernel_invocation(
452492
453493
Otherwise, passes all args plus compile params as usual.
454494
"""
495+
kernel_import = kwargs.get("kernel_import", KernelImportMode.DEFAULT)
455496
source_code = context_bundle.kernel_info.source_code
456497
pos_args, kw_args = _parse_kernel_signature(source_code)
457498

499+
if kernel_import == KernelImportMode.OVERRIDE_TTIR:
500+
# When ir_override is active, autotune provides constexpr values
501+
# and compile params — only pass non-constexpr args.
502+
# When ir_override is missing (no captured_irs/), the stub runs
503+
# as a plain @triton.jit and needs all args + compile params.
504+
constexpr_vals = get_constexpr_values(context_bundle)
505+
filtered_pos = [a for a in pos_args if a not in constexpr_vals]
506+
filtered_kw = [a for a in kw_args if a not in constexpr_vals]
507+
override_snippet = _generate_invocation_snippet(
508+
filtered_pos, filtered_kw, compile_params=None
509+
)
510+
compile_params = _get_compile_params_for_invocation(
511+
context_bundle.compile, kw_args
512+
)
513+
fallback_snippet = _generate_invocation_snippet(
514+
pos_args, kw_args, compile_params
515+
)
516+
# Find the placeholder's indent so the if/else aligns in any template.
517+
placeholder = self.KERNEL_INVOCATION_PLACEHOLDER
518+
indent = next(
519+
(
520+
line[: len(line) - len(line.lstrip())]
521+
for line in code.splitlines()
522+
if line.lstrip().startswith(placeholder)
523+
),
524+
" ",
525+
)
526+
inner = indent + " "
527+
invocation_snippet = (
528+
f"if _IR_OVERRIDE_FILE:\n"
529+
f"{inner}{override_snippet}\n"
530+
f"{indent}else:\n"
531+
f"{inner}{fallback_snippet}"
532+
)
533+
return code.replace(self.KERNEL_INVOCATION_PLACEHOLDER, invocation_snippet)
534+
458535
if self.preserve_autotune:
459536
# Get full kernel source with decorators to find autotune config params
460537
func_name = context_bundle.kernel_info.function_name
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
"""
4+
Generate stub @triton.jit functions for IR-based reproducers.
5+
6+
A stub kernel has the same function name and parameter signature as the original
7+
kernel but a trivial body (``pass``). When used with ``ir_override`` on a
8+
``triton.Config``, the stub compiles through the frontend (producing a valid but
9+
trivial TTIR) and the compilation output is replaced by the captured IR from the
10+
original kernel. This completely eliminates the need to copy the original
11+
kernel's source code and its transitive dependencies.
12+
"""
13+
14+
import ast
15+
from functools import lru_cache
16+
from typing import Any, Dict, List, Tuple
17+
18+
from tritonparse.reproducer.function_extractor import _is_constexpr_annotation
19+
from tritonparse.reproducer.ingestion.ndjson import ContextBundle
20+
from tritonparse.tp_logger import logger
21+
22+
23+
def generate_stub_source(kernel_name: str, source_code: str) -> str:
24+
"""Generate a stub ``@triton.jit`` function from the original kernel source.
25+
26+
Parses the original source to extract parameter names and ``tl.constexpr``
27+
annotations, then produces a stub with the same signature and a ``pass``
28+
body. Does NOT include ``@triton.autotune`` — that is added separately
29+
by the placeholder replacer.
30+
31+
Args:
32+
kernel_name: The kernel function name.
33+
source_code: The original kernel source code (from ``@triton.jit`` onward).
34+
35+
Returns:
36+
Python source code defining the stub function.
37+
"""
38+
params = extract_params_from_source(source_code)
39+
40+
# Build parameter list
41+
param_strs: List[str] = []
42+
for name, is_constexpr in params:
43+
if is_constexpr:
44+
param_strs.append(f"{name}: tl.constexpr")
45+
else:
46+
param_strs.append(name)
47+
48+
# Format as multi-line if many params
49+
if len(param_strs) > 3:
50+
params_block = "\n " + ",\n ".join(param_strs) + ",\n"
51+
else:
52+
params_block = ", ".join(param_strs)
53+
54+
return f"@triton.jit\ndef {kernel_name}({params_block}):\n pass"
55+
56+
57+
def _find_ir_override_file(ir_dir: str) -> "str | None":
58+
"""Find the captured TTIR file for ir_override.
59+
60+
Returns the full path to the first ``.ttir`` file, or None.
61+
62+
NOTE: This function is extracted by ``get_function_source()`` and embedded
63+
verbatim in generated reproducer scripts. Keep imports minimal (only
64+
``os`` and ``logging`` are available).
65+
"""
66+
import logging
67+
import os
68+
69+
for name in sorted(os.listdir(ir_dir)):
70+
if name.endswith(".ttir"):
71+
return os.path.join(ir_dir, name)
72+
logging.getLogger(__name__).warning(
73+
"captured_irs/ directory exists but contains no .ttir files. "
74+
"Stub kernel will run without ir_override (fallback mode)."
75+
)
76+
return None
77+
78+
79+
def get_constexpr_values(context_bundle: ContextBundle) -> Dict[str, Any]:
80+
"""Extract constexpr parameter names and their values from the context.
81+
82+
Args:
83+
context_bundle: The context bundle from the captured launch event.
84+
85+
Returns:
86+
Dictionary mapping constexpr parameter name to its value.
87+
"""
88+
source_code = context_bundle.kernel_info.source_code
89+
params = extract_params_from_source(source_code)
90+
constexpr_names = {name for name, is_ce in params if is_ce}
91+
args = context_bundle.args
92+
93+
result: Dict[str, Any] = {}
94+
for name in constexpr_names:
95+
arg_info = args.get(name)
96+
if arg_info is None:
97+
continue
98+
if isinstance(arg_info, dict):
99+
result[name] = arg_info.get("value")
100+
else:
101+
result[name] = arg_info
102+
103+
return result
104+
105+
106+
@lru_cache(maxsize=8)
107+
def extract_params_from_source(
108+
source_code: str,
109+
) -> Tuple[Tuple[str, bool], ...]:
110+
"""Extract parameter names and constexpr annotations from kernel source.
111+
112+
Results are cached since the same source is parsed multiple times during
113+
reproducer generation (stub generation, constexpr extraction, invocation).
114+
115+
Returns:
116+
Tuple of ``(param_name, is_constexpr)`` pairs.
117+
"""
118+
try:
119+
tree = ast.parse(source_code)
120+
except SyntaxError:
121+
logger.warning("Failed to parse kernel source for param extraction")
122+
return ()
123+
124+
for node in ast.walk(tree):
125+
if not isinstance(node, ast.FunctionDef):
126+
continue
127+
128+
params: List[Tuple[str, bool]] = []
129+
all_args = list(node.args.args) + list(node.args.kwonlyargs)
130+
for arg in all_args:
131+
is_ce = arg.annotation is not None and _is_constexpr_annotation(
132+
arg.annotation
133+
)
134+
params.append((arg.arg, is_ce))
135+
return tuple(params)
136+
137+
return ()

tritonparse/reproducer/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class KernelImportMode(str, Enum):
1212
Attributes:
1313
DEFAULT: Import kernel from original file (current behavior).
1414
COPY: Embed kernel source code directly in reproducer.
15-
OVERRIDE_TTIR: Use TTIR from compilation event with monkeypatch.
15+
OVERRIDE_TTIR: Generate a stub kernel with captured IR override.
1616
"""
1717

1818
DEFAULT = "default"

0 commit comments

Comments
 (0)