Skip to content

Commit 6834341

Browse files
committed
first handle axial positional embedding for single modality training
1 parent c1a1930 commit 6834341

File tree

3 files changed

+34
-5
lines changed

3 files changed

+34
-5
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.1.12"
3+
version = "0.1.14"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_transfusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def test_velocity_consistency():
215215
dim_latent = 384,
216216
channel_first_latent = True,
217217
modality_default_shape = ((4, 4)),
218-
modality_validate_num_dim = 2,
218+
modality_num_dim = 2,
219219
modality_encoder = mock_encoder,
220220
modality_decoder = mock_decoder,
221221
transformer = dict(

transfusion_pytorch/transfusion.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,6 +1424,8 @@ def forward_modality(
14241424
return_loss = True
14251425
) -> Float['']:
14261426

1427+
shape = modalities.shape
1428+
14271429
if self.num_modalities > 1:
14281430
assert exists(modality_type), '`modality_type` must be explicitly passed in on forward when training on greater than 1 modality'
14291431

@@ -1433,12 +1435,32 @@ def forward_modality(
14331435
latent_to_model_fn = self.latent_to_model_projs[modality_type]
14341436
model_to_flow_pred_fn = self.model_to_latent_preds[modality_type]
14351437

1438+
# grab the shape of the modality, for maybe axial pos emb
1439+
1440+
add_pos_emb = self.add_pos_emb[modality_type]
1441+
maybe_pos_emb_mlp = self.pos_emb_mlp[modality_type]
1442+
modality_num_dim = self.modality_num_dim[modality_type]
1443+
1444+
if add_pos_emb:
1445+
assert exists(modality_num_dim), f'modality_num_dim must be set for modality {modality_type} if further injecting axial positional embedding'
1446+
1447+
if self.channel_first_latent:
1448+
_, _, *axial_dims = shape
1449+
else:
1450+
_, *axial_dims, _ = shape
1451+
1452+
assert len(axial_dims) == modality_num_dim, f'received modalities of ndim {len(axial_dims)} but expected {modality_num_dim}'
1453+
1454+
# maybe transform
1455+
14361456
tokens = transform(modalities)
14371457

14381458
# maybe channel first
14391459

14401460
if self.channel_first_latent:
1441-
tokens = rearrange(tokens, 'b d ... -> b (...) d')
1461+
tokens = rearrange(tokens, 'b d ... -> b ... d')
1462+
1463+
tokens = rearrange(tokens, 'b ... d -> b (...) d')
14421464

14431465
# rotary
14441466

@@ -1459,10 +1481,17 @@ def forward_modality(
14591481

14601482
flow = tokens - noise
14611483

1462-
# attention
1463-
14641484
noised_tokens = latent_to_model_fn(noised_tokens)
14651485

1486+
# maybe add axial pos emb
1487+
1488+
if add_pos_emb:
1489+
axial_pos_emb = maybe_pos_emb_mlp(tensor(axial_dims))
1490+
1491+
noised_tokens = noised_tokens + rearrange(axial_pos_emb, '... d -> (...) d')
1492+
1493+
# attention
1494+
14661495
embed = self.transformer(
14671496
noised_tokens,
14681497
times = times,

0 commit comments

Comments
 (0)