@@ -90,6 +90,7 @@ class LossBreakdown(NamedTuple):
9090 text : Scalar
9191 flow : list [Scalar ]
9292 velocity : list [Scalar ] | None
93+ recon : list [Scalar ] | None
9394
9495class ModalityInfo (NamedTuple ):
9596 encoder : Module | None
@@ -1085,6 +1086,7 @@ def __init__(
10851086 flow_loss_weight = 1. ,
10861087 text_loss_weight = 1. ,
10871088 velocity_consistency_loss_weight = 0.1 ,
1089+ reconstruction_loss_weight = 0. ,
10881090 modality_encoder_decoder_requires_batch_dim = True , # whether the modality encoder / decoder requires batch dimension, will auto assume it is needed
10891091 odeint_kwargs : dict = dict (
10901092 atol = 1e-5 ,
@@ -1277,6 +1279,11 @@ def __init__(
12771279
12781280 self .velocity_consistency_loss_weight = velocity_consistency_loss_weight
12791281
1282+ # additional reconstruction loss, through the decoder
1283+
1284+ self .has_recon_loss = reconstruction_loss_weight > 0.
1285+ self .reconstruction_loss_weight = reconstruction_loss_weight
1286+
12801287 # flow sampling related
12811288
12821289 self .odeint_fn = partial (odeint , ** odeint_kwargs )
@@ -1711,10 +1718,10 @@ def forward_modality(
17111718 return_loss = True ,
17121719 return_loss_breakdown = False
17131720 ) -> Scalar | Float ['b ...' ]:
1714-
17151721 requires_velocity_consistency = exists (velocity_consistency_ema_model )
17161722
17171723 modalities = modalities .to (self .device )
1724+ orig_modalities = modalities
17181725
17191726 if self .num_modalities > 1 :
17201727 assert exists (modality_type ), '`modality_type` must be explicitly passed in on forward when training on greater than 1 modality'
@@ -1754,6 +1761,7 @@ def forward_modality(
17541761 noised_tokens = padded_times * tokens + (1. - padded_times ) * noise
17551762
17561763 flow = tokens - noise
1764+
17571765 else :
17581766 noised_tokens = tokens
17591767
@@ -1816,17 +1824,37 @@ def forward_modality(
18161824
18171825 velocity_loss = F .mse_loss (flow , flow_with_delta_time )
18181826
1827+ # maybe recon loss
1828+
1829+ recon_loss = self .zero
1830+
1831+ if self .has_recon_loss :
1832+ assert encode_modality
1833+
1834+ recon = noise + pred_flow * (1. - padded_times )
1835+
1836+ if exists (mod .decoder ):
1837+ with torch .no_grad ():
1838+ mod .decoder .eval ()
1839+ recon = self .maybe_add_temp_batch_dim (mod .decoder )(recon )
1840+
1841+ recon_loss = F .mse_loss (
1842+ recon ,
1843+ orig_modalities
1844+ )
1845+
18191846 # total loss
18201847
18211848 total_loss = (
18221849 flow_loss +
1823- velocity_loss * self .velocity_consistency_loss_weight
1850+ velocity_loss * self .velocity_consistency_loss_weight +
1851+ recon_loss * self .reconstruction_loss_weight
18241852 )
18251853
18261854 if not return_loss_breakdown :
18271855 return total_loss
18281856
1829- return total_loss , (flow_loss , velocity_loss )
1857+ return total_loss , (flow_loss , velocity_loss , recon_loss )
18301858
18311859 @torch .no_grad ()
18321860 @eval_decorator
0 commit comments