Skip to content

Commit 30d6888

Browse files
committed
0.0.14
1 parent 1782570 commit 30d6888

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.0.12"
3+
version = "0.0.14"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

transfusion_pytorch/transfusion.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -554,15 +554,17 @@ def __init__(
554554

555555
self.dim_latents = cast_tuple(dim_latent)
556556

557+
# number of modalities
558+
559+
self.num_modalities = len(self.dim_latents)
560+
557561
# modality transforms
558562

559-
modality_token_transform = cast_tuple(modality_token_transform)
563+
modality_token_transform = cast_tuple(modality_token_transform, self.num_modalities)
560564
modality_token_transform = [default(transform, identity) for transform in modality_token_transform]
561565
self.modality_token_transform = [Rearrange(maybe_einops_eq) if isinstance(maybe_einops_eq, str) else maybe_einops_eq for maybe_einops_eq in modality_token_transform]
562566

563-
# number of modalities
564-
565-
self.num_modalities = len(self.dim_latents)
567+
assert len(self.modality_token_transform) == self.num_modalities
566568

567569
self.latent_to_model_projs = ModuleList([Linear(dim_latent, dim) if dim_latent != dim else nn.Identity() for dim_latent in self.dim_latents])
568570

0 commit comments

Comments
 (0)