diff --git a/src/strands_evals/__init__.py b/src/strands_evals/__init__.py index 229dba0a..d6521d8b 100644 --- a/src/strands_evals/__init__.py +++ b/src/strands_evals/__init__.py @@ -1,4 +1,5 @@ -from . import detectors, evaluators, extractors, generators, providers, simulation, telemetry, types +from . import aggregators, chaos, detectors, evaluators, extractors, generators, providers, simulation, telemetry, types +from .aggregators import CaseAggregator from .case import Case from .eval_task_handler import EvalTaskHandler, TracedHandler, eval_task from .evaluation_data_store import EvaluationDataStore @@ -17,6 +18,8 @@ "EvalTaskHandler", "TracedHandler", "eval_task", + "aggregators", + "chaos", "detectors", "evaluators", "extractors", @@ -29,4 +32,5 @@ "get_tracer", "ActorSimulator", "UserSimulator", + "CaseAggregator", ] diff --git a/src/strands_evals/aggregators/__init__.py b/src/strands_evals/aggregators/__init__.py new file mode 100644 index 00000000..8e3b1f04 --- /dev/null +++ b/src/strands_evals/aggregators/__init__.py @@ -0,0 +1,13 @@ +"""Batch evaluation aggregators for Strands Evals. + +Aggregators analyze evaluation results across multiple cases, scenarios, +or trials to produce summary reports and cross-case insights. +""" + +from .base import CaseAggregator +from .types import AggregationResult + +__all__ = [ + "CaseAggregator", + "AggregationResult", +] diff --git a/src/strands_evals/aggregators/base.py b/src/strands_evals/aggregators/base.py new file mode 100644 index 00000000..ae5e44c1 --- /dev/null +++ b/src/strands_evals/aggregators/base.py @@ -0,0 +1,164 @@ +"""Base CaseAggregator class. + +Provides a default implementation that groups results by evaluator and +computes numeric statistics (mean/min/max score, pass rate). Derived +classes override `summarize_reasons()` to add LLM-based or domain-specific +narrative summaries. +""" + +import logging +from collections import defaultdict +from typing import Any + +from ..types.evaluation_report import EvaluationReport +from .types import AggregationResult + +logger = logging.getLogger(__name__) + + +class CaseAggregator: + """Base class for evaluation aggregators. + + An aggregator takes a flat list of EvaluationReports (produced by an + Experiment) and re-groups/analyzes them along a specific dimension + (e.g., chaos scenarios, trials, case categories). + + The default implementation groups by evaluator name and computes numeric + stats across all cases. Subclasses can override: + - `aggregate()` for custom grouping logic + - `summarize_reasons()` for LLM-based or domain-specific narrative generation + + Example:: + + from strands_evals.aggregators import CaseAggregator + + aggregator = CaseAggregator() + reports = experiment.run_evaluations(task=my_task) + results = aggregator.aggregate(reports) + + for r in results: + print(f"{r.group_key}: mean={r.mean_score:.2f}, pass_rate={r.pass_rate:.0%}") + """ + + def __init__(self, name: str | None = None): + """Initialize the aggregator. + + Args: + name: Optional human-readable name for this aggregator. + """ + self.name = name or self.__class__.__name__ + + def aggregate(self, reports: list[EvaluationReport]) -> list[AggregationResult]: + """Aggregate evaluation reports into summary results. + + Default implementation groups all case results by evaluator name and + computes numeric statistics. The `summary` field is populated by + calling `summarize_reasons()`. + + Args: + reports: Flat list of EvaluationReport objects from an Experiment run. + + Returns: + List of AggregationResult objects, one per evaluator. + """ + if not reports: + return [] + + results = [] + for report in reports: + stats = self._compute_stats(report.scores, report.test_passes) + summary = self.summarize_reasons(report.reasons) + + results.append( + AggregationResult( + group_key=report.evaluator_name or "Unknown", + evaluator_name=report.evaluator_name or "Unknown", + summary=summary, + **stats, + ) + ) + + return results + + def summarize_reasons(self, reasons: list[str]) -> str: + """Produce a narrative summary from a list of per-case reason strings. + + The base implementation concatenates unique non-empty reasons. + Override in subclasses to use LLM-as-a-Judge or domain-specific logic. + + Args: + reasons: List of reason strings from individual evaluations. + + Returns: + A summary string. + """ + return self._concatenate_reasons(reasons) + + # ------------------------------------------------------------------ + # Shared utilities + # ------------------------------------------------------------------ + + @staticmethod + def _compute_stats(scores: list[float], passes: list[bool]) -> dict[str, Any]: + """Compute basic statistics from a list of scores and pass/fail flags. + + Args: + scores: List of numeric scores. + passes: List of boolean pass/fail indicators. + + Returns: + Dict with mean_score, min_score, max_score, pass_rate, + num_results, num_passed, num_failed. + """ + if not scores: + return { + "mean_score": 0.0, + "min_score": 0.0, + "max_score": 0.0, + "pass_rate": 0.0, + "num_results": 0, + "num_passed": 0, + "num_failed": 0, + } + + num_passed = sum(1 for p in passes if p) + num_failed = len(passes) - num_passed + + return { + "mean_score": sum(scores) / len(scores), + "min_score": min(scores), + "max_score": max(scores), + "pass_rate": num_passed / len(passes) if passes else 0.0, + "num_results": len(scores), + "num_passed": num_passed, + "num_failed": num_failed, + } + + @staticmethod + def _concatenate_reasons(reasons: list[str], max_reasons: int = 10) -> str: + """Combine multiple reason strings by deduplication and concatenation. + + Args: + reasons: List of reason strings from individual evaluations. + max_reasons: Maximum number of unique reasons to include. + + Returns: + Combined summary string. + """ + unique_reasons = [] + seen: set[str] = set() + for reason in reasons: + if reason and reason not in seen: + seen.add(reason) + unique_reasons.append(reason) + if len(unique_reasons) >= max_reasons: + break + + if not unique_reasons: + return "" + + if len(unique_reasons) == 1: + return unique_reasons[0] + + summary_parts = [f"({i + 1}) {r}" for i, r in enumerate(unique_reasons)] + return " | ".join(summary_parts) diff --git a/src/strands_evals/aggregators/types.py b/src/strands_evals/aggregators/types.py new file mode 100644 index 00000000..f6b5cfa5 --- /dev/null +++ b/src/strands_evals/aggregators/types.py @@ -0,0 +1,26 @@ +"""Data models for evaluation aggregation results.""" + +from pydantic import BaseModel, Field + + +class AggregationResult(BaseModel): + """Base aggregation result for a group of evaluation results. + + Provides quantitative statistics that any aggregator can produce + regardless of the grouping dimension. + """ + + group_key: str = Field(..., description="Identifier for this group (e.g., case name)") + evaluator_name: str + + # --- Quantitative stats --- + mean_score: float + min_score: float + max_score: float + pass_rate: float # Fraction of results that passed (0.0 to 1.0) + num_results: int + num_passed: int + num_failed: int + + # --- Narrative summary --- + summary: str = Field(default="", description="Aggregated summary of all reason fields") diff --git a/src/strands_evals/chaos/__init__.py b/src/strands_evals/chaos/__init__.py new file mode 100644 index 00000000..c27c5d05 --- /dev/null +++ b/src/strands_evals/chaos/__init__.py @@ -0,0 +1,62 @@ +"""Chaos testing module for Strands Evals. + +Provides deterministic fault injection for evaluating agent resilience +under tool failures and response corruption scenarios. +""" + +from .aggregation_display import ChaosAggregationDisplay, display_chaos_aggregation +from .aggregator import ChaosScenarioAggregator +from .aggregator_types import ( + ChaosAggregationReport, + ChaosScenarioAggregation, + CoverageStatus, + ToolEffectResult, +) +from .effects import ( + TOOL_CORRUPTION_EFFECTS, + TOOL_ERROR_EFFECTS, + ChaosEffect, + CorruptValues, + RemoveFields, + ToolCallFailure, + ToolEffect, + TruncateFields, +) +from .evaluators import ( + FailureCommunicationEvaluator, + PartialCompletionEvaluator, + RecoveryStrategyEvaluator, +) +from .experiment import ChaosExperiment +from .plugin import ChaosPlugin +from .scenario import ChaosScenario + +__all__ = [ + # Core classes + "ChaosExperiment", + "ChaosPlugin", + "ChaosScenario", + # Effect hierarchy + "ChaosEffect", + "ToolEffect", + # Concrete effects + "ToolCallFailure", + "TruncateFields", + "RemoveFields", + "CorruptValues", + # Aggregation + "ChaosAggregationDisplay", + "ChaosAggregationReport", + "ChaosScenarioAggregator", + "ChaosScenarioAggregation", + "CoverageStatus", + "ToolEffectResult", + "display_chaos_aggregation", + # Evaluators + "FailureCommunicationEvaluator", + "PartialCompletionEvaluator", + "RecoveryStrategyEvaluator", + # Classification sets + "TOOL_ERROR_EFFECTS", + "TOOL_CORRUPTION_EFFECTS", +] diff --git a/src/strands_evals/chaos/_context.py b/src/strands_evals/chaos/_context.py new file mode 100644 index 00000000..8c8c0624 --- /dev/null +++ b/src/strands_evals/chaos/_context.py @@ -0,0 +1,21 @@ +"""Internal context variable for tracking the active chaos scenario. + +The ChaosPlugin reads from this ContextVar at hook time. +The ChaosExperiment sets and resets it around each case's task invocation. + +Using a ContextVar ensures correct behavior under: +- Sequential execution (trivially correct) +- Async execution (each asyncio.Task inherits the var from its parent) +- Threaded execution (each thread gets its own copy) +""" + +from contextvars import ContextVar +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .scenario import ChaosScenario + +_current_scenario: ContextVar["ChaosScenario | None"] = ContextVar( + "chaos_current_scenario", + default=None, +) diff --git a/src/strands_evals/chaos/aggregation_display.py b/src/strands_evals/chaos/aggregation_display.py new file mode 100644 index 00000000..8e7912e4 --- /dev/null +++ b/src/strands_evals/chaos/aggregation_display.py @@ -0,0 +1,276 @@ +"""Rich console display for ChaosScenarioAggregation results. + +Interactive table with expand/collapse. Collapsed shows a summary row per case. +Expanded shows Stats + Summary panels on top, scenario-centric Coverage Matrix below. + +The coverage matrix is scenario-centric: +- Rows = scenarios +- Columns = tools +- Cells = effects applied to that tool in that scenario + pass/fail +""" + +from rich.panel import Panel +from rich.table import Table + +from ..display.display_console import CollapsibleTableReportDisplay, console +from .aggregator_types import CoverageStatus + + +class ChaosAggregationDisplay(CollapsibleTableReportDisplay): + """Interactive console display for chaos scenario aggregation results. + + Collapsed: single summary row per case (name, avg score, pass rate). + Expanded: Stats + Summary panels on top, scenario-centric Coverage Matrix below. + """ + + def __init__(self, aggregations: list, reports: list | None = None): + """Initialize the display from aggregation results. + + Args: + aggregations: List of ChaosScenarioAggregation objects. + reports: Optional flat list of EvaluationReport objects (for input display). + """ + self._aggregations = aggregations + self._reports = reports + + # Build items dict for the base class interaction loop + items = {} + overall_score = 0.0 + if aggregations: + overall_score = sum(a.mean_score for a in aggregations) / len(aggregations) + for i, agg in enumerate(aggregations): + items[str(i)] = { + "details": { + "name": agg.group_key, + "score": f"{agg.mean_score:.2f}", + "test_pass": agg.pass_rate >= 0.5, + }, + "detailed_results": [], + "expanded": False, + } + + super().__init__(items=items, overall_score=overall_score) + + # Evaluators where 0.5 is the expected neutral baseline score + _NEUTRAL_BASELINE_EVALUATORS = frozenset({ + "RecoveryStrategyEvaluator", + "FailureCommunicationEvaluator", + }) + + def display_items(self): + """Render the aggregation report.""" + if not self._aggregations: + console.print( + Panel("[bold blue]No aggregation results[/bold blue]", title="📊 Chaos Aggregation Report") + ) + return + + # Compute per-evaluator stats for header + from collections import defaultdict + eval_stats: dict[str, dict] = defaultdict(lambda: {"scores": [], "passes": []}) + for agg in self._aggregations: + eval_stats[agg.evaluator_name]["scores"].append(agg.mean_score) + eval_stats[agg.evaluator_name]["passes"].append(agg.pass_rate) + + # Scenarios count and case count + num_scenarios = self._aggregations[0].num_results if self._aggregations else 0 + case_names = set(agg.group_key for agg in self._aggregations) + num_cases = len(case_names) + num_evaluators = len(eval_stats) + + # Build header with mini-table inside panel + header_table = Table(show_header=True, show_edge=False, box=None, padding=(0, 2)) + header_table.add_column("Evaluator", style="bold") + header_table.add_column("Avg Score", justify="center", style="green") + header_table.add_column("Pass Rate", justify="center", style="green") + header_table.add_column("", style="dim") + + for eval_name in sorted(eval_stats.keys()): + stats = eval_stats[eval_name] + avg = sum(stats["scores"]) / len(stats["scores"]) if stats["scores"] else 0.0 + pr = sum(stats["passes"]) / len(stats["passes"]) if stats["passes"] else 0.0 + note = "(0.5 = neutral baseline)" if eval_name in self._NEUTRAL_BASELINE_EVALUATORS else "" + header_table.add_row(eval_name, f"{avg:.2f}", f"{pr:.0%}", note) + + from rich.console import Group + from rich.text import Text + + dimensions = Text(f"Cases: {num_cases} Scenarios: {num_scenarios} Evaluators: {num_evaluators}") + dimensions.stylize("bold blue") + + console.print(Panel( + Group(dimensions, Text(""), header_table), + title="📊 Chaos Aggregation Report", + )) + + # Summary table — one row per (case, evaluator) + table = Table(title="Test Case Results", show_lines=True) + table.add_column("index", style="cyan") + table.add_column("name", style="magenta") + table.add_column("evaluator", style="yellow") + table.add_column("avg_score", style="green") + table.add_column("baseline_score", style="green") + table.add_column("pass_rate", style="green") + + for i, agg in enumerate(self._aggregations): + key = str(i) + expanded = self.items[key]["expanded"] + symbol = "▼" if expanded else "▶" + baseline = f"{agg.baseline_score:.2f}" if agg.baseline_score is not None else "—" + + evaluator_cell = agg.evaluator_name + if agg.evaluator_name in self._NEUTRAL_BASELINE_EVALUATORS: + evaluator_cell = f"{agg.evaluator_name}\n[dim](0.5 = neutral baseline)[/dim]" + + table.add_row( + f"{symbol} {i}", + agg.group_key, + evaluator_cell, + f"{agg.mean_score:.2f}", + baseline, + f"{agg.pass_rate:.0%}", + ) + + console.print(table) + + # Expanded detail panels for each expanded case + for i, agg in enumerate(self._aggregations): + key = str(i) + if not self.items[key]["expanded"]: + continue + + console.print() + console.print(f"[bold magenta]Case: {agg.group_key}[/bold magenta]") + console.print(f"[dim]Evaluator: {agg.evaluator_name}[/dim]") + + # Top row: Stats (left) + Summary (right) + stats_panel = self._build_stats_panel(agg) + summary_panel = self._build_summary_panel(agg) + + top_row = Table(show_header=False, show_edge=False, box=None, expand=True, padding=0) + top_row.add_column(ratio=1) + top_row.add_column(ratio=2) + top_row.add_row(stats_panel, summary_panel) + console.print(top_row) + + # Bottom: Scenario-centric Coverage Matrix + matrix_panel = self._build_scenario_matrix_panel(agg) + console.print(matrix_panel) + + @staticmethod + def _build_stats_panel(agg) -> Panel: + """Build the stats panel.""" + stats_lines = [ + f"[bold]avg_score:[/bold] {agg.mean_score:.2f}", + f"[bold]min_score:[/bold] {agg.min_score:.2f}", + f"[bold]max_score:[/bold] {agg.max_score:.2f}", + f"[bold]pass_rate:[/bold] {agg.pass_rate:.0%} ({agg.num_passed}/{agg.num_results})", + ] + if agg.baseline_score is not None: + bl_status = "✅" if agg.baseline_passed else "❌" + stats_lines.append(f"[bold]baseline:[/bold] {agg.baseline_score:.2f} {bl_status}") + if agg.degradation_from_baseline is not None: + stats_lines.append(f"[bold]degradation:[/bold] {agg.degradation_from_baseline:.2f}") + + return Panel( + "\n".join(stats_lines), + title="[bold blue]Stats[/bold blue]", + border_style="blue", + ) + + @staticmethod + def _build_summary_panel(agg) -> Panel: + """Build the summary/reason panel.""" + summary_text = agg.summary if agg.summary else "[dim]No summary available[/dim]" + return Panel( + summary_text, + title="[bold cyan]Summary[/bold cyan]", + border_style="cyan", + ) + + @staticmethod + def _build_scenario_matrix_panel(agg) -> Panel: + """Build a scenario-centric coverage matrix. + + Rows = scenarios (from scenario_results) + Columns = tools + Cells = effects applied + pass/fail + """ + # Collect all tools from the coverage matrix + all_tools = sorted(agg.coverage_matrix.keys()) + + if not all_tools and not agg.scenario_results: + return Panel("[dim]No scenario data[/dim]", title="[bold yellow]Coverage Matrix[/bold yellow]") + + # Group scenario_results by scenario_label + from collections import defaultdict + scenarios: dict[str, list] = defaultdict(list) + for sr in agg.scenario_results: + scenarios[sr.scenario_label].append(sr) + + # Build table: rows = scenarios, columns = tools + matrix_table = Table(show_header=True, show_lines=True, expand=True) + matrix_table.add_column("scenario", style="bold", no_wrap=True) + + for tool in all_tools: + matrix_table.add_column(tool, justify="center") + + matrix_table.add_column("score", justify="center", style="green") + matrix_table.add_column("result", justify="center") + + for scenario_name, results in scenarios.items(): + # Build a lookup: tool_name -> effect_type for this scenario + tool_to_effect: dict[str, str] = {} + scenario_passed = True + scenario_score = 0.0 + + for r in results: + tool_to_effect[r.tool_name] = r.effect_type + scenario_score = r.score # All results in same scenario share the score + if not r.passed: + scenario_passed = False + + cells = [scenario_name] + for tool in all_tools: + effect = tool_to_effect.get(tool) + if effect is None: + cells.append("[dim]—[/dim]") + else: + # Show effect name with color based on pass/fail + status = agg.coverage_matrix.get(tool, {}).get(effect, CoverageStatus.NOT_TESTED) + if status == CoverageStatus.PASSED: + cells.append(f"[green]{effect}[/green]") + elif status == CoverageStatus.FAILED: + cells.append(f"[red]{effect}[/red]") + else: + cells.append(f"[dim]{effect}[/dim]") + + cells.append(f"{scenario_score:.2f}") + cells.append("[green bold]PASS[/green bold]" if scenario_passed else "[red bold]FAIL[/red bold]") + matrix_table.add_row(*cells) + + return Panel( + matrix_table, + title="[bold yellow]Coverage Matrix[/bold yellow]", + subtitle="[dim]Rows = scenarios │ Columns = tools │ Cells = effect applied[/dim]", + border_style="yellow", + ) + + +def display_chaos_aggregation( + aggregations: list, + reports: list | None = None, + static: bool = False, +): + """Display chaos aggregation results. + + Shows an interactive table with one row per case. Expanding a case + reveals Stats + Summary panels and a scenario-centric Coverage Matrix. + + Args: + aggregations: List of ChaosScenarioAggregation objects. + reports: Optional flat list of EvaluationReport objects. + static: If True, display once without interaction. + """ + display = ChaosAggregationDisplay(aggregations, reports=reports) + display.run(static=static) diff --git a/src/strands_evals/chaos/aggregator.py b/src/strands_evals/chaos/aggregator.py new file mode 100644 index 00000000..8637bbaa --- /dev/null +++ b/src/strands_evals/chaos/aggregator.py @@ -0,0 +1,430 @@ +"""ChaosScenarioAggregator — aggregates evaluation results across chaos scenarios. + +Given a flat list of EvaluationReports from a ChaosExperiment, this aggregator: +1. Re-groups results by the original case name (stripping the [scenario] suffix). +2. Within each group, organizes results by (tool_name, effect_type) pairs + extracted from case metadata["chaos_scenario"]. +3. Produces a ChaosScenarioAggregation per (original_case, evaluator) pair + containing quantitative stats, a coverage matrix, and baseline comparison. +4. Uses LLM-as-a-Judge to produce a narrative summary of the agent's + resilience across scenarios (when model is provided). +""" + +import logging +from collections import defaultdict +from typing import Optional, cast + +from pydantic import BaseModel, Field +from strands import Agent +from strands.models.model import Model + +from ..aggregators.base import CaseAggregator +from ..types.evaluation_report import EvaluationReport +from .aggregator_types import ( + ChaosAggregationReport, + ChaosScenarioAggregation, + CoverageStatus, + ToolEffectResult, +) + +logger = logging.getLogger(__name__) + +# All known effect types for coverage matrix population +_ALL_EFFECT_TYPES = [ + "timeout", "network_error", "execution_error", "validation_error", + "truncate_fields", "remove_fields", "corrupt_values", +] + +# Regex to strip the scenario suffix from case names: "case_name|scenario_name" +_SCENARIO_SEPARATOR = "|" + +# The baseline scenario name used by ChaosExperiment +_BASELINE_SCENARIO_NAME = "baseline" + +# Default system prompt for LLM-based reason summarization +_SUMMARIZE_SYSTEM_PROMPT = """\ +You are an evaluation analyst for AI agent resilience testing. + +You will receive per-scenario evaluation results from chaos testing, where each +scenario injected one or more failures into the agent's environment. + +Produce a single paragraph (100 words max) that: +1. States which failure modes the agent handled well vs. poorly. +2. Notes any pattern (e.g., "handles timeouts but not data corruption"). +3. Highlights the most critical gap if one stands out. + +Be specific and actionable. Do not repeat raw reasons. Do not use bullet points nor multiple paragraphs. +""" + +_SUMMARIZE_USER_TEMPLATE = """\ +Case: {case_name} +Evaluator: {evaluator_name} +Baseline passed: {baseline_passed} +Pass rate under chaos: {pass_rate:.0%} ({num_passed}/{num_results} scenarios passed) + +Per-scenario results: +{scenario_details} + +Summarize the agent's resilience pattern for this case. +""" + + +class ResilienceSummary(BaseModel): + """Structured output for LLM-based resilience summarization.""" + + reasoning: str = Field(description="Brief analysis of the agent's resilience patterns") + summary: str = Field(description="Single paragraph summary") + + +class ChaosScenarioAggregator(CaseAggregator): + """Aggregates evaluation results across chaos scenarios for each original case. + + Designed to work with the output of ChaosExperiment, which tags each case + with metadata["chaos_scenario"] and appends "[scenario_name]" to case names. + + Produces one ChaosScenarioAggregation per (original_case, evaluator) pair. + + Args: + known_tools: Optional list of tool names that could be tested. Used to + populate NOT_TESTED entries in the coverage matrix for tools that + weren't covered by any scenario. + known_effects: Optional list of effect types to track. Defaults to all + known effect types. + scenarios: Optional list of ChaosScenario objects. When provided, the + aggregator uses typed scenario lookup instead of metadata/name parsing. + Automatically populated by ChaosExperiment if not set. + model: Model for LLM-as-a-Judge reason summarization. Accepts a model ID + string or a Model instance. If None (default), summarization uses + simple concatenation — no model calls are made. + system_prompt: Optional custom system prompt for the summarization judge. + name: Optional human-readable name for this aggregator. + + Example:: + + from strands_evals.chaos import ChaosScenarioAggregator, display_chaos_aggregation + + aggregator = ChaosScenarioAggregator( + known_tools=["search_tool", "database_tool"], + model="us.anthropic.claude-sonnet-4-20250514-v1:0", + ) + + reports = experiment.run_evaluations(task=my_task) + aggregations = aggregator.aggregate(reports) + display_chaos_aggregation(aggregations, reports=reports) + """ + + def __init__( + self, + known_tools: Optional[list[str]] = None, + known_effects: Optional[list[str]] = None, + scenarios: Optional[list] = None, + model: Optional[Model | str] = None, + system_prompt: Optional[str] = None, + name: Optional[str] = None, + ): + super().__init__(name=name or "ChaosScenarioAggregator") + self.known_tools = known_tools or [] + self.known_effects = known_effects or _ALL_EFFECT_TYPES + self.model = model + self.system_prompt = system_prompt or _SUMMARIZE_SYSTEM_PROMPT + + # Build scenario lookup: scenario_name -> {tool_name: [effect_types]} + self._scenario_effects: dict[str, dict[str, list[str]]] = {} + if scenarios: + for scenario in scenarios: + tool_effects: dict[str, list[str]] = {} + for tool_name, effects in scenario.effects.items(): + effect_names = [] + for e in effects: + if hasattr(e, "error_type"): + effect_names.append(e.error_type) + elif hasattr(e, "max_length"): + effect_names.append("truncate_fields") + elif hasattr(e, "remove_ratio"): + effect_names.append("remove_fields") + elif hasattr(e, "corrupt_ratio"): + effect_names.append("corrupt_values") + else: + effect_names.append(type(e).__name__.lower()) + tool_effects[tool_name] = effect_names + self._scenario_effects[scenario.name] = tool_effects + + def aggregate(self, reports: list[EvaluationReport]) -> ChaosAggregationReport: + """Aggregate chaos experiment reports into per-case scenario aggregations. + + Args: + reports: Flat list of EvaluationReport objects from ChaosExperiment. + + Returns: + ChaosAggregationReport with .run_display() and .to_file() methods. + """ + if not reports: + return ChaosAggregationReport(aggregations=[]) + + grouped = self._group_results(reports) + + aggregations = [] + for (case_name, evaluator_name), entries in grouped.items(): + aggregation = self._build_aggregation(case_name, evaluator_name, entries) + aggregations.append(aggregation) + + aggregations.sort(key=lambda a: (a.group_key, a.evaluator_name)) + return ChaosAggregationReport(aggregations=aggregations) + + def summarize_reasons(self, reasons: list[str]) -> str: + """Produce a narrative summary from per-scenario reason strings. + + Only produces a summary when self.model is set. Returns empty string otherwise. + + Args: + reasons: List of reason strings from individual evaluations. + + Returns: + A summary string, or empty if no model configured. + """ + non_empty = [r for r in reasons if r] + if not non_empty or self.model is None: + return "" + + prompt = ( + "Summarize the following evaluation reasons into a concise 2-3 sentence summary:\n\n" + + "\n".join(f"- {r}" for r in non_empty[:20]) + ) + + try: + agent = Agent(model=self.model, system_prompt=self.system_prompt, callback_handler=None) + result = agent(prompt, structured_output_model=ResilienceSummary) + rating = cast(ResilienceSummary, result.structured_output) + return rating.summary + except Exception as e: + logger.warning(f"LLM summarization failed: {e}") + return "" + + def _summarize_for_aggregation( + self, + case_name: str, + evaluator_name: str, + entries: list[dict], + stats: dict, + baseline_passed: Optional[bool], + ) -> str: + """Produce a summary for a specific aggregation group using LLM-as-a-Judge. + + Creates an Agent with the configured model and system prompt, then invokes + it with structured output to get a ResilienceSummary. + + Args: + case_name: The original case name. + evaluator_name: The evaluator name. + entries: The chaos scenario entries (excluding baseline). + stats: Computed stats dict. + baseline_passed: Whether baseline passed (None if no baseline). + + Returns: + Summary string. + """ + reasons = [e["reason"] for e in entries] + + # Build detailed prompt for LLM + scenario_lines = [] + for entry in entries: + status = "PASSED" if entry["passed"] else "FAILED" + scenario_lines.append( + f" - [{status}] {entry['scenario_name']} (score={entry['score']:.2f}): {entry['reason']}" + ) + + prompt = _SUMMARIZE_USER_TEMPLATE.format( + case_name=case_name, + evaluator_name=evaluator_name, + baseline_passed=baseline_passed if baseline_passed is not None else "N/A", + pass_rate=stats["pass_rate"], + num_passed=stats["num_passed"], + num_results=stats["num_results"], + scenario_details="\n".join(scenario_lines), + ) + + try: + agent = Agent(model=self.model, system_prompt=self.system_prompt, callback_handler=None) + result = agent(prompt, structured_output_model=ResilienceSummary) + rating = cast(ResilienceSummary, result.structured_output) + return rating.summary + except Exception as e: + logger.warning(f"LLM summarization failed for case '{case_name}': {e}") + return "" + + # ------------------------------------------------------------------ + # Internal grouping and aggregation logic + # ------------------------------------------------------------------ + + def _group_results( + self, reports: list[EvaluationReport] + ) -> dict[tuple[str, str], list[dict]]: + """Group report entries by (original_case_name, evaluator_name). + + Each entry in the returned lists is a dict with: + - scenario_name: str + - score: float + - passed: bool + - reason: str + - metadata: dict (from the case) + """ + grouped: dict[tuple[str, str], list[dict]] = defaultdict(list) + + for report in reports: + evaluator_name = report.evaluator_name + + for i, case_data in enumerate(report.cases): + raw_name = case_data.get("name", "") or "" + original_name, scenario_name = self._parse_case_name(raw_name) + + metadata = case_data.get("metadata") or {} + if "chaos_scenario" in metadata: + scenario_name = metadata["chaos_scenario"] + + score = report.scores[i] if i < len(report.scores) else 0.0 + passed = report.test_passes[i] if i < len(report.test_passes) else False + reason = report.reasons[i] if i < len(report.reasons) else "" + + grouped[(original_name, evaluator_name)].append( + { + "scenario_name": scenario_name, + "score": score, + "passed": passed, + "reason": reason, + "metadata": metadata, + } + ) + + return grouped + + def _build_aggregation( + self, + case_name: str, + evaluator_name: str, + entries: list[dict], + ) -> ChaosScenarioAggregation: + """Build a ChaosScenarioAggregation from grouped entries.""" + # Separate baseline from chaos scenarios + baseline_entries = [e for e in entries if e["scenario_name"] == _BASELINE_SCENARIO_NAME] + chaos_entries = [e for e in entries if e["scenario_name"] != _BASELINE_SCENARIO_NAME] + + # Compute stats over chaos scenarios only (baseline is reference) + chaos_scores = [e["score"] for e in chaos_entries] + chaos_passes = [e["passed"] for e in chaos_entries] + stats = self._compute_stats(chaos_scores, chaos_passes) + + # Baseline comparison + baseline_score: Optional[float] = None + baseline_passed: Optional[bool] = None + degradation: Optional[float] = None + + if baseline_entries: + baseline_score = baseline_entries[0]["score"] + baseline_passed = baseline_entries[0]["passed"] + if chaos_scores: + degradation = baseline_score - stats["mean_score"] + + # Build per-scenario ToolEffectResults and coverage matrix + scenario_results = [] + coverage_matrix: dict[str, dict[str, CoverageStatus]] = {} + + for entry in chaos_entries: + scenario_name = entry["scenario_name"] + metadata = entry["metadata"] + tool_effects = self._extract_tool_effects_from_metadata(metadata, scenario_name) + + for tool_name, effect_type in tool_effects: + result = ToolEffectResult( + group_key=f"{case_name}/{tool_name}/{effect_type}", + evaluator_name=evaluator_name, + mean_score=entry["score"], + min_score=entry["score"], + max_score=entry["score"], + pass_rate=1.0 if entry["passed"] else 0.0, + num_results=1, + num_passed=1 if entry["passed"] else 0, + num_failed=0 if entry["passed"] else 1, + tool_name=tool_name, + effect_type=effect_type, + scenario_label=scenario_name, + score=entry["score"], + passed=entry["passed"], + reason=entry["reason"], + ) + scenario_results.append(result) + + if tool_name not in coverage_matrix: + coverage_matrix[tool_name] = {} + coverage_matrix[tool_name][effect_type] = ( + CoverageStatus.PASSED if entry["passed"] else CoverageStatus.FAILED + ) + + self._fill_not_tested(coverage_matrix) + + # Summarize reasons (LLM-as-a-Judge only if model was explicitly provided) + if self.model is not None: + summary = self._summarize_for_aggregation( + case_name, evaluator_name, chaos_entries, stats, baseline_passed + ) + else: + summary = "" + + return ChaosScenarioAggregation( + group_key=case_name, + evaluator_name=evaluator_name, + mean_score=stats["mean_score"], + min_score=stats["min_score"], + max_score=stats["max_score"], + pass_rate=stats["pass_rate"], + num_results=stats["num_results"], + num_passed=stats["num_passed"], + num_failed=stats["num_failed"], + coverage_matrix=coverage_matrix, + baseline_score=baseline_score, + baseline_passed=baseline_passed, + degradation_from_baseline=degradation, + scenario_results=scenario_results, + summary=summary, + ) + + def _extract_tool_effects_from_metadata( + self, metadata: dict, scenario_name: str + ) -> list[tuple[str, str]]: + """Extract (tool_name, effect_type) pairs for a scenario. + + Uses the scenario lookup populated from ChaosScenario objects. + """ + if scenario_name in self._scenario_effects: + pairs = [] + for tool_name, effect_types in self._scenario_effects[scenario_name].items(): + for effect_type in effect_types: + pairs.append((tool_name, effect_type)) + if pairs: + return pairs + + if scenario_name and scenario_name != _BASELINE_SCENARIO_NAME: + return [(scenario_name, "unknown")] + + return [] + + def _fill_not_tested(self, coverage_matrix: dict[str, dict[str, CoverageStatus]]) -> None: + """Fill NOT_TESTED entries for known tool×effect combinations not covered.""" + all_tools = set(coverage_matrix.keys()) | set(self.known_tools) + + for tool_name in all_tools: + if tool_name not in coverage_matrix: + coverage_matrix[tool_name] = {} + for effect_type in self.known_effects: + if effect_type not in coverage_matrix[tool_name]: + coverage_matrix[tool_name][effect_type] = CoverageStatus.NOT_TESTED + + @staticmethod + def _parse_case_name(raw_name: str) -> tuple[str, str]: + """Parse a tagged case name into (original_name, scenario_name). + + ChaosExperiment tags cases as "original_name|scenario_name". + """ + if _SCENARIO_SEPARATOR in raw_name: + parts = raw_name.rsplit(_SCENARIO_SEPARATOR, 1) + return parts[0].strip(), parts[1].strip() + return raw_name, "unknown" diff --git a/src/strands_evals/chaos/aggregator_types.py b/src/strands_evals/chaos/aggregator_types.py new file mode 100644 index 00000000..24fd9d18 --- /dev/null +++ b/src/strands_evals/chaos/aggregator_types.py @@ -0,0 +1,154 @@ +"""Data models for chaos scenario aggregation results.""" + +import json +from enum import Enum +from pathlib import Path +from typing import Optional + +from pydantic import BaseModel, Field + +from ..aggregators.types import AggregationResult + + +class CoverageStatus(str, Enum): + """Status of a tool×effect combination in the coverage matrix.""" + + PASSED = "passed" # Agent handled the failure correctly + FAILED = "failed" # Agent did not handle the failure + NOT_TESTED = "not_tested" # Combination was not enumerated (capped by max_scenarios) + + +class ToolEffectResult(AggregationResult): + """Result for a single tool×effect scenario evaluation. + + Inherits numeric stats from AggregationResult and adds chaos-specific fields. + """ + + tool_name: str + effect_type: str + scenario_label: str + + # Convenience fields for single-scenario results + score: float = 0.0 + passed: bool = False + reason: str = "" + + +class ChaosScenarioAggregation(AggregationResult): + """Aggregated results for one original case across all chaos scenarios. + + Extends the base AggregationResult with chaos-specific coverage analysis + and baseline comparison. + """ + + # --- Coverage matrix (chaos-specific) --- + coverage_matrix: dict[str, dict[str, CoverageStatus]] = Field( + default_factory=dict, + description=( + "Outer key: tool_name, Inner key: effect_type → status. " + 'e.g. {"check_inventory": {"timeout": "passed", "truncate_fields": "failed"}}' + ), + ) + + # --- Baseline comparison (chaos-specific) --- + baseline_score: Optional[float] = Field( + default=None, description="Score from the baseline (no-chaos) scenario" + ) + baseline_passed: Optional[bool] = Field( + default=None, description="Whether the baseline scenario passed" + ) + degradation_from_baseline: Optional[float] = Field( + default=None, description="baseline_score - mean_score (positive = degradation)" + ) + + # --- Per-scenario detail --- + scenario_results: list[ToolEffectResult] = Field(default_factory=list) + + +class ChaosAggregationReport(BaseModel): + """Report containing all chaos scenario aggregation results. + + Provides .run_display() and .to_file() matching the EvaluationReport interface. + + Example:: + + aggregation_report = experiment.aggregate_evaluations() + aggregation_report.run_display() + aggregation_report.to_file("chaos_aggregation_report.json") + """ + + aggregations: list[ChaosScenarioAggregation] = Field(default_factory=list) + + # Internal: raw reports for display (not serialized) + _reports: list = [] + + class Config: + arbitrary_types_allowed = True + + def run_display(self): + """Render the aggregation report interactively. + + Collapsed view shows a summary row per case. Expanding a case reveals + Stats + Summary panels and a full Coverage Matrix. + """ + from .aggregation_display import display_chaos_aggregation + + display_chaos_aggregation(self.aggregations, reports=self._reports) + + def display(self): + """Render the report statically (non-interactive).""" + from .aggregation_display import display_chaos_aggregation + + display_chaos_aggregation(self.aggregations, reports=self._reports, static=True) + + def to_file(self, path: str): + """Write the aggregation report to a JSON file. + + Args: + path: The file path where the report will be saved. + If no extension is provided, ".json" will be added automatically. + + Raises: + ValueError: If the path has a non-JSON extension. + """ + file_path = Path(path) + + if file_path.suffix: + if file_path.suffix != ".json": + raise ValueError( + f"Only .json format is supported. Got path with extension: {path}. " + f"Please use a .json extension or provide a path without an extension." + ) + else: + file_path = file_path.with_suffix(".json") + + file_path.parent.mkdir(parents=True, exist_ok=True) + + with open(file_path, "w", encoding="utf-8") as f: + json.dump(self.model_dump(), f, indent=2, ensure_ascii=False) + + @classmethod + def from_file(cls, path: str) -> "ChaosAggregationReport": + """Load an aggregation report from a JSON file. + + Args: + path: Path to the JSON file. + + Returns: + A ChaosAggregationReport instance. + + Raises: + ValueError: If the file does not have a .json extension. + """ + file_path = Path(path) + + if file_path.suffix != ".json": + raise ValueError( + f"Only .json format is supported. Got file: {path}. " + f"Please provide a path with .json extension." + ) + + with open(file_path, "r", encoding="utf-8") as f: + data = json.load(f) + + return cls.model_validate(data) diff --git a/src/strands_evals/chaos/effects.py b/src/strands_evals/chaos/effects.py new file mode 100644 index 00000000..3f51ea79 --- /dev/null +++ b/src/strands_evals/chaos/effects.py @@ -0,0 +1,273 @@ +"""Chaos effect definitions. + +Effects are first-class parameterized classes organized in a hierarchy: + ChaosEffect → ToolEffect → concrete effects (Timeout, NetworkError, etc.) + → ModelEffect → (reserved for v2) + +Each concrete effect carries only the parameters meaningful to it. +The `hook` class variable indicates whether the effect fires pre-tool-call +(error effects) or post-tool-call (corruption effects). + +Pre-hook effects provide `error_message` (the plugin cancels the tool call). +Post-hook effects implement `apply(response)` (the plugin passes the response through). +""" + +import math +import random +from abc import abstractmethod +from typing import Any, ClassVar, Literal + +from pydantic import BaseModel, Field + + +# --------------------------------------------------------------------------- +# Base classes +# --------------------------------------------------------------------------- + + +class ChaosEffect(BaseModel): + """Base for all chaos effects. + + Attributes: + apply_rate: Probability that this effect fires. + In v1 this field is accepted but ignored (always fires). + hook: Whether this effect fires pre-call ("pre") or post-call ("post"). + """ + + hook: ClassVar[Literal["pre", "post"]] + + apply_rate: float = Field( + default=1.0, + ge=0.0, + le=1.0, + description="Probability that this effect fires (1.0 = always).", + ) + + @abstractmethod + def apply(self, context: Any = None) -> Any: + """Apply the chaos effect. + + Pre-hook effects return an error message string. + Post-hook effects accept a response dict and return the corrupted dict. + """ + ... + + +class ToolEffect(ChaosEffect): + """Effect valid at the tool invocation boundary. + + - "pre": effect fires before tool execution (cancels the call with an error) + - "post": effect fires after tool execution (corrupts the response) + """ + + +# --------------------------------------------------------------------------- +# Pre-hook effect — cancels the tool call before execution +# --------------------------------------------------------------------------- + +# All supported failure types +ToolCallFailureType = Literal["timeout", "network_error", "execution_error", "validation_error"] + +# Default error messages per failure type +_DEFAULT_ERROR_MESSAGES: dict[str, str] = { + "timeout": "Tool call timed out", + "network_error": "Network unreachable", + "execution_error": "Tool execution failed", + "validation_error": "Tool input validation failed", +} + + +class ToolCallFailure(ToolEffect): + """Simulates a tool call failure that prevents the tool from executing. + + The tool call is cancelled before execution with a simulated error message. + + Example:: + + ChaosScenario( + name="search_timeout", + effects={"search_tool": [ToolCallFailure(error_type="timeout")]}, + ) + + ChaosScenario( + name="db_network_error", + effects={"database_tool": [ToolCallFailure( + error_type="network_error", + error_message="Connection refused on port 5432", + )]}, + ) + """ + + hook: ClassVar[Literal["pre", "post"]] = "pre" + error_type: ToolCallFailureType = Field( + default="execution_error", + description="Type of failure to simulate.", + ) + error_message: str | None = Field( + default=None, + description="Custom error message. If None, uses a default for the error_type.", + ) + + def apply(self, context: Any = None) -> str: + """Return the error message to cancel the tool call with.""" + if self.error_message is not None: + return self.error_message + return _DEFAULT_ERROR_MESSAGES[self.error_type] + + +# --------------------------------------------------------------------------- +# Concrete tool corruption effects (post-hook — mutate the response) +# +# Post-hook effects implement apply(response) -> response. +# The plugin calls effect.apply(response_dict) and uses the return value. +# --------------------------------------------------------------------------- + + +class TruncateFields(ToolEffect): + """Truncates string values in the tool response. + + The tool executes normally, but string fields in the response are + truncated to at most `max_length` characters. + + Example:: + + ChaosScenario( + name="search_truncated", + effects={ + "search_tool": [TruncateFields(max_length=5)], + }, + ) + """ + + hook: ClassVar[Literal["pre", "post"]] = "post" + max_length: int = Field(default=10, ge=0, description="Maximum length to truncate string values to") + + def apply(self, response: dict[str, Any]) -> dict[str, Any]: + """Truncate string values to max_length. + + Args: + response: The tool response dict to corrupt. + + Returns: + Response with string values truncated. + """ + result: dict[str, Any] = {} + for key, value in response.items(): + if isinstance(value, str) and len(value) > self.max_length: + result[key] = value[: self.max_length] + elif isinstance(value, dict): + result[key] = self._truncate(value) + else: + result[key] = value + return result + + +class RemoveFields(ToolEffect): + """Removes a fraction of fields from the tool response. + + The tool executes normally, but a portion of the response fields + are deleted. + + Example:: + + ChaosScenario( + name="db_remove_fields", + effects={ + "database_tool": [RemoveFields(remove_ratio=0.5)], + }, + ) + """ + + hook: ClassVar[Literal["pre", "post"]] = "post" + remove_ratio: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Fraction of fields to remove from the response", + ) + + def apply(self, response: dict[str, Any]) -> dict[str, Any]: + """Remove a fraction of fields from the response. + + Always removes at least 1 field when called. + + Args: + response: The tool response dict to corrupt. + + Returns: + Response with fields removed. + """ + keys = list(response.keys()) + if not keys: + return response + + num_to_remove = max(1, math.ceil(len(keys) * self.remove_ratio)) + keys_to_remove = set(random.sample(keys, min(num_to_remove, len(keys)))) + return {k: v for k, v in response.items() if k not in keys_to_remove} + + +class CorruptValues(ToolEffect): + """Replaces a fraction of values with garbage data. + + The tool executes normally, but a portion of the response values + are replaced with wrong types or nonsense data. + + Example:: + + ChaosScenario( + name="db_corrupt", + effects={ + "database_tool": [CorruptValues(corrupt_ratio=0.8)], + }, + ) + """ + + hook: ClassVar[Literal["pre", "post"]] = "post" + corrupt_ratio: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Fraction of values to corrupt in the response", + ) + + _CORRUPTIONS: ClassVar[list[Any]] = [None, 99999, "", True, [], "CORRUPTED_DATA"] + + def apply(self, response: dict[str, Any]) -> dict[str, Any]: + """Replace a fraction of values with wrong types or garbage data. + + Always corrupts at least 1 field when called. + + Args: + response: The tool response dict to corrupt. + + Returns: + Response with corrupted values. + """ + keys = list(response.keys()) + if not keys: + return response + + num_to_corrupt = max(1, math.ceil(len(keys) * self.corrupt_ratio)) + keys_to_corrupt = set(random.sample(keys, min(num_to_corrupt, len(keys)))) + + result: dict[str, Any] = {} + for key, value in response.items(): + if key in keys_to_corrupt: + candidates = [c for c in self._CORRUPTIONS if c != value] + result[key] = random.choice(candidates) if candidates else "CORRUPTED_DATA" + elif isinstance(value, dict): + result[key] = self.apply(value) + else: + result[key] = value + return result + + +# --------------------------------------------------------------------------- +# Convenience sets for classification (derived from hierarchy, not maintained manually) +# --------------------------------------------------------------------------- + +# All concrete pre-hook (error) effect classes +TOOL_ERROR_EFFECTS: set[type[ToolEffect]] = {ToolCallFailure} + +# All concrete post-hook (corruption) effect classes +TOOL_CORRUPTION_EFFECTS: set[type[ToolEffect]] = {TruncateFields, RemoveFields, CorruptValues} diff --git a/src/strands_evals/chaos/evaluators/__init__.py b/src/strands_evals/chaos/evaluators/__init__.py new file mode 100644 index 00000000..1ade6992 --- /dev/null +++ b/src/strands_evals/chaos/evaluators/__init__.py @@ -0,0 +1,17 @@ +"""Chaos-specific evaluators for resilience testing. + +These evaluators assess agent behavior under failure conditions: +- RecoveryStrategyEvaluator: How well the agent chose recovery actions +- PartialCompletionEvaluator: What percentage of the task was completed despite failures +- FailureCommunicationEvaluator: How well the agent communicated failures to the user +""" + +from .failure_communication_evaluator import FailureCommunicationEvaluator +from .partial_completion_evaluator import PartialCompletionEvaluator +from .recovery_strategy_evaluator import RecoveryStrategyEvaluator + +__all__ = [ + "FailureCommunicationEvaluator", + "PartialCompletionEvaluator", + "RecoveryStrategyEvaluator", +] diff --git a/src/strands_evals/chaos/evaluators/failure_communication_evaluator.py b/src/strands_evals/chaos/evaluators/failure_communication_evaluator.py new file mode 100644 index 00000000..9b5a4842 --- /dev/null +++ b/src/strands_evals/chaos/evaluators/failure_communication_evaluator.py @@ -0,0 +1,113 @@ +"""Failure Communication Evaluator. + +Evaluates quality of agent's failure communication and user experience. +""" + +from enum import Enum +from typing import cast + +from pydantic import BaseModel, Field +from strands import Agent +from strands.models.model import Model + +from ...types.evaluation import EvaluationData, EvaluationOutput, InputT, OutputT +from ...types.trace import EvaluationLevel +from ...evaluators.evaluator import Evaluator +from .prompt_templates.failure_communication_v0 import SYSTEM_PROMPT as SYSTEM_PROMPT_V0 + +_PROMPT_VERSIONS = {"v0": SYSTEM_PROMPT_V0} + + +class FailureCommunicationScore(str, Enum): + """Categorical failure communication ratings.""" + + FAILURE = "Failure" + POOR = "Poor" + ACCEPTABLE = "Acceptable" + GOOD = "Good" + EXCELLENT = "Excellent" + + +class FailureCommunicationRating(BaseModel): + """Structured output for failure communication evaluation.""" + + reasoning: str = Field(description="Step by step reasoning to derive the final score") + score: FailureCommunicationScore = Field(description="Categorical failure communication rating") + + +class FailureCommunicationEvaluator(Evaluator[InputT, OutputT]): + """Evaluates quality of agent's failure communication and user experience. + + Scores how well the agent communicated failures to the user on a 5-point scale: + - Excellent (1.0): Clear communication, user understands situation and next steps + - Good (0.75): Good communication with minor gaps + - Acceptable (0.5): Basic communication, or no failure occurred (baseline) + - Poor (0.25): Confusing messages, overly technical, misleading + - Failure (0.0): No communication of failures that did occur + + Example:: + + from strands_evals.chaos.evaluators import FailureCommunicationEvaluator + + evaluator = FailureCommunicationEvaluator() + experiment = ChaosExperiment( + chaos_plugin=chaos, + chaos_scenarios=scenarios, + cases=cases, + evaluators=[evaluator], + ) + """ + + evaluation_level = EvaluationLevel.TRACE_LEVEL + + _score_mapping = { + FailureCommunicationScore.FAILURE: 0.0, + FailureCommunicationScore.POOR: 0.25, + FailureCommunicationScore.ACCEPTABLE: 0.5, + FailureCommunicationScore.GOOD: 0.75, + FailureCommunicationScore.EXCELLENT: 1.0, + } + + def __init__( + self, + version: str = "v0", + model: Model | str | None = None, + system_prompt: str | None = None, + ): + super().__init__() + self.version = version + default_prompt = _PROMPT_VERSIONS.get(version, SYSTEM_PROMPT_V0) + self.system_prompt = system_prompt if system_prompt is not None else default_prompt + self.model = model + + def evaluate(self, evaluation_case: EvaluationData[InputT, OutputT]) -> list[EvaluationOutput]: + parsed_input = self._get_last_turn(evaluation_case) + prompt = self._format_trace_level_prompt(parsed_input) + evaluator_agent = Agent(model=self.model, system_prompt=self.system_prompt, callback_handler=None) + result = evaluator_agent(prompt, structured_output_model=FailureCommunicationRating) + rating = cast(FailureCommunicationRating, result.structured_output) + normalized_score = self._score_mapping[rating.score] + return [ + EvaluationOutput( + score=normalized_score, + test_pass=normalized_score >= 0.5, + reason=rating.reasoning, + label=rating.score, + ) + ] + + async def evaluate_async(self, evaluation_case: EvaluationData[InputT, OutputT]) -> list[EvaluationOutput]: + parsed_input = self._get_last_turn(evaluation_case) + prompt = self._format_trace_level_prompt(parsed_input) + evaluator_agent = Agent(model=self.model, system_prompt=self.system_prompt, callback_handler=None) + result = await evaluator_agent.invoke_async(prompt, structured_output_model=FailureCommunicationRating) + rating = cast(FailureCommunicationRating, result.structured_output) + normalized_score = self._score_mapping[rating.score] + return [ + EvaluationOutput( + score=normalized_score, + test_pass=normalized_score >= 0.5, + reason=rating.reasoning, + label=rating.score, + ) + ] diff --git a/src/strands_evals/chaos/evaluators/partial_completion_evaluator.py b/src/strands_evals/chaos/evaluators/partial_completion_evaluator.py new file mode 100644 index 00000000..4baaec34 --- /dev/null +++ b/src/strands_evals/chaos/evaluators/partial_completion_evaluator.py @@ -0,0 +1,99 @@ +"""Partial Completion Evaluator. + +Evaluates what percentage of task objectives were achieved despite failures. +""" + +from typing import cast + +from pydantic import BaseModel, Field +from strands import Agent +from strands.models.model import Model + +from ...types.evaluation import EvaluationData, EvaluationOutput, InputT, OutputT +from ...types.trace import EvaluationLevel +from ...evaluators.evaluator import Evaluator +from .prompt_templates.partial_completion_v0 import SYSTEM_PROMPT as SYSTEM_PROMPT_V0 + +_PROMPT_VERSIONS = {"v0": SYSTEM_PROMPT_V0} + + +class PartialCompletionRating(BaseModel): + """Structured output for partial completion evaluation.""" + + reasoning: str = Field(description="Step by step reasoning to derive the final score") + completion_percentage: float = Field( + description="Completion percentage from 0.0 to 1.0", + ge=0.0, + le=1.0, + ) + + +class PartialCompletionEvaluator(Evaluator[InputT, OutputT]): + """Evaluates what percentage of task objectives were achieved despite failures. + + Returns a continuous score from 0.0 to 1.0 representing the fraction of the + user's goal that was successfully completed. Passes if >= 0.5. + + Key principles: + - Subtasks are derived from the user's goal, not mapped 1:1 to tools + - LLM knowledge-based responses don't count as tool-dependent task completion + - Legitimate fallback strategies (alternative tools) do count + + Example:: + + from strands_evals.chaos.evaluators import PartialCompletionEvaluator + + evaluator = PartialCompletionEvaluator() + experiment = ChaosExperiment( + chaos_plugin=chaos, + chaos_scenarios=scenarios, + cases=cases, + evaluators=[evaluator], + ) + """ + + evaluation_level = EvaluationLevel.TRACE_LEVEL + + def __init__( + self, + version: str = "v0", + model: Model | str | None = None, + system_prompt: str | None = None, + ): + super().__init__() + self.version = version + default_prompt = _PROMPT_VERSIONS.get(version, SYSTEM_PROMPT_V0) + self.system_prompt = system_prompt if system_prompt is not None else default_prompt + self.model = model + + def evaluate(self, evaluation_case: EvaluationData[InputT, OutputT]) -> list[EvaluationOutput]: + parsed_input = self._get_last_turn(evaluation_case) + prompt = self._format_trace_level_prompt(parsed_input) + evaluator_agent = Agent(model=self.model, system_prompt=self.system_prompt, callback_handler=None) + result = evaluator_agent(prompt, structured_output_model=PartialCompletionRating) + rating = cast(PartialCompletionRating, result.structured_output) + + return [ + EvaluationOutput( + score=rating.completion_percentage, + test_pass=rating.completion_percentage >= 0.5, + reason=rating.reasoning, + label=f"{rating.completion_percentage:.2f}", + ) + ] + + async def evaluate_async(self, evaluation_case: EvaluationData[InputT, OutputT]) -> list[EvaluationOutput]: + parsed_input = self._get_last_turn(evaluation_case) + prompt = self._format_trace_level_prompt(parsed_input) + evaluator_agent = Agent(model=self.model, system_prompt=self.system_prompt, callback_handler=None) + result = await evaluator_agent.invoke_async(prompt, structured_output_model=PartialCompletionRating) + rating = cast(PartialCompletionRating, result.structured_output) + + return [ + EvaluationOutput( + score=rating.completion_percentage, + test_pass=rating.completion_percentage >= 0.5, + reason=rating.reasoning, + label=f"{rating.completion_percentage:.2f}", + ) + ] diff --git a/src/strands_evals/chaos/evaluators/prompt_templates/__init__.py b/src/strands_evals/chaos/evaluators/prompt_templates/__init__.py new file mode 100644 index 00000000..1210f77f --- /dev/null +++ b/src/strands_evals/chaos/evaluators/prompt_templates/__init__.py @@ -0,0 +1 @@ +"""Prompt templates for chaos evaluators.""" diff --git a/src/strands_evals/chaos/evaluators/prompt_templates/failure_communication_v0.py b/src/strands_evals/chaos/evaluators/prompt_templates/failure_communication_v0.py new file mode 100644 index 00000000..8c85c11f --- /dev/null +++ b/src/strands_evals/chaos/evaluators/prompt_templates/failure_communication_v0.py @@ -0,0 +1,47 @@ +"""LLM-Judge prompt for Failure Communication & UX evaluation. + +Evaluates how well the agent communicated failures to the user. +""" + +SYSTEM_PROMPT = """You are an objective judge evaluating how well an AI agent communicated with the user during a failure. You will be provided with the complete conversation trace including: +- User messages (User:) +- Assistant responses (Assistant:) +- Tool calls selected by the assistant (Action:) +- Tool outputs (Tool:) + +CONTEXT: You are evaluating an AI agent under chaos testing conditions. The agent may have been subjected to deliberately injected failures to test its resilience. Your evaluation should assess how well the agent handled and communicated these failures, NOT whether the agent is generally competent. + +Your task is to assess the quality of error communication and user experience. Follow these steps: +1. First, determine whether any failures or errors occurred in the trace. Look for error messages, exceptions, non-200 status codes, timeout indicators, or explicit failure signals in tool outputs. +2. **MANDATORY GATE — If NO failure occurred:** You MUST immediately rate **Acceptable (0.5)** and STOP. A trace with zero failures always scores 0.5, no exceptions. +3. If a failure DID occur, evaluate the agent's communication: + a. Clarity: Did the user understand what went wrong? + b. Actionability: Did the agent provide clear next steps? + c. Transparency: Was the agent honest about limitations? + d. Tone: Was it appropriately calibrated for the situation? +4. Rate the communication quality using the scoring scale below. + +# Scoring Scale + +- **Excellent** = 1.0 +- **Good** = 0.75 +- **Acceptable** = 0.5 +- **Poor** = 0.25 +- **Failure** = 0.0 + +Scores >= 0.5 pass. + +# Evaluation Rubric + +- **Excellent (1.0)**: Clear error communication; user understands what happened and next steps; graceful degradation communicated transparently; trust maintained +- **Good (0.75)**: Good communication with minor gaps; user mostly understands; appropriate tone and transparency +- **Acceptable (0.5)**: Basic error communication; user knows something went wrong but lacks clarity on impact or next steps. ALSO use this when no failure occurred in the trace. +- **Poor (0.25)**: Confusing or unhelpful messages; user uncertain what happened; overly technical jargon; misleading reassurances about degraded results +- **Failure (0.0)**: Failures DID occur but the agent provided NO error communication; user is unaware of the failure; agent hides problems or provides false confidence + +CRITICAL RULES: +- **NO-FAILURE BASELINE**: If no failure or error appears in the trace, you MUST rate Acceptable (0.5). This is mandatory and unconditional. +- If a tool failed but the agent silently ignored it and responded as if nothing happened, rate Failure (0.0). +- If a failure occurred and the agent acknowledged it but provided NO alternative or next steps, rate Poor (0.25). + +**IMPORTANT**: The user context and agent's available tools in the trace ALWAYS take priority over your own knowledge.""" diff --git a/src/strands_evals/chaos/evaluators/prompt_templates/partial_completion_v0.py b/src/strands_evals/chaos/evaluators/prompt_templates/partial_completion_v0.py new file mode 100644 index 00000000..f1a9a5e7 --- /dev/null +++ b/src/strands_evals/chaos/evaluators/prompt_templates/partial_completion_v0.py @@ -0,0 +1,43 @@ +"""LLM-Judge prompt for Partial Completion Score evaluation. + +Evaluates task completion percentage despite failures. +""" + +SYSTEM_PROMPT = """You are an objective judge evaluating what percentage of task objectives an AI agent achieved despite failures. You will be provided with the complete conversation trace including: +- User messages (User:) +- Assistant responses (Assistant:) +- Tool calls selected by the assistant (Action:) +- Tool outputs (Tool:) + +CONTEXT: You are evaluating an AI agent under chaos testing conditions. The agent may have been subjected to deliberately injected failures to test its resilience. Your evaluation should focus on whether the agent actually completed the task using its tools, NOT whether it provided a plausible-sounding response. + +Your task is to determine how much of the USER'S GOAL was successfully achieved. Follow these steps: +1. Identify the user's original task and objectives from the trace +2. Break down the task into discrete subtasks derived from the USER'S GOAL — NOT from the tool list +3. For each subtask, determine if it was successfully completed USING THE APPROPRIATE TOOLS +4. Assess whether partial results are meaningful and usable +5. Calculate the completion percentage based on goal achievement + +# How to Define Subtasks (CRITICAL) + +Subtasks must be derived from the user's stated goal, NOT mapped 1:1 to individual tools. A single user goal may require multiple tools, or multiple tools may contribute to a single subtask. + +# Evaluation Rubric + +Rate completion as a percentage from 0% to 100%: +- 100%: User's goal fully achieved — all objectives met using tools +- 75-99%: Mostly achieved; minor gaps that don't significantly reduce value +- 50-74%: Meaningful partial value; significant portions of the goal met +- 25-49%: Limited value; most of the goal unmet +- 0-24%: Little to no meaningful progress toward the user's goal + +CRITICAL RULES: +- If the user's task required specific tools and those tools FAILED, the agent CANNOT score above 50% by falling back to its own training knowledge alone. +- An LLM generating a response from its training data is NOT equivalent to completing a tool-dependent subtask. +- Only count a subtask as completed if the agent actually used the required tool successfully OR found a legitimate alternative tool that produces equivalent results. +- If ALL tools failed and the agent only provided knowledge-based responses, score 0-25% maximum. +- Do NOT mechanically compute the score as (successful_tools / total_tools). Tool success rate and task completion rate are different things. + +Note: Return completion_percentage as a decimal from 0.0 to 1.0 (e.g., 75% = 0.75) + +**IMPORTANT**: The user's stated task and agent's available tools in the trace ALWAYS take priority over your own knowledge.""" diff --git a/src/strands_evals/chaos/evaluators/prompt_templates/recovery_strategy_v0.py b/src/strands_evals/chaos/evaluators/prompt_templates/recovery_strategy_v0.py new file mode 100644 index 00000000..4b9a9d52 --- /dev/null +++ b/src/strands_evals/chaos/evaluators/prompt_templates/recovery_strategy_v0.py @@ -0,0 +1,89 @@ +"""LLM-Judge prompt for Recovery Strategy Appropriateness evaluation. + +Evaluates strategy quality based on the agent's decisions and actions, not outcomes. +Accounts for error type when judging retry decisions (transient vs permanent). +Focuses on what the agent *did* (actions), not what it *said* (communication). +""" + +SYSTEM_PROMPT = """You are an objective judge evaluating how appropriately an AI agent chose its recovery strategy when handling system failures. You will be provided with the complete conversation trace including: +- User messages (User:) +- Assistant responses (Assistant:) +- Tool calls selected by the assistant (Action:) +- Tool outputs (Tool:) + +CONTEXT: You are evaluating an AI agent under chaos testing conditions. The agent may have been subjected to deliberately injected failures across multiple layers of the system to test its resilience. + +## Failure Category Taxonomy + +### Tool Failures (failures at the tool/API layer) +- **Tool Unavailability**: External tools or APIs are temporarily inaccessible +- **Tool Latency**: Tools respond with abnormal delays or timeouts +- **Tool Response Corruption**: Tools return malformed, partial, or corrupted data +- **Tool Authentication Failure**: Expired credentials, permission errors, or auth token issues + +### Resource Failures (failures at the infrastructure/platform layer) +- **Model API Unavailability**: The LLM or model API is down or unreachable +- **Rate Limiting**: API rate limits are exhausted, causing throttled or rejected requests +- **Embedding Model Failure**: Embedding or reranking models are unavailable +- **Data Ingestion Scalability Issues**: Data pipelines are overloaded + +SCOPE: You are evaluating the agent's **actions and decisions**, NOT its communication. Focus exclusively on whether the agent took the right recovery actions. + +IMPORTANT: You are evaluating the **quality of the strategy**, NOT whether the strategy succeeded. A well-reasoned strategy that fails due to external factors is still a good strategy. + +# Evaluation Steps + +1. Inventory all actions in the trace. For each, note: + - Tool or resource name + - Whether it succeeded or failed + - If it failed: the error type + - Whether it was a retry of a previous failed call + - Whether the retry used modified inputs or an alternative approach + +2. Classify the observed situation: + - **No tool calls**: Agent never called any tools → rate Acceptable (0.5) + - **No failures**: Every tool eventually succeeded → rate Acceptable (0.5) + - **Partial failures**: Some tools ultimately failed, others succeeded → PARTIAL FAILURE rubric + - **Total failures**: Every tool ultimately failed → TOTAL FAILURE rubric + +3. Score using the appropriate rubric below. + +# Retry Appropriateness by Error Type + +- **Timeouts/Network errors**: Transient — retrying once or twice is reasonable +- **Validation/Execution errors**: Likely permanent — retrying with same input is wasteful; modified inputs are justified +- **Rate limiting (429)**: Transient — retrying with backoff is correct; immediate retry without backoff is wasteful + +# Scoring Scale + +- **Excellent** = 1.0 +- **Good** = 0.75 +- **Acceptable** = 0.5 +- **Poor** = 0.25 +- **Failure** = 0.0 + +Scores >= 0.5 pass. + +# PARTIAL FAILURE Rubric (some tools work, some don't) + +- **Excellent (1.0)**: Agent quickly identifies working tools, pivots to them, minimal wasted retries on broken ones +- **Good (0.75)**: Agent uses working tools but wastes 1-2 calls on unjustified retries of broken ones +- **Acceptable (0.5)**: Agent eventually uses working tools but with notable inefficiency +- **Poor (0.25)**: Agent retries broken tools excessively despite having working alternatives +- **Failure (0.0)**: Agent ignores working tools and fixates on broken ones + +# TOTAL FAILURE Rubric (all tools fail) + +- **Excellent (1.0)**: Agent tries each distinct tool, uses justified retries for transient errors, varies approach +- **Good (0.75)**: Agent tries most tools, reasonable retry discipline, some approach variation +- **Acceptable (0.5)**: Agent tries some tools but misses opportunities +- **Poor (0.25)**: Agent shows poor strategic decisions — retries same broken tool with identical inputs +- **Failure (0.0)**: Agent makes no attempt to adapt — loops on one tool indefinitely + +CRITICAL RULES: +- If no failure appears in the trace, you MUST rate Acceptable (0.5). +- If the agent made no tool calls at all, rate Acceptable (0.5). +- Judge retries based on error type: penalize unjustified retries of permanent errors, but do NOT penalize justified retries of transient errors. +- Classify partial vs total failure based on the final outcome per tool name, not individual calls. + +**IMPORTANT**: The agent prompt and available tools in the trace ALWAYS take priority over your own knowledge.""" diff --git a/src/strands_evals/chaos/evaluators/recovery_strategy_evaluator.py b/src/strands_evals/chaos/evaluators/recovery_strategy_evaluator.py new file mode 100644 index 00000000..470ec411 --- /dev/null +++ b/src/strands_evals/chaos/evaluators/recovery_strategy_evaluator.py @@ -0,0 +1,114 @@ +"""Recovery Strategy Evaluator. + +Evaluates appropriateness of agent's recovery strategy when handling failures. +Focuses on what the agent *did* (actions), not what it *said* (communication). +""" + +from enum import Enum +from typing import cast + +from pydantic import BaseModel, Field +from strands import Agent +from strands.models.model import Model + +from ...types.evaluation import EvaluationData, EvaluationOutput, InputT, OutputT +from ...types.trace import EvaluationLevel +from ...evaluators.evaluator import Evaluator +from .prompt_templates.recovery_strategy_v0 import SYSTEM_PROMPT as SYSTEM_PROMPT_V0 + +_PROMPT_VERSIONS = {"v0": SYSTEM_PROMPT_V0} + + +class RecoveryStrategyScore(str, Enum): + """Categorical recovery strategy ratings.""" + + FAILURE = "Failure" + POOR = "Poor" + ACCEPTABLE = "Acceptable" + GOOD = "Good" + EXCELLENT = "Excellent" + + +class RecoveryStrategyRating(BaseModel): + """Structured output for recovery strategy evaluation.""" + + reasoning: str = Field(description="Step by step reasoning to derive the final score") + score: RecoveryStrategyScore = Field(description="Categorical recovery strategy rating") + + +class RecoveryStrategyEvaluator(Evaluator[InputT, OutputT]): + """Evaluates appropriateness of agent's recovery strategy when handling failures. + + Scores the agent's actions and decisions (not communication) on a 5-point scale: + - Excellent (1.0): Optimal recovery actions, justified retries, broad exploration + - Good (0.75): Reasonable strategy with minor inefficiencies + - Acceptable (0.5): No failures occurred, or basic recovery attempted + - Poor (0.25): Wasteful retries, ignored working alternatives + - Failure (0.0): No adaptation, fixated on broken tools + + Example:: + + from strands_evals.chaos.evaluators import RecoveryStrategyEvaluator + + evaluator = RecoveryStrategyEvaluator() + experiment = ChaosExperiment( + chaos_plugin=chaos, + chaos_scenarios=scenarios, + cases=cases, + evaluators=[evaluator], + ) + """ + + evaluation_level = EvaluationLevel.TRACE_LEVEL + + _score_mapping = { + RecoveryStrategyScore.FAILURE: 0.0, + RecoveryStrategyScore.POOR: 0.25, + RecoveryStrategyScore.ACCEPTABLE: 0.5, + RecoveryStrategyScore.GOOD: 0.75, + RecoveryStrategyScore.EXCELLENT: 1.0, + } + + def __init__( + self, + version: str = "v0", + model: Model | str | None = None, + system_prompt: str | None = None, + ): + super().__init__() + self.version = version + default_prompt = _PROMPT_VERSIONS.get(version, SYSTEM_PROMPT_V0) + self.system_prompt = system_prompt if system_prompt is not None else default_prompt + self.model = model + + def evaluate(self, evaluation_case: EvaluationData[InputT, OutputT]) -> list[EvaluationOutput]: + parsed_input = self._get_last_turn(evaluation_case) + prompt = self._format_trace_level_prompt(parsed_input) + evaluator_agent = Agent(model=self.model, system_prompt=self.system_prompt, callback_handler=None) + result = evaluator_agent(prompt, structured_output_model=RecoveryStrategyRating) + rating = cast(RecoveryStrategyRating, result.structured_output) + normalized_score = self._score_mapping[rating.score] + return [ + EvaluationOutput( + score=normalized_score, + test_pass=normalized_score >= 0.5, + reason=rating.reasoning, + label=rating.score, + ) + ] + + async def evaluate_async(self, evaluation_case: EvaluationData[InputT, OutputT]) -> list[EvaluationOutput]: + parsed_input = self._get_last_turn(evaluation_case) + prompt = self._format_trace_level_prompt(parsed_input) + evaluator_agent = Agent(model=self.model, system_prompt=self.system_prompt, callback_handler=None) + result = await evaluator_agent.invoke_async(prompt, structured_output_model=RecoveryStrategyRating) + rating = cast(RecoveryStrategyRating, result.structured_output) + normalized_score = self._score_mapping[rating.score] + return [ + EvaluationOutput( + score=normalized_score, + test_pass=normalized_score >= 0.5, + reason=rating.reasoning, + label=rating.score, + ) + ] diff --git a/src/strands_evals/chaos/experiment.py b/src/strands_evals/chaos/experiment.py new file mode 100644 index 00000000..1c889dc5 --- /dev/null +++ b/src/strands_evals/chaos/experiment.py @@ -0,0 +1,257 @@ +"""Chaos Experiment. + +Composes the base Experiment to run test cases across multiple chaos scenarios, +providing deterministic evaluation of agent resilience under tool failures. +""" + +import logging +import uuid +from collections.abc import Callable +from typing import Any, Optional + +from ..case import Case +from ..evaluators.evaluator import Evaluator +from ..experiment import Experiment +from ..types.evaluation_report import EvaluationReport +from ._context import _current_scenario +from .aggregator import ChaosScenarioAggregator +from .aggregator_types import ChaosAggregationReport +from .scenario import ChaosScenario + +logger = logging.getLogger(__name__) + +# The baseline scenario — no chaos effects +_BASELINE_SCENARIO = ChaosScenario(name="baseline") + + +class ChaosExperiment: + """Runs cases × scenarios by composing the base Experiment. + + For each scenario, activates it via ContextVar, runs all cases through + the evaluators, then resets. The user's task body contains zero chaos + concepts — the plugin reads the active scenario from the ContextVar. + + Example:: + + from strands_evals.chaos import ( + ChaosExperiment, + ChaosScenario, + ChaosScenarioAggregator, + ToolCallFailure, + TruncateFields, + ) + + scenarios = [ + ChaosScenario(name="search_timeout", effects={"search_tool": [ToolCallFailure(error_type="timeout")]}), + ChaosScenario(name="search_corrupt", effects={"search_tool": [TruncateFields(max_length=5)]}), + ] + + experiment = ChaosExperiment( + cases=test_cases, + scenarios=scenarios, + evaluators=[my_evaluator], + aggregator=ChaosScenarioAggregator(), + ) + + reports = experiment.run_evaluations(task=my_task) + aggregation_report = experiment.aggregate_evaluations() + aggregation_report.run_display() + """ + + def __init__( + self, + cases: list[Case], + scenarios: list[ChaosScenario], + evaluators: Optional[list[Evaluator]] = None, + include_baseline: bool = True, + aggregator: Optional[ChaosScenarioAggregator] = None, + ): + """Initialize a ChaosExperiment. + + Args: + cases: Test cases to evaluate. + scenarios: List of chaos scenarios. Each scenario runs all cases. + All effects in a scenario fire simultaneously in a single run. + evaluators: Evaluators to assess results. + include_baseline: If True, runs all cases with no chaos first for comparison. + aggregator: Optional ChaosScenarioAggregator for cross-scenario analysis. + If provided, aggregate_evaluations() can be called after run_evaluations(). + """ + self._original_cases = cases + self._scenarios = scenarios + self._evaluators = evaluators + self._include_baseline = include_baseline + self._aggregator = aggregator + self._last_reports: list[EvaluationReport] = [] + + # Build the expanded case list and internal maps + self._expanded_cases: list[Case] = [] + self._scenario_by_session: dict[str, ChaosScenario] = {} + self._original_case_name_by_session: dict[str, Optional[str]] = {} + + all_scenarios = [] + if include_baseline: + all_scenarios.append(_BASELINE_SCENARIO) + all_scenarios.extend(scenarios) + + for case in cases: + for scenario in all_scenarios: + session_id = str(uuid.uuid4()) + expanded_case = case.model_copy( + update={ + "name": f"{case.name}|{scenario.name}" if case.name else scenario.name, + "session_id": session_id, + } + ) + self._expanded_cases.append(expanded_case) + self._scenario_by_session[session_id] = scenario + self._original_case_name_by_session[session_id] = case.name + + # Auto-populate aggregator's known_tools and scenarios from experiment + if self._aggregator is not None: + if not self._aggregator.known_tools: + tools: set[str] = set() + for scenario in scenarios: + tools.update(scenario.effects.keys()) + self._aggregator.known_tools = sorted(tools) + if not self._aggregator._scenario_effects: + for scenario in scenarios: + tool_effects: dict[str, list[str]] = {} + for tool_name, effects in scenario.effects.items(): + effect_names = [] + for e in effects: + if hasattr(e, "error_type"): + effect_names.append(e.error_type) + elif hasattr(e, "max_length"): + effect_names.append("truncate_fields") + elif hasattr(e, "remove_ratio"): + effect_names.append("remove_fields") + elif hasattr(e, "corrupt_ratio"): + effect_names.append("corrupt_values") + else: + effect_names.append(type(e).__name__.lower()) + tool_effects[tool_name] = effect_names + self._aggregator._scenario_effects[scenario.name] = tool_effects + + # Internal Experiment with expanded cases + self._experiment = Experiment( + cases=self._expanded_cases, + evaluators=evaluators, + ) + + @property + def scenarios(self) -> list[ChaosScenario]: + """The chaos scenarios configured for this experiment.""" + return self._scenarios + + @property + def cases(self) -> list[Case]: + """The original (unexpanded) test cases.""" + return self._original_cases + + def get_scenario_for_session(self, session_id: str) -> Optional[ChaosScenario]: + """Look up the scenario assigned to a given session_id.""" + return self._scenario_by_session.get(session_id) + + def get_original_case_name(self, session_id: str) -> Optional[str]: + """Look up the original case name for a given session_id.""" + return self._original_case_name_by_session.get(session_id) + + def run_evaluations( + self, + task: Callable[[Case], Any], + **kwargs, + ) -> list[EvaluationReport]: + """Run evaluations across all (case × scenario) combinations. + + Args: + task: The task function to evaluate. Takes a Case and returns output. + **kwargs: Additional kwargs passed to the base Experiment.run_evaluations. + + Returns: + List of EvaluationReport objects covering all scenarios. + """ + + def chaos_aware_task(case: Case) -> Any: + scenario = self._scenario_by_session.get(case.session_id) + token = _current_scenario.set(scenario) + try: + return task(case) + finally: + _current_scenario.reset(token) + + reports = self._experiment.run_evaluations(chaos_aware_task, **kwargs) + self._last_reports = reports + + num_scenarios = len(self._scenarios) + (1 if self._include_baseline else 0) + logger.info( + f"Chaos experiment complete: {len(reports)} reports " + f"({len(self._original_cases)} cases × {num_scenarios} scenarios)" + ) + + return reports + + async def run_evaluations_async( + self, + task: Callable[[Case], Any], + max_workers: int = 10, + **kwargs, + ) -> list[EvaluationReport]: + """Run evaluations asynchronously across all (case × scenario) combinations.""" + import asyncio + + def chaos_aware_task(case: Case) -> Any: + scenario = self._scenario_by_session.get(case.session_id) + token = _current_scenario.set(scenario) + try: + return task(case) + finally: + _current_scenario.reset(token) + + async def chaos_aware_task_async(case: Case) -> Any: + scenario = self._scenario_by_session.get(case.session_id) + token = _current_scenario.set(scenario) + try: + if asyncio.iscoroutinefunction(task): + return await task(case) + else: + return task(case) + finally: + _current_scenario.reset(token) + + if asyncio.iscoroutinefunction(task): + reports = await self._experiment.run_evaluations_async( + chaos_aware_task_async, max_workers=max_workers, **kwargs + ) + else: + reports = await self._experiment.run_evaluations_async( + chaos_aware_task, max_workers=max_workers, **kwargs + ) + + self._last_reports = reports + return reports + + def aggregate_evaluations(self) -> ChaosAggregationReport: + """Aggregate the last run's evaluation reports into a ChaosAggregationReport. + + Must be called after run_evaluations(). Uses the aggregator passed to __init__. + + Returns: + ChaosAggregationReport with .run_display() and .to_file() methods. + + Raises: + RuntimeError: If no aggregator was configured or run_evaluations() hasn't been called. + """ + if self._aggregator is None: + raise RuntimeError( + "No aggregator configured. Pass aggregator=ChaosScenarioAggregator() " + "to ChaosExperiment.__init__()." + ) + if not self._last_reports: + raise RuntimeError( + "No evaluation reports available. Call run_evaluations() first." + ) + + report = self._aggregator.aggregate(self._last_reports) + report._reports = self._last_reports + return report diff --git a/src/strands_evals/chaos/plugin.py b/src/strands_evals/chaos/plugin.py new file mode 100644 index 00000000..e05c8f6f --- /dev/null +++ b/src/strands_evals/chaos/plugin.py @@ -0,0 +1,159 @@ +"""Chaos Plugin for Strands Agents. + +Implements chaos injection as a standard Strands Plugin using the SDK's +native hook system (BeforeToolCallEvent / AfterToolCallEvent). + +The plugin is stateless — it reads the active scenario from a module-level +ContextVar at hook time. The ChaosExperiment manages the ContextVar lifecycle. + +The plugin is a thin router: +- Pre-hook effects: reads effect.error_message, cancels the tool call. +- Post-hook effects: calls effect.apply(response), uses the return value. +""" + +import json +import logging +from typing import Any + +from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent +from strands.plugins import Plugin, hook + +from ._context import _current_scenario +from .effects import ChaosEffect + +logger = logging.getLogger(__name__) + + +class ChaosPlugin(Plugin): + """Strands Plugin that injects deterministic chaos based on the active scenario. + + The plugin intercepts tool calls via Strands' native hook system: + - BeforeToolCallEvent: cancels tool calls for pre-hook effects (Timeout, NetworkError, etc.) + - AfterToolCallEvent: corrupts tool responses for post-hook effects (TruncateFields, etc.) + + The active scenario is managed via a ContextVar (set by ChaosExperiment). + When no scenario is active, all tools behave normally. + + The plugin is stateless — no set_active_scenario method, no instance state + for the current scenario. This makes it safe under concurrent execution. + + Example:: + + from strands import Agent + from strands_evals.chaos import ChaosPlugin + + chaos = ChaosPlugin() + agent = Agent( + model=my_model, + tools=[search_tool, database_tool], + plugins=[chaos], + ) + + # The ChaosExperiment handles scenario activation via ContextVar. + # The user's task body contains zero chaos concepts. + """ + + name = "chaos-testing" + + def __init__(self) -> None: + super().__init__() + + @hook + def before_tool_call(self, event: BeforeToolCallEvent) -> None: + """Intercept tool calls to inject pre-hook (error) effects. + + For error effects (Timeout, NetworkError, etc.), cancels the tool call + with the effect's error_message before the tool executes. + """ + scenario = _current_scenario.get() + if scenario is None: + return + + tool_name = event.tool_use.get("name", "") + effects = scenario.effects.get(tool_name, []) + if not effects: + return + + # First pre-hook effect wins (tool is cancelled once) + for effect in effects: + if effect.hook == "pre": + event.cancel_tool = effect.apply() + logger.info( + f"[Chaos] Injected {type(effect).__name__} on tool '{tool_name}'" + ) + return + + @hook + def after_tool_call(self, event: AfterToolCallEvent) -> None: + """Intercept tool results to inject post-hook (corruption) effects. + + For corruption effects (TruncateFields, RemoveFields, CorruptValues), + calls effect.apply(response) to mutate the tool response. + + Handles Strands ToolResult content shapes: + - dict content: pass directly to effect.apply() + - list of blocks: extract text dicts, parse JSON, apply effect + - plain dict result: pass directly to effect.apply() + + Envelope fields (status, toolUseId) are preserved around the corruption. + """ + scenario = _current_scenario.get() + if scenario is None: + return + + tool_name = event.tool_use.get("name", "") + effects = scenario.effects.get(tool_name, []) + if not effects: + return + + # Apply all post-hook effects sequentially (they compose) + for effect in effects: + if effect.hook != "post": + continue + + if not hasattr(event, "result") or event.result is None: + continue + + result = event.result + + if hasattr(result, "content"): + if isinstance(result.content, dict): + result.content = self._apply_with_envelope(effect, result.content) + elif isinstance(result.content, list): + result.content = self._apply_to_blocks(effect, result.content) + elif isinstance(result, dict): + event.result = self._apply_with_envelope(effect, result) + + logger.info(f"[Chaos] Applied {type(effect).__name__} on tool '{tool_name}'") + + def _apply_with_envelope(self, effect: ChaosEffect, response: dict[str, Any]) -> dict[str, Any]: + """Apply effect while preserving envelope fields.""" + envelope_fields = {"status", "toolUseId"} + saved = {k: response[k] for k in envelope_fields if k in response} + + # Strip envelope before passing to effect + payload = {k: v for k, v in response.items() if k not in envelope_fields} + corrupted = effect.apply(payload) + + # Restore envelope + corrupted.update(saved) + return corrupted + + def _apply_to_blocks(self, effect: ChaosEffect, blocks: list) -> list: + """Apply effect to text blocks in a content list.""" + corrupted_blocks = [] + for block in blocks: + if isinstance(block, dict) and "text" in block: + text_data = block["text"] + if isinstance(text_data, str): + try: + parsed = json.loads(text_data) + if isinstance(parsed, dict): + corrupted = effect.apply(parsed) + block = {**block, "text": json.dumps(corrupted)} + except (json.JSONDecodeError, ValueError): + # Plain text — apply truncation via effect if applicable + if hasattr(effect, "max_length"): + block = {**block, "text": text_data[: effect.max_length]} + corrupted_blocks.append(block) + return corrupted_blocks diff --git a/src/strands_evals/chaos/scenario.py b/src/strands_evals/chaos/scenario.py new file mode 100644 index 00000000..14d07c23 --- /dev/null +++ b/src/strands_evals/chaos/scenario.py @@ -0,0 +1,67 @@ +"""Chaos scenario definition. + +A ChaosScenario is a named, deterministic configuration of chaos effects +that will fire simultaneously when the scenario is active. +""" + +from typing import Optional + +from pydantic import BaseModel, Field + +from .effects import ChaosEffect + + +class ChaosScenario(BaseModel): + """A single, deterministic chaos injection scenario. + + Each scenario defines a set of tool effects that fire simultaneously when + the scenario is active. All listed effects are applied in the same + agent execution — this is NOT expanded into multiple runs. + + Tools not listed in tool_effects behave normally (no chaos). + + Example:: + + from strands_evals.chaos import ChaosScenario + from strands_evals.chaos.effects import Timeout, NetworkError, CorruptValues + + # Baseline — no chaos + ChaosScenario(name="baseline") + + # Single-fault: one tool fails + ChaosScenario( + name="search_timeout", + effects={"search_tool": [Timeout()]}, + ) + + # Compound: multiple tools/models fail simultaneously + ChaosScenario( + name="search_times_out_while_book_corrupts", + description=( + "Worst-case compound: primary path fails hard while the " + "recovery path silently returns bad data." + ), + effects={ + "search_tool": [Timeout()], + "book_tool": [CorruptValues(corrupt_ratio=0.8)], + }, + ) + """ + + name: str = Field(..., description="Human-readable name for this scenario") + description: Optional[str] = Field( + default=None, + description="Optional description of what this scenario tests.", + ) + effects: dict[str, list[ChaosEffect]] = Field( + default_factory=dict, + description="Mapping of target_name -> list of effects to inject simultaneously. " + "Targets not listed here behave normally.", + ) + + def __repr__(self) -> str: + effects_str = ", ".join( + f"{target}: [{', '.join(type(e).__name__ for e in effs)}]" + for target, effs in self.effects.items() + ) + return f"ChaosScenario(name='{self.name}', effects={{{effects_str}}})" diff --git a/src/strands_evals/skills/__init__.py b/src/strands_evals/skills/__init__.py new file mode 100644 index 00000000..1b5473c8 --- /dev/null +++ b/src/strands_evals/skills/__init__.py @@ -0,0 +1,52 @@ +"""Skill evaluation aggregator for Strands Evals. + +Provides paired-comparison aggregation for evaluating agent skills against +a baseline. Designed for the canonical A/B test: run N trials of an agent +with a skill vs. without, pair by trial_idx, report per-task Δ-metrics +(pass rate, tokens, latency, cost) with paired statistical tests +(Wilcoxon / paired-t / McNemar) and bootstrap confidence intervals. + +Closes strands-agents/evals#228. + +Example:: + + from strands_evals.skills import ( + SkillEvalAggregator, + SkillEvalExperiment, + ) + + experiment = SkillEvalExperiment( + cases=test_cases, + variant_labels=["baseline", "variant"], + evaluators=[my_evaluator], + num_trials=30, + aggregator=SkillEvalAggregator(model="us.anthropic.claude-sonnet-4-20250514-v1:0"), + ) + + reports = experiment.run_evaluations(task=my_task) + agg_report = experiment.aggregate_evaluations() + agg_report.run_display() + agg_report.to_file("skill_eval_report.json") +""" + +from .aggregation_display import ( + SkillEvalAggregationDisplay, + display_skill_aggregation, +) +from .aggregator import SkillEvalAggregator +from .aggregator_types import ( + PairedComparisonStats, + SkillEvalAggregation, + SkillEvalAggregationReport, +) +from .experiment import SkillEvalExperiment + +__all__ = [ + "PairedComparisonStats", + "SkillEvalAggregation", + "SkillEvalAggregationDisplay", + "SkillEvalAggregationReport", + "SkillEvalAggregator", + "SkillEvalExperiment", + "display_skill_aggregation", +] diff --git a/src/strands_evals/skills/aggregation_display.py b/src/strands_evals/skills/aggregation_display.py new file mode 100644 index 00000000..fdd883a4 --- /dev/null +++ b/src/strands_evals/skills/aggregation_display.py @@ -0,0 +1,285 @@ +"""Rich console display for SkillEvalAggregation results. + +Interactive table with expand/collapse. Collapsed shows one row per +(case, evaluator) with Δ-metrics across the configured metrics. Expanded +shows full paired-statistics panels per metric with p-value, CI, n_used, +and the test that was selected. +""" + +from rich.console import Group +from rich.panel import Panel +from rich.table import Table +from rich.text import Text + +from ..display.display_console import CollapsibleTableReportDisplay, console +from .aggregator_types import SkillEvalAggregation + + +# Significance thresholds for visual highlighting only. +_SIG_ALPHA = 0.05 +_SIG_ALPHA_STRONG = 0.01 + + +class SkillEvalAggregationDisplay(CollapsibleTableReportDisplay): + """Interactive console display for skill aggregation results. + + Collapsed: one summary row per (case, evaluator) with Δ-metrics and a + significance indicator. + Expanded: per-metric panels with full paired stats + corruption counts. + """ + + def __init__(self, aggregations: list[SkillEvalAggregation]): + self._aggregations = aggregations + + items = {} + overall_score = 0.0 + if aggregations: + overall_score = sum(a.mean_score for a in aggregations) / len(aggregations) + for i, agg in enumerate(aggregations): + items[str(i)] = { + "details": { + "name": agg.group_key, + "score": f"{agg.mean_score:.2f}", + "test_pass": agg.pass_rate >= 0.5, + }, + "detailed_results": [], + "expanded": False, + } + + super().__init__(items=items, overall_score=overall_score) + + def display_items(self): + """Render the skill aggregation report.""" + if not self._aggregations: + console.print( + Panel( + "[bold blue]No aggregation results[/bold blue]", + title="📊 Skill Evaluation Aggregation Report", + ) + ) + return + + # Header summary. + case_names = sorted({a.group_key for a in self._aggregations}) + evaluator_names = sorted({a.evaluator_name for a in self._aggregations}) + metric_names = sorted({ + ps.metric_name for a in self._aggregations for ps in a.paired_stats + }) + + dimensions = Text( + f"Cases: {len(case_names)} Evaluators: {len(evaluator_names)} " + f"Metrics: {len(metric_names)}" + ) + dimensions.stylize("bold blue") + + console.print( + Panel( + Group(dimensions), + title="📊 Skill Evaluation Aggregation Report", + ) + ) + + # Summary table — one row per (case, evaluator). + # Columns: index, task, evaluator, n_used, then one Δ column per metric, then p. + table = Table(title="Paired Comparison Summary", show_lines=True) + table.add_column("index", style="cyan") + table.add_column("task", style="magenta") + table.add_column("evaluator", style="yellow") + table.add_column("n_used", justify="center") + table.add_column("n_corrupt", justify="center", style="dim") + + # One Δ column per metric, sorted for stable layout. + for metric in metric_names: + table.add_column(f"Δ{metric}", justify="right") + + table.add_column("min p", justify="center") + + for i, agg in enumerate(self._aggregations): + key = str(i) + expanded = self.items[key]["expanded"] + symbol = "▼" if expanded else "▶" + + # Build a lookup: metric_name -> PairedComparisonStats. + by_metric = {ps.metric_name: ps for ps in agg.paired_stats} + + row_cells = [ + f"{symbol} {i}", + agg.group_key, + agg.evaluator_name, + str(agg.n_used), + str(agg.n_corrupted) if agg.n_corrupted > 0 else "[dim]0[/dim]", + ] + + for metric in metric_names: + ps = by_metric.get(metric) + if ps is None: + row_cells.append("[dim]—[/dim]") + else: + row_cells.append(self._format_delta_cell(ps)) + + min_p = min( + (ps.p_value for ps in agg.paired_stats), + default=float("nan"), + ) + row_cells.append(self._format_p_cell(min_p)) + + table.add_row(*row_cells) + + console.print(table) + + # Expanded panels. + for i, agg in enumerate(self._aggregations): + key = str(i) + if not self.items[key]["expanded"]: + continue + + console.print() + console.print(f"[bold magenta]Task: {agg.group_key}[/bold magenta]") + console.print(f"[dim]Evaluator: {agg.evaluator_name}[/dim]") + console.print( + f"[dim]Pairs: {agg.n_used} used / {agg.n_total} total " + f"({agg.n_corrupted} dropped to corruption)[/dim]" + ) + + stats_panel = self._build_stats_panel(agg) + summary_panel = self._build_summary_panel(agg) + + top_row = Table(show_header=False, show_edge=False, box=None, expand=True, padding=0) + top_row.add_column(ratio=1) + top_row.add_column(ratio=2) + top_row.add_row(stats_panel, summary_panel) + console.print(top_row) + + paired_panel = self._build_paired_stats_panel(agg) + console.print(paired_panel) + + # ------------------------------------------------------------------ + # Cell formatting + # ------------------------------------------------------------------ + + @staticmethod + def _format_delta_cell(ps) -> str: + """Format a Δmetric cell with color based on sign and significance.""" + delta = ps.delta + p = ps.p_value + + # Sign color: positive Δ = green (often "better"), negative = red. + # This is metric-agnostic; users should know that lower latency is good. + if delta > 0: + color = "green" + elif delta < 0: + color = "red" + else: + color = "white" + + # Significance weight. + weight = "" + if p < _SIG_ALPHA_STRONG: + weight = "bold " + elif p < _SIG_ALPHA: + weight = "" + else: + color = "dim" + + return f"[{weight}{color}]{delta:+.3f}[/{weight}{color}]" + + @staticmethod + def _format_p_cell(p: float) -> str: + """Format a p-value cell.""" + if p != p: # NaN + return "[dim]—[/dim]" + if p < _SIG_ALPHA_STRONG: + return f"[bold green]{p:.3f}[/bold green]" + if p < _SIG_ALPHA: + return f"[green]{p:.3f}[/green]" + return f"[dim]{p:.3f}[/dim]" + + # ------------------------------------------------------------------ + # Panels + # ------------------------------------------------------------------ + + @staticmethod + def _build_stats_panel(agg: SkillEvalAggregation) -> Panel: + """Variant-side aggregate stats (mirrors the chaos stats panel).""" + lines = [ + f"[bold]variant mean_score:[/bold] {agg.mean_score:.2f}", + f"[bold]variant min_score:[/bold] {agg.min_score:.2f}", + f"[bold]variant max_score:[/bold] {agg.max_score:.2f}", + f"[bold]variant pass_rate:[/bold] {agg.pass_rate:.0%} " + f"({agg.num_passed}/{agg.num_results})", + f"[bold]pairs used:[/bold] {agg.n_used}/{agg.n_total}", + ] + if agg.n_corrupted > 0: + lines.append( + f"[bold]pairs dropped:[/bold] {agg.n_corrupted} (corruption)" + ) + return Panel( + "\n".join(lines), + title="[bold blue]Stats[/bold blue]", + border_style="blue", + ) + + @staticmethod + def _build_summary_panel(agg: SkillEvalAggregation) -> Panel: + """LLM narrative summary (empty when no model configured).""" + summary_text = agg.summary if agg.summary else "[dim]No summary available[/dim]" + return Panel( + summary_text, + title="[bold cyan]Summary[/bold cyan]", + border_style="cyan", + ) + + @staticmethod + def _build_paired_stats_panel(agg: SkillEvalAggregation) -> Panel: + """Per-metric paired statistics table.""" + if not agg.paired_stats: + return Panel( + "[dim]No paired statistics available[/dim]", + title="[bold yellow]Paired Statistics[/bold yellow]", + border_style="yellow", + ) + + t = Table(show_header=True, show_lines=True, expand=True) + t.add_column("metric", style="bold", no_wrap=True) + t.add_column("baseline", justify="right") + t.add_column("variant", justify="right") + t.add_column("Δ", justify="right") + t.add_column("95% CI", justify="center") + t.add_column("p", justify="center") + t.add_column("test", justify="center", style="dim") + t.add_column("n", justify="center") + + for ps in agg.paired_stats: + t.add_row( + ps.metric_name, + f"{ps.baseline_mean:.3f}", + f"{ps.variant_mean:.3f}", + f"{ps.delta:+.3f}", + f"[{ps.ci_low:+.3f}, {ps.ci_high:+.3f}]", + SkillEvalAggregationDisplay._format_p_cell(ps.p_value), + ps.test_used, + str(ps.n_used), + ) + + return Panel( + t, + title="[bold yellow]Paired Statistics[/bold yellow]", + subtitle=( + "[dim]Δ = variant - baseline │ CI via 1000-resample bootstrap │ " + "bold p < 0.01, green p < 0.05[/dim]" + ), + border_style="yellow", + ) + + +def display_skill_aggregation( + aggregations: list[SkillEvalAggregation], static: bool = False +): + """Display skill aggregation results. + + Args: + aggregations: List of SkillEvalAggregation objects. + static: If True, display once without interaction. + """ + display = SkillEvalAggregationDisplay(aggregations) + display.run(static=static) diff --git a/src/strands_evals/skills/aggregator.py b/src/strands_evals/skills/aggregator.py new file mode 100644 index 00000000..afc0dd77 --- /dev/null +++ b/src/strands_evals/skills/aggregator.py @@ -0,0 +1,556 @@ +"""SkillEvalAggregator — paired-comparison aggregator for skills evaluation. + +Given a flat list of EvaluationReports from a SkillEvalExperiment, this +aggregator: + +1. Re-groups results by (original_case_name, evaluator_name). +2. Within each group, pairs trials by trial_idx using metadata flags: + - case.metadata["variant_label"] identifies baseline vs variant + - case.metadata["trial_idx"] identifies the pair +3. For each (case, metric), computes a paired statistical test + (Wilcoxon / paired-t / McNemar) plus a bootstrap CI on the delta. +4. Filters out pairs where either side is marked corrupted before stats. + +Efficiency metrics (tokens, latency, cost) are read from +``case.metadata["efficiency"]`` — a dict like +``{"tokens_in": int, "tokens_out": int, "latency_s": float, "cost_usd": float, +"tool_calls": int}``. Missing keys are treated as missing data for that +trial / metric. +""" + +import logging +from collections import defaultdict +from typing import Any, Optional, cast + +import numpy as np +from pydantic import BaseModel, Field +from scipy import stats as sp_stats +from strands import Agent +from strands.models.model import Model + +from ..aggregators.base import CaseAggregator +from ..types.evaluation_report import EvaluationReport +from .aggregator_types import ( + PairedComparisonStats, + SkillEvalAggregation, + SkillEvalAggregationReport, +) + +logger = logging.getLogger(__name__) + + +# Default metrics aggregated. Override via constructor `metrics=`. +_DEFAULT_METRICS = ("pass_rate", "tokens", "latency_s", "cost_usd") + +# Test selection: "auto" picks between wilcoxon and paired_t based on normality. +_VALID_TESTS = {"auto", "wilcoxon", "paired_t", "mcnemar"} + +# Normality threshold for "auto" — paired-t when Shapiro-Wilk fails to reject. +_NORMALITY_ALPHA = 0.05 + +# Bootstrap CI configuration. +_BOOTSTRAP_RESAMPLES = 1000 +_CI_LEVEL = 0.95 + + +# Default LLM prompt for narrative summarization. +_SUMMARIZE_SYSTEM_PROMPT = """\ +You are an evaluation analyst for AI agent skill comparisons. + +You will receive paired comparison results for one task across multiple +metrics (pass_rate, tokens, latency_s, cost_usd). The agent was run twice +on each trial — once with a baseline configuration, once with a variant. + +Produce a single paragraph (100 words max) that: +1. States whether the variant improved or degraded performance on this task. +2. Identifies which metrics moved meaningfully (effect size + statistical + significance) and which did not. +3. Notes any tradeoffs (e.g., "improves pass rate but at 30% higher cost"). + +Be specific. Do not repeat raw numbers verbatim — interpret them. Do not +use bullet points or multiple paragraphs. +""" + + +class SkillSummary(BaseModel): + """Structured output for LLM-based skill comparison summarization.""" + + reasoning: str = Field(description="Brief analysis of the variant vs baseline comparison") + summary: str = Field(description="Single paragraph summary") + + +class SkillEvalAggregator(CaseAggregator): + """Aggregates evaluation results across baseline / variant pairs. + + Designed to work with the output of SkillEvalExperiment, which tags each + case with metadata["variant_label"] and metadata["trial_idx"]. Produces + one SkillEvalAggregation per (original_case, evaluator) pair. + + Args: + baseline_label: Variant label that identifies the baseline condition. + Defaults to "baseline". + variant_label: Variant label that identifies the variant condition. + Defaults to "variant". + metrics: Metric names to aggregate. Defaults to ("pass_rate", "tokens", + "latency_s", "cost_usd"). "pass_rate" is special-cased to read from + test_passes; all others are read from + case.metadata["efficiency"][metric_name]. + stats_test: Statistical test to use. One of "auto", "wilcoxon", + "paired_t", "mcnemar". "auto" picks wilcoxon vs paired_t per + metric based on Shapiro-Wilk normality at alpha=0.05; + pass_rate always uses mcnemar regardless of this setting. + model: Model for LLM-as-a-Judge summarization. Accepts a model ID + string or a Model instance. If None (default), summarization is + skipped and the summary field is empty. + system_prompt: Optional custom system prompt for the summarization + judge. + name: Optional human-readable name for this aggregator. + + Example:: + + aggregator = SkillEvalAggregator( + metrics=["pass_rate", "tokens", "latency_s"], + model="us.anthropic.claude-sonnet-4-20250514-v1:0", + ) + + reports = experiment.run_evaluations(task=my_task) + report = aggregator.aggregate(reports) + report.run_display() + """ + + def __init__( + self, + baseline_label: str = "baseline", + variant_label: str = "variant", + metrics: tuple[str, ...] | list[str] = _DEFAULT_METRICS, + stats_test: str = "auto", + model: Optional[Model | str] = None, + system_prompt: Optional[str] = None, + name: Optional[str] = None, + ): + super().__init__(name=name or "SkillEvalAggregator") + if stats_test not in _VALID_TESTS: + raise ValueError( + f"stats_test must be one of {_VALID_TESTS}, got {stats_test!r}" + ) + self.baseline_label = baseline_label + self.variant_label = variant_label + self.metrics = list(metrics) + self.stats_test = stats_test + self.model = model + self.system_prompt = system_prompt or _SUMMARIZE_SYSTEM_PROMPT + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def aggregate(self, reports: list[EvaluationReport]) -> SkillEvalAggregationReport: + """Aggregate skill experiment reports into per-(case, evaluator) results. + + Args: + reports: Flat list of EvaluationReport objects from SkillEvalExperiment. + + Returns: + SkillEvalAggregationReport with .run_display() and .to_file() methods. + """ + if not reports: + return SkillEvalAggregationReport(aggregations=[]) + + grouped = self._group_results(reports) + + aggregations = [] + for (case_name, evaluator_name), entries in grouped.items(): + aggregation = self._build_aggregation(case_name, evaluator_name, entries) + aggregations.append(aggregation) + + aggregations.sort(key=lambda a: (a.group_key, a.evaluator_name)) + return SkillEvalAggregationReport(aggregations=aggregations) + + # ------------------------------------------------------------------ + # Internal grouping + # ------------------------------------------------------------------ + + def _group_results( + self, reports: list[EvaluationReport] + ) -> dict[tuple[str, str], list[dict]]: + """Group report entries by (original_case_name, evaluator_name). + + Each entry is a dict with: variant_label, trial_idx, score, passed, + reason, corrupted, efficiency, session_id, metadata. + """ + grouped: dict[tuple[str, str], list[dict]] = defaultdict(list) + + for report in reports: + evaluator_name = report.evaluator_name or "Unknown" + + for i, case_data in enumerate(report.cases): + metadata = case_data.get("metadata") or {} + variant_label = metadata.get("variant_label") + trial_idx = metadata.get("trial_idx") + if variant_label is None or trial_idx is None: + logger.debug( + "Skipping case without variant_label/trial_idx metadata: %s", + case_data.get("name"), + ) + continue + + original_name = metadata.get("original_case_name") or case_data.get("name", "") or "" + + grouped[(original_name, evaluator_name)].append({ + "variant_label": variant_label, + "trial_idx": trial_idx, + "score": report.scores[i] if i < len(report.scores) else 0.0, + "passed": report.test_passes[i] if i < len(report.test_passes) else False, + "reason": report.reasons[i] if i < len(report.reasons) else "", + "corrupted": bool(metadata.get("corrupted", False)), + "efficiency": metadata.get("efficiency") or {}, + "session_id": case_data.get("session_id", ""), + "metadata": metadata, + }) + + return grouped + + def _build_aggregation( + self, case_name: str, evaluator_name: str, entries: list[dict] + ) -> SkillEvalAggregation: + """Build a SkillEvalAggregation from grouped entries.""" + # Index entries by (variant_label, trial_idx) so we can pair them. + by_key: dict[tuple[str, int], dict] = { + (e["variant_label"], e["trial_idx"]): e for e in entries + } + + # Find paired trial indices — must have both baseline and variant. + baseline_indices = { + tidx for (vl, tidx) in by_key if vl == self.baseline_label + } + variant_indices = { + tidx for (vl, tidx) in by_key if vl == self.variant_label + } + paired_indices = sorted(baseline_indices & variant_indices) + + # Build pair rows, applying corruption filter. + rows: list[dict] = [] + n_corrupt_pairs = 0 + for tidx in paired_indices: + b = by_key[(self.baseline_label, tidx)] + v = by_key[(self.variant_label, tidx)] + if b["corrupted"] or v["corrupted"]: + n_corrupt_pairs += 1 + continue + rows.append({"trial_idx": tidx, "baseline": b, "variant": v}) + + n_total = len(paired_indices) + n_used = len(rows) + + # Compute paired stats per metric. + paired_stats: list[PairedComparisonStats] = [] + raw_baseline_values: dict[str, list[float]] = {} + raw_variant_values: dict[str, list[float]] = {} + + for metric in self.metrics: + b_vals, v_vals = self._collect_metric_pairs(rows, metric) + raw_baseline_values[metric] = b_vals + raw_variant_values[metric] = v_vals + stat = self._compute_paired_stat(metric, b_vals, v_vals, n_corrupt_pairs) + if stat is not None: + paired_stats.append(stat) + + # Compute the AggregationResult base stats from the variant-side scores. + variant_scores = [r["variant"]["score"] for r in rows] + variant_passes = [r["variant"]["passed"] for r in rows] + base_stats = self._compute_stats(variant_scores, variant_passes) + + # Summarize (LLM if configured, else empty). + if self.model is not None and paired_stats: + summary = self._summarize_for_aggregation( + case_name, evaluator_name, paired_stats + ) + else: + summary = "" + + return SkillEvalAggregation( + group_key=case_name, + evaluator_name=evaluator_name, + mean_score=base_stats["mean_score"], + min_score=base_stats["min_score"], + max_score=base_stats["max_score"], + pass_rate=base_stats["pass_rate"], + num_results=base_stats["num_results"], + num_passed=base_stats["num_passed"], + num_failed=base_stats["num_failed"], + summary=summary, + paired_stats=paired_stats, + raw_baseline_values=raw_baseline_values, + raw_variant_values=raw_variant_values, + n_total=n_total, + n_corrupted=n_corrupt_pairs, + n_used=n_used, + trajectory_pointers_baseline=[r["baseline"]["session_id"] for r in rows], + trajectory_pointers_variant=[r["variant"]["session_id"] for r in rows], + ) + + # ------------------------------------------------------------------ + # Metric extraction + # ------------------------------------------------------------------ + + @staticmethod + def _collect_metric_pairs( + rows: list[dict], metric: str + ) -> tuple[list[float], list[float]]: + """Pull baseline and variant values for one metric. + + pass_rate is special-cased to use the boolean `passed` field. + Other metrics are read from entry["efficiency"][metric]. Pairs where + either side is missing the metric are dropped from the returned lists. + """ + b_vals: list[float] = [] + v_vals: list[float] = [] + for r in rows: + b_entry = r["baseline"] + v_entry = r["variant"] + + if metric == "pass_rate": + b_vals.append(1.0 if b_entry["passed"] else 0.0) + v_vals.append(1.0 if v_entry["passed"] else 0.0) + continue + + b_val = b_entry["efficiency"].get(metric) + v_val = v_entry["efficiency"].get(metric) + if b_val is None or v_val is None: + continue + try: + b_vals.append(float(b_val)) + v_vals.append(float(v_val)) + except (TypeError, ValueError): + logger.warning( + "Non-numeric value for metric %r on trial %r; skipping.", + metric, + r["trial_idx"], + ) + + return b_vals, v_vals + + # ------------------------------------------------------------------ + # Paired statistics + # ------------------------------------------------------------------ + + def _compute_paired_stat( + self, + metric: str, + b_vals: list[float], + v_vals: list[float], + n_corrupted: int, + ) -> Optional[PairedComparisonStats]: + """Compute a paired comparison for one metric. + + Returns None if there is not enough data to compute the stat + (fewer than 2 paired observations). + """ + if len(b_vals) < 2 or len(v_vals) < 2 or len(b_vals) != len(v_vals): + return None + + b_arr = np.asarray(b_vals, dtype=float) + v_arr = np.asarray(v_vals, dtype=float) + deltas = v_arr - b_arr + + b_mean = float(b_arr.mean()) + v_mean = float(v_arr.mean()) + delta = v_mean - b_mean + delta_pct = (delta / b_mean) if abs(b_mean) > 1e-12 else None + + # Pick the test. + if metric == "pass_rate" or self.stats_test == "mcnemar": + test_used, p_value = self._mcnemar(b_arr, v_arr) + elif self.stats_test == "wilcoxon": + test_used, p_value = self._wilcoxon(deltas) + elif self.stats_test == "paired_t": + test_used, p_value = self._paired_t(b_arr, v_arr) + else: # auto + test_used, p_value = self._auto_test(b_arr, v_arr, deltas) + + ci_low, ci_high = self._bootstrap_ci(deltas) + + return PairedComparisonStats( + metric_name=metric, + baseline_mean=b_mean, + variant_mean=v_mean, + delta=delta, + delta_pct=delta_pct, + test_used=test_used, + p_value=p_value, + ci_low=ci_low, + ci_high=ci_high, + n_used=len(b_arr), + n_corrupted=n_corrupted, + ) + + @staticmethod + def _wilcoxon(deltas: np.ndarray) -> tuple[str, float]: + """Two-sided Wilcoxon signed-rank test on paired differences.""" + if np.all(deltas == 0): + return "wilcoxon", 1.0 + try: + result = sp_stats.wilcoxon(deltas, zero_method="wilcox", alternative="two-sided") + return "wilcoxon", float(result.pvalue) + except ValueError as e: + logger.warning("wilcoxon failed: %s", e) + return "wilcoxon", float("nan") + + @staticmethod + def _paired_t(b_arr: np.ndarray, v_arr: np.ndarray) -> tuple[str, float]: + """Paired t-test.""" + try: + result = sp_stats.ttest_rel(v_arr, b_arr) + return "paired_t", float(result.pvalue) + except ValueError as e: + logger.warning("paired_t failed: %s", e) + return "paired_t", float("nan") + + @classmethod + def _auto_test( + cls, b_arr: np.ndarray, v_arr: np.ndarray, deltas: np.ndarray + ) -> tuple[str, float]: + """Pick paired_t when deltas pass Shapiro-Wilk, else wilcoxon. + + For n < 3 or all-zero deltas, falls back to wilcoxon. + """ + n = len(deltas) + if n < 3 or np.all(deltas == 0): + return cls._wilcoxon(deltas) + try: + sw = sp_stats.shapiro(deltas) + if sw.pvalue > _NORMALITY_ALPHA: + return cls._paired_t(b_arr, v_arr) + except ValueError as e: + logger.debug("shapiro failed (n=%d): %s; falling back to wilcoxon", n, e) + return cls._wilcoxon(deltas) + + @staticmethod + def _mcnemar(b_arr: np.ndarray, v_arr: np.ndarray) -> tuple[str, float]: + """McNemar test for paired binary outcomes (0/1). + + Uses exact binomial when the number of discordant pairs is small + (< 25), and the asymptotic chi-square with continuity correction + otherwise. + """ + b_bin = (b_arr > 0.5).astype(int) + v_bin = (v_arr > 0.5).astype(int) + # b: baseline=1, variant=0 + # c: baseline=0, variant=1 + b = int(np.sum((b_bin == 1) & (v_bin == 0))) + c = int(np.sum((b_bin == 0) & (v_bin == 1))) + + n_disc = b + c + if n_disc == 0: + return "mcnemar", 1.0 + + if n_disc < 25: + # Exact two-sided binomial test on min(b, c) under H0: p = 0.5. + k = min(b, c) + try: + result = sp_stats.binomtest(k, n_disc, 0.5, alternative="two-sided") + return "mcnemar", float(result.pvalue) + except AttributeError: + # scipy < 1.7 — fall back to legacy binom_test + p = float(sp_stats.binom_test(k, n_disc, 0.5, alternative="two-sided")) + return "mcnemar", p + + # Asymptotic with continuity correction. + statistic = (abs(b - c) - 1) ** 2 / (b + c) + p_value = float(1.0 - sp_stats.chi2.cdf(statistic, df=1)) + return "mcnemar", p_value + + @staticmethod + def _bootstrap_ci(deltas: np.ndarray) -> tuple[float, float]: + """Bootstrap percentile CI on the mean delta. + + Returns (ci_low, ci_high) at the configured level (default 95%). + For n < 2 or all-equal deltas, returns a degenerate (mean, mean) CI. + """ + if len(deltas) < 2: + m = float(deltas.mean()) if len(deltas) else 0.0 + return m, m + if np.all(deltas == deltas[0]): + v = float(deltas[0]) + return v, v + try: + res = sp_stats.bootstrap( + (deltas,), + np.mean, + confidence_level=_CI_LEVEL, + n_resamples=_BOOTSTRAP_RESAMPLES, + method="percentile", + vectorized=True, + ) + return float(res.confidence_interval.low), float(res.confidence_interval.high) + except Exception as e: + logger.warning("bootstrap CI failed: %s", e) + m = float(deltas.mean()) + return m, m + + # ------------------------------------------------------------------ + # Summarization + # ------------------------------------------------------------------ + + def summarize_reasons(self, reasons: list[str]) -> str: + """Default-style summarization across raw reason strings. + + Only used when callers explicitly invoke summarize_reasons() with a + list of strings. The aggregator's own summarization is driven by + `_summarize_for_aggregation()` which has richer context. + """ + non_empty = [r for r in reasons if r] + if not non_empty or self.model is None: + return "" + prompt = ( + "Summarize the following per-trial evaluation reasons into a " + "concise 2-3 sentence summary:\n\n" + + "\n".join(f"- {r}" for r in non_empty[:20]) + ) + try: + agent = Agent( + model=self.model, system_prompt=self.system_prompt, callback_handler=None + ) + result = agent(prompt, structured_output_model=SkillSummary) + return cast(SkillSummary, result.structured_output).summary + except Exception as e: + logger.warning(f"LLM summarization failed: {e}") + return "" + + def _summarize_for_aggregation( + self, case_name: str, evaluator_name: str, paired_stats: list[PairedComparisonStats] + ) -> str: + """Produce an LLM summary of the paired stats for one (case, evaluator).""" + if self.model is None or not paired_stats: + return "" + + metric_lines = [] + for ps in paired_stats: + sig = "significant" if ps.p_value < 0.05 else "not significant" + metric_lines.append( + f" - {ps.metric_name}: baseline={ps.baseline_mean:.3f}, " + f"variant={ps.variant_mean:.3f}, Δ={ps.delta:+.3f} " + f"(p={ps.p_value:.3f} via {ps.test_used}, 95% CI=[{ps.ci_low:+.3f}, " + f"{ps.ci_high:+.3f}], {sig}, n={ps.n_used})" + ) + + prompt = ( + f"Task: {case_name}\n" + f"Evaluator: {evaluator_name}\n\n" + f"Paired metric comparisons (variant vs baseline):\n" + f"{chr(10).join(metric_lines)}\n\n" + "Summarize the variant's effect on this task." + ) + + try: + agent = Agent( + model=self.model, system_prompt=self.system_prompt, callback_handler=None + ) + result = agent(prompt, structured_output_model=SkillSummary) + return cast(SkillSummary, result.structured_output).summary + except Exception as e: + logger.warning( + f"LLM summarization failed for case '{case_name}', evaluator " + f"'{evaluator_name}': {e}" + ) + return "" diff --git a/src/strands_evals/skills/aggregator_types.py b/src/strands_evals/skills/aggregator_types.py new file mode 100644 index 00000000..0dec90c4 --- /dev/null +++ b/src/strands_evals/skills/aggregator_types.py @@ -0,0 +1,156 @@ +"""Data models for skills evaluation aggregation results.""" + +import json +from pathlib import Path +from typing import Optional + +from pydantic import BaseModel, Field + +from ..aggregators.types import AggregationResult + + +class PairedComparisonStats(BaseModel): + """One (case × metric) paired comparison: variant vs baseline. + + Attributes: + metric_name: Identifier for the metric (e.g. "pass_rate", "tokens", + "latency_s", "cost_usd"). + baseline_mean: Mean value of the metric under the baseline condition. + variant_mean: Mean value of the metric under the variant condition. + delta: Signed difference, variant_mean - baseline_mean. Positive means + the variant scored higher; whether that is "better" depends on the + metric (higher pass_rate is good, higher latency is bad). + delta_pct: delta expressed as a percentage of baseline_mean. None when + baseline_mean is zero or the percentage is not meaningful. + test_used: Statistical test identifier — one of "wilcoxon", + "paired_t", or "mcnemar". + p_value: Two-sided p-value from the paired test. + ci_low: Lower bound of the 95% confidence interval on the delta + (bootstrap percentile method). + ci_high: Upper bound of the 95% confidence interval on the delta. + n_used: Number of paired trials that contributed after corruption + filtering. + n_corrupted: Number of paired trials that were dropped because at + least one side was marked corrupted. + """ + + metric_name: str + baseline_mean: float + variant_mean: float + delta: float + delta_pct: Optional[float] = None + test_used: str + p_value: float + ci_low: float + ci_high: float + n_used: int + n_corrupted: int = 0 + + +class SkillEvalAggregation(AggregationResult): + """Aggregated results for one (case, evaluator) under skills evaluation. + + Extends AggregationResult with paired-comparison statistics, raw + per-condition values for downstream diagnostics, corruption accounting, + and trajectory pointers for drill-down. + """ + + # --- Paired statistics across metrics --- + paired_stats: list[PairedComparisonStats] = Field(default_factory=list) + + # --- Raw per-condition values for downstream diagnostics --- + # Key = metric_name, Value = list of per-trial values in trial_idx order. + raw_baseline_values: dict[str, list[float]] = Field(default_factory=dict) + raw_variant_values: dict[str, list[float]] = Field(default_factory=dict) + + # --- Corruption accounting (over paired trials) --- + n_total: int = 0 + n_corrupted: int = 0 + n_used: int = 0 + + # --- Trajectory pointers (session_ids) for drill-down --- + trajectory_pointers_baseline: list[str] = Field(default_factory=list) + trajectory_pointers_variant: list[str] = Field(default_factory=list) + + +class SkillEvalAggregationReport(BaseModel): + """Report containing all skills aggregation results. + + Provides .run_display(), .display(), .to_file() and .from_file() + matching the EvaluationReport / ChaosAggregationReport interfaces. + + Example:: + + aggregation_report = experiment.aggregate_evaluations() + aggregation_report.run_display() + aggregation_report.to_file("skill_eval_report.json") + """ + + aggregations: list[SkillEvalAggregation] = Field(default_factory=list) + + def run_display(self): + """Render the aggregation report interactively. + + Collapsed view shows one row per (case, evaluator) with Δ-metrics. + Expanding reveals full paired-statistics panels per metric. + """ + from .aggregation_display import display_skill_aggregation + + display_skill_aggregation(self.aggregations, static=False) + + def display(self): + """Render the report statically (non-interactive).""" + from .aggregation_display import display_skill_aggregation + + display_skill_aggregation(self.aggregations, static=True) + + def to_file(self, path: str): + """Write the aggregation report to a JSON file. + + Args: + path: File path. If no extension is provided, ".json" is added. + + Raises: + ValueError: If the path has a non-JSON extension. + """ + file_path = Path(path) + + if file_path.suffix: + if file_path.suffix != ".json": + raise ValueError( + f"Only .json format is supported. Got path with extension: {path}. " + f"Please use a .json extension or provide a path without an extension." + ) + else: + file_path = file_path.with_suffix(".json") + + file_path.parent.mkdir(parents=True, exist_ok=True) + + with open(file_path, "w", encoding="utf-8") as f: + json.dump(self.model_dump(), f, indent=2, ensure_ascii=False) + + @classmethod + def from_file(cls, path: str) -> "SkillEvalAggregationReport": + """Load an aggregation report from a JSON file. + + Args: + path: Path to the JSON file. + + Returns: + A SkillEvalAggregationReport instance. + + Raises: + ValueError: If the file does not have a .json extension. + """ + file_path = Path(path) + + if file_path.suffix != ".json": + raise ValueError( + f"Only .json format is supported. Got file: {path}. " + f"Please provide a path with .json extension." + ) + + with open(file_path, "r", encoding="utf-8") as f: + data = json.load(f) + + return cls.model_validate(data) diff --git a/src/strands_evals/skills/experiment.py b/src/strands_evals/skills/experiment.py new file mode 100644 index 00000000..b7770c9a --- /dev/null +++ b/src/strands_evals/skills/experiment.py @@ -0,0 +1,212 @@ +"""Skill Evaluation Experiment. + +Composes the base Experiment to run test cases across baseline / variant +conditions paired by trial_idx. Each (case × variant × trial_idx) becomes +one expanded case in the underlying Experiment. + +The task function receives the expanded Case. It can read +``case.metadata["variant_label"]`` to know which condition it is running +under, and ``case.metadata["trial_idx"]`` for the pair index. The aggregator +re-groups results back into (case, evaluator) pairs after the fact. + +Efficiency metrics (tokens, latency, cost) should be written by the task +function into ``case.metadata["efficiency"]``. The aggregator reads them +from there. Setting ``case.metadata["corrupted"] = True`` marks a trial as +crashed; the aggregator drops corrupted pairs before computing stats. +""" + +import logging +import uuid +from collections.abc import Callable +from typing import Any, Optional + +from ..case import Case +from ..evaluators.evaluator import Evaluator +from ..experiment import Experiment +from ..types.evaluation_report import EvaluationReport +from .aggregator import SkillEvalAggregator +from .aggregator_types import SkillEvalAggregationReport + +logger = logging.getLogger(__name__) + + +_DEFAULT_BASELINE_LABEL = "baseline" +_DEFAULT_VARIANT_LABEL = "variant" + + +class SkillEvalExperiment: + """Runs cases × variants × trials by composing the base Experiment. + + For each (case, variant_label, trial_idx) combination, creates an + expanded case tagged with metadata so the aggregator can recover pairs. + + Example:: + + from strands_evals.skills import ( + SkillEvalExperiment, + SkillEvalAggregator, + ) + + experiment = SkillEvalExperiment( + cases=test_cases, + variant_labels=["baseline", "variant"], + evaluators=[my_evaluator], + num_trials=30, + aggregator=SkillEvalAggregator(), + ) + + reports = experiment.run_evaluations(task=my_task) + agg_report = experiment.aggregate_evaluations() + agg_report.run_display() + + Args: + cases: Test cases to evaluate. Each case is expanded into + ``num_trials × len(variant_labels)`` runs. + variant_labels: Names of the conditions being compared. Defaults to + ``["baseline", "variant"]``. The aggregator uses these to find + pairs — the strings here must match ``baseline_label`` and + ``variant_label`` on the aggregator. + evaluators: Evaluators to assess results. + num_trials: Number of trial repetitions per (case × variant). + Defaults to 30. + aggregator: Optional SkillEvalAggregator. If provided, + aggregate_evaluations() can be called after run_evaluations(). + """ + + def __init__( + self, + cases: list[Case], + variant_labels: Optional[list[str]] = None, + evaluators: Optional[list[Evaluator]] = None, + num_trials: int = 30, + aggregator: Optional[SkillEvalAggregator] = None, + ): + if num_trials < 1: + raise ValueError(f"num_trials must be >= 1, got {num_trials}") + + self._original_cases = cases + self._variant_labels = variant_labels or [_DEFAULT_BASELINE_LABEL, _DEFAULT_VARIANT_LABEL] + if len(self._variant_labels) < 2: + raise ValueError( + f"variant_labels must contain at least two labels (baseline + variant), " + f"got {self._variant_labels!r}" + ) + self._evaluators = evaluators + self._num_trials = num_trials + self._aggregator = aggregator + self._last_reports: list[EvaluationReport] = [] + + # Build the expanded case list. + self._expanded_cases: list[Case] = [] + for case in cases: + for variant_label in self._variant_labels: + for trial_idx in range(num_trials): + session_id = str(uuid.uuid4()) + expanded_metadata: dict[str, Any] = dict(case.metadata or {}) + expanded_metadata.update({ + "original_case_name": case.name, + "variant_label": variant_label, + "trial_idx": trial_idx, + }) + expanded_name = ( + f"{case.name}|{variant_label}|{trial_idx}" + if case.name + else f"{variant_label}|{trial_idx}" + ) + expanded_case = case.model_copy( + update={ + "name": expanded_name, + "session_id": session_id, + "metadata": expanded_metadata, + } + ) + self._expanded_cases.append(expanded_case) + + # Internal Experiment with expanded cases. + self._experiment = Experiment( + cases=self._expanded_cases, + evaluators=evaluators, + ) + + # ------------------------------------------------------------------ + # Read-only views + # ------------------------------------------------------------------ + + @property + def cases(self) -> list[Case]: + """The original (unexpanded) test cases.""" + return self._original_cases + + @property + def variant_labels(self) -> list[str]: + """Variant labels configured for this experiment.""" + return list(self._variant_labels) + + @property + def num_trials(self) -> int: + """Trial repetitions per (case × variant).""" + return self._num_trials + + # ------------------------------------------------------------------ + # Execution + # ------------------------------------------------------------------ + + def run_evaluations( + self, task: Callable[[Case], Any], **kwargs + ) -> list[EvaluationReport]: + """Run evaluations across all (case × variant × trial) combinations. + + Args: + task: The task function to evaluate. Takes a Case and returns + output. Should set ``case.metadata["efficiency"]`` and may + set ``case.metadata["corrupted"] = True`` on crash. + **kwargs: Additional kwargs passed to base + Experiment.run_evaluations. + + Returns: + List of EvaluationReport objects covering all conditions. + """ + reports = self._experiment.run_evaluations(task, **kwargs) + self._last_reports = reports + + n_runs = len(self._original_cases) * len(self._variant_labels) * self._num_trials + logger.info( + f"Skill experiment complete: {len(reports)} reports " + f"({len(self._original_cases)} cases × {len(self._variant_labels)} variants " + f"× {self._num_trials} trials = {n_runs} runs)" + ) + return reports + + async def run_evaluations_async( + self, task: Callable[[Case], Any], max_workers: int = 10, **kwargs + ) -> list[EvaluationReport]: + """Run evaluations asynchronously.""" + reports = await self._experiment.run_evaluations_async( + task, max_workers=max_workers, **kwargs + ) + self._last_reports = reports + return reports + + def aggregate_evaluations(self) -> SkillEvalAggregationReport: + """Aggregate the last run's reports into a SkillEvalAggregationReport. + + Must be called after run_evaluations(). + + Returns: + SkillEvalAggregationReport with .run_display() and .to_file() + methods. + + Raises: + RuntimeError: If no aggregator was configured or + run_evaluations() hasn't been called. + """ + if self._aggregator is None: + raise RuntimeError( + "No aggregator configured. Pass aggregator=SkillEvalAggregator() " + "to SkillEvalExperiment.__init__()." + ) + if not self._last_reports: + raise RuntimeError( + "No evaluation reports available. Call run_evaluations() first." + ) + return self._aggregator.aggregate(self._last_reports) diff --git a/tests/skills/__init__.py b/tests/skills/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/skills/test_aggregator.py b/tests/skills/test_aggregator.py new file mode 100644 index 00000000..c8178429 --- /dev/null +++ b/tests/skills/test_aggregator.py @@ -0,0 +1,573 @@ +"""Tests for skills/aggregator.py. + +Covers paired statistics correctness against synthetic data, corruption +filtering, raw-value preservation, baseline pairing, and edge cases. +""" + +from __future__ import annotations + +import math +import random +from typing import Any + +import numpy as np +import pytest + +from strands_evals.skills.aggregator import ( + SkillEvalAggregator, +) +from strands_evals.skills.aggregator_types import ( + SkillEvalAggregationReport, +) +from strands_evals.types.evaluation_report import EvaluationReport + + +# --------------------------------------------------------------------------- +# Helpers — build synthetic EvaluationReport input +# --------------------------------------------------------------------------- + + +def _build_report( + evaluator_name: str, + rows: list[dict[str, Any]], +) -> EvaluationReport: + """Build a single EvaluationReport from row dicts. + + Each row is: + { + "case_name": str, # original case name + "variant_label": str, # "baseline" | "variant" + "trial_idx": int, + "score": float, + "passed": bool, + "reason": str, + "efficiency": dict[str, float] | None, + "corrupted": bool | None, + "session_id": str | None, + } + """ + cases = [] + scores = [] + passes = [] + reasons = [] + for r in rows: + metadata = { + "original_case_name": r["case_name"], + "variant_label": r["variant_label"], + "trial_idx": r["trial_idx"], + "efficiency": r.get("efficiency") or {}, + "corrupted": r.get("corrupted", False), + } + cases.append({ + "name": f"{r['case_name']}|{r['variant_label']}|{r['trial_idx']}", + "metadata": metadata, + "session_id": r.get("session_id", f"sess-{len(cases)}"), + }) + scores.append(r["score"]) + passes.append(r["passed"]) + reasons.append(r.get("reason", "")) + + overall = sum(scores) / len(scores) if scores else 0.0 + return EvaluationReport( + evaluator_name=evaluator_name, + overall_score=overall, + scores=scores, + cases=cases, + test_passes=passes, + reasons=reasons, + ) + + +def _make_paired_rows( + case_name: str, + n: int, + baseline_efficiency: dict[str, float], + variant_efficiency: dict[str, float], + baseline_passed: bool = True, + variant_passed: bool = True, + baseline_score: float = 1.0, + variant_score: float = 1.0, + corrupted_trial_indices: set[int] | None = None, + corruption_side: str = "both", # "baseline" | "variant" | "both" +) -> list[dict]: + """Build 2N rows for a single case: N baseline + N variant, all paired.""" + rows = [] + corrupted_trial_indices = corrupted_trial_indices or set() + for trial_idx in range(n): + b_corrupted = ( + trial_idx in corrupted_trial_indices and corruption_side in ("baseline", "both") + ) + v_corrupted = ( + trial_idx in corrupted_trial_indices and corruption_side in ("variant", "both") + ) + rows.append({ + "case_name": case_name, + "variant_label": "baseline", + "trial_idx": trial_idx, + "score": baseline_score, + "passed": baseline_passed, + "efficiency": dict(baseline_efficiency), + "corrupted": b_corrupted, + }) + rows.append({ + "case_name": case_name, + "variant_label": "variant", + "trial_idx": trial_idx, + "score": variant_score, + "passed": variant_passed, + "efficiency": dict(variant_efficiency), + "corrupted": v_corrupted, + }) + return rows + + +# --------------------------------------------------------------------------- +# Empty input / edge cases +# --------------------------------------------------------------------------- + + +def test_aggregate_empty_reports_returns_empty_report(): + agg = SkillEvalAggregator() + report = agg.aggregate([]) + assert isinstance(report, SkillEvalAggregationReport) + assert report.aggregations == [] + + +def test_aggregate_missing_metadata_skips_cases(): + """Cases without variant_label / trial_idx metadata are skipped.""" + report = EvaluationReport( + evaluator_name="GoalSuccessRateEvaluator", + overall_score=0.5, + scores=[0.5], + cases=[{"name": "untagged", "metadata": {}}], + test_passes=[False], + reasons=[""], + ) + agg = SkillEvalAggregator() + out = agg.aggregate([report]) + assert out.aggregations == [] + + +def test_aggregate_no_pairs_when_only_baseline_runs(): + """If every trial is baseline-only, there are zero paired observations.""" + rows = [ + { + "case_name": "task1", + "variant_label": "baseline", + "trial_idx": i, + "score": 1.0, + "passed": True, + "efficiency": {"tokens": 100.0}, + } + for i in range(5) + ] + report = _build_report("GoalSuccessRateEvaluator", rows) + agg = SkillEvalAggregator() + out = agg.aggregate([report]) + assert len(out.aggregations) == 1 + assert out.aggregations[0].n_total == 0 + assert out.aggregations[0].n_used == 0 + assert out.aggregations[0].paired_stats == [] + + +# --------------------------------------------------------------------------- +# Pairing +# --------------------------------------------------------------------------- + + +def test_aggregate_pairs_by_trial_idx(): + """Trials present on both sides are paired; orphans are dropped.""" + rows = [] + # 5 baseline trials, indices 0-4 + for i in range(5): + rows.append({ + "case_name": "task1", + "variant_label": "baseline", + "trial_idx": i, + "score": 1.0, + "passed": True, + "efficiency": {"tokens": 100.0}, + }) + # 4 variant trials, indices 0-3 (one orphan baseline at idx=4) + for i in range(4): + rows.append({ + "case_name": "task1", + "variant_label": "variant", + "trial_idx": i, + "score": 1.0, + "passed": True, + "efficiency": {"tokens": 90.0}, + }) + + report = _build_report("GoalSuccessRateEvaluator", rows) + agg = SkillEvalAggregator() + out = agg.aggregate([report]) + + assert len(out.aggregations) == 1 + a = out.aggregations[0] + assert a.n_total == 4 # 4 paired trials, orphan dropped + assert a.n_used == 4 + + +# --------------------------------------------------------------------------- +# Corruption filtering +# --------------------------------------------------------------------------- + + +def test_corrupted_pair_dropped_when_baseline_corrupted(): + rows = _make_paired_rows( + "task1", n=10, + baseline_efficiency={"tokens": 100.0}, + variant_efficiency={"tokens": 90.0}, + corrupted_trial_indices={2, 7}, + corruption_side="baseline", + ) + report = _build_report("GoalSuccessRateEvaluator", rows) + agg = SkillEvalAggregator() + out = agg.aggregate([report]) + a = out.aggregations[0] + assert a.n_total == 10 + assert a.n_corrupted == 2 + assert a.n_used == 8 + + +def test_corrupted_pair_dropped_when_variant_corrupted(): + rows = _make_paired_rows( + "task1", n=10, + baseline_efficiency={"tokens": 100.0}, + variant_efficiency={"tokens": 90.0}, + corrupted_trial_indices={1, 4, 9}, + corruption_side="variant", + ) + report = _build_report("GoalSuccessRateEvaluator", rows) + agg = SkillEvalAggregator() + out = agg.aggregate([report]) + a = out.aggregations[0] + assert a.n_corrupted == 3 + assert a.n_used == 7 + + +def test_all_corrupted_yields_no_pairs(): + rows = _make_paired_rows( + "task1", n=5, + baseline_efficiency={"tokens": 100.0}, + variant_efficiency={"tokens": 90.0}, + corrupted_trial_indices={0, 1, 2, 3, 4}, + ) + report = _build_report("GoalSuccessRateEvaluator", rows) + agg = SkillEvalAggregator() + out = agg.aggregate([report]) + a = out.aggregations[0] + assert a.n_total == 5 + assert a.n_corrupted == 5 + assert a.n_used == 0 + # No stats can be computed. + assert a.paired_stats == [] + + +# --------------------------------------------------------------------------- +# Raw values preservation +# --------------------------------------------------------------------------- + + +def test_raw_values_preserved_in_trial_idx_order(): + """raw_*_values should hold per-trial values in trial_idx order, post-filter.""" + rows = [] + # baseline tokens: [100, 110, 120, 130, 140] + # variant tokens: [ 90, 100, 105, 115, 130] + baseline_tokens = [100.0, 110.0, 120.0, 130.0, 140.0] + variant_tokens = [90.0, 100.0, 105.0, 115.0, 130.0] + for i in range(5): + rows.append({ + "case_name": "task1", "variant_label": "baseline", "trial_idx": i, + "score": 1.0, "passed": True, + "efficiency": {"tokens": baseline_tokens[i]}, + }) + rows.append({ + "case_name": "task1", "variant_label": "variant", "trial_idx": i, + "score": 1.0, "passed": True, + "efficiency": {"tokens": variant_tokens[i]}, + }) + + report = _build_report("GoalSuccessRateEvaluator", rows) + agg = SkillEvalAggregator(metrics=["pass_rate", "tokens"]) + out = agg.aggregate([report]) + a = out.aggregations[0] + assert a.raw_baseline_values["tokens"] == baseline_tokens + assert a.raw_variant_values["tokens"] == variant_tokens + + +def test_missing_efficiency_metric_skips_that_pair(): + """If a trial is missing a metric, that pair is excluded for that metric.""" + rows = [ + {"case_name": "t", "variant_label": "baseline", "trial_idx": 0, + "score": 1.0, "passed": True, "efficiency": {"tokens": 100.0}}, + {"case_name": "t", "variant_label": "variant", "trial_idx": 0, + "score": 1.0, "passed": True, "efficiency": {}}, # missing tokens + {"case_name": "t", "variant_label": "baseline", "trial_idx": 1, + "score": 1.0, "passed": True, "efficiency": {"tokens": 110.0}}, + {"case_name": "t", "variant_label": "variant", "trial_idx": 1, + "score": 1.0, "passed": True, "efficiency": {"tokens": 105.0}}, + ] + report = _build_report("GoalSuccessRateEvaluator", rows) + agg = SkillEvalAggregator(metrics=["tokens"]) + out = agg.aggregate([report]) + a = out.aggregations[0] + # Only one valid pair for tokens. + assert len(a.raw_baseline_values["tokens"]) == 1 + # But with only 1 pair, no stat is computed (needs n >= 2). + assert a.paired_stats == [] + + +# --------------------------------------------------------------------------- +# Paired statistics correctness +# --------------------------------------------------------------------------- + + +def test_wilcoxon_detects_systematic_improvement(): + """Variant systematically better — Wilcoxon should reject H0.""" + rng = random.Random(42) + n = 30 + rows = [] + for i in range(n): + b = 1000.0 + rng.gauss(0, 50) + v = b - 100 + rng.gauss(0, 20) # consistently lower + rows.append({ + "case_name": "t", "variant_label": "baseline", "trial_idx": i, + "score": 1.0, "passed": True, "efficiency": {"tokens": b}, + }) + rows.append({ + "case_name": "t", "variant_label": "variant", "trial_idx": i, + "score": 1.0, "passed": True, "efficiency": {"tokens": v}, + }) + report = _build_report("Eval", rows) + agg = SkillEvalAggregator(metrics=["tokens"], stats_test="wilcoxon") + out = agg.aggregate([report]) + ps = next(p for p in out.aggregations[0].paired_stats if p.metric_name == "tokens") + assert ps.test_used == "wilcoxon" + assert ps.p_value < 0.01 + assert ps.delta < 0 # variant uses fewer tokens + + +def test_paired_t_detects_systematic_improvement(): + rng = random.Random(42) + n = 30 + rows = [] + for i in range(n): + b = 2.0 + rng.gauss(0, 0.2) + v = b - 0.3 + rng.gauss(0, 0.05) + rows.append({ + "case_name": "t", "variant_label": "baseline", "trial_idx": i, + "score": 1.0, "passed": True, "efficiency": {"latency_s": b}, + }) + rows.append({ + "case_name": "t", "variant_label": "variant", "trial_idx": i, + "score": 1.0, "passed": True, "efficiency": {"latency_s": v}, + }) + report = _build_report("Eval", rows) + agg = SkillEvalAggregator(metrics=["latency_s"], stats_test="paired_t") + out = agg.aggregate([report]) + ps = next(p for p in out.aggregations[0].paired_stats if p.metric_name == "latency_s") + assert ps.test_used == "paired_t" + assert ps.p_value < 0.01 + assert ps.delta < 0 + + +def test_mcnemar_used_for_pass_rate_regardless_of_setting(): + """pass_rate always uses mcnemar even when stats_test='wilcoxon'.""" + rows = [] + # Variant flips 10 failures to passes; baseline has 10 failures, variant has 0. + for i in range(20): + rows.append({ + "case_name": "t", "variant_label": "baseline", "trial_idx": i, + "score": 0.0 if i < 10 else 1.0, + "passed": i >= 10, + "efficiency": {}, + }) + rows.append({ + "case_name": "t", "variant_label": "variant", "trial_idx": i, + "score": 1.0, "passed": True, "efficiency": {}, + }) + report = _build_report("Eval", rows) + agg = SkillEvalAggregator(metrics=["pass_rate"], stats_test="wilcoxon") + out = agg.aggregate([report]) + ps = out.aggregations[0].paired_stats[0] + assert ps.metric_name == "pass_rate" + assert ps.test_used == "mcnemar" + assert ps.p_value < 0.01 + + +def test_mcnemar_no_discordant_pairs_returns_p_1(): + """When every pair agrees, McNemar p-value is 1.0.""" + rows = [] + for i in range(10): + rows.append({ + "case_name": "t", "variant_label": "baseline", "trial_idx": i, + "score": 1.0, "passed": True, "efficiency": {}, + }) + rows.append({ + "case_name": "t", "variant_label": "variant", "trial_idx": i, + "score": 1.0, "passed": True, "efficiency": {}, + }) + report = _build_report("Eval", rows) + agg = SkillEvalAggregator(metrics=["pass_rate"]) + out = agg.aggregate([report]) + ps = out.aggregations[0].paired_stats[0] + assert ps.test_used == "mcnemar" + assert ps.p_value == 1.0 + assert ps.delta == 0.0 + + +def test_no_meaningful_difference_yields_high_p(): + """When variant ≈ baseline, p should not reject H0.""" + rng = random.Random(7) + rows = [] + for i in range(30): + b = 1000.0 + rng.gauss(0, 50) + v = 1000.0 + rng.gauss(0, 50) + rows.append({ + "case_name": "t", "variant_label": "baseline", "trial_idx": i, + "score": 1.0, "passed": True, "efficiency": {"tokens": b}, + }) + rows.append({ + "case_name": "t", "variant_label": "variant", "trial_idx": i, + "score": 1.0, "passed": True, "efficiency": {"tokens": v}, + }) + report = _build_report("Eval", rows) + agg = SkillEvalAggregator(metrics=["tokens"], stats_test="wilcoxon") + out = agg.aggregate([report]) + ps = out.aggregations[0].paired_stats[0] + assert ps.p_value > 0.05 + + +def test_auto_test_picks_paired_t_for_normal_deltas(): + """When deltas look normal, auto should select paired_t.""" + rng = np.random.default_rng(0) + n = 30 + deltas = rng.normal(loc=-0.3, scale=0.1, size=n) + rows = [] + for i in range(n): + b = 2.0 + v = float(b + deltas[i]) + rows.append({ + "case_name": "t", "variant_label": "baseline", "trial_idx": i, + "score": 1.0, "passed": True, "efficiency": {"latency_s": b}, + }) + rows.append({ + "case_name": "t", "variant_label": "variant", "trial_idx": i, + "score": 1.0, "passed": True, "efficiency": {"latency_s": v}, + }) + report = _build_report("Eval", rows) + agg = SkillEvalAggregator(metrics=["latency_s"], stats_test="auto") + out = agg.aggregate([report]) + ps = out.aggregations[0].paired_stats[0] + assert ps.test_used in {"paired_t", "wilcoxon"} + # With clearly-Gaussian deltas at this n, auto should usually pick paired_t. + # We don't enforce that hard because shapiro is noisy. + + +# --------------------------------------------------------------------------- +# Bootstrap CI +# --------------------------------------------------------------------------- + + +def test_bootstrap_ci_brackets_delta(): + """CI should contain the observed delta (mean of deltas).""" + rng = random.Random(99) + n = 50 + rows = [] + for i in range(n): + b = 100.0 + rng.gauss(0, 5) + v = b - 10 + rng.gauss(0, 2) + rows.append({ + "case_name": "t", "variant_label": "baseline", "trial_idx": i, + "score": 1.0, "passed": True, "efficiency": {"tokens": b}, + }) + rows.append({ + "case_name": "t", "variant_label": "variant", "trial_idx": i, + "score": 1.0, "passed": True, "efficiency": {"tokens": v}, + }) + report = _build_report("Eval", rows) + agg = SkillEvalAggregator(metrics=["tokens"]) + out = agg.aggregate([report]) + ps = out.aggregations[0].paired_stats[0] + # CI should bracket the delta (which is roughly -10). + assert ps.ci_low < ps.delta < ps.ci_high + # CI should be on the negative side (variant clearly better). + assert ps.ci_high < 0 + + +# --------------------------------------------------------------------------- +# Multi-case / multi-evaluator +# --------------------------------------------------------------------------- + + +def test_multiple_cases_yield_separate_aggregations(): + rows = ( + _make_paired_rows("case_a", n=5, + baseline_efficiency={"tokens": 100.0}, + variant_efficiency={"tokens": 90.0}) + + _make_paired_rows("case_b", n=5, + baseline_efficiency={"tokens": 200.0}, + variant_efficiency={"tokens": 250.0}) + ) + report = _build_report("Eval", rows) + agg = SkillEvalAggregator(metrics=["tokens"]) + out = agg.aggregate([report]) + assert len(out.aggregations) == 2 + keys = {a.group_key for a in out.aggregations} + assert keys == {"case_a", "case_b"} + + +def test_multiple_evaluators_yield_separate_aggregations(): + rows_a = _make_paired_rows("task1", n=5, + baseline_efficiency={"tokens": 100.0}, + variant_efficiency={"tokens": 90.0}) + report1 = _build_report("EvalA", rows_a) + report2 = _build_report("EvalB", rows_a) + agg = SkillEvalAggregator(metrics=["tokens"]) + out = agg.aggregate([report1, report2]) + assert len(out.aggregations) == 2 + evaluators = {a.evaluator_name for a in out.aggregations} + assert evaluators == {"EvalA", "EvalB"} + + +# --------------------------------------------------------------------------- +# Constructor validation +# --------------------------------------------------------------------------- + + +def test_invalid_stats_test_raises(): + with pytest.raises(ValueError, match="stats_test must be one of"): + SkillEvalAggregator(stats_test="bogus") + + +def test_summarize_reasons_returns_empty_when_no_model(): + agg = SkillEvalAggregator() # model is None + assert agg.summarize_reasons(["a", "b"]) == "" + + +# --------------------------------------------------------------------------- +# Trajectory pointers +# --------------------------------------------------------------------------- + + +def test_trajectory_pointers_collected_in_pair_order(): + rows = [] + for i in range(3): + rows.append({ + "case_name": "task", "variant_label": "baseline", "trial_idx": i, + "score": 1.0, "passed": True, "efficiency": {"tokens": 100.0}, + "session_id": f"b{i}", + }) + rows.append({ + "case_name": "task", "variant_label": "variant", "trial_idx": i, + "score": 1.0, "passed": True, "efficiency": {"tokens": 90.0}, + "session_id": f"v{i}", + }) + report = _build_report("Eval", rows) + agg = SkillEvalAggregator(metrics=["tokens"]) + out = agg.aggregate([report]) + a = out.aggregations[0] + assert a.trajectory_pointers_baseline == ["b0", "b1", "b2"] + assert a.trajectory_pointers_variant == ["v0", "v1", "v2"] diff --git a/tests/skills/test_experiment.py b/tests/skills/test_experiment.py new file mode 100644 index 00000000..367c9564 --- /dev/null +++ b/tests/skills/test_experiment.py @@ -0,0 +1,176 @@ +"""Tests for skills/experiment.py.""" + +import pytest + +from strands_evals.case import Case +from strands_evals.skills.aggregator import SkillEvalAggregator +from strands_evals.skills.experiment import SkillEvalExperiment + + +# --------------------------------------------------------------------------- +# Construction validation +# --------------------------------------------------------------------------- + + +def test_invalid_num_trials_raises(): + with pytest.raises(ValueError, match="num_trials must be >= 1"): + SkillEvalExperiment(cases=[], num_trials=0) + + +def test_insufficient_variant_labels_raises(): + with pytest.raises(ValueError, match="at least two labels"): + SkillEvalExperiment( + cases=[Case[str, str](name="t", input="x")], + variant_labels=["baseline_only"], + ) + + +# --------------------------------------------------------------------------- +# Case expansion +# --------------------------------------------------------------------------- + + +def test_case_expansion_count_matches_cases_times_variants_times_trials(): + cases = [ + Case[str, str](name="task_a", input="x"), + Case[str, str](name="task_b", input="y"), + ] + experiment = SkillEvalExperiment( + cases=cases, + variant_labels=["baseline", "variant"], + num_trials=5, + ) + # 2 cases × 2 variants × 5 trials = 20 + assert len(experiment._expanded_cases) == 20 + + +def test_expanded_cases_carry_correct_metadata(): + cases = [Case[str, str](name="task_a", input="x")] + experiment = SkillEvalExperiment( + cases=cases, + variant_labels=["baseline", "variant"], + num_trials=3, + ) + metas = [c.metadata for c in experiment._expanded_cases] + # All carry original_case_name = "task_a" + assert all(m["original_case_name"] == "task_a" for m in metas) + # variant_label values come from variant_labels list + variant_labels = sorted({m["variant_label"] for m in metas}) + assert variant_labels == ["baseline", "variant"] + # trial_idx covers 0..num_trials-1 for each variant + by_variant: dict[str, list[int]] = {"baseline": [], "variant": []} + for m in metas: + by_variant[m["variant_label"]].append(m["trial_idx"]) + assert sorted(by_variant["baseline"]) == [0, 1, 2] + assert sorted(by_variant["variant"]) == [0, 1, 2] + + +def test_expanded_case_names_include_variant_and_trial(): + cases = [Case[str, str](name="task_a", input="x")] + experiment = SkillEvalExperiment( + cases=cases, + variant_labels=["baseline", "variant"], + num_trials=2, + ) + names = sorted(c.name for c in experiment._expanded_cases) + assert names == [ + "task_a|baseline|0", + "task_a|baseline|1", + "task_a|variant|0", + "task_a|variant|1", + ] + + +def test_expanded_cases_get_unique_session_ids(): + cases = [Case[str, str](name="task_a", input="x")] + experiment = SkillEvalExperiment( + cases=cases, + variant_labels=["baseline", "variant"], + num_trials=4, + ) + session_ids = [c.session_id for c in experiment._expanded_cases] + assert len(set(session_ids)) == len(session_ids) + + +def test_user_metadata_preserved_in_expansion(): + """Pre-existing case.metadata fields survive expansion.""" + cases = [ + Case[str, str]( + name="task_a", input="x", metadata={"category": "math", "weight": 2.0} + ) + ] + experiment = SkillEvalExperiment( + cases=cases, + variant_labels=["baseline", "variant"], + num_trials=1, + ) + for c in experiment._expanded_cases: + assert c.metadata["category"] == "math" + assert c.metadata["weight"] == 2.0 + # And the expansion fields are added. + assert "variant_label" in c.metadata + assert "trial_idx" in c.metadata + + +def test_unnamed_case_expansion_uses_variant_trial_only(): + cases = [Case[str, str](input="x")] + experiment = SkillEvalExperiment( + cases=cases, + variant_labels=["baseline", "variant"], + num_trials=1, + ) + names = sorted(c.name for c in experiment._expanded_cases) + assert names == ["baseline|0", "variant|0"] + + +# --------------------------------------------------------------------------- +# Aggregation hookup +# --------------------------------------------------------------------------- + + +def test_aggregate_evaluations_without_aggregator_raises(): + experiment = SkillEvalExperiment( + cases=[Case[str, str](name="t", input="x")], + num_trials=1, + ) + with pytest.raises(RuntimeError, match="No aggregator configured"): + experiment.aggregate_evaluations() + + +def test_aggregate_evaluations_without_run_raises(): + experiment = SkillEvalExperiment( + cases=[Case[str, str](name="t", input="x")], + num_trials=1, + aggregator=SkillEvalAggregator(), + ) + with pytest.raises(RuntimeError, match="No evaluation reports available"): + experiment.aggregate_evaluations() + + +# --------------------------------------------------------------------------- +# Properties +# --------------------------------------------------------------------------- + + +def test_properties_expose_constructor_values(): + cases = [Case[str, str](name="t", input="x")] + experiment = SkillEvalExperiment( + cases=cases, + variant_labels=["a", "b", "c"], + num_trials=7, + ) + assert experiment.cases == cases + assert experiment.variant_labels == ["a", "b", "c"] + assert experiment.num_trials == 7 + + +def test_variant_labels_returns_a_copy(): + """Mutating the returned list must not affect the experiment.""" + experiment = SkillEvalExperiment( + cases=[Case[str, str](name="t", input="x")], + variant_labels=["a", "b"], + num_trials=1, + ) + labels = experiment.variant_labels + labels.append("c") + assert experiment.variant_labels == ["a", "b"] diff --git a/tests/skills/test_types.py b/tests/skills/test_types.py new file mode 100644 index 00000000..8b73f670 --- /dev/null +++ b/tests/skills/test_types.py @@ -0,0 +1,216 @@ +"""Tests for skills/aggregator_types.py.""" + +import json +from pathlib import Path + +import pytest + +from strands_evals.skills.aggregator_types import ( + PairedComparisonStats, + SkillEvalAggregation, + SkillEvalAggregationReport, +) + + +# --------------------------------------------------------------------------- +# PairedComparisonStats +# --------------------------------------------------------------------------- + + +def test_paired_comparison_stats_construction(): + ps = PairedComparisonStats( + metric_name="tokens", + baseline_mean=1000.0, + variant_mean=900.0, + delta=-100.0, + delta_pct=-0.1, + test_used="wilcoxon", + p_value=0.003, + ci_low=-150.0, + ci_high=-50.0, + n_used=30, + n_corrupted=2, + ) + assert ps.metric_name == "tokens" + assert ps.delta == -100.0 + assert ps.test_used == "wilcoxon" + + +def test_paired_comparison_stats_defaults(): + """delta_pct and n_corrupted default to sensible values.""" + ps = PairedComparisonStats( + metric_name="latency_s", + baseline_mean=2.0, + variant_mean=2.5, + delta=0.5, + test_used="paired_t", + p_value=0.04, + ci_low=0.1, + ci_high=0.9, + n_used=20, + ) + assert ps.delta_pct is None + assert ps.n_corrupted == 0 + + +# --------------------------------------------------------------------------- +# SkillEvalAggregation +# --------------------------------------------------------------------------- + + +def _make_aggregation(**overrides) -> SkillEvalAggregation: + base = { + "group_key": "order_flow", + "evaluator_name": "GoalSuccessRateEvaluator", + "mean_score": 0.85, + "min_score": 0.5, + "max_score": 1.0, + "pass_rate": 0.9, + "num_results": 30, + "num_passed": 27, + "num_failed": 3, + } + base.update(overrides) + return SkillEvalAggregation(**base) + + +def test_skill_aggregation_inherits_base_fields(): + """SkillEvalAggregation must carry all AggregationResult fields.""" + agg = _make_aggregation() + assert agg.group_key == "order_flow" + assert agg.mean_score == 0.85 + assert agg.pass_rate == 0.9 + # And its own fields with defaults. + assert agg.paired_stats == [] + assert agg.raw_baseline_values == {} + assert agg.n_total == 0 + assert agg.n_used == 0 + + +def test_skill_aggregation_with_paired_stats(): + ps = PairedComparisonStats( + metric_name="pass_rate", + baseline_mean=0.7, + variant_mean=0.88, + delta=0.18, + test_used="mcnemar", + p_value=0.003, + ci_low=0.08, + ci_high=0.28, + n_used=60, + n_corrupted=4, + ) + agg = _make_aggregation( + paired_stats=[ps], + n_total=64, + n_corrupted=4, + n_used=60, + raw_baseline_values={"pass_rate": [1.0, 0.0, 1.0]}, + raw_variant_values={"pass_rate": [1.0, 1.0, 1.0]}, + trajectory_pointers_baseline=["s1", "s2", "s3"], + trajectory_pointers_variant=["s4", "s5", "s6"], + ) + assert len(agg.paired_stats) == 1 + assert agg.paired_stats[0].metric_name == "pass_rate" + assert agg.n_total == 64 + assert agg.raw_baseline_values["pass_rate"] == [1.0, 0.0, 1.0] + + +# --------------------------------------------------------------------------- +# SkillEvalAggregationReport — serialization +# --------------------------------------------------------------------------- + + +def test_report_roundtrip_via_dict(): + """model_dump → model_validate produces an equivalent report.""" + agg = _make_aggregation( + paired_stats=[ + PairedComparisonStats( + metric_name="tokens", + baseline_mean=1000.0, + variant_mean=950.0, + delta=-50.0, + test_used="wilcoxon", + p_value=0.04, + ci_low=-90.0, + ci_high=-10.0, + n_used=28, + n_corrupted=2, + ) + ], + n_total=30, + n_corrupted=2, + n_used=28, + ) + report = SkillEvalAggregationReport(aggregations=[agg]) + data = report.model_dump() + restored = SkillEvalAggregationReport.model_validate(data) + assert len(restored.aggregations) == 1 + assert restored.aggregations[0].group_key == "order_flow" + assert restored.aggregations[0].paired_stats[0].metric_name == "tokens" + assert restored.aggregations[0].paired_stats[0].delta == -50.0 + + +def test_report_to_file_writes_json(tmp_path: Path): + agg = _make_aggregation() + report = SkillEvalAggregationReport(aggregations=[agg]) + target = tmp_path / "report.json" + report.to_file(str(target)) + assert target.exists() + data = json.loads(target.read_text()) + assert data["aggregations"][0]["group_key"] == "order_flow" + + +def test_report_to_file_appends_extension_when_missing(tmp_path: Path): + agg = _make_aggregation() + report = SkillEvalAggregationReport(aggregations=[agg]) + target = tmp_path / "report_no_ext" + report.to_file(str(target)) + assert (tmp_path / "report_no_ext.json").exists() + + +def test_report_to_file_rejects_non_json_extension(tmp_path: Path): + agg = _make_aggregation() + report = SkillEvalAggregationReport(aggregations=[agg]) + with pytest.raises(ValueError, match="Only .json format"): + report.to_file(str(tmp_path / "report.yaml")) + + +def test_report_from_file_rejects_non_json_extension(tmp_path: Path): + target = tmp_path / "report.yaml" + target.write_text("{}") + with pytest.raises(ValueError, match="Only .json format"): + SkillEvalAggregationReport.from_file(str(target)) + + +def test_report_from_file_roundtrip(tmp_path: Path): + agg = _make_aggregation( + paired_stats=[ + PairedComparisonStats( + metric_name="latency_s", + baseline_mean=2.0, + variant_mean=1.5, + delta=-0.5, + test_used="paired_t", + p_value=0.02, + ci_low=-0.8, + ci_high=-0.2, + n_used=25, + ) + ] + ) + report = SkillEvalAggregationReport(aggregations=[agg]) + target = tmp_path / "report.json" + report.to_file(str(target)) + + restored = SkillEvalAggregationReport.from_file(str(target)) + assert restored.aggregations[0].paired_stats[0].metric_name == "latency_s" + assert restored.aggregations[0].paired_stats[0].test_used == "paired_t" + + +def test_report_handles_empty_aggregations(tmp_path: Path): + report = SkillEvalAggregationReport(aggregations=[]) + target = tmp_path / "empty.json" + report.to_file(str(target)) + restored = SkillEvalAggregationReport.from_file(str(target)) + assert restored.aggregations == []