diff --git a/strands_robots/simulation/__init__.py b/strands_robots/simulation/__init__.py index c272008..e832a51 100644 --- a/strands_robots/simulation/__init__.py +++ b/strands_robots/simulation/__init__.py @@ -51,6 +51,19 @@ # Light imports (no heavy deps - stdlib + dataclasses only) from strands_robots.simulation.base import SimEngine +from strands_robots.simulation.benchmark import ( + BenchmarkCompatibilityError, + BenchmarkProtocol, + StepInfo, + get_benchmark, + list_benchmarks, + register_benchmark, + unregister_benchmark, +) +from strands_robots.simulation.benchmark_spec import ( + DeclarativeBenchmark, + register_benchmark_from_file, +) from strands_robots.simulation.factory import ( create_simulation, list_backends, @@ -71,6 +84,11 @@ SimWorld, TrajectoryStep, ) +from strands_robots.simulation.predicates import ( + PREDICATE_REGISTRY, + make_predicate, + register_predicate, +) # Heavy imports (lazy - need strands SDK + mujoco) _LAZY_IMPORTS: dict[str, tuple[str, str]] = { @@ -108,6 +126,20 @@ "resolve_urdf", "list_registered_urdfs", "list_available_models", + # Benchmark protocol + registry + "BenchmarkProtocol", + "BenchmarkCompatibilityError", + "StepInfo", + "register_benchmark", + "unregister_benchmark", + "get_benchmark", + "list_benchmarks", + # Declarative DSL + predicates + "DeclarativeBenchmark", + "register_benchmark_from_file", + "PREDICATE_REGISTRY", + "make_predicate", + "register_predicate", ] diff --git a/strands_robots/simulation/base.py b/strands_robots/simulation/base.py index 54f0071..bc6cb90 100644 --- a/strands_robots/simulation/base.py +++ b/strands_robots/simulation/base.py @@ -449,6 +449,187 @@ def eval_policy( success_fn=success_fn, ) + # Benchmark protocol facades + + def evaluate_benchmark( + self, + benchmark_name: str, + robot_name: str | None = None, + policy_provider: str = "mock", + policy_config: dict[str, Any] | None = None, + instruction: str = "", + n_episodes: int = 1, + seed: int | None = None, + ) -> dict[str, Any]: + """Run a registered :class:`BenchmarkProtocol` against the current sim. + + Benchmark-agnostic evaluation entry point. Looks up ``benchmark_name`` + in the global benchmark registry, validates robot compatibility, and + forwards to :meth:`PolicyRunner.evaluate` with the spec. + ``max_steps`` comes from the benchmark (not a parameter here). + + Args: + benchmark_name: Key from :func:`register_benchmark` / + :func:`register_benchmark_from_file`. + robot_name: Robot to evaluate. If ``None`` and the benchmark has + exactly one supported robot that matches a loaded robot, that + robot is picked; otherwise returns an error. + policy_provider: Policy provider name (forwarded to + :func:`create_policy`). + policy_config: Provider-specific kwargs. + instruction: Natural-language instruction for the policy. + n_episodes: Number of episodes. + seed: Master RNG seed for per-episode reproducibility. + + Returns: + Standard status dict. On success, carries per-episode cumulative + reward + aggregate success_rate / avg_reward / avg_steps in the + JSON payload. + """ + from strands_robots.policies import create_policy + from strands_robots.simulation.benchmark import get_benchmark + + spec = get_benchmark(benchmark_name) + if spec is None: + from strands_robots.simulation.benchmark import list_benchmarks as _list + + available = sorted(_list().keys()) + return { + "status": "error", + "content": [ + { + "text": ( + f"evaluate_benchmark: no benchmark registered under " + f"{benchmark_name!r}. Registered: {available}. " + "Call register_benchmark_from_file or register_benchmark first." + ) + } + ], + } + + robots = self.list_robots() + if not robots: + return {"status": "error", "content": [{"text": "No robots in sim. Add one first."}]} + + resolved_robot = robot_name + if not resolved_robot: + # Try to pick a robot. Prefer single-robot scenes; multi-robot + # scenes require explicit selection. + if len(robots) == 1: + resolved_robot = robots[0] + else: + return { + "status": "error", + "content": [ + { + "text": ( + f"evaluate_benchmark: 'robot_name' is required when the sim has " + f"multiple robots. Loaded: {robots}" + ) + } + ], + } + if resolved_robot not in robots: + return { + "status": "error", + "content": [{"text": f"Robot '{resolved_robot}' not found. Loaded: {robots}"}], + } + + policy = create_policy(policy_provider, **(policy_config or {})) + policy.set_robot_state_keys(self.robot_joint_names(resolved_robot)) + + return PolicyRunner(self).evaluate( + resolved_robot, + policy, + instruction=instruction, + n_episodes=n_episodes, + spec=spec, + seed=seed, + ) + + def list_benchmarks(self) -> dict[str, Any]: + """Enumerate registered benchmarks. + + Returns a standard status dict whose JSON payload contains the + :func:`~strands_robots.simulation.benchmark.list_benchmarks` + metadata snapshot. Safe to call from any backend; the registry is + engine-agnostic. + """ + from strands_robots.simulation.benchmark import list_benchmarks as _list + + snapshot = _list() + if not snapshot: + text = "No benchmarks registered. Use register_benchmark_from_file to add one." + else: + lines = [f"Registered benchmarks ({len(snapshot)}):"] + for name, meta in snapshot.items(): + lines.append( + f" • {name}: {meta['class']} " + f"(robots={meta['supported_robots'] or 'any'}, " + f"default={meta['default_robot']}, " + f"max_steps={meta['max_steps']})" + ) + text = "\n".join(lines) + return { + "status": "success", + "content": [{"text": text}, {"json": {"benchmarks": snapshot}}], + } + + def register_benchmark_from_file( + self, + benchmark_name: str, + spec_path: str, + ) -> dict[str, Any]: + """Load a declarative benchmark spec from disk and register it. + + Wraps :func:`strands_robots.simulation.benchmark_spec.register_benchmark_from_file` + so agents can author benchmarks as YAML / JSON at runtime. Parsing + errors surface as structured error dicts rather than exceptions. + """ + from strands_robots.simulation.benchmark_spec import ( + register_benchmark_from_file as _register, + ) + + if not benchmark_name: + return { + "status": "error", + "content": [{"text": "register_benchmark_from_file: 'benchmark_name' must be non-empty."}], + } + if not spec_path: + return { + "status": "error", + "content": [{"text": "register_benchmark_from_file: 'spec_path' must be non-empty."}], + } + try: + benchmark = _register(benchmark_name, spec_path) + except FileNotFoundError as e: + return {"status": "error", "content": [{"text": f"register_benchmark_from_file: {e}"}]} + except ValueError as e: + return {"status": "error", "content": [{"text": f"register_benchmark_from_file: {e}"}]} + except ImportError as e: + # YAML support requires pyyaml; surface the install hint verbatim. + return {"status": "error", "content": [{"text": f"{e}"}]} + except Exception as e: # noqa: BLE001 - defensive catch-all with clear message + return { + "status": "error", + "content": [{"text": f"register_benchmark_from_file: unexpected error: {e}"}], + } + + return { + "status": "success", + "content": [ + { + "text": ( + f"📋 Registered benchmark '{benchmark_name}' from {spec_path}\n" + f" class: {type(benchmark).__name__}\n" + f" supported_robots: {benchmark.supported_robots or 'any'}\n" + f" default_robot: {benchmark.default_robot}\n" + f" max_steps: {benchmark.max_steps}" + ) + } + ], + } + def _make_run_policy_hook(self, robot_name: str, instruction: str) -> Any: """Override to return an ``on_frame(step, obs, action)`` callable. diff --git a/strands_robots/simulation/benchmark.py b/strands_robots/simulation/benchmark.py new file mode 100644 index 0000000..243eac7 --- /dev/null +++ b/strands_robots/simulation/benchmark.py @@ -0,0 +1,307 @@ +"""Benchmark-agnostic evaluation protocol for any ``SimEngine``. + +Every standard benchmark (LIBERO, Meta-World, RoboSuite, ManiSkill, user-authored +tasks) has a different notion of "what a task is" - sparse-success, dense-reward, +procedural scenes, BDDL predicates, hardcoded robots, etc. The correct abstraction +is the protocol the eval loop calls into, not a benchmark-specific schema. + +:class:`BenchmarkProtocol` is that protocol. Each adapter implements a handful of +lifecycle hooks (``on_episode_start``, ``on_step``, ``is_success``, ``is_failure``) +and declares the robots it is compatible with. The evaluation loop +(:meth:`~strands_robots.simulation.policy_runner.PolicyRunner.evaluate`) drives +the protocol without knowing anything about the underlying benchmark. + +Adapters live in optional extras (``strands-robots[benchmark-libero]`` etc.); +the core package stays dependency-free. A reference :class:`DeclarativeBenchmark` +shipped in :mod:`strands_robots.simulation.benchmark_spec` turns a YAML/JSON +spec into a fully functional ``BenchmarkProtocol`` instance - LLMs can author +and register benchmarks at runtime without writing Python code. + +Registry: a module-level ``dict[str, BenchmarkProtocol]`` keyed by name, +mirroring the shape of :func:`~strands_robots.simulation.model_registry.register_urdf`. +Registration is idempotent-by-overwrite: re-registering the same name replaces +the previous entry and logs a warning. This matches how users iterate on a +spec file during development. + +Thread safety: the registry is guarded by an internal lock so concurrent +registrations from agent threads do not race. The benchmark instances +themselves are expected to be immutable after registration - adapters that +keep per-episode state MUST put it on the ``rng``-scoped call, not on ``self``. +""" + +from __future__ import annotations + +import logging +import random +import threading +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from strands_robots.simulation.base import SimEngine + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class StepInfo: + """Per-step feedback from a :class:`BenchmarkProtocol`. + + Returned by :meth:`BenchmarkProtocol.on_step`. The evaluation loop + accumulates ``reward`` across steps and terminates the episode early + when ``done`` is ``True`` *or* when :meth:`BenchmarkProtocol.is_success` + / :meth:`BenchmarkProtocol.is_failure` fires. + + Attributes: + reward: Dense reward for this step. Sparse-success benchmarks + return ``0.0`` on every step that isn't a terminal success. + done: Early-termination flag - set when the benchmark knows the + episode is over (e.g. the object fell off the table and + nothing further will happen). + info: Free-form metadata propagated into the per-episode result + under the ``info`` key. Safe for small scalars / diagnostics - + do NOT stuff large tensors here. + """ + + reward: float = 0.0 + done: bool = False + info: dict[str, Any] = field(default_factory=dict) + + +class BenchmarkProtocol(ABC): + """Protocol every benchmark (LIBERO, Meta-World, custom) implements. + + Subclass this, declare :attr:`supported_robots` + :attr:`default_robot`, + and implement :meth:`on_step` + :meth:`is_success`. The default + :meth:`on_episode_start` validates robot compatibility and auto-loads + :attr:`default_robot` when the sim is empty, which is the right behaviour + for 90% of adapters; override only if you need per-episode scene setup + beyond what :meth:`SimEngine.reset` provides. + + Robot compatibility is first-class metadata. LIBERO's BDDL and scene + files reference Panda body names, Meta-World hardcodes Sawyer, RoboSuite + parameterizes over a fixed robot list - without declaring which robots a + benchmark accepts, agents will silently evaluate with the wrong robot. + :meth:`~strands_robots.simulation.policy_runner.PolicyRunner.evaluate` + validates the sim's robot against :attr:`supported_robots` before episode + 1 and returns a structured error on mismatch. + + Attributes: + max_steps: Per-episode horizon. Instance attribute (not abstract + property) so subclasses can set it in ``__init__`` or as a class + attribute. Defaults to ``300``. + """ + + max_steps: int = 300 + + # Robot compatibility (first-class metadata) + + @property + @abstractmethod + def supported_robots(self) -> list[str]: + """Registry ``data_config`` names this benchmark accepts. + + Empty list means "any robot" (unusual; dense-reward benchmarks rarely + generalise across embodiments). LIBERO-shaped adapters should return + a closed list of Panda variants; Meta-World should return its Sawyer + variants; etc. + """ + + @property + @abstractmethod + def default_robot(self) -> str: + """Robot :meth:`on_episode_start` loads when the sim is empty. + + Must be an element of :attr:`supported_robots` (or any compatible + registry name when ``supported_robots`` is empty). Declared + separately from ``supported_robots[0]`` so multi-robot benchmarks + can be explicit about their canonical default. + """ + + # Lifecycle hooks + + def on_episode_start(self, sim: SimEngine, rng: random.Random) -> None: + """Per-episode init. Called after ``sim.reset()`` and before the first obs. + + Default implementation enforces robot compatibility: + + * If the sim has no robots, add :attr:`default_robot` via + ``sim.add_robot(name="robot", data_config=default_robot)``. + * Otherwise, validate that every loaded robot's ``data_config`` is + in :attr:`supported_robots` (when non-empty). Mismatches raise + :class:`BenchmarkCompatibilityError` - the eval loop catches that + and returns a structured error with the allowed list. + + Override to layer on per-episode randomization, goal sampling, or + procedural scene generation. Always call ``super().on_episode_start`` + first unless you deliberately want to skip compatibility checks. + + Args: + sim: The engine being driven. + rng: Seeded per-episode RNG. Always use this - don't create your + own ``random.Random()`` or seeding will be non-reproducible. + """ + robots = sim.list_robots() + if not robots: + sim.add_robot(name="robot", data_config=self.default_robot) + return + + # Validate all loaded robots against supported_robots + supported = self.supported_robots + if not supported: # empty list means "any robot" + return + + # data_config lookup is backend-specific; MuJoCo stores it on SimRobot. + # Reach into sim._world if available (cheap duck-typing); otherwise skip + # the check rather than false-positive error. Adapters needing stricter + # checks should override on_episode_start. + world = getattr(sim, "_world", None) + if world is None or not hasattr(world, "robots"): + return + for rname in robots: + robot_obj = world.robots.get(rname) + if robot_obj is None: + continue + data_config = getattr(robot_obj, "data_config", None) + if data_config is None or data_config in supported: + continue + raise BenchmarkCompatibilityError( + robot_name=rname, + data_config=data_config, + supported=supported, + ) + + @abstractmethod + def on_step(self, sim: SimEngine, obs: dict[str, Any], action: dict[str, Any]) -> StepInfo: + """Return dense reward + done flag + info dict for this step. + + Called after every ``sim.send_action(action)`` with the observation + that produced the action and the action itself. Sparse-success + benchmarks return ``StepInfo(reward=0.0, done=False)`` on every + non-terminal step. + """ + + @abstractmethod + def is_success(self, sim: SimEngine) -> bool: + """Terminal success predicate. + + Called every step after :meth:`on_step`. Returning ``True`` ends + the episode with ``success=True``. Must be side-effect-free: the + evaluation loop may call this multiple times per step depending on + how backends batch success / failure checks. + """ + + def is_failure(self, sim: SimEngine) -> bool: + """Optional early-termination failure condition. + + Default: always ``False``. Override to end an episode early without + marking success (e.g. the arm self-collided, the object fell off + the table, the agent picked the wrong object). Failure ends the + episode; it does not count as a success. + """ + return False + + +class BenchmarkCompatibilityError(ValueError): + """Raised when a benchmark's robot compatibility check fails. + + Carries enough context (robot name, loaded data_config, supported list) + for the eval loop to produce an actionable structured error. Subclasses + :class:`ValueError` so code that uses broad ``except ValueError`` still + catches it cleanly. + """ + + def __init__(self, robot_name: str, data_config: str, supported: list[str]): + self.robot_name = robot_name + self.data_config = data_config + self.supported = list(supported) + super().__init__( + f"Robot '{robot_name}' (data_config={data_config!r}) is not compatible " + f"with this benchmark. Supported: {self.supported}" + ) + + +# Registry + +# Module-level registry - mirrors model_registry._URDF_REGISTRY. Mutable dict +# plus an RLock for thread safety; registry ops are cheap so we do not shard. +_BENCHMARK_REGISTRY: dict[str, BenchmarkProtocol] = {} +_REGISTRY_LOCK = threading.RLock() + + +def register_benchmark(name: str, benchmark: BenchmarkProtocol) -> None: + """Register a :class:`BenchmarkProtocol` under ``name``. + + Idempotent-by-overwrite: re-registering the same name replaces the + previous entry and logs a warning. This matches how users iterate on a + spec file during development. + + Args: + name: String key. Must be non-empty; any other validation is up to + the caller (lowercase / underscores / hyphens are all fine). + benchmark: An instantiated :class:`BenchmarkProtocol` subclass. + + Raises: + TypeError: If ``benchmark`` is not a :class:`BenchmarkProtocol`. + ValueError: If ``name`` is empty. + """ + if not name or not isinstance(name, str): + raise ValueError(f"register_benchmark: name must be a non-empty string, got {name!r}") + if not isinstance(benchmark, BenchmarkProtocol): + raise TypeError(f"register_benchmark: expected BenchmarkProtocol instance, got {type(benchmark).__name__}") + with _REGISTRY_LOCK: + if name in _BENCHMARK_REGISTRY: + logger.warning("Overwriting existing benchmark registration: %s", name) + _BENCHMARK_REGISTRY[name] = benchmark + logger.info("📋 Registered benchmark '%s' (%s)", name, type(benchmark).__name__) + + +def unregister_benchmark(name: str) -> BenchmarkProtocol | None: + """Remove a benchmark from the registry. + + Returns the removed benchmark or ``None`` if it was not registered. + Primarily used by tests for cleanup; user code is rarely expected to + unregister benchmarks at runtime. + """ + with _REGISTRY_LOCK: + return _BENCHMARK_REGISTRY.pop(name, None) + + +def get_benchmark(name: str) -> BenchmarkProtocol | None: + """Return the registered benchmark or ``None`` if not found.""" + with _REGISTRY_LOCK: + return _BENCHMARK_REGISTRY.get(name) + + +def list_benchmarks() -> dict[str, dict[str, Any]]: + """Enumerate registered benchmarks with their metadata. + + Returns a shallow-copy snapshot keyed by name. Each value is a dict + with ``class``, ``supported_robots``, ``default_robot``, ``max_steps`` + - enough for an LLM to pick an appropriate benchmark without + instantiating one. Reads a snapshot under the registry lock so a + concurrent registration does not corrupt the returned dict. + """ + with _REGISTRY_LOCK: + snapshot = dict(_BENCHMARK_REGISTRY) + return { + name: { + "class": type(bench).__name__, + "supported_robots": list(bench.supported_robots), + "default_robot": bench.default_robot, + "max_steps": bench.max_steps, + } + for name, bench in snapshot.items() + } + + +__all__ = [ + "BenchmarkCompatibilityError", + "BenchmarkProtocol", + "StepInfo", + "get_benchmark", + "list_benchmarks", + "register_benchmark", + "unregister_benchmark", +] diff --git a/strands_robots/simulation/benchmark_spec.py b/strands_robots/simulation/benchmark_spec.py new file mode 100644 index 0000000..adbd3af --- /dev/null +++ b/strands_robots/simulation/benchmark_spec.py @@ -0,0 +1,386 @@ +"""Declarative benchmark specs loaded from YAML / JSON files. + +This module is the LLM-facing surface for authoring benchmarks without +writing Python. A spec file declares scene, success predicate, failure +predicate, and dense reward terms using the named-predicate DSL from +:mod:`strands_robots.simulation.predicates`. Nothing in a spec ever reaches +``eval`` / ``exec`` - predicates are looked up in a closed registry and +kwargs are forwarded as-is, so spec files are safe to load from untrusted +input. + +Spec schema (top-level keys):: + + name: string # required + max_steps: int # default 300 + supported_robots: list[str] # default [] (any) + default_robot: string # required - registry data_config + scene: string # optional MJCF/URDF path for sim.load_scene() + success: + all: [, ...] # all must be true + any: [, ...] # at least one must be true + failure: + all: [, ...] + any: [, ...] + dense_reward: [, ...] # summed per step + +A ```` is a dict with a ``predicate`` key naming the +predicate and any remaining keys forwarded as kwargs:: + + {predicate: body_above_z, body: cube, z: 0.2} + +Example:: + + name: drawer-open + max_steps: 300 + supported_robots: [panda] + default_robot: panda + success: + all: + - {predicate: joint_above, joint: drawer_slide, value: 0.15} + failure: + any: + - {predicate: body_below_z, body: gripper, z: -0.1} + dense_reward: + - {predicate: distance_neg, body_a: gripper, body_b: drawer_handle, weight: 1.0} + - {predicate: joint_progress, joint: drawer_slide, target: 0.2, weight: 5.0} + +Load + register via :func:`register_benchmark_from_file`; agents call this +through the ``register_benchmark_from_file`` tool action. + +YAML files require ``pyyaml`` - not a core dep. JSON works out of the box. +The loader autodetects format by extension. +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import Callable +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from strands_robots.simulation.benchmark import ( + BenchmarkProtocol, + StepInfo, + register_benchmark, +) +from strands_robots.simulation.predicates import make_predicate +from strands_robots.utils import require_optional + +if TYPE_CHECKING: + import random + + from strands_robots.simulation.base import SimEngine + +logger = logging.getLogger(__name__) + +# Canonical top-level keys allowed in a spec. Anything else is a user error +# and produces a clear message rather than silently being ignored. +_ALLOWED_TOP_LEVEL = frozenset( + { + "name", + "max_steps", + "supported_robots", + "default_robot", + "scene", + "success", + "failure", + "dense_reward", + } +) + + +def _compile_bool_group( + clause: dict[str, Any] | None, + *, + default: bool, + context: str, +) -> Callable[[SimEngine], bool]: + """Compile an ``{"all": [...], "any": [...]}`` bool group into a single callable. + + * ``None`` / missing → returns a function always returning ``default``. + * ``all``: every listed predicate must be true. + * ``any``: at least one predicate must be true. + * Both: both conditions must hold (all AND any). + + Args: + clause: The ``success`` / ``failure`` dict from the spec. + default: Value returned when the clause is absent (``False`` for + success → "never succeeds", ``False`` for failure → "never + fails"; both are reasonable). + context: Name for error messages (``"success"`` or ``"failure"``). + + Raises: + ValueError: If the clause shape is wrong. + """ + if clause is None: + return lambda _sim: default + if not isinstance(clause, dict): + raise ValueError(f"{context}: expected a dict with 'all' / 'any' keys, got {type(clause).__name__}") + + unknown = set(clause.keys()) - {"all", "any"} + if unknown: + raise ValueError(f"{context}: unknown keys {sorted(unknown)}; allowed: ['all', 'any']") + + all_calls = [_compile_call(c, context=f"{context}.all") for c in (clause.get("all") or [])] + any_calls = [_compile_call(c, context=f"{context}.any") for c in (clause.get("any") or [])] + + if not all_calls and not any_calls: + return lambda _sim: default + + def check(sim: SimEngine) -> bool: + if all_calls and not all(bool(p(sim)) for p in all_calls): + return False + if any_calls and not any(bool(p(sim)) for p in any_calls): + return False + return True + + return check + + +def _compile_call(entry: Any, *, context: str) -> Callable[[SimEngine], Any]: + """Compile one ``{predicate: , **kwargs}`` entry to a callable.""" + if not isinstance(entry, dict): + raise ValueError(f"{context}: expected a dict like {{predicate: , ...}}, got {type(entry).__name__}") + pred_name = entry.get("predicate") + if not isinstance(pred_name, str) or not pred_name: + raise ValueError(f"{context}: each entry must have a non-empty 'predicate' string") + kwargs = {k: v for k, v in entry.items() if k != "predicate"} + try: + return make_predicate(pred_name, **kwargs) + except ValueError: + # Unknown predicate; surface verbatim (already carries the valid list). + raise + except TypeError as e: + # Bad kwargs; wrap so the caller knows which predicate failed to compile. + raise ValueError(f"{context}: predicate '{pred_name}' compilation failed: {e}") from e + + +def _compile_reward_terms(terms: list[Any] | None) -> list[Callable[[SimEngine], float]]: + if terms is None: + return [] + if not isinstance(terms, list): + raise ValueError(f"dense_reward: expected a list, got {type(terms).__name__}") + compiled: list[Callable[[SimEngine], float]] = [] + for i, t in enumerate(terms): + term = _compile_call(t, context=f"dense_reward[{i}]") + compiled.append(term) + return compiled + + +class DeclarativeBenchmark(BenchmarkProtocol): + """:class:`BenchmarkProtocol` backed by a compiled DSL spec. + + Use :func:`register_benchmark_from_file` or + :meth:`DeclarativeBenchmark.from_dict` to construct one - direct + instantiation is only for tests / internal use. + + Thread safety: the compiled predicate closures capture only the spec + kwargs (ints, floats, strings, lists of floats) so instances are safe + to share across threads. The evaluation loop still drives each episode + sequentially; we do not batch episodes. + """ + + def __init__( + self, + *, + name: str, + supported_robots: list[str], + default_robot: str, + max_steps: int, + success_fn: Callable[[SimEngine], bool], + failure_fn: Callable[[SimEngine], bool], + reward_terms: list[Callable[[SimEngine], float]], + scene: str | None = None, + ): + self._name = name + self._supported_robots = list(supported_robots) + self._default_robot = default_robot + self.max_steps = int(max_steps) + self._success_fn = success_fn + self._failure_fn = failure_fn + self._reward_terms = list(reward_terms) + self._scene = scene + + @property + def name(self) -> str: + return self._name + + @property + def supported_robots(self) -> list[str]: + return list(self._supported_robots) + + @property + def default_robot(self) -> str: + return self._default_robot + + def on_episode_start(self, sim: SimEngine, rng: random.Random) -> None: + """Load the declared scene (if any) before delegating to the base impl. + + The base impl adds :attr:`default_robot` when the sim is empty and + validates robot compatibility. Scene loading happens *before* that so + a scene-declared robot is detected by the compatibility check. + """ + if self._scene: + load_scene = getattr(sim, "load_scene", None) + if load_scene is None: + logger.warning( + "DeclarativeBenchmark '%s' declares scene=%r but sim has no load_scene()", + self._name, + self._scene, + ) + else: + result = load_scene(self._scene) + if isinstance(result, dict) and result.get("status") == "error": + msg = (result.get("content") or [{}])[0].get("text", "") + raise RuntimeError( + f"DeclarativeBenchmark '{self._name}': load_scene({self._scene!r}) failed: {msg}" + ) + super().on_episode_start(sim, rng) + + def on_step(self, sim: SimEngine, obs: dict[str, Any], action: dict[str, Any]) -> StepInfo: + """Sum all registered reward terms; ``done`` is False (handled by is_success/is_failure).""" + reward = 0.0 + for term in self._reward_terms: + try: + reward += float(term(sim)) + except Exception as e: # noqa: BLE001 - defensive: one bad term shouldn't kill the episode + logger.warning("reward term failed in '%s': %s", self._name, e) + return StepInfo(reward=reward, done=False) + + def is_success(self, sim: SimEngine) -> bool: + return bool(self._success_fn(sim)) + + def is_failure(self, sim: SimEngine) -> bool: + return bool(self._failure_fn(sim)) + + @classmethod + def from_dict(cls, spec: dict[str, Any]) -> DeclarativeBenchmark: + """Compile a spec dict (already parsed from YAML/JSON) into a benchmark.""" + if not isinstance(spec, dict): + raise ValueError(f"spec must be a dict, got {type(spec).__name__}") + + unknown = set(spec.keys()) - _ALLOWED_TOP_LEVEL + if unknown: + raise ValueError( + f"Unknown top-level keys in spec: {sorted(unknown)}. Allowed: {sorted(_ALLOWED_TOP_LEVEL)}" + ) + + name = spec.get("name") + if not isinstance(name, str) or not name: + raise ValueError("spec.name: required non-empty string") + + default_robot = spec.get("default_robot") + if not isinstance(default_robot, str) or not default_robot: + raise ValueError("spec.default_robot: required non-empty string") + + supported_robots = spec.get("supported_robots", []) + if not isinstance(supported_robots, list) or not all(isinstance(r, str) for r in supported_robots): + raise ValueError("spec.supported_robots: must be a list of strings") + + # default_robot should be in supported_robots (unless list is empty = any) + if supported_robots and default_robot not in supported_robots: + raise ValueError( + f"spec.default_robot={default_robot!r} not in supported_robots={supported_robots}; " + "either add it to supported_robots or leave supported_robots empty for any-robot benchmarks" + ) + + max_steps_raw = spec.get("max_steps", 300) + if not isinstance(max_steps_raw, int) or isinstance(max_steps_raw, bool) or max_steps_raw <= 0: + raise ValueError(f"spec.max_steps: must be a positive int, got {max_steps_raw!r}") + + scene = spec.get("scene") + if scene is not None and not isinstance(scene, str): + raise ValueError(f"spec.scene: must be a string path or omitted, got {type(scene).__name__}") + + success_fn = _compile_bool_group(spec.get("success"), default=False, context="success") + failure_fn = _compile_bool_group(spec.get("failure"), default=False, context="failure") + reward_terms = _compile_reward_terms(spec.get("dense_reward")) + + return cls( + name=name, + supported_robots=supported_robots, + default_robot=default_robot, + max_steps=max_steps_raw, + success_fn=success_fn, + failure_fn=failure_fn, + reward_terms=reward_terms, + scene=scene, + ) + + +def _load_spec_file(path: str | Path) -> dict[str, Any]: + """Parse a spec file by extension. JSON via stdlib, YAML via ``pyyaml`` (optional). + + Return type is declared as ``dict[str, Any]`` but ``json.loads`` / + ``yaml.safe_load`` may produce lists, strings, etc. Caller + (``register_benchmark_from_file``) validates the parsed shape before + passing it to :meth:`DeclarativeBenchmark.from_dict`; we do the + ``isinstance`` check here so the returned value is actually a dict. + """ + p = Path(path) + if not p.exists(): + raise FileNotFoundError(f"Benchmark spec file not found: {path}") + if not p.is_file(): + raise ValueError(f"Benchmark spec path is not a file: {path}") + + suffix = p.suffix.lower() + text = p.read_text() + + parsed: Any + if suffix == ".json": + parsed = json.loads(text) + elif suffix in (".yaml", ".yml"): + yaml = require_optional( + "yaml", + pip_install="pyyaml", + purpose="YAML benchmark spec loading", + ) + parsed = yaml.safe_load(text) # type: ignore[attr-defined] + else: + raise ValueError(f"Unsupported spec file extension: {suffix!r}. Use .json, .yaml, or .yml.") + + if not isinstance(parsed, dict): + raise ValueError(f"Benchmark spec {path} must parse to a dict, got {type(parsed).__name__}") + return parsed + + +def register_benchmark_from_file( + name: str, + spec_path: str | Path, +) -> BenchmarkProtocol: + """Load a declarative benchmark spec from disk and register it under ``name``. + + Convenience wrapper that: + + 1. Parses ``spec_path`` (JSON or YAML, autodetected by extension). + 2. Compiles it into a :class:`DeclarativeBenchmark`. + 3. Registers it via :func:`register_benchmark`. + 4. Returns the instantiated benchmark for programmatic use. + + Args: + name: Registry key. Overrides any ``name`` declared inside the spec + (so the same spec file can be registered under multiple names). + spec_path: Path to a ``.json`` / ``.yaml`` / ``.yml`` file. + + Returns: + The registered :class:`DeclarativeBenchmark` instance. + + Raises: + FileNotFoundError / ValueError: From :func:`_load_spec_file`. + ValueError: From :meth:`DeclarativeBenchmark.from_dict` on bad schema. + """ + if not isinstance(name, str) or not name: + raise ValueError(f"register_benchmark_from_file: name must be a non-empty string, got {name!r}") + spec_dict = _load_spec_file(spec_path) + # Spec-internal name is informational; the registry name always wins. + spec_dict.setdefault("name", name) + benchmark = DeclarativeBenchmark.from_dict(spec_dict) + register_benchmark(name, benchmark) + return benchmark + + +__all__ = [ + "DeclarativeBenchmark", + "register_benchmark_from_file", +] diff --git a/strands_robots/simulation/mujoco/tool_spec.json b/strands_robots/simulation/mujoco/tool_spec.json index 099ce9a..cb07311 100644 --- a/strands_robots/simulation/mujoco/tool_spec.json +++ b/strands_robots/simulation/mujoco/tool_spec.json @@ -65,7 +65,10 @@ "render_all", "start_cameras_recording", "stop_cameras_recording", - "get_cameras_recording_status" + "get_cameras_recording_status", + "list_benchmarks", + "register_benchmark_from_file", + "evaluate_benchmark" ] }, "scene_path": { @@ -364,6 +367,14 @@ "items": { "type": "object" } + }, + "benchmark_name": { + "type": "string", + "description": "Name of a registered BenchmarkProtocol (for register_benchmark_from_file / evaluate_benchmark)" + }, + "spec_path": { + "type": "string", + "description": "Path to a declarative benchmark spec file (.json, .yaml, .yml). Used by register_benchmark_from_file." } }, "required": [ diff --git a/strands_robots/simulation/policy_runner.py b/strands_robots/simulation/policy_runner.py index 4968a7e..0452d53 100644 --- a/strands_robots/simulation/policy_runner.py +++ b/strands_robots/simulation/policy_runner.py @@ -33,6 +33,7 @@ import logging import os +import random import time from collections.abc import Callable from dataclasses import dataclass @@ -46,6 +47,7 @@ if TYPE_CHECKING: from strands_robots.policies.base import Policy from strands_robots.simulation.base import SimEngine + from strands_robots.simulation.benchmark import BenchmarkProtocol from strands_robots.simulation.models import TrajectoryStep @@ -512,26 +514,67 @@ def evaluate( n_episodes: int = 10, max_steps: int = 300, success_fn: SuccessFn | str | None = None, + spec: BenchmarkProtocol | None = None, + seed: int | None = None, ) -> dict[str, Any]: """Evaluate ``policy`` for ``n_episodes`` episodes. + Two evaluation paths: + + * **``spec=``** (preferred): drive a full :class:`BenchmarkProtocol`. + Per-episode seeded RNG, ``on_episode_start`` / ``on_step`` / + ``is_success`` / ``is_failure`` hooks, cumulative dense reward, + robot-compatibility validation. ``max_steps`` from the spec wins. + * **``success_fn=``**: legacy sparse-success path kept for + backwards compatibility with PR #85. Equivalent to a + ``BenchmarkProtocol`` whose ``on_step`` always returns + ``StepInfo(reward=0.0, done=False)``. + + Passing both ``spec`` and ``success_fn`` is an error - benchmarks + define their own success predicate. + Args: robot_name: Robot to evaluate. policy: Already-constructed ``Policy`` instance. instruction: Instruction forwarded to the policy. n_episodes: Number of reset → rollout episodes. - max_steps: Cap per episode. - success_fn: Either - - * ``None`` - never succeeds (dry run / performance probe). - * ``"contact"`` - success when ``sim.get_contacts()`` reports - any penetrating contact. Requires backend to implement - ``get_contacts``; falls back to ``False`` otherwise. - * callable ``(observation) -> bool``. + max_steps: Cap per episode. Ignored when ``spec`` is provided + (``spec.max_steps`` wins). + success_fn: Legacy success predicate (see above). + spec: :class:`BenchmarkProtocol` to drive the eval. When + provided, overrides the ``success_fn`` path. + seed: Master RNG seed. Each episode derives a child RNG from it, + so evaluations are reproducible within a process. Only used + when ``spec`` is provided. Returns: - Standard status dict with ``success_rate``, per-episode results. + Standard status dict. When ``spec`` is used, the JSON payload + also contains ``cumulative_reward`` and ``avg_reward`` fields + per episode and aggregate. """ + if spec is not None and success_fn is not None: + return { + "status": "error", + "content": [ + { + "text": ( + "evaluate() accepts either 'spec' or 'success_fn', not both. " + "'spec' defines its own success predicate." + ) + } + ], + } + + if spec is not None: + return self._evaluate_with_spec( + robot_name, + policy, + spec, + instruction=instruction, + n_episodes=n_episodes, + seed=seed, + ) + try: resolved_check = self._resolve_success_fn(success_fn) except ValueError as e: @@ -593,6 +636,156 @@ def evaluate( ], } + def _evaluate_with_spec( + self, + robot_name: str, + policy: Policy, + spec: BenchmarkProtocol, + *, + instruction: str, + n_episodes: int, + seed: int | None, + ) -> dict[str, Any]: + """Drive a :class:`BenchmarkProtocol` for ``n_episodes`` episodes. + + Split out from :meth:`evaluate` to keep the legacy-path body small; + both routes share the same return-dict schema plus the spec route + layers on cumulative-reward accounting. + + Robot compatibility is validated before episode 1: if the sim's + loaded robot declares a ``data_config`` not in + ``spec.supported_robots`` (non-empty), we return a structured error + with the allowed list instead of silently running a mismatched + evaluation. + """ + # Lazy import to avoid circular reference (benchmark module imports + # `SimEngine` from base which imports this module under TYPE_CHECKING). + from strands_robots.simulation.benchmark import BenchmarkCompatibilityError + + # T26: skip camera rendering when the policy does not need images. + _skip_images = not getattr(policy, "requires_images", True) + master_rng = random.Random(seed) + spec_name = type(spec).__name__ + max_steps = spec.max_steps + results: list[dict[str, Any]] = [] + + for ep in range(n_episodes): + self.sim.reset() + # Per-episode seeded RNG - deterministic given the master seed + # and the episode index. + episode_seed = master_rng.randint(0, 2**31 - 1) + episode_rng = random.Random(episode_seed) + + try: + spec.on_episode_start(self.sim, episode_rng) + except BenchmarkCompatibilityError as e: + # Surface the structured error with the supported list - + # agents can fix this without retrying. + return { + "status": "error", + "content": [ + { + "text": ( + f"Benchmark compatibility error: robot '{e.robot_name}' " + f"has data_config={e.data_config!r}, but benchmark " + f"{spec_name} supports {e.supported}." + ) + } + ], + } + except Exception as e: # noqa: BLE001 - surface as structured error + logger.exception("on_episode_start failed") + return { + "status": "error", + "content": [{"text": f"on_episode_start failed in {spec_name}: {e}"}], + } + + success = False + failure = False + steps = 0 + cumulative_reward = 0.0 + last_info: dict[str, Any] = {} + + for _ in range(max_steps): + observation = self.sim.get_observation(robot_name=robot_name, skip_images=_skip_images) + coro_or_result = policy.get_actions(observation, instruction) + actions = _resolve_coroutine(coro_or_result) + + if actions: + action_applied: dict[str, Any] = dict(actions[0]) + self.sim.send_action(action_applied, robot_name=robot_name) + else: + # Degenerate policy - advance physics so loop terminates. + action_applied = {} + self.sim.step(n_steps=1) + + steps += 1 + try: + info = spec.on_step(self.sim, observation, action_applied) + except Exception as e: # noqa: BLE001 + logger.exception("on_step failed in %s", spec_name) + return { + "status": "error", + "content": [{"text": f"on_step failed in {spec_name}: {e}"}], + } + cumulative_reward += float(info.reward) + last_info = dict(info.info) if info.info else {} + + if info.done: + break + if spec.is_failure(self.sim): + failure = True + break + if spec.is_success(self.sim): + success = True + break + + results.append( + { + "episode": ep, + "steps": steps, + "success": success, + "failure": failure, + "cumulative_reward": round(cumulative_reward, 4), + "seed": episode_seed, + "info": last_info, + } + ) + + n_success = sum(1 for r in results if r["success"]) + n_failure = sum(1 for r in results if r["failure"]) + success_rate = n_success / max(n_episodes, 1) + avg_steps = sum(r["steps"] for r in results) / max(n_episodes, 1) + avg_reward = sum(r["cumulative_reward"] for r in results) / max(n_episodes, 1) + + return { + "status": "success", + "content": [ + { + "text": ( + f"📊 Benchmark: {spec_name} | policy {type(policy).__name__} on '{robot_name}'\n" + f"Episodes: {n_episodes} | Success: {n_success} | Failure: {n_failure} " + f"({success_rate:.1%} success)\n" + f"Avg reward: {avg_reward:.2f} | Avg steps: {avg_steps:.0f}/{max_steps}" + ) + }, + { + "json": { + "success_rate": round(success_rate, 4), + "n_episodes": n_episodes, + "n_success": n_success, + "n_failure": n_failure, + "avg_steps": round(avg_steps, 1), + "avg_reward": round(avg_reward, 4), + "max_steps": max_steps, + "seed": seed, + "benchmark_class": spec_name, + "episodes": results, + } + }, + ], + } + # Helpers def _maybe_sim_time(self) -> float | None: diff --git a/strands_robots/simulation/predicates.py b/strands_robots/simulation/predicates.py new file mode 100644 index 0000000..65a9af5 --- /dev/null +++ b/strands_robots/simulation/predicates.py @@ -0,0 +1,373 @@ +"""Named-predicate library for declarative :class:`BenchmarkProtocol` specs. + +Each entry in :data:`PREDICATE_REGISTRY` is a factory ``(**kwargs) -> callable`` +where the returned callable takes a :class:`SimEngine` and returns either +``bool`` (for success/failure predicates) or ``float`` (for reward terms). + +The registry is a closed set - the YAML/JSON loader in +:mod:`strands_robots.simulation.benchmark_spec` refuses predicates whose +name is not in this registry, so spec files are safe to parse from +untrusted / LLM-authored input. **No ``eval`` is ever called.** User-defined +predicates must be registered programmatically via :func:`register_predicate` +before loading the spec. + +Predicates are backend-aware but not backend-specific: they exclusively call +``SimEngine`` methods (abstract) or probe for MuJoCo-only methods via +``getattr`` and return a safe fallback (``False`` / ``0.0``) when the +backend does not support them. A predicate that silently evaluates to +``False`` because of an unimplemented backend call is a bug in the +predicate, not the benchmark - file an issue. + +Available predicates (bool): + + body_above_z(body, z) + body_below_z(body, z) + joint_above(joint, value) + joint_below(joint, value) + distance_less_than(body_a, body_b, threshold) + inside_region(body, min, max) + contact_between(geom_a, geom_b) + contact_any() + +Available reward terms (float): + + distance_neg(body_a, body_b, weight=1.0) + joint_progress(joint, target, weight=1.0) + constant(value) + +Register custom predicates with :func:`register_predicate`. +""" + +from __future__ import annotations + +import logging +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from strands_robots.simulation.base import SimEngine + +logger = logging.getLogger(__name__) + +BoolPredicate = Callable[["SimEngine"], bool] +RewardTerm = Callable[["SimEngine"], float] +PredicateFactory = Callable[..., Callable[["SimEngine"], Any]] + + +# Helpers for digging values out of the structured ``{"status", "content"}`` +# dicts that MuJoCo-backend methods return. Defensive against empty content +# lists and missing keys - predicates should never crash the eval loop. + + +def _extract_json(result: dict[str, Any] | None) -> dict[str, Any]: + """Return the ``json`` content block payload, or ``{}`` if absent.""" + if not isinstance(result, dict): + return {} + for block in result.get("content", []) or []: + if isinstance(block, dict): + payload = block.get("json") + if isinstance(payload, dict): + # dict[str, Any] by construction of the content schema; mypy can't + # narrow through dict.get() so we cast via a new dict to keep it typed. + return dict(payload) + return {} + + +def _body_position(sim: SimEngine, body: str) -> list[float] | None: + """Best-effort body-position lookup. Returns ``None`` on any failure. + + Requires the backend to implement ``get_body_state`` (MuJoCo only at time + of writing). Future backends can add the same method signature - see + :meth:`strands_robots.simulation.mujoco.physics.PhysicsMixin.get_body_state`. + """ + get_body_state = getattr(sim, "get_body_state", None) + if get_body_state is None: + return None + try: + result = get_body_state(body_name=body) + except Exception as e: # noqa: BLE001 - defensive: predicates never raise + logger.debug("body_position(%r) failed: %s", body, e) + return None + if not isinstance(result, dict) or result.get("status") != "success": + return None + payload = _extract_json(result) + pos = payload.get("position") + if isinstance(pos, list) and len(pos) == 3 and all(isinstance(c, (int, float)) for c in pos): + return [float(c) for c in pos] + return None + + +def _joint_position(sim: SimEngine, joint: str) -> float | None: + """Best-effort joint-position lookup via ``get_observation``. + + ``get_observation`` is on the ABC and returns ``{: float}``. + When the joint is absent from the observation dict (wrong robot, wrong + namespace) we return ``None`` so predicates can decide between ``False`` + and an explicit error path. + """ + try: + obs = sim.get_observation(skip_images=True) + except Exception as e: # noqa: BLE001 - defensive + logger.debug("get_observation() failed: %s", e) + return None + if not isinstance(obs, dict): + return None + val = obs.get(joint) + if isinstance(val, (int, float)) and not isinstance(val, bool): + return float(val) + return None + + +def _euclidean_distance(a: list[float], b: list[float]) -> float: + """Simple 3D Euclidean distance; no numpy so predicates stay dependency-free.""" + dx = a[0] - b[0] + dy = a[1] - b[1] + dz = a[2] - b[2] + return float((dx * dx + dy * dy + dz * dz) ** 0.5) + + +# Predicate factories + + +def _body_above_z(body: str, z: float) -> BoolPredicate: + def check(sim: SimEngine) -> bool: + pos = _body_position(sim, body) + return pos is not None and pos[2] > float(z) + + return check + + +def _body_below_z(body: str, z: float) -> BoolPredicate: + def check(sim: SimEngine) -> bool: + pos = _body_position(sim, body) + return pos is not None and pos[2] < float(z) + + return check + + +def _joint_above(joint: str, value: float) -> BoolPredicate: + def check(sim: SimEngine) -> bool: + q = _joint_position(sim, joint) + return q is not None and q > float(value) + + return check + + +def _joint_below(joint: str, value: float) -> BoolPredicate: + def check(sim: SimEngine) -> bool: + q = _joint_position(sim, joint) + return q is not None and q < float(value) + + return check + + +def _distance_less_than(body_a: str, body_b: str, threshold: float) -> BoolPredicate: + def check(sim: SimEngine) -> bool: + pos_a = _body_position(sim, body_a) + pos_b = _body_position(sim, body_b) + if pos_a is None or pos_b is None: + return False + return _euclidean_distance(pos_a, pos_b) < float(threshold) + + return check + + +def _inside_region(body: str, min: list[float], max: list[float]) -> BoolPredicate: # noqa: A002 - DSL keyword + if not (isinstance(min, list) and len(min) == 3 and isinstance(max, list) and len(max) == 3): + raise ValueError("inside_region: 'min' and 'max' must each be a list of 3 numbers") + lo = [float(c) for c in min] + hi = [float(c) for c in max] + if any(lo[i] > hi[i] for i in range(3)): + raise ValueError(f"inside_region: 'min' {lo} must be component-wise <= 'max' {hi}") + + def check(sim: SimEngine) -> bool: + pos = _body_position(sim, body) + if pos is None: + return False + return all(lo[i] <= pos[i] <= hi[i] for i in range(3)) + + return check + + +def _contact_between(geom_a: str, geom_b: str) -> BoolPredicate: + """Pairwise contact predicate. + + Requires ``get_contacts()`` (MuJoCo). Ignores contact ordering - a contact + reported as ``(geom_a, geom_b)`` matches the same predicate as + ``(geom_b, geom_a)``. + """ + + def check(sim: SimEngine) -> bool: + get_contacts = getattr(sim, "get_contacts", None) + if get_contacts is None: + return False + try: + result = get_contacts() + except Exception as e: # noqa: BLE001 - defensive + logger.debug("contact_between(%r,%r) failed: %s", geom_a, geom_b, e) + return False + payload = _extract_json(result) + contacts = payload.get("contacts") + if not isinstance(contacts, list): + return False + want = {geom_a, geom_b} + for c in contacts: + if not isinstance(c, dict): + continue + pair = {c.get("geom1"), c.get("geom2")} + if want <= pair: + return True + return False + + return check + + +def _contact_any() -> BoolPredicate: + """Sparse "any contact" predicate - matches the legacy ``success_fn='contact'`` path.""" + + def check(sim: SimEngine) -> bool: + get_contacts = getattr(sim, "get_contacts", None) + if get_contacts is None: + return False + try: + result = get_contacts() + except Exception as e: # noqa: BLE001 - defensive + logger.debug("contact_any() failed: %s", e) + return False + payload = _extract_json(result) + if payload.get("n_contacts", 0) > 0: + return True + contacts = payload.get("contacts") + return bool(isinstance(contacts, list) and contacts) + + return check + + +# Reward terms (float-valued) + + +def _distance_neg(body_a: str, body_b: str, weight: float = 1.0) -> RewardTerm: + """Negative Euclidean distance between two bodies, weighted. + + The canonical "reach" reward: ``weight * -dist(a, b)``. Monotonic in + the distance, so naive policy improvement pulls the bodies together. + """ + w = float(weight) + + def term(sim: SimEngine) -> float: + pos_a = _body_position(sim, body_a) + pos_b = _body_position(sim, body_b) + if pos_a is None or pos_b is None: + return 0.0 + return -w * _euclidean_distance(pos_a, pos_b) + + return term + + +def _joint_progress(joint: str, target: float, weight: float = 1.0) -> RewardTerm: + """Negative absolute distance from a joint to its target, weighted. + + Useful for drawer/door tasks where success is "joint near target + position" and you want dense signal during training. + """ + w = float(weight) + t = float(target) + + def term(sim: SimEngine) -> float: + q = _joint_position(sim, joint) + if q is None: + return 0.0 + return -w * abs(q - t) + + return term + + +def _constant(value: float) -> RewardTerm: + """Constant reward per step. Useful for shaping a survival bonus.""" + v = float(value) + + def term(_sim: SimEngine) -> float: + return v + + return term + + +# Registry + +PREDICATE_REGISTRY: dict[str, PredicateFactory] = { + # bool-valued + "body_above_z": _body_above_z, + "body_below_z": _body_below_z, + "joint_above": _joint_above, + "joint_below": _joint_below, + "distance_less_than": _distance_less_than, + "inside_region": _inside_region, + "contact_between": _contact_between, + "contact_any": _contact_any, + # float-valued + "distance_neg": _distance_neg, + "joint_progress": _joint_progress, + "constant": _constant, +} + + +def register_predicate(name: str, factory: PredicateFactory) -> None: + """Register a user-defined predicate factory. + + Must be called before loading a spec that references ``name``. Factories + registered at runtime are NOT sandboxed - by registering, you opt into + running the factory with kwargs parsed from the spec. Only register + predicates from trusted code paths; anything LLM-authored should use the + built-in DSL exclusively. + + Args: + name: Predicate name used in spec files. Must not shadow a built-in. + factory: Callable that takes DSL kwargs and returns a predicate + ``(sim) -> bool`` or reward term ``(sim) -> float``. + + Raises: + ValueError: If ``name`` shadows a built-in predicate. + TypeError: If ``factory`` is not callable. + """ + if name in PREDICATE_REGISTRY: + raise ValueError(f"register_predicate: '{name}' shadows a built-in predicate; pick a different name") + if not callable(factory): + raise TypeError(f"register_predicate: factory must be callable, got {type(factory).__name__}") + PREDICATE_REGISTRY[name] = factory + + +def make_predicate(name: str, **kwargs: Any) -> Callable[[SimEngine], Any]: + """Instantiate a predicate from its name + kwargs. + + This is the single entry point the DSL loader uses - it never touches + ``eval`` or ``exec``. Unknown names produce a ``ValueError`` listing + the valid set; bad kwargs surface as whatever ``TypeError`` the factory + raises. + + Args: + name: Predicate name. Must be registered in :data:`PREDICATE_REGISTRY`. + **kwargs: Forwarded verbatim to the factory. + + Returns: + A callable ``(sim) -> bool`` or ``(sim) -> float`` depending on the + predicate. + + Raises: + ValueError: If ``name`` is unknown. + TypeError: If required factory kwargs are missing. + """ + factory = PREDICATE_REGISTRY.get(name) + if factory is None: + valid = sorted(PREDICATE_REGISTRY.keys()) + raise ValueError(f"Unknown predicate '{name}'. Valid: {valid}") + return factory(**kwargs) + + +__all__ = [ + "PREDICATE_REGISTRY", + "BoolPredicate", + "PredicateFactory", + "RewardTerm", + "make_predicate", + "register_predicate", +] diff --git a/tests/simulation/mujoco/test_benchmark_dispatch.py b/tests/simulation/mujoco/test_benchmark_dispatch.py new file mode 100644 index 0000000..f215195 --- /dev/null +++ b/tests/simulation/mujoco/test_benchmark_dispatch.py @@ -0,0 +1,238 @@ +"""Dispatch-path tests for the benchmark tool actions on the MuJoCo backend. + +Mirrors the ``test_agenttool_contract.py`` pattern: exercises ``_dispatch_action`` +with the new action names (``list_benchmarks``, ``register_benchmark_from_file``, +``evaluate_benchmark``) and asserts: + +* the tool_spec ``action`` enum exposes them, +* unknown / missing params produce the friendly structured errors that the + dispatcher generates from ``inspect.signature``, +* the underlying ``SimEngine`` facade is reached for valid inputs. + +Integration with real MuJoCo physics is deferred to ``tests_integ/``; these +tests only need the Simulation stub without a created world. +""" + +from __future__ import annotations + +import json +import os +import shutil +import tempfile +from pathlib import Path + +import pytest + +mj = pytest.importorskip("mujoco") + +from strands_robots.simulation.benchmark import _BENCHMARK_REGISTRY # noqa: E402 +from strands_robots.simulation.mujoco.simulation import Simulation # noqa: E402 + +# Robot XML reused from test_simulation.py (keep in sync if the canonical +# fixture changes). +ROBOT_XML = """ + + + +""" + + +@pytest.fixture(autouse=True) +def _clean_registry(): + snapshot = dict(_BENCHMARK_REGISTRY) + _BENCHMARK_REGISTRY.clear() + yield + _BENCHMARK_REGISTRY.clear() + _BENCHMARK_REGISTRY.update(snapshot) + + +@pytest.fixture +def sim(): + s = Simulation(tool_name="bench_sim", mesh=False) + yield s + s.cleanup() + + +@pytest.fixture +def robot_xml_path(): + tmpdir = tempfile.mkdtemp() + path = os.path.join(tmpdir, "test_arm.xml") + with open(path, "w") as f: + f.write(ROBOT_XML) + yield path + shutil.rmtree(tmpdir, ignore_errors=True) + + +@pytest.fixture +def sim_with_robot(sim, robot_xml_path): + sim.create_world() + sim.add_robot("arm1", urdf_path=robot_xml_path) + return sim + + +@pytest.fixture +def basic_spec_file(tmp_path: Path): + path = tmp_path / "basic.json" + path.write_text( + json.dumps( + { + "name": "basic-task", + "default_robot": "arm1", + "supported_robots": [], # any robot (so100 by data_config isn't loaded here) + "max_steps": 5, + } + ) + ) + return str(path) + + +# Tool spec: action enum + property surface + + +class TestToolSpecSurface: + def test_enum_includes_new_actions(self, sim): + # _TOOL_SPEC_SCHEMA lives at module level; read via module introspection. + from strands_robots.simulation.mujoco import simulation as _sim_mod + + enum = _sim_mod._TOOL_SPEC_SCHEMA["properties"]["action"]["enum"] + assert "list_benchmarks" in enum + assert "register_benchmark_from_file" in enum + assert "evaluate_benchmark" in enum + + def test_property_surface_has_new_params(self): + from strands_robots.simulation.mujoco import simulation as _sim_mod + + props = _sim_mod._TOOL_SPEC_SCHEMA["properties"] + assert "benchmark_name" in props + assert "spec_path" in props + + +# Dispatch + + +class TestListBenchmarksDispatch: + def test_empty_registry(self, sim): + result = sim._dispatch_action("list_benchmarks", {"action": "list_benchmarks"}) + assert result["status"] == "success" + assert "No benchmarks" in result["content"][0]["text"] + + def test_lists_registered(self, sim, basic_spec_file): + sim._dispatch_action( + "register_benchmark_from_file", + { + "action": "register_benchmark_from_file", + "benchmark_name": "basic", + "spec_path": basic_spec_file, + }, + ) + result = sim._dispatch_action("list_benchmarks", {"action": "list_benchmarks"}) + assert result["status"] == "success" + payload = next(c["json"] for c in result["content"] if "json" in c) + assert "basic" in payload["benchmarks"] + + +class TestRegisterBenchmarkFromFileDispatch: + def test_happy_path(self, sim, basic_spec_file): + result = sim._dispatch_action( + "register_benchmark_from_file", + { + "action": "register_benchmark_from_file", + "benchmark_name": "happy", + "spec_path": basic_spec_file, + }, + ) + assert result["status"] == "success" + assert "happy" in result["content"][0]["text"] + + def test_no_args_friendly_error(self, sim): + """Dispatcher surfaces missing required params with a clear message.""" + result = sim._dispatch_action("register_benchmark_from_file", {"action": "register_benchmark_from_file"}) + assert result["status"] == "error" + text = result["content"][0]["text"] + assert "requires parameter" in text + + def test_bad_spec_path_returns_structured_error(self, sim): + result = sim._dispatch_action( + "register_benchmark_from_file", + { + "action": "register_benchmark_from_file", + "benchmark_name": "missing", + "spec_path": "/nonexistent/path/nope.json", + }, + ) + assert result["status"] == "error" + assert "not found" in result["content"][0]["text"].lower() + + def test_unknown_param_rejected(self, sim, basic_spec_file): + """Dispatcher rejects unknown kwargs so spec drift is caught early.""" + result = sim._dispatch_action( + "register_benchmark_from_file", + { + "action": "register_benchmark_from_file", + "benchmark_name": "x", + "spec_path": basic_spec_file, + "bogus": 1, + }, + ) + assert result["status"] == "error" + assert "Unknown parameter 'bogus'" in result["content"][0]["text"] + + +class TestEvaluateBenchmarkDispatch: + def test_requires_benchmark_name(self, sim): + result = sim._dispatch_action("evaluate_benchmark", {"action": "evaluate_benchmark"}) + assert result["status"] == "error" + assert "requires parameter" in result["content"][0]["text"] + + def test_unknown_benchmark(self, sim_with_robot): + result = sim_with_robot._dispatch_action( + "evaluate_benchmark", + {"action": "evaluate_benchmark", "benchmark_name": "never-registered"}, + ) + assert result["status"] == "error" + assert "no benchmark registered" in result["content"][0]["text"].lower() + + def test_evaluate_end_to_end(self, sim_with_robot, basic_spec_file): + """Register a no-op benchmark, evaluate with the mock policy - must succeed.""" + sim_with_robot._dispatch_action( + "register_benchmark_from_file", + { + "action": "register_benchmark_from_file", + "benchmark_name": "e2e", + "spec_path": basic_spec_file, + }, + ) + result = sim_with_robot._dispatch_action( + "evaluate_benchmark", + { + "action": "evaluate_benchmark", + "benchmark_name": "e2e", + "robot_name": "arm1", + "policy_provider": "mock", + "n_episodes": 1, + "seed": 0, + }, + ) + assert result["status"] == "success" + payload = next(c["json"] for c in result["content"] if "json" in c) + assert payload["n_episodes"] == 1 + assert payload["benchmark_class"] == "DeclarativeBenchmark" + # With no success/failure/done in the spec, loop runs to max_steps=5. + assert payload["episodes"][0]["steps"] == 5 diff --git a/tests/simulation/test_benchmark.py b/tests/simulation/test_benchmark.py new file mode 100644 index 0000000..57abda4 --- /dev/null +++ b/tests/simulation/test_benchmark.py @@ -0,0 +1,320 @@ +"""Tests for ``strands_robots.simulation.benchmark``. + +Covers: + +* :class:`BenchmarkProtocol` ABC contract (cannot instantiate abstract, + required methods must be implemented, optional hooks have usable defaults). +* :class:`StepInfo` dataclass. +* Registry operations (:func:`register_benchmark` / :func:`get_benchmark` / + :func:`list_benchmarks` / :func:`unregister_benchmark`), including + idempotent-overwrite and thread safety. +* Robot compatibility validation via :meth:`BenchmarkProtocol.on_episode_start`. +""" + +from __future__ import annotations + +import random +import threading +from typing import Any + +import pytest + +from strands_robots.simulation.base import SimEngine +from strands_robots.simulation.benchmark import ( + _BENCHMARK_REGISTRY, + BenchmarkCompatibilityError, + BenchmarkProtocol, + StepInfo, + get_benchmark, + list_benchmarks, + register_benchmark, + unregister_benchmark, +) + + +class _MinimalBenchmark(BenchmarkProtocol): + """Concrete benchmark used across tests.""" + + max_steps = 42 + + def __init__( + self, + *, + supported: list[str] | None = None, + default: str = "so100", + success: bool = False, + failure: bool = False, + reward: float = 0.0, + ): + self._supported = list(supported if supported is not None else ["so100"]) + self._default = default + self._success = success + self._failure = failure + self._reward = reward + + @property + def supported_robots(self) -> list[str]: + return list(self._supported) + + @property + def default_robot(self) -> str: + return self._default + + def on_step(self, sim: SimEngine, obs: dict[str, Any], action: dict[str, Any]) -> StepInfo: + return StepInfo(reward=self._reward) + + def is_success(self, sim: SimEngine) -> bool: + return self._success + + def is_failure(self, sim: SimEngine) -> bool: + return self._failure + + +# Registry fixtures + + +@pytest.fixture(autouse=True) +def _clean_registry(): + """Snapshot + restore the registry around every test so they stay isolated.""" + snapshot = dict(_BENCHMARK_REGISTRY) + _BENCHMARK_REGISTRY.clear() + yield + _BENCHMARK_REGISTRY.clear() + _BENCHMARK_REGISTRY.update(snapshot) + + +# StepInfo + + +class TestStepInfo: + def test_defaults(self): + info = StepInfo() + assert info.reward == 0.0 + assert info.done is False + assert info.info == {} + + def test_custom_values(self): + info = StepInfo(reward=3.5, done=True, info={"k": "v"}) + assert info.reward == 3.5 + assert info.done is True + assert info.info == {"k": "v"} + + def test_is_frozen(self): + info = StepInfo() + with pytest.raises(Exception): # dataclasses.FrozenInstanceError + info.reward = 1.0 # type: ignore[misc] + + +# BenchmarkProtocol ABC contract + + +class TestBenchmarkProtocolContract: + def test_cannot_instantiate_abstract(self): + """ABC with abstract methods must not be instantiable.""" + with pytest.raises(TypeError): + BenchmarkProtocol() # type: ignore[abstract] + + def test_concrete_instantiates(self): + bench = _MinimalBenchmark() + assert bench.supported_robots == ["so100"] + assert bench.default_robot == "so100" + assert bench.max_steps == 42 + + def test_is_failure_default_false(self): + """The default is_failure returns False, so sparse-success benchmarks + don't need to override it.""" + + class _Sparse(_MinimalBenchmark): + pass + + # Don't set failure=True; default should return False. + bench = _Sparse() + assert bench.is_failure(None) is False # type: ignore[arg-type] + + def test_on_episode_start_has_default_impl(self): + """on_episode_start is NOT abstract - base impl handles empty-sim + compat checks.""" + + # Fake sim with no robots - should call add_robot with default_robot. + class FakeSim: + def __init__(self): + self._robots: list[str] = [] + self.add_robot_calls: list[dict[str, Any]] = [] + + def list_robots(self): + return list(self._robots) + + def add_robot(self, *, name, data_config): + self.add_robot_calls.append({"name": name, "data_config": data_config}) + self._robots.append(name) + + sim = FakeSim() + bench = _MinimalBenchmark(supported=["so100"], default="so100") + bench.on_episode_start(sim, random.Random(0)) # type: ignore[arg-type] + assert len(sim.add_robot_calls) == 1 + assert sim.add_robot_calls[0]["data_config"] == "so100" + + +# Robot compatibility + + +class TestRobotCompatibility: + def test_raises_when_loaded_robot_incompatible(self): + """Loading a robot whose data_config is not in supported_robots raises + BenchmarkCompatibilityError.""" + + class _Robot: + data_config = "panda" # not in supported + + class FakeSimWithWorld: + _world: Any = None + + def __init__(self): + self._world = type( + "World", + (), + {"robots": {"arm1": _Robot()}}, + )() + + def list_robots(self): + return ["arm1"] + + def add_robot(self, **kw): # pragma: no cover + raise AssertionError("should not be called when robot already present") + + sim = FakeSimWithWorld() + bench = _MinimalBenchmark(supported=["so100", "so101"], default="so100") + with pytest.raises(BenchmarkCompatibilityError) as excinfo: + bench.on_episode_start(sim, random.Random(0)) # type: ignore[arg-type] + assert excinfo.value.robot_name == "arm1" + assert excinfo.value.data_config == "panda" + assert excinfo.value.supported == ["so100", "so101"] + + def test_passes_when_supported_robots_empty(self): + """Empty supported_robots means 'any robot' - no compat check.""" + + class _Robot: + data_config = "any_weird_thing" + + class FakeSim: + def __init__(self): + self._world = type( + "World", + (), + {"robots": {"arm1": _Robot()}}, + )() + + def list_robots(self): + return ["arm1"] + + def add_robot(self, **kw): # pragma: no cover + raise AssertionError("should not be called") + + sim = FakeSim() + bench = _MinimalBenchmark(supported=[], default="any_weird_thing") + # Must not raise. + bench.on_episode_start(sim, random.Random(0)) # type: ignore[arg-type] + + def test_skips_check_when_sim_has_no_world_attr(self): + """Backends without a ``_world`` attribute are treated as "cannot verify" + and skip the compat check rather than false-positive.""" + + class FakeSim: + def list_robots(self): + return ["arm1"] + + def add_robot(self, **kw): # pragma: no cover + raise AssertionError("should not be called") + + sim = FakeSim() + bench = _MinimalBenchmark(supported=["so100"], default="so100") + # Must not raise even though arm1 has unknown data_config. + bench.on_episode_start(sim, random.Random(0)) # type: ignore[arg-type] + + +# Registry + + +class TestRegistry: + def test_register_and_get(self): + bench = _MinimalBenchmark() + register_benchmark("my-task", bench) + assert get_benchmark("my-task") is bench + + def test_get_unknown_returns_none(self): + assert get_benchmark("nonexistent") is None + + def test_register_rejects_non_string_name(self): + with pytest.raises(ValueError): + register_benchmark("", _MinimalBenchmark()) + with pytest.raises(ValueError): + register_benchmark(None, _MinimalBenchmark()) # type: ignore[arg-type] + + def test_register_rejects_non_benchmark(self): + with pytest.raises(TypeError): + register_benchmark("x", "not a benchmark") # type: ignore[arg-type] + + def test_register_overwrites_and_warns(self, caplog): + """Re-registering the same name replaces the entry and logs a warning.""" + first = _MinimalBenchmark() + second = _MinimalBenchmark() + register_benchmark("dup", first) + with caplog.at_level("WARNING"): + register_benchmark("dup", second) + assert get_benchmark("dup") is second + assert any("Overwriting existing" in rec.message for rec in caplog.records) + + def test_unregister_removes(self): + bench = _MinimalBenchmark() + register_benchmark("rm-me", bench) + removed = unregister_benchmark("rm-me") + assert removed is bench + assert get_benchmark("rm-me") is None + + def test_unregister_unknown_returns_none(self): + assert unregister_benchmark("never-registered") is None + + def test_list_benchmarks_metadata(self): + bench = _MinimalBenchmark(supported=["so100", "so101"], default="so100") + register_benchmark("listed", bench) + listed = list_benchmarks() + assert "listed" in listed + meta = listed["listed"] + assert meta["class"] == "_MinimalBenchmark" + assert meta["supported_robots"] == ["so100", "so101"] + assert meta["default_robot"] == "so100" + assert meta["max_steps"] == 42 + + def test_list_benchmarks_empty(self): + assert list_benchmarks() == {} + + +class TestRegistryThreadSafety: + """The registry guard is an RLock so concurrent registrations don't race.""" + + def test_concurrent_registrations_all_land(self): + benches = [_MinimalBenchmark() for _ in range(50)] + barrier = threading.Barrier(len(benches)) + + def register(i: int): + barrier.wait() + register_benchmark(f"thread-{i}", benches[i]) + + threads = [threading.Thread(target=register, args=(i,)) for i in range(len(benches))] + for t in threads: + t.start() + for t in threads: + t.join() + + listed = list_benchmarks() + for i in range(len(benches)): + assert f"thread-{i}" in listed + + +class TestBenchmarkCompatibilityError: + def test_carries_context(self): + e = BenchmarkCompatibilityError(robot_name="arm", data_config="foo", supported=["bar"]) + assert e.robot_name == "arm" + assert e.data_config == "foo" + assert e.supported == ["bar"] + # Subclasses ValueError so broad except ValueError still works. + assert isinstance(e, ValueError) diff --git a/tests/simulation/test_benchmark_dsl.py b/tests/simulation/test_benchmark_dsl.py new file mode 100644 index 0000000..841b0f7 --- /dev/null +++ b/tests/simulation/test_benchmark_dsl.py @@ -0,0 +1,365 @@ +"""Tests for ``strands_robots.simulation.benchmark_spec`` (declarative YAML/JSON loader). + +Covers: + +* :meth:`DeclarativeBenchmark.from_dict` schema validation (good / bad specs). +* :func:`register_benchmark_from_file` end-to-end with JSON + YAML. +* The sandboxed contract: unknown predicates / unknown top-level keys / + non-dict predicate entries produce clear errors, not ``eval`` side-effects. +""" + +from __future__ import annotations + +import json +import random +from pathlib import Path +from typing import Any + +import pytest + +from strands_robots.simulation.benchmark import ( + _BENCHMARK_REGISTRY, + get_benchmark, +) +from strands_robots.simulation.benchmark_spec import ( + DeclarativeBenchmark, + register_benchmark_from_file, +) + + +@pytest.fixture(autouse=True) +def _clean_registry(): + snapshot = dict(_BENCHMARK_REGISTRY) + _BENCHMARK_REGISTRY.clear() + yield + _BENCHMARK_REGISTRY.clear() + _BENCHMARK_REGISTRY.update(snapshot) + + +class _BodyStateSim: + def __init__(self, positions: dict[str, list[float]]): + self._pos = positions + + def get_body_state(self, body_name: str) -> dict[str, Any]: + if body_name not in self._pos: + return {"status": "error", "content": [{"text": "missing"}]} + return { + "status": "success", + "content": [ + {"text": body_name}, + {"json": {"position": self._pos[body_name]}}, + ], + } + + def get_observation(self, *_, **__) -> dict[str, Any]: + return {} + + +# Schema validation + + +class TestFromDictValidation: + def test_minimal_valid_spec(self): + spec = { + "name": "minimal", + "default_robot": "so100", + "supported_robots": ["so100"], + } + bench = DeclarativeBenchmark.from_dict(spec) + assert bench.name == "minimal" + assert bench.default_robot == "so100" + assert bench.supported_robots == ["so100"] + assert bench.max_steps == 300 # default + + def test_rejects_non_dict_spec(self): + with pytest.raises(ValueError): + DeclarativeBenchmark.from_dict([1, 2, 3]) # type: ignore[arg-type] + + def test_rejects_unknown_top_level_keys(self): + with pytest.raises(ValueError) as exc: + DeclarativeBenchmark.from_dict( + {"name": "x", "default_robot": "y", "supported_robots": ["y"], "weird_key": 1} + ) + assert "weird_key" in str(exc.value) + + def test_rejects_missing_name(self): + with pytest.raises(ValueError): + DeclarativeBenchmark.from_dict({"default_robot": "y", "supported_robots": []}) + + def test_rejects_missing_default_robot(self): + with pytest.raises(ValueError): + DeclarativeBenchmark.from_dict({"name": "x"}) + + def test_rejects_default_not_in_supported(self): + with pytest.raises(ValueError) as exc: + DeclarativeBenchmark.from_dict({"name": "x", "default_robot": "ghost", "supported_robots": ["a", "b"]}) + assert "not in supported_robots" in str(exc.value) + + def test_allows_default_outside_supported_when_empty(self): + """Empty supported_robots means "any" - default outside makes sense.""" + bench = DeclarativeBenchmark.from_dict({"name": "x", "default_robot": "anything", "supported_robots": []}) + assert bench.default_robot == "anything" + + def test_rejects_non_positive_max_steps(self): + for bad in (-1, 0, "300", True): + with pytest.raises(ValueError): + DeclarativeBenchmark.from_dict( + { + "name": "x", + "default_robot": "y", + "supported_robots": ["y"], + "max_steps": bad, + } + ) + + +# Predicate compilation + + +class TestPredicateCompilation: + def _base_spec(self, **overrides: Any) -> dict[str, Any]: + spec = { + "name": "t", + "default_robot": "so100", + "supported_robots": ["so100"], + "max_steps": 10, + } + spec.update(overrides) + return spec + + def test_success_all_true(self): + spec = self._base_spec( + success={ + "all": [ + {"predicate": "body_above_z", "body": "cube", "z": 0.1}, + ] + } + ) + bench = DeclarativeBenchmark.from_dict(spec) + sim_hit = _BodyStateSim({"cube": [0, 0, 0.2]}) + sim_miss = _BodyStateSim({"cube": [0, 0, 0.05]}) + assert bench.is_success(sim_hit) is True + assert bench.is_success(sim_miss) is False + + def test_success_all_any_combined(self): + """When both 'all' and 'any' are provided, both must hold.""" + spec = self._base_spec( + success={ + "all": [{"predicate": "body_above_z", "body": "cube", "z": 0.0}], + "any": [ + {"predicate": "body_above_z", "body": "cube", "z": 10.0}, + {"predicate": "body_above_z", "body": "cube", "z": 0.05}, + ], + } + ) + bench = DeclarativeBenchmark.from_dict(spec) + sim = _BodyStateSim({"cube": [0, 0, 0.1]}) + # all: z>0.0 true. any: z>10 false OR z>0.05 true → any true. Combined: true. + assert bench.is_success(sim) is True + + def test_failure_any(self): + spec = self._base_spec(failure={"any": [{"predicate": "body_below_z", "body": "cube", "z": 0.0}]}) + bench = DeclarativeBenchmark.from_dict(spec) + assert bench.is_failure(_BodyStateSim({"cube": [0, 0, -0.01]})) is True + assert bench.is_failure(_BodyStateSim({"cube": [0, 0, 0.5]})) is False + + def test_dense_reward_sums_terms(self): + spec = self._base_spec( + dense_reward=[ + {"predicate": "constant", "value": 1.0}, + {"predicate": "constant", "value": -0.5}, + ] + ) + bench = DeclarativeBenchmark.from_dict(spec) + info = bench.on_step(None, {}, {}) # type: ignore[arg-type] + assert info.reward == pytest.approx(0.5) + + def test_rejects_unknown_predicate(self): + with pytest.raises(ValueError) as exc: + DeclarativeBenchmark.from_dict(self._base_spec(success={"all": [{"predicate": "totally_made_up"}]})) + assert "Unknown predicate" in str(exc.value) + + def test_rejects_non_dict_predicate_entry(self): + with pytest.raises(ValueError): + DeclarativeBenchmark.from_dict(self._base_spec(success={"all": ["just a string"]})) + + def test_rejects_missing_predicate_key(self): + with pytest.raises(ValueError): + DeclarativeBenchmark.from_dict(self._base_spec(success={"all": [{"body": "cube", "z": 0.1}]})) + + def test_rejects_bad_clause_keys(self): + """success/failure only allow 'all' / 'any'.""" + with pytest.raises(ValueError) as exc: + DeclarativeBenchmark.from_dict(self._base_spec(success={"all": [], "other": []})) + assert "other" in str(exc.value) + + def test_predicate_bad_kwargs_surface_compile_error(self): + """Bad predicate kwargs (wrong types, missing required) surface as a + compile-time error, not a runtime predicate crash.""" + with pytest.raises(ValueError) as exc: + DeclarativeBenchmark.from_dict( + self._base_spec(success={"all": [{"predicate": "inside_region", "body": "x"}]}) + ) + # Should mention the predicate name in the error for discoverability. + assert "inside_region" in str(exc.value) + + +# Empty / default clauses + + +class TestEmptyClauses: + def test_success_absent_defaults_to_false(self): + bench = DeclarativeBenchmark.from_dict({"name": "x", "default_robot": "so100", "supported_robots": ["so100"]}) + assert bench.is_success(_BodyStateSim({})) is False + + def test_failure_absent_defaults_to_false(self): + bench = DeclarativeBenchmark.from_dict({"name": "x", "default_robot": "so100", "supported_robots": ["so100"]}) + assert bench.is_failure(_BodyStateSim({})) is False + + def test_empty_success_returns_false(self): + """Non-None but empty success clause must not default to "always true".""" + bench = DeclarativeBenchmark.from_dict( + { + "name": "x", + "default_robot": "so100", + "supported_robots": ["so100"], + "success": {"all": [], "any": []}, + } + ) + assert bench.is_success(_BodyStateSim({})) is False + + +# File loading + + +class TestRegisterBenchmarkFromFile: + def test_register_from_json(self, tmp_path): + spec_path = tmp_path / "drawer.json" + spec_path.write_text( + json.dumps( + { + "name": "drawer", + "default_robot": "so100", + "supported_robots": ["so100"], + "max_steps": 50, + "success": { + "all": [ + {"predicate": "body_above_z", "body": "cube", "z": 0.1}, + ] + }, + } + ) + ) + bench = register_benchmark_from_file("drawer", str(spec_path)) + assert get_benchmark("drawer") is bench + assert bench.max_steps == 50 + assert bench.is_success(_BodyStateSim({"cube": [0, 0, 0.5]})) is True + + def test_register_from_yaml(self, tmp_path): + """YAML support is opt-in; skip if pyyaml isn't available in this env.""" + pytest.importorskip("yaml") + spec_path = tmp_path / "y.yaml" + spec_path.write_text( + """ +name: yml-task +default_robot: so100 +supported_robots: [so100] +max_steps: 99 +success: + all: + - {predicate: body_above_z, body: cube, z: 0.5} +""" + ) + bench = register_benchmark_from_file("yml-task", str(spec_path)) + assert bench.max_steps == 99 + + def test_file_not_found(self, tmp_path): + with pytest.raises(FileNotFoundError): + register_benchmark_from_file("missing", str(tmp_path / "nope.json")) + + def test_rejects_unsupported_extension(self, tmp_path): + p = tmp_path / "spec.toml" + p.write_text("") + with pytest.raises(ValueError) as exc: + register_benchmark_from_file("x", str(p)) + assert ".toml" in str(exc.value) or "extension" in str(exc.value) + + def test_spec_name_internal_overridden_by_registry_name(self, tmp_path): + """Registry name wins over any ``name`` declared inside the spec file.""" + p = tmp_path / "s.json" + p.write_text( + json.dumps( + { + "name": "internal-name", + "default_robot": "so100", + "supported_robots": ["so100"], + } + ) + ) + register_benchmark_from_file("external-name", str(p)) + assert get_benchmark("external-name") is not None + # The spec's internal name doesn't end up in the registry. + assert get_benchmark("internal-name") is None + + def test_rejects_empty_name(self, tmp_path): + p = tmp_path / "s.json" + p.write_text('{"name": "x", "default_robot": "y", "supported_robots": []}') + with pytest.raises(ValueError): + register_benchmark_from_file("", str(p)) + + def test_bad_json_propagates(self, tmp_path): + p = tmp_path / "bad.json" + p.write_text("{not json}") + with pytest.raises(json.JSONDecodeError): + register_benchmark_from_file("x", str(p)) + + +# DeclarativeBenchmark lifecycle + + +class TestDeclarativeBenchmarkLifecycle: + def test_on_episode_start_delegates_to_base(self): + """Default on_episode_start loads the default_robot when sim is empty.""" + spec = { + "name": "x", + "default_robot": "so100", + "supported_robots": ["so100"], + } + bench = DeclarativeBenchmark.from_dict(spec) + + class FakeSim: + def __init__(self): + self.added: list[dict[str, Any]] = [] + + def list_robots(self): + return [] + + def add_robot(self, *, name, data_config): + self.added.append({"name": name, "data_config": data_config}) + + sim = FakeSim() + bench.on_episode_start(sim, random.Random(0)) # type: ignore[arg-type] + assert len(sim.added) == 1 + assert sim.added[0]["data_config"] == "so100" + + def test_scene_load_error_raises(self, tmp_path: Path): + """If the sim's load_scene returns an error dict, the benchmark must surface it.""" + spec = { + "name": "x", + "default_robot": "so100", + "supported_robots": ["so100"], + "scene": str(tmp_path / "missing.xml"), + } + bench = DeclarativeBenchmark.from_dict(spec) + + class FakeSim: + def load_scene(self, path): + return {"status": "error", "content": [{"text": f"no such file: {path}"}]} + + def list_robots(self): + return ["preloaded"] + + sim = FakeSim() + with pytest.raises(RuntimeError) as exc: + bench.on_episode_start(sim, random.Random(0)) # type: ignore[arg-type] + assert "load_scene" in str(exc.value) diff --git a/tests/simulation/test_benchmark_predicates.py b/tests/simulation/test_benchmark_predicates.py new file mode 100644 index 0000000..cddafdb --- /dev/null +++ b/tests/simulation/test_benchmark_predicates.py @@ -0,0 +1,274 @@ +"""Tests for ``strands_robots.simulation.predicates``. + +Each predicate is tested against a lightweight fake sim that implements +only the methods the predicate exercises. Real MuJoCo integration is out +of scope here - those predicates are covered end-to-end in the dispatch +tests under ``tests/simulation/mujoco/``. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from strands_robots.simulation.predicates import ( + PREDICATE_REGISTRY, + make_predicate, + register_predicate, +) + +# Fake sim helpers + + +class _BodyStateSim: + """Sim that exposes ``get_body_state`` with caller-provided positions.""" + + def __init__(self, positions: dict[str, list[float]]): + self._pos = positions + + def get_body_state(self, body_name: str) -> dict[str, Any]: + if body_name not in self._pos: + return {"status": "error", "content": [{"text": f"Body '{body_name}' not found."}]} + return { + "status": "success", + "content": [ + {"text": f"body {body_name}"}, + { + "json": { + "position": self._pos[body_name], + "quaternion": [1, 0, 0, 0], + "mass": 1.0, + } + }, + ], + } + + # Predicates that probe `get_observation` for joint state need this stub. + def get_observation(self, *_, **__) -> dict[str, Any]: + return {} + + +class _JointObsSim: + """Sim that exposes joint positions via ``get_observation``.""" + + def __init__(self, joints: dict[str, float]): + self._joints = joints + + def get_observation(self, *_, **__) -> dict[str, float]: + return dict(self._joints) + + def get_body_state(self, body_name: str) -> dict[str, Any]: # pragma: no cover + return {"status": "error", "content": [{"text": "no bodies"}]} + + +class _ContactSim: + """Sim that exposes ``get_contacts`` in the MuJoCo-backend shape.""" + + def __init__(self, contacts: list[dict[str, Any]]): + self._contacts = contacts + + def get_contacts(self) -> dict[str, Any]: + return { + "status": "success", + "content": [ + {"text": f"{len(self._contacts)} contacts"}, + { + "json": { + "contacts": self._contacts, + "n_contacts": len(self._contacts), + } + }, + ], + } + + +class _NoHelpersSim: + """Sim missing get_body_state / get_contacts entirely (e.g. future backend).""" + + def get_observation(self, *_, **__) -> dict[str, Any]: + return {} + + +# Registry + + +class TestRegistry: + def test_builtin_predicates_registered(self): + required = { + "body_above_z", + "body_below_z", + "joint_above", + "joint_below", + "distance_less_than", + "inside_region", + "contact_between", + "contact_any", + "distance_neg", + "joint_progress", + "constant", + } + assert required.issubset(PREDICATE_REGISTRY.keys()) + + def test_make_predicate_unknown_raises(self): + with pytest.raises(ValueError) as exc: + make_predicate("totally_made_up") + assert "Unknown predicate" in str(exc.value) + # Error message should list valid names so the user can fix the spec. + assert "body_above_z" in str(exc.value) + + def test_register_predicate_rejects_shadow(self): + with pytest.raises(ValueError): + register_predicate("body_above_z", lambda **_: lambda _sim: True) + + def test_register_predicate_rejects_non_callable(self): + with pytest.raises(TypeError): + register_predicate("my_pred", "not a callable") # type: ignore[arg-type] + + def test_register_predicate_custom(self): + try: + + def factory(value: float): + return lambda _sim: value > 0 + + register_predicate("positive_constant", factory) + pred = make_predicate("positive_constant", value=1.5) + assert pred(None) is True + finally: + PREDICATE_REGISTRY.pop("positive_constant", None) + + +# Body-position predicates + + +class TestBodyPositionPredicates: + def test_body_above_z_true(self): + sim = _BodyStateSim({"cube": [0.1, 0.0, 0.25]}) + pred = make_predicate("body_above_z", body="cube", z=0.2) + assert pred(sim) is True + + def test_body_above_z_false(self): + sim = _BodyStateSim({"cube": [0.1, 0.0, 0.15]}) + pred = make_predicate("body_above_z", body="cube", z=0.2) + assert pred(sim) is False + + def test_body_above_z_missing_body_returns_false(self): + sim = _BodyStateSim({"other": [0, 0, 1]}) + pred = make_predicate("body_above_z", body="cube", z=0.2) + assert pred(sim) is False + + def test_body_below_z(self): + sim = _BodyStateSim({"cube": [0.0, 0.0, -0.05]}) + pred = make_predicate("body_below_z", body="cube", z=0.0) + assert pred(sim) is True + + def test_distance_less_than_true(self): + sim = _BodyStateSim({"a": [0, 0, 0], "b": [0.05, 0, 0]}) + pred = make_predicate("distance_less_than", body_a="a", body_b="b", threshold=0.1) + assert pred(sim) is True + + def test_distance_less_than_false(self): + sim = _BodyStateSim({"a": [0, 0, 0], "b": [1.0, 0, 0]}) + pred = make_predicate("distance_less_than", body_a="a", body_b="b", threshold=0.1) + assert pred(sim) is False + + def test_inside_region_matches(self): + sim = _BodyStateSim({"cube": [0.1, 0.2, 0.3]}) + pred = make_predicate("inside_region", body="cube", min=[-0.5, 0.0, 0.0], max=[0.5, 0.5, 1.0]) + assert pred(sim) is True + + def test_inside_region_outside(self): + sim = _BodyStateSim({"cube": [0.6, 0.0, 0.0]}) + pred = make_predicate("inside_region", body="cube", min=[0, 0, 0], max=[0.5, 0.5, 0.5]) + assert pred(sim) is False + + def test_inside_region_rejects_malformed_args(self): + with pytest.raises(ValueError): + make_predicate("inside_region", body="cube", min=[0, 0], max=[1, 1, 1]) + with pytest.raises(ValueError): + # min > max should error up front, not silently always return False. + make_predicate("inside_region", body="cube", min=[1, 1, 1], max=[0, 0, 0]) + + def test_body_predicate_without_get_body_state_returns_false(self): + sim = _NoHelpersSim() + pred = make_predicate("body_above_z", body="cube", z=0) + assert pred(sim) is False + + +# Joint predicates + + +class TestJointPredicates: + def test_joint_above(self): + sim = _JointObsSim({"drawer_slide": 0.18}) + assert make_predicate("joint_above", joint="drawer_slide", value=0.15)(sim) is True + assert make_predicate("joint_above", joint="drawer_slide", value=0.2)(sim) is False + + def test_joint_below(self): + sim = _JointObsSim({"gripper": 0.02}) + assert make_predicate("joint_below", joint="gripper", value=0.05)(sim) is True + + def test_joint_missing_returns_false(self): + sim = _JointObsSim({"other_joint": 1.0}) + assert make_predicate("joint_above", joint="missing", value=0.0)(sim) is False + + def test_joint_progress_reward(self): + sim = _JointObsSim({"drawer": 0.1}) + term = make_predicate("joint_progress", joint="drawer", target=0.2, weight=10.0) + # -weight * |q - target| = -10 * 0.1 = -1.0 + assert term(sim) == pytest.approx(-1.0) + + def test_joint_progress_at_target_gives_zero_reward(self): + sim = _JointObsSim({"drawer": 0.2}) + term = make_predicate("joint_progress", joint="drawer", target=0.2, weight=1.0) + assert term(sim) == pytest.approx(0.0) + + +# Contact predicates + + +class TestContactPredicates: + def test_contact_between_matches_either_order(self): + sim = _ContactSim([{"geom1": "cube", "geom2": "gripper", "dist": -0.001}]) + assert make_predicate("contact_between", geom_a="cube", geom_b="gripper")(sim) is True + assert make_predicate("contact_between", geom_a="gripper", geom_b="cube")(sim) is True + + def test_contact_between_no_match(self): + sim = _ContactSim([{"geom1": "cube", "geom2": "ground"}]) + assert make_predicate("contact_between", geom_a="cube", geom_b="gripper")(sim) is False + + def test_contact_any(self): + assert make_predicate("contact_any")(_ContactSim([{"geom1": "a", "geom2": "b"}])) is True + assert make_predicate("contact_any")(_ContactSim([])) is False + + def test_contact_predicate_without_get_contacts(self): + sim = _NoHelpersSim() + assert make_predicate("contact_any")(sim) is False + assert make_predicate("contact_between", geom_a="a", geom_b="b")(sim) is False + + +# Reward terms + + +class TestRewardTerms: + def test_distance_neg_monotonic(self): + far = _BodyStateSim({"a": [0, 0, 0], "b": [1, 0, 0]}) + near = _BodyStateSim({"a": [0, 0, 0], "b": [0.1, 0, 0]}) + term = make_predicate("distance_neg", body_a="a", body_b="b", weight=1.0) + # Closer is greater (less negative). + assert term(near) > term(far) + + def test_distance_neg_weight(self): + sim = _BodyStateSim({"a": [0, 0, 0], "b": [1, 0, 0]}) + weighted = make_predicate("distance_neg", body_a="a", body_b="b", weight=5.0) + assert weighted(sim) == pytest.approx(-5.0) + + def test_distance_neg_missing_body_returns_zero(self): + """Missing bodies should not crash or reward heavily - return 0.0.""" + sim = _BodyStateSim({"a": [0, 0, 0]}) + term = make_predicate("distance_neg", body_a="a", body_b="ghost", weight=1.0) + assert term(sim) == 0.0 + + def test_constant(self): + term = make_predicate("constant", value=-0.01) + assert term(None) == pytest.approx(-0.01) diff --git a/tests/simulation/test_policy_runner_benchmark.py b/tests/simulation/test_policy_runner_benchmark.py new file mode 100644 index 0000000..ec31e11 --- /dev/null +++ b/tests/simulation/test_policy_runner_benchmark.py @@ -0,0 +1,462 @@ +"""Tests for ``PolicyRunner.evaluate`` with a :class:`BenchmarkProtocol` spec. + +Covers the new spec-driven evaluation path: + +* Cumulative reward accounting across an episode. +* Early termination on ``is_success`` / ``is_failure`` / ``StepInfo.done``. +* Per-episode seeded RNG so identical seeds produce identical rollouts. +* Robot-compatibility error surfaces as a structured error dict (not raises). +* Legacy ``success_fn`` path still works. +* Passing both ``spec`` and ``success_fn`` is an error. +* :meth:`SimEngine.evaluate_benchmark` facade end-to-end. +* :meth:`SimEngine.list_benchmarks` / :meth:`register_benchmark_from_file` + facades return structured dicts. +""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from strands_robots.policies.mock import MockPolicy +from strands_robots.simulation.base import SimEngine +from strands_robots.simulation.benchmark import ( + _BENCHMARK_REGISTRY, + BenchmarkProtocol, + StepInfo, + register_benchmark, +) +from strands_robots.simulation.policy_runner import PolicyRunner + +# Fixtures + + +@pytest.fixture(autouse=True) +def _clean_registry(): + snapshot = dict(_BENCHMARK_REGISTRY) + _BENCHMARK_REGISTRY.clear() + yield + _BENCHMARK_REGISTRY.clear() + _BENCHMARK_REGISTRY.update(snapshot) + + +class _FakeRobot: + """Simple object with a data_config attr so the base compat check passes.""" + + def __init__(self, data_config: str): + self.data_config = data_config + + +class _FakeWorld: + """Minimal world with a ``robots`` dict - enough for the base compat check.""" + + def __init__(self, robots: dict[str, _FakeRobot]): + self.robots = dict(robots) + + +class FakeSim(SimEngine): + """Minimal ``SimEngine`` that records a few per-episode counters. + + Deliberately stripped-down: no cameras, no physics - just enough for + :class:`PolicyRunner` to step through. + """ + + def __init__(self, joint_names: tuple[str, ...] = ("j0", "j1"), data_config: str = "so100"): + self._joint_names = list(joint_names) + self._data_config = data_config + self._step_count = 0 + self._reset_count = 0 + self._world = _FakeWorld({"fake_robot": _FakeRobot(data_config)}) + + def create_world(self, timestep=None, gravity=None, ground_plane=True): + return {"status": "success"} + + def destroy(self): + return {"status": "success"} + + def reset(self): + self._step_count = 0 + self._reset_count += 1 + return {"status": "success"} + + def step(self, n_steps: int = 1): + self._step_count += n_steps + return {"status": "success"} + + def get_state(self): + return {"step_count": self._step_count} + + def add_robot(self, name, **kw): + data_config = kw.get("data_config") or self._data_config + self._world.robots[name] = _FakeRobot(data_config) + return {"status": "success"} + + def remove_robot(self, name): + return {"status": "success"} + + def list_robots(self) -> list[str]: + return list(self._world.robots.keys()) + + def robot_joint_names(self, robot_name: str) -> list[str]: + return list(self._joint_names) + + def add_object(self, name, **kw): + return {"status": "success"} + + def remove_object(self, name): + return {"status": "success"} + + def get_observation(self, robot_name=None, *, skip_images=False): + return {n: 0.0 for n in self._joint_names} + + def send_action(self, action, robot_name=None, n_substeps=1): + self._step_count += 1 + + def render(self, camera_name="default", width=None, height=None): + return {"status": "success", "content": [{"text": "render"}]} + + +# Test benchmarks + + +class _CountingBenchmark(BenchmarkProtocol): + """Benchmark that tracks how many times each hook was called and rewards +1/step.""" + + max_steps = 20 + + def __init__(self, *, success_after: int = 10**9, fail_after: int = 10**9): + self.success_after = success_after + self.fail_after = fail_after + self.on_episode_start_calls = 0 + self.on_step_calls = 0 + self.rng_seeds_seen: list[int] = [] + + @property + def supported_robots(self) -> list[str]: + return ["so100"] + + @property + def default_robot(self) -> str: + return "so100" + + def on_episode_start(self, sim, rng): + self.on_episode_start_calls += 1 + # Record the first draw for reproducibility tests. + self.rng_seeds_seen.append(rng.randint(0, 1000000)) + super().on_episode_start(sim, rng) + + def on_step(self, sim, obs, action): + self.on_step_calls += 1 + return StepInfo(reward=1.0) + + def is_success(self, sim): + return self.on_step_calls >= self.success_after + + def is_failure(self, sim): + return self.on_step_calls >= self.fail_after + + +class _DoneAfterBenchmark(BenchmarkProtocol): + """Returns StepInfo(done=True) on the Nth step.""" + + max_steps = 50 + + def __init__(self, done_after: int): + self._done_after = done_after + self._step = 0 + + @property + def supported_robots(self) -> list[str]: + return [] + + @property + def default_robot(self) -> str: + return "so100" + + def on_step(self, sim, obs, action): + self._step += 1 + return StepInfo(reward=0.5, done=self._step >= self._done_after) + + def is_success(self, sim): + return False + + def is_failure(self, sim): + return False + + +# Cumulative reward + + +class TestCumulativeReward: + def test_sums_reward_across_steps(self): + sim = FakeSim() + policy = MockPolicy() + policy.set_robot_state_keys(sim.robot_joint_names("fake_robot")) + spec = _CountingBenchmark() + + result = PolicyRunner(sim).evaluate( + "fake_robot", + policy, + spec=spec, + n_episodes=1, + ) + assert result["status"] == "success" + payload = next(c["json"] for c in result["content"] if "json" in c) + # _CountingBenchmark rewards +1/step; max_steps=20 → 20.0 cumulative. + assert payload["episodes"][0]["cumulative_reward"] == pytest.approx(20.0) + assert payload["avg_reward"] == pytest.approx(20.0) + + def test_success_terminates_and_stops_reward_accumulation(self): + sim = FakeSim() + policy = MockPolicy() + policy.set_robot_state_keys(sim.robot_joint_names("fake_robot")) + spec = _CountingBenchmark(success_after=5) + + result = PolicyRunner(sim).evaluate("fake_robot", policy, spec=spec, n_episodes=1) + payload = next(c["json"] for c in result["content"] if "json" in c) + assert payload["n_success"] == 1 + # Reward per step = 1; terminates on the 5th step. + assert payload["episodes"][0]["cumulative_reward"] == pytest.approx(5.0) + assert payload["episodes"][0]["steps"] == 5 + + def test_failure_marks_episode_unsuccessful(self): + sim = FakeSim() + policy = MockPolicy() + policy.set_robot_state_keys(sim.robot_joint_names("fake_robot")) + spec = _CountingBenchmark(fail_after=3) + + result = PolicyRunner(sim).evaluate("fake_robot", policy, spec=spec, n_episodes=1) + payload = next(c["json"] for c in result["content"] if "json" in c) + assert payload["n_failure"] == 1 + assert payload["n_success"] == 0 + assert payload["episodes"][0]["failure"] is True + assert payload["episodes"][0]["success"] is False + + def test_done_flag_terminates_episode(self): + """StepInfo.done=True ends the episode even without is_success/is_failure.""" + sim = FakeSim() + policy = MockPolicy() + policy.set_robot_state_keys(sim.robot_joint_names("fake_robot")) + spec = _DoneAfterBenchmark(done_after=4) + + result = PolicyRunner(sim).evaluate("fake_robot", policy, spec=spec, n_episodes=1) + payload = next(c["json"] for c in result["content"] if "json" in c) + assert payload["episodes"][0]["steps"] == 4 + + +# Seed reproducibility + + +class TestSeedReproducibility: + def test_same_seed_same_rng_draws(self): + """Two evaluations with the same seed must produce identical per-episode RNG draws.""" + sim1 = FakeSim() + sim2 = FakeSim() + policy = MockPolicy() + policy.set_robot_state_keys(sim1.robot_joint_names("fake_robot")) + + spec1 = _CountingBenchmark() + spec2 = _CountingBenchmark() + + PolicyRunner(sim1).evaluate("fake_robot", policy, spec=spec1, n_episodes=3, seed=42) + PolicyRunner(sim2).evaluate("fake_robot", policy, spec=spec2, n_episodes=3, seed=42) + + assert spec1.rng_seeds_seen == spec2.rng_seeds_seen + + def test_different_seed_different_rng_draws(self): + sim1 = FakeSim() + sim2 = FakeSim() + policy = MockPolicy() + policy.set_robot_state_keys(sim1.robot_joint_names("fake_robot")) + + spec1 = _CountingBenchmark() + spec2 = _CountingBenchmark() + + PolicyRunner(sim1).evaluate("fake_robot", policy, spec=spec1, n_episodes=3, seed=42) + PolicyRunner(sim2).evaluate("fake_robot", policy, spec=spec2, n_episodes=3, seed=999) + + assert spec1.rng_seeds_seen != spec2.rng_seeds_seen + + def test_seed_recorded_in_per_episode_results(self): + sim = FakeSim() + policy = MockPolicy() + policy.set_robot_state_keys(sim.robot_joint_names("fake_robot")) + + result = PolicyRunner(sim).evaluate("fake_robot", policy, spec=_CountingBenchmark(), n_episodes=3, seed=7) + payload = next(c["json"] for c in result["content"] if "json" in c) + # Every episode has a distinct seed derived from the master seed. + seeds = [e["seed"] for e in payload["episodes"]] + assert len(set(seeds)) == 3 + assert payload["seed"] == 7 + + +# Robot compatibility via the spec path + + +class TestRobotCompatibility: + def test_mismatched_robot_returns_structured_error(self): + """Spec with supported=['panda'] vs sim loaded with 'so100' → structured error.""" + + class _PandaOnly(_CountingBenchmark): + @property + def supported_robots(self) -> list[str]: + return ["panda"] + + @property + def default_robot(self) -> str: + return "panda" + + sim = FakeSim(data_config="so100") # loaded with so100 + policy = MockPolicy() + policy.set_robot_state_keys(sim.robot_joint_names("fake_robot")) + + result = PolicyRunner(sim).evaluate("fake_robot", policy, spec=_PandaOnly(), n_episodes=1) + assert result["status"] == "error" + text = result["content"][0]["text"] + assert "compatibility" in text.lower() or "supported" in text.lower() + assert "so100" in text # shows the offending data_config + assert "panda" in text # shows the allowed list + + +# Legacy success_fn path still works + + +class TestBackwardCompatibility: + def test_legacy_success_fn_callable_still_works(self): + """The pre-PR success_fn=callable path must be unchanged.""" + sim = FakeSim() + policy = MockPolicy() + policy.set_robot_state_keys(sim.robot_joint_names("fake_robot")) + + result = PolicyRunner(sim).evaluate( + "fake_robot", + policy, + n_episodes=2, + max_steps=5, + success_fn=lambda _obs: True, + ) + payload = next(c["json"] for c in result["content"] if "json" in c) + assert payload["success_rate"] == 1.0 + # Legacy path doesn't emit cumulative_reward; just the pre-PR schema. + assert "cumulative_reward" not in payload + + def test_legacy_success_fn_none_returns_zero_success(self): + sim = FakeSim() + policy = MockPolicy() + policy.set_robot_state_keys(sim.robot_joint_names("fake_robot")) + + result = PolicyRunner(sim).evaluate("fake_robot", policy, n_episodes=1, max_steps=3, success_fn=None) + payload = next(c["json"] for c in result["content"] if "json" in c) + assert payload["n_success"] == 0 + + def test_cannot_pass_both_spec_and_success_fn(self): + sim = FakeSim() + policy = MockPolicy() + policy.set_robot_state_keys(sim.robot_joint_names("fake_robot")) + + result = PolicyRunner(sim).evaluate( + "fake_robot", + policy, + spec=_CountingBenchmark(), + success_fn=lambda _obs: True, + ) + assert result["status"] == "error" + assert "both" in result["content"][0]["text"].lower() + + +# SimEngine facade + + +class TestSimEngineFacades: + def test_evaluate_benchmark_dispatches_to_runner(self): + sim = FakeSim() + spec = _CountingBenchmark() + register_benchmark("eval-test", spec) + + result = sim.evaluate_benchmark( + benchmark_name="eval-test", + robot_name="fake_robot", + policy_provider="mock", + n_episodes=1, + seed=3, + ) + assert result["status"] == "success" + payload = next(c["json"] for c in result["content"] if "json" in c) + assert payload["benchmark_class"] == "_CountingBenchmark" + # Full per-step traversal (nothing terminates early). + assert payload["episodes"][0]["steps"] == 20 + + def test_evaluate_benchmark_unknown_name(self): + sim = FakeSim() + result = sim.evaluate_benchmark(benchmark_name="never-registered") + assert result["status"] == "error" + assert "no benchmark registered" in result["content"][0]["text"].lower() + + def test_evaluate_benchmark_auto_picks_sole_robot(self): + """Single-robot scene: robot_name can be omitted.""" + sim = FakeSim() + register_benchmark("auto-robot", _CountingBenchmark()) + result = sim.evaluate_benchmark(benchmark_name="auto-robot", n_episodes=1) + assert result["status"] == "success" + + def test_evaluate_benchmark_requires_robot_name_in_multi_robot(self): + sim = FakeSim() + sim.add_robot("second", data_config="so100") + register_benchmark("multi", _CountingBenchmark()) + result = sim.evaluate_benchmark(benchmark_name="multi") + assert result["status"] == "error" + assert "robot_name" in result["content"][0]["text"] + + def test_list_benchmarks_returns_snapshot(self): + sim = FakeSim() + register_benchmark("a", _CountingBenchmark()) + register_benchmark("b", _CountingBenchmark()) + result = sim.list_benchmarks() + assert result["status"] == "success" + payload = next(c["json"] for c in result["content"] if "json" in c) + assert set(payload["benchmarks"].keys()) == {"a", "b"} + + def test_list_benchmarks_empty(self): + sim = FakeSim() + result = sim.list_benchmarks() + assert result["status"] == "success" + assert "No benchmarks" in result["content"][0]["text"] + + def test_register_benchmark_from_file_success(self, tmp_path: Path): + sim = FakeSim() + spec_path = tmp_path / "s.json" + spec_path.write_text( + json.dumps( + { + "name": "file-bench", + "default_robot": "so100", + "supported_robots": ["so100"], + "max_steps": 7, + } + ) + ) + result = sim.register_benchmark_from_file(benchmark_name="file-bench", spec_path=str(spec_path)) + assert result["status"] == "success" + assert "Registered benchmark" in result["content"][0]["text"] + + def test_register_benchmark_from_file_missing_file(self, tmp_path: Path): + sim = FakeSim() + result = sim.register_benchmark_from_file(benchmark_name="missing", spec_path=str(tmp_path / "nope.json")) + assert result["status"] == "error" + assert "not found" in result["content"][0]["text"].lower() + + def test_register_benchmark_from_file_empty_name(self): + sim = FakeSim() + result = sim.register_benchmark_from_file(benchmark_name="", spec_path="/tmp/x.json") + assert result["status"] == "error" + assert "benchmark_name" in result["content"][0]["text"] + + def test_register_benchmark_from_file_bad_schema(self, tmp_path: Path): + sim = FakeSim() + spec_path = tmp_path / "bad.json" + spec_path.write_text('{"name": "x"}') # missing default_robot + result = sim.register_benchmark_from_file(benchmark_name="bad", spec_path=str(spec_path)) + assert result["status"] == "error" + assert "default_robot" in result["content"][0]["text"]