Skip to content

Commit d4ad918

Browse files
committed
go even simpler
1 parent 30d6888 commit d4ad918

File tree

3 files changed

+54
-71
lines changed

3 files changed

+54
-71
lines changed

README.md

Lines changed: 7 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -17,83 +17,24 @@ $ pip install transfusion-pytorch
1717
One modality, say images
1818

1919
```python
20-
import torch
20+
from torch import randint, randn
2121
from transfusion_pytorch import Transfusion
2222

2323
model = Transfusion(
2424
num_text_tokens = 256,
25-
dim_latent = 192,
25+
dim_latent = 384,
2626
transformer = dict(
2727
dim = 512,
2828
depth = 8
2929
)
3030
)
3131

32-
text_ids = torch.randint(0, 256, (2, 1024))
32+
text_and_images = [
33+
[randint(0, 256, (16,)), randn(4, 384), randint(0, 256, (8,)), randn(6, 384)],
34+
[randint(0, 256, (16,)), randn(7, 384), randint(0, 256, (5,)), randn(2, 384), randint(0, 256, (9,))]
35+
]
3336

34-
modality_tokens = [[
35-
torch.randn(6, 192),
36-
torch.randn(4, 192)
37-
], [
38-
torch.randn(5, 192),
39-
torch.randn(3, 192)
40-
]]
41-
42-
modality_positions = [[
43-
(2, 6),
44-
(10, 4)
45-
], [
46-
(2, 5),
47-
(10, 3)
48-
]] # (offset, length)
49-
50-
loss, breakdown = model(
51-
text_ids,
52-
modality_tokens = modality_tokens,
53-
modality_positions = modality_positions
54-
)
55-
56-
loss.backward()
57-
```
58-
59-
Multiple modalities
60-
61-
```python
62-
import torch
63-
from transfusion_pytorch import Transfusion
64-
65-
model = Transfusion(
66-
num_text_tokens = 256,
67-
dim_latent = (384, 192),
68-
transformer = dict(
69-
dim = 512,
70-
depth = 8
71-
)
72-
)
73-
74-
text_ids = torch.randint(0, 256, (2, 1024))
75-
76-
modality_tokens = [[
77-
torch.randn(6, 384),
78-
torch.randn(4, 192)
79-
], [
80-
torch.randn(5, 192),
81-
torch.randn(3, 384)
82-
]]
83-
84-
modality_positions = [[
85-
(0, 2, 6),
86-
(1, 10, 4)
87-
], [
88-
(1, 2, 5),
89-
(0, 10, 3)
90-
]] # (type, offset, length)
91-
92-
loss, breakdown = model(
93-
text_ids,
94-
modality_tokens = modality_tokens,
95-
modality_positions = modality_positions
96-
)
37+
loss = model(text_and_images)
9738

9839
loss.backward()
9940
```

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.0.14"
3+
version = "0.0.15"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

transfusion_pytorch/transfusion.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -585,17 +585,54 @@ def __init__(
585585
self.ignore_index = ignore_index
586586
self.diffusion_loss_weight = diffusion_loss_weight
587587

588+
@property
589+
def device(self):
590+
return next(self.parameters()).device
591+
588592
def forward(
589593
self,
590-
text: Int['b n'],
591-
modality_tokens: list[list[Float['_ _']]] | list[Float['b n _']] | Float['b n _'],
592-
modality_positions: RawModalityPositions | Int['b m 2'] | Int['b m 3'],
594+
modalities: list[list[Int['_'] | Float['_ _']]],
593595
times: Float['b m'] | None = None,
594-
return_loss = True
596+
return_loss = True,
597+
return_breakdown = False
595598
) -> (
596599
Float['b n l'] |
600+
Float[''] |
597601
tuple[Float[''], LossBreakdown]
598602
):
603+
device = self.device
604+
605+
# process list of text and modalities interspersed with one another
606+
607+
modality_positions = []
608+
modality_tokens = []
609+
text = []
610+
611+
for batch_modalities in modalities:
612+
batch_modality_positions = []
613+
batch_modality_tokens = []
614+
batch_text = []
615+
616+
for modality in batch_modalities:
617+
is_text = modality.dtype in (torch.long, torch.int)
618+
619+
length = modality.shape[0]
620+
offset = 0
621+
622+
if is_text:
623+
batch_text.append(modality)
624+
else:
625+
batch_text.append(torch.full((length,), -1, device = device))
626+
batch_modality_tokens.append(modality)
627+
batch_modality_positions.append((offset, length))
628+
629+
offset += length
630+
631+
text.append(torch.cat(batch_text))
632+
modality_tokens.append(batch_modality_tokens)
633+
modality_positions.append(batch_modality_positions)
634+
635+
text = pad_sequence(text, padding_value = -1)
599636

600637
# if returning loss, split text for next token prediction
601638

@@ -653,6 +690,8 @@ def forward(
653690

654691
# embed text
655692

693+
text = text.masked_fill(text == -1, 0)
694+
656695
text_tokens = self.text_embed(text)
657696

658697
# noise the modality tokens
@@ -766,4 +805,7 @@ def forward(
766805
(torch.stack(diffusion_losses) * torch.stack(modality_loss_weights)).sum() * self.diffusion_loss_weight
767806
)
768807

808+
if not return_breakdown:
809+
return total_loss
810+
769811
return total_loss, LossBreakdown(total_loss, text_loss, diffusion_losses)

0 commit comments

Comments
 (0)