Skip to content
Merged
55 changes: 43 additions & 12 deletions skyrl/backends/skyrl_train/distributed/megatron/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,12 @@ def backward(

all_grad_input = []

batch_size = int(vocab_parallel_logits.shape[0])

for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * chunk_size
chunk_end = min(seq_size, (chunk_idx + 1) * chunk_size)
chunk_len = chunk_end - chunk_start

logits = vocab_parallel_logits[:, chunk_start:chunk_end, :]
logits = logits.to(dtype=torch.float32)
Expand All @@ -229,15 +232,30 @@ def backward(
)
softmax_output = softmax_output.exp()

# 1 if it's the chosen log prob, 0 otherwise
is_chosen = (~(target_mask[:, chunk_start:chunk_end])).unsqueeze(-1) * torch.nn.functional.one_hot(
masked_target[:, chunk_start:chunk_end],
num_classes=partition_vocab_size,
)

grad_input = is_chosen.float().sub_(softmax_output)

grad_input.mul_(grad_output[:, chunk_start:chunk_end].unsqueeze(dim=-1))
# Memory-efficient scatter-add fast path (ported from DistributedLogprob.backward).
# Materializing one_hot(masked_target, num_classes=partition_vocab_size) would
# allocate a [B, chunk_len, partition_vocab_size] int64 tensor (~8x the size of
# softmax_output in float32), which causes OOM for large vocabularies. Instead,
# compute -softmax * grad_output in place and add grad_output at the chosen-token
# positions via scatter_add_.
chunk_target_mask = target_mask[:, chunk_start:chunk_end]
chunk_masked_target = masked_target[:, chunk_start:chunk_end]
chunk_grad_output = grad_output[:, chunk_start:chunk_end]

row = torch.arange(batch_size, device=softmax_output.device).view(-1, 1).expand(-1, chunk_len).reshape(-1)
col = torch.arange(chunk_len, device=softmax_output.device).expand(batch_size, -1).reshape(-1)
# Flat offset to the start of each [b, s, :] row in the chunk's flattened tensor.
flat_idx = (row * chunk_len + col) * partition_vocab_size

valid_mask = ~chunk_target_mask
flat_chosen = flat_idx.masked_select(valid_mask.reshape(-1)) + chunk_masked_target.masked_select(valid_mask)

# `neg` is zero-copy; the subsequent mul_ writes in place.
grad_input = softmax_output.neg()
Comment thread
SumanthRH marked this conversation as resolved.
Outdated
grad_input.mul_(chunk_grad_output.unsqueeze(-1))

grad_output_selected = chunk_grad_output.masked_select(valid_mask)
grad_input.view(-1).scatter_add_(0, flat_chosen, grad_output_selected)

all_grad_input.append(grad_input)

Expand Down Expand Up @@ -290,7 +308,13 @@ def from_parallel_logits_to_logprobs(
cp_rank = torch.distributed.get_rank(cp_group)
target = _get_tokens_on_this_cp_rank(target, cp_rank, cp_size, seq_dim=1)

if chunk_size is not None:
# Only use the chunked path when chunking actually splits the sequence into
# multiple chunks. When chunk_size >= seq_len the whole sequence is one
# chunk, but ChunkedDistributedLogprob still saves the raw
# vocab_parallel_logits and recomputes softmax in backward (~3x peak memory
# vs DistributedLogprob's ~2x), so chunking actively hurts in that regime.
seq_len_local = vocab_parallel_logits.shape[1]
if chunk_size is not None and chunk_size < seq_len_local:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While chunk_size is typically positive, it's safer to ensure it is greater than zero to avoid a potential ZeroDivisionError when calculating num_chunks inside ChunkedDistributedLogprob.forward (line 167).

Suggested change
if chunk_size is not None and chunk_size < seq_len_local:
if chunk_size is not None and 0 < chunk_size < seq_len_local:

logprobs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore
vocab_parallel_logits,
target,
Expand Down Expand Up @@ -378,8 +402,15 @@ def from_parallel_logits_to_logprobs_packed_sequences(
rolled_targets = rolled_targets.unsqueeze(0)
vocab_parallel_logits = vocab_parallel_logits.unsqueeze(0)

# Apply distributed log probability computation
if chunk_size is not None:
# Apply distributed log probability computation.
#
# Only use the chunked path when chunking actually splits the sequence into
# multiple chunks. When chunk_size >= seq_len the whole sequence is one
# chunk, but ChunkedDistributedLogprob still saves the raw
# vocab_parallel_logits and recomputes softmax in backward (~3x peak memory
# vs DistributedLogprob's ~2x), so chunking actively hurts in that regime.
seq_len_local = vocab_parallel_logits.shape[1]
if chunk_size is not None and chunk_size < seq_len_local:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the change in from_parallel_logits_to_logprobs, adding a check to ensure chunk_size > 0 prevents a division by zero in the autograd function's forward pass.

Suggested change
if chunk_size is not None and chunk_size < seq_len_local:
if chunk_size is not None and 0 < chunk_size < seq_len_local:

probs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore
vocab_parallel_logits,
rolled_targets,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
"""
uv run --isolated --extra dev --extra megatron -- pytest -s tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward.py
"""

import os

import pytest
import torch
import torch.distributed as dist

from skyrl.backends.skyrl_train.distributed.megatron.model_utils import (
ChunkedDistributedLogprob,
DistributedLogprob,
)
from skyrl.train.utils.utils import get_free_port


@pytest.fixture(scope="module")
def tp_group():
"""Single-rank TP process group used by both autograd functions.

Uses gloo distributed backend for simplicity because the world size is 1.
"""
if not dist.is_initialized():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
dist.init_process_group(backend="gloo", rank=0, world_size=1)
yield dist.group.WORLD
if dist.is_initialized():
dist.destroy_process_group()


def _forward_backward(func_cls, logits, target, vocab_start, vocab_end, tp_group, *, chunk_size=None):
"""Run forward+backward through a logprob autograd function and return (out, grad_logits).

Uses a non-uniform upstream gradient so that any per-position bug surfaces.
"""
leaf = logits.detach().clone().requires_grad_(True)
if chunk_size is None:
out = func_cls.apply(leaf, target, vocab_start, vocab_end, tp_group, False)
else:
out = func_cls.apply(leaf, target, vocab_start, vocab_end, chunk_size, tp_group, False)
grad_seed = torch.linspace(0.5, 1.5, steps=out.numel(), device=out.device, dtype=out.dtype).reshape(out.shape)
out.backward(grad_seed)
return out.detach(), leaf.grad.detach()


@pytest.mark.parametrize("chunk_size", [1, 7, 16, 64, 512])
@pytest.mark.parametrize("with_oov_targets", [False, True])
def test_chunked_matches_non_chunked(tp_group, chunk_size, with_oov_targets):
"""Chunked and non-chunked logprob produce matching forwards and gradients.

Sweeps several chunk sizes (including one larger than the sequence length,
which collapses the chunk loop to a single iteration) and toggles whether
targets can fall outside the TP rank's vocab slice. The out-of-vocab path
exercises the ``target_mask`` branch in both functions.
"""
device = torch.device("cuda")
torch.manual_seed(0)

batch_size = 4
seq_len = 32
vocab_size = 32_000

target_high = vocab_size + 1024 if with_oov_targets else vocab_size

logits = torch.randn(batch_size, seq_len, vocab_size, dtype=torch.bfloat16, device=device) * 2.0
target = torch.randint(0, target_high, (batch_size, seq_len), device=device, dtype=torch.long)

out_ref, grad_ref = _forward_backward(DistributedLogprob, logits, target, 0, vocab_size, tp_group)
out_chunk, grad_chunk = _forward_backward(
ChunkedDistributedLogprob,
logits,
target,
0,
vocab_size,
tp_group,
chunk_size=chunk_size,
)

# Output dtype contract: log-softmax upcasts to fp32 internally.
assert out_chunk.dtype == torch.float32

# Forward parity. Both paths do the same fp32 math, but for small chunk sizes
# the reduction order across chunks differs from the single-shot path, which
# introduces ~1e-6 rounding noise. A loose tolerance still rules out real bugs.
torch.testing.assert_close(out_chunk, out_ref, atol=1e-5, rtol=1e-5)

# Gradient parity. Both paths use the same scatter-add formulation, so the
# tolerance can be tight relative to the bf16-logits/fp32-grad pipeline.
torch.testing.assert_close(grad_chunk, grad_ref, atol=1e-5, rtol=1e-4)


@pytest.mark.parametrize(
"case",
[
# (batch, seq_len, vocab, chunk_size, mask_mode)
# mask_mode: "default" (mixed), "all_in" (no OOV), "all_out" (all OOV)
pytest.param((1, 1, 1024, 4, "default"), id="seq1"),
pytest.param((2, 8, 1024, 32, "all_in"), id="all_in_vocab"),
pytest.param((2, 8, 1024, 32, "all_out"), id="all_out_vocab"),
pytest.param((2, 8, 8, 4, "default"), id="tiny_vocab"),
],
)
def test_chunked_matches_non_chunked_edge_cases(tp_group, case):
"""Edge cases for the chunked path: short sequences, mask extremes, tiny vocab.

Covers configurations the main sweep does not: ``seq_len=1`` (chunk loop runs
once with ``chunk_len=1``), a target_mask that is entirely False or entirely
True (the empty-``scatter_add_`` path), and a very small vocab that stresses
the masked-select/scatter arithmetic.
"""
batch_size, seq_len, vocab_size, chunk_size, mask_mode = case
device = torch.device("cuda")
torch.manual_seed(1)

logits = torch.randn(batch_size, seq_len, vocab_size, dtype=torch.bfloat16, device=device) * 2.0
if mask_mode == "all_out":
# Every target is outside the TP rank's vocab slice [0, vocab_size).
target = torch.full((batch_size, seq_len), vocab_size + 5, device=device, dtype=torch.long)
else:
target = torch.randint(0, vocab_size, (batch_size, seq_len), device=device, dtype=torch.long)

out_ref, grad_ref = _forward_backward(DistributedLogprob, logits, target, 0, vocab_size, tp_group)
out_chunk, grad_chunk = _forward_backward(
ChunkedDistributedLogprob,
logits,
target,
0,
vocab_size,
tp_group,
chunk_size=chunk_size,
)

torch.testing.assert_close(out_chunk, out_ref, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(grad_chunk, grad_ref, atol=1e-5, rtol=1e-4)


def test_chunked_backward_uses_scatter_add_path(tp_group):
"""Chunked backward uses ``scatter_add_`` and does not materialize one_hot.

Asserts the implementation contract directly: ``scatter_add_`` must be
invoked during backward, and ``torch.nn.functional.one_hot`` must not.
Verifying the code path via mocks is more reliable than a peak-memory
assertion, which is sensitive to caching allocator behaviour and other
in-flight tensors.

We patch ``F.one_hot`` after ``forward`` has finished (forward doesn't use
it either; restricting the patch window keeps the assertion tight to the
backward path).
"""
from unittest.mock import patch

device = torch.device("cuda")
torch.manual_seed(0)

# chunk_size >= seq_len collapses the chunk loop to a single iteration --
# the path most likely to regress to a one_hot formulation.
batch_size = 2
seq_len = 16
vocab_size = 1024
chunk_size = 1024

logits = torch.randn(batch_size, seq_len, vocab_size, dtype=torch.bfloat16, device=device, requires_grad=True)
target = torch.randint(0, vocab_size, (batch_size, seq_len), device=device, dtype=torch.long)

out = ChunkedDistributedLogprob.apply(logits, target, 0, vocab_size, chunk_size, tp_group, False)

real_scatter_add_ = torch.Tensor.scatter_add_
scatter_add_calls = []

def _tracking_scatter_add_(self, dim, index, src):
scatter_add_calls.append((tuple(self.shape), int(dim), tuple(index.shape)))
return real_scatter_add_(self, dim, index, src)

with (
patch(
"torch.nn.functional.one_hot",
side_effect=AssertionError("one_hot must not be called in chunked backward"),
),
patch.object(torch.Tensor, "scatter_add_", _tracking_scatter_add_),
):
out.sum().backward()

assert scatter_add_calls, "Chunked backward must call scatter_add_ to place chosen-token grads"
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ async def test_megatron_train(
# the entropy calculation is different (fsdp has random logits for padding tokens)
continue
assert isinstance(result.metrics[k], (int, float)), f"{k} should be an int or float"
assert abs(result.metrics[k] - results_megatron[i].metrics[k]) < 4e-1, f"diff in {k} is too large!"
assert abs(result.metrics[k] - results_megatron[i].metrics[k]) < 5e-1, f"diff in {k} is too large!"


@pytest.mark.asyncio
Expand Down
Loading