Skip to content

Commit 506134b

Browse files
committed
one more fix
1 parent 477ddbe commit 506134b

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.4.11"
3+
version = "0.4.12"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

transfusion_pytorch/transfusion.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2370,7 +2370,7 @@ def forward(
23702370

23712371
velocity_match_losses = []
23722372

2373-
for ema_pred_flow, pred_flow, is_one_modality in zip(ema_pred_flows, pred_flows, is_modalities.unbind(dim = 1)):
2373+
for mod, ema_pred_flow, pred_flow, is_one_modality in zip(self.get_all_modality_info(), ema_pred_flows, pred_flows, is_modalities.unbind(dim = 1)):
23742374

23752375
velocity_match_loss = F.mse_loss(
23762376
pred_flow,
@@ -2380,6 +2380,9 @@ def forward(
23802380

23812381
is_one_modality = reduce(is_one_modality, 'b m n -> b n', 'any')
23822382

2383+
if mod.channel_first_latent:
2384+
velocity_match_loss = rearrange(velocity_match_loss, 'b d ... -> b ... d')
2385+
23832386
velocity_match_loss = velocity_match_loss[is_one_modality].mean()
23842387

23852388
velocity_match_losses.append(velocity_match_loss)

0 commit comments

Comments
 (0)