Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
FileSystemTransportConfig,
HeartbeatConfig,
LogConfig,
NoOpTransportConfig,
PrimeMonitorConfig,
TransportConfig,
WandbWithExtrasConfig,
Expand Down Expand Up @@ -489,6 +490,20 @@ class OrchestratorExperimentalConfig(BaseConfig):
pass


class OrchestratorDebugConfig(BaseConfig):
no_inference: bool = False
"""Use an in-process no-op inference pool. Rollout environments must not call the model client."""

no_trainer: bool = False
"""Run without a trainer. Uses no-op batch sending and advances policy version locally."""

fake_tokenizer: bool = False
"""Use a tiny tokenizer stub. Only safe when rollouts already include tokens and sample monitors are disabled."""

log_memory: bool = False
"""Log orchestrator and child-process RSS after each completed orchestrator step."""


class RolloutModelConfig(BaseConfig):
model: ModelConfig = ModelConfig()

Expand Down Expand Up @@ -624,6 +639,8 @@ def _preserve_mito_renderer(self, handler: SerializerFunctionWrapHandler) -> dic
env_install_prerelease: bool = False
"""Allow pre-release versions when installing environments (e.g. ``verifiers>=0.1.12.dev5``). Passes ``--prerelease`` to ``prime env install``."""

debug: OrchestratorDebugConfig = OrchestratorDebugConfig()

experimental: OrchestratorExperimentalConfig = OrchestratorExperimentalConfig()

