Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
)
from prime_rl.orchestrator.watcher import NoOpWeightWatcher, WeightWatcher
from prime_rl.transport import TrainingBatch, setup_training_batch_sender
from prime_rl.transport.compact import compact_training_samples
from prime_rl.utils.async_utils import safe_cancel
from prime_rl.utils.client import init_nccl_broadcast, setup_inference_pool
from prime_rl.utils.heartbeat import Heartbeat
Expand Down Expand Up @@ -635,6 +636,7 @@ async def finalize_train_batch(self, batch: TrainBatch) -> None:
ex.teacher_logprobs = lp
teacher_logprobs_time = time.perf_counter() - t

compact_training_samples(batch.samples)
await self.sender.send(TrainingBatch(examples=batch.samples, step=step))
self.release_train_batch_samples(batch)
if config.debug.no_trainer:
Expand Down
43 changes: 32 additions & 11 deletions src/prime_rl/trainer/batch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
import copy

from prime_rl.transport.compact import (
training_sample_completion_ids,
training_sample_completion_len,
training_sample_completion_logprobs,
training_sample_completion_mask,
training_sample_completion_temperatures,
training_sample_mm_token_type_ids,
training_sample_prompt_ids,
training_sample_prompt_len,
training_sample_prompt_mask,
training_sample_teacher_logprobs,
)
from prime_rl.transport.types import MicroBatch, RoutedExperts, TrainingSample

ROUTED_EXPERTS_DTYPE_ITEMSIZE = {
Expand Down Expand Up @@ -30,14 +42,15 @@ def _slice_routed_experts(routed_experts: RoutedExperts, seq_len: int) -> Routed
)


def _completion_temperatures(training_example: TrainingSample) -> tuple[float, list[float]]:
if training_example.completion_temperatures:
return training_example.completion_temperatures[0], training_example.completion_temperatures
def _completion_temperatures(training_example: TrainingSample, completion_len: int) -> tuple[float, list[float]]:
completion_temperatures = training_sample_completion_temperatures(training_example)
if completion_temperatures:
return completion_temperatures[0], completion_temperatures

temperature = training_example.completion_temperature
if temperature is None:
temperature = 1.0
return temperature, [temperature] * len(training_example.completion_ids)
return temperature, [temperature] * completion_len


