Skip to content

Commit 6dc31cb

Browse files
committed
when sampling, past decoded modalities receive a time conditioning of 1.
1 parent ed73212 commit 6dc31cb

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.0.26"
3+
version = "0.0.27"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

transfusion_pytorch/transfusion.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,7 @@ def sample(
664664

665665
curr_length = 0
666666
curr_modality_id = None
667+
num_past_modalities = 0 # starts off with no modalities in output
667668
is_decoding_text = True # starts off with text decoding, and alternates with modalities depending on [som] tokens detected
668669

669670
with tqdm(total = max_length) as pbar:
@@ -702,8 +703,15 @@ def sample(
702703
latent_dim = self.dim_latents[curr_modality_id]
703704
noise = torch.randn((modality_length, latent_dim), device = device)
704705

705-
def ode_step_fn(t, denoised):
706-
embeds = self.forward([[*modality_sample, denoised]], return_loss = False, return_embed = True)
706+
def ode_step_fn(step_times, denoised):
707+
step_times = rearrange(step_times, ' -> 1 1') # batch size of 1
708+
step_times = F.pad(step_times, (num_past_modalities, 0), value = 1.) # past decoded modalities receive a time conditioning of 1.
709+
710+
embeds = self.forward(
711+
[[*modality_sample, denoised]],
712+
times = step_times,
713+
return_embed = True,
714+
)
707715

708716
to_flow_pred = self.model_to_latent_preds[curr_modality_id]
709717
flow = to_flow_pred(embeds)
@@ -729,6 +737,7 @@ def ode_step_fn(t, denoised):
729737
pbar.update(modality_length)
730738
curr_length += modality_length
731739

740+
num_past_modalities += 1
732741
curr_modality_id = None
733742
is_decoding_text = True
734743

0 commit comments

Comments
 (0)