diff --git a/pyproject.toml b/pyproject.toml index 2c2e0fd..27d3859 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,9 @@ sim-mujoco = [ "imageio>=2.28.0,<3.0.0", "imageio-ffmpeg>=0.4.0,<1.0.0", ] +benchmark-libero = [ + "libero>=0.1.0,<1.0.0", +] all = [ "strands-robots[groot-service]", "strands-robots[lerobot]", @@ -132,7 +135,7 @@ ignore_missing_imports = false # Third-party libs without type stubs [[tool.mypy.overrides]] -module = ["lerobot.*", "gr00t.*", "draccus.*", "msgpack.*", "zmq.*", "huggingface_hub.*", "serial.*", "psutil.*", "torch.*", "torchvision.*", "transformers.*", "einops.*", "robot_descriptions.*", "mujoco.*", "imageio.*"] +module = ["lerobot.*", "gr00t.*", "draccus.*", "msgpack.*", "zmq.*", "huggingface_hub.*", "serial.*", "psutil.*", "torch.*", "torchvision.*", "transformers.*", "einops.*", "robot_descriptions.*", "mujoco.*", "imageio.*", "libero.*"] ignore_missing_imports = true # @tool decorator injects runtime signatures mypy cannot check diff --git a/strands_robots/benchmarks/__init__.py b/strands_robots/benchmarks/__init__.py new file mode 100644 index 0000000..4caa0f1 --- /dev/null +++ b/strands_robots/benchmarks/__init__.py @@ -0,0 +1,14 @@ +"""Strands Robots Benchmarks - per-benchmark adapters layered on :mod:`strands_robots.simulation.benchmark`. + +Adapters live in optional extras so the core package stays dependency-free. +Importing this namespace is cheap; the heavy work happens when a specific +adapter submodule is imported (e.g. ``from strands_robots.benchmarks.libero +import LiberoAdapter``). + +Currently shipped adapters: + +* ``strands_robots.benchmarks.libero`` - LIBERO (Panda-only, ~130 tasks). + Install with ``pip install 'strands-robots[benchmark-libero]'``. + +Tracked follow-ups: Meta-World (#108), RoboSuite (#109). +""" diff --git a/strands_robots/benchmarks/libero/__init__.py b/strands_robots/benchmarks/libero/__init__.py new file mode 100644 index 0000000..551317a --- /dev/null +++ b/strands_robots/benchmarks/libero/__init__.py @@ -0,0 +1,41 @@ +"""LIBERO benchmark adapter - see :mod:`strands_robots.benchmarks.libero.adapter`. + +Public surface (re-exported from submodules so agents can do +``from strands_robots.benchmarks.libero import LiberoAdapter``): + +* :class:`LiberoAdapter` - ``BenchmarkProtocol`` built around a BDDL task. +* :func:`load_libero_suite` - bulk-register every task in a suite. +* :class:`BDDLParseError` - raised on malformed BDDL input. + +The adapter and parser have **no** dependency on the ``libero`` pip +package - you can use them with your own BDDL files. Only +:func:`load_libero_suite` touches the upstream package (to discover task +files), and only when you don't pass an explicit ``bddl_dir=``. +""" + +from strands_robots.benchmarks.libero.adapter import BDDLParseError, LiberoAdapter +from strands_robots.benchmarks.libero.bddl_parser import ( + PREDICATE_VOCABULARY, + BDDLProblem, + compile_goal, + parse_bddl, + parse_bddl_file, +) +from strands_robots.benchmarks.libero.suite import ( + SUITE_NAMES, + available_suites, + load_libero_suite, +) + +__all__ = [ + "BDDLParseError", + "BDDLProblem", + "LiberoAdapter", + "PREDICATE_VOCABULARY", + "SUITE_NAMES", + "available_suites", + "compile_goal", + "load_libero_suite", + "parse_bddl", + "parse_bddl_file", +] diff --git a/strands_robots/benchmarks/libero/adapter.py b/strands_robots/benchmarks/libero/adapter.py new file mode 100644 index 0000000..f14472f --- /dev/null +++ b/strands_robots/benchmarks/libero/adapter.py @@ -0,0 +1,297 @@ +"""``LiberoAdapter`` - :class:`BenchmarkProtocol` driven by a LIBERO BDDL file. + +LIBERO is a suite of ~130 tabletop manipulation tasks built around a Franka +Panda. Each task ships as a BDDL problem file + an MJCF scene. The adapter +compiles the BDDL ``:goal`` into a sparse success predicate via +:mod:`strands_robots.benchmarks.libero.bddl_parser` and drives the scene +through the standard :class:`BenchmarkProtocol` lifecycle: + +1. :meth:`on_episode_start` - optional ``sim.load_scene(scene_path)``, then + the base ``BenchmarkProtocol`` compatibility check (Panda-only), then + per-episode jitter of ``(:init ...)`` object positions. +2. :meth:`on_step` - sparse: ``StepInfo(reward=0.0, done=False)``. LIBERO + does not define a dense reward. +3. :meth:`is_success` - walks the compiled ``:goal`` predicate tree against + the current sim state. + +**Panda-only by design.** LIBERO's scene MJCFs ```` Panda geometry +and BDDL predicates reference Panda gripper body names +(``robot0_gripper_*``). Retargeting to a different robot would require +rewriting every BDDL predicate against different body names and is out of +scope for this adapter. Subclass :class:`LiberoAdapter` and override +:attr:`supported_robots` + :attr:`default_robot` if you know what you're +doing. + +The adapter does NOT require the ``libero`` Python package to be installed - +only a BDDL string / file and (optionally) an MJCF scene path. The +:func:`strands_robots.benchmarks.libero.suite.load_libero_suite` helper is +the one that pulls in the upstream package to discover task files. +""" + +from __future__ import annotations + +import logging +from collections.abc import Callable +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from strands_robots.benchmarks.libero.bddl_parser import ( + BDDLParseError, + BDDLProblem, + Node, + compile_goal, + parse_bddl, + parse_bddl_file, +) +from strands_robots.simulation.benchmark import BenchmarkProtocol, StepInfo + +if TYPE_CHECKING: + import random + + from strands_robots.simulation.base import SimEngine + +logger = logging.getLogger(__name__) + + +class LiberoAdapter(BenchmarkProtocol): + """Panda-only :class:`BenchmarkProtocol` driven by a parsed LIBERO BDDL task. + + Construct with a BDDL file path (``from_file``) or raw BDDL text + (``from_text``) - direct ``__init__`` is for advanced use when you + already have a :class:`BDDLProblem`. + + Example:: + + from strands_robots.benchmarks.libero import LiberoAdapter + + adapter = LiberoAdapter.from_file( + "libero/tasks/libero_spatial/pick_up_the_red_cube.bddl", + scene_path="libero/assets/scenes/libero_spatial_scene.xml", + ) + sim.register_benchmark("pick-red-cube", adapter) + sim.evaluate_benchmark("pick-red-cube", policy_provider="mock", + n_episodes=10, seed=42) + + Attributes: + max_steps: Default 300 (LIBERO convention). Override per-task by + passing ``max_steps=`` to the constructor or mutating the + attribute after construction. + problem: The parsed :class:`BDDLProblem`. Stored for introspection + (agents may read ``problem.language`` as the instruction). + """ + + max_steps: int = 300 + supported_robots_list: list[str] = ["panda"] + default_robot_name: str = "panda" + + def __init__( + self, + problem: BDDLProblem, + *, + scene_path: str | None = None, + max_steps: int | None = None, + init_jitter: float = 0.02, + ): + """Construct from a pre-parsed :class:`BDDLProblem`. + + Args: + problem: Parsed BDDL problem with a non-``None`` ``goal``. + scene_path: Optional MJCF to ``sim.load_scene()`` on each + episode start. ``None`` → assume scene is pre-loaded. + max_steps: Override the class-level 300. + init_jitter: Per-episode ±jitter (metres) applied to xy of every + object referenced by ``(:init (on A B))`` clauses. Set to 0 + to disable jitter. + + Raises: + ValueError: If ``problem.goal`` is ``None``. + """ + if problem.goal is None: + raise ValueError(f"LiberoAdapter: BDDL problem {problem.name!r} has no (:goal ...) block") + self.problem = problem + self.scene_path = scene_path + self._init_jitter = float(init_jitter) + if self._init_jitter < 0: + raise ValueError(f"init_jitter must be >= 0, got {init_jitter}") + if max_steps is not None: + self.max_steps = int(max_steps) + self._success_fn: Callable[[SimEngine], bool] = compile_goal(problem.goal) + + # Construction helpers + + @classmethod + def from_file( + cls, + bddl_path: str | Path, + *, + scene_path: str | None = None, + max_steps: int | None = None, + init_jitter: float = 0.02, + ) -> LiberoAdapter: + """Parse a ``.bddl`` file from disk and build an adapter. + + Raises :class:`FileNotFoundError` / :class:`BDDLParseError` on bad + input - callers that want structured error dicts should catch and + convert. + """ + problem = parse_bddl_file(bddl_path) + return cls( + problem, + scene_path=scene_path, + max_steps=max_steps, + init_jitter=init_jitter, + ) + + @classmethod + def from_text( + cls, + bddl_text: str, + *, + scene_path: str | None = None, + max_steps: int | None = None, + init_jitter: float = 0.02, + ) -> LiberoAdapter: + """Parse a BDDL string directly - useful in tests.""" + problem = parse_bddl(bddl_text) + return cls( + problem, + scene_path=scene_path, + max_steps=max_steps, + init_jitter=init_jitter, + ) + + # BenchmarkProtocol interface + + @property + def supported_robots(self) -> list[str]: + return list(self.supported_robots_list) + + @property + def default_robot(self) -> str: + return self.default_robot_name + + @property + def instruction(self) -> str: + """Language instruction from the BDDL ``:language`` clause, or ``""``.""" + return self.problem.language or "" + + def on_episode_start(self, sim: SimEngine, rng: random.Random) -> None: + """Load the declared scene (if any), validate Panda, then apply init jitter. + + Order matters: load_scene MUST happen before ``super().on_episode_start`` + so the base compatibility check sees the scene's Panda robot rather + than reporting "sim is empty → load default_robot". + """ + if self.scene_path: + load_scene = getattr(sim, "load_scene", None) + if load_scene is None: + logger.warning( + "LiberoAdapter: sim has no load_scene(); skipping scene_path=%r", + self.scene_path, + ) + else: + result = load_scene(self.scene_path) + if isinstance(result, dict) and result.get("status") == "error": + msg = (result.get("content") or [{}])[0].get("text", "") + raise RuntimeError(f"LiberoAdapter: load_scene({self.scene_path!r}) failed: {msg}") + super().on_episode_start(sim, rng) + if self._init_jitter > 0: + self._apply_init_jitter(sim, rng) + + def on_step( + self, + sim: SimEngine, + obs: dict[str, Any], + action: dict[str, Any], + ) -> StepInfo: + """Sparse step: zero reward, never ``done``. Success is detected by + :meth:`is_success` at the outer eval loop.""" + return StepInfo(reward=0.0, done=False) + + def is_success(self, sim: SimEngine) -> bool: + return bool(self._success_fn(sim)) + + # Internals + + def _apply_init_jitter(self, sim: SimEngine, rng: random.Random) -> None: + """Apply ±jitter to xy of every body referenced by ``(:init (on A B))``. + + Best-effort: if the sim doesn't expose ``move_object`` / ``get_body_state``, + or the body isn't in the scene, silently skip. This matches LIBERO's + "small random perturbation per episode" convention without requiring + full BDDL init semantics. + """ + move_object = getattr(sim, "move_object", None) + if move_object is None: + logger.debug("LiberoAdapter: sim has no move_object(); skipping init jitter") + return + get_body_state = getattr(sim, "get_body_state", None) + if get_body_state is None: + return + + # Gather the set of bodies we want to jitter - BDDL init uses the same + # Pred grammar, so (on cube_1 table_1) means "jitter cube_1". + from strands_robots.benchmarks.libero.bddl_parser import Pred as _Pred + + seen: set[str] = set() + for node in self.problem.init: + for body in _extract_init_targets(node): + seen.add(body) + _ = _Pred # referenced for clarity; actual test is inside _extract_init_targets + + for body in sorted(seen): + try: + state = get_body_state(body_name=body) + except Exception as e: # noqa: BLE001 - defensive + logger.debug("jitter lookup for %r failed: %s", body, e) + continue + if not isinstance(state, dict) or state.get("status") != "success": + continue + pos = _extract_position(state) + if pos is None: + continue + jx = rng.uniform(-self._init_jitter, self._init_jitter) + jy = rng.uniform(-self._init_jitter, self._init_jitter) + new_pos = [pos[0] + jx, pos[1] + jy, pos[2]] + try: + move_object(name=body, position=new_pos) + except Exception as e: # noqa: BLE001 - jitter failures are not fatal + logger.debug("jitter apply for %r failed: %s", body, e) + + +def _extract_init_targets(node: Node) -> list[str]: + """Return the first-arg body name of every leaf predicate in ``node``. + + Init clauses like ``(on cube_1 table_1)`` and ``(upright bottle_1)`` + share the convention that the first argument is the "subject" body - + the thing whose position we may want to jitter. Nested + ``and``/``or``/``not`` are traversed; non-predicates are ignored. + """ + from strands_robots.benchmarks.libero.bddl_parser import And, Not, Or, Pred + + if isinstance(node, Pred): + return [node.args[0]] if node.args else [] + if isinstance(node, (And, Or)): + out: list[str] = [] + for c in node.clauses: + out.extend(_extract_init_targets(c)) + return out + if isinstance(node, Not): + return _extract_init_targets(node.clause) + return [] + + +def _extract_position(state: dict[str, Any]) -> list[float] | None: + """Pull ``{"json": {"position": [...]}}`` from a status-dict payload.""" + for block in state.get("content", []) or []: + if isinstance(block, dict) and isinstance(block.get("json"), dict): + pos = block["json"].get("position") + if isinstance(pos, list) and len(pos) == 3 and all(isinstance(c, (int, float)) for c in pos): + return [float(c) for c in pos] + return None + + +__all__ = [ + "BDDLParseError", + "LiberoAdapter", +] diff --git a/strands_robots/benchmarks/libero/bddl_parser.py b/strands_robots/benchmarks/libero/bddl_parser.py new file mode 100644 index 0000000..2c9fd60 --- /dev/null +++ b/strands_robots/benchmarks/libero/bddl_parser.py @@ -0,0 +1,413 @@ +"""BDDL parser - LIBERO task files → named-predicate AST. + +LIBERO ships its tasks as ``.bddl`` files written in a PDDL-derived +s-expression syntax: + +.. code-block:: lisp + + (define (problem libero_pick_cube) + (:domain kitchen) + (:language "pick up the red cube and place it on the plate") + (:objects cube_1 plate_1 - object) + (:init + (on cube_1 table_1)) + (:goal + (and + (on cube_1 plate_1) + (not (grasped cube_1))))) + +This module parses that into a :class:`BDDLProblem` whose ``:goal`` compiles +to a single ``(SimEngine) -> bool`` callable via +:mod:`strands_robots.simulation.predicates`. Crucially it **never** evaluates +user code - the BDDL grammar is a closed set of tokens plus a whitelisted +predicate vocabulary; anything outside that set raises +:class:`BDDLParseError`. + +Unknown predicates are rejected rather than silently evaluated to ``False``: +a BDDL parse failure is always preferable to a misleading success rate. + +Scope of this parser (matches what LIBERO actually uses): + +* Top-level form: ``(define (problem ) ...)`` +* Section markers: ``:domain``, ``:objects``, ``:init``, ``:goal``, ``:language`` +* Boolean combinators: ``and``, ``or``, ``not`` +* Predicate vocabulary: ``on``, ``near``, ``inside``, ``open``, ``closed``, + ``grasped``, ``upright`` + +Everything else is either dropped silently (typed-object annotations like +``obj1 - object`` are flattened to just the symbols) or raises +:class:`BDDLParseError` depending on whether leniency would mask real bugs. +""" + +from __future__ import annotations + +import logging +import re +from collections.abc import Callable +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from strands_robots.simulation.predicates import make_predicate + +if TYPE_CHECKING: + from strands_robots.simulation.base import SimEngine + +logger = logging.getLogger(__name__) + + +class BDDLParseError(ValueError): + """Raised when a BDDL file fails to tokenize, parse, or compile.""" + + +# AST nodes + + +@dataclass(frozen=True) +class Pred: + """Leaf predicate: ``(on cube_1 plate_1)`` → ``Pred("on", ["cube_1", "plate_1"])``.""" + + name: str + args: tuple[str, ...] + + +@dataclass(frozen=True) +class And: + clauses: tuple[Node, ...] + + +@dataclass(frozen=True) +class Or: + clauses: tuple[Node, ...] + + +@dataclass(frozen=True) +class Not: + clause: Node + + +Node = Pred | And | Or | Not + + +@dataclass +class BDDLProblem: + """Parsed representation of a BDDL file.""" + + name: str + domain: str | None = None + language: str | None = None + objects: list[str] = field(default_factory=list) + init: list[Node] = field(default_factory=list) + goal: Node | None = None + + +# Tokenizer + s-expression parser + + +_COMMENT_RE = re.compile(r";[^\n]*") +_PAREN_RE = re.compile(r"([()])") + + +def _tokenize(text: str) -> list[str]: + """Split BDDL text into paren / atom tokens. + + LIBERO BDDL allows ``;`` line comments and double-quoted ``:language`` + strings. We strip comments first, then walk the text pairing up quoted + regions so whitespace inside them isn't split. + """ + # Strip ``;`` line comments (but not inside quotes). + out_tokens: list[str] = [] + i = 0 + # Use a simple hand-rolled scanner so quoted strings stay intact. + s = text + n = len(s) + while i < n: + c = s[i] + if c == ";": + # Skip until newline. + nl = s.find("\n", i) + i = n if nl == -1 else nl + 1 + continue + if c.isspace(): + i += 1 + continue + if c in "()": + out_tokens.append(c) + i += 1 + continue + if c == '"': + # Quoted string - find matching quote. No escape sequences in LIBERO + # so this is a plain scan. + end = s.find('"', i + 1) + if end == -1: + raise BDDLParseError(f"unterminated quoted string at offset {i}") + out_tokens.append(s[i : end + 1]) + i = end + 1 + continue + # Atom - grab until whitespace or paren. + j = i + while j < n and not s[j].isspace() and s[j] not in "()": + j += 1 + out_tokens.append(s[i:j]) + i = j + return out_tokens + + +def _parse_sexp(tokens: list[str]) -> Any: + """Consume tokens (in place, reversed-stack style) and return a nested list. + + Atoms stay as ``str``. Lists nest normally. Caller is responsible for + raising on leftover tokens. + """ + if not tokens: + raise BDDLParseError("unexpected end of input") + token = tokens.pop(0) + if token == "(": + out: list[Any] = [] + while tokens and tokens[0] != ")": + out.append(_parse_sexp(tokens)) + if not tokens: + raise BDDLParseError("missing closing ')'") + tokens.pop(0) # consume ")" + return out + if token == ")": + raise BDDLParseError("unexpected ')'") + return token + + +# Predicate vocabulary + +# BDDL predicate name → (predicate-registry name, args → kwargs adapter). +# The adapter is how LIBERO's positional BDDL args map to our kwarg-style +# predicate library. ``*args`` is the list from ``expr[1:]`` at compile time. + + +def _on_kwargs(args: list[str]) -> dict[str, Any]: + if len(args) != 2: + raise BDDLParseError(f"(on ...) expects 2 args, got {len(args)}: {args}") + return {"body_a": args[0], "body_b": args[1]} + + +def _near_kwargs(args: list[str]) -> dict[str, Any]: + if len(args) != 2: + raise BDDLParseError(f"(near ...) expects 2 args, got {len(args)}: {args}") + return {"body_a": args[0], "body_b": args[1], "threshold": 0.1} + + +def _inside_kwargs(args: list[str]) -> dict[str, Any]: + if len(args) != 2: + raise BDDLParseError(f"(inside ...) expects 2 args, got {len(args)}: {args}") + return {"body": args[0], "container": args[1]} + + +def _open_kwargs(args: list[str]) -> dict[str, Any]: + if len(args) != 1: + raise BDDLParseError(f"(open ...) expects 1 arg, got {len(args)}: {args}") + return {"joint": args[0], "value": 0.1} + + +def _closed_kwargs(args: list[str]) -> dict[str, Any]: + if len(args) != 1: + raise BDDLParseError(f"(closed ...) expects 1 arg, got {len(args)}: {args}") + return {"joint": args[0], "value": 0.01} + + +def _grasped_kwargs(args: list[str]) -> dict[str, Any]: + if len(args) != 1: + raise BDDLParseError(f"(grasped ...) expects 1 arg, got {len(args)}: {args}") + # LIBERO scenes use robot0_gripper_* for Panda gripper geoms. Adapters that + # need a different prefix should subclass LiberoAdapter and override + # ``GRIPPER_PREFIX``. + return {"body": args[0], "gripper_prefix": "robot0_gripper"} + + +def _upright_kwargs(args: list[str]) -> dict[str, Any]: + if len(args) != 1: + raise BDDLParseError(f"(upright ...) expects 1 arg, got {len(args)}: {args}") + return {"body": args[0], "tol": 0.15} + + +#: Whitelist of BDDL predicates → predicate-registry entries. Compiled once +#: at import; mutating this dict in-place is the extension point for +#: adapters that need to add benchmark-specific predicates. Keep in sync +#: with the docstring at the top of this module. +PREDICATE_VOCABULARY: dict[str, tuple[str, Callable[[list[str]], dict[str, Any]]]] = { + "on": ("body_on", _on_kwargs), + "near": ("distance_less_than", _near_kwargs), + "inside": ("body_inside", _inside_kwargs), + "open": ("joint_above", _open_kwargs), + "closed": ("joint_below", _closed_kwargs), + "grasped": ("grasped", _grasped_kwargs), + "upright": ("body_upright", _upright_kwargs), +} + + +# Top-level parser + + +def parse_bddl(text: str) -> BDDLProblem: + """Parse a BDDL string into a :class:`BDDLProblem`. + + Args: + text: Contents of a ``.bddl`` file. + + Raises: + BDDLParseError: On tokenizer, parser, or vocabulary failures. The + message always names the offending construct. + """ + tokens = _tokenize(text) + if not tokens: + raise BDDLParseError("empty BDDL input") + sexp = _parse_sexp(tokens) + if tokens: + raise BDDLParseError(f"trailing tokens after top-level form: {tokens[:5]!r}") + + if not isinstance(sexp, list) or not sexp or sexp[0] != "define": + raise BDDLParseError("expected top-level (define ...) form") + + problem_name = "" + domain: str | None = None + language: str | None = None + objects: list[str] = [] + init_nodes: list[Node] = [] + goal_node: Node | None = None + + for child in sexp[1:]: + if not isinstance(child, list) or not child: + continue + head = child[0] + if head == "problem": + if len(child) >= 2 and isinstance(child[1], str): + problem_name = child[1] + elif head == ":domain": + if len(child) >= 2 and isinstance(child[1], str): + domain = child[1] + elif head == ":language": + # Language strings are quoted ("pick up the cube"). Strip quotes. + pieces: list[str] = [] + for c in child[1:]: + if isinstance(c, str): + if c.startswith('"') and c.endswith('"'): + pieces.append(c[1:-1]) + else: + pieces.append(c) + language = " ".join(pieces) if pieces else None + elif head == ":objects": + # LIBERO uses PDDL typed syntax: ``cube_1 plate_1 - object``. We + # flatten to just the symbols; the ``-`` and type annotations + # don't affect predicate evaluation. + for c in child[1:]: + if isinstance(c, str) and c != "-": + objects.append(c) + elif head == ":init": + for c in child[1:]: + if not isinstance(c, list): + continue + # Init entries use the same predicate grammar - compile each. + try: + init_nodes.append(_compile_ast(c)) + except BDDLParseError as e: + # Init failures are not fatal - they're just "declared + # initial state", which the adapter may or may not + # enforce. Log and skip; the goal is the authoritative + # success criterion. + logger.debug("skipping unsupported (:init ...) clause: %s", e) + elif head == ":goal": + if len(child) < 2: + raise BDDLParseError("(:goal ...) is empty") + goal_node = _compile_ast(child[1]) + # Other markers (:requirements, :constants, etc.) are silently ignored. + + return BDDLProblem( + name=problem_name or "unnamed", + domain=domain, + language=language, + objects=objects, + init=init_nodes, + goal=goal_node, + ) + + +def _compile_ast(expr: Any) -> Node: + """Compile a raw s-expression list into the typed :data:`Node` AST.""" + if not isinstance(expr, list) or not expr: + raise BDDLParseError(f"expected predicate s-expression, got {expr!r}") + head = expr[0] + if not isinstance(head, str): + raise BDDLParseError(f"expected symbol head, got {head!r}") + if head == "and": + if len(expr) == 1: + raise BDDLParseError("(and ...) requires at least one clause") + return And(tuple(_compile_ast(c) for c in expr[1:])) + if head == "or": + if len(expr) == 1: + raise BDDLParseError("(or ...) requires at least one clause") + return Or(tuple(_compile_ast(c) for c in expr[1:])) + if head == "not": + if len(expr) != 2: + raise BDDLParseError(f"(not ...) expects 1 clause, got {len(expr) - 1}") + return Not(_compile_ast(expr[1])) + if head not in PREDICATE_VOCABULARY: + valid = sorted(PREDICATE_VOCABULARY) + raise BDDLParseError(f"unknown predicate {head!r}. Supported: {valid}") + # Leaf predicate - args are the remainder, must all be strings. + args = [] + for a in expr[1:]: + if not isinstance(a, str): + raise BDDLParseError(f"predicate {head!r}: expected string args, got {a!r}") + args.append(a) + # Validate arity by attempting the kwargs conversion now (fail-fast). + _, adapter = PREDICATE_VOCABULARY[head] + adapter(args) # raises BDDLParseError on bad arity + return Pred(name=head, args=tuple(args)) + + +# Compile AST → callable + + +def compile_goal(node: Node) -> Callable[[SimEngine], bool]: + """Compile a :data:`Node` AST into a single ``(sim) -> bool`` callable. + + The compiled callable is a pure function of ``sim`` state: no hidden RNG, + no per-call allocation past what the leaf predicate closures capture. + Boolean combinators are evaluated with short-circuit semantics. + """ + if isinstance(node, Pred): + registry_name, adapter = PREDICATE_VOCABULARY[node.name] + kwargs = adapter(list(node.args)) + return make_predicate(registry_name, **kwargs) + if isinstance(node, And): + compiled = [compile_goal(c) for c in node.clauses] + return lambda sim: all(p(sim) for p in compiled) + if isinstance(node, Or): + compiled = [compile_goal(c) for c in node.clauses] + return lambda sim: any(p(sim) for p in compiled) + if isinstance(node, Not): + inner = compile_goal(node.clause) + return lambda sim: not inner(sim) + raise BDDLParseError(f"unsupported AST node: {type(node).__name__}") + + +def parse_bddl_file(path: str | Path) -> BDDLProblem: + """Convenience loader - reads ``path`` and runs :func:`parse_bddl`.""" + p = Path(path) + if not p.exists(): + raise FileNotFoundError(f"BDDL file not found: {path}") + if not p.is_file(): + raise ValueError(f"BDDL path is not a file: {path}") + return parse_bddl(p.read_text()) + + +__all__ = [ + "And", + "BDDLParseError", + "BDDLProblem", + "Node", + "Not", + "Or", + "PREDICATE_VOCABULARY", + "Pred", + "compile_goal", + "parse_bddl", + "parse_bddl_file", +] diff --git a/strands_robots/benchmarks/libero/suite.py b/strands_robots/benchmarks/libero/suite.py new file mode 100644 index 0000000..81b2062 --- /dev/null +++ b/strands_robots/benchmarks/libero/suite.py @@ -0,0 +1,223 @@ +"""Bulk benchmark registration for LIBERO task suites. + +LIBERO ships ~130 tasks split across five suites: + +* ``libero-spatial`` - 10 tasks, same objects / different spatial goals +* ``libero-object`` - 10 tasks, different objects / same goal structure +* ``libero-goal`` - 10 tasks, same objects / different goals +* ``libero-10`` - 10 tasks, "short-horizon diverse" +* ``libero-90`` - 90 tasks, "long-horizon diverse" + +Rather than have agents call :func:`register_benchmark` 130× manually, +:func:`load_libero_suite` walks the upstream package's BDDL directory and +registers every task under a predictable ``libero--`` key. +Tasks that fail to parse are logged and skipped - a single malformed BDDL +file should never block the whole suite from loading. + +Layout discovery +---------------- + +The ``libero`` pip package historically keeps BDDL files under +``/libero/bddl_files//*.bddl`` and scene MJCFs +alongside the benchmark code. Because the exact subpath has drifted +between releases, :func:`load_libero_suite` accepts an explicit +``bddl_dir=`` override and falls back to probing a handful of standard +locations when not given. The scene resolver behaves similarly via +``scene_dir=``. + +Callers who already have the BDDL files on disk (e.g. vendored into their +repo) do **not** need the ``libero`` package installed - just pass +``bddl_dir=`` and the function registers from there. +""" + +from __future__ import annotations + +import logging +from collections.abc import Iterable +from pathlib import Path +from typing import TYPE_CHECKING + +from strands_robots.benchmarks.libero.adapter import LiberoAdapter +from strands_robots.benchmarks.libero.bddl_parser import BDDLParseError +from strands_robots.simulation.benchmark import register_benchmark +from strands_robots.utils import require_optional + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + + +# Canonical suite names - these are the keys agents will use. The upstream +# directory names sometimes differ (snake_case vs kebab-case); the name +# resolver accepts both. +SUITE_NAMES = frozenset( + { + "libero_spatial", + "libero_object", + "libero_goal", + "libero_10", + "libero_90", + } +) + + +def _normalise_suite_name(name: str) -> str: + """Accept either ``libero-spatial`` or ``libero_spatial``; return the underscore form. + + Upstream LIBERO uses snake_case directory names; our benchmark registry + uses kebab-case keys. Keep the normalisation centralised. + """ + key = name.strip().lower().replace("-", "_") + if not key.startswith("libero_"): + key = f"libero_{key}" + return key + + +def _candidate_bddl_dirs(libero_root: Path, suite: str) -> list[Path]: + """Return paths to try in order. First existing one wins.""" + return [ + libero_root / "libero" / "bddl_files" / suite, + libero_root / "libero" / "libero" / "bddl_files" / suite, + libero_root / "bddl_files" / suite, + libero_root / "libero" / "tasks" / suite, + libero_root / suite, + ] + + +def _resolve_libero_root() -> Path: + """Find the filesystem root of the installed ``libero`` package. + + Lazily imports ``libero`` via :func:`require_optional` with a helpful + install hint pointing at ``strands-robots[benchmark-libero]``. + """ + libero = require_optional( + "libero", + pip_install="libero", + extra="benchmark-libero", + purpose="LIBERO benchmark suite discovery", + ) + # __file__ lives inside the package; its parent is the package root. + libero_file = getattr(libero, "__file__", None) + if not libero_file: + raise RuntimeError("libero package has no __file__ attribute; cannot locate BDDL tasks") + return Path(libero_file).resolve().parent.parent + + +def load_libero_suite( + suite_name: str, + *, + bddl_dir: str | Path | None = None, + scene_dir: str | Path | None = None, + max_steps: int | None = None, + init_jitter: float = 0.02, + key_prefix: str = "libero", +) -> dict[str, LiberoAdapter]: + """Register every task in ``suite_name`` under the benchmark registry. + + Args: + suite_name: One of ``libero_spatial`` / ``libero_object`` / + ``libero_goal`` / ``libero_10`` / ``libero_90``. Accepts + ``libero-spatial`` form too. + bddl_dir: Explicit directory containing ``*.bddl`` files. When + omitted, tries the installed ``libero`` package layout. + scene_dir: Root under which per-task scene MJCFs live. When + provided, each adapter gets ``scene_path = scene_dir / + .xml`` if the file exists; otherwise scene is left as + ``None`` and the adapter assumes the scene is already loaded. + max_steps: Forwarded to every :class:`LiberoAdapter`. + init_jitter: Forwarded to every :class:`LiberoAdapter`. + key_prefix: Registry key format is ``--``. + Pass ``key_prefix=""`` for ``-``. + + Returns: + ``{registry_name: LiberoAdapter}`` for every successfully registered + task. Failed tasks are logged (at WARNING) and omitted. + + Raises: + FileNotFoundError: If no BDDL directory can be located. + ValueError: If ``suite_name`` isn't a recognised LIBERO suite. + """ + suite = _normalise_suite_name(suite_name) + if suite not in SUITE_NAMES: + raise ValueError(f"Unknown LIBERO suite {suite_name!r}. Valid: {sorted(SUITE_NAMES)}") + + resolved_bddl_dir = _locate_bddl_dir(suite, bddl_dir) + resolved_scene_dir = Path(scene_dir).expanduser().resolve() if scene_dir else None + + registered: dict[str, LiberoAdapter] = {} + failures: list[tuple[str, str]] = [] + + for bddl_file in sorted(resolved_bddl_dir.glob("*.bddl")): + task_stem = bddl_file.stem + registry_name = _format_registry_name(key_prefix, suite, task_stem) + + scene_path: str | None = None + if resolved_scene_dir is not None: + candidate = resolved_scene_dir / f"{task_stem}.xml" + if candidate.exists(): + scene_path = str(candidate) + + try: + adapter = LiberoAdapter.from_file( + bddl_file, + scene_path=scene_path, + max_steps=max_steps, + init_jitter=init_jitter, + ) + except (BDDLParseError, FileNotFoundError, ValueError) as e: + logger.warning("Skipping LIBERO task %s: %s", bddl_file.name, e) + failures.append((str(bddl_file), str(e))) + continue + register_benchmark(registry_name, adapter) + registered[registry_name] = adapter + + logger.info( + "📚 Registered %d LIBERO tasks from %s (skipped %d malformed)", + len(registered), + resolved_bddl_dir, + len(failures), + ) + return registered + + +def _format_registry_name(prefix: str, suite: str, task: str) -> str: + # Suite is in ``libero_spatial`` form; keys use kebab-case. + # * With prefix: ``--`` + # (``libero-spatial-pick_cube`` when prefix="libero"). + # * Without prefix: ``-`` + # (``spatial-pick_cube``) - agents who supply their own key scheme + # don't want the ``libero-`` doubled in. + suite_kebab = suite.replace("_", "-").removeprefix("libero-") + if prefix: + return f"{prefix}-{suite_kebab}-{task}" + return f"{suite_kebab}-{task}" + + +def _locate_bddl_dir(suite: str, override: str | Path | None) -> Path: + if override is not None: + d = Path(override).expanduser().resolve() + if not d.is_dir(): + raise FileNotFoundError(f"bddl_dir does not exist or is not a directory: {d}") + return d + + libero_root = _resolve_libero_root() + for candidate in _candidate_bddl_dirs(libero_root, suite): + if candidate.is_dir(): + return candidate + tried = [str(p) for p in _candidate_bddl_dirs(libero_root, suite)] + raise FileNotFoundError( + f"Could not locate BDDL directory for suite {suite!r}. Tried: {tried}. Pass bddl_dir= to override." + ) + + +def available_suites() -> Iterable[str]: + """Return the canonical set of LIBERO suite names - offline constant.""" + return frozenset(SUITE_NAMES) + + +__all__ = [ + "SUITE_NAMES", + "available_suites", + "load_libero_suite", +] diff --git a/strands_robots/simulation/__init__.py b/strands_robots/simulation/__init__.py index c272008..e832a51 100644 --- a/strands_robots/simulation/__init__.py +++ b/strands_robots/simulation/__init__.py @@ -51,6 +51,19 @@ # Light imports (no heavy deps - stdlib + dataclasses only) from strands_robots.simulation.base import SimEngine +from strands_robots.simulation.benchmark import ( + BenchmarkCompatibilityError, + BenchmarkProtocol, + StepInfo, + get_benchmark, + list_benchmarks, + register_benchmark, + unregister_benchmark, +) +from strands_robots.simulation.benchmark_spec import ( + DeclarativeBenchmark, + register_benchmark_from_file, +) from strands_robots.simulation.factory import ( create_simulation, list_backends, @@ -71,6 +84,11 @@ SimWorld, TrajectoryStep, ) +from strands_robots.simulation.predicates import ( + PREDICATE_REGISTRY, + make_predicate, + register_predicate, +) # Heavy imports (lazy - need strands SDK + mujoco) _LAZY_IMPORTS: dict[str, tuple[str, str]] = { @@ -108,6 +126,20 @@ "resolve_urdf", "list_registered_urdfs", "list_available_models", + # Benchmark protocol + registry + "BenchmarkProtocol", + "BenchmarkCompatibilityError", + "StepInfo", + "register_benchmark", + "unregister_benchmark", + "get_benchmark", + "list_benchmarks", + # Declarative DSL + predicates + "DeclarativeBenchmark", + "register_benchmark_from_file", + "PREDICATE_REGISTRY", + "make_predicate", + "register_predicate", ] diff --git a/strands_robots/simulation/base.py b/strands_robots/simulation/base.py index 54f0071..bc6cb90 100644 --- a/strands_robots/simulation/base.py +++ b/strands_robots/simulation/base.py @@ -449,6 +449,187 @@ def eval_policy( success_fn=success_fn, ) + # Benchmark protocol facades + + def evaluate_benchmark( + self, + benchmark_name: str, + robot_name: str | None = None, + policy_provider: str = "mock", + policy_config: dict[str, Any] | None = None, + instruction: str = "", + n_episodes: int = 1, + seed: int | None = None, + ) -> dict[str, Any]: + """Run a registered :class:`BenchmarkProtocol` against the current sim. + + Benchmark-agnostic evaluation entry point. Looks up ``benchmark_name`` + in the global benchmark registry, validates robot compatibility, and + forwards to :meth:`PolicyRunner.evaluate` with the spec. + ``max_steps`` comes from the benchmark (not a parameter here). + + Args: + benchmark_name: Key from :func:`register_benchmark` / + :func:`register_benchmark_from_file`. + robot_name: Robot to evaluate. If ``None`` and the benchmark has + exactly one supported robot that matches a loaded robot, that + robot is picked; otherwise returns an error. + policy_provider: Policy provider name (forwarded to + :func:`create_policy`). + policy_config: Provider-specific kwargs. + instruction: Natural-language instruction for the policy. + n_episodes: Number of episodes. + seed: Master RNG seed for per-episode reproducibility. + + Returns: + Standard status dict. On success, carries per-episode cumulative + reward + aggregate success_rate / avg_reward / avg_steps in the + JSON payload. + """ + from strands_robots.policies import create_policy + from strands_robots.simulation.benchmark import get_benchmark + + spec = get_benchmark(benchmark_name) + if spec is None: + from strands_robots.simulation.benchmark import list_benchmarks as _list + + available = sorted(_list().keys()) + return { + "status": "error", + "content": [ + { + "text": ( + f"evaluate_benchmark: no benchmark registered under " + f"{benchmark_name!r}. Registered: {available}. " + "Call register_benchmark_from_file or register_benchmark first." + ) + } + ], + } + + robots = self.list_robots() + if not robots: + return {"status": "error", "content": [{"text": "No robots in sim. Add one first."}]} + + resolved_robot = robot_name + if not resolved_robot: + # Try to pick a robot. Prefer single-robot scenes; multi-robot + # scenes require explicit selection. + if len(robots) == 1: + resolved_robot = robots[0] + else: + return { + "status": "error", + "content": [ + { + "text": ( + f"evaluate_benchmark: 'robot_name' is required when the sim has " + f"multiple robots. Loaded: {robots}" + ) + } + ], + } + if resolved_robot not in robots: + return { + "status": "error", + "content": [{"text": f"Robot '{resolved_robot}' not found. Loaded: {robots}"}], + } + + policy = create_policy(policy_provider, **(policy_config or {})) + policy.set_robot_state_keys(self.robot_joint_names(resolved_robot)) + + return PolicyRunner(self).evaluate( + resolved_robot, + policy, + instruction=instruction, + n_episodes=n_episodes, + spec=spec, + seed=seed, + ) + + def list_benchmarks(self) -> dict[str, Any]: + """Enumerate registered benchmarks. + + Returns a standard status dict whose JSON payload contains the + :func:`~strands_robots.simulation.benchmark.list_benchmarks` + metadata snapshot. Safe to call from any backend; the registry is + engine-agnostic. + """ + from strands_robots.simulation.benchmark import list_benchmarks as _list + + snapshot = _list() + if not snapshot: + text = "No benchmarks registered. Use register_benchmark_from_file to add one." + else: + lines = [f"Registered benchmarks ({len(snapshot)}):"] + for name, meta in snapshot.items(): + lines.append( + f" • {name}: {meta['class']} " + f"(robots={meta['supported_robots'] or 'any'}, " + f"default={meta['default_robot']}, " + f"max_steps={meta['max_steps']})" + ) + text = "\n".join(lines) + return { + "status": "success", + "content": [{"text": text}, {"json": {"benchmarks": snapshot}}], + } + + def register_benchmark_from_file( + self, + benchmark_name: str, + spec_path: str, + ) -> dict[str, Any]: + """Load a declarative benchmark spec from disk and register it. + + Wraps :func:`strands_robots.simulation.benchmark_spec.register_benchmark_from_file` + so agents can author benchmarks as YAML / JSON at runtime. Parsing + errors surface as structured error dicts rather than exceptions. + """ + from strands_robots.simulation.benchmark_spec import ( + register_benchmark_from_file as _register, + ) + + if not benchmark_name: + return { + "status": "error", + "content": [{"text": "register_benchmark_from_file: 'benchmark_name' must be non-empty."}], + } + if not spec_path: + return { + "status": "error", + "content": [{"text": "register_benchmark_from_file: 'spec_path' must be non-empty."}], + } + try: + benchmark = _register(benchmark_name, spec_path) + except FileNotFoundError as e: + return {"status": "error", "content": [{"text": f"register_benchmark_from_file: {e}"}]} + except ValueError as e: + return {"status": "error", "content": [{"text": f"register_benchmark_from_file: {e}"}]} + except ImportError as e: + # YAML support requires pyyaml; surface the install hint verbatim. + return {"status": "error", "content": [{"text": f"{e}"}]} + except Exception as e: # noqa: BLE001 - defensive catch-all with clear message + return { + "status": "error", + "content": [{"text": f"register_benchmark_from_file: unexpected error: {e}"}], + } + + return { + "status": "success", + "content": [ + { + "text": ( + f"📋 Registered benchmark '{benchmark_name}' from {spec_path}\n" + f" class: {type(benchmark).__name__}\n" + f" supported_robots: {benchmark.supported_robots or 'any'}\n" + f" default_robot: {benchmark.default_robot}\n" + f" max_steps: {benchmark.max_steps}" + ) + } + ], + } + def _make_run_policy_hook(self, robot_name: str, instruction: str) -> Any: """Override to return an ``on_frame(step, obs, action)`` callable. diff --git a/strands_robots/simulation/benchmark.py b/strands_robots/simulation/benchmark.py new file mode 100644 index 0000000..243eac7 --- /dev/null +++ b/strands_robots/simulation/benchmark.py @@ -0,0 +1,307 @@ +"""Benchmark-agnostic evaluation protocol for any ``SimEngine``. + +Every standard benchmark (LIBERO, Meta-World, RoboSuite, ManiSkill, user-authored +tasks) has a different notion of "what a task is" - sparse-success, dense-reward, +procedural scenes, BDDL predicates, hardcoded robots, etc. The correct abstraction +is the protocol the eval loop calls into, not a benchmark-specific schema. + +:class:`BenchmarkProtocol` is that protocol. Each adapter implements a handful of +lifecycle hooks (``on_episode_start``, ``on_step``, ``is_success``, ``is_failure``) +and declares the robots it is compatible with. The evaluation loop +(:meth:`~strands_robots.simulation.policy_runner.PolicyRunner.evaluate`) drives +the protocol without knowing anything about the underlying benchmark. + +Adapters live in optional extras (``strands-robots[benchmark-libero]`` etc.); +the core package stays dependency-free. A reference :class:`DeclarativeBenchmark` +shipped in :mod:`strands_robots.simulation.benchmark_spec` turns a YAML/JSON +spec into a fully functional ``BenchmarkProtocol`` instance - LLMs can author +and register benchmarks at runtime without writing Python code. + +Registry: a module-level ``dict[str, BenchmarkProtocol]`` keyed by name, +mirroring the shape of :func:`~strands_robots.simulation.model_registry.register_urdf`. +Registration is idempotent-by-overwrite: re-registering the same name replaces +the previous entry and logs a warning. This matches how users iterate on a +spec file during development. + +Thread safety: the registry is guarded by an internal lock so concurrent +registrations from agent threads do not race. The benchmark instances +themselves are expected to be immutable after registration - adapters that +keep per-episode state MUST put it on the ``rng``-scoped call, not on ``self``. +""" + +from __future__ import annotations + +import logging +import random +import threading +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from strands_robots.simulation.base import SimEngine + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class StepInfo: + """Per-step feedback from a :class:`BenchmarkProtocol`. + + Returned by :meth:`BenchmarkProtocol.on_step`. The evaluation loop + accumulates ``reward`` across steps and terminates the episode early + when ``done`` is ``True`` *or* when :meth:`BenchmarkProtocol.is_success` + / :meth:`BenchmarkProtocol.is_failure` fires. + + Attributes: + reward: Dense reward for this step. Sparse-success benchmarks + return ``0.0`` on every step that isn't a terminal success. + done: Early-termination flag - set when the benchmark knows the + episode is over (e.g. the object fell off the table and + nothing further will happen). + info: Free-form metadata propagated into the per-episode result + under the ``info`` key. Safe for small scalars / diagnostics - + do NOT stuff large tensors here. + """ + + reward: float = 0.0 + done: bool = False + info: dict[str, Any] = field(default_factory=dict) + + +class BenchmarkProtocol(ABC): + """Protocol every benchmark (LIBERO, Meta-World, custom) implements. + + Subclass this, declare :attr:`supported_robots` + :attr:`default_robot`, + and implement :meth:`on_step` + :meth:`is_success`. The default + :meth:`on_episode_start` validates robot compatibility and auto-loads + :attr:`default_robot` when the sim is empty, which is the right behaviour + for 90% of adapters; override only if you need per-episode scene setup + beyond what :meth:`SimEngine.reset` provides. + + Robot compatibility is first-class metadata. LIBERO's BDDL and scene + files reference Panda body names, Meta-World hardcodes Sawyer, RoboSuite + parameterizes over a fixed robot list - without declaring which robots a + benchmark accepts, agents will silently evaluate with the wrong robot. + :meth:`~strands_robots.simulation.policy_runner.PolicyRunner.evaluate` + validates the sim's robot against :attr:`supported_robots` before episode + 1 and returns a structured error on mismatch. + + Attributes: + max_steps: Per-episode horizon. Instance attribute (not abstract + property) so subclasses can set it in ``__init__`` or as a class + attribute. Defaults to ``300``. + """ + + max_steps: int = 300 + + # Robot compatibility (first-class metadata) + + @property + @abstractmethod + def supported_robots(self) -> list[str]: + """Registry ``data_config`` names this benchmark accepts. + + Empty list means "any robot" (unusual; dense-reward benchmarks rarely + generalise across embodiments). LIBERO-shaped adapters should return + a closed list of Panda variants; Meta-World should return its Sawyer + variants; etc. + """ + + @property + @abstractmethod + def default_robot(self) -> str: + """Robot :meth:`on_episode_start` loads when the sim is empty. + + Must be an element of :attr:`supported_robots` (or any compatible + registry name when ``supported_robots`` is empty). Declared + separately from ``supported_robots[0]`` so multi-robot benchmarks + can be explicit about their canonical default. + """ + + # Lifecycle hooks + + def on_episode_start(self, sim: SimEngine, rng: random.Random) -> None: + """Per-episode init. Called after ``sim.reset()`` and before the first obs. + + Default implementation enforces robot compatibility: + + * If the sim has no robots, add :attr:`default_robot` via + ``sim.add_robot(name="robot", data_config=default_robot)``. + * Otherwise, validate that every loaded robot's ``data_config`` is + in :attr:`supported_robots` (when non-empty). Mismatches raise + :class:`BenchmarkCompatibilityError` - the eval loop catches that + and returns a structured error with the allowed list. + + Override to layer on per-episode randomization, goal sampling, or + procedural scene generation. Always call ``super().on_episode_start`` + first unless you deliberately want to skip compatibility checks. + + Args: + sim: The engine being driven. + rng: Seeded per-episode RNG. Always use this - don't create your + own ``random.Random()`` or seeding will be non-reproducible. + """ + robots = sim.list_robots() + if not robots: + sim.add_robot(name="robot", data_config=self.default_robot) + return + + # Validate all loaded robots against supported_robots + supported = self.supported_robots + if not supported: # empty list means "any robot" + return + + # data_config lookup is backend-specific; MuJoCo stores it on SimRobot. + # Reach into sim._world if available (cheap duck-typing); otherwise skip + # the check rather than false-positive error. Adapters needing stricter + # checks should override on_episode_start. + world = getattr(sim, "_world", None) + if world is None or not hasattr(world, "robots"): + return + for rname in robots: + robot_obj = world.robots.get(rname) + if robot_obj is None: + continue + data_config = getattr(robot_obj, "data_config", None) + if data_config is None or data_config in supported: + continue + raise BenchmarkCompatibilityError( + robot_name=rname, + data_config=data_config, + supported=supported, + ) + + @abstractmethod + def on_step(self, sim: SimEngine, obs: dict[str, Any], action: dict[str, Any]) -> StepInfo: + """Return dense reward + done flag + info dict for this step. + + Called after every ``sim.send_action(action)`` with the observation + that produced the action and the action itself. Sparse-success + benchmarks return ``StepInfo(reward=0.0, done=False)`` on every + non-terminal step. + """ + + @abstractmethod + def is_success(self, sim: SimEngine) -> bool: + """Terminal success predicate. + + Called every step after :meth:`on_step`. Returning ``True`` ends + the episode with ``success=True``. Must be side-effect-free: the + evaluation loop may call this multiple times per step depending on + how backends batch success / failure checks. + """ + + def is_failure(self, sim: SimEngine) -> bool: + """Optional early-termination failure condition. + + Default: always ``False``. Override to end an episode early without + marking success (e.g. the arm self-collided, the object fell off + the table, the agent picked the wrong object). Failure ends the + episode; it does not count as a success. + """ + return False + + +class BenchmarkCompatibilityError(ValueError): + """Raised when a benchmark's robot compatibility check fails. + + Carries enough context (robot name, loaded data_config, supported list) + for the eval loop to produce an actionable structured error. Subclasses + :class:`ValueError` so code that uses broad ``except ValueError`` still + catches it cleanly. + """ + + def __init__(self, robot_name: str, data_config: str, supported: list[str]): + self.robot_name = robot_name + self.data_config = data_config + self.supported = list(supported) + super().__init__( + f"Robot '{robot_name}' (data_config={data_config!r}) is not compatible " + f"with this benchmark. Supported: {self.supported}" + ) + + +# Registry + +# Module-level registry - mirrors model_registry._URDF_REGISTRY. Mutable dict +# plus an RLock for thread safety; registry ops are cheap so we do not shard. +_BENCHMARK_REGISTRY: dict[str, BenchmarkProtocol] = {} +_REGISTRY_LOCK = threading.RLock() + + +def register_benchmark(name: str, benchmark: BenchmarkProtocol) -> None: + """Register a :class:`BenchmarkProtocol` under ``name``. + + Idempotent-by-overwrite: re-registering the same name replaces the + previous entry and logs a warning. This matches how users iterate on a + spec file during development. + + Args: + name: String key. Must be non-empty; any other validation is up to + the caller (lowercase / underscores / hyphens are all fine). + benchmark: An instantiated :class:`BenchmarkProtocol` subclass. + + Raises: + TypeError: If ``benchmark`` is not a :class:`BenchmarkProtocol`. + ValueError: If ``name`` is empty. + """ + if not name or not isinstance(name, str): + raise ValueError(f"register_benchmark: name must be a non-empty string, got {name!r}") + if not isinstance(benchmark, BenchmarkProtocol): + raise TypeError(f"register_benchmark: expected BenchmarkProtocol instance, got {type(benchmark).__name__}") + with _REGISTRY_LOCK: + if name in _BENCHMARK_REGISTRY: + logger.warning("Overwriting existing benchmark registration: %s", name) + _BENCHMARK_REGISTRY[name] = benchmark + logger.info("📋 Registered benchmark '%s' (%s)", name, type(benchmark).__name__) + + +def unregister_benchmark(name: str) -> BenchmarkProtocol | None: + """Remove a benchmark from the registry. + + Returns the removed benchmark or ``None`` if it was not registered. + Primarily used by tests for cleanup; user code is rarely expected to + unregister benchmarks at runtime. + """ + with _REGISTRY_LOCK: + return _BENCHMARK_REGISTRY.pop(name, None) + + +def get_benchmark(name: str) -> BenchmarkProtocol | None: + """Return the registered benchmark or ``None`` if not found.""" + with _REGISTRY_LOCK: + return _BENCHMARK_REGISTRY.get(name) + + +def list_benchmarks() -> dict[str, dict[str, Any]]: + """Enumerate registered benchmarks with their metadata. + + Returns a shallow-copy snapshot keyed by name. Each value is a dict + with ``class``, ``supported_robots``, ``default_robot``, ``max_steps`` + - enough for an LLM to pick an appropriate benchmark without + instantiating one. Reads a snapshot under the registry lock so a + concurrent registration does not corrupt the returned dict. + """ + with _REGISTRY_LOCK: + snapshot = dict(_BENCHMARK_REGISTRY) + return { + name: { + "class": type(bench).__name__, + "supported_robots": list(bench.supported_robots), + "default_robot": bench.default_robot, + "max_steps": bench.max_steps, + } + for name, bench in snapshot.items() + } + + +__all__ = [ + "BenchmarkCompatibilityError", + "BenchmarkProtocol", + "StepInfo", + "get_benchmark", + "list_benchmarks", + "register_benchmark", + "unregister_benchmark", +] diff --git a/strands_robots/simulation/benchmark_spec.py b/strands_robots/simulation/benchmark_spec.py new file mode 100644 index 0000000..adbd3af --- /dev/null +++ b/strands_robots/simulation/benchmark_spec.py @@ -0,0 +1,386 @@ +"""Declarative benchmark specs loaded from YAML / JSON files. + +This module is the LLM-facing surface for authoring benchmarks without +writing Python. A spec file declares scene, success predicate, failure +predicate, and dense reward terms using the named-predicate DSL from +:mod:`strands_robots.simulation.predicates`. Nothing in a spec ever reaches +``eval`` / ``exec`` - predicates are looked up in a closed registry and +kwargs are forwarded as-is, so spec files are safe to load from untrusted +input. + +Spec schema (top-level keys):: + + name: string # required + max_steps: int # default 300 + supported_robots: list[str] # default [] (any) + default_robot: string # required - registry data_config + scene: string # optional MJCF/URDF path for sim.load_scene() + success: + all: [, ...] # all must be true + any: [, ...] # at least one must be true + failure: + all: [, ...] + any: [, ...] + dense_reward: [, ...] # summed per step + +A ```` is a dict with a ``predicate`` key naming the +predicate and any remaining keys forwarded as kwargs:: + + {predicate: body_above_z, body: cube, z: 0.2} + +Example:: + + name: drawer-open + max_steps: 300 + supported_robots: [panda] + default_robot: panda + success: + all: + - {predicate: joint_above, joint: drawer_slide, value: 0.15} + failure: + any: + - {predicate: body_below_z, body: gripper, z: -0.1} + dense_reward: + - {predicate: distance_neg, body_a: gripper, body_b: drawer_handle, weight: 1.0} + - {predicate: joint_progress, joint: drawer_slide, target: 0.2, weight: 5.0} + +Load + register via :func:`register_benchmark_from_file`; agents call this +through the ``register_benchmark_from_file`` tool action. + +YAML files require ``pyyaml`` - not a core dep. JSON works out of the box. +The loader autodetects format by extension. +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import Callable +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from strands_robots.simulation.benchmark import ( + BenchmarkProtocol, + StepInfo, + register_benchmark, +) +from strands_robots.simulation.predicates import make_predicate +from strands_robots.utils import require_optional + +if TYPE_CHECKING: + import random + + from strands_robots.simulation.base import SimEngine + +logger = logging.getLogger(__name__) + +# Canonical top-level keys allowed in a spec. Anything else is a user error +# and produces a clear message rather than silently being ignored. +_ALLOWED_TOP_LEVEL = frozenset( + { + "name", + "max_steps", + "supported_robots", + "default_robot", + "scene", + "success", + "failure", + "dense_reward", + } +) + + +def _compile_bool_group( + clause: dict[str, Any] | None, + *, + default: bool, + context: str, +) -> Callable[[SimEngine], bool]: + """Compile an ``{"all": [...], "any": [...]}`` bool group into a single callable. + + * ``None`` / missing → returns a function always returning ``default``. + * ``all``: every listed predicate must be true. + * ``any``: at least one predicate must be true. + * Both: both conditions must hold (all AND any). + + Args: + clause: The ``success`` / ``failure`` dict from the spec. + default: Value returned when the clause is absent (``False`` for + success → "never succeeds", ``False`` for failure → "never + fails"; both are reasonable). + context: Name for error messages (``"success"`` or ``"failure"``). + + Raises: + ValueError: If the clause shape is wrong. + """ + if clause is None: + return lambda _sim: default + if not isinstance(clause, dict): + raise ValueError(f"{context}: expected a dict with 'all' / 'any' keys, got {type(clause).__name__}") + + unknown = set(clause.keys()) - {"all", "any"} + if unknown: + raise ValueError(f"{context}: unknown keys {sorted(unknown)}; allowed: ['all', 'any']") + + all_calls = [_compile_call(c, context=f"{context}.all") for c in (clause.get("all") or [])] + any_calls = [_compile_call(c, context=f"{context}.any") for c in (clause.get("any") or [])] + + if not all_calls and not any_calls: + return lambda _sim: default + + def check(sim: SimEngine) -> bool: + if all_calls and not all(bool(p(sim)) for p in all_calls): + return False + if any_calls and not any(bool(p(sim)) for p in any_calls): + return False + return True + + return check + + +def _compile_call(entry: Any, *, context: str) -> Callable[[SimEngine], Any]: + """Compile one ``{predicate: , **kwargs}`` entry to a callable.""" + if not isinstance(entry, dict): + raise ValueError(f"{context}: expected a dict like {{predicate: , ...}}, got {type(entry).__name__}") + pred_name = entry.get("predicate") + if not isinstance(pred_name, str) or not pred_name: + raise ValueError(f"{context}: each entry must have a non-empty 'predicate' string") + kwargs = {k: v for k, v in entry.items() if k != "predicate"} + try: + return make_predicate(pred_name, **kwargs) + except ValueError: + # Unknown predicate; surface verbatim (already carries the valid list). + raise + except TypeError as e: + # Bad kwargs; wrap so the caller knows which predicate failed to compile. + raise ValueError(f"{context}: predicate '{pred_name}' compilation failed: {e}") from e + + +def _compile_reward_terms(terms: list[Any] | None) -> list[Callable[[SimEngine], float]]: + if terms is None: + return [] + if not isinstance(terms, list): + raise ValueError(f"dense_reward: expected a list, got {type(terms).__name__}") + compiled: list[Callable[[SimEngine], float]] = [] + for i, t in enumerate(terms): + term = _compile_call(t, context=f"dense_reward[{i}]") + compiled.append(term) + return compiled + + +class DeclarativeBenchmark(BenchmarkProtocol): + """:class:`BenchmarkProtocol` backed by a compiled DSL spec. + + Use :func:`register_benchmark_from_file` or + :meth:`DeclarativeBenchmark.from_dict` to construct one - direct + instantiation is only for tests / internal use. + + Thread safety: the compiled predicate closures capture only the spec + kwargs (ints, floats, strings, lists of floats) so instances are safe + to share across threads. The evaluation loop still drives each episode + sequentially; we do not batch episodes. + """ + + def __init__( + self, + *, + name: str, + supported_robots: list[str], + default_robot: str, + max_steps: int, + success_fn: Callable[[SimEngine], bool], + failure_fn: Callable[[SimEngine], bool], + reward_terms: list[Callable[[SimEngine], float]], + scene: str | None = None, + ): + self._name = name + self._supported_robots = list(supported_robots) + self._default_robot = default_robot + self.max_steps = int(max_steps) + self._success_fn = success_fn + self._failure_fn = failure_fn + self._reward_terms = list(reward_terms) + self._scene = scene + + @property + def name(self) -> str: + return self._name + + @property + def supported_robots(self) -> list[str]: + return list(self._supported_robots) + + @property + def default_robot(self) -> str: + return self._default_robot + + def on_episode_start(self, sim: SimEngine, rng: random.Random) -> None: + """Load the declared scene (if any) before delegating to the base impl. + + The base impl adds :attr:`default_robot` when the sim is empty and + validates robot compatibility. Scene loading happens *before* that so + a scene-declared robot is detected by the compatibility check. + """ + if self._scene: + load_scene = getattr(sim, "load_scene", None) + if load_scene is None: + logger.warning( + "DeclarativeBenchmark '%s' declares scene=%r but sim has no load_scene()", + self._name, + self._scene, + ) + else: + result = load_scene(self._scene) + if isinstance(result, dict) and result.get("status") == "error": + msg = (result.get("content") or [{}])[0].get("text", "") + raise RuntimeError( + f"DeclarativeBenchmark '{self._name}': load_scene({self._scene!r}) failed: {msg}" + ) + super().on_episode_start(sim, rng) + + def on_step(self, sim: SimEngine, obs: dict[str, Any], action: dict[str, Any]) -> StepInfo: + """Sum all registered reward terms; ``done`` is False (handled by is_success/is_failure).""" + reward = 0.0 + for term in self._reward_terms: + try: + reward += float(term(sim)) + except Exception as e: # noqa: BLE001 - defensive: one bad term shouldn't kill the episode + logger.warning("reward term failed in '%s': %s", self._name, e) + return StepInfo(reward=reward, done=False) + + def is_success(self, sim: SimEngine) -> bool: + return bool(self._success_fn(sim)) + + def is_failure(self, sim: SimEngine) -> bool: + return bool(self._failure_fn(sim)) + + @classmethod + def from_dict(cls, spec: dict[str, Any]) -> DeclarativeBenchmark: + """Compile a spec dict (already parsed from YAML/JSON) into a benchmark.""" + if not isinstance(spec, dict): + raise ValueError(f"spec must be a dict, got {type(spec).__name__}") + + unknown = set(spec.keys()) - _ALLOWED_TOP_LEVEL + if unknown: + raise ValueError( + f"Unknown top-level keys in spec: {sorted(unknown)}. Allowed: {sorted(_ALLOWED_TOP_LEVEL)}" + ) + + name = spec.get("name") + if not isinstance(name, str) or not name: + raise ValueError("spec.name: required non-empty string") + + default_robot = spec.get("default_robot") + if not isinstance(default_robot, str) or not default_robot: + raise ValueError("spec.default_robot: required non-empty string") + + supported_robots = spec.get("supported_robots", []) + if not isinstance(supported_robots, list) or not all(isinstance(r, str) for r in supported_robots): + raise ValueError("spec.supported_robots: must be a list of strings") + + # default_robot should be in supported_robots (unless list is empty = any) + if supported_robots and default_robot not in supported_robots: + raise ValueError( + f"spec.default_robot={default_robot!r} not in supported_robots={supported_robots}; " + "either add it to supported_robots or leave supported_robots empty for any-robot benchmarks" + ) + + max_steps_raw = spec.get("max_steps", 300) + if not isinstance(max_steps_raw, int) or isinstance(max_steps_raw, bool) or max_steps_raw <= 0: + raise ValueError(f"spec.max_steps: must be a positive int, got {max_steps_raw!r}") + + scene = spec.get("scene") + if scene is not None and not isinstance(scene, str): + raise ValueError(f"spec.scene: must be a string path or omitted, got {type(scene).__name__}") + + success_fn = _compile_bool_group(spec.get("success"), default=False, context="success") + failure_fn = _compile_bool_group(spec.get("failure"), default=False, context="failure") + reward_terms = _compile_reward_terms(spec.get("dense_reward")) + + return cls( + name=name, + supported_robots=supported_robots, + default_robot=default_robot, + max_steps=max_steps_raw, + success_fn=success_fn, + failure_fn=failure_fn, + reward_terms=reward_terms, + scene=scene, + ) + + +def _load_spec_file(path: str | Path) -> dict[str, Any]: + """Parse a spec file by extension. JSON via stdlib, YAML via ``pyyaml`` (optional). + + Return type is declared as ``dict[str, Any]`` but ``json.loads`` / + ``yaml.safe_load`` may produce lists, strings, etc. Caller + (``register_benchmark_from_file``) validates the parsed shape before + passing it to :meth:`DeclarativeBenchmark.from_dict`; we do the + ``isinstance`` check here so the returned value is actually a dict. + """ + p = Path(path) + if not p.exists(): + raise FileNotFoundError(f"Benchmark spec file not found: {path}") + if not p.is_file(): + raise ValueError(f"Benchmark spec path is not a file: {path}") + + suffix = p.suffix.lower() + text = p.read_text() + + parsed: Any + if suffix == ".json": + parsed = json.loads(text) + elif suffix in (".yaml", ".yml"): + yaml = require_optional( + "yaml", + pip_install="pyyaml", + purpose="YAML benchmark spec loading", + ) + parsed = yaml.safe_load(text) # type: ignore[attr-defined] + else: + raise ValueError(f"Unsupported spec file extension: {suffix!r}. Use .json, .yaml, or .yml.") + + if not isinstance(parsed, dict): + raise ValueError(f"Benchmark spec {path} must parse to a dict, got {type(parsed).__name__}") + return parsed + + +def register_benchmark_from_file( + name: str, + spec_path: str | Path, +) -> BenchmarkProtocol: + """Load a declarative benchmark spec from disk and register it under ``name``. + + Convenience wrapper that: + + 1. Parses ``spec_path`` (JSON or YAML, autodetected by extension). + 2. Compiles it into a :class:`DeclarativeBenchmark`. + 3. Registers it via :func:`register_benchmark`. + 4. Returns the instantiated benchmark for programmatic use. + + Args: + name: Registry key. Overrides any ``name`` declared inside the spec + (so the same spec file can be registered under multiple names). + spec_path: Path to a ``.json`` / ``.yaml`` / ``.yml`` file. + + Returns: + The registered :class:`DeclarativeBenchmark` instance. + + Raises: + FileNotFoundError / ValueError: From :func:`_load_spec_file`. + ValueError: From :meth:`DeclarativeBenchmark.from_dict` on bad schema. + """ + if not isinstance(name, str) or not name: + raise ValueError(f"register_benchmark_from_file: name must be a non-empty string, got {name!r}") + spec_dict = _load_spec_file(spec_path) + # Spec-internal name is informational; the registry name always wins. + spec_dict.setdefault("name", name) + benchmark = DeclarativeBenchmark.from_dict(spec_dict) + register_benchmark(name, benchmark) + return benchmark + + +__all__ = [ + "DeclarativeBenchmark", + "register_benchmark_from_file", +] diff --git a/strands_robots/simulation/mujoco/tool_spec.json b/strands_robots/simulation/mujoco/tool_spec.json index 099ce9a..cb07311 100644 --- a/strands_robots/simulation/mujoco/tool_spec.json +++ b/strands_robots/simulation/mujoco/tool_spec.json @@ -65,7 +65,10 @@ "render_all", "start_cameras_recording", "stop_cameras_recording", - "get_cameras_recording_status" + "get_cameras_recording_status", + "list_benchmarks", + "register_benchmark_from_file", + "evaluate_benchmark" ] }, "scene_path": { @@ -364,6 +367,14 @@ "items": { "type": "object" } + }, + "benchmark_name": { + "type": "string", + "description": "Name of a registered BenchmarkProtocol (for register_benchmark_from_file / evaluate_benchmark)" + }, + "spec_path": { + "type": "string", + "description": "Path to a declarative benchmark spec file (.json, .yaml, .yml). Used by register_benchmark_from_file." } }, "required": [ diff --git a/strands_robots/simulation/policy_runner.py b/strands_robots/simulation/policy_runner.py index 4968a7e..0452d53 100644 --- a/strands_robots/simulation/policy_runner.py +++ b/strands_robots/simulation/policy_runner.py @@ -33,6 +33,7 @@ import logging import os +import random import time from collections.abc import Callable from dataclasses import dataclass @@ -46,6 +47,7 @@ if TYPE_CHECKING: from strands_robots.policies.base import Policy from strands_robots.simulation.base import SimEngine + from strands_robots.simulation.benchmark import BenchmarkProtocol from strands_robots.simulation.models import TrajectoryStep @@ -512,26 +514,67 @@ def evaluate( n_episodes: int = 10, max_steps: int = 300, success_fn: SuccessFn | str | None = None, + spec: BenchmarkProtocol | None = None, + seed: int | None = None, ) -> dict[str, Any]: """Evaluate ``policy`` for ``n_episodes`` episodes. + Two evaluation paths: + + * **``spec=``** (preferred): drive a full :class:`BenchmarkProtocol`. + Per-episode seeded RNG, ``on_episode_start`` / ``on_step`` / + ``is_success`` / ``is_failure`` hooks, cumulative dense reward, + robot-compatibility validation. ``max_steps`` from the spec wins. + * **``success_fn=``**: legacy sparse-success path kept for + backwards compatibility with PR #85. Equivalent to a + ``BenchmarkProtocol`` whose ``on_step`` always returns + ``StepInfo(reward=0.0, done=False)``. + + Passing both ``spec`` and ``success_fn`` is an error - benchmarks + define their own success predicate. + Args: robot_name: Robot to evaluate. policy: Already-constructed ``Policy`` instance. instruction: Instruction forwarded to the policy. n_episodes: Number of reset → rollout episodes. - max_steps: Cap per episode. - success_fn: Either - - * ``None`` - never succeeds (dry run / performance probe). - * ``"contact"`` - success when ``sim.get_contacts()`` reports - any penetrating contact. Requires backend to implement - ``get_contacts``; falls back to ``False`` otherwise. - * callable ``(observation) -> bool``. + max_steps: Cap per episode. Ignored when ``spec`` is provided + (``spec.max_steps`` wins). + success_fn: Legacy success predicate (see above). + spec: :class:`BenchmarkProtocol` to drive the eval. When + provided, overrides the ``success_fn`` path. + seed: Master RNG seed. Each episode derives a child RNG from it, + so evaluations are reproducible within a process. Only used + when ``spec`` is provided. Returns: - Standard status dict with ``success_rate``, per-episode results. + Standard status dict. When ``spec`` is used, the JSON payload + also contains ``cumulative_reward`` and ``avg_reward`` fields + per episode and aggregate. """ + if spec is not None and success_fn is not None: + return { + "status": "error", + "content": [ + { + "text": ( + "evaluate() accepts either 'spec' or 'success_fn', not both. " + "'spec' defines its own success predicate." + ) + } + ], + } + + if spec is not None: + return self._evaluate_with_spec( + robot_name, + policy, + spec, + instruction=instruction, + n_episodes=n_episodes, + seed=seed, + ) + try: resolved_check = self._resolve_success_fn(success_fn) except ValueError as e: @@ -593,6 +636,156 @@ def evaluate( ], } + def _evaluate_with_spec( + self, + robot_name: str, + policy: Policy, + spec: BenchmarkProtocol, + *, + instruction: str, + n_episodes: int, + seed: int | None, + ) -> dict[str, Any]: + """Drive a :class:`BenchmarkProtocol` for ``n_episodes`` episodes. + + Split out from :meth:`evaluate` to keep the legacy-path body small; + both routes share the same return-dict schema plus the spec route + layers on cumulative-reward accounting. + + Robot compatibility is validated before episode 1: if the sim's + loaded robot declares a ``data_config`` not in + ``spec.supported_robots`` (non-empty), we return a structured error + with the allowed list instead of silently running a mismatched + evaluation. + """ + # Lazy import to avoid circular reference (benchmark module imports + # `SimEngine` from base which imports this module under TYPE_CHECKING). + from strands_robots.simulation.benchmark import BenchmarkCompatibilityError + + # T26: skip camera rendering when the policy does not need images. + _skip_images = not getattr(policy, "requires_images", True) + master_rng = random.Random(seed) + spec_name = type(spec).__name__ + max_steps = spec.max_steps + results: list[dict[str, Any]] = [] + + for ep in range(n_episodes): + self.sim.reset() + # Per-episode seeded RNG - deterministic given the master seed + # and the episode index. + episode_seed = master_rng.randint(0, 2**31 - 1) + episode_rng = random.Random(episode_seed) + + try: + spec.on_episode_start(self.sim, episode_rng) + except BenchmarkCompatibilityError as e: + # Surface the structured error with the supported list - + # agents can fix this without retrying. + return { + "status": "error", + "content": [ + { + "text": ( + f"Benchmark compatibility error: robot '{e.robot_name}' " + f"has data_config={e.data_config!r}, but benchmark " + f"{spec_name} supports {e.supported}." + ) + } + ], + } + except Exception as e: # noqa: BLE001 - surface as structured error + logger.exception("on_episode_start failed") + return { + "status": "error", + "content": [{"text": f"on_episode_start failed in {spec_name}: {e}"}], + } + + success = False + failure = False + steps = 0 + cumulative_reward = 0.0 + last_info: dict[str, Any] = {} + + for _ in range(max_steps): + observation = self.sim.get_observation(robot_name=robot_name, skip_images=_skip_images) + coro_or_result = policy.get_actions(observation, instruction) + actions = _resolve_coroutine(coro_or_result) + + if actions: + action_applied: dict[str, Any] = dict(actions[0]) + self.sim.send_action(action_applied, robot_name=robot_name) + else: + # Degenerate policy - advance physics so loop terminates. + action_applied = {} + self.sim.step(n_steps=1) + + steps += 1 + try: + info = spec.on_step(self.sim, observation, action_applied) + except Exception as e: # noqa: BLE001 + logger.exception("on_step failed in %s", spec_name) + return { + "status": "error", + "content": [{"text": f"on_step failed in {spec_name}: {e}"}], + } + cumulative_reward += float(info.reward) + last_info = dict(info.info) if info.info else {} + + if info.done: + break + if spec.is_failure(self.sim): + failure = True + break + if spec.is_success(self.sim): + success = True + break + + results.append( + { + "episode": ep, + "steps": steps, + "success": success, + "failure": failure, + "cumulative_reward": round(cumulative_reward, 4), + "seed": episode_seed, + "info": last_info, + } + ) + + n_success = sum(1 for r in results if r["success"]) + n_failure = sum(1 for r in results if r["failure"]) + success_rate = n_success / max(n_episodes, 1) + avg_steps = sum(r["steps"] for r in results) / max(n_episodes, 1) + avg_reward = sum(r["cumulative_reward"] for r in results) / max(n_episodes, 1) + + return { + "status": "success", + "content": [ + { + "text": ( + f"📊 Benchmark: {spec_name} | policy {type(policy).__name__} on '{robot_name}'\n" + f"Episodes: {n_episodes} | Success: {n_success} | Failure: {n_failure} " + f"({success_rate:.1%} success)\n" + f"Avg reward: {avg_reward:.2f} | Avg steps: {avg_steps:.0f}/{max_steps}" + ) + }, + { + "json": { + "success_rate": round(success_rate, 4), + "n_episodes": n_episodes, + "n_success": n_success, + "n_failure": n_failure, + "avg_steps": round(avg_steps, 1), + "avg_reward": round(avg_reward, 4), + "max_steps": max_steps, + "seed": seed, + "benchmark_class": spec_name, + "episodes": results, + } + }, + ], + } + # Helpers def _maybe_sim_time(self) -> float | None: diff --git a/strands_robots/simulation/predicates.py b/strands_robots/simulation/predicates.py new file mode 100644 index 0000000..03d2424 --- /dev/null +++ b/strands_robots/simulation/predicates.py @@ -0,0 +1,532 @@ +"""Named-predicate library for declarative :class:`BenchmarkProtocol` specs. + +Each entry in :data:`PREDICATE_REGISTRY` is a factory ``(**kwargs) -> callable`` +where the returned callable takes a :class:`SimEngine` and returns either +``bool`` (for success/failure predicates) or ``float`` (for reward terms). + +The registry is a closed set - the YAML/JSON loader in +:mod:`strands_robots.simulation.benchmark_spec` refuses predicates whose +name is not in this registry, so spec files are safe to parse from +untrusted / LLM-authored input. **No ``eval`` is ever called.** User-defined +predicates must be registered programmatically via :func:`register_predicate` +before loading the spec. + +Predicates are backend-aware but not backend-specific: they exclusively call +``SimEngine`` methods (abstract) or probe for MuJoCo-only methods via +``getattr`` and return a safe fallback (``False`` / ``0.0``) when the +backend does not support them. A predicate that silently evaluates to +``False`` because of an unimplemented backend call is a bug in the +predicate, not the benchmark - file an issue. + +Available predicates (bool): + + body_above_z(body, z) + body_below_z(body, z) + joint_above(joint, value) + joint_below(joint, value) + distance_less_than(body_a, body_b, threshold) + inside_region(body, min, max) + contact_between(geom_a, geom_b) + contact_any() + body_on(body_a, body_b, z_offset=0.02, xy_tol=0.15) + body_inside(body, container, xy_tol=0.15, z_tol=0.15) + body_upright(body, tol=0.15) + grasped(body, gripper_prefix) + +Available reward terms (float): + + distance_neg(body_a, body_b, weight=1.0) + joint_progress(joint, target, weight=1.0) + constant(value) + +Register custom predicates with :func:`register_predicate`. +""" + +from __future__ import annotations + +import logging +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from strands_robots.simulation.base import SimEngine + +logger = logging.getLogger(__name__) + +BoolPredicate = Callable[["SimEngine"], bool] +RewardTerm = Callable[["SimEngine"], float] +PredicateFactory = Callable[..., Callable[["SimEngine"], Any]] + + +# Helpers for digging values out of the structured ``{"status", "content"}`` +# dicts that MuJoCo-backend methods return. Defensive against empty content +# lists and missing keys - predicates should never crash the eval loop. + + +def _extract_json(result: dict[str, Any] | None) -> dict[str, Any]: + """Return the ``json`` content block payload, or ``{}`` if absent.""" + if not isinstance(result, dict): + return {} + for block in result.get("content", []) or []: + if isinstance(block, dict): + payload = block.get("json") + if isinstance(payload, dict): + # dict[str, Any] by construction of the content schema; mypy can't + # narrow through dict.get() so we cast via a new dict to keep it typed. + return dict(payload) + return {} + + +def _body_position(sim: SimEngine, body: str) -> list[float] | None: + """Best-effort body-position lookup. Returns ``None`` on any failure. + + Requires the backend to implement ``get_body_state`` (MuJoCo only at time + of writing). Future backends can add the same method signature - see + :meth:`strands_robots.simulation.mujoco.physics.PhysicsMixin.get_body_state`. + """ + get_body_state = getattr(sim, "get_body_state", None) + if get_body_state is None: + return None + try: + result = get_body_state(body_name=body) + except Exception as e: # noqa: BLE001 - defensive: predicates never raise + logger.debug("body_position(%r) failed: %s", body, e) + return None + if not isinstance(result, dict) or result.get("status") != "success": + return None + payload = _extract_json(result) + pos = payload.get("position") + if isinstance(pos, list) and len(pos) == 3 and all(isinstance(c, (int, float)) for c in pos): + return [float(c) for c in pos] + return None + + +def _joint_position(sim: SimEngine, joint: str) -> float | None: + """Best-effort joint-position lookup via ``get_observation``. + + ``get_observation`` is on the ABC and returns ``{: float}``. + When the joint is absent from the observation dict (wrong robot, wrong + namespace) we return ``None`` so predicates can decide between ``False`` + and an explicit error path. + """ + try: + obs = sim.get_observation(skip_images=True) + except Exception as e: # noqa: BLE001 - defensive + logger.debug("get_observation() failed: %s", e) + return None + if not isinstance(obs, dict): + return None + val = obs.get(joint) + if isinstance(val, (int, float)) and not isinstance(val, bool): + return float(val) + return None + + +def _body_quaternion(sim: SimEngine, body: str) -> list[float] | None: + """Best-effort quaternion lookup. Returns ``None`` on any failure. + + Quaternion convention: MuJoCo reports ``[w, x, y, z]``. Callers that + need just an axis can derive it from the rotation matrix, but doing + the arithmetic inline here keeps the predicate library numpy-free. + """ + get_body_state = getattr(sim, "get_body_state", None) + if get_body_state is None: + return None + try: + result = get_body_state(body_name=body) + except Exception as e: # noqa: BLE001 - defensive + logger.debug("body_quaternion(%r) failed: %s", body, e) + return None + if not isinstance(result, dict) or result.get("status") != "success": + return None + payload = _extract_json(result) + quat = payload.get("quaternion") + if isinstance(quat, list) and len(quat) == 4 and all(isinstance(c, (int, float)) for c in quat): + return [float(c) for c in quat] + return None + + +def _euclidean_distance(a: list[float], b: list[float]) -> float: + """Simple 3D Euclidean distance; no numpy so predicates stay dependency-free.""" + dx = a[0] - b[0] + dy = a[1] - b[1] + dz = a[2] - b[2] + return float((dx * dx + dy * dy + dz * dz) ** 0.5) + + +# Predicate factories + + +def _body_above_z(body: str, z: float) -> BoolPredicate: + def check(sim: SimEngine) -> bool: + pos = _body_position(sim, body) + return pos is not None and pos[2] > float(z) + + return check + + +def _body_below_z(body: str, z: float) -> BoolPredicate: + def check(sim: SimEngine) -> bool: + pos = _body_position(sim, body) + return pos is not None and pos[2] < float(z) + + return check + + +def _joint_above(joint: str, value: float) -> BoolPredicate: + def check(sim: SimEngine) -> bool: + q = _joint_position(sim, joint) + return q is not None and q > float(value) + + return check + + +def _joint_below(joint: str, value: float) -> BoolPredicate: + def check(sim: SimEngine) -> bool: + q = _joint_position(sim, joint) + return q is not None and q < float(value) + + return check + + +def _distance_less_than(body_a: str, body_b: str, threshold: float) -> BoolPredicate: + def check(sim: SimEngine) -> bool: + pos_a = _body_position(sim, body_a) + pos_b = _body_position(sim, body_b) + if pos_a is None or pos_b is None: + return False + return _euclidean_distance(pos_a, pos_b) < float(threshold) + + return check + + +def _inside_region(body: str, min: list[float], max: list[float]) -> BoolPredicate: # noqa: A002 - DSL keyword + if not (isinstance(min, list) and len(min) == 3 and isinstance(max, list) and len(max) == 3): + raise ValueError("inside_region: 'min' and 'max' must each be a list of 3 numbers") + lo = [float(c) for c in min] + hi = [float(c) for c in max] + if any(lo[i] > hi[i] for i in range(3)): + raise ValueError(f"inside_region: 'min' {lo} must be component-wise <= 'max' {hi}") + + def check(sim: SimEngine) -> bool: + pos = _body_position(sim, body) + if pos is None: + return False + return all(lo[i] <= pos[i] <= hi[i] for i in range(3)) + + return check + + +def _contact_between(geom_a: str, geom_b: str) -> BoolPredicate: + """Pairwise contact predicate. + + Requires ``get_contacts()`` (MuJoCo). Ignores contact ordering - a contact + reported as ``(geom_a, geom_b)`` matches the same predicate as + ``(geom_b, geom_a)``. + """ + + def check(sim: SimEngine) -> bool: + get_contacts = getattr(sim, "get_contacts", None) + if get_contacts is None: + return False + try: + result = get_contacts() + except Exception as e: # noqa: BLE001 - defensive + logger.debug("contact_between(%r,%r) failed: %s", geom_a, geom_b, e) + return False + payload = _extract_json(result) + contacts = payload.get("contacts") + if not isinstance(contacts, list): + return False + want = {geom_a, geom_b} + for c in contacts: + if not isinstance(c, dict): + continue + pair = {c.get("geom1"), c.get("geom2")} + if want <= pair: + return True + return False + + return check + + +def _contact_any() -> BoolPredicate: + """Sparse "any contact" predicate - matches the legacy ``success_fn='contact'`` path.""" + + def check(sim: SimEngine) -> bool: + get_contacts = getattr(sim, "get_contacts", None) + if get_contacts is None: + return False + try: + result = get_contacts() + except Exception as e: # noqa: BLE001 - defensive + logger.debug("contact_any() failed: %s", e) + return False + payload = _extract_json(result) + if payload.get("n_contacts", 0) > 0: + return True + contacts = payload.get("contacts") + return bool(isinstance(contacts, list) and contacts) + + return check + + +def _body_on( + body_a: str, + body_b: str, + z_offset: float = 0.02, + xy_tol: float = 0.15, +) -> BoolPredicate: + """Approximate ``(on A B)`` predicate - A resting on top of B. + + True when ``A.z > B.z + z_offset`` AND horizontal distance ``|A.xy - B.xy| + < xy_tol``. The z-offset parameter accounts for B's half-height + a small + buffer; tune per scene. Intended for sparse-success benchmarks (LIBERO, + etc.) where exact geometric containment isn't required. + + For full fidelity (MJCF geom size lookup + narrow-phase collision), write + a scene-specific predicate and register it via :func:`register_predicate`. + """ + + def check(sim: SimEngine) -> bool: + pos_a = _body_position(sim, body_a) + pos_b = _body_position(sim, body_b) + if pos_a is None or pos_b is None: + return False + dx = pos_a[0] - pos_b[0] + dy = pos_a[1] - pos_b[1] + if (dx * dx + dy * dy) ** 0.5 > float(xy_tol): + return False + return pos_a[2] > pos_b[2] + float(z_offset) + + return check + + +def _body_inside(body: str, container: str, xy_tol: float = 0.15, z_tol: float = 0.15) -> BoolPredicate: + """Approximate ``(inside A B)`` predicate - A contained within B's volume. + + True when A's position is within an axis-aligned box centered on B with + half-extents (``xy_tol``, ``xy_tol``, ``z_tol``). LIBERO-typical use is + "object inside basket / drawer / compartment" where exact bbox is + benchmark-specific; the defaults are tuned for table-top manipulation. + + When richer geometry is available, override by registering a + scene-specific predicate. + """ + + def check(sim: SimEngine) -> bool: + pos_a = _body_position(sim, body) + pos_b = _body_position(sim, container) + if pos_a is None or pos_b is None: + return False + return ( + abs(pos_a[0] - pos_b[0]) <= float(xy_tol) + and abs(pos_a[1] - pos_b[1]) <= float(xy_tol) + and abs(pos_a[2] - pos_b[2]) <= float(z_tol) + ) + + return check + + +def _body_upright(body: str, tol: float = 0.15) -> BoolPredicate: + """True when ``body``'s local +Z axis is within ``tol`` of world +Z. + + Computes the rotation-matrix element ``R[2,2]`` from the body's + quaternion. Upright → ``R[2,2] > 1 - tol``. The math (all unit-quat + identities, w² + x² + y² + z² = 1): + + R[2,2] = 1 - 2*(x² + y²) + + so the check is ``2*(x² + y²) < tol``. This is monotonic in "how + tipped over" the body is, so a small tol (0.01-0.2) corresponds + directly to the maximum allowed tilt. + """ + t = float(tol) + if t < 0: + raise ValueError(f"body_upright: 'tol' must be >= 0, got {t}") + + def check(sim: SimEngine) -> bool: + quat = _body_quaternion(sim, body) + if quat is None: + return False + # MuJoCo quat layout is (w, x, y, z). + _, x, y, _ = quat + return 2.0 * (x * x + y * y) < t + + return check + + +def _grasped(body: str, gripper_prefix: str) -> BoolPredicate: + """True when ``body`` is in contact with any geom whose name starts with ``gripper_prefix``. + + Treats the gripper as a *set* of geoms (fingers, pads, tip sites) so + the caller only has to specify the common prefix - e.g. ``"robot0_gripper"`` + for Panda covers both fingers. A body is "grasped" as long as any one + gripper geom is in contact with any geom matching the body name. + + Backends must implement ``get_contacts()`` returning the MuJoCo + ``{"contacts": [{"geom1", "geom2", ...}]}`` shape. Other backends are + treated as "cannot check" and return ``False``. + """ + + def check(sim: SimEngine) -> bool: + get_contacts = getattr(sim, "get_contacts", None) + if get_contacts is None: + return False + try: + result = get_contacts() + except Exception as e: # noqa: BLE001 - defensive + logger.debug("grasped(%r, %r) failed: %s", body, gripper_prefix, e) + return False + payload = _extract_json(result) + contacts = payload.get("contacts") + if not isinstance(contacts, list): + return False + for c in contacts: + if not isinstance(c, dict): + continue + g1 = c.get("geom1") or "" + g2 = c.get("geom2") or "" + # One side must be the grasped body (bare name or "_geom" suffix); + # the other must start with the gripper prefix. + body_match = {g1, g2} & {body, f"{body}_geom"} + gripper_match = any(isinstance(g, str) and g.startswith(gripper_prefix) for g in (g1, g2)) + if body_match and gripper_match: + return True + return False + + return check + + +# Reward terms (float-valued) + + +def _distance_neg(body_a: str, body_b: str, weight: float = 1.0) -> RewardTerm: + """Negative Euclidean distance between two bodies, weighted. + + The canonical "reach" reward: ``weight * -dist(a, b)``. Monotonic in + the distance, so naive policy improvement pulls the bodies together. + """ + w = float(weight) + + def term(sim: SimEngine) -> float: + pos_a = _body_position(sim, body_a) + pos_b = _body_position(sim, body_b) + if pos_a is None or pos_b is None: + return 0.0 + return -w * _euclidean_distance(pos_a, pos_b) + + return term + + +def _joint_progress(joint: str, target: float, weight: float = 1.0) -> RewardTerm: + """Negative absolute distance from a joint to its target, weighted. + + Useful for drawer/door tasks where success is "joint near target + position" and you want dense signal during training. + """ + w = float(weight) + t = float(target) + + def term(sim: SimEngine) -> float: + q = _joint_position(sim, joint) + if q is None: + return 0.0 + return -w * abs(q - t) + + return term + + +def _constant(value: float) -> RewardTerm: + """Constant reward per step. Useful for shaping a survival bonus.""" + v = float(value) + + def term(_sim: SimEngine) -> float: + return v + + return term + + +# Registry + +PREDICATE_REGISTRY: dict[str, PredicateFactory] = { + # bool-valued + "body_above_z": _body_above_z, + "body_below_z": _body_below_z, + "joint_above": _joint_above, + "joint_below": _joint_below, + "distance_less_than": _distance_less_than, + "inside_region": _inside_region, + "contact_between": _contact_between, + "contact_any": _contact_any, + "body_on": _body_on, + "body_inside": _body_inside, + "body_upright": _body_upright, + "grasped": _grasped, + # float-valued + "distance_neg": _distance_neg, + "joint_progress": _joint_progress, + "constant": _constant, +} + + +def register_predicate(name: str, factory: PredicateFactory) -> None: + """Register a user-defined predicate factory. + + Must be called before loading a spec that references ``name``. Factories + registered at runtime are NOT sandboxed - by registering, you opt into + running the factory with kwargs parsed from the spec. Only register + predicates from trusted code paths; anything LLM-authored should use the + built-in DSL exclusively. + + Args: + name: Predicate name used in spec files. Must not shadow a built-in. + factory: Callable that takes DSL kwargs and returns a predicate + ``(sim) -> bool`` or reward term ``(sim) -> float``. + + Raises: + ValueError: If ``name`` shadows a built-in predicate. + TypeError: If ``factory`` is not callable. + """ + if name in PREDICATE_REGISTRY: + raise ValueError(f"register_predicate: '{name}' shadows a built-in predicate; pick a different name") + if not callable(factory): + raise TypeError(f"register_predicate: factory must be callable, got {type(factory).__name__}") + PREDICATE_REGISTRY[name] = factory + + +def make_predicate(name: str, **kwargs: Any) -> Callable[[SimEngine], Any]: + """Instantiate a predicate from its name + kwargs. + + This is the single entry point the DSL loader uses - it never touches + ``eval`` or ``exec``. Unknown names produce a ``ValueError`` listing + the valid set; bad kwargs surface as whatever ``TypeError`` the factory + raises. + + Args: + name: Predicate name. Must be registered in :data:`PREDICATE_REGISTRY`. + **kwargs: Forwarded verbatim to the factory. + + Returns: + A callable ``(sim) -> bool`` or ``(sim) -> float`` depending on the + predicate. + + Raises: + ValueError: If ``name`` is unknown. + TypeError: If required factory kwargs are missing. + """ + factory = PREDICATE_REGISTRY.get(name) + if factory is None: + valid = sorted(PREDICATE_REGISTRY.keys()) + raise ValueError(f"Unknown predicate '{name}'. Valid: {valid}") + return factory(**kwargs) + + +__all__ = [ + "PREDICATE_REGISTRY", + "BoolPredicate", + "PredicateFactory", + "RewardTerm", + "make_predicate", + "register_predicate", +] diff --git a/tests/benchmarks/__init__.py b/tests/benchmarks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/benchmarks/libero/__init__.py b/tests/benchmarks/libero/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/benchmarks/libero/test_bddl_parser.py b/tests/benchmarks/libero/test_bddl_parser.py new file mode 100644 index 0000000..82fe00d --- /dev/null +++ b/tests/benchmarks/libero/test_bddl_parser.py @@ -0,0 +1,371 @@ +"""Tests for the LIBERO BDDL parser. + +Covers: + +* Tokenizer handling of comments, quoted strings, nested parens. +* S-expression parsing - depth, arity, EOF errors. +* Top-level ``(define ...)`` structure + section extraction + (``:domain``, ``:objects``, ``:init``, ``:goal``, ``:language``). +* Predicate compilation for every entry in ``PREDICATE_VOCABULARY``. +* Boolean combinators (``and`` / ``or`` / ``not``) with short-circuit behaviour. +* Rejection of unknown predicates / wrong arities. +* Round-trip on a curated 5-task subset covering each predicate family. + +The compiled callables are executed against the same fake sims used by +``tests/simulation/test_benchmark_predicates.py`` - no LIBERO / MuJoCo +dependency required. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from strands_robots.benchmarks.libero.bddl_parser import ( + PREDICATE_VOCABULARY, + And, + BDDLParseError, + Not, + Or, + Pred, + _tokenize, + compile_goal, + parse_bddl, + parse_bddl_file, +) + +# Fake sims + + +class _BodyStateSim: + def __init__(self, bodies: dict[str, dict[str, Any]]): + self._bodies = bodies + + def get_body_state(self, body_name: str) -> dict[str, Any]: + if body_name not in self._bodies: + return {"status": "error", "content": [{"text": "missing"}]} + return { + "status": "success", + "content": [ + {"text": body_name}, + { + "json": { + "position": self._bodies[body_name].get("position", [0, 0, 0]), + "quaternion": self._bodies[body_name].get("quaternion", [1, 0, 0, 0]), + "mass": 1.0, + } + }, + ], + } + + def get_observation(self, *_, **__) -> dict[str, Any]: + return self._bodies.get("_joints", {}) + + +class _ContactSim: + def __init__(self, contacts: list[dict[str, str]]): + self._contacts = contacts + + def get_contacts(self) -> dict[str, Any]: + return { + "status": "success", + "content": [ + {"text": f"{len(self._contacts)} contacts"}, + {"json": {"contacts": self._contacts, "n_contacts": len(self._contacts)}}, + ], + } + + +class _CombinedSim(_BodyStateSim, _ContactSim): + """Both body state and contacts for multi-predicate goals.""" + + def __init__( + self, + bodies: dict[str, dict[str, Any]] | None = None, + contacts: list[dict[str, str]] | None = None, + ): + _BodyStateSim.__init__(self, bodies or {}) + _ContactSim.__init__(self, contacts or []) + + +# Tokenizer + + +class TestTokenize: + def test_basic(self): + assert _tokenize("(and a b)") == ["(", "and", "a", "b", ")"] + + def test_comments_stripped(self): + assert _tokenize("(foo) ; trailing comment\n(bar)") == ["(", "foo", ")", "(", "bar", ")"] + + def test_quoted_strings_preserved(self): + toks = _tokenize('(:language "pick the red cube")') + # The quoted region is a single token, including the quotes. + assert '"pick the red cube"' in toks + + def test_unterminated_quote_errors(self): + with pytest.raises(BDDLParseError, match="unterminated quoted string"): + _tokenize('(:language "unterminated') + + +# Top-level parser + + +class TestParseBDDL: + def test_minimal(self): + text = """ + (define (problem libero_pick) + (:domain kitchen) + (:goal (on cube_1 plate_1))) + """ + problem = parse_bddl(text) + assert problem.name == "libero_pick" + assert problem.domain == "kitchen" + assert isinstance(problem.goal, Pred) + assert problem.goal.name == "on" + assert problem.goal.args == ("cube_1", "plate_1") + + def test_extracts_language(self): + text = """ + (define (problem p1) + (:language "pick up the red cube") + (:goal (grasped cube_1))) + """ + problem = parse_bddl(text) + assert problem.language == "pick up the red cube" + + def test_extracts_objects_flattening_typed_syntax(self): + """PDDL-style ``obj1 obj2 - type`` annotations are flattened to symbols.""" + text = """ + (define (problem p) + (:objects cube_1 plate_1 - object table_1 - fixture) + (:goal (on cube_1 plate_1))) + """ + problem = parse_bddl(text) + assert problem.objects == ["cube_1", "plate_1", "object", "table_1", "fixture"] + + def test_extracts_init_clauses(self): + text = """ + (define (problem p) + (:init (on cube_1 table_1) (upright bottle_1)) + (:goal (on cube_1 plate_1))) + """ + problem = parse_bddl(text) + assert len(problem.init) == 2 + # Each init clause is a compiled Pred. + assert all(isinstance(n, Pred) for n in problem.init) + + def test_goal_with_and(self): + text = """ + (define (problem p) + (:goal (and (on cube_1 plate_1) (upright cube_1)))) + """ + problem = parse_bddl(text) + assert isinstance(problem.goal, And) + assert len(problem.goal.clauses) == 2 + + def test_goal_with_or_and_not(self): + text = """ + (define (problem p) + (:goal (or (grasped cube_1) (not (on cube_1 table_1))))) + """ + problem = parse_bddl(text) + assert isinstance(problem.goal, Or) + inner = problem.goal.clauses[1] + assert isinstance(inner, Not) + + def test_missing_define_rejected(self): + with pytest.raises(BDDLParseError, match="top-level"): + parse_bddl("(problem foo)") + + def test_empty_input_rejected(self): + with pytest.raises(BDDLParseError): + parse_bddl("") + + def test_missing_paren_rejected(self): + with pytest.raises(BDDLParseError, match="closing"): + parse_bddl("(define (problem p)") + + def test_trailing_tokens_rejected(self): + with pytest.raises(BDDLParseError, match="trailing"): + parse_bddl("(define (problem p) (:goal (on a b))) (extra)") + + +# Predicate vocabulary + + +class TestPredicateVocabulary: + @pytest.mark.parametrize("bddl_name", sorted(PREDICATE_VOCABULARY.keys())) + def test_every_predicate_compiles(self, bddl_name: str): + """Each BDDL predicate must produce a compilable goal with a valid argc.""" + sample_args = { + "on": "cube_1 table_1", + "near": "cube_1 gripper_1", + "inside": "cube_1 basket_1", + "open": "drawer_joint", + "closed": "drawer_joint", + "grasped": "cube_1", + "upright": "bottle_1", + } + args = sample_args[bddl_name] + text = f"(define (problem p) (:goal ({bddl_name} {args})))" + problem = parse_bddl(text) + # Must compile without error. + fn = compile_goal(problem.goal) # type: ignore[arg-type] + assert callable(fn) + + def test_unknown_predicate_rejected_with_list(self): + text = "(define (problem p) (:goal (telekinesis cube_1)))" + with pytest.raises(BDDLParseError) as exc: + parse_bddl(text) + assert "unknown predicate" in str(exc.value).lower() + # Error must list the valid vocabulary so the author can fix it. + for expected in ("on", "grasped", "upright"): + assert expected in str(exc.value) + + @pytest.mark.parametrize( + "expr,reason", + [ + ("(on cube_1)", "wrong arity"), + ("(on cube_1 plate_1 extra)", "extra arg"), + ("(grasped)", "no arg"), + ("(upright a b)", "extra arg"), + ], + ) + def test_wrong_arity_rejected(self, expr: str, reason: str): + with pytest.raises(BDDLParseError): + parse_bddl(f"(define (problem p) (:goal {expr}))") + + def test_not_with_wrong_arity(self): + with pytest.raises(BDDLParseError, match="not"): + parse_bddl("(define (problem p) (:goal (not (on a b) (on c d))))") + + +# Compiled goal evaluation + + +class TestCompileGoal: + def test_and_short_circuits(self): + """``and`` must evaluate to False as soon as one clause fails.""" + text = """ + (define (problem p) + (:goal (and + (on cube_1 table_1) + (upright bottle_1)))) + """ + problem = parse_bddl(text) + fn = compile_goal(problem.goal) # type: ignore[arg-type] + sim_hit = _BodyStateSim( + { + "cube_1": {"position": [0, 0, 0.2]}, + "table_1": {"position": [0, 0, 0.0]}, + "bottle_1": {"quaternion": [1.0, 0.0, 0.0, 0.0]}, + } + ) + sim_miss_upright = _BodyStateSim( + { + "cube_1": {"position": [0, 0, 0.2]}, + "table_1": {"position": [0, 0, 0.0]}, + "bottle_1": {"quaternion": [0.707, 0.707, 0, 0]}, + } + ) + assert fn(sim_hit) is True + assert fn(sim_miss_upright) is False + + def test_or(self): + text = "(define (problem p) (:goal (or (upright a) (upright b))))" + problem = parse_bddl(text) + fn = compile_goal(problem.goal) # type: ignore[arg-type] + only_b = _BodyStateSim( + { + "a": {"quaternion": [0.707, 0.707, 0, 0]}, # tipped + "b": {"quaternion": [1.0, 0, 0, 0]}, # upright + } + ) + assert fn(only_b) is True + + def test_not(self): + text = "(define (problem p) (:goal (not (grasped cube_1))))" + problem = parse_bddl(text) + fn = compile_goal(problem.goal) # type: ignore[arg-type] + # Without any contacts, grasped is False, so (not grasped) is True. + no_contacts = _ContactSim([]) + assert fn(no_contacts) is True + # With a gripper contact, grasped is True, so (not grasped) is False. + with_grip = _ContactSim([{"geom1": "robot0_gripper_finger_r", "geom2": "cube_1_geom"}]) + assert fn(with_grip) is False + + +# Representative LIBERO-style round-trip + + +class TestRoundTrip: + """One example per predicate family so a regression in any predicate is caught here.""" + + def test_pick_task_on(self): + text = """ + (define (problem libero_spatial_pick_up_the_red_cube) + (:language "pick up the red cube and place it on the plate") + (:objects cube_1 plate_1 table_1 - object) + (:init (on cube_1 table_1)) + (:goal (on cube_1 plate_1))) + """ + problem = parse_bddl(text) + assert problem.language == "pick up the red cube and place it on the plate" + fn = compile_goal(problem.goal) # type: ignore[arg-type] + sim_success = _BodyStateSim({"cube_1": {"position": [0, 0, 0.25]}, "plate_1": {"position": [0, 0, 0.1]}}) + assert fn(sim_success) is True + + def test_open_task(self): + text = "(define (problem libero_open_drawer) (:goal (open drawer_slide)))" + problem = parse_bddl(text) + fn = compile_goal(problem.goal) # type: ignore[arg-type] + sim = _BodyStateSim({"_joints": {"drawer_slide": 0.2}}) + assert fn(sim) is True + sim2 = _BodyStateSim({"_joints": {"drawer_slide": 0.02}}) + assert fn(sim2) is False + + def test_grasp_task(self): + text = "(define (problem libero_grasp_cube) (:goal (grasped cube_1)))" + problem = parse_bddl(text) + fn = compile_goal(problem.goal) # type: ignore[arg-type] + sim = _ContactSim([{"geom1": "robot0_gripper_finger_l", "geom2": "cube_1"}]) + assert fn(sim) is True + + def test_upright_task(self): + text = "(define (problem libero_keep_upright) (:goal (upright bottle_1)))" + problem = parse_bddl(text) + fn = compile_goal(problem.goal) # type: ignore[arg-type] + sim = _BodyStateSim({"bottle_1": {"quaternion": [1.0, 0, 0, 0]}}) + assert fn(sim) is True + + def test_inside_task(self): + text = """ + (define (problem libero_put_inside) + (:goal (inside cube_1 basket_1))) + """ + problem = parse_bddl(text) + fn = compile_goal(problem.goal) # type: ignore[arg-type] + # Approximate-inside uses default tolerances (0.15, 0.15). + sim = _BodyStateSim({"cube_1": {"position": [0.05, 0.02, 0.1]}, "basket_1": {"position": [0, 0, 0.1]}}) + assert fn(sim) is True + + +# File loader + + +class TestParseBDDLFile: + def test_happy_path(self, tmp_path): + p = tmp_path / "task.bddl" + p.write_text("(define (problem p) (:goal (grasped cube_1)))") + problem = parse_bddl_file(p) + assert problem.name == "p" + + def test_missing_file(self, tmp_path): + with pytest.raises(FileNotFoundError): + parse_bddl_file(tmp_path / "nope.bddl") + + def test_not_a_file(self, tmp_path): + with pytest.raises(ValueError): + parse_bddl_file(tmp_path) diff --git a/tests/benchmarks/libero/test_libero_adapter.py b/tests/benchmarks/libero/test_libero_adapter.py new file mode 100644 index 0000000..af6abb3 --- /dev/null +++ b/tests/benchmarks/libero/test_libero_adapter.py @@ -0,0 +1,397 @@ +"""Tests for :class:`LiberoAdapter`. + +Covers: + +* Construction via ``from_file`` / ``from_text`` / raw ``__init__``. +* ``supported_robots`` / ``default_robot`` = Panda-only. +* ``instruction`` surfaces the BDDL ``:language`` string. +* ``is_success`` positive + negative cases against fake sims (no MuJoCo + needed - the predicates poll ``get_body_state`` / ``get_contacts``). +* ``on_episode_start`` loads the scene (or errors cleanly) and applies + per-episode jitter when the sim exposes ``move_object``. +* Integration with ``PolicyRunner.evaluate`` and + ``SimEngine.evaluate_benchmark`` via a minimal ``FakeSim`` stub. +* Error surface: unknown task via ``evaluate_benchmark`` returns a + structured error dict, never raises. +""" + +from __future__ import annotations + +import random +from typing import Any + +import pytest + +from strands_robots.benchmarks.libero import ( + BDDLParseError, + LiberoAdapter, +) +from strands_robots.policies.mock import MockPolicy +from strands_robots.simulation.base import SimEngine +from strands_robots.simulation.benchmark import ( + _BENCHMARK_REGISTRY, + register_benchmark, +) +from strands_robots.simulation.policy_runner import PolicyRunner + + +@pytest.fixture(autouse=True) +def _clean_registry(): + snapshot = dict(_BENCHMARK_REGISTRY) + _BENCHMARK_REGISTRY.clear() + yield + _BENCHMARK_REGISTRY.clear() + _BENCHMARK_REGISTRY.update(snapshot) + + +# Representative BDDL fragments + +PICK_CUBE_BDDL = """ +(define (problem libero_spatial_pick_cube) + (:domain kitchen) + (:language "pick up the red cube and place it on the plate") + (:objects cube_1 plate_1 table_1 - object) + (:init (on cube_1 table_1)) + (:goal (on cube_1 plate_1))) +""" + +COMPOUND_BDDL = """ +(define (problem libero_grasp_and_upright) + (:language "grasp the bottle and keep it upright") + (:goal (and (grasped bottle_1) (upright bottle_1)))) +""" + +NEGATED_BDDL = """ +(define (problem libero_release) + (:goal (not (grasped cube_1)))) +""" + + +# Fake sim helpers + + +class _FakeRobot: + def __init__(self, data_config: str): + self.data_config = data_config + + +class _FakeWorld: + def __init__(self, robots: dict[str, _FakeRobot]): + self.robots = dict(robots) + + +class FakeSim(SimEngine): + """Minimal ``SimEngine`` with get_body_state / get_contacts / move_object.""" + + def __init__( + self, + bodies: dict[str, dict[str, Any]] | None = None, + contacts: list[dict[str, str]] | None = None, + data_config: str = "panda", + ): + self._bodies = dict(bodies or {}) + self._contacts = list(contacts or []) + self._reset_count = 0 + self._move_calls: list[tuple[str, list[float]]] = [] + self._world = _FakeWorld({"fake_panda": _FakeRobot(data_config)}) + self._scenes_loaded: list[str] = [] + + def create_world(self, timestep=None, gravity=None, ground_plane=True): + return {"status": "success"} + + def destroy(self): + return {"status": "success"} + + def reset(self): + self._reset_count += 1 + return {"status": "success"} + + def step(self, n_steps: int = 1): + return {"status": "success"} + + def get_state(self): + return {} + + def add_robot(self, name, **kw): + dc = kw.get("data_config") or "panda" + self._world.robots[name] = _FakeRobot(dc) + return {"status": "success"} + + def remove_robot(self, name): + return {"status": "success"} + + def list_robots(self): + return list(self._world.robots.keys()) + + def robot_joint_names(self, robot_name): + return ["j0", "j1"] + + def add_object(self, name, **kw): + return {"status": "success"} + + def remove_object(self, name): + return {"status": "success"} + + def get_observation(self, robot_name=None, *, skip_images=False): + return {n: 0.0 for n in self.robot_joint_names(robot_name or "fake_panda")} + + def send_action(self, action, robot_name=None, n_substeps=1): + pass + + def render(self, camera_name="default", width=None, height=None): + return {"status": "success", "content": [{"text": "render"}]} + + # Optional helpers used by predicates + adapter + + def get_body_state(self, body_name: str) -> dict[str, Any]: + if body_name not in self._bodies: + return {"status": "error", "content": [{"text": "missing"}]} + return { + "status": "success", + "content": [ + {"text": body_name}, + { + "json": { + "position": self._bodies[body_name].get("position", [0, 0, 0]), + "quaternion": self._bodies[body_name].get("quaternion", [1, 0, 0, 0]), + "mass": 1.0, + } + }, + ], + } + + def get_contacts(self) -> dict[str, Any]: + return { + "status": "success", + "content": [ + {"text": f"{len(self._contacts)} contacts"}, + {"json": {"contacts": self._contacts, "n_contacts": len(self._contacts)}}, + ], + } + + def move_object(self, *, name: str, position: list[float]) -> dict[str, Any]: + self._move_calls.append((name, list(position))) + self._bodies.setdefault(name, {})["position"] = list(position) + return {"status": "success"} + + def load_scene(self, scene_path: str) -> dict[str, Any]: + self._scenes_loaded.append(scene_path) + return {"status": "success"} + + +# Construction + + +class TestConstruction: + def test_from_text_happy_path(self): + adapter = LiberoAdapter.from_text(PICK_CUBE_BDDL) + assert adapter.supported_robots == ["panda"] + assert adapter.default_robot == "panda" + assert adapter.max_steps == 300 + assert adapter.instruction == "pick up the red cube and place it on the plate" + + def test_from_text_respects_max_steps_override(self): + adapter = LiberoAdapter.from_text(PICK_CUBE_BDDL, max_steps=75) + assert adapter.max_steps == 75 + + def test_from_text_rejects_negative_jitter(self): + with pytest.raises(ValueError): + LiberoAdapter.from_text(PICK_CUBE_BDDL, init_jitter=-0.1) + + def test_rejects_bddl_without_goal(self): + text = '(define (problem no_goal) (:language "no goal block"))' + with pytest.raises(ValueError, match="no \\(:goal"): + LiberoAdapter.from_text(text) + + def test_from_file(self, tmp_path): + p = tmp_path / "task.bddl" + p.write_text(PICK_CUBE_BDDL) + adapter = LiberoAdapter.from_file(p) + assert adapter.problem.name == "libero_spatial_pick_cube" + + def test_from_text_propagates_parse_errors(self): + with pytest.raises(BDDLParseError): + LiberoAdapter.from_text("(define (problem p) (:goal (telekinesis cube_1)))") + + +# Lifecycle hooks + + +class TestIsSuccess: + def test_positive_case_on(self): + adapter = LiberoAdapter.from_text(PICK_CUBE_BDDL) + sim = FakeSim( + bodies={ + "cube_1": {"position": [0, 0, 0.25]}, + "plate_1": {"position": [0, 0, 0.1]}, + } + ) + assert adapter.is_success(sim) is True + + def test_negative_case_on(self): + adapter = LiberoAdapter.from_text(PICK_CUBE_BDDL) + sim = FakeSim( + bodies={ + "cube_1": {"position": [0, 0, 0.05]}, # not above plate + "plate_1": {"position": [0, 0, 0.1]}, + } + ) + assert adapter.is_success(sim) is False + + def test_compound_goal_needs_both(self): + adapter = LiberoAdapter.from_text(COMPOUND_BDDL) + sim_neither = FakeSim(bodies={"bottle_1": {"quaternion": [1.0, 0, 0, 0]}}) + sim_upright_only = FakeSim(bodies={"bottle_1": {"quaternion": [1.0, 0, 0, 0]}}) + sim_both = FakeSim( + bodies={"bottle_1": {"quaternion": [1.0, 0, 0, 0]}}, + contacts=[{"geom1": "robot0_gripper_finger_l", "geom2": "bottle_1"}], + ) + assert adapter.is_success(sim_neither) is False + assert adapter.is_success(sim_upright_only) is False + assert adapter.is_success(sim_both) is True + + def test_negated_goal(self): + adapter = LiberoAdapter.from_text(NEGATED_BDDL) + sim_empty = FakeSim() + sim_gripped = FakeSim(contacts=[{"geom1": "robot0_gripper_finger_l", "geom2": "cube_1"}]) + assert adapter.is_success(sim_empty) is True + assert adapter.is_success(sim_gripped) is False + + +class TestOnEpisodeStart: + def test_loads_scene_before_compat_check(self, tmp_path): + """``scene_path`` load must happen before ``super().on_episode_start`` + so the base compat check sees the scene's Panda robot.""" + adapter = LiberoAdapter.from_text(PICK_CUBE_BDDL, scene_path="/fake/scene.xml") + sim = FakeSim(data_config="panda") + adapter.on_episode_start(sim, random.Random(0)) + assert sim._scenes_loaded == ["/fake/scene.xml"] + + def test_scene_load_error_raises(self): + adapter = LiberoAdapter.from_text(PICK_CUBE_BDDL, scene_path="/bad/path.xml") + + sim = FakeSim(data_config="panda") + # Override load_scene to report failure. + sim.load_scene = lambda path: { # type: ignore[assignment] + "status": "error", + "content": [{"text": f"no such file: {path}"}], + } + with pytest.raises(RuntimeError, match="load_scene"): + adapter.on_episode_start(sim, random.Random(0)) + + def test_jitter_applied_when_init_declares_subject(self): + adapter = LiberoAdapter.from_text( + """ + (define (problem p) + (:init (on cube_1 table_1)) + (:goal (on cube_1 plate_1))) + """, + init_jitter=0.01, + ) + sim = FakeSim( + bodies={ + "cube_1": {"position": [0.0, 0.0, 0.1]}, + "table_1": {"position": [0.0, 0.0, 0.0]}, + }, + data_config="panda", + ) + adapter.on_episode_start(sim, random.Random(42)) + # cube_1 got jittered; table_1 is only the reference and not moved. + assert any(call[0] == "cube_1" for call in sim._move_calls) + assert not any(call[0] == "table_1" for call in sim._move_calls) + + def test_jitter_disabled_when_zero(self): + adapter = LiberoAdapter.from_text( + """ + (define (problem p) + (:init (on cube_1 table_1)) + (:goal (on cube_1 plate_1))) + """, + init_jitter=0.0, + ) + sim = FakeSim( + bodies={"cube_1": {"position": [0, 0, 0.1]}, "table_1": {"position": [0, 0, 0]}}, + data_config="panda", + ) + adapter.on_episode_start(sim, random.Random(42)) + assert sim._move_calls == [] + + def test_non_panda_robot_rejected_by_base_compat_check(self): + adapter = LiberoAdapter.from_text(PICK_CUBE_BDDL) + sim = FakeSim(data_config="so100") # not panda + from strands_robots.simulation.benchmark import BenchmarkCompatibilityError + + with pytest.raises(BenchmarkCompatibilityError) as exc: + adapter.on_episode_start(sim, random.Random(0)) + assert exc.value.supported == ["panda"] + assert exc.value.data_config == "so100" + + +# Step semantics + + +class TestOnStep: + def test_sparse_reward_zero_and_not_done(self): + adapter = LiberoAdapter.from_text(PICK_CUBE_BDDL) + info = adapter.on_step(FakeSim(), {}, {}) + assert info.reward == 0.0 + assert info.done is False + + +# PolicyRunner + evaluate_benchmark integration + + +class TestEvaluateBenchmarkIntegration: + def test_evaluate_with_mock_policy_succeeds(self): + """Mock policy drives a loop; the benchmark loop returns a success_rate + without crashing even though the mock policy doesn't actually win.""" + adapter = LiberoAdapter.from_text(PICK_CUBE_BDDL, max_steps=4) + register_benchmark("libero-test-pick", adapter) + # Sim is loaded with Panda but predicate positions don't match the + # goal, so every episode should fall through to max_steps with + # success=False - that's fine, we're testing the loop. + sim = FakeSim( + bodies={ + "cube_1": {"position": [0, 0, 0.0]}, + "plate_1": {"position": [0, 0, 0.0]}, + }, + data_config="panda", + ) + # Rename so SimEngine.evaluate_benchmark can resolve the sole robot. + sim._world.robots.clear() + sim._world.robots["panda_arm"] = _FakeRobot("panda") + + result = sim.evaluate_benchmark( + benchmark_name="libero-test-pick", + policy_provider="mock", + n_episodes=2, + seed=0, + ) + assert result["status"] == "success" + payload = next(c["json"] for c in result["content"] if "json" in c) + assert payload["n_episodes"] == 2 + assert payload["benchmark_class"] == "LiberoAdapter" + + def test_unknown_task_returns_structured_error(self): + sim = FakeSim(data_config="panda") + result = sim.evaluate_benchmark(benchmark_name="libero-nonexistent") + assert result["status"] == "error" + assert "no benchmark registered" in result["content"][0]["text"].lower() + + def test_runner_counts_success_when_predicate_holds(self): + """Seed the sim with predicate-satisfying state so is_success returns + True on the first step - success_rate should be 1.0.""" + adapter = LiberoAdapter.from_text( + "(define (problem p) (:goal (upright bottle)))", + max_steps=3, + ) + sim = FakeSim( + bodies={"bottle": {"quaternion": [1.0, 0, 0, 0]}}, + data_config="panda", + ) + policy = MockPolicy() + policy.set_robot_state_keys(sim.robot_joint_names("fake_panda")) + result = PolicyRunner(sim).evaluate("fake_panda", policy, spec=adapter, n_episodes=1) + payload = next(c["json"] for c in result["content"] if "json" in c) + assert payload["n_success"] == 1 + assert payload["episodes"][0]["steps"] == 1 # terminated on first step diff --git a/tests/benchmarks/libero/test_libero_e2e.py b/tests/benchmarks/libero/test_libero_e2e.py new file mode 100644 index 0000000..3426cbe --- /dev/null +++ b/tests/benchmarks/libero/test_libero_e2e.py @@ -0,0 +1,148 @@ +"""End-to-end dispatch test for LiberoAdapter via the MuJoCo Simulation class. + +Exercises the full register_benchmark → evaluate_benchmark path through +``_dispatch_action`` with a real MuJoCo world. Does not require the +``libero`` pip package - uses a hand-written BDDL string and inline MJCF. + +Distinct from ``tests/simulation/mujoco/test_benchmark_dispatch.py`` which +covers the generic benchmark dispatch path; this test pins the LIBERO +adapter + BDDL compile pipeline against a live sim. +""" + +from __future__ import annotations + +import os +import shutil +import tempfile + +import pytest + +mj = pytest.importorskip("mujoco") + +from strands_robots.benchmarks.libero import LiberoAdapter # noqa: E402 +from strands_robots.simulation.benchmark import ( # noqa: E402 + _BENCHMARK_REGISTRY, + register_benchmark, +) +from strands_robots.simulation.mujoco.simulation import Simulation # noqa: E402 + +# Simple Panda-like MJCF - good enough for the benchmark compat check +# (declares data_config=panda on the robot) and for get_body_state lookups. +PANDA_LIKE_XML = """ + + + +""" + + +@pytest.fixture(autouse=True) +def _clean_registry(): + snapshot = dict(_BENCHMARK_REGISTRY) + _BENCHMARK_REGISTRY.clear() + yield + _BENCHMARK_REGISTRY.clear() + _BENCHMARK_REGISTRY.update(snapshot) + + +@pytest.fixture +def sim(): + s = Simulation(tool_name="libero_sim", mesh=False) + yield s + s.cleanup() + + +@pytest.fixture +def robot_xml_path(): + tmpdir = tempfile.mkdtemp() + path = os.path.join(tmpdir, "panda_lite.xml") + with open(path, "w") as f: + f.write(PANDA_LIKE_XML) + yield path + shutil.rmtree(tmpdir, ignore_errors=True) + + +@pytest.fixture +def sim_with_panda(sim, robot_xml_path): + sim.create_world() + # data_config="panda" so LiberoAdapter's compat check passes. + result = sim.add_robot("panda_arm", urdf_path=robot_xml_path, data_config="panda") + assert result["status"] == "success" + return sim + + +class TestLiberoEvaluateBenchmarkEndToEnd: + def test_registered_adapter_round_trips_via_dispatcher(self, sim_with_panda): + """The MuJoCo dispatcher resolves evaluate_benchmark → PolicyRunner → + LiberoAdapter.is_success without crashing.""" + adapter = LiberoAdapter.from_text( + """ + (define (problem libero_keep_arm_intact) + (:language "do anything; success is never") + (:goal (on nonexistent_cube nonexistent_plate))) + """, + max_steps=3, + init_jitter=0.0, # nonexistent bodies - don't jitter + ) + register_benchmark("libero-e2e", adapter) + + result = sim_with_panda._dispatch_action( + "evaluate_benchmark", + { + "action": "evaluate_benchmark", + "benchmark_name": "libero-e2e", + "robot_name": "panda_arm", + "policy_provider": "mock", + "n_episodes": 2, + "seed": 7, + }, + ) + assert result["status"] == "success" + payload = next(c["json"] for c in result["content"] if "json" in c) + assert payload["n_episodes"] == 2 + assert payload["benchmark_class"] == "LiberoAdapter" + # Mock policy can't satisfy the goal; success_rate is 0 across both eps. + assert payload["n_success"] == 0 + # But the loop ran - each episode ran max_steps=3. + assert all(ep["steps"] == 3 for ep in payload["episodes"]) + + def test_non_panda_robot_surfaces_structured_compat_error(self, sim, robot_xml_path): + """A sim loaded with a non-Panda data_config must produce a structured + error, not a raw traceback, when the LIBERO adapter evaluates.""" + sim.create_world() + sim.add_robot("so100_arm", urdf_path=robot_xml_path, data_config="so100") + + adapter = LiberoAdapter.from_text("(define (problem t) (:goal (grasped cube)))") + register_benchmark("libero-compat-test", adapter) + + result = sim._dispatch_action( + "evaluate_benchmark", + { + "action": "evaluate_benchmark", + "benchmark_name": "libero-compat-test", + "robot_name": "so100_arm", + "policy_provider": "mock", + "n_episodes": 1, + }, + ) + assert result["status"] == "error" + text = result["content"][0]["text"] + assert "compatibility" in text.lower() or "supported" in text.lower() + assert "so100" in text + assert "panda" in text diff --git a/tests/benchmarks/libero/test_libero_suite.py b/tests/benchmarks/libero/test_libero_suite.py new file mode 100644 index 0000000..425f19c --- /dev/null +++ b/tests/benchmarks/libero/test_libero_suite.py @@ -0,0 +1,140 @@ +"""Tests for :func:`load_libero_suite` and the suite enumeration helpers. + +These tests do NOT require the ``libero`` pip package - they all use the +``bddl_dir=`` override to point at a temp directory of hand-written BDDL +files. The upstream-package path is covered indirectly (via the probe +fallback in :func:`_locate_bddl_dir`) but not exercised directly; that +requires the real package layout and would bloat CI. +""" + +from __future__ import annotations + +import pytest + +from strands_robots.benchmarks.libero.suite import ( + SUITE_NAMES, + _normalise_suite_name, + available_suites, + load_libero_suite, +) +from strands_robots.simulation.benchmark import _BENCHMARK_REGISTRY, get_benchmark + + +@pytest.fixture(autouse=True) +def _clean_registry(): + snapshot = dict(_BENCHMARK_REGISTRY) + _BENCHMARK_REGISTRY.clear() + yield + _BENCHMARK_REGISTRY.clear() + _BENCHMARK_REGISTRY.update(snapshot) + + +def _write(path, text): + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(text) + + +# Suite name normalisation + + +class TestSuiteNames: + @pytest.mark.parametrize( + "raw,expected", + [ + ("libero_spatial", "libero_spatial"), + ("libero-spatial", "libero_spatial"), + ("spatial", "libero_spatial"), + ("LIBERO-10", "libero_10"), + (" libero_90 ", "libero_90"), + ], + ) + def test_normalise(self, raw, expected): + assert _normalise_suite_name(raw) == expected + + def test_available_suites_matches_SUITE_NAMES(self): + assert set(available_suites()) == set(SUITE_NAMES) + + +# load_libero_suite with bddl_dir override + + +class TestLoadLiberoSuite: + def test_registers_all_tasks_under_prefix(self, tmp_path): + suite_dir = tmp_path / "libero_spatial" + _write( + suite_dir / "pick_up_the_red_cube.bddl", + "(define (problem t1) (:goal (on cube plate)))", + ) + _write( + suite_dir / "stack_blue_block.bddl", + "(define (problem t2) (:goal (on block base)))", + ) + + registered = load_libero_suite("libero_spatial", bddl_dir=suite_dir) + assert set(registered.keys()) == { + "libero-spatial-pick_up_the_red_cube", + "libero-spatial-stack_blue_block", + } + # Each one is retrievable from the global registry. + assert get_benchmark("libero-spatial-pick_up_the_red_cube") is not None + + def test_custom_key_prefix(self, tmp_path): + suite_dir = tmp_path / "libero_object" + _write(suite_dir / "task_a.bddl", "(define (problem t) (:goal (grasped a)))") + registered = load_libero_suite("libero_object", bddl_dir=suite_dir, key_prefix="") + assert "object-task_a" in registered + + def test_resolves_scene_path_when_file_exists(self, tmp_path): + suite_dir = tmp_path / "libero_spatial" + scene_dir = tmp_path / "scenes" + _write(suite_dir / "pick_cube.bddl", "(define (problem t) (:goal (grasped cube)))") + _write(scene_dir / "pick_cube.xml", "") + + registered = load_libero_suite("libero_spatial", bddl_dir=suite_dir, scene_dir=scene_dir) + adapter = registered["libero-spatial-pick_cube"] + assert adapter.scene_path == str(scene_dir / "pick_cube.xml") + + def test_missing_scene_leaves_adapter_scene_none(self, tmp_path): + suite_dir = tmp_path / "libero_spatial" + scene_dir = tmp_path / "scenes" + scene_dir.mkdir() + _write(suite_dir / "pick_cube.bddl", "(define (problem t) (:goal (grasped cube)))") + + registered = load_libero_suite("libero_spatial", bddl_dir=suite_dir, scene_dir=scene_dir) + adapter = registered["libero-spatial-pick_cube"] + assert adapter.scene_path is None + + def test_malformed_bddl_is_skipped_not_fatal(self, tmp_path, caplog): + """A single bad BDDL file must not prevent the rest of the suite from loading.""" + suite_dir = tmp_path / "libero_spatial" + _write(suite_dir / "good.bddl", "(define (problem good) (:goal (grasped cube)))") + _write(suite_dir / "bad.bddl", "(this is not bddl") + + with caplog.at_level("WARNING"): + registered = load_libero_suite("libero_spatial", bddl_dir=suite_dir) + assert "libero-spatial-good" in registered + assert "libero-spatial-bad" not in registered + assert any("Skipping" in rec.message for rec in caplog.records) + + def test_forwards_max_steps_and_jitter(self, tmp_path): + suite_dir = tmp_path / "libero_spatial" + _write(suite_dir / "t.bddl", "(define (problem t) (:goal (grasped cube)))") + + registered = load_libero_suite("libero_spatial", bddl_dir=suite_dir, max_steps=42, init_jitter=0.0) + adapter = registered["libero-spatial-t"] + assert adapter.max_steps == 42 + assert adapter._init_jitter == 0.0 + + def test_unknown_suite_name_rejected(self, tmp_path): + with pytest.raises(ValueError, match="libero_"): + load_libero_suite("libero_unknown_suite", bddl_dir=tmp_path) + + def test_nonexistent_bddl_dir(self, tmp_path): + with pytest.raises(FileNotFoundError): + load_libero_suite("libero_spatial", bddl_dir=tmp_path / "nope") + + def test_empty_directory_registers_nothing(self, tmp_path): + suite_dir = tmp_path / "libero_spatial" + suite_dir.mkdir() + registered = load_libero_suite("libero_spatial", bddl_dir=suite_dir) + assert registered == {} diff --git a/tests/simulation/mujoco/test_benchmark_dispatch.py b/tests/simulation/mujoco/test_benchmark_dispatch.py new file mode 100644 index 0000000..f215195 --- /dev/null +++ b/tests/simulation/mujoco/test_benchmark_dispatch.py @@ -0,0 +1,238 @@ +"""Dispatch-path tests for the benchmark tool actions on the MuJoCo backend. + +Mirrors the ``test_agenttool_contract.py`` pattern: exercises ``_dispatch_action`` +with the new action names (``list_benchmarks``, ``register_benchmark_from_file``, +``evaluate_benchmark``) and asserts: + +* the tool_spec ``action`` enum exposes them, +* unknown / missing params produce the friendly structured errors that the + dispatcher generates from ``inspect.signature``, +* the underlying ``SimEngine`` facade is reached for valid inputs. + +Integration with real MuJoCo physics is deferred to ``tests_integ/``; these +tests only need the Simulation stub without a created world. +""" + +from __future__ import annotations + +import json +import os +import shutil +import tempfile +from pathlib import Path + +import pytest + +mj = pytest.importorskip("mujoco") + +from strands_robots.simulation.benchmark import _BENCHMARK_REGISTRY # noqa: E402 +from strands_robots.simulation.mujoco.simulation import Simulation # noqa: E402 + +# Robot XML reused from test_simulation.py (keep in sync if the canonical +# fixture changes). +ROBOT_XML = """ + + + +""" + + +@pytest.fixture(autouse=True) +def _clean_registry(): + snapshot = dict(_BENCHMARK_REGISTRY) + _BENCHMARK_REGISTRY.clear() + yield + _BENCHMARK_REGISTRY.clear() + _BENCHMARK_REGISTRY.update(snapshot) + + +@pytest.fixture +def sim(): + s = Simulation(tool_name="bench_sim", mesh=False) + yield s + s.cleanup() + + +@pytest.fixture +def robot_xml_path(): + tmpdir = tempfile.mkdtemp() + path = os.path.join(tmpdir, "test_arm.xml") + with open(path, "w") as f: + f.write(ROBOT_XML) + yield path + shutil.rmtree(tmpdir, ignore_errors=True) + + +@pytest.fixture +def sim_with_robot(sim, robot_xml_path): + sim.create_world() + sim.add_robot("arm1", urdf_path=robot_xml_path) + return sim + + +@pytest.fixture +def basic_spec_file(tmp_path: Path): + path = tmp_path / "basic.json" + path.write_text( + json.dumps( + { + "name": "basic-task", + "default_robot": "arm1", + "supported_robots": [], # any robot (so100 by data_config isn't loaded here) + "max_steps": 5, + } + ) + ) + return str(path) + + +# Tool spec: action enum + property surface + + +class TestToolSpecSurface: + def test_enum_includes_new_actions(self, sim): + # _TOOL_SPEC_SCHEMA lives at module level; read via module introspection. + from strands_robots.simulation.mujoco import simulation as _sim_mod + + enum = _sim_mod._TOOL_SPEC_SCHEMA["properties"]["action"]["enum"] + assert "list_benchmarks" in enum + assert "register_benchmark_from_file" in enum + assert "evaluate_benchmark" in enum + + def test_property_surface_has_new_params(self): + from strands_robots.simulation.mujoco import simulation as _sim_mod + + props = _sim_mod._TOOL_SPEC_SCHEMA["properties"] + assert "benchmark_name" in props + assert "spec_path" in props + + +# Dispatch + + +class TestListBenchmarksDispatch: + def test_empty_registry(self, sim): + result = sim._dispatch_action("list_benchmarks", {"action": "list_benchmarks"}) + assert result["status"] == "success" + assert "No benchmarks" in result["content"][0]["text"] + + def test_lists_registered(self, sim, basic_spec_file): + sim._dispatch_action( + "register_benchmark_from_file", + { + "action": "register_benchmark_from_file", + "benchmark_name": "basic", + "spec_path": basic_spec_file, + }, + ) + result = sim._dispatch_action("list_benchmarks", {"action": "list_benchmarks"}) + assert result["status"] == "success" + payload = next(c["json"] for c in result["content"] if "json" in c) + assert "basic" in payload["benchmarks"] + + +class TestRegisterBenchmarkFromFileDispatch: + def test_happy_path(self, sim, basic_spec_file): + result = sim._dispatch_action( + "register_benchmark_from_file", + { + "action": "register_benchmark_from_file", + "benchmark_name": "happy", + "spec_path": basic_spec_file, + }, + ) + assert result["status"] == "success" + assert "happy" in result["content"][0]["text"] + + def test_no_args_friendly_error(self, sim): + """Dispatcher surfaces missing required params with a clear message.""" + result = sim._dispatch_action("register_benchmark_from_file", {"action": "register_benchmark_from_file"}) + assert result["status"] == "error" + text = result["content"][0]["text"] + assert "requires parameter" in text + + def test_bad_spec_path_returns_structured_error(self, sim): + result = sim._dispatch_action( + "register_benchmark_from_file", + { + "action": "register_benchmark_from_file", + "benchmark_name": "missing", + "spec_path": "/nonexistent/path/nope.json", + }, + ) + assert result["status"] == "error" + assert "not found" in result["content"][0]["text"].lower() + + def test_unknown_param_rejected(self, sim, basic_spec_file): + """Dispatcher rejects unknown kwargs so spec drift is caught early.""" + result = sim._dispatch_action( + "register_benchmark_from_file", + { + "action": "register_benchmark_from_file", + "benchmark_name": "x", + "spec_path": basic_spec_file, + "bogus": 1, + }, + ) + assert result["status"] == "error" + assert "Unknown parameter 'bogus'" in result["content"][0]["text"] + + +class TestEvaluateBenchmarkDispatch: + def test_requires_benchmark_name(self, sim): + result = sim._dispatch_action("evaluate_benchmark", {"action": "evaluate_benchmark"}) + assert result["status"] == "error" + assert "requires parameter" in result["content"][0]["text"] + + def test_unknown_benchmark(self, sim_with_robot): + result = sim_with_robot._dispatch_action( + "evaluate_benchmark", + {"action": "evaluate_benchmark", "benchmark_name": "never-registered"}, + ) + assert result["status"] == "error" + assert "no benchmark registered" in result["content"][0]["text"].lower() + + def test_evaluate_end_to_end(self, sim_with_robot, basic_spec_file): + """Register a no-op benchmark, evaluate with the mock policy - must succeed.""" + sim_with_robot._dispatch_action( + "register_benchmark_from_file", + { + "action": "register_benchmark_from_file", + "benchmark_name": "e2e", + "spec_path": basic_spec_file, + }, + ) + result = sim_with_robot._dispatch_action( + "evaluate_benchmark", + { + "action": "evaluate_benchmark", + "benchmark_name": "e2e", + "robot_name": "arm1", + "policy_provider": "mock", + "n_episodes": 1, + "seed": 0, + }, + ) + assert result["status"] == "success" + payload = next(c["json"] for c in result["content"] if "json" in c) + assert payload["n_episodes"] == 1 + assert payload["benchmark_class"] == "DeclarativeBenchmark" + # With no success/failure/done in the spec, loop runs to max_steps=5. + assert payload["episodes"][0]["steps"] == 5 diff --git a/tests/simulation/test_benchmark.py b/tests/simulation/test_benchmark.py new file mode 100644 index 0000000..57abda4 --- /dev/null +++ b/tests/simulation/test_benchmark.py @@ -0,0 +1,320 @@ +"""Tests for ``strands_robots.simulation.benchmark``. + +Covers: + +* :class:`BenchmarkProtocol` ABC contract (cannot instantiate abstract, + required methods must be implemented, optional hooks have usable defaults). +* :class:`StepInfo` dataclass. +* Registry operations (:func:`register_benchmark` / :func:`get_benchmark` / + :func:`list_benchmarks` / :func:`unregister_benchmark`), including + idempotent-overwrite and thread safety. +* Robot compatibility validation via :meth:`BenchmarkProtocol.on_episode_start`. +""" + +from __future__ import annotations + +import random +import threading +from typing import Any + +import pytest + +from strands_robots.simulation.base import SimEngine +from strands_robots.simulation.benchmark import ( + _BENCHMARK_REGISTRY, + BenchmarkCompatibilityError, + BenchmarkProtocol, + StepInfo, + get_benchmark, + list_benchmarks, + register_benchmark, + unregister_benchmark, +) + + +class _MinimalBenchmark(BenchmarkProtocol): + """Concrete benchmark used across tests.""" + + max_steps = 42 + + def __init__( + self, + *, + supported: list[str] | None = None, + default: str = "so100", + success: bool = False, + failure: bool = False, + reward: float = 0.0, + ): + self._supported = list(supported if supported is not None else ["so100"]) + self._default = default + self._success = success + self._failure = failure + self._reward = reward + + @property + def supported_robots(self) -> list[str]: + return list(self._supported) + + @property + def default_robot(self) -> str: + return self._default + + def on_step(self, sim: SimEngine, obs: dict[str, Any], action: dict[str, Any]) -> StepInfo: + return StepInfo(reward=self._reward) + + def is_success(self, sim: SimEngine) -> bool: + return self._success + + def is_failure(self, sim: SimEngine) -> bool: + return self._failure + + +# Registry fixtures + + +@pytest.fixture(autouse=True) +def _clean_registry(): + """Snapshot + restore the registry around every test so they stay isolated.""" + snapshot = dict(_BENCHMARK_REGISTRY) + _BENCHMARK_REGISTRY.clear() + yield + _BENCHMARK_REGISTRY.clear() + _BENCHMARK_REGISTRY.update(snapshot) + + +# StepInfo + + +class TestStepInfo: + def test_defaults(self): + info = StepInfo() + assert info.reward == 0.0 + assert info.done is False + assert info.info == {} + + def test_custom_values(self): + info = StepInfo(reward=3.5, done=True, info={"k": "v"}) + assert info.reward == 3.5 + assert info.done is True + assert info.info == {"k": "v"} + + def test_is_frozen(self): + info = StepInfo() + with pytest.raises(Exception): # dataclasses.FrozenInstanceError + info.reward = 1.0 # type: ignore[misc] + + +# BenchmarkProtocol ABC contract + + +class TestBenchmarkProtocolContract: + def test_cannot_instantiate_abstract(self): + """ABC with abstract methods must not be instantiable.""" + with pytest.raises(TypeError): + BenchmarkProtocol() # type: ignore[abstract] + + def test_concrete_instantiates(self): + bench = _MinimalBenchmark() + assert bench.supported_robots == ["so100"] + assert bench.default_robot == "so100" + assert bench.max_steps == 42 + + def test_is_failure_default_false(self): + """The default is_failure returns False, so sparse-success benchmarks + don't need to override it.""" + + class _Sparse(_MinimalBenchmark): + pass + + # Don't set failure=True; default should return False. + bench = _Sparse() + assert bench.is_failure(None) is False # type: ignore[arg-type] + + def test_on_episode_start_has_default_impl(self): + """on_episode_start is NOT abstract - base impl handles empty-sim + compat checks.""" + + # Fake sim with no robots - should call add_robot with default_robot. + class FakeSim: + def __init__(self): + self._robots: list[str] = [] + self.add_robot_calls: list[dict[str, Any]] = [] + + def list_robots(self): + return list(self._robots) + + def add_robot(self, *, name, data_config): + self.add_robot_calls.append({"name": name, "data_config": data_config}) + self._robots.append(name) + + sim = FakeSim() + bench = _MinimalBenchmark(supported=["so100"], default="so100") + bench.on_episode_start(sim, random.Random(0)) # type: ignore[arg-type] + assert len(sim.add_robot_calls) == 1 + assert sim.add_robot_calls[0]["data_config"] == "so100" + + +# Robot compatibility + + +class TestRobotCompatibility: + def test_raises_when_loaded_robot_incompatible(self): + """Loading a robot whose data_config is not in supported_robots raises + BenchmarkCompatibilityError.""" + + class _Robot: + data_config = "panda" # not in supported + + class FakeSimWithWorld: + _world: Any = None + + def __init__(self): + self._world = type( + "World", + (), + {"robots": {"arm1": _Robot()}}, + )() + + def list_robots(self): + return ["arm1"] + + def add_robot(self, **kw): # pragma: no cover + raise AssertionError("should not be called when robot already present") + + sim = FakeSimWithWorld() + bench = _MinimalBenchmark(supported=["so100", "so101"], default="so100") + with pytest.raises(BenchmarkCompatibilityError) as excinfo: + bench.on_episode_start(sim, random.Random(0)) # type: ignore[arg-type] + assert excinfo.value.robot_name == "arm1" + assert excinfo.value.data_config == "panda" + assert excinfo.value.supported == ["so100", "so101"] + + def test_passes_when_supported_robots_empty(self): + """Empty supported_robots means 'any robot' - no compat check.""" + + class _Robot: + data_config = "any_weird_thing" + + class FakeSim: + def __init__(self): + self._world = type( + "World", + (), + {"robots": {"arm1": _Robot()}}, + )() + + def list_robots(self): + return ["arm1"] + + def add_robot(self, **kw): # pragma: no cover + raise AssertionError("should not be called") + + sim = FakeSim() + bench = _MinimalBenchmark(supported=[], default="any_weird_thing") + # Must not raise. + bench.on_episode_start(sim, random.Random(0)) # type: ignore[arg-type] + + def test_skips_check_when_sim_has_no_world_attr(self): + """Backends without a ``_world`` attribute are treated as "cannot verify" + and skip the compat check rather than false-positive.""" + + class FakeSim: + def list_robots(self): + return ["arm1"] + + def add_robot(self, **kw): # pragma: no cover + raise AssertionError("should not be called") + + sim = FakeSim() + bench = _MinimalBenchmark(supported=["so100"], default="so100") + # Must not raise even though arm1 has unknown data_config. + bench.on_episode_start(sim, random.Random(0)) # type: ignore[arg-type] + + +# Registry + + +class TestRegistry: + def test_register_and_get(self): + bench = _MinimalBenchmark() + register_benchmark("my-task", bench) + assert get_benchmark("my-task") is bench + + def test_get_unknown_returns_none(self): + assert get_benchmark("nonexistent") is None + + def test_register_rejects_non_string_name(self): + with pytest.raises(ValueError): + register_benchmark("", _MinimalBenchmark()) + with pytest.raises(ValueError): + register_benchmark(None, _MinimalBenchmark()) # type: ignore[arg-type] + + def test_register_rejects_non_benchmark(self): + with pytest.raises(TypeError): + register_benchmark("x", "not a benchmark") # type: ignore[arg-type] + + def test_register_overwrites_and_warns(self, caplog): + """Re-registering the same name replaces the entry and logs a warning.""" + first = _MinimalBenchmark() + second = _MinimalBenchmark() + register_benchmark("dup", first) + with caplog.at_level("WARNING"): + register_benchmark("dup", second) + assert get_benchmark("dup") is second + assert any("Overwriting existing" in rec.message for rec in caplog.records) + + def test_unregister_removes(self): + bench = _MinimalBenchmark() + register_benchmark("rm-me", bench) + removed = unregister_benchmark("rm-me") + assert removed is bench + assert get_benchmark("rm-me") is None + + def test_unregister_unknown_returns_none(self): + assert unregister_benchmark("never-registered") is None + + def test_list_benchmarks_metadata(self): + bench = _MinimalBenchmark(supported=["so100", "so101"], default="so100") + register_benchmark("listed", bench) + listed = list_benchmarks() + assert "listed" in listed + meta = listed["listed"] + assert meta["class"] == "_MinimalBenchmark" + assert meta["supported_robots"] == ["so100", "so101"] + assert meta["default_robot"] == "so100" + assert meta["max_steps"] == 42 + + def test_list_benchmarks_empty(self): + assert list_benchmarks() == {} + + +class TestRegistryThreadSafety: + """The registry guard is an RLock so concurrent registrations don't race.""" + + def test_concurrent_registrations_all_land(self): + benches = [_MinimalBenchmark() for _ in range(50)] + barrier = threading.Barrier(len(benches)) + + def register(i: int): + barrier.wait() + register_benchmark(f"thread-{i}", benches[i]) + + threads = [threading.Thread(target=register, args=(i,)) for i in range(len(benches))] + for t in threads: + t.start() + for t in threads: + t.join() + + listed = list_benchmarks() + for i in range(len(benches)): + assert f"thread-{i}" in listed + + +class TestBenchmarkCompatibilityError: + def test_carries_context(self): + e = BenchmarkCompatibilityError(robot_name="arm", data_config="foo", supported=["bar"]) + assert e.robot_name == "arm" + assert e.data_config == "foo" + assert e.supported == ["bar"] + # Subclasses ValueError so broad except ValueError still works. + assert isinstance(e, ValueError) diff --git a/tests/simulation/test_benchmark_dsl.py b/tests/simulation/test_benchmark_dsl.py new file mode 100644 index 0000000..841b0f7 --- /dev/null +++ b/tests/simulation/test_benchmark_dsl.py @@ -0,0 +1,365 @@ +"""Tests for ``strands_robots.simulation.benchmark_spec`` (declarative YAML/JSON loader). + +Covers: + +* :meth:`DeclarativeBenchmark.from_dict` schema validation (good / bad specs). +* :func:`register_benchmark_from_file` end-to-end with JSON + YAML. +* The sandboxed contract: unknown predicates / unknown top-level keys / + non-dict predicate entries produce clear errors, not ``eval`` side-effects. +""" + +from __future__ import annotations + +import json +import random +from pathlib import Path +from typing import Any + +import pytest + +from strands_robots.simulation.benchmark import ( + _BENCHMARK_REGISTRY, + get_benchmark, +) +from strands_robots.simulation.benchmark_spec import ( + DeclarativeBenchmark, + register_benchmark_from_file, +) + + +@pytest.fixture(autouse=True) +def _clean_registry(): + snapshot = dict(_BENCHMARK_REGISTRY) + _BENCHMARK_REGISTRY.clear() + yield + _BENCHMARK_REGISTRY.clear() + _BENCHMARK_REGISTRY.update(snapshot) + + +class _BodyStateSim: + def __init__(self, positions: dict[str, list[float]]): + self._pos = positions + + def get_body_state(self, body_name: str) -> dict[str, Any]: + if body_name not in self._pos: + return {"status": "error", "content": [{"text": "missing"}]} + return { + "status": "success", + "content": [ + {"text": body_name}, + {"json": {"position": self._pos[body_name]}}, + ], + } + + def get_observation(self, *_, **__) -> dict[str, Any]: + return {} + + +# Schema validation + + +class TestFromDictValidation: + def test_minimal_valid_spec(self): + spec = { + "name": "minimal", + "default_robot": "so100", + "supported_robots": ["so100"], + } + bench = DeclarativeBenchmark.from_dict(spec) + assert bench.name == "minimal" + assert bench.default_robot == "so100" + assert bench.supported_robots == ["so100"] + assert bench.max_steps == 300 # default + + def test_rejects_non_dict_spec(self): + with pytest.raises(ValueError): + DeclarativeBenchmark.from_dict([1, 2, 3]) # type: ignore[arg-type] + + def test_rejects_unknown_top_level_keys(self): + with pytest.raises(ValueError) as exc: + DeclarativeBenchmark.from_dict( + {"name": "x", "default_robot": "y", "supported_robots": ["y"], "weird_key": 1} + ) + assert "weird_key" in str(exc.value) + + def test_rejects_missing_name(self): + with pytest.raises(ValueError): + DeclarativeBenchmark.from_dict({"default_robot": "y", "supported_robots": []}) + + def test_rejects_missing_default_robot(self): + with pytest.raises(ValueError): + DeclarativeBenchmark.from_dict({"name": "x"}) + + def test_rejects_default_not_in_supported(self): + with pytest.raises(ValueError) as exc: + DeclarativeBenchmark.from_dict({"name": "x", "default_robot": "ghost", "supported_robots": ["a", "b"]}) + assert "not in supported_robots" in str(exc.value) + + def test_allows_default_outside_supported_when_empty(self): + """Empty supported_robots means "any" - default outside makes sense.""" + bench = DeclarativeBenchmark.from_dict({"name": "x", "default_robot": "anything", "supported_robots": []}) + assert bench.default_robot == "anything" + + def test_rejects_non_positive_max_steps(self): + for bad in (-1, 0, "300", True): + with pytest.raises(ValueError): + DeclarativeBenchmark.from_dict( + { + "name": "x", + "default_robot": "y", + "supported_robots": ["y"], + "max_steps": bad, + } + ) + + +# Predicate compilation + + +class TestPredicateCompilation: + def _base_spec(self, **overrides: Any) -> dict[str, Any]: + spec = { + "name": "t", + "default_robot": "so100", + "supported_robots": ["so100"], + "max_steps": 10, + } + spec.update(overrides) + return spec + + def test_success_all_true(self): + spec = self._base_spec( + success={ + "all": [ + {"predicate": "body_above_z", "body": "cube", "z": 0.1}, + ] + } + ) + bench = DeclarativeBenchmark.from_dict(spec) + sim_hit = _BodyStateSim({"cube": [0, 0, 0.2]}) + sim_miss = _BodyStateSim({"cube": [0, 0, 0.05]}) + assert bench.is_success(sim_hit) is True + assert bench.is_success(sim_miss) is False + + def test_success_all_any_combined(self): + """When both 'all' and 'any' are provided, both must hold.""" + spec = self._base_spec( + success={ + "all": [{"predicate": "body_above_z", "body": "cube", "z": 0.0}], + "any": [ + {"predicate": "body_above_z", "body": "cube", "z": 10.0}, + {"predicate": "body_above_z", "body": "cube", "z": 0.05}, + ], + } + ) + bench = DeclarativeBenchmark.from_dict(spec) + sim = _BodyStateSim({"cube": [0, 0, 0.1]}) + # all: z>0.0 true. any: z>10 false OR z>0.05 true → any true. Combined: true. + assert bench.is_success(sim) is True + + def test_failure_any(self): + spec = self._base_spec(failure={"any": [{"predicate": "body_below_z", "body": "cube", "z": 0.0}]}) + bench = DeclarativeBenchmark.from_dict(spec) + assert bench.is_failure(_BodyStateSim({"cube": [0, 0, -0.01]})) is True + assert bench.is_failure(_BodyStateSim({"cube": [0, 0, 0.5]})) is False + + def test_dense_reward_sums_terms(self): + spec = self._base_spec( + dense_reward=[ + {"predicate": "constant", "value": 1.0}, + {"predicate": "constant", "value": -0.5}, + ] + ) + bench = DeclarativeBenchmark.from_dict(spec) + info = bench.on_step(None, {}, {}) # type: ignore[arg-type] + assert info.reward == pytest.approx(0.5) + + def test_rejects_unknown_predicate(self): + with pytest.raises(ValueError) as exc: + DeclarativeBenchmark.from_dict(self._base_spec(success={"all": [{"predicate": "totally_made_up"}]})) + assert "Unknown predicate" in str(exc.value) + + def test_rejects_non_dict_predicate_entry(self): + with pytest.raises(ValueError): + DeclarativeBenchmark.from_dict(self._base_spec(success={"all": ["just a string"]})) + + def test_rejects_missing_predicate_key(self): + with pytest.raises(ValueError): + DeclarativeBenchmark.from_dict(self._base_spec(success={"all": [{"body": "cube", "z": 0.1}]})) + + def test_rejects_bad_clause_keys(self): + """success/failure only allow 'all' / 'any'.""" + with pytest.raises(ValueError) as exc: + DeclarativeBenchmark.from_dict(self._base_spec(success={"all": [], "other": []})) + assert "other" in str(exc.value) + + def test_predicate_bad_kwargs_surface_compile_error(self): + """Bad predicate kwargs (wrong types, missing required) surface as a + compile-time error, not a runtime predicate crash.""" + with pytest.raises(ValueError) as exc: + DeclarativeBenchmark.from_dict( + self._base_spec(success={"all": [{"predicate": "inside_region", "body": "x"}]}) + ) + # Should mention the predicate name in the error for discoverability. + assert "inside_region" in str(exc.value) + + +# Empty / default clauses + + +class TestEmptyClauses: + def test_success_absent_defaults_to_false(self): + bench = DeclarativeBenchmark.from_dict({"name": "x", "default_robot": "so100", "supported_robots": ["so100"]}) + assert bench.is_success(_BodyStateSim({})) is False + + def test_failure_absent_defaults_to_false(self): + bench = DeclarativeBenchmark.from_dict({"name": "x", "default_robot": "so100", "supported_robots": ["so100"]}) + assert bench.is_failure(_BodyStateSim({})) is False + + def test_empty_success_returns_false(self): + """Non-None but empty success clause must not default to "always true".""" + bench = DeclarativeBenchmark.from_dict( + { + "name": "x", + "default_robot": "so100", + "supported_robots": ["so100"], + "success": {"all": [], "any": []}, + } + ) + assert bench.is_success(_BodyStateSim({})) is False + + +# File loading + + +class TestRegisterBenchmarkFromFile: + def test_register_from_json(self, tmp_path): + spec_path = tmp_path / "drawer.json" + spec_path.write_text( + json.dumps( + { + "name": "drawer", + "default_robot": "so100", + "supported_robots": ["so100"], + "max_steps": 50, + "success": { + "all": [ + {"predicate": "body_above_z", "body": "cube", "z": 0.1}, + ] + }, + } + ) + ) + bench = register_benchmark_from_file("drawer", str(spec_path)) + assert get_benchmark("drawer") is bench + assert bench.max_steps == 50 + assert bench.is_success(_BodyStateSim({"cube": [0, 0, 0.5]})) is True + + def test_register_from_yaml(self, tmp_path): + """YAML support is opt-in; skip if pyyaml isn't available in this env.""" + pytest.importorskip("yaml") + spec_path = tmp_path / "y.yaml" + spec_path.write_text( + """ +name: yml-task +default_robot: so100 +supported_robots: [so100] +max_steps: 99 +success: + all: + - {predicate: body_above_z, body: cube, z: 0.5} +""" + ) + bench = register_benchmark_from_file("yml-task", str(spec_path)) + assert bench.max_steps == 99 + + def test_file_not_found(self, tmp_path): + with pytest.raises(FileNotFoundError): + register_benchmark_from_file("missing", str(tmp_path / "nope.json")) + + def test_rejects_unsupported_extension(self, tmp_path): + p = tmp_path / "spec.toml" + p.write_text("") + with pytest.raises(ValueError) as exc: + register_benchmark_from_file("x", str(p)) + assert ".toml" in str(exc.value) or "extension" in str(exc.value) + + def test_spec_name_internal_overridden_by_registry_name(self, tmp_path): + """Registry name wins over any ``name`` declared inside the spec file.""" + p = tmp_path / "s.json" + p.write_text( + json.dumps( + { + "name": "internal-name", + "default_robot": "so100", + "supported_robots": ["so100"], + } + ) + ) + register_benchmark_from_file("external-name", str(p)) + assert get_benchmark("external-name") is not None + # The spec's internal name doesn't end up in the registry. + assert get_benchmark("internal-name") is None + + def test_rejects_empty_name(self, tmp_path): + p = tmp_path / "s.json" + p.write_text('{"name": "x", "default_robot": "y", "supported_robots": []}') + with pytest.raises(ValueError): + register_benchmark_from_file("", str(p)) + + def test_bad_json_propagates(self, tmp_path): + p = tmp_path / "bad.json" + p.write_text("{not json}") + with pytest.raises(json.JSONDecodeError): + register_benchmark_from_file("x", str(p)) + + +# DeclarativeBenchmark lifecycle + + +class TestDeclarativeBenchmarkLifecycle: + def test_on_episode_start_delegates_to_base(self): + """Default on_episode_start loads the default_robot when sim is empty.""" + spec = { + "name": "x", + "default_robot": "so100", + "supported_robots": ["so100"], + } + bench = DeclarativeBenchmark.from_dict(spec) + + class FakeSim: + def __init__(self): + self.added: list[dict[str, Any]] = [] + + def list_robots(self): + return [] + + def add_robot(self, *, name, data_config): + self.added.append({"name": name, "data_config": data_config}) + + sim = FakeSim() + bench.on_episode_start(sim, random.Random(0)) # type: ignore[arg-type] + assert len(sim.added) == 1 + assert sim.added[0]["data_config"] == "so100" + + def test_scene_load_error_raises(self, tmp_path: Path): + """If the sim's load_scene returns an error dict, the benchmark must surface it.""" + spec = { + "name": "x", + "default_robot": "so100", + "supported_robots": ["so100"], + "scene": str(tmp_path / "missing.xml"), + } + bench = DeclarativeBenchmark.from_dict(spec) + + class FakeSim: + def load_scene(self, path): + return {"status": "error", "content": [{"text": f"no such file: {path}"}]} + + def list_robots(self): + return ["preloaded"] + + sim = FakeSim() + with pytest.raises(RuntimeError) as exc: + bench.on_episode_start(sim, random.Random(0)) # type: ignore[arg-type] + assert "load_scene" in str(exc.value) diff --git a/tests/simulation/test_benchmark_predicates.py b/tests/simulation/test_benchmark_predicates.py new file mode 100644 index 0000000..dbb9d96 --- /dev/null +++ b/tests/simulation/test_benchmark_predicates.py @@ -0,0 +1,403 @@ +"""Tests for ``strands_robots.simulation.predicates``. + +Each predicate is tested against a lightweight fake sim that implements +only the methods the predicate exercises. Real MuJoCo integration is out +of scope here - those predicates are covered end-to-end in the dispatch +tests under ``tests/simulation/mujoco/``. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from strands_robots.simulation.predicates import ( + PREDICATE_REGISTRY, + make_predicate, + register_predicate, +) + +# Fake sim helpers + + +class _BodyStateSim: + """Sim that exposes ``get_body_state`` with caller-provided positions.""" + + def __init__(self, positions: dict[str, list[float]]): + self._pos = positions + + def get_body_state(self, body_name: str) -> dict[str, Any]: + if body_name not in self._pos: + return {"status": "error", "content": [{"text": f"Body '{body_name}' not found."}]} + return { + "status": "success", + "content": [ + {"text": f"body {body_name}"}, + { + "json": { + "position": self._pos[body_name], + "quaternion": [1, 0, 0, 0], + "mass": 1.0, + } + }, + ], + } + + # Predicates that probe `get_observation` for joint state need this stub. + def get_observation(self, *_, **__) -> dict[str, Any]: + return {} + + +class _JointObsSim: + """Sim that exposes joint positions via ``get_observation``.""" + + def __init__(self, joints: dict[str, float]): + self._joints = joints + + def get_observation(self, *_, **__) -> dict[str, float]: + return dict(self._joints) + + def get_body_state(self, body_name: str) -> dict[str, Any]: # pragma: no cover + return {"status": "error", "content": [{"text": "no bodies"}]} + + +class _ContactSim: + """Sim that exposes ``get_contacts`` in the MuJoCo-backend shape.""" + + def __init__(self, contacts: list[dict[str, Any]]): + self._contacts = contacts + + def get_contacts(self) -> dict[str, Any]: + return { + "status": "success", + "content": [ + {"text": f"{len(self._contacts)} contacts"}, + { + "json": { + "contacts": self._contacts, + "n_contacts": len(self._contacts), + } + }, + ], + } + + +class _NoHelpersSim: + """Sim missing get_body_state / get_contacts entirely (e.g. future backend).""" + + def get_observation(self, *_, **__) -> dict[str, Any]: + return {} + + +# Registry + + +class TestRegistry: + def test_builtin_predicates_registered(self): + required = { + "body_above_z", + "body_below_z", + "joint_above", + "joint_below", + "distance_less_than", + "inside_region", + "contact_between", + "contact_any", + "body_on", + "body_inside", + "body_upright", + "grasped", + "distance_neg", + "joint_progress", + "constant", + } + assert required.issubset(PREDICATE_REGISTRY.keys()) + + def test_make_predicate_unknown_raises(self): + with pytest.raises(ValueError) as exc: + make_predicate("totally_made_up") + assert "Unknown predicate" in str(exc.value) + # Error message should list valid names so the user can fix the spec. + assert "body_above_z" in str(exc.value) + + def test_register_predicate_rejects_shadow(self): + with pytest.raises(ValueError): + register_predicate("body_above_z", lambda **_: lambda _sim: True) + + def test_register_predicate_rejects_non_callable(self): + with pytest.raises(TypeError): + register_predicate("my_pred", "not a callable") # type: ignore[arg-type] + + def test_register_predicate_custom(self): + try: + + def factory(value: float): + return lambda _sim: value > 0 + + register_predicate("positive_constant", factory) + pred = make_predicate("positive_constant", value=1.5) + assert pred(None) is True + finally: + PREDICATE_REGISTRY.pop("positive_constant", None) + + +# Body-position predicates + + +class TestBodyPositionPredicates: + def test_body_above_z_true(self): + sim = _BodyStateSim({"cube": [0.1, 0.0, 0.25]}) + pred = make_predicate("body_above_z", body="cube", z=0.2) + assert pred(sim) is True + + def test_body_above_z_false(self): + sim = _BodyStateSim({"cube": [0.1, 0.0, 0.15]}) + pred = make_predicate("body_above_z", body="cube", z=0.2) + assert pred(sim) is False + + def test_body_above_z_missing_body_returns_false(self): + sim = _BodyStateSim({"other": [0, 0, 1]}) + pred = make_predicate("body_above_z", body="cube", z=0.2) + assert pred(sim) is False + + def test_body_below_z(self): + sim = _BodyStateSim({"cube": [0.0, 0.0, -0.05]}) + pred = make_predicate("body_below_z", body="cube", z=0.0) + assert pred(sim) is True + + def test_distance_less_than_true(self): + sim = _BodyStateSim({"a": [0, 0, 0], "b": [0.05, 0, 0]}) + pred = make_predicate("distance_less_than", body_a="a", body_b="b", threshold=0.1) + assert pred(sim) is True + + def test_distance_less_than_false(self): + sim = _BodyStateSim({"a": [0, 0, 0], "b": [1.0, 0, 0]}) + pred = make_predicate("distance_less_than", body_a="a", body_b="b", threshold=0.1) + assert pred(sim) is False + + def test_inside_region_matches(self): + sim = _BodyStateSim({"cube": [0.1, 0.2, 0.3]}) + pred = make_predicate("inside_region", body="cube", min=[-0.5, 0.0, 0.0], max=[0.5, 0.5, 1.0]) + assert pred(sim) is True + + def test_inside_region_outside(self): + sim = _BodyStateSim({"cube": [0.6, 0.0, 0.0]}) + pred = make_predicate("inside_region", body="cube", min=[0, 0, 0], max=[0.5, 0.5, 0.5]) + assert pred(sim) is False + + def test_inside_region_rejects_malformed_args(self): + with pytest.raises(ValueError): + make_predicate("inside_region", body="cube", min=[0, 0], max=[1, 1, 1]) + with pytest.raises(ValueError): + # min > max should error up front, not silently always return False. + make_predicate("inside_region", body="cube", min=[1, 1, 1], max=[0, 0, 0]) + + def test_body_predicate_without_get_body_state_returns_false(self): + sim = _NoHelpersSim() + pred = make_predicate("body_above_z", body="cube", z=0) + assert pred(sim) is False + + +# Joint predicates + + +class TestJointPredicates: + def test_joint_above(self): + sim = _JointObsSim({"drawer_slide": 0.18}) + assert make_predicate("joint_above", joint="drawer_slide", value=0.15)(sim) is True + assert make_predicate("joint_above", joint="drawer_slide", value=0.2)(sim) is False + + def test_joint_below(self): + sim = _JointObsSim({"gripper": 0.02}) + assert make_predicate("joint_below", joint="gripper", value=0.05)(sim) is True + + def test_joint_missing_returns_false(self): + sim = _JointObsSim({"other_joint": 1.0}) + assert make_predicate("joint_above", joint="missing", value=0.0)(sim) is False + + def test_joint_progress_reward(self): + sim = _JointObsSim({"drawer": 0.1}) + term = make_predicate("joint_progress", joint="drawer", target=0.2, weight=10.0) + # -weight * |q - target| = -10 * 0.1 = -1.0 + assert term(sim) == pytest.approx(-1.0) + + def test_joint_progress_at_target_gives_zero_reward(self): + sim = _JointObsSim({"drawer": 0.2}) + term = make_predicate("joint_progress", joint="drawer", target=0.2, weight=1.0) + assert term(sim) == pytest.approx(0.0) + + +# Contact predicates + + +class TestContactPredicates: + def test_contact_between_matches_either_order(self): + sim = _ContactSim([{"geom1": "cube", "geom2": "gripper", "dist": -0.001}]) + assert make_predicate("contact_between", geom_a="cube", geom_b="gripper")(sim) is True + assert make_predicate("contact_between", geom_a="gripper", geom_b="cube")(sim) is True + + def test_contact_between_no_match(self): + sim = _ContactSim([{"geom1": "cube", "geom2": "ground"}]) + assert make_predicate("contact_between", geom_a="cube", geom_b="gripper")(sim) is False + + def test_contact_any(self): + assert make_predicate("contact_any")(_ContactSim([{"geom1": "a", "geom2": "b"}])) is True + assert make_predicate("contact_any")(_ContactSim([])) is False + + def test_contact_predicate_without_get_contacts(self): + sim = _NoHelpersSim() + assert make_predicate("contact_any")(sim) is False + assert make_predicate("contact_between", geom_a="a", geom_b="b")(sim) is False + + +# Reward terms + + +class TestRewardTerms: + def test_distance_neg_monotonic(self): + far = _BodyStateSim({"a": [0, 0, 0], "b": [1, 0, 0]}) + near = _BodyStateSim({"a": [0, 0, 0], "b": [0.1, 0, 0]}) + term = make_predicate("distance_neg", body_a="a", body_b="b", weight=1.0) + # Closer is greater (less negative). + assert term(near) > term(far) + + def test_distance_neg_weight(self): + sim = _BodyStateSim({"a": [0, 0, 0], "b": [1, 0, 0]}) + weighted = make_predicate("distance_neg", body_a="a", body_b="b", weight=5.0) + assert weighted(sim) == pytest.approx(-5.0) + + def test_distance_neg_missing_body_returns_zero(self): + """Missing bodies should not crash or reward heavily - return 0.0.""" + sim = _BodyStateSim({"a": [0, 0, 0]}) + term = make_predicate("distance_neg", body_a="a", body_b="ghost", weight=1.0) + assert term(sim) == 0.0 + + def test_constant(self): + term = make_predicate("constant", value=-0.01) + assert term(None) == pytest.approx(-0.01) + + +# LIBERO / #110 predicates + + +class _BodyStateWithQuatSim: + """Extends _BodyStateSim with quaternion in the body-state payload.""" + + def __init__(self, bodies: dict[str, dict[str, Any]]): + self._bodies = bodies + + def get_body_state(self, body_name: str) -> dict[str, Any]: + if body_name not in self._bodies: + return {"status": "error", "content": [{"text": "missing"}]} + payload = { + "position": self._bodies[body_name].get("position", [0, 0, 0]), + "quaternion": self._bodies[body_name].get("quaternion", [1, 0, 0, 0]), + "mass": 1.0, + } + return {"status": "success", "content": [{"text": body_name}, {"json": payload}]} + + def get_observation(self, *_, **__) -> dict[str, Any]: + return {} + + +class TestBodyOn: + def test_true_when_above_and_aligned(self): + sim = _BodyStateSim({"cube": [0.0, 0.0, 0.22], "table": [0.0, 0.0, 0.05]}) + pred = make_predicate("body_on", body_a="cube", body_b="table", z_offset=0.1) + assert pred(sim) is True + + def test_false_when_not_above(self): + sim = _BodyStateSim({"cube": [0.0, 0.0, 0.04], "table": [0.0, 0.0, 0.05]}) + pred = make_predicate("body_on", body_a="cube", body_b="table", z_offset=0.01) + assert pred(sim) is False + + def test_false_when_too_far_horizontally(self): + sim = _BodyStateSim({"cube": [1.0, 0.0, 0.2], "table": [0.0, 0.0, 0.05]}) + pred = make_predicate("body_on", body_a="cube", body_b="table", xy_tol=0.1) + assert pred(sim) is False + + def test_missing_body_returns_false(self): + sim = _BodyStateSim({"table": [0, 0, 0.05]}) + pred = make_predicate("body_on", body_a="cube", body_b="table") + assert pred(sim) is False + + +class TestBodyInside: + def test_true_inside_box(self): + sim = _BodyStateSim({"cube": [0.02, 0.01, 0.03], "basket": [0, 0, 0]}) + pred = make_predicate("body_inside", body="cube", container="basket", xy_tol=0.1, z_tol=0.1) + assert pred(sim) is True + + def test_false_outside_xy(self): + sim = _BodyStateSim({"cube": [0.5, 0.0, 0.0], "basket": [0, 0, 0]}) + pred = make_predicate("body_inside", body="cube", container="basket", xy_tol=0.1, z_tol=0.1) + assert pred(sim) is False + + def test_false_outside_z(self): + sim = _BodyStateSim({"cube": [0.0, 0.0, 0.5], "basket": [0, 0, 0]}) + pred = make_predicate("body_inside", body="cube", container="basket", xy_tol=0.2, z_tol=0.1) + assert pred(sim) is False + + +class TestBodyUpright: + def test_identity_quat_is_upright(self): + sim = _BodyStateWithQuatSim({"bottle": {"quaternion": [1.0, 0.0, 0.0, 0.0]}}) + pred = make_predicate("body_upright", body="bottle") + assert pred(sim) is True + + def test_tipped_on_side_is_not_upright(self): + # 90-deg rotation about x-axis: quat = (cos(pi/4), sin(pi/4), 0, 0) ≈ (0.707, 0.707, 0, 0) + sim = _BodyStateWithQuatSim({"bottle": {"quaternion": [0.7071, 0.7071, 0.0, 0.0]}}) + pred = make_predicate("body_upright", body="bottle", tol=0.15) + assert pred(sim) is False + + def test_small_tilt_within_tolerance(self): + # Small rotation about x-axis - x component ~= 0.1, so 2*(x²+y²) ~= 0.02 < default tol 0.15. + sim = _BodyStateWithQuatSim({"bottle": {"quaternion": [0.995, 0.1, 0.0, 0.0]}}) + pred = make_predicate("body_upright", body="bottle", tol=0.15) + assert pred(sim) is True + + def test_missing_body_returns_false(self): + sim = _BodyStateWithQuatSim({}) + pred = make_predicate("body_upright", body="bottle") + assert pred(sim) is False + + def test_negative_tol_rejected(self): + with pytest.raises(ValueError): + make_predicate("body_upright", body="bottle", tol=-0.1) + + +class TestGrasped: + def test_detects_gripper_contact_by_prefix(self): + sim = _ContactSim( + [ + {"geom1": "robot0_gripper_finger_r", "geom2": "cube_geom"}, + ] + ) + pred = make_predicate("grasped", body="cube", gripper_prefix="robot0_gripper") + assert pred(sim) is True + + def test_contact_without_gripper_prefix_is_not_grasp(self): + sim = _ContactSim([{"geom1": "table", "geom2": "cube_geom"}]) + pred = make_predicate("grasped", body="cube", gripper_prefix="robot0_gripper") + assert pred(sim) is False + + def test_matches_either_ordering(self): + sim = _ContactSim( + [ + {"geom1": "cube_geom", "geom2": "robot0_gripper_finger_l"}, + ] + ) + pred = make_predicate("grasped", body="cube", gripper_prefix="robot0_gripper") + assert pred(sim) is True + + def test_no_contacts_returns_false(self): + sim = _ContactSim([]) + pred = make_predicate("grasped", body="cube", gripper_prefix="robot0_gripper") + assert pred(sim) is False + + def test_without_get_contacts_returns_false(self): + sim = _NoHelpersSim() + pred = make_predicate("grasped", body="cube", gripper_prefix="robot0_gripper") + assert pred(sim) is False diff --git a/tests/simulation/test_policy_runner_benchmark.py b/tests/simulation/test_policy_runner_benchmark.py new file mode 100644 index 0000000..ec31e11 --- /dev/null +++ b/tests/simulation/test_policy_runner_benchmark.py @@ -0,0 +1,462 @@ +"""Tests for ``PolicyRunner.evaluate`` with a :class:`BenchmarkProtocol` spec. + +Covers the new spec-driven evaluation path: + +* Cumulative reward accounting across an episode. +* Early termination on ``is_success`` / ``is_failure`` / ``StepInfo.done``. +* Per-episode seeded RNG so identical seeds produce identical rollouts. +* Robot-compatibility error surfaces as a structured error dict (not raises). +* Legacy ``success_fn`` path still works. +* Passing both ``spec`` and ``success_fn`` is an error. +* :meth:`SimEngine.evaluate_benchmark` facade end-to-end. +* :meth:`SimEngine.list_benchmarks` / :meth:`register_benchmark_from_file` + facades return structured dicts. +""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from strands_robots.policies.mock import MockPolicy +from strands_robots.simulation.base import SimEngine +from strands_robots.simulation.benchmark import ( + _BENCHMARK_REGISTRY, + BenchmarkProtocol, + StepInfo, + register_benchmark, +) +from strands_robots.simulation.policy_runner import PolicyRunner + +# Fixtures + + +@pytest.fixture(autouse=True) +def _clean_registry(): + snapshot = dict(_BENCHMARK_REGISTRY) + _BENCHMARK_REGISTRY.clear() + yield + _BENCHMARK_REGISTRY.clear() + _BENCHMARK_REGISTRY.update(snapshot) + + +class _FakeRobot: + """Simple object with a data_config attr so the base compat check passes.""" + + def __init__(self, data_config: str): + self.data_config = data_config + + +class _FakeWorld: + """Minimal world with a ``robots`` dict - enough for the base compat check.""" + + def __init__(self, robots: dict[str, _FakeRobot]): + self.robots = dict(robots) + + +class FakeSim(SimEngine): + """Minimal ``SimEngine`` that records a few per-episode counters. + + Deliberately stripped-down: no cameras, no physics - just enough for + :class:`PolicyRunner` to step through. + """ + + def __init__(self, joint_names: tuple[str, ...] = ("j0", "j1"), data_config: str = "so100"): + self._joint_names = list(joint_names) + self._data_config = data_config + self._step_count = 0 + self._reset_count = 0 + self._world = _FakeWorld({"fake_robot": _FakeRobot(data_config)}) + + def create_world(self, timestep=None, gravity=None, ground_plane=True): + return {"status": "success"} + + def destroy(self): + return {"status": "success"} + + def reset(self): + self._step_count = 0 + self._reset_count += 1 + return {"status": "success"} + + def step(self, n_steps: int = 1): + self._step_count += n_steps + return {"status": "success"} + + def get_state(self): + return {"step_count": self._step_count} + + def add_robot(self, name, **kw): + data_config = kw.get("data_config") or self._data_config + self._world.robots[name] = _FakeRobot(data_config) + return {"status": "success"} + + def remove_robot(self, name): + return {"status": "success"} + + def list_robots(self) -> list[str]: + return list(self._world.robots.keys()) + + def robot_joint_names(self, robot_name: str) -> list[str]: + return list(self._joint_names) + + def add_object(self, name, **kw): + return {"status": "success"} + + def remove_object(self, name): + return {"status": "success"} + + def get_observation(self, robot_name=None, *, skip_images=False): + return {n: 0.0 for n in self._joint_names} + + def send_action(self, action, robot_name=None, n_substeps=1): + self._step_count += 1 + + def render(self, camera_name="default", width=None, height=None): + return {"status": "success", "content": [{"text": "render"}]} + + +# Test benchmarks + + +class _CountingBenchmark(BenchmarkProtocol): + """Benchmark that tracks how many times each hook was called and rewards +1/step.""" + + max_steps = 20 + + def __init__(self, *, success_after: int = 10**9, fail_after: int = 10**9): + self.success_after = success_after + self.fail_after = fail_after + self.on_episode_start_calls = 0 + self.on_step_calls = 0 + self.rng_seeds_seen: list[int] = [] + + @property + def supported_robots(self) -> list[str]: + return ["so100"] + + @property + def default_robot(self) -> str: + return "so100" + + def on_episode_start(self, sim, rng): + self.on_episode_start_calls += 1 + # Record the first draw for reproducibility tests. + self.rng_seeds_seen.append(rng.randint(0, 1000000)) + super().on_episode_start(sim, rng) + + def on_step(self, sim, obs, action): + self.on_step_calls += 1 + return StepInfo(reward=1.0) + + def is_success(self, sim): + return self.on_step_calls >= self.success_after + + def is_failure(self, sim): + return self.on_step_calls >= self.fail_after + + +class _DoneAfterBenchmark(BenchmarkProtocol): + """Returns StepInfo(done=True) on the Nth step.""" + + max_steps = 50 + + def __init__(self, done_after: int): + self._done_after = done_after + self._step = 0 + + @property + def supported_robots(self) -> list[str]: + return [] + + @property + def default_robot(self) -> str: + return "so100" + + def on_step(self, sim, obs, action): + self._step += 1 + return StepInfo(reward=0.5, done=self._step >= self._done_after) + + def is_success(self, sim): + return False + + def is_failure(self, sim): + return False + + +# Cumulative reward + + +class TestCumulativeReward: + def test_sums_reward_across_steps(self): + sim = FakeSim() + policy = MockPolicy() + policy.set_robot_state_keys(sim.robot_joint_names("fake_robot")) + spec = _CountingBenchmark() + + result = PolicyRunner(sim).evaluate( + "fake_robot", + policy, + spec=spec, + n_episodes=1, + ) + assert result["status"] == "success" + payload = next(c["json"] for c in result["content"] if "json" in c) + # _CountingBenchmark rewards +1/step; max_steps=20 → 20.0 cumulative. + assert payload["episodes"][0]["cumulative_reward"] == pytest.approx(20.0) + assert payload["avg_reward"] == pytest.approx(20.0) + + def test_success_terminates_and_stops_reward_accumulation(self): + sim = FakeSim() + policy = MockPolicy() + policy.set_robot_state_keys(sim.robot_joint_names("fake_robot")) + spec = _CountingBenchmark(success_after=5) + + result = PolicyRunner(sim).evaluate("fake_robot", policy, spec=spec, n_episodes=1) + payload = next(c["json"] for c in result["content"] if "json" in c) + assert payload["n_success"] == 1 + # Reward per step = 1; terminates on the 5th step. + assert payload["episodes"][0]["cumulative_reward"] == pytest.approx(5.0) + assert payload["episodes"][0]["steps"] == 5 + + def test_failure_marks_episode_unsuccessful(self): + sim = FakeSim() + policy = MockPolicy() + policy.set_robot_state_keys(sim.robot_joint_names("fake_robot")) + spec = _CountingBenchmark(fail_after=3) + + result = PolicyRunner(sim).evaluate("fake_robot", policy, spec=spec, n_episodes=1) + payload = next(c["json"] for c in result["content"] if "json" in c) + assert payload["n_failure"] == 1 + assert payload["n_success"] == 0 + assert payload["episodes"][0]["failure"] is True + assert payload["episodes"][0]["success"] is False + + def test_done_flag_terminates_episode(self): + """StepInfo.done=True ends the episode even without is_success/is_failure.""" + sim = FakeSim() + policy = MockPolicy() + policy.set_robot_state_keys(sim.robot_joint_names("fake_robot")) + spec = _DoneAfterBenchmark(done_after=4) + + result = PolicyRunner(sim).evaluate("fake_robot", policy, spec=spec, n_episodes=1) + payload = next(c["json"] for c in result["content"] if "json" in c) + assert payload["episodes"][0]["steps"] == 4 + + +# Seed reproducibility + + +class TestSeedReproducibility: + def test_same_seed_same_rng_draws(self): + """Two evaluations with the same seed must produce identical per-episode RNG draws.""" + sim1 = FakeSim() + sim2 = FakeSim() + policy = MockPolicy() + policy.set_robot_state_keys(sim1.robot_joint_names("fake_robot")) + + spec1 = _CountingBenchmark() + spec2 = _CountingBenchmark() + + PolicyRunner(sim1).evaluate("fake_robot", policy, spec=spec1, n_episodes=3, seed=42) + PolicyRunner(sim2).evaluate("fake_robot", policy, spec=spec2, n_episodes=3, seed=42) + + assert spec1.rng_seeds_seen == spec2.rng_seeds_seen + + def test_different_seed_different_rng_draws(self): + sim1 = FakeSim() + sim2 = FakeSim() + policy = MockPolicy() + policy.set_robot_state_keys(sim1.robot_joint_names("fake_robot")) + + spec1 = _CountingBenchmark() + spec2 = _CountingBenchmark() + + PolicyRunner(sim1).evaluate("fake_robot", policy, spec=spec1, n_episodes=3, seed=42) + PolicyRunner(sim2).evaluate("fake_robot", policy, spec=spec2, n_episodes=3, seed=999) + + assert spec1.rng_seeds_seen != spec2.rng_seeds_seen + + def test_seed_recorded_in_per_episode_results(self): + sim = FakeSim() + policy = MockPolicy() + policy.set_robot_state_keys(sim.robot_joint_names("fake_robot")) + + result = PolicyRunner(sim).evaluate("fake_robot", policy, spec=_CountingBenchmark(), n_episodes=3, seed=7) + payload = next(c["json"] for c in result["content"] if "json" in c) + # Every episode has a distinct seed derived from the master seed. + seeds = [e["seed"] for e in payload["episodes"]] + assert len(set(seeds)) == 3 + assert payload["seed"] == 7 + + +# Robot compatibility via the spec path + + +class TestRobotCompatibility: + def test_mismatched_robot_returns_structured_error(self): + """Spec with supported=['panda'] vs sim loaded with 'so100' → structured error.""" + + class _PandaOnly(_CountingBenchmark): + @property + def supported_robots(self) -> list[str]: + return ["panda"] + + @property + def default_robot(self) -> str: + return "panda" + + sim = FakeSim(data_config="so100") # loaded with so100 + policy = MockPolicy() + policy.set_robot_state_keys(sim.robot_joint_names("fake_robot")) + + result = PolicyRunner(sim).evaluate("fake_robot", policy, spec=_PandaOnly(), n_episodes=1) + assert result["status"] == "error" + text = result["content"][0]["text"] + assert "compatibility" in text.lower() or "supported" in text.lower() + assert "so100" in text # shows the offending data_config + assert "panda" in text # shows the allowed list + + +# Legacy success_fn path still works + + +class TestBackwardCompatibility: + def test_legacy_success_fn_callable_still_works(self): + """The pre-PR success_fn=callable path must be unchanged.""" + sim = FakeSim() + policy = MockPolicy() + policy.set_robot_state_keys(sim.robot_joint_names("fake_robot")) + + result = PolicyRunner(sim).evaluate( + "fake_robot", + policy, + n_episodes=2, + max_steps=5, + success_fn=lambda _obs: True, + ) + payload = next(c["json"] for c in result["content"] if "json" in c) + assert payload["success_rate"] == 1.0 + # Legacy path doesn't emit cumulative_reward; just the pre-PR schema. + assert "cumulative_reward" not in payload + + def test_legacy_success_fn_none_returns_zero_success(self): + sim = FakeSim() + policy = MockPolicy() + policy.set_robot_state_keys(sim.robot_joint_names("fake_robot")) + + result = PolicyRunner(sim).evaluate("fake_robot", policy, n_episodes=1, max_steps=3, success_fn=None) + payload = next(c["json"] for c in result["content"] if "json" in c) + assert payload["n_success"] == 0 + + def test_cannot_pass_both_spec_and_success_fn(self): + sim = FakeSim() + policy = MockPolicy() + policy.set_robot_state_keys(sim.robot_joint_names("fake_robot")) + + result = PolicyRunner(sim).evaluate( + "fake_robot", + policy, + spec=_CountingBenchmark(), + success_fn=lambda _obs: True, + ) + assert result["status"] == "error" + assert "both" in result["content"][0]["text"].lower() + + +# SimEngine facade + + +class TestSimEngineFacades: + def test_evaluate_benchmark_dispatches_to_runner(self): + sim = FakeSim() + spec = _CountingBenchmark() + register_benchmark("eval-test", spec) + + result = sim.evaluate_benchmark( + benchmark_name="eval-test", + robot_name="fake_robot", + policy_provider="mock", + n_episodes=1, + seed=3, + ) + assert result["status"] == "success" + payload = next(c["json"] for c in result["content"] if "json" in c) + assert payload["benchmark_class"] == "_CountingBenchmark" + # Full per-step traversal (nothing terminates early). + assert payload["episodes"][0]["steps"] == 20 + + def test_evaluate_benchmark_unknown_name(self): + sim = FakeSim() + result = sim.evaluate_benchmark(benchmark_name="never-registered") + assert result["status"] == "error" + assert "no benchmark registered" in result["content"][0]["text"].lower() + + def test_evaluate_benchmark_auto_picks_sole_robot(self): + """Single-robot scene: robot_name can be omitted.""" + sim = FakeSim() + register_benchmark("auto-robot", _CountingBenchmark()) + result = sim.evaluate_benchmark(benchmark_name="auto-robot", n_episodes=1) + assert result["status"] == "success" + + def test_evaluate_benchmark_requires_robot_name_in_multi_robot(self): + sim = FakeSim() + sim.add_robot("second", data_config="so100") + register_benchmark("multi", _CountingBenchmark()) + result = sim.evaluate_benchmark(benchmark_name="multi") + assert result["status"] == "error" + assert "robot_name" in result["content"][0]["text"] + + def test_list_benchmarks_returns_snapshot(self): + sim = FakeSim() + register_benchmark("a", _CountingBenchmark()) + register_benchmark("b", _CountingBenchmark()) + result = sim.list_benchmarks() + assert result["status"] == "success" + payload = next(c["json"] for c in result["content"] if "json" in c) + assert set(payload["benchmarks"].keys()) == {"a", "b"} + + def test_list_benchmarks_empty(self): + sim = FakeSim() + result = sim.list_benchmarks() + assert result["status"] == "success" + assert "No benchmarks" in result["content"][0]["text"] + + def test_register_benchmark_from_file_success(self, tmp_path: Path): + sim = FakeSim() + spec_path = tmp_path / "s.json" + spec_path.write_text( + json.dumps( + { + "name": "file-bench", + "default_robot": "so100", + "supported_robots": ["so100"], + "max_steps": 7, + } + ) + ) + result = sim.register_benchmark_from_file(benchmark_name="file-bench", spec_path=str(spec_path)) + assert result["status"] == "success" + assert "Registered benchmark" in result["content"][0]["text"] + + def test_register_benchmark_from_file_missing_file(self, tmp_path: Path): + sim = FakeSim() + result = sim.register_benchmark_from_file(benchmark_name="missing", spec_path=str(tmp_path / "nope.json")) + assert result["status"] == "error" + assert "not found" in result["content"][0]["text"].lower() + + def test_register_benchmark_from_file_empty_name(self): + sim = FakeSim() + result = sim.register_benchmark_from_file(benchmark_name="", spec_path="/tmp/x.json") + assert result["status"] == "error" + assert "benchmark_name" in result["content"][0]["text"] + + def test_register_benchmark_from_file_bad_schema(self, tmp_path: Path): + sim = FakeSim() + spec_path = tmp_path / "bad.json" + spec_path.write_text('{"name": "x"}') # missing default_robot + result = sim.register_benchmark_from_file(benchmark_name="bad", spec_path=str(spec_path)) + assert result["status"] == "error" + assert "default_robot" in result["content"][0]["text"]