@@ -585,17 +585,54 @@ def __init__(
585585 self .ignore_index = ignore_index
586586 self .diffusion_loss_weight = diffusion_loss_weight
587587
588+ @property
589+ def device (self ):
590+ return next (self .parameters ()).device
591+
588592 def forward (
589593 self ,
590- text : Int ['b n' ],
591- modality_tokens : list [list [Float ['_ _' ]]] | list [Float ['b n _' ]] | Float ['b n _' ],
592- modality_positions : RawModalityPositions | Int ['b m 2' ] | Int ['b m 3' ],
594+ modalities : list [list [Int ['_' ] | Float ['_ _' ]]],
593595 times : Float ['b m' ] | None = None ,
594- return_loss = True
596+ return_loss = True ,
597+ return_breakdown = False
595598 ) -> (
596599 Float ['b n l' ] |
600+ Float ['' ] |
597601 tuple [Float ['' ], LossBreakdown ]
598602 ):
603+ device = self .device
604+
605+ # process list of text and modalities interspersed with one another
606+
607+ modality_positions = []
608+ modality_tokens = []
609+ text = []
610+
611+ for batch_modalities in modalities :
612+ batch_modality_positions = []
613+ batch_modality_tokens = []
614+ batch_text = []
615+
616+ for modality in batch_modalities :
617+ is_text = modality .dtype in (torch .long , torch .int )
618+
619+ length = modality .shape [0 ]
620+ offset = 0
621+
622+ if is_text :
623+ batch_text .append (modality )
624+ else :
625+ batch_text .append (torch .full ((length ,), - 1 , device = device ))
626+ batch_modality_tokens .append (modality )
627+ batch_modality_positions .append ((offset , length ))
628+
629+ offset += length
630+
631+ text .append (torch .cat (batch_text ))
632+ modality_tokens .append (batch_modality_tokens )
633+ modality_positions .append (batch_modality_positions )
634+
635+ text = pad_sequence (text , padding_value = - 1 )
599636
600637 # if returning loss, split text for next token prediction
601638
@@ -653,6 +690,8 @@ def forward(
653690
654691 # embed text
655692
693+ text = text .masked_fill (text == - 1 , 0 )
694+
656695 text_tokens = self .text_embed (text )
657696
658697 # noise the modality tokens
@@ -766,4 +805,7 @@ def forward(
766805 (torch .stack (diffusion_losses ) * torch .stack (modality_loss_weights )).sum () * self .diffusion_loss_weight
767806 )
768807
808+ if not return_breakdown :
809+ return total_loss
810+
769811 return total_loss , LossBreakdown (total_loss , text_loss , diffusion_losses )
0 commit comments