def _append_routed_experts(dst: MicroBatch, src: MicroBatch) -> None:
Expand All @@ -64,26 +77,34 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch
Prepare a problem for sequence packing training.
Tokenize and prepare tensors.
"""
input_ids = training_example.prompt_ids + training_example.completion_ids
loss_mask = training_example.prompt_mask + training_example.completion_mask
inference_logprobs = [0.0] * len(training_example.prompt_ids) + training_example.completion_logprobs
prompt_ids = training_sample_prompt_ids(training_example)
prompt_mask = training_sample_prompt_mask(training_example)
completion_ids = training_sample_completion_ids(training_example)
completion_mask = training_sample_completion_mask(training_example)
completion_logprobs = training_sample_completion_logprobs(training_example)
prompt_len = training_sample_prompt_len(training_example)
completion_len = training_sample_completion_len(training_example)

input_ids = prompt_ids + completion_ids
loss_mask = prompt_mask + completion_mask
inference_logprobs = [0.0] * prompt_len + completion_logprobs
advantages = [training_example.advantage] * len(input_ids)
reward = training_example.reward if training_example.reward is not None else float("nan")
rewards = [reward] * len(input_ids)
position_ids = list(range(len(input_ids)))
mm_token_type_ids = training_example.mm_token_type_ids
mm_token_type_ids = training_sample_mm_token_type_ids(training_example)
assert training_example.env_name != "all", "env_name='all' is reserved for aggregate metric keys"
env_names = [training_example.env_name] * len(input_ids)

# Per-token temperatures: prompt tokens use the completion temperature
# (masked out anyway). The transport can carry a compact scalar for the
# common constant-temperature case.
prompt_temp, completion_temperatures = _completion_temperatures(training_example)
temperatures = [prompt_temp] * len(training_example.prompt_ids) + completion_temperatures
prompt_temp, completion_temperatures = _completion_temperatures(training_example, completion_len)
temperatures = [prompt_temp] * prompt_len + completion_temperatures

# Teacher logprobs already cover the full sequence (prompt + completion),
# computed via prefill in the orchestrator when a teacher model is configured
teacher_logprobs = training_example.teacher_logprobs
teacher_logprobs = training_sample_teacher_logprobs(training_example)
routed_experts = (
_copy_routed_experts(training_example.routed_experts) if training_example.routed_experts is not None else None
)
Expand Down
51 changes: 34 additions & 17 deletions src/prime_rl/trainer/rl/packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@
setup_micro_batch_sender,
setup_training_batch_receiver,
)
from prime_rl.transport.compact import (
training_sample_completion_len,
training_sample_completion_logprobs_len,
training_sample_completion_mask_len,
training_sample_completion_temperatures_len,
training_sample_prompt_len,
training_sample_prompt_mask_len,
training_sample_teacher_logprobs_len,
training_sample_token_len,
)
from prime_rl.utils.logger import get_logger
from prime_rl.utils.pathing import get_rollout_dir

Expand Down Expand Up @@ -154,30 +164,35 @@ def _on_run_data_deleted(self, idx: int, run_id: str) -> None:

def _validate_sample(self, sample: TrainingSample) -> tuple[bool, str | None]:
"""Validate a sample to ensure it won't crash the trainer."""
sample_length = len(sample.prompt_ids) + len(sample.completion_ids)
if len(sample.prompt_mask) != len(sample.prompt_ids):
prompt_len = training_sample_prompt_len(sample)
completion_len = training_sample_completion_len(sample)
sample_length = prompt_len + completion_len
prompt_mask_len = training_sample_prompt_mask_len(sample)
completion_mask_len = training_sample_completion_mask_len(sample)
completion_logprobs_len = training_sample_completion_logprobs_len(sample)
if prompt_mask_len != prompt_len:
return (
False,
f"Run wrote a sample with prompt mask length != prompt ids length ({len(sample.prompt_mask)} != {len(sample.prompt_ids)})",
f"Run wrote a sample with prompt mask length != prompt ids length ({prompt_mask_len} != {prompt_len})",
)
if len(sample.completion_mask) != len(sample.completion_ids):
if completion_mask_len != completion_len:
return (
False,
f"Run wrote a sample with completion mask length != completion ids length ({len(sample.completion_mask)} != {len(sample.completion_ids)})",
f"Run wrote a sample with completion mask length != completion ids length ({completion_mask_len} != {completion_len})",
)
if len(sample.completion_logprobs) != len(sample.completion_ids):
if completion_logprobs_len != completion_len:
return (
False,
f"Run wrote a sample with completion logprobs length != completion ids length ({len(sample.completion_logprobs)} != {len(sample.completion_ids)})",
f"Run wrote a sample with completion logprobs length != completion ids length ({completion_logprobs_len} != {completion_len})",
)
completion_temperatures_len = len(sample.completion_temperatures)
completion_temperatures_len = training_sample_completion_temperatures_len(sample)
has_compact_temperature = sample.completion_temperature is not None
if completion_temperatures_len != len(sample.completion_ids) and not (
if completion_temperatures_len != completion_len and not (
completion_temperatures_len == 0 and has_compact_temperature
):
return (
False,
f"Run wrote a sample with completion temperatures length != completion ids length ({completion_temperatures_len} != {len(sample.completion_ids)})",
f"Run wrote a sample with completion temperatures length != completion ids length ({completion_temperatures_len} != {completion_len})",
)
if sample_length == 0:
return False, "Run wrote a sample with no tokens"
Expand All @@ -186,10 +201,11 @@ def _validate_sample(self, sample: TrainingSample) -> tuple[bool, str | None]:
False,
f"Run wrote a sample with length {sample_length} which exceeds max sequence length {self.seq_len}",
)
if sample.teacher_logprobs is not None and len(sample.teacher_logprobs) != sample_length:
teacher_logprobs_len = training_sample_teacher_logprobs_len(sample)
if teacher_logprobs_len is not None and teacher_logprobs_len != sample_length:
return (
False,
f"Run wrote a sample with teacher logprobs length != sample length ({len(sample.teacher_logprobs)} != {sample_length})",
f"Run wrote a sample with teacher logprobs length != sample length ({teacher_logprobs_len} != {sample_length})",
)
return True, None

