diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index 65363862e7..599391363a 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -13,6 +13,7 @@ FileSystemTransportConfig, HeartbeatConfig, LogConfig, + NoOpTransportConfig, PrimeMonitorConfig, TransportConfig, WandbWithExtrasConfig, @@ -489,6 +490,20 @@ class OrchestratorExperimentalConfig(BaseConfig): pass +class OrchestratorDebugConfig(BaseConfig): + no_inference: bool = False + """Use an in-process no-op inference pool. Rollout environments must not call the model client.""" + + no_trainer: bool = False + """Run without a trainer. Uses no-op batch sending and advances policy version locally.""" + + fake_tokenizer: bool = False + """Use a tiny tokenizer stub. Only safe when rollouts already include tokens and sample monitors are disabled.""" + + log_memory: bool = False + """Log orchestrator and child-process RSS after each completed orchestrator step.""" + + class RolloutModelConfig(BaseConfig): model: ModelConfig = ModelConfig() @@ -624,6 +639,8 @@ def _preserve_mito_renderer(self, handler: SerializerFunctionWrapHandler) -> dic env_install_prerelease: bool = False """Allow pre-release versions when installing environments (e.g. ``verifiers>=0.1.12.dev5``). Passes ``--prerelease`` to ``prime env install``.""" + debug: OrchestratorDebugConfig = OrchestratorDebugConfig() + experimental: OrchestratorExperimentalConfig = OrchestratorExperimentalConfig() @model_validator(mode="before") @@ -760,6 +777,8 @@ def _force_no_renderer_for_sft(self): @model_validator(mode="after") def validate_training_mode(self): """Enforce training mode invariants that involve only orchestrator fields.""" + if self.debug.no_inference: + return self has_teacher = self.teacher is not None if self.training_mode == "rl" and has_teacher: raise ValueError("orchestrator.teacher must not be set when training_mode = 'rl'.") @@ -772,6 +791,8 @@ def validate_pool_size(self): """``pool_size`` is only meaningful when the renderer is enabled (``renderer is not None``). Reject otherwise so callers don't silently pass it and wonder why it's ignored.""" + if self.debug.no_inference: + return self if self.renderer is None and self.pool_size is not None: raise ValueError( f"orchestrator.pool_size={self.pool_size!r} is set but " @@ -788,6 +809,8 @@ def vlm_requires_renderer(self): tokens, and ships generic ``mm_kwargs`` keyed by whatever the model's forward signature expects. """ + if self.debug.no_inference: + return self if self.student.model.vlm is not None and self.renderer is None: raise ValueError( "orchestrator.renderer must be set when model.vlm is set. " @@ -810,6 +833,8 @@ def validate_renderer_auto_resolves(self): """ if self.renderer is None or self.renderer.name != "auto": return self + if self.debug.no_inference: + return self from renderers.base import MODEL_RENDERER_MAP model_id = self.tokenizer.name or self.student.model.name @@ -898,6 +923,17 @@ def auto_setup_bench(self): return self + @model_validator(mode="after") + def resolve_debug(self): + if self.debug.no_trainer: + self.rollout_transport = NoOpTransportConfig() + if self.debug.fake_tokenizer and (self.wandb is not None or self.prime_monitor is not None): + raise ValueError("orchestrator.debug.fake_tokenizer requires wandb and prime_monitor to be disabled.") + if self.debug.no_inference: + self.renderer = None + self.collect_inference_metrics = False + return self + @model_validator(mode="after") def resolve_env_config(self): """Populate extra_env_kwargs and vLLM sampling defaults from top-level fields.""" diff --git a/packages/prime-rl-configs/src/prime_rl/configs/shared.py b/packages/prime-rl-configs/src/prime_rl/configs/shared.py index ff311f145d..4e40a6a994 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/shared.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/shared.py @@ -261,4 +261,12 @@ class ZMQTransportConfig(BaseTransportConfig): """High-water mark (max in-flight messages per ZMQ socket).""" -TransportConfig: TypeAlias = Annotated[FileSystemTransportConfig | ZMQTransportConfig, Field(discriminator="type")] +class NoOpTransportConfig(BaseTransportConfig): + type: Literal["noop"] = "noop" + """Drop training batches after orchestrator-side processing. Intended for orchestrator debug runs without a trainer.""" + + +TransportConfig: TypeAlias = Annotated[ + FileSystemTransportConfig | ZMQTransportConfig | NoOpTransportConfig, + Field(discriminator="type"), +] diff --git a/pyproject.toml b/pyproject.toml index 1f6fd89afb..885381bc10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ envs = [ "code-env", "color-codeword", "deepdive", + "fake-r3-trajectory", "general-agent", "gpqa", "hle", @@ -209,6 +210,7 @@ alphabet-sort = { path = "deps/verifiers/environments/alphabet_sort", editable = code-env = { path = "deps/research-environments/environments/code_env", editable = true } color-codeword = { path = "deps/research-environments/environments/color_codeword", editable = true } deepdive = { path = "deps/research-environments/environments/deepdive", editable = true } +fake-r3-trajectory = { path = "tests/debug_envs/fake_r3_trajectory", editable = true } general-agent = { path = "deps/research-environments/environments/general_agent", editable = true } gpqa = { path = "deps/research-environments/environments/gpqa", editable = true } harnesses = { path = "deps/verifiers/packages/harnesses", editable = true } diff --git a/src/prime_rl/entrypoints/rl.py b/src/prime_rl/entrypoints/rl.py index db1e27b995..69f506cae9 100644 --- a/src/prime_rl/entrypoints/rl.py +++ b/src/prime_rl/entrypoints/rl.py @@ -83,22 +83,32 @@ def rl_local(config: RLConfig): logger.success("Dry run complete. To start an RL run locally, remove --dry-run from your command.") return + debug_no_inference = config.orchestrator.debug.no_inference + debug_no_trainer = config.orchestrator.debug.no_trainer + # Derive launcher-local GPU IDs from deployment config gpu_offset = 0 - num_infer_gpus = config.deployment.num_infer_gpus if config.inference is not None else 0 + num_infer_gpus = ( + 0 if debug_no_inference else (config.deployment.num_infer_gpus if config.inference is not None else 0) + ) infer_local_gpu_ids = list(range(gpu_offset, gpu_offset + num_infer_gpus)) gpu_offset += num_infer_gpus - trainer_local_gpu_ids = list(range(gpu_offset, gpu_offset + config.deployment.num_train_gpus)) - - total_requested_gpus = num_infer_gpus + config.deployment.num_train_gpus - physical_gpu_ids = get_physical_gpu_ids() - if total_requested_gpus > len(physical_gpu_ids): - raise ValueError( - f"Requested {total_requested_gpus} GPUs via deployment settings, but only " - f"{len(physical_gpu_ids)} physical GPU(s) are available: {physical_gpu_ids}" - ) - physical_gpu_mapping = {local_id: physical_gpu_ids[local_id] for local_id in range(total_requested_gpus)} - logger.info(f"Using local->physical GPU mapping: {physical_gpu_mapping}") + num_train_gpus = 0 if debug_no_trainer else config.deployment.num_train_gpus + trainer_local_gpu_ids = list(range(gpu_offset, gpu_offset + num_train_gpus)) + + total_requested_gpus = num_infer_gpus + num_train_gpus + if total_requested_gpus > 0: + physical_gpu_ids = get_physical_gpu_ids() + if total_requested_gpus > len(physical_gpu_ids): + raise ValueError( + f"Requested {total_requested_gpus} GPUs via deployment settings, but only " + f"{len(physical_gpu_ids)} physical GPU(s) are available: {physical_gpu_ids}" + ) + physical_gpu_mapping = {local_id: physical_gpu_ids[local_id] for local_id in range(total_requested_gpus)} + logger.info(f"Using local->physical GPU mapping: {physical_gpu_mapping}") + else: + physical_gpu_mapping = {} + logger.info("No GPUs requested for orchestrator debug run") infer_gpu_ids = [physical_gpu_mapping[local_gpu_id] for local_gpu_id in infer_local_gpu_ids] trainer_gpu_ids = [physical_gpu_mapping[local_gpu_id] for local_gpu_id in trainer_local_gpu_ids] @@ -150,7 +160,7 @@ def sigterm_handler(signum, frame): try: # Optionally, start inference process - if config.inference: + if config.inference and not debug_no_inference: inference_cmd = ["inference", "@", (config_dir / INFERENCE_TOML).as_posix()] logger.info(f"Starting inference on GPU(s) {' '.join(map(str, infer_gpu_ids))}") logger.debug(f"Inference start command: {' '.join(inference_cmd)}") @@ -177,6 +187,8 @@ def sigterm_handler(signum, frame): ) monitor_thread.start() monitor_threads.append(monitor_thread) + elif debug_no_inference: + logger.warning("Skipping inference process for orchestrator debug no-inference mode") else: logger.warning( "No [inference] block configured - the student inference server will not be started here. " @@ -224,54 +236,58 @@ def sigterm_handler(signum, frame): monitor_thread.start() monitor_threads.append(monitor_thread) - # Start training process - from prime_rl.utils.utils import get_free_port - - trainer_cmd = [ - "torchrun", - "--role=trainer", - f"--rdzv-endpoint=localhost:{get_free_port()}", - f"--rdzv-id={uuid.uuid4().hex}", - # Pipe all logs to file, and only master rank logs to stdout - f"--log-dir={log_dir / 'trainer' / 'torchrun'}", - f"--local-ranks-filter={','.join(map(str, config.trainer.log.ranks_filter))}", - "--redirect=3", - "--tee=3", - f"--nproc-per-node={len(trainer_gpu_ids)}", - "-m", - "prime_rl.trainer.rl.train", - "@", - (config_dir / TRAINER_TOML).as_posix(), - ] - logger.info(f"Starting trainer on GPU(s) {' '.join(map(str, trainer_gpu_ids))}") - logger.debug(f"Training start command: {' '.join(trainer_cmd)}") - with open(log_dir / "trainer.log", "w") as log_file: - trainer_process = Popen( - trainer_cmd, - env={ - **os.environ, - **wandb_shared_env, - "WANDB_SHARED_LABEL": "trainer", - "CUDA_VISIBLE_DEVICES": ",".join(map(str, trainer_gpu_ids)), - "PYTHONUNBUFFERED": "1", - "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", - "LOGURU_FORCE_COLORS": "1", - "WANDB_PROGRAM": "uv run rl", - "WANDB_ARGS": json.dumps(start_command), - }, - stdout=log_file, - stderr=log_file, - ) - processes.append(trainer_process) + if debug_no_trainer: + logger.warning("Skipping trainer process for orchestrator debug no-trainer mode") + trainer_process = None + else: + # Start training process + from prime_rl.utils.utils import get_free_port + + trainer_cmd = [ + "torchrun", + "--role=trainer", + f"--rdzv-endpoint=localhost:{get_free_port()}", + f"--rdzv-id={uuid.uuid4().hex}", + # Pipe all logs to file, and only master rank logs to stdout + f"--log-dir={log_dir / 'trainer' / 'torchrun'}", + f"--local-ranks-filter={','.join(map(str, config.trainer.log.ranks_filter))}", + "--redirect=3", + "--tee=3", + f"--nproc-per-node={len(trainer_gpu_ids)}", + "-m", + "prime_rl.trainer.rl.train", + "@", + (config_dir / TRAINER_TOML).as_posix(), + ] + logger.info(f"Starting trainer on GPU(s) {' '.join(map(str, trainer_gpu_ids))}") + logger.debug(f"Training start command: {' '.join(trainer_cmd)}") + with open(log_dir / "trainer.log", "w") as log_file: + trainer_process = Popen( + trainer_cmd, + env={ + **os.environ, + **wandb_shared_env, + "WANDB_SHARED_LABEL": "trainer", + "CUDA_VISIBLE_DEVICES": ",".join(map(str, trainer_gpu_ids)), + "PYTHONUNBUFFERED": "1", + "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", + "LOGURU_FORCE_COLORS": "1", + "WANDB_PROGRAM": "uv run rl", + "WANDB_ARGS": json.dumps(start_command), + }, + stdout=log_file, + stderr=log_file, + ) + processes.append(trainer_process) - # Start monitoring thread - stop_event = Event() - stop_events["trainer"] = stop_event - monitor_thread = Thread( - target=monitor_process, args=(trainer_process, stop_event, error_queue, "trainer"), daemon=True - ) - monitor_thread.start() - monitor_threads.append(monitor_thread) + # Start monitoring thread + stop_event = Event() + stop_events["trainer"] = stop_event + monitor_thread = Thread( + target=monitor_process, args=(trainer_process, stop_event, error_queue, "trainer"), daemon=True + ) + monitor_thread.start() + monitor_threads.append(monitor_thread) # Monitor all processes for failures logger.success("Startup complete. Showing orchestrator logs...") @@ -283,7 +299,8 @@ def sigterm_handler(signum, frame): processes.append(tail_process) # Check for errors from monitor threads - while not (stop_events["orchestrator"].is_set() and stop_events["trainer"].is_set()): + required_processes = ["orchestrator"] + ([] if debug_no_trainer else ["trainer"]) + while not all(stop_events[name].is_set() for name in required_processes): if error_queue: error = error_queue[0] logger.error(f"Error: {error}") @@ -302,7 +319,7 @@ def sigterm_handler(signum, frame): cleanup_processes(processes) sys.exit(1) - if trainer_process.returncode != 0: + if trainer_process is not None and trainer_process.returncode != 0: logger.error(f"Trainer failed with exit code {trainer_process.returncode}") cleanup_threads(monitor_threads) cleanup_processes(processes) @@ -501,7 +518,7 @@ def rl(config: RLConfig): get_logger().info("Training from scratch, cleaning any stale rollouts and broadcasts") clean_future_steps(config.output_dir, -1) - if not config.dry_run: + if not config.dry_run and not config.orchestrator.debug.no_trainer: from prime_rl.trainer.model import pre_download_model pre_download_model(config.trainer.model.name) diff --git a/src/prime_rl/orchestrator/filters.py b/src/prime_rl/orchestrator/filters.py index f8deda1230..acf11f9527 100644 --- a/src/prime_rl/orchestrator/filters.py +++ b/src/prime_rl/orchestrator/filters.py @@ -9,6 +9,7 @@ from __future__ import annotations import math +from collections.abc import Iterator from dataclasses import dataclass from typing import TYPE_CHECKING, Protocol @@ -19,6 +20,42 @@ from prime_rl.orchestrator.types import TrainRollout +@dataclass(frozen=True) +class GeneratedTokenLogprob: + token_id: int + logprob: float + + +def _iter_generated_token_logprobs(rollout: "TrainRollout") -> Iterator[GeneratedTokenLogprob]: + """Yield only model-generated completion tokens for rollout filters. + + ``TrainSink`` builds ``rollout.samples`` before pre-filters run. Those + samples survive raw trajectory compaction, so filters prefer them when + present. Tests and direct callers can still use the raw trajectory path. + """ + if rollout.samples: + for sample in rollout.samples: + for token_id, logprob, is_generated in zip( + sample.completion_ids, + sample.completion_logprobs, + sample.completion_mask, + ): + if is_generated: + yield GeneratedTokenLogprob(token_id=token_id, logprob=logprob) + return + + for step in rollout.raw.get("trajectory") or []: + tokens = step.get("tokens") if isinstance(step, dict) else None + if tokens is None or "completion_ids" not in tokens or "completion_logprobs" not in tokens: + continue + mask = tokens.get("completion_mask") + if mask is None: + mask = [True] * len(tokens["completion_ids"]) + for token_id, logprob, is_generated in zip(tokens["completion_ids"], tokens["completion_logprobs"], mask): + if is_generated: + yield GeneratedTokenLogprob(token_id=token_id, logprob=logprob) + + @dataclass class FilterResult: detected: bool @@ -51,14 +88,10 @@ class GibberishFilter: def check(self, rollout: "TrainRollout") -> FilterResult: global_idx = 0 - for step in rollout.raw["trajectory"]: - tokens = step["tokens"] - if tokens is None: - continue - for token_id, logprob in zip(tokens["completion_ids"], tokens["completion_logprobs"]): - if token_id > self.token_id_threshold and logprob < self.logprob_threshold: - return FilterResult(detected=True, detection_index=global_idx) - global_idx += 1 + for token in _iter_generated_token_logprobs(rollout): + if token.token_id > self.token_id_threshold and token.logprob < self.logprob_threshold: + return FilterResult(detected=True, detection_index=global_idx) + global_idx += 1 return FilterResult(detected=False) @@ -82,18 +115,14 @@ class RepetitionFilter: def check(self, rollout: "TrainRollout") -> FilterResult: consecutive = 0 global_idx = 0 - for step in rollout.raw["trajectory"]: - tokens = step["tokens"] - if tokens is None: - continue - for logprob in tokens["completion_logprobs"]: - if logprob > self.logprob_threshold: - consecutive += 1 - else: - consecutive = 0 - if consecutive >= self.window: - return FilterResult(detected=True, detection_index=global_idx) - global_idx += 1 + for token in _iter_generated_token_logprobs(rollout): + if token.logprob > self.logprob_threshold: + consecutive += 1 + else: + consecutive = 0 + if consecutive >= self.window: + return FilterResult(detected=True, detection_index=global_idx) + global_idx += 1 return FilterResult(detected=False) diff --git a/src/prime_rl/orchestrator/metrics.py b/src/prime_rl/orchestrator/metrics.py index d32dbc02de..9257a84197 100644 --- a/src/prime_rl/orchestrator/metrics.py +++ b/src/prime_rl/orchestrator/metrics.py @@ -61,7 +61,7 @@ def build( "prefill_len": metrics.rollout_prefill_lens, "decode_len": metrics.rollout_decode_lens, "samples_per_rollout": metrics.samples_per_rollout, - "num_turns": [len(r.raw["trajectory"]) for r in rollouts], + "num_turns": [r.turn_count for r in rollouts], } ) metrics_df = pd.DataFrame([(r.raw.get("metrics") or {}) for r in rollouts]) diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index fbdb89f57d..7b0bd61ebb 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -21,7 +21,6 @@ from __future__ import annotations import asyncio -import ctypes import logging import os import time @@ -69,12 +68,13 @@ compute_teacher_logprobs, get_weight_dir, intercept_vf_logging, + log_process_memory, save_rollouts, set_default_executor, setup_student_inference_pool, + trim_process_memory, ) -from prime_rl.orchestrator.watcher import WeightWatcher -from prime_rl.trainer.model import setup_tokenizer +from prime_rl.orchestrator.watcher import NoOpWeightWatcher, WeightWatcher from prime_rl.transport import TrainingBatch, setup_training_batch_sender from prime_rl.utils.async_utils import safe_cancel from prime_rl.utils.client import init_nccl_broadcast, setup_inference_pool @@ -114,6 +114,20 @@ TOKEN_EXPORT_DRAIN_POLL_S = 2.0 +class DebugTokenizer: + """Tiny tokenizer stub for pre-tokenized orchestrator debug rollouts.""" + + vocab_size = 200_000 + eos_token_id = 0 + pad_token_id = 0 + + def decode(self, token_ids, *args, **kwargs) -> str: + return " ".join(str(t) for t in token_ids) + + def batch_decode(self, sequences, *args, **kwargs) -> list[str]: + return [self.decode(seq, *args, **kwargs) for seq in sequences] + + class Orchestrator: # Set in ``__init__`` config: OrchestratorConfig @@ -217,8 +231,14 @@ async def setup(self) -> None: for env_id in env_ids_to_install: install_env(env_id, prerelease=config.env_install_prerelease) - get_logger().info(f"Initializing tokenizer ({config.tokenizer})") - self.tokenizer = setup_tokenizer(config.tokenizer) + if config.debug.fake_tokenizer: + get_logger().warning("Using debug tokenizer stub; rollout tokens must be present in env output") + self.tokenizer = DebugTokenizer() # type: ignore[assignment] + else: + get_logger().info(f"Initializing tokenizer ({config.tokenizer})") + from prime_rl.trainer.model import setup_tokenizer + + self.tokenizer = setup_tokenizer(config.tokenizer) # Student inference pool get_logger().info( @@ -234,7 +254,7 @@ async def setup(self) -> None: if self.mm_token_type_ids_mapping == {}: self.mm_token_type_ids_mapping = None - if config.teacher is not None: + if config.teacher is not None and not config.debug.no_inference: get_logger().info( f"Initializing teacher inference pool (base_url={', '.join(config.teacher.client.base_url)}, " f"model={config.teacher.model.name})" @@ -318,8 +338,11 @@ async def setup(self) -> None: ) await self.inference_metrics.start() - get_logger().info(f"Initializing weight broadcast ({config.weight_broadcast})") - if config.weight_broadcast.type == "nccl": + if config.debug.no_trainer: + get_logger().warning("Skipping weight broadcast setup for orchestrator debug no-trainer mode") + else: + get_logger().info(f"Initializing weight broadcast ({config.weight_broadcast})") + if not config.debug.no_trainer and config.weight_broadcast.type == "nccl": await init_nccl_broadcast( self.student_inference.admin_clients, config.weight_broadcast.host, @@ -400,14 +423,17 @@ async def setup(self) -> None: post_filters=post_filters, ) self.eval_sink = EvalSink(eval_envs=self.eval_envs) if self.eval_envs is not None else None - self.watcher = WeightWatcher( - config, - policy=self.policy, - inference=self.student_inference, - observers=[self.dispatcher, self], - lora_name=self.lora_name, - ckpt_step=self.progress.step, - ) + if config.debug.no_trainer: + self.watcher = NoOpWeightWatcher(policy=self.policy) # type: ignore[assignment] + else: + self.watcher = WeightWatcher( + config, + policy=self.policy, + inference=self.student_inference, + observers=[self.dispatcher, self], + lora_name=self.lora_name, + ckpt_step=self.progress.step, + ) # Single periodic logger for the whole pipeline. It's the only # consumer of ``dispatcher.metrics.drained()`` (which clears on read) self.lag_monitor = EventLoopLagMonitor() @@ -482,10 +508,7 @@ async def start(self) -> None: get_logger().success("Orchestrator finished.") else: get_logger().warning("Orchestrator cleanup complete (forced).") - try: - ctypes.CDLL("libc.so.6").malloc_trim(0) - except Exception as e: - get_logger().debug(f"malloc_trim(0) failed: {e}") + trim_process_memory() async def main_loop(self) -> None: """Consume ``FinishedRollout``\\ s from the dispatcher and route them @@ -613,6 +636,9 @@ async def finalize_train_batch(self, batch: TrainBatch) -> None: teacher_logprobs_time = time.perf_counter() - t await self.sender.send(TrainingBatch(examples=batch.samples, step=step)) + self.release_train_batch_samples(batch) + if config.debug.no_trainer: + self.policy.version = max(self.policy.version, step + 1) self.update_dispatch_gate() metrics = self.metrics.build( @@ -629,6 +655,7 @@ async def finalize_train_batch(self, batch: TrainBatch) -> None: ) self.monitor.log(metrics, step=step) self.monitor.log_samples(rollout_dicts, step=step) + rollout_dicts.clear() self.monitor.log_distributions( distributions={ "rewards": [r.reward for r in batch.rollouts], @@ -663,6 +690,22 @@ async def finalize_train_batch(self, batch: TrainBatch) -> None: self.train_sink.reset_pre_filter_stats() self.progress.step += 1 self.maybe_trigger_eval(self.progress.step) + self.release_train_batch_rollouts(batch) + trim_process_memory() + if config.debug.log_memory: + log_process_memory(f"after_step step={step}") + + @staticmethod + def release_train_batch_samples(batch: TrainBatch) -> None: + """Drop orchestrator-owned references to sent trainer samples.""" + batch.samples.clear() + for rollout in batch.rollouts: + rollout.samples.clear() + + @staticmethod + def release_train_batch_rollouts(batch: TrainBatch) -> None: + """Drop the finalized train batch after metrics and logs consume it.""" + batch.rollouts.clear() def maybe_trigger_eval(self, step: int) -> None: """Fire eligible eval epochs and flip to ``PREFER_EVAL`` if anything @@ -760,7 +803,7 @@ def log_train_batch(self, batch: TrainBatch, *, step: int, step_time: float) -> trainable_rate = (n_trainable / n_survivors) if n_survivors else 0.0 reward_mean = sum(r.reward for r in batch.rollouts) / max(n_survivors, 1) max_off_policy = max((r.off_policy_steps for r in batch.rollouts), default=0) - turns_mean = sum(len(r.raw.get("trajectory") or []) for r in batch.rollouts) / max(n_survivors, 1) + turns_mean = sum(r.turn_count for r in batch.rollouts) / max(n_survivors, 1) truncation_rate = sum(1 for r in batch.rollouts if r.is_truncated) / max(n_survivors, 1) head = ( @@ -784,11 +827,7 @@ def log_train_batch(self, batch: TrainBatch, *, step: int, step_time: float) -> env_error_rate = (n_env_errors / n_env_arrivals) if n_env_arrivals else 0.0 env_reward = (sum(r.reward for r in env_rollouts) / len(env_rollouts)) if env_rollouts else 0.0 env_max_off_policy = max((r.off_policy_steps for r in env_rollouts), default=0) - env_turns = ( - sum(len(r.raw.get("trajectory") or []) for r in env_rollouts) / len(env_rollouts) - if env_rollouts - else 0.0 - ) + env_turns = sum(r.turn_count for r in env_rollouts) / len(env_rollouts) if env_rollouts else 0.0 env_truncation = sum(1 for r in env_rollouts if r.is_truncated) / len(env_rollouts) if env_rollouts else 0.0 lines.append( f"╰─ {env_name:<{name_width}} | Ratio {ratio:.1%} | Reward {env_reward:.4f} | " diff --git a/src/prime_rl/orchestrator/train_sink.py b/src/prime_rl/orchestrator/train_sink.py index f79a0d5eff..622530931d 100644 --- a/src/prime_rl/orchestrator/train_sink.py +++ b/src/prime_rl/orchestrator/train_sink.py @@ -153,6 +153,7 @@ async def process_rollout(self, rollout: TrainRollout) -> None: if rollout.error is not None: return raw = rollout.raw + rollout.num_turns = len(raw.get("trajectory") or []) needs_backfill = any(s["tokens"] is None for s in raw.get("trajectory") or []) if needs_backfill: await asyncio.to_thread(backfill_rollout_tokens, raw, self.tokenizer, renderer=self.renderer) @@ -161,6 +162,7 @@ async def process_rollout(self, rollout: TrainRollout) -> None: raw, mm_token_type_ids_mapping=self.mm_token_type_ids_mapping, env_name=rollout.env_name, + prune_raw_payload=True, ) rollout.samples = samples or [] # Offload base64 image bytes to disk as soon as the rollout is @@ -200,8 +202,8 @@ def process_group(self, group_id: uuid.UUID) -> None: # Propagate to the pre-tokenized samples so the orchestrator can # collect samples at ship time without re-walking rollouts. The env - # has a single sampling temperature; fan it out across each sample's - # completion tokens here (interleave leaves it empty). + # has a single sampling temperature; keep it compact instead of + # fanning out a duplicate Python float per completion token. temperature = env.sampling_args["temperature"] for r in survivors: for sample in r.samples: @@ -209,7 +211,8 @@ def process_group(self, group_id: uuid.UUID) -> None: sample.reward = r.reward sample.env_name = r.env_name sample.training_mode = self.config.training_mode - sample.completion_temperatures = [temperature] * len(sample.completion_ids) + sample.completion_temperature = temperature + sample.completion_temperatures = [] if self.pre_filters: apply_filters(self.pre_filters, survivors) diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index 3e8431c12a..7428f40d4b 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -22,6 +22,68 @@ # We use list() instead of deepcopy() for flat lists (token IDs, logprobs) - safe because # primitives are immutable. mm_kwargs payloads are not mutated after creation. +_RAW_TOKEN_ARRAY_KEYS = { + "prompt_ids", + "prompt_mask", + "completion_ids", + "completion_mask", + "completion_logprobs", + "routed_experts", + "multi_modal_data", + "mm_token_type_ids", +} + + +def _safe_len(value) -> int | None: + if value is None: + return None + try: + return len(value) + except TypeError: + return None + + +def _compact_token_payload(tokens: dict) -> dict: + compact: dict = {} + for key in ("prompt_ids", "prompt_mask", "completion_ids", "completion_mask", "completion_logprobs"): + length = _safe_len(tokens.get(key)) + if length is not None: + compact[f"{key}_len"] = length + + routed = tokens.get("routed_experts") + compact["has_routed_experts"] = routed is not None + if isinstance(routed, dict): + for key in ("shape", "dtype", "start"): + if key in routed: + compact[f"routed_experts_{key}"] = routed[key] + + for key, value in tokens.items(): + if key in _RAW_TOKEN_ARRAY_KEYS: + continue + if isinstance(value, (str, int, float, bool)) or value is None: + compact[key] = value + return compact + + +def prune_train_rollout_payload(raw: dict) -> None: + """Replace duplicate raw trajectory token/R3 arrays with length summaries.""" + if raw.get("trajectory_payload_pruned"): + return + + trajectory = raw.get("trajectory") + if not isinstance(trajectory, list): + return + + for step in trajectory: + if not isinstance(step, dict): + continue + tokens = step.get("tokens") + if isinstance(tokens, dict): + step["tokens"] = _compact_token_payload(tokens) + step.pop("response", None) + + raw["trajectory_payload_pruned"] = True + def align_routed_experts( routed_experts: np.ndarray | None, @@ -206,6 +268,7 @@ def interleave_rollout( mm_token_type_ids_mapping: dict[int, int] | None = None, *, env_name: str = "", + prune_raw_payload: bool = False, ) -> list[TrainingSample] | None: """ Convert vf.RolloutOutput to trainable rollouts by interleaving trajectory steps @@ -276,6 +339,10 @@ def prepare_step_tokens(step: vf.TrajectoryStep, step_idx: int) -> dict[str, Any if prepared is None: return None prepared_steps.append(prepared) + if prune_raw_payload: + # Each step has been copied into prepared_steps, so raw tokens can be + # compacted without affecting sample construction below. + prune_train_rollout_payload(output) # Deferred routed_experts state per sample: O(N) chunk list concatenated # once at finalize, replacing the prior O(N²) per-extension unpack/repack. diff --git a/src/prime_rl/orchestrator/types.py b/src/prime_rl/orchestrator/types.py index 54c3f827b9..9231534911 100644 --- a/src/prime_rl/orchestrator/types.py +++ b/src/prime_rl/orchestrator/types.py @@ -79,6 +79,7 @@ class FinishedRollout: policy_version: int off_policy_steps: int rollout_id: uuid.UUID = field(default_factory=uuid.uuid4) + num_turns: int | None = None @property def error(self) -> dict | None: @@ -92,6 +93,12 @@ def reward(self) -> float: def is_truncated(self) -> bool: return bool(self.raw.get("is_truncated", False)) + @property + def turn_count(self) -> int: + if self.num_turns is not None: + return self.num_turns + return len(self.raw.get("trajectory") or []) + def to_dict(self) -> vf.RolloutOutput: """``raw`` + metadata merged for I/O (``save_rollouts``, ``monitor.log_samples``). Shallow copy; never mutates ``self.raw``.""" @@ -103,6 +110,8 @@ def to_dict(self) -> vf.RolloutOutput: if f.name == "filter_results": out["filters"] = dict(val) continue + if f.name == "num_turns" and val is None: + continue out[f.name] = str(val) if isinstance(val, uuid.UUID) else val return out diff --git a/src/prime_rl/orchestrator/utils.py b/src/prime_rl/orchestrator/utils.py index 5675ba3f34..7b49aeddfd 100644 --- a/src/prime_rl/orchestrator/utils.py +++ b/src/prime_rl/orchestrator/utils.py @@ -1,4 +1,5 @@ import asyncio +import ctypes import logging import time from concurrent.futures import ThreadPoolExecutor @@ -12,7 +13,7 @@ from prime_rl.configs.orchestrator import OrchestratorConfig from prime_rl.transport import TrainingSample -from prime_rl.utils.client import setup_inference_pool +from prime_rl.utils.client import NoOpInferencePool, setup_inference_pool from prime_rl.utils.logger import InterceptHandler, get_logger from prime_rl.utils.utils import ( get_broadcast_dir, @@ -25,6 +26,11 @@ async def setup_student_inference_pool(*, config: OrchestratorConfig, tokenizer) """Build the student inference pool + matching renderer. Returns ``(renderer | None, inference_pool)``; ``renderer`` is ``None`` on the MITO path (``config.renderer is None``).""" + if config.debug.no_inference: + model_name = config.student.model.name + get_logger().warning(f"Using no-op student inference pool for orchestrator debug mode ({model_name=})") + return None, NoOpInferencePool(model_name=model_name) + from renderers.base import create_renderer client_config = config.student.client @@ -57,7 +63,19 @@ async def setup_student_inference_pool(*, config: OrchestratorConfig, tokenizer) def get_model_completion_len(output: vf.RolloutOutput) -> int: """Sum of model-generated completion tokens across all turns (excludes environment-injected tokens between turns).""" - return sum(len(step["tokens"]["completion_ids"]) for step in output["trajectory"] if step.get("tokens")) + total = 0 + for step in output["trajectory"]: + tokens = step.get("tokens") + if not tokens: + continue + completion_ids = tokens.get("completion_ids") + if completion_ids is not None: + total += len(completion_ids) + continue + completion_ids_len = tokens.get("completion_ids_len") + if isinstance(completion_ids_len, int): + total += completion_ids_len + return total def get_tool_response_len(output: vf.RolloutOutput) -> int: @@ -96,6 +114,37 @@ def set_default_executor(max_workers: int = 64) -> None: asyncio.get_event_loop().set_default_executor(ThreadPoolExecutor(max_workers=max_workers)) +def log_process_memory(label: str) -> None: + """Debug-log RSS for the orchestrator process and its child processes.""" + try: + import psutil + + proc = psutil.Process() + rss = proc.memory_info().rss + child_rss = 0 + for child in proc.children(recursive=True): + try: + child_rss += child.memory_info().rss + except psutil.Error: + continue + total = rss + child_rss + gib = 1024**3 + get_logger().debug( + f"Memory | {label} | rss={rss / gib:.3f} GiB | child_rss={child_rss / gib:.3f} GiB | " + f"total_rss={total / gib:.3f} GiB" + ) + except Exception as exc: + get_logger().debug(f"Memory logging failed at {label}: {exc!r}") + + +def trim_process_memory() -> None: + """Return freed heap pages to the OS on glibc systems.""" + try: + ctypes.CDLL("libc.so.6").malloc_trim(0) + except Exception as exc: + get_logger().debug(f"malloc_trim(0) failed: {exc!r}") + + async def compute_teacher_logprobs( clients: list[vf.ClientConfig], model_name: str, diff --git a/src/prime_rl/orchestrator/watcher.py b/src/prime_rl/orchestrator/watcher.py index c01d349f40..ef17902602 100644 --- a/src/prime_rl/orchestrator/watcher.py +++ b/src/prime_rl/orchestrator/watcher.py @@ -138,3 +138,24 @@ def gauges(self) -> dict[str, float]: "watcher/last_update_weights_time": self.last_update_weights_time, "watcher/last_wait_for_ckpt_time": self.last_wait_for_ckpt_time, } + + +class NoOpWeightWatcher: + """Weight watcher for orchestrator debug runs without a trainer.""" + + def __init__(self, *, policy: Policy) -> None: + self.policy = policy + + async def start(self) -> None: + await asyncio.Event().wait() + + async def stop(self) -> None: + return None + + def gauges(self) -> dict[str, float]: + return { + "watcher/policy_version": float(self.policy.version), + "watcher/update_count": 0.0, + "watcher/last_update_weights_time": 0.0, + "watcher/last_wait_for_ckpt_time": 0.0, + } diff --git a/src/prime_rl/trainer/batch.py b/src/prime_rl/trainer/batch.py index 2c39a83e4c..ee3136fdb8 100644 --- a/src/prime_rl/trainer/batch.py +++ b/src/prime_rl/trainer/batch.py @@ -33,6 +33,27 @@ def _slice_routed_experts(routed_experts: RoutedExperts, seq_len: int) -> Routed ) +def _completion_temperatures(training_example: TrainingSample) -> tuple[float, list[float]]: + if training_example.completion_temperatures: + return training_example.completion_temperatures[0], training_example.completion_temperatures + + temperature = training_example.completion_temperature + if temperature is None: + temperature = 1.0 + return temperature, [temperature] * len(training_example.completion_ids) + + +def _append_routed_experts(dst: MicroBatch, src: MicroBatch) -> None: + dst_routed = dst.routed_experts + src_routed = src.routed_experts + assert dst_routed is not None + assert src_routed is not None + assert dst_routed.dtype == src_routed.dtype + assert dst_routed.shape[1:] == src_routed.shape[1:] + dst_routed.data += src_routed.data + dst_routed.shape[0] += src_routed.shape[0] + + def _pad_routed_experts(micro_batch: MicroBatch, padding_size: int) -> None: routed_experts = micro_batch.routed_experts assert routed_experts is not None @@ -57,10 +78,11 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch assert training_example.env_name != "all", "env_name='all' is reserved for aggregate metric keys" env_names = [training_example.env_name] * len(input_ids) - # Per-token temperatures: prompt tokens use first completion temp (masked out anyway) - # Default to 1.0 if completion is empty (e.g., model generated only tool calls with no text) - prompt_temp = training_example.completion_temperatures[0] if training_example.completion_temperatures else 1.0 - temperatures = [prompt_temp] * len(training_example.prompt_ids) + training_example.completion_temperatures + # Per-token temperatures: prompt tokens use the completion temperature + # (masked out anyway). The transport can carry a compact scalar for the + # common constant-temperature case. + prompt_temp, completion_temperatures = _completion_temperatures(training_example) + temperatures = [prompt_temp] * len(training_example.prompt_ids) + completion_temperatures # Teacher logprobs already cover the full sequence (prompt + completion), # computed via prefill in the orchestrator when a teacher model is configured diff --git a/src/prime_rl/trainer/rl/packer.py b/src/prime_rl/trainer/rl/packer.py index ae0d9863e1..456b29f6ef 100644 --- a/src/prime_rl/trainer/rl/packer.py +++ b/src/prime_rl/trainer/rl/packer.py @@ -176,10 +176,14 @@ def _validate_sample(self, sample: TrainingSample) -> tuple[bool, str | None]: False, f"Run wrote a sample with completion logprobs length != completion ids length ({len(sample.completion_logprobs)} != {len(sample.completion_ids)})", ) - if len(sample.completion_temperatures) != len(sample.completion_ids): + completion_temperatures_len = len(sample.completion_temperatures) + has_compact_temperature = sample.completion_temperature is not None + if completion_temperatures_len != len(sample.completion_ids) and not ( + completion_temperatures_len == 0 and has_compact_temperature + ): return ( False, - f"Run wrote a sample with completion temperatures length != completion ids length ({len(sample.completion_temperatures)} != {len(sample.completion_ids)})", + f"Run wrote a sample with completion temperatures length != completion ids length ({completion_temperatures_len} != {len(sample.completion_ids)})", ) if sample_length == 0: return False, "Run wrote a sample with no tokens" diff --git a/src/prime_rl/transport/__init__.py b/src/prime_rl/transport/__init__.py index 7e49ecbc81..297bc58cc8 100644 --- a/src/prime_rl/transport/__init__.py +++ b/src/prime_rl/transport/__init__.py @@ -7,6 +7,7 @@ FileSystemMicroBatchSender, FileSystemTrainingBatchReceiver, FileSystemTrainingBatchSender, + NoOpTrainingBatchSender, ) from prime_rl.transport.types import ( MicroBatch, @@ -27,6 +28,8 @@ def setup_training_batch_sender(output_dir: Path, transport: TransportConfig) -> return FileSystemTrainingBatchSender(output_dir) elif transport.type == "zmq": return ZMQTrainingBatchSender(output_dir, transport) + elif transport.type == "noop": + return NoOpTrainingBatchSender(output_dir) else: raise ValueError(f"Invalid transport type: {transport.type}") @@ -65,6 +68,7 @@ def setup_micro_batch_receiver( __all__ = [ "FileSystemTrainingBatchSender", "FileSystemTrainingBatchReceiver", + "NoOpTrainingBatchSender", "FileSystemMicroBatchSender", "FileSystemMicroBatchReceiver", "MicroBatchReceiver", diff --git a/src/prime_rl/transport/filesystem.py b/src/prime_rl/transport/filesystem.py index 5158b97539..a82d299dca 100644 --- a/src/prime_rl/transport/filesystem.py +++ b/src/prime_rl/transport/filesystem.py @@ -34,6 +34,13 @@ def _encode_and_write(self, batch: TrainingBatch, step_path: Path) -> None: tmp_path.rename(step_path / BATCH_FILE_NAME) +class NoOpTrainingBatchSender(TrainingBatchSender): + """Debug sender that drops batches after orchestrator-side processing.""" + + async def send(self, batch: TrainingBatch) -> None: + self.logger.info(f"No-op training transport dropping batch step={batch.step} examples={len(batch.examples)}") + + class FileSystemTrainingBatchReceiver(TrainingBatchReceiver): """Filesystem-based training batch receiver that reads batches from multiple run directories.""" diff --git a/src/prime_rl/transport/types.py b/src/prime_rl/transport/types.py index 1666c51092..44eaff69ff 100644 --- a/src/prime_rl/transport/types.py +++ b/src/prime_rl/transport/types.py @@ -36,6 +36,10 @@ class TrainingSample(msgspec.Struct, array_like=True, gc=False, omit_defaults=Tr teacher_logprobs: list[float] | None = None advantage: float | None = None reward: float | None = None + # Compact representation for the common case where all completion tokens + # were sampled with the same temperature. Legacy per-token temperatures + # remain supported for callers that need token-varying values. + completion_temperature: float | None = None # Generic multimodal kwargs: flat dict keyed by the kwarg names the # model's forward expects (e.g. {"pixel_values": ..., "image_grid_thw": diff --git a/src/prime_rl/utils/client.py b/src/prime_rl/utils/client.py index 533f6e2711..2a7376e79e 100644 --- a/src/prime_rl/utils/client.py +++ b/src/prime_rl/utils/client.py @@ -147,6 +147,55 @@ async def stop(self) -> None: pass +class NoOpInferencePool: + """InferencePool for orchestrator debug runs whose envs do not call the model.""" + + def __init__(self, model_name: str = "debug-noop-model"): + self.model_name = model_name + self._client = vf.ClientConfig( + client_idx=0, + client_type="openai_chat_completions", + api_base_url="http://debug-noop-inference/v1", + api_key_var="EMPTY", + timeout=1, + connect_timeout=1, + max_connections=1, + max_keepalive_connections=1, + max_retries=0, + extra_headers={}, + extra_headers_from_state={}, + ) + + @property + def train_clients(self) -> list[vf.ClientConfig]: + return [self._client] + + @property + def admin_clients(self) -> list[AsyncClient]: + return [] + + def update_model_name(self, model_name: str) -> None: + self.model_name = model_name + + async def get_eval_client(self) -> vf.ClientConfig: + return self._client + + async def select_train_client(self, load: Mapping[ClientIdentity, int]) -> vf.ClientConfig: + return self._client + + async def wait_for_ready(self, model_name: str, timeout: int | None = None) -> None: + return None + + async def update_weights(self, weight_dir: Path | None, lora_name: str | None = None, step: int = 0) -> None: + return None + + def get_metrics(self) -> dict[str, float]: + return {} + + async def stop(self) -> None: + return None + + async def setup_inference_pool( client_config: ClientConfig, model_name: str, diff --git a/src/prime_rl/utils/monitor/prime.py b/src/prime_rl/utils/monitor/prime.py index cd1fa8375a..9e89edd082 100644 --- a/src/prime_rl/utils/monitor/prime.py +++ b/src/prime_rl/utils/monitor/prime.py @@ -20,6 +20,7 @@ from prime_rl.configs.shared import PrimeMonitorConfig from prime_rl.utils.logger import get_logger from prime_rl.utils.monitor.base import Monitor, sample_items_for_logging +from prime_rl.utils.monitor.samples import token_payload_length def _json(val: Any) -> str: @@ -354,8 +355,8 @@ def _rollouts_to_parquet_bytes(self, rollouts: list[vf.RolloutOutput], step: int "reward": ts.get("reward"), "advantage": ts.get("advantage"), "extras": ts.get("extras", {}), - "num_input_tokens": len(ts["tokens"]["prompt_ids"]) if ts.get("tokens") else None, - "num_output_tokens": len(ts["tokens"]["completion_ids"]) if ts.get("tokens") else None, + "num_input_tokens": token_payload_length(ts.get("tokens"), "prompt_ids"), + "num_output_tokens": token_payload_length(ts.get("tokens"), "completion_ids"), } for ts in trajectory ] diff --git a/src/prime_rl/utils/monitor/samples.py b/src/prime_rl/utils/monitor/samples.py new file mode 100644 index 0000000000..9f73088e42 --- /dev/null +++ b/src/prime_rl/utils/monitor/samples.py @@ -0,0 +1,46 @@ +import json +from collections.abc import Mapping +from typing import Any + +from transformers.tokenization_utils import PreTrainedTokenizer + +from prime_rl.utils.chat_template import deserialize_tool_calls, normalize_messages + + +def token_payload_length(tokens: Mapping[str, Any] | None, key: str) -> int | None: + """Return a token-array length from either full or compacted raw tokens.""" + if not tokens: + return None + + value = tokens.get(key) + if value is not None: + try: + return len(value) + except TypeError: + pass + + compact_value = tokens.get(f"{key}_len") + return compact_value if isinstance(compact_value, int) else None + + +def token_payload_ids(tokens: Mapping[str, Any] | None) -> list[int] | None: + """Return prompt+completion token IDs when raw arrays are still present.""" + if not tokens: + return None + prompt_ids = tokens.get("prompt_ids") + completion_ids = tokens.get("completion_ids") + if not isinstance(prompt_ids, list) or not isinstance(completion_ids, list): + return None + return prompt_ids + completion_ids + + +def render_prompt_completion_text(tokenizer: PreTrainedTokenizer, prompt: Any, completion: Any) -> str: + """Render prompt/completion messages after raw token IDs have been compacted.""" + messages = normalize_messages(prompt, default_role="user") + messages.extend(normalize_messages(completion, default_role="assistant")) + if not messages: + return "" + try: + return tokenizer.apply_chat_template(deserialize_tool_calls(messages), tokenize=False) + except Exception: + return json.dumps(messages) diff --git a/src/prime_rl/utils/monitor/wandb.py b/src/prime_rl/utils/monitor/wandb.py index 7935d13a6a..1696218f30 100644 --- a/src/prime_rl/utils/monitor/wandb.py +++ b/src/prime_rl/utils/monitor/wandb.py @@ -16,6 +16,7 @@ from prime_rl.utils.config import BaseConfig from prime_rl.utils.logger import get_logger from prime_rl.utils.monitor.base import Monitor, sample_items_for_logging +from prime_rl.utils.monitor.samples import render_prompt_completion_text, token_payload_ids class WandbMonitor(Monitor): @@ -176,9 +177,15 @@ def log_samples(self, rollouts: list[vf.RolloutOutput], step: int) -> None: if not trajectory: continue last_step = trajectory[-1] - tokens = last_step["tokens"] - full_ids = tokens["prompt_ids"] + tokens["completion_ids"] - messages_text = self.tokenizer.decode(full_ids) + full_ids = token_payload_ids(last_step.get("tokens")) or [] + if full_ids: + messages_text = self.tokenizer.decode(full_ids) + else: + messages_text = render_prompt_completion_text( + self.tokenizer, + last_step.get("prompt") or rollout.get("prompt"), + last_step.get("completion") or rollout.get("completion"), + ) sample = { "step": step, "env_name": rollout.get("env_name"), diff --git a/tests/debug_envs/fake_r3_trajectory/fake_r3_trajectory.py b/tests/debug_envs/fake_r3_trajectory/fake_r3_trajectory.py new file mode 100644 index 0000000000..a627ea1c46 --- /dev/null +++ b/tests/debug_envs/fake_r3_trajectory/fake_r3_trajectory.py @@ -0,0 +1,315 @@ +from __future__ import annotations + +import time +from dataclasses import dataclass + +import numpy as np +import pybase64 +import verifiers as vf +from datasets import Dataset +from verifiers.clients import Client +from verifiers.types import RolloutInput, SamplingArgs, State, TimeSpan, TrajectoryStep + + +@dataclass(frozen=True) +class TokenPlan: + prompt_len: int + completion_lens: list[int] + env_delta_lens: list[int] + + +def _constant_reward(**kwargs) -> float: + return 1.0 + + +def _spread(total: int, parts: int) -> list[int]: + if parts <= 0: + return [] + q, r = divmod(total, parts) + return [q + (1 if i < r else 0) for i in range(parts)] + + +def _token_plan(seq_len: int, turns: int, prompt_len: int, completion_fraction: float) -> TokenPlan: + if turns < 1: + raise ValueError("turns must be >= 1") + if seq_len <= prompt_len + turns: + raise ValueError("seq_len is too small for the requested turns and prompt length") + body = seq_len - prompt_len + completion_total = max(turns, int(body * completion_fraction)) + completion_total = min(completion_total, body) + env_total = body - completion_total + return TokenPlan( + prompt_len=prompt_len, + completion_lens=_spread(completion_total, turns), + env_delta_lens=_spread(env_total, max(turns - 1, 0)), + ) + + +def _ids(start: int, count: int, *, vocab_size: int) -> list[int]: + # Keep IDs below the default gibberish threshold while still looking nontrivial. + usable = min(vocab_size - 100, 90_000) + return [100 + ((start + i) % usable) for i in range(count)] + + +def _routed_payload(length: int, layers: int, topk: int, n_experts: int, *, start: int, salt: int): + if length <= 0: + return None + pattern = np.arange(layers * topk, dtype=np.uint16).reshape(layers, topk) + pattern = ((pattern + salt) % min(n_experts, 256)).astype(np.uint8) + routed = np.empty((length, layers, topk), dtype=np.uint8) + routed[:] = pattern + routed += np.arange(length, dtype=np.uint8).reshape(length, 1, 1) + return { + "data": pybase64.b64encode(memoryview(np.ascontiguousarray(routed))).decode("ascii"), + "shape": [length, layers, topk], + "start": start, + } + + +class FakeR3TrajectoryEnv(vf.Environment): + def __init__( + self, + *, + turns: int = 30, + seq_len: int = 30_000, + prompt_len: int = 128, + completion_fraction: float = 0.70, + include_r3: bool = True, + # Defaults mirror the GLM-5 routed-expert replay shape. + routed_layers: int = 78, + routed_topk: int = 8, + n_routed_experts: int = 256, + num_examples: int = 16, + vocab_size: int = 154_880, + **kwargs, + ) -> None: + self.turns = turns + self.seq_len = seq_len + self.prompt_len = prompt_len + self.completion_fraction = completion_fraction + self.include_r3 = include_r3 + self.routed_layers = routed_layers + self.routed_topk = routed_topk + self.n_routed_experts = n_routed_experts + self.vocab_size = vocab_size + self.plan = _token_plan(seq_len, turns, prompt_len, completion_fraction) + + rows = [ + { + "question": f"Return the deterministic fake R3 trajectory for example {i}.", + "answer": "ok", + "info": { + "fake_r3": True, + "seq_len": seq_len, + "turns": turns, + "include_r3": include_r3, + "routed_layers": routed_layers, + "routed_topk": routed_topk, + "n_routed_experts": n_routed_experts, + "glm5_num_hidden_layers": 78, + "glm5_sparse_moe_layers": 75, + }, + } + for i in range(num_examples) + ] + + tool_defs = [ + vf.Tool( + name="lookup_record", + description="Look up a deterministic debug record.", + parameters={ + "type": "object", + "properties": {"key": {"type": "string"}}, + "required": ["key"], + }, + ), + vf.Tool( + name="write_note", + description="Write a deterministic debug note.", + parameters={ + "type": "object", + "properties": {"note": {"type": "string"}}, + "required": ["note"], + }, + ), + ] + + super().__init__( + dataset=Dataset.from_list(rows), + rubric=vf.Rubric(funcs=[_constant_reward]), + tool_defs=tool_defs, + score_rollouts=True, + **kwargs, + ) + + async def rollout( + self, + input: RolloutInput, + client: Client, + model: str, + sampling_args: SamplingArgs | None = None, + ) -> State: + state = await self.init_state(input, client, model, sampling_args) + timing = state["timing"] + start_time = time.time() + timing.generation.start = start_time + timing.setup.start = start_time + timing.setup.end = start_time + + trajectory, completion, final_input_tokens, final_output_tokens = self._build_trajectory( + state["trajectory_id"], + model=model, + example_id=int(state["example_id"]), + ) + state["trajectory"] = trajectory + state["completion"] = completion + state["is_completed"] = True + state["is_truncated"] = False + state["stop_condition"] = "fake_r3_complete" + state["token_usage"] = { + "input_tokens": float(sum(len(step["tokens"]["prompt_ids"]) for step in trajectory)), + "output_tokens": float(sum(len(step["tokens"]["completion_ids"]) for step in trajectory)), + "final_input_tokens": float(final_input_tokens), + "final_output_tokens": float(final_output_tokens), + } + end_time = time.time() + timing.generation.end = end_time + timing.model.spans.append(TimeSpan(start=start_time, end=end_time)) + return state + + def _build_trajectory(self, trajectory_id: str, *, model: str, example_id: int): + prompt_tokens = _ids(example_id * 17, self.plan.prompt_len, vocab_size=self.vocab_size) + current_prompt_ids = list(prompt_tokens) + current_messages = [ + { + "role": "system", + "content": "You are executing a deterministic fake R3 memory rollout.", + }, + { + "role": "user", + "content": f"Example {example_id}: produce the fixed heavy trajectory.", + }, + ] + + trajectory: list[TrajectoryStep] = [] + completion_messages: list[dict] = [] + token_cursor = self.plan.prompt_len + + for turn_idx, completion_len in enumerate(self.plan.completion_lens): + completion_ids = _ids(token_cursor, completion_len, vocab_size=self.vocab_size) + token_cursor += completion_len + + has_tool_call = turn_idx % 3 != 1 + tool_call_id = f"call_fake_{example_id}_{turn_idx}" + assistant_msg = { + "role": "assistant", + "content": f"Turn {turn_idx}: deterministic reasoning chunk.", + "reasoning_content": ( + f"Reasoning turn {turn_idx}: inspect synthetic state, decide whether to call a tool, " + "then continue the fixed replay." + ), + } + if has_tool_call: + assistant_msg["tool_calls"] = [ + { + "id": tool_call_id, + "name": "lookup_record" if turn_idx % 2 == 0 else "write_note", + "arguments": f'{{"key": "example-{example_id}-turn-{turn_idx}"}}', + } + ] + + routed = None + routed_start = 0 + if self.include_r3: + if turn_idx == 0: + routed_len = len(current_prompt_ids) + completion_len - 1 + else: + prefix_len = len(trajectory[-1]["tokens"]["prompt_ids"]) + len( + trajectory[-1]["tokens"]["completion_ids"] + ) + routed_start = prefix_len - 1 + routed_len = len(current_prompt_ids) + completion_len - prefix_len + routed = _routed_payload( + routed_len, + self.routed_layers, + self.routed_topk, + self.n_routed_experts, + start=routed_start, + salt=example_id + turn_idx, + ) + + tokens = { + "prompt_ids": list(current_prompt_ids), + "prompt_mask": [0] * len(current_prompt_ids), + "completion_ids": completion_ids, + "completion_mask": [1] * completion_len, + "completion_logprobs": [-0.25] * completion_len, + "overlong_prompt": False, + "is_truncated": False, + "routed_experts": routed, + } + + response = vf.Response( + id=f"fake-r3-{example_id}-{turn_idx}", + created=int(time.time()), + model=model, + usage=vf.Usage( + prompt_tokens=len(current_prompt_ids), + reasoning_tokens=completion_len // 3, + completion_tokens=completion_len, + total_tokens=len(current_prompt_ids) + completion_len, + ), + message=vf.ResponseMessage( + role="assistant", + content=assistant_msg["content"], + reasoning_content=assistant_msg["reasoning_content"], + tool_calls=assistant_msg.get("tool_calls"), + finish_reason="tool_calls" if has_tool_call else "stop", + is_truncated=False, + tokens=None, + ), + ) + + trajectory.append( + TrajectoryStep( + prompt=list(current_messages), + completion=[assistant_msg], + response=response, + tokens=tokens, + reward=None, + advantage=None, + is_truncated=False, + trajectory_id=trajectory_id, + extras={"turn_idx": turn_idx, "fake_r3": True}, + ) + ) + completion_messages.append(assistant_msg) + current_messages = [*current_messages, assistant_msg] + current_prompt_ids = [*current_prompt_ids, *completion_ids] + + if turn_idx < len(self.plan.env_delta_lens): + env_len = self.plan.env_delta_lens[turn_idx] + env_ids = _ids(token_cursor, env_len, vocab_size=self.vocab_size) + token_cursor += env_len + tool_msg = { + "role": "tool", + "tool_call_id": tool_call_id, + "content": f"Tool result for turn {turn_idx}: " + ("debug payload " * 8), + } + current_messages = [*current_messages, tool_msg] + completion_messages.append(tool_msg) + current_prompt_ids = [*current_prompt_ids, *env_ids] + + final_input_tokens = len(trajectory[-1]["tokens"]["prompt_ids"]) + final_output_tokens = len(trajectory[-1]["tokens"]["completion_ids"]) + assert len(prompt_tokens) + sum(self.plan.completion_lens) + sum(self.plan.env_delta_lens) == self.seq_len + return ( + trajectory, + completion_messages, + final_input_tokens, + final_output_tokens, + ) + + +def load_environment(**kwargs) -> vf.Environment: + return FakeR3TrajectoryEnv(**kwargs) diff --git a/tests/debug_envs/fake_r3_trajectory/pyproject.toml b/tests/debug_envs/fake_r3_trajectory/pyproject.toml new file mode 100644 index 0000000000..5821a0ba64 --- /dev/null +++ b/tests/debug_envs/fake_r3_trajectory/pyproject.toml @@ -0,0 +1,18 @@ +[project] +name = "fake-r3-trajectory" +version = "0.1.0" +description = "Deterministic heavy trajectory environment for PRIME-RL orchestrator memory debugging." +requires-python = ">=3.10" +dependencies = [ + "datasets", + "numpy", + "pybase64", + "verifiers", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = ["fake_r3_trajectory.py"] diff --git a/tests/unit/orchestrator/test_advantage.py b/tests/unit/orchestrator/test_advantage.py index 0f5fd9bf92..7666173f13 100644 --- a/tests/unit/orchestrator/test_advantage.py +++ b/tests/unit/orchestrator/test_advantage.py @@ -17,6 +17,7 @@ ) from prime_rl.orchestrator.envs import Env, TrainEnv from prime_rl.orchestrator.types import TrainRollout +from prime_rl.orchestrator.utils import get_model_completion_len def _make_rollout(reward: float, completion_len: int = 0) -> dict: @@ -183,6 +184,30 @@ def test_length_weighted_baseline(): assert sum(length * adv for length, adv in zip((10, 30, 60), result.advantages)) == pytest.approx(0.0, abs=1e-5) +def test_length_penalty_uses_compacted_completion_lengths(): + """``get_model_completion_len`` reads compacted ``completion_ids_len`` summaries + (raw token arrays pruned post-interleave) so length shaping still sees real lengths.""" + rollout = _make_rollout(reward=1.0, completion_len=0) + rollout["trajectory"] = [ + {"tokens": {"prompt_ids_len": 10, "completion_ids_len": 12}}, + {"tokens": {"prompt_ids_len": 20, "completion_ids_len": 34}}, + ] + + assert get_model_completion_len(rollout) == 46 + + inputs = AdvantageInputs( + rollouts=[ + rollout, + _make_rollout(reward=1.0, completion_len=92), + _make_rollout(reward=0.0, completion_len=46), + ] + ) + result = default_advantage_fn(inputs, length_penalty=LinearLengthPenaltyConfig(coef=1.0), max_seq_len=100) + + # Both reward-1 rollouts, but the compacted one is shorter (46 < 92) → higher advantage. + assert result.advantages[0] > result.advantages[1] + + def _train_rollouts(rewards: list[float]) -> list[TrainRollout]: """Wrap a list of rewards into ``TrainRollout``\\ s sharing a single ``group_id`` — ``assign_advantages`` works on one group at a time diff --git a/tests/unit/orchestrator/test_batch.py b/tests/unit/orchestrator/test_batch.py index bbe0e61724..19cc86cee0 100644 --- a/tests/unit/orchestrator/test_batch.py +++ b/tests/unit/orchestrator/test_batch.py @@ -147,6 +147,24 @@ def test_prepare_batch_packs_different_temperatures(make_training_example): assert flat_batches[0].sequence_lengths == [4, 4] +def test_prepare_sample_expands_compact_completion_temperature(): + sample = TrainingSample( + prompt_ids=[1, 2], + prompt_mask=[False, False], + completion_ids=[3, 4], + completion_mask=[True, True], + completion_logprobs=[-0.1, -0.2], + completion_temperatures=[], + completion_temperature=0.7, + advantage=1.0, + env_name="test-env", + ) + + micro_batch = prepare_sample(sample, seq_len=8) + + assert micro_batch.temperatures == [0.7, 0.7, 0.7, 0.7] + + def test_pad_micro_batch_preserves_explicit_sequence_lengths(make_training_example): micro_batch = prepare_sample(make_training_example(), seq_len=16) diff --git a/tests/unit/orchestrator/test_filters.py b/tests/unit/orchestrator/test_filters.py index 2643bf71bb..5b8422448f 100644 --- a/tests/unit/orchestrator/test_filters.py +++ b/tests/unit/orchestrator/test_filters.py @@ -10,6 +10,7 @@ setup_filters, ) from prime_rl.orchestrator.types import TrainRollout +from prime_rl.transport import TrainingSample def _make_rollout( @@ -134,6 +135,32 @@ def test_gibberish_works_across_trajectory_steps(): assert result.detection_index == 2 +def test_gibberish_uses_samples_when_raw_tokens_are_pruned(): + gibberish_filter = _make_gibberish_filter() + rollout = _make_rollout(completion_ids=[], completion_logprobs=[]) + rollout.raw["trajectory"][0]["tokens"] = { + "prompt_ids_len": 4, + "completion_ids_len": 2, + "has_routed_experts": True, + } + rollout.samples = [ + TrainingSample( + prompt_ids=[1, 2], + prompt_mask=[False, False], + completion_ids=[777, 120_000], + completion_mask=[False, True], + completion_logprobs=[0.0, gibberish_filter.logprob_threshold - 1.0], + completion_temperatures=[1.0, 1.0], + env_name="test", + ) + ] + + result = gibberish_filter.check(rollout) + + assert result.detected is True + assert result.detection_index == 0 + + # --- RepetitionFilter tests --- @@ -187,6 +214,32 @@ def test_repetition_varied_probs_no_trigger(): assert result.detected is False +def test_repetition_uses_samples_when_raw_tokens_are_pruned(): + repetition_filter = _make_repetition_filter(window=3) + rollout = _make_rollout(completion_ids=[], completion_logprobs=[]) + rollout.raw["trajectory"][0]["tokens"] = { + "prompt_ids_len": 4, + "completion_ids_len": 5, + "has_routed_experts": True, + } + rollout.samples = [ + TrainingSample( + prompt_ids=[1, 2], + prompt_mask=[False, False], + completion_ids=[10, 11, 12, 13], + completion_mask=[False, True, True, True], + completion_logprobs=[0.0, -0.001, -0.001, -0.001], + completion_temperatures=[1.0, 1.0, 1.0, 1.0], + env_name="test", + ) + ] + + result = repetition_filter.check(rollout) + + assert result.detected is True + assert result.detection_index == 2 + + # --- setup_filter / setup_filters tests --- diff --git a/tests/unit/orchestrator/test_orchestrator_setup.py b/tests/unit/orchestrator/test_orchestrator_setup.py index 2372b004fd..12f87117c6 100644 --- a/tests/unit/orchestrator/test_orchestrator_setup.py +++ b/tests/unit/orchestrator/test_orchestrator_setup.py @@ -19,6 +19,7 @@ async def run() -> None: ), renderer=renderer_settings, pool_size=None, + debug=SimpleNamespace(no_inference=False), ) renderer = object() inference_pool = object() @@ -63,6 +64,7 @@ async def run() -> None: client=SimpleNamespace(base_url=["http://localhost:8000/v1"]), model=SimpleNamespace(name="student-model"), ), + debug=SimpleNamespace(no_inference=False), ) inference_pool = object() diff --git a/tests/unit/orchestrator/test_train_sink_memory.py b/tests/unit/orchestrator/test_train_sink_memory.py new file mode 100644 index 0000000000..5957e4b95c --- /dev/null +++ b/tests/unit/orchestrator/test_train_sink_memory.py @@ -0,0 +1,204 @@ +import base64 +import uuid +from pathlib import Path +from types import SimpleNamespace + +import pytest + +from prime_rl.orchestrator.train_sink import TrainSink +from prime_rl.orchestrator.trajectories import interleave_rollout, prune_train_rollout_payload +from prime_rl.orchestrator.types import TrainRollout +from prime_rl.transport import TrainingSample + + +def test_prune_train_rollout_payload_compacts_heavy_token_arrays(): + raw = { + "trajectory": [ + { + "prompt": [{"role": "user", "content": "hello"}], + "completion": [{"role": "assistant", "content": "world"}], + "response": object(), + "tokens": { + "prompt_ids": [1, 2, 3], + "prompt_mask": [0, 0, 0], + "completion_ids": [4, 5], + "completion_mask": [1, 1], + "completion_logprobs": [-0.1, -0.2], + "routed_experts": { + "data": "large-base64-payload", + "shape": [5, 78, 8], + "dtype": "uint8", + "start": 0, + }, + "overlong_prompt": False, + "is_truncated": False, + }, + } + ] + } + + prune_train_rollout_payload(raw) + + step = raw["trajectory"][0] + assert "response" not in step + assert raw["trajectory_payload_pruned"] is True + assert step["tokens"] == { + "prompt_ids_len": 3, + "prompt_mask_len": 3, + "completion_ids_len": 2, + "completion_mask_len": 2, + "completion_logprobs_len": 2, + "has_routed_experts": True, + "routed_experts_shape": [5, 78, 8], + "routed_experts_dtype": "uint8", + "routed_experts_start": 0, + "overlong_prompt": False, + "is_truncated": False, + } + + +def test_interleave_rollout_prunes_raw_payload_during_preparation(): + raw = { + "example_id": 0, + "error": None, + "trajectory": [ + { + "tokens": { + "prompt_ids": [1, 2], + "prompt_mask": [0, 0], + "completion_ids": [3], + "completion_mask": [1], + "completion_logprobs": [-0.1], + "routed_experts": { + "data": base64.b64encode(bytes([7, 8])).decode(), + "shape": [2, 1, 1], + "dtype": "uint8", + "start": 0, + }, + "overlong_prompt": False, + "is_truncated": False, + }, + "response": object(), + } + ], + } + + samples = interleave_rollout(raw, env_name="test-env", prune_raw_payload=True) + + assert samples is not None + assert samples[0].routed_experts is not None + assert samples[0].routed_experts.shape == [3, 1, 1] + assert raw["trajectory_payload_pruned"] is True + step = raw["trajectory"][0] + assert "response" not in step + assert step["tokens"]["has_routed_experts"] is True + assert step["tokens"]["routed_experts_shape"] == [2, 1, 1] + assert "routed_experts" not in step["tokens"] + + +def test_interleave_rollout_does_not_partially_prune_on_prepare_failure(): + raw = { + "example_id": 0, + "error": None, + "trajectory": [ + { + "tokens": { + "prompt_ids": [1, 2], + "prompt_mask": [0, 0], + "completion_ids": [3], + "completion_mask": [1], + "completion_logprobs": [-0.1], + "routed_experts": { + "data": base64.b64encode(bytes([7, 8])).decode(), + "shape": [2, 1, 1], + "dtype": "uint8", + "start": 0, + }, + }, + "response": object(), + }, + {"tokens": None}, + ], + } + + samples = interleave_rollout(raw, env_name="test-env", prune_raw_payload=True) + + assert samples is None + assert "trajectory_payload_pruned" not in raw + first_step = raw["trajectory"][0] + assert "response" in first_step + assert "routed_experts" in first_step["tokens"] + assert "completion_ids" in first_step["tokens"] + + +@pytest.mark.asyncio +async def test_process_rollout_prunes_raw_payload_immediately(monkeypatch: pytest.MonkeyPatch, tmp_path: Path): + raw = { + "error": None, + "trajectory": [ + { + "tokens": { + "prompt_ids": [1, 2, 3], + "prompt_mask": [0, 0, 0], + "completion_ids": [4, 5], + "completion_mask": [1, 1], + "completion_logprobs": [-0.1, -0.2], + "routed_experts": { + "data": "large-base64-payload", + "shape": [5, 78, 8], + "dtype": "uint8", + "start": 0, + }, + }, + } + ], + } + sample = TrainingSample( + prompt_ids=[1, 2, 3], + prompt_mask=[False, False, False], + completion_ids=[4, 5], + completion_mask=[True, True], + completion_logprobs=[-0.1, -0.2], + completion_temperatures=[], + env_name="test-env", + ) + + def fake_interleave_rollout(output, *args, prune_raw_payload: bool, **kwargs): + assert prune_raw_payload is True + prune_train_rollout_payload(output) + return [sample] + + monkeypatch.setattr("prime_rl.orchestrator.train_sink.interleave_rollout", fake_interleave_rollout) + monkeypatch.setattr("prime_rl.orchestrator.train_sink.offload_images_to_disk", lambda *args, **kwargs: 0) + + sink = TrainSink( + SimpleNamespace( + output_dir=tmp_path, + training_mode="rl", + ), + tokenizer=None, + renderer=None, + train_envs=SimpleNamespace(), + mm_token_type_ids_mapping=None, + batch_size=1, + token_batch_size=None, + pre_filters=[], + post_filters=[], + ) + rollout = TrainRollout( + raw=raw, + env_name="test-env", + example_id=0, + group_id=uuid.uuid4(), + policy_version=0, + off_policy_steps=0, + ) + + await sink.process_rollout(rollout) + + assert rollout.samples == [sample] + assert raw["trajectory_payload_pruned"] is True + tokens = raw["trajectory"][0]["tokens"] + assert "routed_experts" not in tokens + assert tokens["has_routed_experts"] is True + assert tokens["routed_experts_shape"] == [5, 78, 8] diff --git a/tests/unit/utils/test_monitor_samples.py b/tests/unit/utils/test_monitor_samples.py new file mode 100644 index 0000000000..038e395870 --- /dev/null +++ b/tests/unit/utils/test_monitor_samples.py @@ -0,0 +1,30 @@ +from prime_rl.utils.monitor.samples import render_prompt_completion_text, token_payload_ids, token_payload_length + + +class FakeTokenizer: + def apply_chat_template(self, messages, tokenize=False): + assert tokenize is False + return "\n".join(f"{message['role']}: {message['content']}" for message in messages) + + +def test_token_payload_length_supports_full_and_compacted_tokens(): + assert token_payload_length({"prompt_ids": [1, 2, 3]}, "prompt_ids") == 3 + assert token_payload_length({"prompt_ids_len": 12}, "prompt_ids") == 12 + assert token_payload_length({"prompt_ids_len": "12"}, "prompt_ids") is None + assert token_payload_length(None, "prompt_ids") is None + + +def test_token_payload_ids_returns_full_prompt_and_completion_ids_only(): + assert token_payload_ids({"prompt_ids": [1, 2], "completion_ids": [3, 4]}) == [1, 2, 3, 4] + assert token_payload_ids({"prompt_ids_len": 2, "completion_ids_len": 2}) is None + assert token_payload_ids(None) is None + + +def test_render_prompt_completion_text_uses_chat_template_after_tokens_are_compacted(): + text = render_prompt_completion_text( + FakeTokenizer(), + [{"role": "user", "content": "question"}], + [{"role": "assistant", "content": "answer"}], + ) + + assert text == "user: question\nassistant: answer" diff --git a/tests/unit/utils/test_prime_monitor.py b/tests/unit/utils/test_prime_monitor.py index f44065a7c6..89c4679cc9 100644 --- a/tests/unit/utils/test_prime_monitor.py +++ b/tests/unit/utils/test_prime_monitor.py @@ -94,6 +94,28 @@ def test_rollouts_to_parquet_bytes_skips_rollouts_without_trajectory(): assert rows[0]["sample_id"] == 0 +def test_rollouts_to_parquet_bytes_uses_compacted_token_lengths(): + monitor = _new_monitor() + monitor.run_id = "run-compact" + rollout = _build_rollout(example_id=303, reward=1.0, task="task-c") + rollout["trajectory"][0]["tokens"] = { + "prompt_ids_len": 12, + "completion_ids_len": 34, + "has_routed_experts": True, + } + + parquet_bytes = monitor._rollouts_to_parquet_bytes([rollout], step=9) + + assert parquet_bytes is not None + + table = pq.read_table(io.BytesIO(parquet_bytes)) + rows = table.to_pylist() + trajectory = json.loads(rows[0]["trajectory"]) + + assert trajectory[0]["num_input_tokens"] == 12 + assert trajectory[0]["num_output_tokens"] == 34 + + def test_sanitize_json_payload_drops_non_finite_values_and_logs_paths(): monitor = _new_monitor() monitor.logger = Mock() diff --git a/uv.lock b/uv.lock index dfd6cc1377..372c8cf68e 100644 --- a/uv.lock +++ b/uv.lock @@ -1054,6 +1054,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017", size = 28317, upload-time = "2025-09-01T09:48:08.5Z" }, ] +[[package]] +name = "fake-r3-trajectory" +version = "0.1.0" +source = { editable = "tests/debug_envs/fake_r3_trajectory" } +dependencies = [ + { name = "datasets", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pybase64", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "verifiers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] + +[package.metadata] +requires-dist = [ + { name = "datasets" }, + { name = "numpy" }, + { name = "pybase64" }, + { name = "verifiers" }, +] + [[package]] name = "farama-notifications" version = "0.0.6" @@ -3997,6 +4016,7 @@ envs = [ { name = "code-env", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "color-codeword", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "deepdive", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "fake-r3-trajectory", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "general-agent", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "gpqa", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "hle", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -4069,6 +4089,7 @@ requires-dist = [ { name = "deep-gemm", marker = "platform_machine == 'x86_64' and extra == 'disagg'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/deep_gemm-2.5.0+891d57b-cp312-cp312-linux_x86_64.whl" }, { name = "deepdive", marker = "extra == 'envs'", editable = "deps/research-environments/environments/deepdive" }, { name = "dion", git = "https://github.com/samsja/dion.git?rev=d891eeb" }, + { name = "fake-r3-trajectory", marker = "extra == 'envs'", editable = "tests/debug_envs/fake_r3_trajectory" }, { name = "flash-attn", marker = "platform_machine == 'x86_64' and extra == 'flash-attn'", url = "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.9.4/flash_attn-2.8.3+cu128torch2.11-cp312-cp312-linux_x86_64.whl" }, { name = "flash-attn-3", marker = "extra == 'flash-attn-3'", index = "https://download.pytorch.org/whl/test/cu128" }, { name = "flash-attn-4", marker = "extra == 'flash-attn-cute'", git = "https://github.com/Dao-AILab/flash-attention.git?subdirectory=flash_attn%2Fcute&rev=96bd151" },