@@ -459,6 +459,225 @@ def _extract_dict_keys(dict_node: ast.Dict) -> set[str]:
459459 return keys
460460
461461
462+ def _is_autotune_decorator (node : ast .expr ) -> bool :
463+ """Check if a decorator node is @triton.autotune(...) or @autotune(...)."""
464+ if not isinstance (node , ast .Call ):
465+ return False
466+ func = node .func
467+ # Handle: @triton.autotune(...)
468+ if isinstance (func , ast .Attribute ) and func .attr == "autotune" :
469+ if isinstance (func .value , ast .Name ) and func .value .id == "triton" :
470+ return True
471+ # Handle: @autotune(...) when imported directly
472+ if isinstance (func , ast .Name ) and func .id == "autotune" :
473+ return True
474+ return False
475+
476+
477+ def _get_configs_kwarg (node : ast .Call ) -> ast .expr | None :
478+ """Extract the 'configs' argument AST node from a triton.autotune(...) call."""
479+ for kw in node .keywords :
480+ if kw .arg == "configs" :
481+ return kw .value
482+ # Also check first positional argument
483+ if node .args :
484+ return node .args [0 ]
485+ return None
486+
487+
488+ def _extract_config_values (
489+ call_node : ast .Call ,
490+ var_bindings : dict [str , ast .Dict ],
491+ ) -> dict [str , object ]:
492+ """Extract all key-value pairs from a single triton.Config(...) call.
493+
494+ Extracts both:
495+ - Dict params (first arg): {"BLOCK_M": 256, "BLOCK_N": 128, ...}
496+ - Compile params (keyword args): num_warps=4, num_stages=0, ...
497+
498+ Non-constant values (e.g., pre_hook=func_ref) are skipped.
499+ """
500+ config : dict [str , object ] = {}
501+
502+ # Extract from first positional argument (dict)
503+ if call_node .args :
504+ first_arg = call_node .args [0 ]
505+ dict_node = None
506+ if isinstance (first_arg , ast .Dict ):
507+ dict_node = first_arg
508+ elif isinstance (first_arg , ast .Name ) and first_arg .id in var_bindings :
509+ dict_node = var_bindings [first_arg .id ]
510+
511+ if dict_node is not None :
512+ for key , val in zip (dict_node .keys , dict_node .values ):
513+ k = None
514+ if key is not None :
515+ if isinstance (key , ast .Constant ) and isinstance (key .value , str ):
516+ k = key .value
517+ elif isinstance (key , ast .Name ):
518+ k = key .id
519+ if k is not None and isinstance (val , ast .Constant ):
520+ config [k ] = val .value
521+
522+ # Extract from keyword args (num_warps=4, num_stages=0, etc.)
523+ for kw in call_node .keywords :
524+ if kw .arg is not None and isinstance (kw .value , ast .Constant ):
525+ config [kw .arg ] = kw .value .value
526+
527+ return config
528+
529+
530+ def extract_kernel_autotune_configs (
531+ source_code : str ,
532+ function_name : str ,
533+ ) -> list [dict [str , object ]] | None :
534+ """Extract compile param values from a kernel's @triton.autotune configs.
535+
536+ Traces from the kernel's decorator to the exact configs it references,
537+ isolating from other kernels' configs in the same file.
538+
539+ Handles three patterns for the configs= argument:
540+ - ast.Name (variable reference): configs=configs_var → traces to variable assignment
541+ - ast.List (inline list): configs=[triton.Config(...), ...] → parses directly
542+ - ast.Call (function call): configs=get_configs() → returns None (cannot resolve)
543+
544+ Args:
545+ source_code: Full source file content.
546+ function_name: Target kernel function name.
547+
548+ Returns:
549+ List of config dicts (each containing dict params + compile params),
550+ or None if extraction fails.
551+ """
552+ try :
553+ tree , _ = _parse_source_code (source_code )
554+ except SyntaxError :
555+ return None
556+
557+ # Step 1: Locate the target function
558+ func_node = None
559+ for node in ast .walk (tree ):
560+ if isinstance (node , ast .FunctionDef ) and node .name == function_name :
561+ func_node = node
562+ break
563+ if func_node is None :
564+ return None
565+
566+ # Step 2: Find @triton.autotune decorator, extract configs= argument
567+ configs_ast_node = None
568+ for dec in func_node .decorator_list :
569+ if _is_autotune_decorator (dec ):
570+ configs_ast_node = _get_configs_kwarg (dec )
571+ break
572+ if configs_ast_node is None :
573+ return None
574+
575+ # Step 3: Resolve configs AST node to a list of triton.Config calls
576+ config_call_nodes : list [ast .Call ] = []
577+ if isinstance (configs_ast_node , ast .List ):
578+ # Inline: configs=[triton.Config(...), ...]
579+ config_call_nodes = [
580+ elt
581+ for elt in configs_ast_node .elts
582+ if isinstance (elt , ast .Call ) and _is_triton_config_call (elt )
583+ ]
584+ elif isinstance (configs_ast_node , ast .Name ):
585+ # Variable reference: configs=configs_var
586+ # Collect ALL triton.Config calls from ALL assignments to this variable.
587+ # Multiple assignments may exist (e.g., different config sets for different
588+ # kernels reusing the same variable name, or a cleanup `configs = []`).
589+ # We collect broadly here and let match_autotune_config disambiguate
590+ # using extracted_args.
591+ var_name = configs_ast_node .id
592+ for stmt in tree .body :
593+ if isinstance (stmt , ast .Assign ):
594+ for target in stmt .targets :
595+ if isinstance (target , ast .Name ) and target .id == var_name :
596+ if isinstance (stmt .value , ast .List ):
597+ config_call_nodes .extend (
598+ elt
599+ for elt in stmt .value .elts
600+ if isinstance (elt , ast .Call )
601+ and _is_triton_config_call (elt )
602+ )
603+ elif isinstance (configs_ast_node , ast .Call ):
604+ # Dynamic: configs=get_fwd_configs() — cannot resolve statically
605+ return None
606+ else :
607+ return None
608+
609+ if not config_call_nodes :
610+ return None
611+
612+ # Step 4: Extract key-value pairs from each Config
613+ var_bindings = _build_variable_bindings (tree )
614+ configs = []
615+ for call_node in config_call_nodes :
616+ config = _extract_config_values (call_node , var_bindings )
617+ if config :
618+ configs .append (config )
619+
620+ return configs if configs else None
621+
622+
623+ def match_autotune_config (
624+ configs : list [dict [str , object ]],
625+ extracted_args : dict [str , object ],
626+ ) -> dict [str , object ] | None :
627+ """Match the selected autotune config and return its compile params.
628+
629+ Strategy:
630+ 1. If all configs agree on every compile param value, use them directly.
631+ 2. Otherwise, match by comparing config dict params against extracted_args.
632+
633+ Args:
634+ configs: List of config dicts from extract_kernel_autotune_configs.
635+ extracted_args: Extracted argument info from launch event.
636+ Format: {"BLOCK_M": {"type": "int", "value": 256}, ...}
637+
638+ Returns:
639+ Dict of compile params (e.g., {"num_warps": 4, "num_stages": 0}),
640+ or None if matching fails.
641+ """
642+ if not configs :
643+ return None
644+
645+ # Strategy 1: Check if all configs agree on compile params
646+ compile_params_result : dict [str , object ] = {}
647+ all_agree = True
648+ for param in TRITON_COMPILE_PARAMS :
649+ values = {c [param ] for c in configs if param in c }
650+ if len (values ) == 1 :
651+ compile_params_result [param ] = values .pop ()
652+ elif len (values ) > 1 :
653+ all_agree = False
654+ break
655+
656+ if all_agree and compile_params_result :
657+ return compile_params_result
658+
659+ # Strategy 2: Match by config dict params against extracted_args
660+ for config in configs :
661+ dict_params = {
662+ k : v
663+ for k , v in config .items ()
664+ if k not in TRITON_COMPILE_PARAMS and k not in _TRITON_CONFIG_KWARGS
665+ }
666+ if not dict_params :
667+ continue
668+ match = all (
669+ (
670+ isinstance (extracted_args .get (k ), dict )
671+ and extracted_args [k ].get ("value" ) == v
672+ )
673+ for k , v in dict_params .items ()
674+ )
675+ if match :
676+ return {k : v for k , v in config .items () if k in TRITON_COMPILE_PARAMS }
677+
678+ return None
679+
680+
462681def _is_constexpr_annotation (annotation : ast .expr ) -> bool :
463682 """Check if an annotation is tl.constexpr or triton.language.constexpr."""
464683 if isinstance (annotation , ast .Attribute ) and annotation .attr == "constexpr" :
0 commit comments