Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions roll/distributed/strategy/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,10 +341,13 @@ def _pack_sequences(self, input_tensor, attention_mask, pad_packed_seq_to=None,
seq_tokens = input_tensor_unpadded[b]

# Pad sequence if needed
if padded_seq_len > seq_len:
seq_tokens = torch.nn.functional.pad(
seq_tokens, (0, padded_seq_len - seq_len), value=pad_val
)
pad_seq_len = padded_seq_len - seq_len
if pad_seq_len > 0:
if seq_tokens.dim() == 1:
pad_config = (0, pad_seq_len)
else:
pad_config = tuple([0, 0] * (seq_tokens.dim() - 1) + [0, pad_seq_len])
seq_tokens = torch.nn.functional.pad(seq_tokens, pad_config, value=pad_val)
all_input_tensor_padded.append(seq_tokens)

if cp_size > 1:
Expand Down Expand Up @@ -411,20 +414,32 @@ def inner_forward_step(self, loss_func, data_iterator: Iterator[DataProto], mode
attention_mask = data.batch["attention_mask"]
labels = data.batch["labels"] if "labels" in data.batch else None # labels is only used for sft
packed_seq_params = None
position_ids = None
raw_position_ids = None
if "position_ids" in data.batch.keys() and data.batch["position_ids"].dim() == 3: # qwen2vl/qwen3vl mrope
raw_position_ids = data.batch["position_ids"]
if raw_position_ids.size(1) == 4:
raw_position_ids = raw_position_ids[:, 1:, :].contiguous() # (bsz, 4, seqlen) -> (bsz, 3, seqlen)

if self.use_sequence_packing:
input_ids, packed_seq_params, cu_seqlens, cu_seqlens_padded = self._pack_sequences(
input_ids, attention_mask,
)
if labels is not None:
labels, _, _, _ = self._pack_sequences(labels, attention_mask, pad_val=IGNORE_INDEX)
if raw_position_ids is not None:
position_ids_for_pack = raw_position_ids.transpose(1, 2).contiguous() # (bsz, C, seqlen) -> (bsz, seqlen, C)
packed_position_ids, _, _, _ = self._pack_sequences(position_ids_for_pack, attention_mask, pad_val=0)
position_ids = packed_position_ids.squeeze(0).transpose(0, 1).unsqueeze(1).contiguous()
attention_mask = None
else:
input_ids = self._get_feature_on_this_cp_rank(input_ids, "input_ids")
attention_mask = self._get_feature_on_this_cp_rank(attention_mask, "attention_mask")
if labels is not None:
labels = self._get_feature_on_this_cp_rank(labels, "labels")
position_ids = None
if raw_position_ids is not None:
attention_mask = None
position_ids = raw_position_ids.transpose(0, 1) # (bsz, C, seqlen) -> (C, bsz, seqlen)
# attention_mask: SelfAttention defalt to te DotProductAttention with
# AttnMaskType.causal in which attention_mask would not be used, pass
# it mainly for moe aux loss without pad token and it is 2D
Expand Down