File tree Expand file tree Collapse file tree 2 files changed +7
-5
lines changed
Expand file tree Collapse file tree 2 files changed +7
-5
lines changed Original file line number Diff line number Diff line change 11[project ]
22name = " transfusion-pytorch"
3- version = " 0.0.12 "
3+ version = " 0.0.14 "
44description = " Transfusion in Pytorch"
55authors = [
66 { name = " Phil Wang" , email = " lucidrains@gmail.com" }
Original file line number Diff line number Diff line change @@ -554,15 +554,17 @@ def __init__(
554554
555555 self .dim_latents = cast_tuple (dim_latent )
556556
557+ # number of modalities
558+
559+ self .num_modalities = len (self .dim_latents )
560+
557561 # modality transforms
558562
559- modality_token_transform = cast_tuple (modality_token_transform )
563+ modality_token_transform = cast_tuple (modality_token_transform , self . num_modalities )
560564 modality_token_transform = [default (transform , identity ) for transform in modality_token_transform ]
561565 self .modality_token_transform = [Rearrange (maybe_einops_eq ) if isinstance (maybe_einops_eq , str ) else maybe_einops_eq for maybe_einops_eq in modality_token_transform ]
562566
563- # number of modalities
564-
565- self .num_modalities = len (self .dim_latents )
567+ assert len (self .modality_token_transform ) == self .num_modalities
566568
567569 self .latent_to_model_projs = ModuleList ([Linear (dim_latent , dim ) if dim_latent != dim else nn .Identity () for dim_latent in self .dim_latents ])
568570
You can’t perform that action at this time.
0 commit comments