Expand Down Expand Up @@ -226,7 +242,7 @@ def _count_tokens(self, threshold: int | None = None) -> int:
for sample, step in buffer:
if step > current_step:
break
tokens += len(sample.prompt_ids) + len(sample.completion_ids)
tokens += training_sample_token_len(sample)
if threshold is not None and tokens >= threshold:
return tokens
return tokens
Expand Down Expand Up @@ -263,10 +279,11 @@ def _select_samples_round_robin(self, token_budget: int) -> list[tuple[int, Trai
if step > current_step:
# Samples from different steps should be consumed later
break
tokens_collected += len(sample.prompt_ids) + len(sample.completion_ids)
sample_tokens = training_sample_token_len(sample)
tokens_collected += sample_tokens
if tokens_collected > token_budget:
if tokens_collected == (len(sample.prompt_ids) + len(sample.completion_ids)):
tokens_collected -= len(sample.prompt_ids) + len(sample.completion_ids)
if tokens_collected == sample_tokens:
tokens_collected -= sample_tokens
# This means we have a sample that has more tokens than max seqlen
self.buffers[run_idx].popleft()
continue
Expand Down Expand Up @@ -320,7 +337,7 @@ def pack(self):
assert steps_by_run[run_idx] == step, "Micro batches for a run must come from a single run step"
samples_by_run[run_idx].append(sample)

num_tokens = len(sample.prompt_ids) + len(sample.completion_ids)
num_tokens = training_sample_token_len(sample)
if run_idx in per_run_stats:
cur_samples, cur_tokens = per_run_stats[run_idx]
per_run_stats[run_idx] = (cur_samples + 1, cur_tokens + num_tokens)
Expand Down
156 changes: 156 additions & 0 deletions src/prime_rl/transport/compact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from collections.abc import Sequence

import numpy as np

from prime_rl.transport.types import PackedArray, TrainingSample


def _pack_numeric(values: Sequence[int] | Sequence[float], dtype: str) -> PackedArray:
arr = np.asarray(values, dtype=np.dtype(dtype))
return PackedArray(data=arr.tobytes(), shape=[int(arr.shape[0])], dtype=dtype)


def _pack_bool(values: Sequence[bool]) -> PackedArray:
arr = np.asarray(values, dtype=np.bool_)
return PackedArray(data=np.packbits(arr, bitorder="little").tobytes(), shape=[int(arr.shape[0])], dtype="bool")


def _unpack_numeric(packed: PackedArray) -> list[int] | list[float]:
return np.frombuffer(packed.data, dtype=np.dtype(packed.dtype), count=packed.shape[0]).tolist()


def _unpack_bool(packed: PackedArray) -> list[bool]:
if packed.shape[0] == 0:
return []
packed_bytes = np.frombuffer(packed.data, dtype=np.uint8)
return np.unpackbits(packed_bytes, bitorder="little", count=packed.shape[0]).astype(np.bool_).tolist()


def _packed_len(packed: PackedArray | None) -> int | None:
if packed is None:
return None
return packed.shape[0]


def _field_len(values: Sequence | None, packed: PackedArray | None) -> int:
packed_len = _packed_len(packed)
return packed_len if packed_len is not None else len(values or [])


def compact_training_sample(sample: TrainingSample) -> None:
"""Replace large list fields with byte-backed arrays for transport."""
if sample.packed_prompt_ids is not None:
return

sample.packed_prompt_ids = _pack_numeric(sample.prompt_ids, "uint32")
sample.prompt_ids = []
sample.packed_prompt_mask = _pack_bool(sample.prompt_mask)
sample.prompt_mask = []
sample.packed_completion_ids = _pack_numeric(sample.completion_ids, "uint32")
sample.completion_ids = []
sample.packed_completion_mask = _pack_bool(sample.completion_mask)
sample.completion_mask = []
sample.packed_completion_logprobs = _pack_numeric(sample.completion_logprobs, "float32")
sample.completion_logprobs = []

if sample.completion_temperatures:
sample.packed_completion_temperatures = _pack_numeric(sample.completion_temperatures, "float32")
sample.completion_temperatures = []

if sample.teacher_logprobs is not None:
sample.packed_teacher_logprobs = _pack_numeric(sample.teacher_logprobs, "float32")
sample.teacher_logprobs = None

if sample.mm_token_type_ids is not None:
sample.packed_mm_token_type_ids = _pack_numeric(sample.mm_token_type_ids, "uint8")
sample.mm_token_type_ids = None


def compact_training_samples(samples: list[TrainingSample]) -> None:
for sample in samples:
compact_training_sample(sample)


def training_sample_prompt_ids(sample: TrainingSample) -> list[int]:
if sample.packed_prompt_ids is not None:
return _unpack_numeric(sample.packed_prompt_ids)
return sample.prompt_ids


def training_sample_prompt_mask(sample: TrainingSample) -> list[bool]:
if sample.packed_prompt_mask is not None:
return _unpack_bool(sample.packed_prompt_mask)
return sample.prompt_mask


def training_sample_completion_ids(sample: TrainingSample) -> list[int]:
if sample.packed_completion_ids is not None:
return _unpack_numeric(sample.packed_completion_ids)
return sample.completion_ids


def training_sample_completion_mask(sample: TrainingSample) -> list[bool]:
if sample.packed_completion_mask is not None:
return _unpack_bool(sample.packed_completion_mask)
return sample.completion_mask


def training_sample_completion_logprobs(sample: TrainingSample) -> list[float]:
if sample.packed_completion_logprobs is not None:
return _unpack_numeric(sample.packed_completion_logprobs)
return sample.completion_logprobs


def training_sample_completion_temperatures(sample: TrainingSample) -> list[float]:
if sample.packed_completion_temperatures is not None:
return _unpack_numeric(sample.packed_completion_temperatures)
return sample.completion_temperatures


def training_sample_teacher_logprobs(sample: TrainingSample) -> list[float] | None:
if sample.packed_teacher_logprobs is not None:
return _unpack_numeric(sample.packed_teacher_logprobs)
return sample.teacher_logprobs


def training_sample_mm_token_type_ids(sample: TrainingSample) -> list[int] | None:
if sample.packed_mm_token_type_ids is not None:
return _unpack_numeric(sample.packed_mm_token_type_ids)
return sample.mm_token_type_ids


def training_sample_prompt_len(sample: TrainingSample) -> int:
return _field_len(sample.prompt_ids, sample.packed_prompt_ids)


def training_sample_prompt_mask_len(sample: TrainingSample) -> int:
return _field_len(sample.prompt_mask, sample.packed_prompt_mask)


def training_sample_completion_len(sample: TrainingSample) -> int:
return _field_len(sample.completion_ids, sample.packed_completion_ids)


def training_sample_completion_mask_len(sample: TrainingSample) -> int:
return _field_len(sample.completion_mask, sample.packed_completion_mask)


def training_sample_completion_logprobs_len(sample: TrainingSample) -> int:
return _field_len(sample.completion_logprobs, sample.packed_completion_logprobs)


def training_sample_completion_temperatures_len(sample: TrainingSample) -> int:
return _field_len(sample.completion_temperatures, sample.packed_completion_temperatures)


def training_sample_teacher_logprobs_len(sample: TrainingSample) -> int | None:
packed_len = _packed_len(sample.packed_teacher_logprobs)
if packed_len is not None:
return packed_len
if sample.teacher_logprobs is None:
return None
return len(sample.teacher_logprobs)


def training_sample_token_len(sample: TrainingSample) -> int:
return training_sample_prompt_len(sample) + training_sample_completion_len(sample)
18 changes: 18 additions & 0 deletions src/prime_rl/transport/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ class RoutedExperts(msgspec.Struct, array_like=True, gc=False, omit_defaults=Tru
dtype: str


class PackedArray(msgspec.Struct, array_like=True, gc=False, omit_defaults=True):
data: bytes
shape: list[int]
dtype: str


# Orchestrator -> Packer
class TrainingSample(msgspec.Struct, array_like=True, gc=False, omit_defaults=True):
"""A single training example."""
Expand Down Expand Up @@ -60,6 +66,18 @@ class TrainingSample(msgspec.Struct, array_like=True, gc=False, omit_defaults=Tr
# taus), sft uses sft_loss_fn. Stamped by the orchestrator from training_mode.
training_mode: TrainingMode = "rl"

# Compact transport representation for large Python lists. Orchestrator
# compacts these at the train-batch send boundary; trainer inflates only
# when preparing the selected microbatch.
packed_prompt_ids: PackedArray | None = None
packed_prompt_mask: PackedArray | None = None
packed_completion_ids: PackedArray | None = None
packed_completion_mask: PackedArray | None = None
packed_completion_logprobs: PackedArray | None = None
packed_completion_temperatures: PackedArray | None = None
packed_teacher_logprobs: PackedArray | None = None
packed_mm_token_type_ids: PackedArray | None = None


class TrainingBatch(msgspec.Struct, array_like=True, gc=False, omit_defaults=True):
"""A batch of training examples with metadata for transport."""
Expand Down
Loading
Loading