@@ -664,6 +664,7 @@ def sample(
664664
665665 curr_length = 0
666666 curr_modality_id = None
667+ num_past_modalities = 0 # starts off with no modalities in output
667668 is_decoding_text = True # starts off with text decoding, and alternates with modalities depending on [som] tokens detected
668669
669670 with tqdm (total = max_length ) as pbar :
@@ -702,8 +703,15 @@ def sample(
702703 latent_dim = self .dim_latents [curr_modality_id ]
703704 noise = torch .randn ((modality_length , latent_dim ), device = device )
704705
705- def ode_step_fn (t , denoised ):
706- embeds = self .forward ([[* modality_sample , denoised ]], return_loss = False , return_embed = True )
706+ def ode_step_fn (step_times , denoised ):
707+ step_times = rearrange (step_times , ' -> 1 1' ) # batch size of 1
708+ step_times = F .pad (step_times , (num_past_modalities , 0 ), value = 1. ) # past decoded modalities receive a time conditioning of 1.
709+
710+ embeds = self .forward (
711+ [[* modality_sample , denoised ]],
712+ times = step_times ,
713+ return_embed = True ,
714+ )
707715
708716 to_flow_pred = self .model_to_latent_preds [curr_modality_id ]
709717 flow = to_flow_pred (embeds )
@@ -729,6 +737,7 @@ def ode_step_fn(t, denoised):
729737 pbar .update (modality_length )
730738 curr_length += modality_length
731739
740+ num_past_modalities += 1
732741 curr_modality_id = None
733742 is_decoding_text = True
734743
0 commit comments