Skip to content

Commit eb04a8f

Browse files
FindHaometa-codesync[bot]
authored andcommitted
Fix WS kernel num_warps in reproducer scripts (#349)
Summary: Pull Request resolved: #349 For Warp Specialization (WS) kernels, the Triton compiler overwrites `num_warps` in compilation metadata from the user-specified value (e.g., 4) to the total warp count across all async_task groups (e.g., 16). When the reproducer used this inflated value to launch the kernel, it failed with an assertion in TLXResolvePlaceholderLayouts: `numWarps == 4 || numWarps == 8`. The fix traces from the kernel's `triton.autotune` decorator to its exact `triton.Config(...)` definitions — handling variable references, inline lists, and gracefully falling back for dynamic configs. It extracts the original compile param values and overrides the overwritten values before generating the kernel invocation. Multi-kernel files are handled by isolating each kernel's configs through its decorator AST. Reviewed By: wychi Differential Revision: D94154539 fbshipit-source-id: a6bd84b4531e39a197632f836afa364f3d645220
1 parent 4b7219f commit eb04a8f

3 files changed

Lines changed: 643 additions & 4 deletions

File tree

tritonparse/reproducer/function_extractor.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,225 @@ def _extract_dict_keys(dict_node: ast.Dict) -> set[str]:
459459
return keys
460460

461461

462+
def _is_autotune_decorator(node: ast.expr) -> bool:
463+
"""Check if a decorator node is @triton.autotune(...) or @autotune(...)."""
464+
if not isinstance(node, ast.Call):
465+
return False
466+
func = node.func
467+
# Handle: @triton.autotune(...)
468+
if isinstance(func, ast.Attribute) and func.attr == "autotune":
469+
if isinstance(func.value, ast.Name) and func.value.id == "triton":
470+
return True
471+
# Handle: @autotune(...) when imported directly
472+
if isinstance(func, ast.Name) and func.id == "autotune":
473+
return True
474+
return False
475+
476+
477+
def _get_configs_kwarg(node: ast.Call) -> ast.expr | None:
478+
"""Extract the 'configs' argument AST node from a triton.autotune(...) call."""
479+
for kw in node.keywords:
480+
if kw.arg == "configs":
481+
return kw.value
482+
# Also check first positional argument
483+
if node.args:
484+
return node.args[0]
485+
return None
486+
487+
488+
def _extract_config_values(
489+
call_node: ast.Call,
490+
var_bindings: dict[str, ast.Dict],
491+
) -> dict[str, object]:
492+
"""Extract all key-value pairs from a single triton.Config(...) call.
493+
494+
Extracts both:
495+
- Dict params (first arg): {"BLOCK_M": 256, "BLOCK_N": 128, ...}
496+
- Compile params (keyword args): num_warps=4, num_stages=0, ...
497+
498+
Non-constant values (e.g., pre_hook=func_ref) are skipped.
499+
"""
500+
config: dict[str, object] = {}
501+
502+
# Extract from first positional argument (dict)
503+
if call_node.args:
504+
first_arg = call_node.args[0]
505+
dict_node = None
506+
if isinstance(first_arg, ast.Dict):
507+
dict_node = first_arg
508+
elif isinstance(first_arg, ast.Name) and first_arg.id in var_bindings:
509+
dict_node = var_bindings[first_arg.id]
510+
511+
if dict_node is not None:
512+
for key, val in zip(dict_node.keys, dict_node.values):
513+
k = None
514+
if key is not None:
515+
if isinstance(key, ast.Constant) and isinstance(key.value, str):
516+
k = key.value
517+
elif isinstance(key, ast.Name):
518+
k = key.id
519+
if k is not None and isinstance(val, ast.Constant):
520+
config[k] = val.value
521+
522+
# Extract from keyword args (num_warps=4, num_stages=0, etc.)
523+
for kw in call_node.keywords:
524+
if kw.arg is not None and isinstance(kw.value, ast.Constant):
525+
config[kw.arg] = kw.value.value
526+
527+
return config
528+
529+
530+
def extract_kernel_autotune_configs(
531+
source_code: str,
532+
function_name: str,
533+
) -> list[dict[str, object]] | None:
534+
"""Extract compile param values from a kernel's @triton.autotune configs.
535+
536+
Traces from the kernel's decorator to the exact configs it references,
537+
isolating from other kernels' configs in the same file.
538+
539+
Handles three patterns for the configs= argument:
540+
- ast.Name (variable reference): configs=configs_var → traces to variable assignment
541+
- ast.List (inline list): configs=[triton.Config(...), ...] → parses directly
542+
- ast.Call (function call): configs=get_configs() → returns None (cannot resolve)
543+
544+
Args:
545+
source_code: Full source file content.
546+
function_name: Target kernel function name.
547+
548+
Returns:
549+
List of config dicts (each containing dict params + compile params),
550+
or None if extraction fails.
551+
"""
552+
try:
553+
tree, _ = _parse_source_code(source_code)
554+
except SyntaxError:
555+
return None
556+
557+
# Step 1: Locate the target function
558+
func_node = None
559+
for node in ast.walk(tree):
560+
if isinstance(node, ast.FunctionDef) and node.name == function_name:
561+
func_node = node
562+
break
563+
if func_node is None:
564+
return None
565+
566+
# Step 2: Find @triton.autotune decorator, extract configs= argument
567+
configs_ast_node = None
568+
for dec in func_node.decorator_list:
569+
if _is_autotune_decorator(dec):
570+
configs_ast_node = _get_configs_kwarg(dec)
571+
break
572+
if configs_ast_node is None:
573+
return None
574+
575+
# Step 3: Resolve configs AST node to a list of triton.Config calls
576+
config_call_nodes: list[ast.Call] = []
577+
if isinstance(configs_ast_node, ast.List):
578+
# Inline: configs=[triton.Config(...), ...]
579+
config_call_nodes = [
580+
elt
581+
for elt in configs_ast_node.elts
582+
if isinstance(elt, ast.Call) and _is_triton_config_call(elt)
583+
]
584+
elif isinstance(configs_ast_node, ast.Name):
585+
# Variable reference: configs=configs_var
586+
# Collect ALL triton.Config calls from ALL assignments to this variable.
587+
# Multiple assignments may exist (e.g., different config sets for different
588+
# kernels reusing the same variable name, or a cleanup `configs = []`).
589+
# We collect broadly here and let match_autotune_config disambiguate
590+
# using extracted_args.
591+
var_name = configs_ast_node.id
592+
for stmt in tree.body:
593+
if isinstance(stmt, ast.Assign):
594+
for target in stmt.targets:
595+
if isinstance(target, ast.Name) and target.id == var_name:
596+
if isinstance(stmt.value, ast.List):
597+
config_call_nodes.extend(
598+
elt
599+
for elt in stmt.value.elts
600+
if isinstance(elt, ast.Call)
601+
and _is_triton_config_call(elt)
602+
)
603+
elif isinstance(configs_ast_node, ast.Call):
604+
# Dynamic: configs=get_fwd_configs() — cannot resolve statically
605+
return None
606+
else:
607+
return None
608+
609+
if not config_call_nodes:
610+
return None
611+
612+
# Step 4: Extract key-value pairs from each Config
613+
var_bindings = _build_variable_bindings(tree)
614+
configs = []
615+
for call_node in config_call_nodes:
616+
config = _extract_config_values(call_node, var_bindings)
617+
if config:
618+
configs.append(config)
619+
620+
return configs if configs else None
621+
622+
623+
def match_autotune_config(
624+
configs: list[dict[str, object]],
625+
extracted_args: dict[str, object],
626+
) -> dict[str, object] | None:
627+
"""Match the selected autotune config and return its compile params.
628+
629+
Strategy:
630+
1. If all configs agree on every compile param value, use them directly.
631+
2. Otherwise, match by comparing config dict params against extracted_args.
632+
633+
Args:
634+
configs: List of config dicts from extract_kernel_autotune_configs.
635+
extracted_args: Extracted argument info from launch event.
636+
Format: {"BLOCK_M": {"type": "int", "value": 256}, ...}
637+
638+
Returns:
639+
Dict of compile params (e.g., {"num_warps": 4, "num_stages": 0}),
640+
or None if matching fails.
641+
"""
642+
if not configs:
643+
return None
644+
645+
# Strategy 1: Check if all configs agree on compile params
646+
compile_params_result: dict[str, object] = {}
647+
all_agree = True
648+
for param in TRITON_COMPILE_PARAMS:
649+
values = {c[param] for c in configs if param in c}
650+
if len(values) == 1:
651+
compile_params_result[param] = values.pop()
652+
elif len(values) > 1:
653+
all_agree = False
654+
break
655+
656+
if all_agree and compile_params_result:
657+
return compile_params_result
658+
659+
# Strategy 2: Match by config dict params against extracted_args
660+
for config in configs:
661+
dict_params = {
662+
k: v
663+
for k, v in config.items()
664+
if k not in TRITON_COMPILE_PARAMS and k not in _TRITON_CONFIG_KWARGS
665+
}
666+
if not dict_params:
667+
continue
668+
match = all(
669+
(
670+
isinstance(extracted_args.get(k), dict)
671+
and extracted_args[k].get("value") == v
672+
)
673+
for k, v in dict_params.items()
674+
)
675+
if match:
676+
return {k: v for k, v in config.items() if k in TRITON_COMPILE_PARAMS}
677+
678+
return None
679+
680+
462681
def _is_constexpr_annotation(annotation: ast.expr) -> bool:
463682
"""Check if an annotation is tl.constexpr or triton.language.constexpr."""
464683
if isinstance(annotation, ast.Attribute) and annotation.attr == "constexpr":

tritonparse/reproducer/placeholder_replacer.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
from tritonparse.reproducer.function_extractor import (
1010
extract_autotune_config_params,
1111
extract_function_with_decorators,
12+
extract_kernel_autotune_configs,
1213
extract_utility_functions,
1314
is_constexpr_param,
15+
match_autotune_config,
1416
)
1517
from tritonparse.reproducer.ingestion.ndjson import ContextBundle
1618
from tritonparse.reproducer.templates.utils import (
@@ -23,6 +25,7 @@
2325
_generate_invocation_snippet,
2426
_get_compile_params_for_invocation,
2527
_parse_kernel_signature,
28+
TRITON_COMPILE_PARAMS,
2629
)
2730
from tritonparse.tp_logger import logger
2831

@@ -513,8 +516,6 @@ def _replace_kernel_invocation(
513516

514517
# Get compile params that are NOT provided by autotune configs
515518
# (e.g., num_warps=8 when configs don't specify num_warps)
516-
from tritonparse.reproducer.utils import TRITON_COMPILE_PARAMS
517-
518519
autotune_compile_params = set(autotune_config_params) & set(
519520
TRITON_COMPILE_PARAMS
520521
)
@@ -530,9 +531,49 @@ def _replace_kernel_invocation(
530531
filtered_pos_args, all_filtered_kw_args, remaining_compile_params
531532
)
532533
else:
533-
# Standard behavior: pass compile params
534+
# Standard behavior: pass compile params.
535+
# Override compile params with values from triton.Config if available,
536+
# because the compiler may overwrite them (e.g., WS kernels overwrite
537+
# num_warps with the total warp count across all async_task groups).
538+
effective_compile = dict(context_bundle.compile)
539+
540+
file_path = context_bundle.kernel_info.file_path
541+
func_name = context_bundle.kernel_info.function_name
542+
source_path = Path(file_path)
543+
544+
if not source_path.exists() and context_bundle.source_repo_dir:
545+
source_repo_path = Path(context_bundle.source_repo_dir)
546+
if source_repo_path.exists():
547+
for i in range(1, len(source_path.parts)):
548+
new_path = source_repo_path / Path(
549+
"/".join(source_path.parts[i:])
550+
)
551+
if new_path.exists():
552+
source_path = new_path
553+
break
554+
555+
if source_path.exists():
556+
try:
557+
file_source = source_path.read_text()
558+
configs = extract_kernel_autotune_configs(file_source, func_name)
559+
if configs:
560+
override = match_autotune_config(configs, context_bundle.args)
561+
if override:
562+
for param in TRITON_COMPILE_PARAMS:
563+
if param in override:
564+
effective_compile[param] = override[param]
565+
logger.debug(
566+
"Overrode compile params from triton.Config: %s",
567+
override,
568+
)
569+
except Exception:
570+
logger.debug(
571+
"Failed to extract autotune configs from source",
572+
exc_info=True,
573+
)
574+
534575
compile_params = _get_compile_params_for_invocation(
535-
context_bundle.compile, kw_args
576+
effective_compile, kw_args
536577
)
537578
invocation_snippet = _generate_invocation_snippet(
538579
pos_args, kw_args, compile_params

0 commit comments

Comments
 (0)