-
Notifications
You must be signed in to change notification settings - Fork 327
[train] Fix chunked-logprob backward and skip chunking for short sequences #1650
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
d7b31d4
fix: ChunkedDistributedLogprob.backward use scatter-add to avoid one_…
SumanthRH 7fd7ee4
fix: _VocabParallelEntropy.backward use out-of-place subtraction to p…
SumanthRH 8ffca86
perf: skip chunked logprob dispatch when seq_len <= chunk_size
SumanthRH ab4f8e8
test: bump policy_loss tolerance from 0.4 to 0.5 in megatron entropy-…
SumanthRH 98f419f
refactor
SumanthRH c74dff9
Merge remote-tracking branch 'origin/main' into fix-chunk-logprobs-sm…
SumanthRH 8a88c60
nit
SumanthRH f3892e3
Merge remote-tracking branch 'origin/main' into fix-chunk-logprobs-sm…
SumanthRH abcb01b
x
SumanthRH c76e1ad
Update skyrl/backends/skyrl_train/distributed/megatron/model_utils.py
SumanthRH File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -216,9 +216,12 @@ def backward( | |||||
|
|
||||||
| all_grad_input = [] | ||||||
|
|
||||||
| batch_size = int(vocab_parallel_logits.shape[0]) | ||||||
|
|
||||||
| for chunk_idx in range(num_chunks): | ||||||
| chunk_start = chunk_idx * chunk_size | ||||||
| chunk_end = min(seq_size, (chunk_idx + 1) * chunk_size) | ||||||
| chunk_len = chunk_end - chunk_start | ||||||
|
|
||||||
| logits = vocab_parallel_logits[:, chunk_start:chunk_end, :] | ||||||
| logits = logits.to(dtype=torch.float32) | ||||||
|
|
@@ -229,15 +232,30 @@ def backward( | |||||
| ) | ||||||
| softmax_output = softmax_output.exp() | ||||||
|
|
||||||
| # 1 if it's the chosen log prob, 0 otherwise | ||||||
| is_chosen = (~(target_mask[:, chunk_start:chunk_end])).unsqueeze(-1) * torch.nn.functional.one_hot( | ||||||
| masked_target[:, chunk_start:chunk_end], | ||||||
| num_classes=partition_vocab_size, | ||||||
| ) | ||||||
|
|
||||||
| grad_input = is_chosen.float().sub_(softmax_output) | ||||||
|
|
||||||
| grad_input.mul_(grad_output[:, chunk_start:chunk_end].unsqueeze(dim=-1)) | ||||||
| # Memory-efficient scatter-add fast path (ported from DistributedLogprob.backward). | ||||||
| # Materializing one_hot(masked_target, num_classes=partition_vocab_size) would | ||||||
| # allocate a [B, chunk_len, partition_vocab_size] int64 tensor (~8x the size of | ||||||
| # softmax_output in float32), which causes OOM for large vocabularies. Instead, | ||||||
| # compute -softmax * grad_output in place and add grad_output at the chosen-token | ||||||
| # positions via scatter_add_. | ||||||
| chunk_target_mask = target_mask[:, chunk_start:chunk_end] | ||||||
| chunk_masked_target = masked_target[:, chunk_start:chunk_end] | ||||||
| chunk_grad_output = grad_output[:, chunk_start:chunk_end] | ||||||
|
|
||||||
| row = torch.arange(batch_size, device=softmax_output.device).view(-1, 1).expand(-1, chunk_len).reshape(-1) | ||||||
| col = torch.arange(chunk_len, device=softmax_output.device).expand(batch_size, -1).reshape(-1) | ||||||
| # Flat offset to the start of each [b, s, :] row in the chunk's flattened tensor. | ||||||
| flat_idx = (row * chunk_len + col) * partition_vocab_size | ||||||
|
|
||||||
| valid_mask = ~chunk_target_mask | ||||||
| flat_chosen = flat_idx.masked_select(valid_mask.reshape(-1)) + chunk_masked_target.masked_select(valid_mask) | ||||||
|
|
||||||
| # `neg` is zero-copy; the subsequent mul_ writes in place. | ||||||
| grad_input = softmax_output.neg_() | ||||||
| grad_input.mul_(chunk_grad_output.unsqueeze(-1)) | ||||||
|
|
||||||
| grad_output_selected = chunk_grad_output.masked_select(valid_mask) | ||||||
| grad_input.view(-1).scatter_add_(0, flat_chosen, grad_output_selected) | ||||||
|
|
||||||
| all_grad_input.append(grad_input) | ||||||
|
|
||||||
|
|
@@ -290,7 +308,13 @@ def from_parallel_logits_to_logprobs( | |||||
| cp_rank = torch.distributed.get_rank(cp_group) | ||||||
| target = _get_tokens_on_this_cp_rank(target, cp_rank, cp_size, seq_dim=1) | ||||||
|
|
||||||
| if chunk_size is not None: | ||||||
| # Only use the chunked path when chunking actually splits the sequence into | ||||||
| # multiple chunks. When chunk_size >= seq_len the whole sequence is one | ||||||
| # chunk, but ChunkedDistributedLogprob still saves the raw | ||||||
| # vocab_parallel_logits and recomputes softmax in backward (~3x peak memory | ||||||
| # vs DistributedLogprob's ~2x), so chunking actively hurts in that regime. | ||||||
| seq_len_local = vocab_parallel_logits.shape[1] | ||||||
| if chunk_size is not None and chunk_size < seq_len_local: | ||||||
| logprobs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore | ||||||
| vocab_parallel_logits, | ||||||
| target, | ||||||
|
|
@@ -378,8 +402,15 @@ def from_parallel_logits_to_logprobs_packed_sequences( | |||||
| rolled_targets = rolled_targets.unsqueeze(0) | ||||||
| vocab_parallel_logits = vocab_parallel_logits.unsqueeze(0) | ||||||
|
|
||||||
| # Apply distributed log probability computation | ||||||
| if chunk_size is not None: | ||||||
| # Apply distributed log probability computation. | ||||||
| # | ||||||
| # Only use the chunked path when chunking actually splits the sequence into | ||||||
| # multiple chunks. When chunk_size >= seq_len the whole sequence is one | ||||||
| # chunk, but ChunkedDistributedLogprob still saves the raw | ||||||
| # vocab_parallel_logits and recomputes softmax in backward (~3x peak memory | ||||||
| # vs DistributedLogprob's ~2x), so chunking actively hurts in that regime. | ||||||
| seq_len_local = vocab_parallel_logits.shape[1] | ||||||
| if chunk_size is not None and chunk_size < seq_len_local: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the change in
Suggested change
|
||||||
| probs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore | ||||||
| vocab_parallel_logits, | ||||||
| rolled_targets, | ||||||
|
|
||||||
187 changes: 187 additions & 0 deletions
187
tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,187 @@ | ||
| """ | ||
| uv run --isolated --extra dev --extra megatron -- pytest -s tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward.py | ||
| """ | ||
|
|
||
| import os | ||
|
|
||
| import pytest | ||
| import torch | ||
| import torch.distributed as dist | ||
|
|
||
| from skyrl.backends.skyrl_train.distributed.megatron.model_utils import ( | ||
| ChunkedDistributedLogprob, | ||
| DistributedLogprob, | ||
| ) | ||
| from skyrl.train.utils.utils import get_free_port | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def tp_group(): | ||
| """Single-rank TP process group used by both autograd functions. | ||
|
|
||
| Uses gloo distributed backend for simplicity because the world size is 1. | ||
| """ | ||
| if not dist.is_initialized(): | ||
| os.environ["MASTER_ADDR"] = "localhost" | ||
| os.environ["MASTER_PORT"] = str(get_free_port()) | ||
| os.environ["RANK"] = "0" | ||
| os.environ["WORLD_SIZE"] = "1" | ||
| dist.init_process_group(backend="gloo", rank=0, world_size=1) | ||
| yield dist.group.WORLD | ||
| if dist.is_initialized(): | ||
| dist.destroy_process_group() | ||
|
|
||
|
|
||
| def _forward_backward(func_cls, logits, target, vocab_start, vocab_end, tp_group, *, chunk_size=None): | ||
| """Run forward+backward through a logprob autograd function and return (out, grad_logits). | ||
|
|
||
| Uses a non-uniform upstream gradient so that any per-position bug surfaces. | ||
| """ | ||
| leaf = logits.detach().clone().requires_grad_(True) | ||
| if chunk_size is None: | ||
| out = func_cls.apply(leaf, target, vocab_start, vocab_end, tp_group, False) | ||
| else: | ||
| out = func_cls.apply(leaf, target, vocab_start, vocab_end, chunk_size, tp_group, False) | ||
| grad_seed = torch.linspace(0.5, 1.5, steps=out.numel(), device=out.device, dtype=out.dtype).reshape(out.shape) | ||
| out.backward(grad_seed) | ||
| return out.detach(), leaf.grad.detach() | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("chunk_size", [1, 7, 16, 64, 512]) | ||
| @pytest.mark.parametrize("with_oov_targets", [False, True]) | ||
| def test_chunked_matches_non_chunked(tp_group, chunk_size, with_oov_targets): | ||
| """Chunked and non-chunked logprob produce matching forwards and gradients. | ||
|
|
||
| Sweeps several chunk sizes (including one larger than the sequence length, | ||
| which collapses the chunk loop to a single iteration) and toggles whether | ||
| targets can fall outside the TP rank's vocab slice. The out-of-vocab path | ||
| exercises the ``target_mask`` branch in both functions. | ||
| """ | ||
| device = torch.device("cuda") | ||
| torch.manual_seed(0) | ||
|
|
||
| batch_size = 4 | ||
| seq_len = 32 | ||
| vocab_size = 32_000 | ||
|
|
||
| target_high = vocab_size + 1024 if with_oov_targets else vocab_size | ||
|
|
||
| logits = torch.randn(batch_size, seq_len, vocab_size, dtype=torch.bfloat16, device=device) * 2.0 | ||
| target = torch.randint(0, target_high, (batch_size, seq_len), device=device, dtype=torch.long) | ||
|
|
||
| out_ref, grad_ref = _forward_backward(DistributedLogprob, logits, target, 0, vocab_size, tp_group) | ||
| out_chunk, grad_chunk = _forward_backward( | ||
| ChunkedDistributedLogprob, | ||
| logits, | ||
| target, | ||
| 0, | ||
| vocab_size, | ||
| tp_group, | ||
| chunk_size=chunk_size, | ||
| ) | ||
|
|
||
| # Output dtype contract: log-softmax upcasts to fp32 internally. | ||
| assert out_chunk.dtype == torch.float32 | ||
|
|
||
| # Forward parity. Both paths do the same fp32 math, but for small chunk sizes | ||
| # the reduction order across chunks differs from the single-shot path, which | ||
| # introduces ~1e-6 rounding noise. A loose tolerance still rules out real bugs. | ||
| torch.testing.assert_close(out_chunk, out_ref, atol=1e-5, rtol=1e-5) | ||
|
|
||
| # Gradient parity. Both paths use the same scatter-add formulation, so the | ||
| # tolerance can be tight relative to the bf16-logits/fp32-grad pipeline. | ||
| torch.testing.assert_close(grad_chunk, grad_ref, atol=1e-5, rtol=1e-4) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "case", | ||
| [ | ||
| # (batch, seq_len, vocab, chunk_size, mask_mode) | ||
| # mask_mode: "default" (mixed), "all_in" (no OOV), "all_out" (all OOV) | ||
| pytest.param((1, 1, 1024, 4, "default"), id="seq1"), | ||
| pytest.param((2, 8, 1024, 32, "all_in"), id="all_in_vocab"), | ||
| pytest.param((2, 8, 1024, 32, "all_out"), id="all_out_vocab"), | ||
| pytest.param((2, 8, 8, 4, "default"), id="tiny_vocab"), | ||
| ], | ||
| ) | ||
| def test_chunked_matches_non_chunked_edge_cases(tp_group, case): | ||
| """Edge cases for the chunked path: short sequences, mask extremes, tiny vocab. | ||
|
|
||
| Covers configurations the main sweep does not: ``seq_len=1`` (chunk loop runs | ||
| once with ``chunk_len=1``), a target_mask that is entirely False or entirely | ||
| True (the empty-``scatter_add_`` path), and a very small vocab that stresses | ||
| the masked-select/scatter arithmetic. | ||
| """ | ||
| batch_size, seq_len, vocab_size, chunk_size, mask_mode = case | ||
| device = torch.device("cuda") | ||
| torch.manual_seed(1) | ||
|
|
||
| logits = torch.randn(batch_size, seq_len, vocab_size, dtype=torch.bfloat16, device=device) * 2.0 | ||
| if mask_mode == "all_out": | ||
| # Every target is outside the TP rank's vocab slice [0, vocab_size). | ||
| target = torch.full((batch_size, seq_len), vocab_size + 5, device=device, dtype=torch.long) | ||
| else: | ||
| target = torch.randint(0, vocab_size, (batch_size, seq_len), device=device, dtype=torch.long) | ||
|
|
||
| out_ref, grad_ref = _forward_backward(DistributedLogprob, logits, target, 0, vocab_size, tp_group) | ||
| out_chunk, grad_chunk = _forward_backward( | ||
| ChunkedDistributedLogprob, | ||
| logits, | ||
| target, | ||
| 0, | ||
| vocab_size, | ||
| tp_group, | ||
| chunk_size=chunk_size, | ||
| ) | ||
|
|
||
| torch.testing.assert_close(out_chunk, out_ref, atol=1e-5, rtol=1e-5) | ||
| torch.testing.assert_close(grad_chunk, grad_ref, atol=1e-5, rtol=1e-4) | ||
|
|
||
|
|
||
| def test_chunked_backward_uses_scatter_add_path(tp_group): | ||
| """Chunked backward uses ``scatter_add_`` and does not materialize one_hot. | ||
|
|
||
| Asserts the implementation contract directly: ``scatter_add_`` must be | ||
| invoked during backward, and ``torch.nn.functional.one_hot`` must not. | ||
| Verifying the code path via mocks is more reliable than a peak-memory | ||
| assertion, which is sensitive to caching allocator behaviour and other | ||
| in-flight tensors. | ||
|
|
||
| We patch ``F.one_hot`` after ``forward`` has finished (forward doesn't use | ||
| it either; restricting the patch window keeps the assertion tight to the | ||
| backward path). | ||
| """ | ||
| from unittest.mock import patch | ||
|
|
||
| device = torch.device("cuda") | ||
| torch.manual_seed(0) | ||
|
|
||
| # chunk_size >= seq_len collapses the chunk loop to a single iteration -- | ||
| # the path most likely to regress to a one_hot formulation. | ||
| batch_size = 2 | ||
| seq_len = 16 | ||
| vocab_size = 1024 | ||
| chunk_size = 1024 | ||
|
|
||
| logits = torch.randn(batch_size, seq_len, vocab_size, dtype=torch.bfloat16, device=device, requires_grad=True) | ||
| target = torch.randint(0, vocab_size, (batch_size, seq_len), device=device, dtype=torch.long) | ||
|
|
||
| out = ChunkedDistributedLogprob.apply(logits, target, 0, vocab_size, chunk_size, tp_group, False) | ||
|
|
||
| real_scatter_add_ = torch.Tensor.scatter_add_ | ||
| scatter_add_calls = [] | ||
|
|
||
| def _tracking_scatter_add_(self, dim, index, src): | ||
| scatter_add_calls.append((tuple(self.shape), int(dim), tuple(index.shape))) | ||
| return real_scatter_add_(self, dim, index, src) | ||
|
|
||
| with ( | ||
| patch( | ||
| "torch.nn.functional.one_hot", | ||
| side_effect=AssertionError("one_hot must not be called in chunked backward"), | ||
| ), | ||
| patch.object(torch.Tensor, "scatter_add_", _tracking_scatter_add_), | ||
| ): | ||
| out.sum().backward() | ||
|
|
||
| assert scatter_add_calls, "Chunked backward must call scatter_add_ to place chosen-token grads" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While
chunk_sizeis typically positive, it's safer to ensure it is greater than zero to avoid a potentialZeroDivisionErrorwhen calculatingnum_chunksinsideChunkedDistributedLogprob.forward(line 167).