@@ -1546,7 +1546,9 @@ def ode_step_fn(step_times, denoised):
15461546
15471547 parse_embed = get_pred_flows [curr_modality_id ][- 1 ]
15481548
1549- flow = add_temp_batch_dim (mod .model_to_latent )(parse_embed (embeds ))
1549+ parsed_embed = parse_embed (embeds , need_splice = not exists (cache ))
1550+
1551+ flow = add_temp_batch_dim (mod .model_to_latent )(parsed_embed )
15501552
15511553 return flow
15521554
@@ -2161,10 +2163,14 @@ def forward(
21612163
21622164 def model_to_pred_flow (batch_index , start_index , modality_length , unpack_fn ):
21632165
2164- def inner (embed : Float ['b n d' ]) -> Float ['...' ]:
2165- modality_embed = embed [batch_index , start_index :(start_index + modality_length )]
2166- modality_embed = unpack_fn (modality_embed )
2167- return modality_embed
2166+ def inner (embed : Float ['b n d' ], need_splice = True ) -> Float ['...' ]:
2167+ embed = embed [batch_index ]
2168+
2169+ if need_splice :
2170+ embed = embed [start_index :(start_index + modality_length )]
2171+
2172+ embed = unpack_fn (embed )
2173+ return embed
21682174
21692175 return inner
21702176
@@ -2334,9 +2340,9 @@ def inner(embed: Float['b n d']) -> Float['...']:
23342340 modality_get_pred_flows = get_pred_flows [modality_id ]
23352341
23362342 modality_pred_flows = []
2343+
23372344 for get_pred_flow in modality_get_pred_flows :
23382345 pred_flow = get_pred_flow (embed )
2339-
23402346 pred_flow = add_temp_batch_dim (mod .model_to_latent )(pred_flow )
23412347 modality_pred_flows .append (pred_flow )
23422348
0 commit comments