Skip to content

Commit de91932

Browse files
committed
use own rotary embedding library for relative positions, and take strategy of treating a single modality made up of multiple tokens as one positions
1 parent bec21b0 commit de91932

File tree

2 files changed

+46
-5
lines changed

2 files changed

+46
-5
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ dependencies = [
2727
'einx>=0.3.0',
2828
'einops>=0.8.0',
2929
'jaxtyping',
30+
'rotary_embedding_torch',
3031
'torch>=2.0',
3132
]
3233

transfusion_pytorch/transfusion.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626

2727
from transfusion_pytorch.tensor_typing import Float, Int, Bool
2828

29+
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
30+
2931
pad_sequence = partial(pad_sequence, batch_first = True)
3032

3133
# constants
@@ -109,6 +111,23 @@ def order_modality_positions_by_seq_offset(
109111

110112
return modalities, sorted_indices
111113

114+
# deriving relative positions from modality positions
115+
# ex. given a sequence of 10 with an image at offset 3 with length 4 - [t] [t] [t] [i] [i] [i] [i] [t] [t] [t]
116+
# relative positions for rotary will be [0] [1] [2] [3] [3] [3] [3] [4] [5] [6]
117+
# rationale is that each modality will need the same position so there is no distance when conducting bidirectional attention, but should still have a relative distance to other text tokens and modalities
118+
119+
def derive_rotary_positions_from_modality_positions(
120+
seq_len: int,
121+
modalities: Int['b m 2']
122+
) -> Int['b n']:
123+
124+
device = modalities.device
125+
126+
modality_mask = modality_positions_to_is_modality_mask(seq_len, modalities, offset = torch.tensor([1, -1]))
127+
is_any_modality = modality_mask.any(dim = 1)
128+
129+
return torch.arange(seq_len, device = device) - is_any_modality.cumsum(dim = -1)
130+
112131
# modality tokens are given as list of tensors, can be then be embedded into the modality tokens for attending alongside text tokens
113132

114133
def embed_modality_tokens(
@@ -139,10 +158,14 @@ def embed_modality_tokens(
139158
def modality_positions_to_is_modality_mask(
140159
seq_len: int,
141160
modalities: RawModalityPositions | Int['b m 2'],
161+
offset: Int['2'] | None = None,
142162
) -> Bool['b m n']:
143163

144164
if isinstance(modalities, list):
145-
modalities = modalities_to_tensor(modalities)
165+
modalities = modality_positions_to_tensor(modalities)
166+
167+
if exists(offset):
168+
modalities = modalities + offset.to(modalities)
146169

147170
left, right = modalities.cumsum(dim = -1).unbind(dim = -1)
148171

@@ -322,12 +345,16 @@ def __init__(
322345
def forward(
323346
self,
324347
x,
325-
attn_mask = None
348+
attn_mask = None,
349+
rotary_emb: Tensor | None = None
326350
):
327351
x = self.norm(x)
328352

329353
q, k, v = self.to_qkv(x)
330354

355+
if exists(rotary_emb):
356+
q, k = tuple(apply_rotary_emb(rotary_emb, t) for t in (q, k))
357+
331358
q = q * self.scale
332359
sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
333360

@@ -360,6 +387,7 @@ def __init__(
360387
):
361388
super().__init__()
362389
self.dim = dim
390+
self.dim_head = dim_head
363391

364392
self.to_time_cond = nn.Sequential(
365393
RandomFourierEmbed(dim),
@@ -388,7 +416,8 @@ def forward(
388416
times: Float[''] | Float['b'] | Float['b n'],
389417
attn_mask: Bool['b i j'] | None = None,
390418
modality_positions: RawModalityPositions | Int['b n 2'] | None = None,
391-
is_any_modality: Bool['b n'] | None = None
419+
is_any_modality: Bool['b n'] | None = None,
420+
rotary_emb: Tensor | None = None
392421
):
393422
seq_len, device = x.shape[-2], x.device
394423
assert exists(attn_mask) ^ exists(modality_positions)
@@ -414,7 +443,7 @@ def forward(
414443
# transformer layers as usual, using mask from above
415444

416445
for attn, ff in self.layers:
417-
x = attn(x, attn_mask = attn_mask, **adaptive_kwargs) + x
446+
x = attn(x, attn_mask = attn_mask, rotary_emb = rotary_emb, **adaptive_kwargs) + x
418447
x = ff(x, **adaptive_kwargs) + x
419448

420449
return self.norm(x)
@@ -438,9 +467,13 @@ def __init__(
438467
transformer = Transformer(**transformer)
439468

440469
self.transformer = transformer
441-
dim = transformer.dim
470+
dim, dim_head = transformer.dim, transformer.dim_head
442471
self.dim = dim
443472

473+
# relative positions
474+
475+
self.rotary_emb = RotaryEmbedding(transformer.dim_head)
476+
444477
# embeddings and un-embeddings
445478

446479
self.text_embed = nn.Embedding(num_text_tokens, dim)
@@ -519,11 +552,18 @@ def forward(
519552

520553
tokens = einx.where('b n, b n d, b n d', is_any_modality, modality_tokens, text_tokens)
521554

555+
# derive rotary positions
556+
557+
rotary_positions = derive_rotary_positions_from_modality_positions(seq_len, modality_positions)
558+
559+
rotary_emb = self.rotary_emb(rotary_positions)
560+
522561
# attention
523562

524563
embed = self.transformer(
525564
tokens,
526565
times = times,
566+
rotary_emb = rotary_emb,
527567
modality_positions = modality_positions
528568
)
529569

0 commit comments

Comments
 (0)