1515 match_autotune_config ,
1616)
1717from tritonparse .reproducer .ingestion .ndjson import ContextBundle
18+ from tritonparse .reproducer .stub_generator import (
19+ _find_ir_override_file ,
20+ generate_stub_source ,
21+ get_constexpr_values ,
22+ )
1823from tritonparse .reproducer .templates .utils import (
1924 _disable_triton_autotune ,
2025 get_function_source ,
@@ -242,60 +247,7 @@ def _replace_ir_override_setup(
242247 self , code : str , context_bundle : ContextBundle , ** kwargs
243248 ) -> str :
244249 """Replace the IR override setup placeholder."""
245- kernel_import = kwargs .get ("kernel_import" , KernelImportMode .DEFAULT )
246-
247- if kernel_import != KernelImportMode .OVERRIDE_TTIR :
248- return code .replace (self .IR_OVERRIDE_SETUP_PLACEHOLDER , "" )
249-
250- comp_json_filename = kwargs .get ("comp_json_filename" )
251- if not comp_json_filename :
252- raise ValueError ("comp_json_filename is required for OVERRIDE_TTIR mode" )
253-
254- setup_code = f'''
255- def create_ttir_tempfile():
256- """Extract TTIR from compilation event and create temporary file."""
257- script_dir = Path(__file__).resolve().parent
258- comp_json_file = script_dir / "{ comp_json_filename } "
259-
260- with open(comp_json_file, 'r') as f:
261- comp_data = json.load(f)
262-
263- # Extract TTIR content
264- kernel_name = comp_data['payload']['metadata']['name']
265- ttir_key = f"{{kernel_name}}.ttir"
266- ttir_content = comp_data['payload']['file_content'][ttir_key]
267-
268- # Create temporary file
269- temp_file = tempfile.NamedTemporaryFile(
270- mode='w',
271- suffix='.ttir',
272- delete=False,
273- prefix=f'{{kernel_name}}_'
274- )
275- temp_file.write(ttir_content)
276- temp_file.close()
277- return temp_file.name
278-
279-
280- # Monkeypatch triton.autotune to use our TTIR
281- _ttir_file = create_ttir_tempfile()
282- _original_autotune = None
283-
284- def _patched_autotune(configs, key=None, **kwargs):
285- """Patched autotune that uses our TTIR file."""
286- import triton
287- # Replace configs with our single config using ir_override
288- new_configs = [triton.Config(kwargs={{}}, ir_override=_ttir_file)]
289- # Call original autotune with our config
290- return _original_autotune(new_configs, key=[], **kwargs)
291-
292- # Apply the monkeypatch before importing the kernel
293- import triton
294- _original_autotune = triton.autotune
295- triton.autotune = _patched_autotune
296- '''
297-
298- return code .replace (self .IR_OVERRIDE_SETUP_PLACEHOLDER , setup_code )
250+ return code .replace (self .IR_OVERRIDE_SETUP_PLACEHOLDER , "" )
299251
300252 def _replace_kernel_syspath (
301253 self , code : str , context_bundle : ContextBundle , ** kwargs
@@ -312,7 +264,7 @@ def _replace_kernel_syspath(
312264 )
313265 return code .replace (self .KERNEL_SYSPATH_PLACEHOLDER , comment )
314266 elif kernel_import == KernelImportMode .OVERRIDE_TTIR :
315- comment = "# Kernel sys.path setup skipped - using IR override mode "
267+ comment = "# Kernel sys.path setup skipped - using stub with IR override"
316268 return code .replace (self .KERNEL_SYSPATH_PLACEHOLDER , comment )
317269 else :
318270 raise ValueError (f"Unknown kernel_import mode: { kernel_import } " )
@@ -425,8 +377,96 @@ def _replace_kernel_import(
425377
426378 return code .replace (self .KERNEL_IMPORT_PLACEHOLDER , embedded_code )
427379 elif kernel_import == KernelImportMode .OVERRIDE_TTIR :
428- comment = "# Kernel import skipped - using IR override mode with TTIR"
429- return code .replace (self .KERNEL_IMPORT_PLACEHOLDER , comment )
380+ source_code = context_bundle .kernel_info .source_code
381+ func_name = context_bundle .kernel_info .function_name
382+
383+ if not source_code or not source_code .strip ():
384+ raise ValueError (
385+ "Kernel source code is empty, cannot use 'override-ttir' mode"
386+ )
387+ if not func_name :
388+ raise ValueError (
389+ "Cannot determine kernel function name for 'override-ttir' mode"
390+ )
391+
392+ # Build the autotune config with ir_override + constexpr values
393+ constexpr_vals = get_constexpr_values (context_bundle )
394+ compile_meta = context_bundle .compile
395+
396+ config_parts = []
397+ # Constexpr kwargs
398+ kwargs_repr = repr (constexpr_vals )
399+ config_parts .append (f"kwargs={ kwargs_repr } " )
400+ # Compile params
401+ for param in TRITON_COMPILE_PARAMS :
402+ value = compile_meta .get (param )
403+ if value is not None :
404+ config_parts .append (f"{ param } ={ value !r} " )
405+ # ir_override — find the best available IR file
406+ config_parts .append ("ir_override=_IR_OVERRIDE_FILE" )
407+
408+ config_str = ", " .join (config_parts )
409+
410+ # Generate imports + IR file discovery + stub function + autotune wrapper
411+ import_lines = list (_BASE_IMPORT_LINES ) + ["" ]
412+
413+ # Scan original source for extra imports (e.g., tlx)
414+ extra_imports = _detect_extra_imports (source_code )
415+ if extra_imports :
416+ import_lines .extend (extra_imports )
417+ import_lines .append ("" )
418+
419+ import_lines .extend (
420+ [
421+ "import os" ,
422+ "" ,
423+ ]
424+ + get_function_source (_find_ir_override_file , with_invocation = False )
425+ + [
426+ "" ,
427+ "# Find the best captured IR file for ir_override" ,
428+ "_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))" ,
429+ "_IR_DIR = os.path.join(_SCRIPT_DIR, 'captured_irs')" ,
430+ "_IR_OVERRIDE_FILE = _find_ir_override_file(_IR_DIR) if os.path.isdir(_IR_DIR) else None" ,
431+ "" ,
432+ ]
433+ )
434+
435+ # Generate the stub function source
436+ stub_code = generate_stub_source (func_name , source_code )
437+
438+ # Wrap with autotune carrying ir_override + constexpr values.
439+ # Apply @triton.autotune programmatically so we can conditionally
440+ # include ir_override only when the IR file exists.
441+ # Detect indent of the import placeholder for the if/else block.
442+ placeholder = self .KERNEL_IMPORT_PLACEHOLDER
443+ indent = next (
444+ (
445+ line [: len (line ) - len (line .lstrip ())]
446+ for line in code .splitlines ()
447+ if line .lstrip ().startswith (placeholder )
448+ ),
449+ "" ,
450+ )
451+ inner = indent + " "
452+
453+ import_lines .extend (
454+ [
455+ "# Stub kernel — body is replaced by captured IR via ir_override" ,
456+ stub_code ,
457+ "" ,
458+ f"{ indent } if _IR_OVERRIDE_FILE:" ,
459+ f"{ inner } { func_name } = triton.autotune(" ,
460+ f"{ inner } configs=[triton.Config({ config_str } )]," ,
461+ f"{ inner } key=[]," ,
462+ f"{ inner } )({ func_name } )" ,
463+ "" ,
464+ f"{ indent } imported_kernel_function = { func_name } " ,
465+ ]
466+ )
467+
468+ embedded_code = "\n " .join (import_lines )
469+ return code .replace (self .KERNEL_IMPORT_PLACEHOLDER , embedded_code )
430470 else :
431471 raise ValueError (f"Unknown kernel_import mode: { kernel_import } " )
432472
@@ -452,9 +492,46 @@ def _replace_kernel_invocation(
452492
453493 Otherwise, passes all args plus compile params as usual.
454494 """
495+ kernel_import = kwargs .get ("kernel_import" , KernelImportMode .DEFAULT )
455496 source_code = context_bundle .kernel_info .source_code
456497 pos_args , kw_args = _parse_kernel_signature (source_code )
457498
499+ if kernel_import == KernelImportMode .OVERRIDE_TTIR :
500+ # When ir_override is active, autotune provides constexpr values
501+ # and compile params — only pass non-constexpr args.
502+ # When ir_override is missing (no captured_irs/), the stub runs
503+ # as a plain @triton.jit and needs all args + compile params.
504+ constexpr_vals = get_constexpr_values (context_bundle )
505+ filtered_pos = [a for a in pos_args if a not in constexpr_vals ]
506+ filtered_kw = [a for a in kw_args if a not in constexpr_vals ]
507+ override_snippet = _generate_invocation_snippet (
508+ filtered_pos , filtered_kw , compile_params = None
509+ )
510+ compile_params = _get_compile_params_for_invocation (
511+ context_bundle .compile , kw_args
512+ )
513+ fallback_snippet = _generate_invocation_snippet (
514+ pos_args , kw_args , compile_params
515+ )
516+ # Find the placeholder's indent so the if/else aligns in any template.
517+ placeholder = self .KERNEL_INVOCATION_PLACEHOLDER
518+ indent = next (
519+ (
520+ line [: len (line ) - len (line .lstrip ())]
521+ for line in code .splitlines ()
522+ if line .lstrip ().startswith (placeholder )
523+ ),
524+ " " ,
525+ )
526+ inner = indent + " "
527+ invocation_snippet = (
528+ f"if _IR_OVERRIDE_FILE:\n "
529+ f"{ inner } { override_snippet } \n "
530+ f"{ indent } else:\n "
531+ f"{ inner } { fallback_snippet } "
532+ )
533+ return code .replace (self .KERNEL_INVOCATION_PLACEHOLDER , invocation_snippet )
534+
458535 if self .preserve_autotune :
459536 # Get full kernel source with decorators to find autotune config params
460537 func_name = context_bundle .kernel_info .function_name
0 commit comments