-
Notifications
You must be signed in to change notification settings - Fork 323
[train] Fix chunked-logprob backward and skip chunking for short sequences #1650
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
d7b31d4
7fd7ee4
8ffca86
ab4f8e8
98f419f
c74dff9
8a88c60
f3892e3
abcb01b
c76e1ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -216,9 +216,12 @@ def backward( | |||||
|
|
||||||
| all_grad_input = [] | ||||||
|
|
||||||
| batch_size = int(vocab_parallel_logits.shape[0]) | ||||||
|
|
||||||
| for chunk_idx in range(num_chunks): | ||||||
| chunk_start = chunk_idx * chunk_size | ||||||
| chunk_end = min(seq_size, (chunk_idx + 1) * chunk_size) | ||||||
| chunk_len = chunk_end - chunk_start | ||||||
|
|
||||||
| logits = vocab_parallel_logits[:, chunk_start:chunk_end, :] | ||||||
| logits = logits.to(dtype=torch.float32) | ||||||
|
|
@@ -229,15 +232,30 @@ def backward( | |||||
| ) | ||||||
| softmax_output = softmax_output.exp() | ||||||
|
|
||||||
| # 1 if it's the chosen log prob, 0 otherwise | ||||||
| is_chosen = (~(target_mask[:, chunk_start:chunk_end])).unsqueeze(-1) * torch.nn.functional.one_hot( | ||||||
| masked_target[:, chunk_start:chunk_end], | ||||||
| num_classes=partition_vocab_size, | ||||||
| ) | ||||||
|
|
||||||
| grad_input = is_chosen.float().sub_(softmax_output) | ||||||
|
|
||||||
| grad_input.mul_(grad_output[:, chunk_start:chunk_end].unsqueeze(dim=-1)) | ||||||
| # Memory-efficient scatter-add fast path (ported from DistributedLogprob.backward). | ||||||
| # Materializing one_hot(masked_target, num_classes=partition_vocab_size) would | ||||||
| # allocate a [B, chunk_len, partition_vocab_size] int64 tensor (~8x the size of | ||||||
| # softmax_output in float32), which causes OOM for large vocabularies. Instead, | ||||||
| # compute -softmax * grad_output in place and add grad_output at the chosen-token | ||||||
| # positions via scatter_add_. | ||||||
| chunk_target_mask = target_mask[:, chunk_start:chunk_end] | ||||||
| chunk_masked_target = masked_target[:, chunk_start:chunk_end] | ||||||
| chunk_grad_output = grad_output[:, chunk_start:chunk_end] | ||||||
|
|
||||||
| row = torch.arange(batch_size, device=softmax_output.device).view(-1, 1).expand(-1, chunk_len).reshape(-1) | ||||||
| col = torch.arange(chunk_len, device=softmax_output.device).expand(batch_size, -1).reshape(-1) | ||||||
| # Flat offset to the start of each [b, s, :] row in the chunk's flattened tensor. | ||||||
| flat_idx = (row * chunk_len + col) * partition_vocab_size | ||||||
|
|
||||||
| valid_mask = ~chunk_target_mask | ||||||
| flat_chosen = flat_idx.masked_select(valid_mask.reshape(-1)) + chunk_masked_target.masked_select(valid_mask) | ||||||
|
|
||||||
| # `neg` is zero-copy; the subsequent mul_ writes in place. | ||||||
| grad_input = softmax_output.neg() | ||||||
| grad_input.mul_(chunk_grad_output.unsqueeze(-1)) | ||||||
|
|
||||||
| grad_output_selected = chunk_grad_output.masked_select(valid_mask) | ||||||
| grad_input.view(-1).scatter_add_(0, flat_chosen, grad_output_selected) | ||||||
|
|
||||||
| all_grad_input.append(grad_input) | ||||||
|
|
||||||
|
|
@@ -290,7 +308,13 @@ def from_parallel_logits_to_logprobs( | |||||
| cp_rank = torch.distributed.get_rank(cp_group) | ||||||
| target = _get_tokens_on_this_cp_rank(target, cp_rank, cp_size, seq_dim=1) | ||||||
|
|
||||||
| if chunk_size is not None: | ||||||
| # Only use the chunked path when chunking actually splits the sequence into | ||||||
| # multiple chunks. When chunk_size >= seq_len the whole sequence is one | ||||||
| # chunk, but ChunkedDistributedLogprob still saves the raw | ||||||
| # vocab_parallel_logits and recomputes softmax in backward (~3x peak memory | ||||||
| # vs DistributedLogprob's ~2x), so chunking actively hurts in that regime. | ||||||
| seq_len_local = vocab_parallel_logits.shape[1] | ||||||
| if chunk_size is not None and chunk_size < seq_len_local: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While
Suggested change
|
||||||
| logprobs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore | ||||||
| vocab_parallel_logits, | ||||||
| target, | ||||||
|
|
@@ -378,8 +402,15 @@ def from_parallel_logits_to_logprobs_packed_sequences( | |||||
| rolled_targets = rolled_targets.unsqueeze(0) | ||||||
| vocab_parallel_logits = vocab_parallel_logits.unsqueeze(0) | ||||||
|
|
||||||
| # Apply distributed log probability computation | ||||||
| if chunk_size is not None: | ||||||
| # Apply distributed log probability computation. | ||||||
| # | ||||||
| # Only use the chunked path when chunking actually splits the sequence into | ||||||
| # multiple chunks. When chunk_size >= seq_len the whole sequence is one | ||||||
| # chunk, but ChunkedDistributedLogprob still saves the raw | ||||||
| # vocab_parallel_logits and recomputes softmax in backward (~3x peak memory | ||||||
| # vs DistributedLogprob's ~2x), so chunking actively hurts in that regime. | ||||||
| seq_len_local = vocab_parallel_logits.shape[1] | ||||||
| if chunk_size is not None and chunk_size < seq_len_local: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the change in
Suggested change
|
||||||
| probs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore | ||||||
| vocab_parallel_logits, | ||||||
| rolled_targets, | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,187 @@ | ||
| """ | ||
| uv run --isolated --extra dev --extra megatron -- pytest -s tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward.py | ||
| """ | ||
|
|
||
| import os | ||
|
|
||
| import pytest | ||
| import torch | ||
| import torch.distributed as dist | ||
|
|
||
| from skyrl.backends.skyrl_train.distributed.megatron.model_utils import ( | ||
| ChunkedDistributedLogprob, | ||
| DistributedLogprob, | ||
| ) | ||
| from skyrl.train.utils.utils import get_free_port | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def tp_group(): | ||
| """Single-rank TP process group used by both autograd functions. | ||
|
|
||
| Uses gloo distributed backend for simplicity because the world size is 1. | ||
| """ | ||
| if not dist.is_initialized(): | ||
| os.environ["MASTER_ADDR"] = "localhost" | ||
| os.environ["MASTER_PORT"] = str(get_free_port()) | ||
| os.environ["RANK"] = "0" | ||
| os.environ["WORLD_SIZE"] = "1" | ||
| dist.init_process_group(backend="gloo", rank=0, world_size=1) | ||
| yield dist.group.WORLD | ||
| if dist.is_initialized(): | ||
| dist.destroy_process_group() | ||
|
|
||
|
|
||
| def _forward_backward(func_cls, logits, target, vocab_start, vocab_end, tp_group, *, chunk_size=None): | ||
| """Run forward+backward through a logprob autograd function and return (out, grad_logits). | ||
|
|
||
| Uses a non-uniform upstream gradient so that any per-position bug surfaces. | ||
| """ | ||
| leaf = logits.detach().clone().requires_grad_(True) | ||
| if chunk_size is None: | ||
| out = func_cls.apply(leaf, target, vocab_start, vocab_end, tp_group, False) | ||
| else: | ||
| out = func_cls.apply(leaf, target, vocab_start, vocab_end, chunk_size, tp_group, False) | ||
| grad_seed = torch.linspace(0.5, 1.5, steps=out.numel(), device=out.device, dtype=out.dtype).reshape(out.shape) | ||
| out.backward(grad_seed) | ||
| return out.detach(), leaf.grad.detach() | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("chunk_size", [1, 7, 16, 64, 512]) | ||
| @pytest.mark.parametrize("with_oov_targets", [False, True]) | ||
| def test_chunked_matches_non_chunked(tp_group, chunk_size, with_oov_targets): | ||
| """Chunked and non-chunked logprob produce matching forwards and gradients. | ||
|
|
||
| Sweeps several chunk sizes (including one larger than the sequence length, | ||
| which collapses the chunk loop to a single iteration) and toggles whether | ||
| targets can fall outside the TP rank's vocab slice. The out-of-vocab path | ||
| exercises the ``target_mask`` branch in both functions. | ||
| """ | ||
| device = torch.device("cuda") | ||
| torch.manual_seed(0) | ||
|
|
||
| batch_size = 4 | ||
| seq_len = 32 | ||
| vocab_size = 32_000 | ||
|
|
||
| target_high = vocab_size + 1024 if with_oov_targets else vocab_size | ||
|
|
||
| logits = torch.randn(batch_size, seq_len, vocab_size, dtype=torch.bfloat16, device=device) * 2.0 | ||
| target = torch.randint(0, target_high, (batch_size, seq_len), device=device, dtype=torch.long) | ||
|
|
||
| out_ref, grad_ref = _forward_backward(DistributedLogprob, logits, target, 0, vocab_size, tp_group) | ||
| out_chunk, grad_chunk = _forward_backward( | ||
| ChunkedDistributedLogprob, | ||
| logits, | ||
| target, | ||
| 0, | ||
| vocab_size, | ||
| tp_group, | ||
| chunk_size=chunk_size, | ||
| ) | ||
|
|
||
| # Output dtype contract: log-softmax upcasts to fp32 internally. | ||
| assert out_chunk.dtype == torch.float32 | ||
|
|
||
| # Forward parity. Both paths do the same fp32 math, but for small chunk sizes | ||
| # the reduction order across chunks differs from the single-shot path, which | ||
| # introduces ~1e-6 rounding noise. A loose tolerance still rules out real bugs. | ||
| torch.testing.assert_close(out_chunk, out_ref, atol=1e-5, rtol=1e-5) | ||
|
|
||
| # Gradient parity. Both paths use the same scatter-add formulation, so the | ||
| # tolerance can be tight relative to the bf16-logits/fp32-grad pipeline. | ||
| torch.testing.assert_close(grad_chunk, grad_ref, atol=1e-5, rtol=1e-4) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "case", | ||
| [ | ||
| # (batch, seq_len, vocab, chunk_size, mask_mode) | ||
| # mask_mode: "default" (mixed), "all_in" (no OOV), "all_out" (all OOV) | ||
| pytest.param((1, 1, 1024, 4, "default"), id="seq1"), | ||
| pytest.param((2, 8, 1024, 32, "all_in"), id="all_in_vocab"), | ||
| pytest.param((2, 8, 1024, 32, "all_out"), id="all_out_vocab"), | ||
| pytest.param((2, 8, 8, 4, "default"), id="tiny_vocab"), | ||
| ], | ||
| ) | ||
| def test_chunked_matches_non_chunked_edge_cases(tp_group, case): | ||
| """Edge cases for the chunked path: short sequences, mask extremes, tiny vocab. | ||
|
|
||
| Covers configurations the main sweep does not: ``seq_len=1`` (chunk loop runs | ||
| once with ``chunk_len=1``), a target_mask that is entirely False or entirely | ||
| True (the empty-``scatter_add_`` path), and a very small vocab that stresses | ||
| the masked-select/scatter arithmetic. | ||
| """ | ||
| batch_size, seq_len, vocab_size, chunk_size, mask_mode = case | ||
| device = torch.device("cuda") | ||
| torch.manual_seed(1) | ||
|
|
||
| logits = torch.randn(batch_size, seq_len, vocab_size, dtype=torch.bfloat16, device=device) * 2.0 | ||
| if mask_mode == "all_out": | ||
| # Every target is outside the TP rank's vocab slice [0, vocab_size). | ||
| target = torch.full((batch_size, seq_len), vocab_size + 5, device=device, dtype=torch.long) | ||
| else: | ||
| target = torch.randint(0, vocab_size, (batch_size, seq_len), device=device, dtype=torch.long) | ||
|
|
||
| out_ref, grad_ref = _forward_backward(DistributedLogprob, logits, target, 0, vocab_size, tp_group) | ||
| out_chunk, grad_chunk = _forward_backward( | ||
| ChunkedDistributedLogprob, | ||
| logits, | ||
| target, | ||
| 0, | ||
| vocab_size, | ||
| tp_group, | ||
| chunk_size=chunk_size, | ||
| ) | ||
|
|
||
| torch.testing.assert_close(out_chunk, out_ref, atol=1e-5, rtol=1e-5) | ||
| torch.testing.assert_close(grad_chunk, grad_ref, atol=1e-5, rtol=1e-4) | ||
|
|
||
|
|
||
| def test_chunked_backward_uses_scatter_add_path(tp_group): | ||
| """Chunked backward uses ``scatter_add_`` and does not materialize one_hot. | ||
|
|
||
| Asserts the implementation contract directly: ``scatter_add_`` must be | ||
| invoked during backward, and ``torch.nn.functional.one_hot`` must not. | ||
| Verifying the code path via mocks is more reliable than a peak-memory | ||
| assertion, which is sensitive to caching allocator behaviour and other | ||
| in-flight tensors. | ||
|
|
||
| We patch ``F.one_hot`` after ``forward`` has finished (forward doesn't use | ||
| it either; restricting the patch window keeps the assertion tight to the | ||
| backward path). | ||
| """ | ||
| from unittest.mock import patch | ||
|
|
||
| device = torch.device("cuda") | ||
| torch.manual_seed(0) | ||
|
|
||
| # chunk_size >= seq_len collapses the chunk loop to a single iteration -- | ||
| # the path most likely to regress to a one_hot formulation. | ||
| batch_size = 2 | ||
| seq_len = 16 | ||
| vocab_size = 1024 | ||
| chunk_size = 1024 | ||
|
|
||
| logits = torch.randn(batch_size, seq_len, vocab_size, dtype=torch.bfloat16, device=device, requires_grad=True) | ||
| target = torch.randint(0, vocab_size, (batch_size, seq_len), device=device, dtype=torch.long) | ||
|
|
||
| out = ChunkedDistributedLogprob.apply(logits, target, 0, vocab_size, chunk_size, tp_group, False) | ||
|
|
||
| real_scatter_add_ = torch.Tensor.scatter_add_ | ||
| scatter_add_calls = [] | ||
|
|
||
| def _tracking_scatter_add_(self, dim, index, src): | ||
| scatter_add_calls.append((tuple(self.shape), int(dim), tuple(index.shape))) | ||
| return real_scatter_add_(self, dim, index, src) | ||
|
|
||
| with ( | ||
| patch( | ||
| "torch.nn.functional.one_hot", | ||
| side_effect=AssertionError("one_hot must not be called in chunked backward"), | ||
| ), | ||
| patch.object(torch.Tensor, "scatter_add_", _tracking_scatter_add_), | ||
| ): | ||
| out.sum().backward() | ||
|
|
||
| assert scatter_add_calls, "Chunked backward must call scatter_add_ to place chosen-token grads" |
Uh oh!
There was an error while loading. Please reload this page.