Skip to content

Commit 05ac097

Browse files
committed
support multiple modalities again, except simpler
1 parent d4ad918 commit 05ac097

File tree

3 files changed

+48
-7
lines changed

3 files changed

+48
-7
lines changed

README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,33 @@ loss = model(text_and_images)
3939
loss.backward()
4040
```
4141

42+
Multiple different modalities
43+
44+
```python
45+
from torch import randint, randn
46+
from transfusion_pytorch import Transfusion
47+
48+
model = Transfusion(
49+
num_text_tokens = 256,
50+
dim_latent = (384, 192), # specify multiple latent dimensions
51+
transformer = dict(
52+
dim = 512,
53+
depth = 8
54+
)
55+
)
56+
57+
# then for the Tensors of type float, you can pass a tuple[int, Tensor] and specify the modality index in the first position
58+
59+
text_images_and_audio = [
60+
[randint(0, 256, (16,)), (0, randn(4, 384)), randint(0, 256, (8,)), (1, randn(6, 192))],
61+
[randint(0, 256, (16,)), randn(7, 384), randint(0, 256, (5,)), (1, randn(2, 192)), randint(0, 256, (9,))]
62+
]
63+
64+
loss = model(text_images_and_audio)
65+
66+
loss.backward()
67+
```
68+
4269
## Citations
4370

4471
```bibtex

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.0.15"
3+
version = "0.0.16"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

transfusion_pytorch/transfusion.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ def device(self):
591591

592592
def forward(
593593
self,
594-
modalities: list[list[Int['_'] | Float['_ _']]],
594+
modalities: list[list[Int['_'] | Float['_ _'] | tuple[int, Float['_ _']]]],
595595
times: Float['b m'] | None = None,
596596
return_loss = True,
597597
return_breakdown = False
@@ -614,17 +614,31 @@ def forward(
614614
batch_text = []
615615

616616
for modality in batch_modalities:
617-
is_text = modality.dtype in (torch.long, torch.int)
617+
# if non-text modality detected and not given as a tuple
618+
# cast to (int, Tensor) where int is defaulted to type 0 (convenience for one modality)
619+
620+
if torch.is_tensor(modality) and modality.dtype == torch.float:
621+
modality = (0, modality)
622+
623+
is_text = not isinstance(modality, tuple)
624+
625+
if is_text:
626+
modality_tensor = modality
627+
else:
628+
modality_type, modality_tensor = modality
629+
630+
assert 0 <= modality_type < self.num_modalities, f'received a modality index that is out of range. only {self.num_modalities} modalities specified'
631+
assert self.dim_latents[modality_type] == modality_tensor.shape[-1], 'mismatch for modality latent dimension - expected {self.dim_latents[modality_type]} but received {modality_tensor.shape[-1]}'
618632

619-
length = modality.shape[0]
620633
offset = 0
634+
length = modality_tensor.shape[0]
621635

622636
if is_text:
623-
batch_text.append(modality)
637+
batch_text.append(modality_tensor)
624638
else:
625639
batch_text.append(torch.full((length,), -1, device = device))
626-
batch_modality_tokens.append(modality)
627-
batch_modality_positions.append((offset, length))
640+
batch_modality_tokens.append(modality_tensor)
641+
batch_modality_positions.append((modality_type, offset, length))
628642

629643
offset += length
630644

0 commit comments

Comments
 (0)