@@ -1424,6 +1424,8 @@ def forward_modality(
14241424 return_loss = True
14251425 ) -> Float ['' ]:
14261426
1427+ shape = modalities .shape
1428+
14271429 if self .num_modalities > 1 :
14281430 assert exists (modality_type ), '`modality_type` must be explicitly passed in on forward when training on greater than 1 modality'
14291431
@@ -1433,12 +1435,32 @@ def forward_modality(
14331435 latent_to_model_fn = self .latent_to_model_projs [modality_type ]
14341436 model_to_flow_pred_fn = self .model_to_latent_preds [modality_type ]
14351437
1438+ # grab the shape of the modality, for maybe axial pos emb
1439+
1440+ add_pos_emb = self .add_pos_emb [modality_type ]
1441+ maybe_pos_emb_mlp = self .pos_emb_mlp [modality_type ]
1442+ modality_num_dim = self .modality_num_dim [modality_type ]
1443+
1444+ if add_pos_emb :
1445+ assert exists (modality_num_dim ), f'modality_num_dim must be set for modality { modality_type } if further injecting axial positional embedding'
1446+
1447+ if self .channel_first_latent :
1448+ _ , _ , * axial_dims = shape
1449+ else :
1450+ _ , * axial_dims , _ = shape
1451+
1452+ assert len (axial_dims ) == modality_num_dim , f'received modalities of ndim { len (axial_dims )} but expected { modality_num_dim } '
1453+
1454+ # maybe transform
1455+
14361456 tokens = transform (modalities )
14371457
14381458 # maybe channel first
14391459
14401460 if self .channel_first_latent :
1441- tokens = rearrange (tokens , 'b d ... -> b (...) d' )
1461+ tokens = rearrange (tokens , 'b d ... -> b ... d' )
1462+
1463+ tokens = rearrange (tokens , 'b ... d -> b (...) d' )
14421464
14431465 # rotary
14441466
@@ -1459,10 +1481,17 @@ def forward_modality(
14591481
14601482 flow = tokens - noise
14611483
1462- # attention
1463-
14641484 noised_tokens = latent_to_model_fn (noised_tokens )
14651485
1486+ # maybe add axial pos emb
1487+
1488+ if add_pos_emb :
1489+ axial_pos_emb = maybe_pos_emb_mlp (tensor (axial_dims ))
1490+
1491+ noised_tokens = noised_tokens + rearrange (axial_pos_emb , '... d -> (...) d' )
1492+
1493+ # attention
1494+
14661495 embed = self .transformer (
14671496 noised_tokens ,
14681497 times = times ,
0 commit comments