@@ -71,6 +71,15 @@ def divisible_by(num, den):
7171def cast_tuple (t , length = 1 ):
7272 return t if isinstance (t , tuple ) else ((t ,) * length )
7373
74+ def eval_decorator (fn ):
75+ def inner (self , * args , ** kwargs ):
76+ was_training = self .training
77+ self .eval ()
78+ out = fn (self , * args , ** kwargs )
79+ self .train (was_training )
80+ return out
81+ return inner
82+
7483# tensor helpers
7584
7685def l2norm (t ):
@@ -593,7 +602,7 @@ def __init__(
593602
594603 # modality start and end ids
595604
596- self .som_ids , self .eom_ids = som_eom_tensor .tolist ()
605+ self .som_ids , self .eom_ids = modality_start_end_tensor .tolist ()
597606
598607 # entire "sentence" start and end id
599608
@@ -637,34 +646,93 @@ def device(self):
637646 return next (self .parameters ()).device
638647
639648 @torch .no_grad ()
649+ @eval_decorator
640650 def sample (
641651 self ,
642652 prompt : ModalitySample | None = None ,
643653 max_length = 8192 ,
644654 text_temperature = 1.5 ,
645655 text_min_p = 0.1 ,
656+ modality_length = 32 , # fix the modality token length for now, but this will be determined by the language model in a metadata tag
657+ modality_steps = 16
646658 ) -> ModalitySample :
647659
648- was_training = self .training
649- self .eval ()
660+ device = self .device
650661
651- seq = tensor ([self .sos_id ], device = self .device )
662+ init_text_seq = tensor ([self .sos_id ], device = self .device )
663+ modality_sample = [init_text_seq ]
652664
653- for _ in tqdm ( range ( max_length )):
654- logits = self . forward ([[ seq ]], return_loss = False )
655- logits = logits [ 0 ][ - 1 ]
665+ curr_length = 0
666+ curr_modality_id = None
667+ is_decoding_text = True # starts off with text decoding, and alternates with modalities depending on [som] tokens detected
656668
657- logits = min_p_filter (logits , min_p = text_min_p )
658- probs = (logits / text_temperature ).softmax (dim = - 1 )
669+ with tqdm (total = max_length ) as pbar :
659670
660- sampled = torch .multinomial (probs , 1 )
661- seq = torch .cat ((seq , sampled ), dim = - 1 )
671+ while curr_length <= max_length :
662672
663- if sampled . item () == self . eos_id :
664- break
673+ if is_decoding_text :
674+ * _ , seq = modality_sample
665675
666- self .train (was_training )
667- return [seq ]
676+ logits = self .forward ([modality_sample ], return_loss = False )
677+ logits = logits [0 ][- 1 ]
678+
679+ logits = min_p_filter (logits , min_p = text_min_p )
680+ probs = (logits / text_temperature ).softmax (dim = - 1 )
681+
682+ sampled = torch .multinomial (probs , 1 )
683+
684+ seq = torch .cat ((seq , sampled ), dim = - 1 )
685+ modality_sample [- 1 ] = seq
686+
687+ pbar .update (1 )
688+ curr_length += 1
689+
690+ sampled_token_id = sampled .item ()
691+
692+ if sampled_token_id == self .eos_id :
693+ break
694+
695+ if sampled_token_id in self .som_ids :
696+ curr_modality_id = self .som_ids .index (sampled_token_id )
697+ is_decoding_text = False
698+
699+ else :
700+ assert exists (curr_modality_id )
701+
702+ latent_dim = self .dim_latents [curr_modality_id ]
703+ noise = torch .randn ((modality_length , latent_dim ), device = device )
704+
705+ def ode_step_fn (t , denoised ):
706+ embeds = self .forward ([[* modality_sample , denoised ]], return_loss = False , return_embed = True )
707+
708+ to_flow_pred = self .model_to_latent_preds [curr_modality_id ]
709+ flow = to_flow_pred (embeds )
710+
711+ return flow [0 , - modality_length :]
712+
713+ times = torch .linspace (0 , 1 , modality_steps , device = device )
714+
715+ trajectory = self .odeint_fn (ode_step_fn , noise , times )
716+
717+ # add the sampled modality tokens
718+
719+ sampled_modality = trajectory [- 1 ]
720+ modality_sample .append ((curr_modality_id , sampled_modality ))
721+
722+ # add the appropriate [eom]
723+
724+ eom_id = self .eom_ids [curr_modality_id ]
725+ modality_sample .append (tensor ([eom_id ], device = device ))
726+
727+ # back to decoding text
728+
729+ pbar .update (modality_length )
730+ curr_length += modality_length
731+
732+ curr_modality_id = None
733+ is_decoding_text = True
734+
735+ return modality_sample
668736
669737 def forward (
670738 self ,
@@ -675,12 +743,16 @@ def forward(
675743 None
676744 ) = None ,
677745 return_loss = True ,
678- return_breakdown = False
746+ return_breakdown = False ,
747+ return_embed = False
679748 ) -> (
680749 Float ['b n l' ] |
750+ Float ['b n d' ] |
681751 Float ['' ] |
682752 tuple [Float ['' ], LossBreakdown ]
683753 ):
754+ return_loss &= not return_embed
755+
684756 device = self .device
685757
686758 # add "sentence" start and end tokens when training
@@ -717,7 +789,7 @@ def forward(
717789 modality_type , modality_tensor = modality
718790
719791 assert 0 <= modality_type < self .num_modalities , f'received a modality index that is out of range. only { self .num_modalities } modalities specified'
720- assert self .dim_latents [modality_type ] == modality_tensor .shape [- 1 ], 'mismatch for modality latent dimension - expected {self.dim_latents[modality_type]} but received {modality_tensor.shape[-1]}'
792+ assert self .dim_latents [modality_type ] == modality_tensor .shape [- 1 ], f 'mismatch for modality latent dimension - expected { self .dim_latents [modality_type ]} but received { modality_tensor .shape [- 1 ]} '
721793
722794 length = modality_tensor .shape [0 ]
723795
@@ -874,6 +946,11 @@ def forward(
874946 modality_positions = modality_positions
875947 )
876948
949+ # early return for embedding for decoding modality
950+
951+ if return_embed :
952+ return embed
953+
877954 # text unembedding
878955
879956 text_logits = self .to_text_logits (embed )
0 commit comments