diff --git a/src/prime_rl/orchestrator/filters.py b/src/prime_rl/orchestrator/filters.py index f8deda1230..37c00032fb 100644 --- a/src/prime_rl/orchestrator/filters.py +++ b/src/prime_rl/orchestrator/filters.py @@ -9,6 +9,7 @@ from __future__ import annotations import math +from collections.abc import Iterator from dataclasses import dataclass from typing import TYPE_CHECKING, Protocol @@ -19,6 +20,40 @@ from prime_rl.orchestrator.types import TrainRollout +@dataclass(frozen=True) +class GeneratedTokenLogprob: + token_id: int + logprob: float + + +def _iter_generated_token_logprobs(rollout: "TrainRollout") -> Iterator[GeneratedTokenLogprob]: + if rollout.samples: + for sample in rollout.samples: + for token_id, logprob, is_generated in zip( + sample.completion_ids, + sample.completion_logprobs, + sample.completion_mask, + ): + if is_generated: + yield GeneratedTokenLogprob(token_id=token_id, logprob=logprob) + return + + for step in rollout.raw.get("trajectory") or []: + tokens = step.get("tokens") if isinstance(step, dict) else None + if not tokens or "completion_ids" not in tokens or "completion_logprobs" not in tokens: + continue + completion_mask = tokens.get("completion_mask") + if completion_mask is None: + completion_mask = [True] * len(tokens["completion_ids"]) + for token_id, logprob, is_generated in zip( + tokens["completion_ids"], + tokens["completion_logprobs"], + completion_mask, + ): + if is_generated: + yield GeneratedTokenLogprob(token_id=token_id, logprob=logprob) + + @dataclass class FilterResult: detected: bool @@ -51,14 +86,10 @@ class GibberishFilter: def check(self, rollout: "TrainRollout") -> FilterResult: global_idx = 0 - for step in rollout.raw["trajectory"]: - tokens = step["tokens"] - if tokens is None: - continue - for token_id, logprob in zip(tokens["completion_ids"], tokens["completion_logprobs"]): - if token_id > self.token_id_threshold and logprob < self.logprob_threshold: - return FilterResult(detected=True, detection_index=global_idx) - global_idx += 1 + for token in _iter_generated_token_logprobs(rollout): + if token.token_id > self.token_id_threshold and token.logprob < self.logprob_threshold: + return FilterResult(detected=True, detection_index=global_idx) + global_idx += 1 return FilterResult(detected=False) @@ -82,18 +113,14 @@ class RepetitionFilter: def check(self, rollout: "TrainRollout") -> FilterResult: consecutive = 0 global_idx = 0 - for step in rollout.raw["trajectory"]: - tokens = step["tokens"] - if tokens is None: - continue - for logprob in tokens["completion_logprobs"]: - if logprob > self.logprob_threshold: - consecutive += 1 - else: - consecutive = 0 - if consecutive >= self.window: - return FilterResult(detected=True, detection_index=global_idx) - global_idx += 1 + for token in _iter_generated_token_logprobs(rollout): + if token.logprob > self.logprob_threshold: + consecutive += 1 + else: + consecutive = 0 + if consecutive >= self.window: + return FilterResult(detected=True, detection_index=global_idx) + global_idx += 1 return FilterResult(detected=False) diff --git a/src/prime_rl/orchestrator/metrics.py b/src/prime_rl/orchestrator/metrics.py index d32dbc02de..9257a84197 100644 --- a/src/prime_rl/orchestrator/metrics.py +++ b/src/prime_rl/orchestrator/metrics.py @@ -61,7 +61,7 @@ def build( "prefill_len": metrics.rollout_prefill_lens, "decode_len": metrics.rollout_decode_lens, "samples_per_rollout": metrics.samples_per_rollout, - "num_turns": [len(r.raw["trajectory"]) for r in rollouts], + "num_turns": [r.turn_count for r in rollouts], } ) metrics_df = pd.DataFrame([(r.raw.get("metrics") or {}) for r in rollouts]) diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 54be327baf..d330cc01ac 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -771,7 +771,7 @@ def log_train_batch(self, batch: TrainBatch, *, step: int, step_time: float) -> trainable_rate = (n_trainable / n_survivors) if n_survivors else 0.0 reward_mean = sum(r.reward for r in batch.rollouts) / max(n_survivors, 1) max_off_policy = max((r.off_policy_steps for r in batch.rollouts), default=0) - turns_mean = sum(len(r.raw.get("trajectory") or []) for r in batch.rollouts) / max(n_survivors, 1) + turns_mean = sum(r.turn_count for r in batch.rollouts) / max(n_survivors, 1) truncation_rate = sum(1 for r in batch.rollouts if r.is_truncated) / max(n_survivors, 1) head = ( @@ -795,11 +795,7 @@ def log_train_batch(self, batch: TrainBatch, *, step: int, step_time: float) -> env_error_rate = (n_env_errors / n_env_arrivals) if n_env_arrivals else 0.0 env_reward = (sum(r.reward for r in env_rollouts) / len(env_rollouts)) if env_rollouts else 0.0 env_max_off_policy = max((r.off_policy_steps for r in env_rollouts), default=0) - env_turns = ( - sum(len(r.raw.get("trajectory") or []) for r in env_rollouts) / len(env_rollouts) - if env_rollouts - else 0.0 - ) + env_turns = sum(r.turn_count for r in env_rollouts) / len(env_rollouts) if env_rollouts else 0.0 env_truncation = sum(1 for r in env_rollouts if r.is_truncated) / len(env_rollouts) if env_rollouts else 0.0 lines.append( f"╰─ {env_name:<{name_width}} | Ratio {ratio:.1%} | Reward {env_reward:.4f} | " diff --git a/src/prime_rl/orchestrator/train_sink.py b/src/prime_rl/orchestrator/train_sink.py index f79a0d5eff..9c18ff5815 100644 --- a/src/prime_rl/orchestrator/train_sink.py +++ b/src/prime_rl/orchestrator/train_sink.py @@ -153,6 +153,7 @@ async def process_rollout(self, rollout: TrainRollout) -> None: if rollout.error is not None: return raw = rollout.raw + rollout.num_turns = len(raw.get("trajectory") or []) needs_backfill = any(s["tokens"] is None for s in raw.get("trajectory") or []) if needs_backfill: await asyncio.to_thread(backfill_rollout_tokens, raw, self.tokenizer, renderer=self.renderer) @@ -161,6 +162,7 @@ async def process_rollout(self, rollout: TrainRollout) -> None: raw, mm_token_type_ids_mapping=self.mm_token_type_ids_mapping, env_name=rollout.env_name, + prune_raw_payload=True, ) rollout.samples = samples or [] # Offload base64 image bytes to disk as soon as the rollout is diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index 3e8431c12a..90422b2bd2 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -23,6 +23,51 @@ # primitives are immutable. mm_kwargs payloads are not mutated after creation. +def _safe_len(value: Any) -> int | None: + if value is None: + return None + try: + return len(value) + except TypeError: + return None + + +def prune_token_payload(tokens: dict[str, Any]) -> dict[str, Any]: + compact: dict[str, Any] = {"has_routed_experts": tokens.get("routed_experts") is not None} + for key in ("prompt_ids", "prompt_mask", "completion_ids", "completion_mask", "completion_logprobs"): + length = _safe_len(tokens.get(key)) + if length is not None: + compact[f"{key}_len"] = length + + routed = tokens.get("routed_experts") + if isinstance(routed, dict): + for key in ("shape", "dtype", "start"): + value = routed.get(key) + if value is not None: + compact[f"routed_experts_{key}"] = value + return compact + + +def prune_train_rollout_payload(raw: dict[str, Any]) -> None: + """Replace raw trajectory token arrays with length and shape summaries.""" + if raw.get("trajectory_payload_pruned"): + return + + trajectory = raw.get("trajectory") + if not isinstance(trajectory, list): + return + + for step in trajectory: + if not isinstance(step, dict): + continue + tokens = step.get("tokens") + if isinstance(tokens, dict): + step["tokens"] = prune_token_payload(tokens) + step.pop("response", None) + + raw["trajectory_payload_pruned"] = True + + def align_routed_experts( routed_experts: np.ndarray | None, expected_len: int, @@ -206,6 +251,7 @@ def interleave_rollout( mm_token_type_ids_mapping: dict[int, int] | None = None, *, env_name: str = "", + prune_raw_payload: bool = False, ) -> list[TrainingSample] | None: """ Convert vf.RolloutOutput to trainable rollouts by interleaving trajectory steps @@ -521,7 +567,10 @@ def extend_sample( for token_id in sample.prompt_ids + sample.completion_ids ] - return [sample for _, sample, _ in active_samples] + samples = [sample for _, sample, _ in active_samples] + if prune_raw_payload: + prune_train_rollout_payload(output) + return samples def _union_step_mm_data( diff --git a/src/prime_rl/orchestrator/types.py b/src/prime_rl/orchestrator/types.py index 54c3f827b9..9231534911 100644 --- a/src/prime_rl/orchestrator/types.py +++ b/src/prime_rl/orchestrator/types.py @@ -79,6 +79,7 @@ class FinishedRollout: policy_version: int off_policy_steps: int rollout_id: uuid.UUID = field(default_factory=uuid.uuid4) + num_turns: int | None = None @property def error(self) -> dict | None: @@ -92,6 +93,12 @@ def reward(self) -> float: def is_truncated(self) -> bool: return bool(self.raw.get("is_truncated", False)) + @property + def turn_count(self) -> int: + if self.num_turns is not None: + return self.num_turns + return len(self.raw.get("trajectory") or []) + def to_dict(self) -> vf.RolloutOutput: """``raw`` + metadata merged for I/O (``save_rollouts``, ``monitor.log_samples``). Shallow copy; never mutates ``self.raw``.""" @@ -103,6 +110,8 @@ def to_dict(self) -> vf.RolloutOutput: if f.name == "filter_results": out["filters"] = dict(val) continue + if f.name == "num_turns" and val is None: + continue out[f.name] = str(val) if isinstance(val, uuid.UUID) else val return out diff --git a/src/prime_rl/orchestrator/utils.py b/src/prime_rl/orchestrator/utils.py index 9c68f361de..087ec5beb1 100644 --- a/src/prime_rl/orchestrator/utils.py +++ b/src/prime_rl/orchestrator/utils.py @@ -59,7 +59,19 @@ async def setup_student_inference_pool(*, config: OrchestratorConfig, tokenizer) def get_model_completion_len(output: vf.RolloutOutput) -> int: """Sum of model-generated completion tokens across all turns (excludes environment-injected tokens between turns).""" - return sum(len(step["tokens"]["completion_ids"]) for step in output["trajectory"] if step.get("tokens")) + total = 0 + for step in output["trajectory"]: + tokens = step.get("tokens") + if not tokens: + continue + completion_ids = tokens.get("completion_ids") + if completion_ids is not None: + total += len(completion_ids) + continue + completion_ids_len = tokens.get("completion_ids_len") + if isinstance(completion_ids_len, int): + total += completion_ids_len + return total def get_tool_response_len(output: vf.RolloutOutput) -> int: diff --git a/src/prime_rl/utils/monitor/prime.py b/src/prime_rl/utils/monitor/prime.py index cd1fa8375a..752b46ddd2 100644 --- a/src/prime_rl/utils/monitor/prime.py +++ b/src/prime_rl/utils/monitor/prime.py @@ -354,8 +354,16 @@ def _rollouts_to_parquet_bytes(self, rollouts: list[vf.RolloutOutput], step: int "reward": ts.get("reward"), "advantage": ts.get("advantage"), "extras": ts.get("extras", {}), - "num_input_tokens": len(ts["tokens"]["prompt_ids"]) if ts.get("tokens") else None, - "num_output_tokens": len(ts["tokens"]["completion_ids"]) if ts.get("tokens") else None, + "num_input_tokens": ( + len(ts["tokens"]["prompt_ids"]) + if ts.get("tokens") and "prompt_ids" in ts["tokens"] + else (ts["tokens"].get("prompt_ids_len") if ts.get("tokens") else None) + ), + "num_output_tokens": ( + len(ts["tokens"]["completion_ids"]) + if ts.get("tokens") and "completion_ids" in ts["tokens"] + else (ts["tokens"].get("completion_ids_len") if ts.get("tokens") else None) + ), } for ts in trajectory ] diff --git a/src/prime_rl/utils/monitor/wandb.py b/src/prime_rl/utils/monitor/wandb.py index 7935d13a6a..c0066ee6c8 100644 --- a/src/prime_rl/utils/monitor/wandb.py +++ b/src/prime_rl/utils/monitor/wandb.py @@ -176,16 +176,28 @@ def log_samples(self, rollouts: list[vf.RolloutOutput], step: int) -> None: if not trajectory: continue last_step = trajectory[-1] - tokens = last_step["tokens"] - full_ids = tokens["prompt_ids"] + tokens["completion_ids"] - messages_text = self.tokenizer.decode(full_ids) + tokens = last_step.get("tokens") or {} + prompt_ids = tokens.get("prompt_ids") + completion_ids = tokens.get("completion_ids") + if prompt_ids is not None and completion_ids is not None: + full_ids = prompt_ids + completion_ids + messages_text = self.tokenizer.decode(full_ids) + input_ids = str(full_ids) + else: + messages_text = json.dumps( + { + "prompt": rollout.get("prompt"), + "completion": rollout.get("completion"), + } + ) + input_ids = "" sample = { "step": step, "env_name": rollout.get("env_name"), "task": rollout.get("task"), "example_id": rollout["example_id"], "messages": messages_text, - "input_ids": str(full_ids), + "input_ids": input_ids, "reward": rollout["reward"], } assert list(sample.keys()) == self.samples_cols, ( diff --git a/tests/unit/orchestrator/test_advantage.py b/tests/unit/orchestrator/test_advantage.py index 0f5fd9bf92..7666173f13 100644 --- a/tests/unit/orchestrator/test_advantage.py +++ b/tests/unit/orchestrator/test_advantage.py @@ -17,6 +17,7 @@ ) from prime_rl.orchestrator.envs import Env, TrainEnv from prime_rl.orchestrator.types import TrainRollout +from prime_rl.orchestrator.utils import get_model_completion_len def _make_rollout(reward: float, completion_len: int = 0) -> dict: @@ -183,6 +184,30 @@ def test_length_weighted_baseline(): assert sum(length * adv for length, adv in zip((10, 30, 60), result.advantages)) == pytest.approx(0.0, abs=1e-5) +def test_length_penalty_uses_compacted_completion_lengths(): + """``get_model_completion_len`` reads compacted ``completion_ids_len`` summaries + (raw token arrays pruned post-interleave) so length shaping still sees real lengths.""" + rollout = _make_rollout(reward=1.0, completion_len=0) + rollout["trajectory"] = [ + {"tokens": {"prompt_ids_len": 10, "completion_ids_len": 12}}, + {"tokens": {"prompt_ids_len": 20, "completion_ids_len": 34}}, + ] + + assert get_model_completion_len(rollout) == 46 + + inputs = AdvantageInputs( + rollouts=[ + rollout, + _make_rollout(reward=1.0, completion_len=92), + _make_rollout(reward=0.0, completion_len=46), + ] + ) + result = default_advantage_fn(inputs, length_penalty=LinearLengthPenaltyConfig(coef=1.0), max_seq_len=100) + + # Both reward-1 rollouts, but the compacted one is shorter (46 < 92) → higher advantage. + assert result.advantages[0] > result.advantages[1] + + def _train_rollouts(rewards: list[float]) -> list[TrainRollout]: """Wrap a list of rewards into ``TrainRollout``\\ s sharing a single ``group_id`` — ``assign_advantages`` works on one group at a time diff --git a/tests/unit/orchestrator/test_filters.py b/tests/unit/orchestrator/test_filters.py index 2643bf71bb..ce1ad0a8f4 100644 --- a/tests/unit/orchestrator/test_filters.py +++ b/tests/unit/orchestrator/test_filters.py @@ -10,6 +10,7 @@ setup_filters, ) from prime_rl.orchestrator.types import TrainRollout +from prime_rl.transport import TrainingSample def _make_rollout( @@ -134,6 +135,28 @@ def test_gibberish_works_across_trajectory_steps(): assert result.detection_index == 2 +def test_gibberish_uses_samples_when_raw_tokens_are_pruned(): + gibberish_filter = _make_gibberish_filter() + rollout = _make_rollout(completion_ids=[], completion_logprobs=[]) + rollout.raw["trajectory"][0]["tokens"] = {"prompt_ids_len": 4, "completion_ids_len": 2} + rollout.samples = [ + TrainingSample( + prompt_ids=[1, 2], + prompt_mask=[False, False], + completion_ids=[777, 120_000], + completion_mask=[False, True], + completion_logprobs=[0.0, gibberish_filter.logprob_threshold - 1.0], + completion_temperatures=[1.0, 1.0], + env_name="test", + ) + ] + + result = gibberish_filter.check(rollout) + + assert result.detected is True + assert result.detection_index == 0 + + # --- RepetitionFilter tests --- @@ -187,6 +210,28 @@ def test_repetition_varied_probs_no_trigger(): assert result.detected is False +def test_repetition_uses_samples_when_raw_tokens_are_pruned(): + repetition_filter = _make_repetition_filter(window=3) + rollout = _make_rollout(completion_ids=[], completion_logprobs=[]) + rollout.raw["trajectory"][0]["tokens"] = {"prompt_ids_len": 4, "completion_ids_len": 4} + rollout.samples = [ + TrainingSample( + prompt_ids=[1, 2], + prompt_mask=[False, False], + completion_ids=[10, 11, 12, 13], + completion_mask=[False, True, True, True], + completion_logprobs=[0.0, -0.001, -0.001, -0.001], + completion_temperatures=[1.0, 1.0, 1.0, 1.0], + env_name="test", + ) + ] + + result = repetition_filter.check(rollout) + + assert result.detected is True + assert result.detection_index == 2 + + # --- setup_filter / setup_filters tests --- diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py index bda129cd43..16bb028d4d 100644 --- a/tests/unit/orchestrator/test_trajectories.py +++ b/tests/unit/orchestrator/test_trajectories.py @@ -9,6 +9,8 @@ _deserialize_tool_calls, align_routed_experts, interleave_rollout, + prune_token_payload, + prune_train_rollout_payload, ) _interleave_rollout = interleave_rollout @@ -55,6 +57,133 @@ def test_deserialize_tool_calls_does_not_inject_missing_key(): assert "tool_calls" not in deserialized[0] +def test_prune_train_rollout_payload_compacts_heavy_token_arrays(): + raw = { + "trajectory": [ + { + "prompt": [{"role": "user", "content": "hello"}], + "completion": [{"role": "assistant", "content": "world"}], + "response": object(), + "tokens": { + "prompt_ids": [1, 2, 3], + "prompt_mask": [0, 0, 0], + "completion_ids": [4, 5], + "completion_mask": [1, 1], + "completion_logprobs": [-0.1, -0.2], + "routed_experts": { + "data": "large-base64-payload", + "shape": [5, 78, 8], + "dtype": "uint8", + "start": 0, + }, + "overlong_prompt": False, + "is_truncated": False, + "debug_scalar": "drop-me", + }, + } + ] + } + + prune_train_rollout_payload(raw) + + step = raw["trajectory"][0] + assert "response" not in step + assert raw["trajectory_payload_pruned"] is True + assert step["tokens"] == { + "prompt_ids_len": 3, + "prompt_mask_len": 3, + "completion_ids_len": 2, + "completion_mask_len": 2, + "completion_logprobs_len": 2, + "has_routed_experts": True, + "routed_experts_shape": [5, 78, 8], + "routed_experts_dtype": "uint8", + "routed_experts_start": 0, + } + + +def test_prune_token_payload_returns_allowlisted_compact_fields_only(): + compact = prune_token_payload( + { + "prompt_ids": [1, 2, 3], + "prompt_mask": [0, 0, 0], + "completion_ids": [4], + "completion_mask": [1], + "completion_logprobs": [-0.1], + "routed_experts": None, + "multi_modal_data": object(), + } + ) + + assert compact == { + "has_routed_experts": False, + "prompt_ids_len": 3, + "prompt_mask_len": 3, + "completion_ids_len": 1, + "completion_mask_len": 1, + "completion_logprobs_len": 1, + } + + +def test_interleave_rollout_prunes_raw_payload_after_preparation(): + raw = { + "example_id": 0, + "error": None, + "trajectory": [ + { + "tokens": { + "prompt_ids": [1, 2], + "prompt_mask": [0, 0], + "completion_ids": [3], + "completion_mask": [1], + "completion_logprobs": [-0.1], + "routed_experts": _routed_experts_payload([[[7]], [[8]]]), + }, + "response": object(), + } + ], + } + + samples = interleave_rollout(raw, prune_raw_payload=True) + + assert samples is not None + assert samples[0].routed_experts is not None + assert raw["trajectory_payload_pruned"] is True + step = raw["trajectory"][0] + assert "response" not in step + assert step["tokens"]["has_routed_experts"] is True + assert step["tokens"]["routed_experts_shape"] == [2, 1, 1] + assert "routed_experts" not in step["tokens"] + + +def test_interleave_rollout_does_not_partially_prune_on_prepare_failure(): + raw = { + "example_id": 0, + "error": None, + "trajectory": [ + { + "tokens": { + "prompt_ids": [1, 2], + "prompt_mask": [0, 0], + "completion_ids": [3], + "completion_mask": [1], + "completion_logprobs": [-0.1], + "routed_experts": _routed_experts_payload([[[7]], [[8]]]), + }, + "response": object(), + }, + {"tokens": None}, + ], + } + + samples = interleave_rollout(raw, prune_raw_payload=True) + + assert samples is None + assert "trajectory_payload_pruned" not in raw + assert "response" in raw["trajectory"][0] + assert "routed_experts" in raw["trajectory"][0]["tokens"] + + def test_deserialize_tool_calls_parses_arguments_when_present(): messages = [ {