Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 47 additions & 20 deletions src/prime_rl/orchestrator/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from __future__ import annotations

import math
from collections.abc import Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol

Expand All @@ -19,6 +20,40 @@
from prime_rl.orchestrator.types import TrainRollout


@dataclass(frozen=True)
class GeneratedTokenLogprob:
token_id: int
logprob: float


def _iter_generated_token_logprobs(rollout: "TrainRollout") -> Iterator[GeneratedTokenLogprob]:
if rollout.samples:
for sample in rollout.samples:
for token_id, logprob, is_generated in zip(
sample.completion_ids,
sample.completion_logprobs,
sample.completion_mask,
):
if is_generated:
yield GeneratedTokenLogprob(token_id=token_id, logprob=logprob)
return

for step in rollout.raw.get("trajectory") or []:
tokens = step.get("tokens") if isinstance(step, dict) else None
if not tokens or "completion_ids" not in tokens or "completion_logprobs" not in tokens:
continue
completion_mask = tokens.get("completion_mask")
if completion_mask is None:
completion_mask = [True] * len(tokens["completion_ids"])
for token_id, logprob, is_generated in zip(
tokens["completion_ids"],
tokens["completion_logprobs"],
completion_mask,
):
if is_generated:
yield GeneratedTokenLogprob(token_id=token_id, logprob=logprob)


@dataclass
class FilterResult:
detected: bool
Expand Down Expand Up @@ -51,14 +86,10 @@ class GibberishFilter:

def check(self, rollout: "TrainRollout") -> FilterResult:
global_idx = 0
for step in rollout.raw["trajectory"]:
tokens = step["tokens"]
if tokens is None:
continue
for token_id, logprob in zip(tokens["completion_ids"], tokens["completion_logprobs"]):
if token_id > self.token_id_threshold and logprob < self.logprob_threshold:
return FilterResult(detected=True, detection_index=global_idx)
global_idx += 1
for token in _iter_generated_token_logprobs(rollout):
if token.token_id > self.token_id_threshold and token.logprob < self.logprob_threshold:
return FilterResult(detected=True, detection_index=global_idx)
global_idx += 1
return FilterResult(detected=False)


Expand All @@ -82,18 +113,14 @@ class RepetitionFilter:
def check(self, rollout: "TrainRollout") -> FilterResult:
consecutive = 0
global_idx = 0
for step in rollout.raw["trajectory"]:
tokens = step["tokens"]
if tokens is None:
continue
for logprob in tokens["completion_logprobs"]:
if logprob > self.logprob_threshold:
consecutive += 1
else:
consecutive = 0
if consecutive >= self.window:
return FilterResult(detected=True, detection_index=global_idx)
global_idx += 1
for token in _iter_generated_token_logprobs(rollout):
if token.logprob > self.logprob_threshold:
consecutive += 1
else:
consecutive = 0
if consecutive >= self.window:
return FilterResult(detected=True, detection_index=global_idx)
global_idx += 1
return FilterResult(detected=False)


Expand Down
2 changes: 1 addition & 1 deletion src/prime_rl/orchestrator/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def build(
"prefill_len": metrics.rollout_prefill_lens,
"decode_len": metrics.rollout_decode_lens,
"samples_per_rollout": metrics.samples_per_rollout,
"num_turns": [len(r.raw["trajectory"]) for r in rollouts],
"num_turns": [r.turn_count for r in rollouts],
}
)
metrics_df = pd.DataFrame([(r.raw.get("metrics") or {}) for r in rollouts])
Expand Down
8 changes: 2 additions & 6 deletions src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ def log_train_batch(self, batch: TrainBatch, *, step: int, step_time: float) ->
trainable_rate = (n_trainable / n_survivors) if n_survivors else 0.0
reward_mean = sum(r.reward for r in batch.rollouts) / max(n_survivors, 1)
max_off_policy = max((r.off_policy_steps for r in batch.rollouts), default=0)
turns_mean = sum(len(r.raw.get("trajectory") or []) for r in batch.rollouts) / max(n_survivors, 1)
turns_mean = sum(r.turn_count for r in batch.rollouts) / max(n_survivors, 1)
truncation_rate = sum(1 for r in batch.rollouts if r.is_truncated) / max(n_survivors, 1)

