Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ A condensed view of the knobs you'll most often tune. For trainer-side paralleli
| `orchestrator.batch_size` | Tasks per trainer step. |
| `orchestrator.group_size` | Rollouts generated per task. |
| `orchestrator.max_off_policy_steps` | How many distinct policies may have contributed to one rollout before it's discarded (default 8). The main off-policy dial on long agentic rollouts — bump for throughput, lower for tighter on-policyness. Watch `errored_rollouts` and `mismatch_kl/all/mean` when tuning. |
| `orchestrator.max_consecutive_errored_rollouts` | Halt after this many consecutive terminal errored rollouts after verifiers retries are exhausted (default 10; set to `None` to disable). |
| `orchestrator.training_mode` | `rl` (default), `opd`, or `sft`. See [Training modes](#training-modes-rl--opd--sft). |
| `[[orchestrator.train.env]]` | Training environments. List multiple tables for multi-env training; weight them via `ratio`. See [Configuration § Environments](configuration.md#environments-orchestratortrainenv). |
| `[[orchestrator.eval.env]]` + `orchestrator.eval.interval` | Eval environments and cadence (default every 100 steps). |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,9 @@ def _preserve_mito_renderer(self, handler: SerializerFunctionWrapHandler) -> dic
max_off_policy_steps: int = Field(8, ge=0)
"""Maximum policies allowed to generate a single rollout. Rollouts generated more than ``max_off_policy_steps`` ahead of training are discarded. Higher values yield better throughput at the cost of off-policy noise."""

max_consecutive_errored_rollouts: int | None = Field(10, ge=1)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think 10 is a bit rough here?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scared of cancelling runs on intermittent errors

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you suggest?

"""Halt after this many consecutive terminal errored rollouts, counted after the verifiers retry budget is exhausted. None disables the guard."""

bench: bool = False
"""Benchmark mode. Sets ``max_steps`` to 5 and disables W&B."""

Expand Down
37 changes: 37 additions & 0 deletions src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class Orchestrator:
draining: bool
last_batch_at: float | None
consecutive_empty_batches: int
consecutive_errored_rollouts: int
eval_triggered_at: dict[tuple[str, int], float]
ckpt_manager: CheckpointManager | None
component_tasks: list[asyncio.Task]
Expand Down Expand Up @@ -180,6 +181,7 @@ def __init__(self, config: OrchestratorConfig) -> None:
# Trigger timestamps so eval success logs can report epoch duration
self.eval_triggered_at = {}
self.consecutive_empty_batches = 0
self.consecutive_errored_rollouts = 0
self.component_tasks = []

# Optional attributes — ``setup()`` populates them when the relevant
Expand Down Expand Up @@ -502,6 +504,7 @@ async def main_loop(self) -> None:
except asyncio.TimeoutError:
continue

self.update_consecutive_errored_rollouts(rollout)
if isinstance(rollout, EvalRollout):
assert self.eval_sink is not None # eval rollouts only emitted when eval is configured
eval_batch = self.eval_sink.add(rollout)
Expand All @@ -516,6 +519,40 @@ async def main_loop(self) -> None:
if train_batch is not None and not self.draining and not self.stopped.is_set():
await self.finalize_train_batch(train_batch)

def update_consecutive_errored_rollouts(self, rollout: FinishedRollout) -> None:
"""Track terminal rollout errors before sink accounting.

The dispatcher emits one terminal errored rollout only after verifiers
has exhausted its retry budget. Synthetic ``Cancelled`` markers from
off-policy cleanup are ignored.
"""
limit = self.config.max_consecutive_errored_rollouts
if limit is None:
return

error = rollout.error
if error is None:
self.consecutive_errored_rollouts = 0
return

error_type = error.get("error", "Unknown")
if error_type == "Cancelled":
return

self.consecutive_errored_rollouts += 1
rollout_kind = "eval" if isinstance(rollout, EvalRollout) else "train"
get_logger().warning(
f"Terminal {rollout_kind} rollout error | env={rollout.env_name} example_id={rollout.example_id} "
f"error={error_type} "
f"(consecutive errored rollouts: {self.consecutive_errored_rollouts}/{limit})"
)
if self.consecutive_errored_rollouts >= limit:
raise RuntimeError(
f"{self.consecutive_errored_rollouts} consecutive terminal errored rollouts - "
"check environment failures and verifiers retry settings, or set "
"orchestrator.max_consecutive_errored_rollouts = None to disable this guard."
)

async def _drain_token_export_metrics(self) -> None:
"""Token exports lag the orchestrator, so once the loop ends the trainer is
still finalizing the last shipped steps and ``build`` no longer runs to pick
Expand Down
139 changes: 139 additions & 0 deletions tests/unit/orchestrator/test_orchestrator_error_guard.py

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import asyncio
import uuid
from types import SimpleNamespace

import pytest

from prime_rl.orchestrator.orchestrator import Orchestrator
from prime_rl.orchestrator.types import EvalRollout, TrainRollout


def _make_orchestrator(*, limit: int | None) -> Orchestrator:
orchestrator = object.__new__(Orchestrator)
orchestrator.config = SimpleNamespace(max_consecutive_errored_rollouts=limit)
orchestrator.consecutive_errored_rollouts = 0
return orchestrator


def _make_raw(error_type: str | None = None) -> dict:
raw = {
"trajectory": [],
"reward": 0.0,
"metrics": {},
}
if error_type is not None:
raw["error"] = {
"error": error_type,
"error_chain_repr": error_type,
"error_chain_str": error_type,
}
return raw


def _make_train_rollout(error_type: str | None = None) -> TrainRollout:
return TrainRollout(
raw=_make_raw(error_type),
env_name="test-env",
example_id=123,
group_id=uuid.uuid4(),
policy_version=0,
off_policy_steps=0,
)


def _make_eval_rollout(error_type: str | None = None) -> EvalRollout:
return EvalRollout(
raw=_make_raw(error_type),
env_name="eval-env",
example_id=456,
group_id=uuid.uuid4(),
policy_version=0,
off_policy_steps=0,
eval_step=0,
)


def test_consecutive_errored_train_rollouts_raise_at_configured_threshold():
orchestrator = _make_orchestrator(limit=2)

orchestrator.update_consecutive_errored_rollouts(_make_train_rollout("TimeoutError"))
assert orchestrator.consecutive_errored_rollouts == 1

with pytest.raises(RuntimeError, match="2 consecutive terminal errored rollouts"):
orchestrator.update_consecutive_errored_rollouts(_make_train_rollout("TimeoutError"))


def test_successful_train_rollout_resets_errored_rollout_streak():
orchestrator = _make_orchestrator(limit=10)
orchestrator.update_consecutive_errored_rollouts(_make_train_rollout("TimeoutError"))

orchestrator.update_consecutive_errored_rollouts(_make_train_rollout())

assert orchestrator.consecutive_errored_rollouts == 0


def test_successful_eval_rollout_resets_errored_rollout_streak():
orchestrator = _make_orchestrator(limit=10)
orchestrator.update_consecutive_errored_rollouts(_make_train_rollout("TimeoutError"))

orchestrator.update_consecutive_errored_rollouts(_make_eval_rollout())

assert orchestrator.consecutive_errored_rollouts == 0


def test_cancelled_train_rollout_markers_do_not_increment_errored_rollout_streak():
orchestrator = _make_orchestrator(limit=10)
orchestrator.consecutive_errored_rollouts = 1

orchestrator.update_consecutive_errored_rollouts(_make_train_rollout("Cancelled"))

assert orchestrator.consecutive_errored_rollouts == 1


def test_none_disables_errored_rollout_guard():
orchestrator = _make_orchestrator(limit=None)

for _ in range(3):
orchestrator.update_consecutive_errored_rollouts(_make_train_rollout("TimeoutError"))

assert orchestrator.consecutive_errored_rollouts == 0


def test_main_loop_raises_on_errored_train_rollout_before_sink_add():
async def run() -> None:
orchestrator = _make_orchestrator(limit=1)
orchestrator.stopped = asyncio.Event()
orchestrator.draining = False
orchestrator.dispatcher = SimpleNamespace(out_q=asyncio.Queue())

class TrainSinkShouldNotBeCalled:
async def add(self, rollout: TrainRollout):
raise AssertionError("TrainSink.add should not be called after the guard trips")

orchestrator.train_sink = TrainSinkShouldNotBeCalled()
await orchestrator.dispatcher.out_q.put(_make_train_rollout("TimeoutError"))

with pytest.raises(RuntimeError, match="1 consecutive terminal errored rollouts"):
await orchestrator.main_loop()

asyncio.run(run())


def test_main_loop_raises_on_errored_eval_rollout_before_sink_add():
async def run() -> None:
orchestrator = _make_orchestrator(limit=1)
orchestrator.stopped = asyncio.Event()
orchestrator.draining = False
orchestrator.dispatcher = SimpleNamespace(out_q=asyncio.Queue())

class EvalSinkShouldNotBeCalled:
def add(self, rollout: EvalRollout):
raise AssertionError("EvalSink.add should not be called after the guard trips")

orchestrator.eval_sink = EvalSinkShouldNotBeCalled()
await orchestrator.dispatcher.out_q.put(_make_eval_rollout("TimeoutError"))

with pytest.raises(RuntimeError, match="1 consecutive terminal errored rollouts"):
await orchestrator.main_loop()

asyncio.run(run())
14 changes: 14 additions & 0 deletions tests/unit/test_configs.py

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,20 @@ def test_trainer_enable_token_export_cli_flag():
assert cli(TrainerConfig, args=["--enable-token-export"]).enable_token_export


def test_orchestrator_max_consecutive_errored_rollouts_default():
assert OrchestratorConfig.model_validate({}).max_consecutive_errored_rollouts == 10


def test_orchestrator_max_consecutive_errored_rollouts_custom_value():
config = OrchestratorConfig.model_validate({"max_consecutive_errored_rollouts": 3})
assert config.max_consecutive_errored_rollouts == 3


def test_orchestrator_max_consecutive_errored_rollouts_none_disables():
config = OrchestratorConfig.model_validate({"max_consecutive_errored_rollouts": None})
assert config.max_consecutive_errored_rollouts is None


def test_single_node_auto_inference_client_dp_rank_count_matches_local_dp():
config = RLConfig.model_validate(
{
Expand Down
Loading