2626
2727from transfusion_pytorch .tensor_typing import Float , Int , Bool
2828
29+ from rotary_embedding_torch import RotaryEmbedding , apply_rotary_emb
30+
2931pad_sequence = partial (pad_sequence , batch_first = True )
3032
3133# constants
@@ -109,6 +111,23 @@ def order_modality_positions_by_seq_offset(
109111
110112 return modalities , sorted_indices
111113
114+ # deriving relative positions from modality positions
115+ # ex. given a sequence of 10 with an image at offset 3 with length 4 - [t] [t] [t] [i] [i] [i] [i] [t] [t] [t]
116+ # relative positions for rotary will be [0] [1] [2] [3] [3] [3] [3] [4] [5] [6]
117+ # rationale is that each modality will need the same position so there is no distance when conducting bidirectional attention, but should still have a relative distance to other text tokens and modalities
118+
119+ def derive_rotary_positions_from_modality_positions (
120+ seq_len : int ,
121+ modalities : Int ['b m 2' ]
122+ ) -> Int ['b n' ]:
123+
124+ device = modalities .device
125+
126+ modality_mask = modality_positions_to_is_modality_mask (seq_len , modalities , offset = torch .tensor ([1 , - 1 ]))
127+ is_any_modality = modality_mask .any (dim = 1 )
128+
129+ return torch .arange (seq_len , device = device ) - is_any_modality .cumsum (dim = - 1 )
130+
112131# modality tokens are given as list of tensors, can be then be embedded into the modality tokens for attending alongside text tokens
113132
114133def embed_modality_tokens (
@@ -139,10 +158,14 @@ def embed_modality_tokens(
139158def modality_positions_to_is_modality_mask (
140159 seq_len : int ,
141160 modalities : RawModalityPositions | Int ['b m 2' ],
161+ offset : Int ['2' ] | None = None ,
142162) -> Bool ['b m n' ]:
143163
144164 if isinstance (modalities , list ):
145- modalities = modalities_to_tensor (modalities )
165+ modalities = modality_positions_to_tensor (modalities )
166+
167+ if exists (offset ):
168+ modalities = modalities + offset .to (modalities )
146169
147170 left , right = modalities .cumsum (dim = - 1 ).unbind (dim = - 1 )
148171
@@ -322,12 +345,16 @@ def __init__(
322345 def forward (
323346 self ,
324347 x ,
325- attn_mask = None
348+ attn_mask = None ,
349+ rotary_emb : Tensor | None = None
326350 ):
327351 x = self .norm (x )
328352
329353 q , k , v = self .to_qkv (x )
330354
355+ if exists (rotary_emb ):
356+ q , k = tuple (apply_rotary_emb (rotary_emb , t ) for t in (q , k ))
357+
331358 q = q * self .scale
332359 sim = einsum (q , k , 'b h i d, b h j d -> b h i j' )
333360
@@ -360,6 +387,7 @@ def __init__(
360387 ):
361388 super ().__init__ ()
362389 self .dim = dim
390+ self .dim_head = dim_head
363391
364392 self .to_time_cond = nn .Sequential (
365393 RandomFourierEmbed (dim ),
@@ -388,7 +416,8 @@ def forward(
388416 times : Float ['' ] | Float ['b' ] | Float ['b n' ],
389417 attn_mask : Bool ['b i j' ] | None = None ,
390418 modality_positions : RawModalityPositions | Int ['b n 2' ] | None = None ,
391- is_any_modality : Bool ['b n' ] | None = None
419+ is_any_modality : Bool ['b n' ] | None = None ,
420+ rotary_emb : Tensor | None = None
392421 ):
393422 seq_len , device = x .shape [- 2 ], x .device
394423 assert exists (attn_mask ) ^ exists (modality_positions )
@@ -414,7 +443,7 @@ def forward(
414443 # transformer layers as usual, using mask from above
415444
416445 for attn , ff in self .layers :
417- x = attn (x , attn_mask = attn_mask , ** adaptive_kwargs ) + x
446+ x = attn (x , attn_mask = attn_mask , rotary_emb = rotary_emb , ** adaptive_kwargs ) + x
418447 x = ff (x , ** adaptive_kwargs ) + x
419448
420449 return self .norm (x )
@@ -438,9 +467,13 @@ def __init__(
438467 transformer = Transformer (** transformer )
439468
440469 self .transformer = transformer
441- dim = transformer .dim
470+ dim , dim_head = transformer .dim , transformer . dim_head
442471 self .dim = dim
443472
473+ # relative positions
474+
475+ self .rotary_emb = RotaryEmbedding (transformer .dim_head )
476+
444477 # embeddings and un-embeddings
445478
446479 self .text_embed = nn .Embedding (num_text_tokens , dim )
@@ -519,11 +552,18 @@ def forward(
519552
520553 tokens = einx .where ('b n, b n d, b n d' , is_any_modality , modality_tokens , text_tokens )
521554
555+ # derive rotary positions
556+
557+ rotary_positions = derive_rotary_positions_from_modality_positions (seq_len , modality_positions )
558+
559+ rotary_emb = self .rotary_emb (rotary_positions )
560+
522561 # attention
523562
524563 embed = self .transformer (
525564 tokens ,
526565 times = times ,
566+ rotary_emb = rotary_emb ,
527567 modality_positions = modality_positions
528568 )
529569
0 commit comments