diff --git a/roll/distributed/strategy/megatron_strategy.py b/roll/distributed/strategy/megatron_strategy.py index 4eeb4cc74..0c696ce94 100644 --- a/roll/distributed/strategy/megatron_strategy.py +++ b/roll/distributed/strategy/megatron_strategy.py @@ -341,10 +341,13 @@ def _pack_sequences(self, input_tensor, attention_mask, pad_packed_seq_to=None, seq_tokens = input_tensor_unpadded[b] # Pad sequence if needed - if padded_seq_len > seq_len: - seq_tokens = torch.nn.functional.pad( - seq_tokens, (0, padded_seq_len - seq_len), value=pad_val - ) + pad_seq_len = padded_seq_len - seq_len + if pad_seq_len > 0: + if seq_tokens.dim() == 1: + pad_config = (0, pad_seq_len) + else: + pad_config = tuple([0, 0] * (seq_tokens.dim() - 1) + [0, pad_seq_len]) + seq_tokens = torch.nn.functional.pad(seq_tokens, pad_config, value=pad_val) all_input_tensor_padded.append(seq_tokens) if cp_size > 1: @@ -411,6 +414,12 @@ def inner_forward_step(self, loss_func, data_iterator: Iterator[DataProto], mode attention_mask = data.batch["attention_mask"] labels = data.batch["labels"] if "labels" in data.batch else None # labels is only used for sft packed_seq_params = None + position_ids = None + raw_position_ids = None + if "position_ids" in data.batch.keys() and data.batch["position_ids"].dim() == 3: # qwen2vl/qwen3vl mrope + raw_position_ids = data.batch["position_ids"] + if raw_position_ids.size(1) == 4: + raw_position_ids = raw_position_ids[:, 1:, :].contiguous() # (bsz, 4, seqlen) -> (bsz, 3, seqlen) if self.use_sequence_packing: input_ids, packed_seq_params, cu_seqlens, cu_seqlens_padded = self._pack_sequences( @@ -418,13 +427,19 @@ def inner_forward_step(self, loss_func, data_iterator: Iterator[DataProto], mode ) if labels is not None: labels, _, _, _ = self._pack_sequences(labels, attention_mask, pad_val=IGNORE_INDEX) + if raw_position_ids is not None: + position_ids_for_pack = raw_position_ids.transpose(1, 2).contiguous() # (bsz, C, seqlen) -> (bsz, seqlen, C) + packed_position_ids, _, _, _ = self._pack_sequences(position_ids_for_pack, attention_mask, pad_val=0) + position_ids = packed_position_ids.squeeze(0).transpose(0, 1).unsqueeze(1).contiguous() attention_mask = None else: input_ids = self._get_feature_on_this_cp_rank(input_ids, "input_ids") attention_mask = self._get_feature_on_this_cp_rank(attention_mask, "attention_mask") if labels is not None: labels = self._get_feature_on_this_cp_rank(labels, "labels") - position_ids = None + if raw_position_ids is not None: + attention_mask = None + position_ids = raw_position_ids.transpose(0, 1) # (bsz, C, seqlen) -> (C, bsz, seqlen) # attention_mask: SelfAttention defalt to te DotProductAttention with # AttnMaskType.causal in which attention_mask would not be used, pass # it mainly for moe aux loss without pad token and it is 2D