@@ -184,24 +184,21 @@ def modality_positions_to_is_modality_mask(
184184 offset = F .pad (offset , (1 , 0 ))
185185 modalities = modalities + offset .to (modalities )
186186
187- is_modalities = []
188-
189- for modality_id in range (num_modalities ):
190- one_modality_type_mask = modalities [..., 0 ] == modality_id
191- one_modality = modalities .masked_fill (~ one_modality_type_mask [..., None ], 0 )
187+ seq = torch .arange (seq_len , device = device )
188+ type_seq = torch .arange (num_modalities , device = device )
192189
193- left , right = one_modality [..., 1 :]. cumsum ( dim = - 1 ). unbind ( dim = - 1 )
190+ modality_types = modalities [..., 0 ]
194191
195- seq = torch . arange ( seq_len , device = device )
192+ left , right = modalities [..., 1 :]. cumsum ( dim = - 1 ). unbind ( dim = - 1 )
196193
197- is_modality = (
198- einx .greater_equal ('i, b m -> b m i' , seq , left ) &
199- einx .less ('j, b m -> b m j' , seq , right )
200- )
194+ is_instance_for_type = einx .equal ('b m, t -> b t m' , modality_types , type_seq )
201195
202- is_modalities .append (is_modality )
196+ is_modality_along_seq = (
197+ einx .greater_equal ('i, b m -> b m i' , seq , left ) &
198+ einx .less ('j, b m -> b m j' , seq , right )
199+ )
203200
204- return torch . stack ( is_modalities , dim = 1 )
201+ return einx . logical_and ( 'b t m, b m n -> b t m n' , is_instance_for_type , is_modality_along_seq )
205202
206203def naive_attn_mask (
207204 seq_len : int ,
0 commit comments