@@ -1050,6 +1050,7 @@ def __init__(
10501050 add_pos_emb : bool | tuple [bool , ...] = False ,
10511051 modality_encoder : Module | tuple [Module , ...] | None = None ,
10521052 modality_decoder : Module | tuple [Module , ...] | None = None ,
1053+ pre_post_transformer_enc_dec : tuple [Module , Module ] | tuple [tuple [Module , Module ], ...] | None = None ,
10531054 modality_token_transform : tuple [ModalityTokenTransform , ...] | ModalityTokenTransform | None = None ,
10541055 modality_default_shape : tuple [int , ...] | tuple [tuple [int , ...], ...] | None = None ,
10551056 fallback_to_default_shape_if_invalid = False ,
@@ -1188,14 +1189,32 @@ def __init__(
11881189
11891190 assert len (self .modality_token_transform ) == self .num_modalities
11901191
1192+ # prepare pre-post transformer encoder / decoder, for the learnable unets as in paper
1193+
1194+ if is_bearable (pre_post_transformer_enc_dec , tuple [Module , Module ]):
1195+ pre_post_transformer_enc_dec = (pre_post_transformer_enc_dec ,)
1196+
1197+ pre_post_transformer_enc_dec = cast_tuple (pre_post_transformer_enc_dec , self .num_modalities )
1198+ assert len (pre_post_transformer_enc_dec ) == self .num_modalities
1199+
11911200 # latent to model and back
11921201 # by default will be Linear, with or without rearranges depending on channel_first_latent setting
1193- # can also be overridden for the unet down/up as in the paper
1202+ # can also be overridden for the unet down/up as in the paper with `pre_post_transformer_enc_dec: tuple[Module, Module]`
11941203
11951204 latent_to_model_projs = []
11961205 model_to_latent_projs = []
11971206
1198- for dim_latent , one_channel_first_latent in zip (self .dim_latents , self .channel_first_latent ):
1207+ for (
1208+ dim_latent ,
1209+ one_channel_first_latent ,
1210+ enc_dec ,
1211+ ) in zip (
1212+ self .dim_latents ,
1213+ self .channel_first_latent ,
1214+ pre_post_transformer_enc_dec
1215+ ):
1216+
1217+ pre_attend_enc , post_attend_dec = default (enc_dec , (None , None ))
11991218
12001219 latent_to_model_proj = Linear (dim_latent , dim ) if dim_latent != dim else nn .Identity ()
12011220 model_to_latent_proj = Linear (dim , dim_latent , bias = False )
@@ -1204,8 +1223,8 @@ def __init__(
12041223 latent_to_model_proj = nn .Sequential (Rearrange ('b d ... -> b ... d' ), latent_to_model_proj , Rearrange ('b ... d -> b d ...' ))
12051224 model_to_latent_proj = nn .Sequential (Rearrange ('b d ... -> b ... d' ), model_to_latent_proj , Rearrange ('b ... d -> b d ...' ))
12061225
1207- latent_to_model_projs .append (latent_to_model_proj )
1208- model_to_latent_projs .append (model_to_latent_proj )
1226+ latent_to_model_projs .append (default ( pre_attend_enc , latent_to_model_proj ) )
1227+ model_to_latent_projs .append (default ( post_attend_dec , model_to_latent_proj ) )
12091228
12101229 self .latent_to_model_projs = ModuleList (latent_to_model_projs )
12111230 self .model_to_latent_projs = ModuleList (model_to_latent_projs )
@@ -1706,18 +1725,6 @@ def forward_modality(
17061725 mod .encoder .eval ()
17071726 modalities = self .maybe_add_temp_batch_dim (mod .encoder )(modalities ).detach ()
17081727
1709- # axial positions
1710-
1711- if mod .add_pos_emb :
1712- assert exists (mod .num_dim ), f'modality_num_dim must be set for modality { modality_type } if further injecting axial positional embedding'
1713-
1714- if mod .channel_first_latent :
1715- _ , _ , * axial_dims = modalities .shape
1716- else :
1717- _ , * axial_dims , _ = modalities .shape
1718-
1719- assert len (axial_dims ) == mod .num_dim , f'received modalities of ndim { len (axial_dims )} but expected { modality_num_dim } '
1720-
17211728 # shapes and device
17221729
17231730 tokens = modalities
@@ -1753,6 +1760,15 @@ def forward_modality(
17531760 if mod .channel_first_latent :
17541761 noised_tokens = rearrange (noised_tokens , 'b d ... -> b ... d' )
17551762
1763+ # axial positions
1764+
1765+ if mod .add_pos_emb :
1766+ assert exists (mod .num_dim ), f'modality_num_dim must be set for modality { modality_type } if further injecting axial positional embedding'
1767+
1768+ _ , * axial_dims , _ = noised_tokens .shape
1769+
1770+ assert len (axial_dims ) == mod .num_dim , f'received modalities of ndim { len (axial_dims )} but expected { modality_num_dim } '
1771+
17561772 # maybe transform
17571773
17581774 noised_tokens = mod .token_transform (noised_tokens )
0 commit comments