head = (
Expand All @@ -795,11 +795,7 @@ def log_train_batch(self, batch: TrainBatch, *, step: int, step_time: float) ->
env_error_rate = (n_env_errors / n_env_arrivals) if n_env_arrivals else 0.0
env_reward = (sum(r.reward for r in env_rollouts) / len(env_rollouts)) if env_rollouts else 0.0
env_max_off_policy = max((r.off_policy_steps for r in env_rollouts), default=0)
env_turns = (
sum(len(r.raw.get("trajectory") or []) for r in env_rollouts) / len(env_rollouts)
if env_rollouts
else 0.0
)
env_turns = sum(r.turn_count for r in env_rollouts) / len(env_rollouts) if env_rollouts else 0.0
env_truncation = sum(1 for r in env_rollouts if r.is_truncated) / len(env_rollouts) if env_rollouts else 0.0
lines.append(
f"╰─ {env_name:<{name_width}} | Ratio {ratio:.1%} | Reward {env_reward:.4f} | "
Expand Down
2 changes: 2 additions & 0 deletions src/prime_rl/orchestrator/train_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ async def process_rollout(self, rollout: TrainRollout) -> None:
if rollout.error is not None:
return
raw = rollout.raw
rollout.num_turns = len(raw.get("trajectory") or [])
needs_backfill = any(s["tokens"] is None for s in raw.get("trajectory") or [])
if needs_backfill:
await asyncio.to_thread(backfill_rollout_tokens, raw, self.tokenizer, renderer=self.renderer)
Expand All @@ -161,6 +162,7 @@ async def process_rollout(self, rollout: TrainRollout) -> None:
raw,
mm_token_type_ids_mapping=self.mm_token_type_ids_mapping,
env_name=rollout.env_name,
prune_raw_payload=True,
)
rollout.samples = samples or []
# Offload base64 image bytes to disk as soon as the rollout is
Expand Down
51 changes: 50 additions & 1 deletion src/prime_rl/orchestrator/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,51 @@
# primitives are immutable. mm_kwargs payloads are not mutated after creation.


def _safe_len(value: Any) -> int | None:
if value is None:
return None
try:
return len(value)
except TypeError:
return None


def prune_token_payload(tokens: dict[str, Any]) -> dict[str, Any]:
compact: dict[str, Any] = {"has_routed_experts": tokens.get("routed_experts") is not None}
for key in ("prompt_ids", "prompt_mask", "completion_ids", "completion_mask", "completion_logprobs"):
length = _safe_len(tokens.get(key))
if length is not None:
compact[f"{key}_len"] = length

routed = tokens.get("routed_experts")
if isinstance(routed, dict):
for key in ("shape", "dtype", "start"):
value = routed.get(key)
if value is not None:
compact[f"routed_experts_{key}"] = value
return compact


def prune_train_rollout_payload(raw: dict[str, Any]) -> None:
"""Replace raw trajectory token arrays with length and shape summaries."""
if raw.get("trajectory_payload_pruned"):
return

trajectory = raw.get("trajectory")
if not isinstance(trajectory, list):
return

for step in trajectory:
if not isinstance(step, dict):
continue
tokens = step.get("tokens")
if isinstance(tokens, dict):
step["tokens"] = prune_token_payload(tokens)
step.pop("response", None)

raw["trajectory_payload_pruned"] = True


def align_routed_experts(
routed_experts: np.ndarray | None,
expected_len: int,
Expand Down Expand Up @@ -206,6 +251,7 @@ def interleave_rollout(
mm_token_type_ids_mapping: dict[int, int] | None = None,
*,
env_name: str = "",
prune_raw_payload: bool = False,
) -> list[TrainingSample] | None:
"""
Convert vf.RolloutOutput to trainable rollouts by interleaving trajectory steps
Expand Down Expand Up @@ -521,7 +567,10 @@ def extend_sample(
for token_id in sample.prompt_ids + sample.completion_ids
]

return [sample for _, sample, _ in active_samples]
samples = [sample for _, sample, _ in active_samples]
if prune_raw_payload:
prune_train_rollout_payload(output)
return samples


def _union_step_mm_data(
Expand Down
9 changes: 9 additions & 0 deletions src/prime_rl/orchestrator/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class FinishedRollout:
policy_version: int
off_policy_steps: int
rollout_id: uuid.UUID = field(default_factory=uuid.uuid4)
num_turns: int | None = None

@property
def error(self) -> dict | None:
Expand All @@ -92,6 +93,12 @@ def reward(self) -> float:
def is_truncated(self) -> bool:
return bool(self.raw.get("is_truncated", False))

@property
def turn_count(self) -> int:
if self.num_turns is not None:
return self.num_turns
return len(self.raw.get("trajectory") or [])

def to_dict(self) -> vf.RolloutOutput:
"""``raw`` + metadata merged for I/O (``save_rollouts``,
``monitor.log_samples``). Shallow copy; never mutates ``self.raw``."""
Expand All @@ -103,6 +110,8 @@ def to_dict(self) -> vf.RolloutOutput:
if f.name == "filter_results":
out["filters"] = dict(val)
continue
if f.name == "num_turns" and val is None:
continue
out[f.name] = str(val) if isinstance(val, uuid.UUID) else val
return out

Expand Down
14 changes: 13 additions & 1 deletion src/prime_rl/orchestrator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,19 @@ async def setup_student_inference_pool(*, config: OrchestratorConfig, tokenizer)
def get_model_completion_len(output: vf.RolloutOutput) -> int:
"""Sum of model-generated completion tokens across all turns (excludes
environment-injected tokens between turns)."""
return sum(len(step["tokens"]["completion_ids"]) for step in output["trajectory"] if step.get("tokens"))
total = 0
for step in output["trajectory"]:
tokens = step.get("tokens")
if not tokens:
continue
completion_ids = tokens.get("completion_ids")
if completion_ids is not None:
total += len(completion_ids)
continue
completion_ids_len = tokens.get("completion_ids_len")
if isinstance(completion_ids_len, int):
total += completion_ids_len
return total


def get_tool_response_len(output: vf.RolloutOutput) -> int:
Expand Down
12 changes: 10 additions & 2 deletions src/prime_rl/utils/monitor/prime.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,16 @@ def _rollouts_to_parquet_bytes(self, rollouts: list[vf.RolloutOutput], step: int
"reward": ts.get("reward"),
"advantage": ts.get("advantage"),
"extras": ts.get("extras", {}),
"num_input_tokens": len(ts["tokens"]["prompt_ids"]) if ts.get("tokens") else None,
"num_output_tokens": len(ts["tokens"]["completion_ids"]) if ts.get("tokens") else None,
"num_input_tokens": (
len(ts["tokens"]["prompt_ids"])
if ts.get("tokens") and "prompt_ids" in ts["tokens"]
else (ts["tokens"].get("prompt_ids_len") if ts.get("tokens") else None)
),
"num_output_tokens": (
len(ts["tokens"]["completion_ids"])
if ts.get("tokens") and "completion_ids" in ts["tokens"]
else (ts["tokens"].get("completion_ids_len") if ts.get("tokens") else None)
),
}
for ts in trajectory
]
Expand Down
20 changes: 16 additions & 4 deletions src/prime_rl/utils/monitor/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,28 @@ def log_samples(self, rollouts: list[vf.RolloutOutput], step: int) -> None:
if not trajectory:
continue
last_step = trajectory[-1]
tokens = last_step["tokens"]
full_ids = tokens["prompt_ids"] + tokens["completion_ids"]
messages_text = self.tokenizer.decode(full_ids)
tokens = last_step.get("tokens") or {}
prompt_ids = tokens.get("prompt_ids")
completion_ids = tokens.get("completion_ids")
if prompt_ids is not None and completion_ids is not None:
full_ids = prompt_ids + completion_ids
messages_text = self.tokenizer.decode(full_ids)
input_ids = str(full_ids)
else:
messages_text = json.dumps(
{
"prompt": rollout.get("prompt"),
"completion": rollout.get("completion"),
}
)
input_ids = ""
sample = {
"step": step,
"env_name": rollout.get("env_name"),
"task": rollout.get("task"),
"example_id": rollout["example_id"],
"messages": messages_text,
"input_ids": str(full_ids),
"input_ids": input_ids,
"reward": rollout["reward"],
}
assert list(sample.keys()) == self.samples_cols, (
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/orchestrator/test_advantage.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from prime_rl.orchestrator.envs import Env, TrainEnv
from prime_rl.orchestrator.types import TrainRollout
from prime_rl.orchestrator.utils import get_model_completion_len


def _make_rollout(reward: float, completion_len: int = 0) -> dict:
Expand Down Expand Up @@ -183,6 +184,30 @@ def test_length_weighted_baseline():
assert sum(length * adv for length, adv in zip((10, 30, 60), result.advantages)) == pytest.approx(0.0, abs=1e-5)


def test_length_penalty_uses_compacted_completion_lengths():
"""``get_model_completion_len`` reads compacted ``completion_ids_len`` summaries
(raw token arrays pruned post-interleave) so length shaping still sees real lengths."""
rollout = _make_rollout(reward=1.0, completion_len=0)
rollout["trajectory"] = [
{"tokens": {"prompt_ids_len": 10, "completion_ids_len": 12}},
{"tokens": {"prompt_ids_len": 20, "completion_ids_len": 34}},
]

assert get_model_completion_len(rollout) == 46

inputs = AdvantageInputs(
rollouts=[
rollout,
_make_rollout(reward=1.0, completion_len=92),
_make_rollout(reward=0.0, completion_len=46),
]
)
result = default_advantage_fn(inputs, length_penalty=LinearLengthPenaltyConfig(coef=1.0), max_seq_len=100)

# Both reward-1 rollouts, but the compacted one is shorter (46 < 92) → higher advantage.
assert result.advantages[0] > result.advantages[1]


def _train_rollouts(rewards: list[float]) -> list[TrainRollout]:
"""Wrap a list of rewards into ``TrainRollout``\\ s sharing a single
``group_id`` — ``assign_advantages`` works on one group at a time
Expand Down
Loading
Loading