@@ -624,6 +624,30 @@ def apply_fn_modality_type(
624624
625625 return tree_unflatten (out , tree_spec )
626626
627+ # decorator for model output to flow
628+
629+ def get_model_output_to_flow_fn (
630+ noised : Tensor ,
631+ times : Tensor ,
632+ eps = 5e-2 ,
633+ return_decorator = False
634+ ):
635+ def to_flow (out ):
636+ padded_times = append_dims (times , out .ndim - 1 )
637+ flow = (out - noised ) / (1. - padded_times ).clamp_min (eps )
638+ return flow
639+
640+ if not return_decorator :
641+ return to_flow
642+
643+ def decorator (fn ):
644+ def inner (embed ):
645+ out = fn (embed )
646+ return to_flow (out )
647+ return inner
648+
649+ return decorator
650+
627651# sampling related functions
628652
629653# min_p for text
@@ -1230,6 +1254,7 @@ def __init__(
12301254 * ,
12311255 num_text_tokens ,
12321256 transformer : dict | Transformer ,
1257+ pred_clean = False ,
12331258 dim_latent : int | tuple [int , ...] | None = None ,
12341259 channel_first_latent : bool | tuple [bool , ...] = False ,
12351260 add_pos_emb : bool | tuple [bool , ...] = False ,
@@ -1251,6 +1276,7 @@ def __init__(
12511276 rtol = 1e-5 ,
12521277 method = 'midpoint'
12531278 ),
1279+ eps = 5e-2
12541280 ):
12551281 super ().__init__ ()
12561282
@@ -1453,6 +1479,11 @@ def __init__(
14531479 self .has_recon_loss = reconstruction_loss_weight > 0.
14541480 self .reconstruction_loss_weight = reconstruction_loss_weight
14551481
1482+ # whether model is predicting clean
1483+
1484+ self .pred_clean = pred_clean
1485+ self .eps = eps
1486+
14561487 # flow sampling related
14571488
14581489 self .odeint_fn = partial (odeint , ** odeint_kwargs )
@@ -2001,6 +2032,13 @@ def forward_modality(
20012032 else :
20022033 noised_tokens = tokens
20032034
2035+ # save the noised and times
2036+
2037+ model_output_to_flow = identity
2038+
2039+ if self .pred_clean :
2040+ model_output_to_flow = get_model_output_to_flow_fn (noised_tokens , times )
2041+
20042042 # from latent to model tokens
20052043
20062044 noised_tokens = mod .latent_to_model (noised_tokens )
@@ -2034,7 +2072,9 @@ def forward_modality(
20342072
20352073 embed = inverse_pack_axial_dims (embed )
20362074
2037- pred_flow = mod .model_to_latent (embed )
2075+ model_output = mod .model_to_latent (embed )
2076+
2077+ pred_flow = model_output_to_flow (model_output )
20382078
20392079 if not return_loss :
20402080 return pred_flow
@@ -2475,6 +2515,14 @@ def inner(pred_flow):
24752515
24762516 inverse_fn = model_to_pred_flow (batch_index , offset + precede_modality_tokens , modality_length , unpack_modality_shape )
24772517
2518+ # maybe decorate the function if model output is predicting clean
2519+
2520+ if self .pred_clean :
2521+ decorator = get_model_output_to_flow_fn (modality_tensor , modality_time , self .eps , return_decorator = True )
2522+ inverse_fn = decorator (inverse_fn )
2523+
2524+ # store function for extracting flow later
2525+
24782526 get_pred_flows [modality_type ].append (inverse_fn )
24792527
24802528 # increment offset
0 commit comments