@model_validator(mode="before")
Expand Down Expand Up @@ -760,6 +777,8 @@ def _force_no_renderer_for_sft(self):
@model_validator(mode="after")
def validate_training_mode(self):
"""Enforce training mode invariants that involve only orchestrator fields."""
if self.debug.no_inference:
return self
has_teacher = self.teacher is not None
if self.training_mode == "rl" and has_teacher:
raise ValueError("orchestrator.teacher must not be set when training_mode = 'rl'.")
Expand All @@ -772,6 +791,8 @@ def validate_pool_size(self):
"""``pool_size`` is only meaningful when the renderer is enabled
(``renderer is not None``). Reject otherwise so callers don't
silently pass it and wonder why it's ignored."""
if self.debug.no_inference:
return self
if self.renderer is None and self.pool_size is not None:
raise ValueError(
f"orchestrator.pool_size={self.pool_size!r} is set but "
Expand All @@ -788,6 +809,8 @@ def vlm_requires_renderer(self):
tokens, and ships generic ``mm_kwargs`` keyed by whatever the
model's forward signature expects.
"""
if self.debug.no_inference:
return self
if self.student.model.vlm is not None and self.renderer is None:
raise ValueError(
"orchestrator.renderer must be set when model.vlm is set. "
Expand All @@ -810,6 +833,8 @@ def validate_renderer_auto_resolves(self):
"""
if self.renderer is None or self.renderer.name != "auto":
return self
if self.debug.no_inference:
return self
from renderers.base import MODEL_RENDERER_MAP

model_id = self.tokenizer.name or self.student.model.name
Expand Down Expand Up @@ -898,6 +923,17 @@ def auto_setup_bench(self):

return self

@model_validator(mode="after")
def resolve_debug(self):
if self.debug.no_trainer:
self.rollout_transport = NoOpTransportConfig()
if self.debug.fake_tokenizer and (self.wandb is not None or self.prime_monitor is not None):
raise ValueError("orchestrator.debug.fake_tokenizer requires wandb and prime_monitor to be disabled.")
if self.debug.no_inference:
self.renderer = None
self.collect_inference_metrics = False
return self

@model_validator(mode="after")
def resolve_env_config(self):
"""Populate extra_env_kwargs and vLLM sampling defaults from top-level fields."""
Expand Down
10 changes: 9 additions & 1 deletion packages/prime-rl-configs/src/prime_rl/configs/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,4 +261,12 @@ class ZMQTransportConfig(BaseTransportConfig):
"""High-water mark (max in-flight messages per ZMQ socket)."""


TransportConfig: TypeAlias = Annotated[FileSystemTransportConfig | ZMQTransportConfig, Field(discriminator="type")]
class NoOpTransportConfig(BaseTransportConfig):
type: Literal["noop"] = "noop"
"""Drop training batches after orchestrator-side processing. Intended for orchestrator debug runs without a trainer."""


TransportConfig: TypeAlias = Annotated[
FileSystemTransportConfig | ZMQTransportConfig | NoOpTransportConfig,
Field(discriminator="type"),
]
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ envs = [
"code-env",
"color-codeword",
"deepdive",
"fake-r3-trajectory",
"general-agent",
"gpqa",
"hle",
Expand Down Expand Up @@ -209,6 +210,7 @@ alphabet-sort = { path = "deps/verifiers/environments/alphabet_sort", editable =
code-env = { path = "deps/research-environments/environments/code_env", editable = true }
color-codeword = { path = "deps/research-environments/environments/color_codeword", editable = true }
deepdive = { path = "deps/research-environments/environments/deepdive", editable = true }
fake-r3-trajectory = { path = "tests/debug_envs/fake_r3_trajectory", editable = true }
general-agent = { path = "deps/research-environments/environments/general_agent", editable = true }
gpqa = { path = "deps/research-environments/environments/gpqa", editable = true }
harnesses = { path = "deps/verifiers/packages/harnesses", editable = true }
Expand Down
143 changes: 80 additions & 63 deletions src/prime_rl/entrypoints/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,32 @@ def rl_local(config: RLConfig):
logger.success("Dry run complete. To start an RL run locally, remove --dry-run from your command.")
return

debug_no_inference = config.orchestrator.debug.no_inference
debug_no_trainer = config.orchestrator.debug.no_trainer

# Derive launcher-local GPU IDs from deployment config
gpu_offset = 0
num_infer_gpus = config.deployment.num_infer_gpus if config.inference is not None else 0
num_infer_gpus = (
0 if debug_no_inference else (config.deployment.num_infer_gpus if config.inference is not None else 0)
)
infer_local_gpu_ids = list(range(gpu_offset, gpu_offset + num_infer_gpus))
gpu_offset += num_infer_gpus
trainer_local_gpu_ids = list(range(gpu_offset, gpu_offset + config.deployment.num_train_gpus))

total_requested_gpus = num_infer_gpus + config.deployment.num_train_gpus
physical_gpu_ids = get_physical_gpu_ids()
if total_requested_gpus > len(physical_gpu_ids):
raise ValueError(
f"Requested {total_requested_gpus} GPUs via deployment settings, but only "
f"{len(physical_gpu_ids)} physical GPU(s) are available: {physical_gpu_ids}"
)
physical_gpu_mapping = {local_id: physical_gpu_ids[local_id] for local_id in range(total_requested_gpus)}
logger.info(f"Using local->physical GPU mapping: {physical_gpu_mapping}")
num_train_gpus = 0 if debug_no_trainer else config.deployment.num_train_gpus
trainer_local_gpu_ids = list(range(gpu_offset, gpu_offset + num_train_gpus))

total_requested_gpus = num_infer_gpus + num_train_gpus
if total_requested_gpus > 0:
physical_gpu_ids = get_physical_gpu_ids()
if total_requested_gpus > len(physical_gpu_ids):
raise ValueError(
f"Requested {total_requested_gpus} GPUs via deployment settings, but only "
f"{len(physical_gpu_ids)} physical GPU(s) are available: {physical_gpu_ids}"
)
physical_gpu_mapping = {local_id: physical_gpu_ids[local_id] for local_id in range(total_requested_gpus)}
logger.info(f"Using local->physical GPU mapping: {physical_gpu_mapping}")
else:
physical_gpu_mapping = {}
logger.info("No GPUs requested for orchestrator debug run")

infer_gpu_ids = [physical_gpu_mapping[local_gpu_id] for local_gpu_id in infer_local_gpu_ids]
trainer_gpu_ids = [physical_gpu_mapping[local_gpu_id] for local_gpu_id in trainer_local_gpu_ids]
Expand Down Expand Up @@ -150,7 +160,7 @@ def sigterm_handler(signum, frame):

try:
# Optionally, start inference process
if config.inference:
if config.inference and not debug_no_inference:
inference_cmd = ["inference", "@", (config_dir / INFERENCE_TOML).as_posix()]
logger.info(f"Starting inference on GPU(s) {' '.join(map(str, infer_gpu_ids))}")
logger.debug(f"Inference start command: {' '.join(inference_cmd)}")
Expand All @@ -177,6 +187,8 @@ def sigterm_handler(signum, frame):
)
monitor_thread.start()
monitor_threads.append(monitor_thread)
elif debug_no_inference:
logger.warning("Skipping inference process for orchestrator debug no-inference mode")
else:
logger.warning(
"No [inference] block configured - the student inference server will not be started here. "
Expand Down Expand Up @@ -224,54 +236,58 @@ def sigterm_handler(signum, frame):
monitor_thread.start()
monitor_threads.append(monitor_thread)

# Start training process
from prime_rl.utils.utils import get_free_port

trainer_cmd = [
"torchrun",
"--role=trainer",
f"--rdzv-endpoint=localhost:{get_free_port()}",
f"--rdzv-id={uuid.uuid4().hex}",
# Pipe all logs to file, and only master rank logs to stdout
f"--log-dir={log_dir / 'trainer' / 'torchrun'}",
f"--local-ranks-filter={','.join(map(str, config.trainer.log.ranks_filter))}",
"--redirect=3",
"--tee=3",
f"--nproc-per-node={len(trainer_gpu_ids)}",
"-m",
"prime_rl.trainer.rl.train",
"@",
(config_dir / TRAINER_TOML).as_posix(),
]
logger.info(f"Starting trainer on GPU(s) {' '.join(map(str, trainer_gpu_ids))}")
logger.debug(f"Training start command: {' '.join(trainer_cmd)}")
with open(log_dir / "trainer.log", "w") as log_file:
trainer_process = Popen(
trainer_cmd,
env={
**os.environ,
**wandb_shared_env,
"WANDB_SHARED_LABEL": "trainer",
"CUDA_VISIBLE_DEVICES": ",".join(map(str, trainer_gpu_ids)),
"PYTHONUNBUFFERED": "1",
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
"LOGURU_FORCE_COLORS": "1",
"WANDB_PROGRAM": "uv run rl",
"WANDB_ARGS": json.dumps(start_command),
},
stdout=log_file,
stderr=log_file,
)
processes.append(trainer_process)
if debug_no_trainer:
logger.warning("Skipping trainer process for orchestrator debug no-trainer mode")
trainer_process = None
else:
# Start training process
from prime_rl.utils.utils import get_free_port

trainer_cmd = [
"torchrun",
"--role=trainer",
f"--rdzv-endpoint=localhost:{get_free_port()}",
f"--rdzv-id={uuid.uuid4().hex}",
# Pipe all logs to file, and only master rank logs to stdout
f"--log-dir={log_dir / 'trainer' / 'torchrun'}",
f"--local-ranks-filter={','.join(map(str, config.trainer.log.ranks_filter))}",
"--redirect=3",
"--tee=3",
f"--nproc-per-node={len(trainer_gpu_ids)}",
"-m",
"prime_rl.trainer.rl.train",
"@",
(config_dir / TRAINER_TOML).as_posix(),
]
logger.info(f"Starting trainer on GPU(s) {' '.join(map(str, trainer_gpu_ids))}")
logger.debug(f"Training start command: {' '.join(trainer_cmd)}")
with open(log_dir / "trainer.log", "w") as log_file:
trainer_process = Popen(
trainer_cmd,
env={
**os.environ,
**wandb_shared_env,
"WANDB_SHARED_LABEL": "trainer",
"CUDA_VISIBLE_DEVICES": ",".join(map(str, trainer_gpu_ids)),
"PYTHONUNBUFFERED": "1",
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
"LOGURU_FORCE_COLORS": "1",
"WANDB_PROGRAM": "uv run rl",
"WANDB_ARGS": json.dumps(start_command),
},
stdout=log_file,
stderr=log_file,
)
processes.append(trainer_process)

# Start monitoring thread
stop_event = Event()
stop_events["trainer"] = stop_event
monitor_thread = Thread(
target=monitor_process, args=(trainer_process, stop_event, error_queue, "trainer"), daemon=True
)
monitor_thread.start()
monitor_threads.append(monitor_thread)
# Start monitoring thread
stop_event = Event()
stop_events["trainer"] = stop_event
monitor_thread = Thread(
target=monitor_process, args=(trainer_process, stop_event, error_queue, "trainer"), daemon=True
)
monitor_thread.start()
monitor_threads.append(monitor_thread)

# Monitor all processes for failures
logger.success("Startup complete. Showing orchestrator logs...")
Expand All @@ -283,7 +299,8 @@ def sigterm_handler(signum, frame):
processes.append(tail_process)

# Check for errors from monitor threads
while not (stop_events["orchestrator"].is_set() and stop_events["trainer"].is_set()):
required_processes = ["orchestrator"] + ([] if debug_no_trainer else ["trainer"])
while not all(stop_events[name].is_set() for name in required_processes):
if error_queue:
error = error_queue[0]
logger.error(f"Error: {error}")
Expand All @@ -302,7 +319,7 @@ def sigterm_handler(signum, frame):
cleanup_processes(processes)
sys.exit(1)

if trainer_process.returncode != 0:
if trainer_process is not None and trainer_process.returncode != 0:
logger.error(f"Trainer failed with exit code {trainer_process.returncode}")
cleanup_threads(monitor_threads)
cleanup_processes(processes)
Expand Down Expand Up @@ -501,7 +518,7 @@ def rl(config: RLConfig):
get_logger().info("Training from scratch, cleaning any stale rollouts and broadcasts")
clean_future_steps(config.output_dir, -1)

if not config.dry_run:
if not config.dry_run and not config.orchestrator.debug.no_trainer:
from prime_rl.trainer.model import pre_download_model

pre_download_model(config.trainer.model.name)
Expand Down
Loading
Loading