@@ -201,11 +201,11 @@ def concat_contiguous_text(
201201 """ within a modality sample, any two tensors of type int / long will be concatted together if next to each other, so all text is followed by a modality, and all modality followed by text """
202202
203203 output = []
204- curr_modality = None
205204
206205 for modality in modality_sample :
207206 if (
208207 len (output ) > 0 and
208+ is_tensor (output [- 1 ]) and is_tensor (modality ) and
209209 output [- 1 ].dtype == modality .dtype and
210210 modality .dtype in (torch .int , torch .long )
211211 ):
@@ -1365,6 +1365,19 @@ def get_modality_info(
13651365 def get_all_modality_info (self ) -> list [ModalityInfo ]:
13661366 return [self .get_modality_info (i ) for i in range (self .num_modalities )]
13671367
1368+ def get_modality_shape (
1369+ self ,
1370+ modality : Float ['...' ],
1371+ modality_type : int | None = None
1372+ ) -> tuple [int , ...]:
1373+
1374+ mod = self .get_modality_info (modality_type )
1375+
1376+ if mod .channel_first_latent :
1377+ modality = rearrange (modality , 'c ... -> ... c' )
1378+
1379+ return tuple (modality .shape [:- 1 ])
1380+
13681381 def parameters_without_encoder_decoder (self ):
13691382 return (
13701383 set (self .parameters ()) -
@@ -1402,7 +1415,7 @@ def create_ema(
14021415 @typecheck
14031416 def sample (
14041417 self ,
1405- prompt : ModalitySample | None = None ,
1418+ prompt : ModalitySample | Tensor | tuple [ int , Float [ '...' ]] | None = None ,
14061419 max_length = 2048 ,
14071420 text_temperature = 1.5 ,
14081421 text_min_p = 0.1 ,
@@ -1415,22 +1428,52 @@ def sample(
14151428
14161429 device = self .device
14171430
1431+ # take care of prompt being a raw tensor, either text or raw modality (image, video, actions, latents, etc)
1432+
1433+ if is_tensor (prompt ) and prompt .dtype == torch .float : # is modality with type 0 implicit
1434+ prompt = (0 , prompt )
1435+
1436+ if is_tensor (prompt ) and prompt .dtype in (torch .int , torch .long ): # is text only prompt
1437+ prompt = [prompt ]
1438+
1439+ elif isinstance (prompt , tuple ):
1440+ modality_type , modality = prompt
1441+
1442+ mod = self .get_modality_info (modality_type )
1443+
1444+ if exists (mod .encoder ):
1445+ with torch .no_grad ():
1446+ mod .encoder .eval ()
1447+ modality = self .maybe_add_temp_batch_dim (mod .encoder )(modality ).detach ()
1448+
1449+ modality_shape_tuple = self .get_modality_shape (modality , modality_type )
1450+ modality_shape_str = join ([* map (str , modality_shape_tuple )], ',' )
1451+ modality_meta_info = self .char_tokenizer (modality_shape_str , device = device )
1452+
1453+ prompt = [
1454+ tensor ([self .meta_id ]),
1455+ modality_meta_info ,
1456+ tensor ([mod .som_id ]),
1457+ (modality_type , modality ),
1458+ tensor ([mod .eom_id ]),
1459+ ]
1460+
1461+ # sos
1462+
14181463 init_text_seq = tensor ([self .sos_id ], device = device )
14191464
14201465 # just take care of prompt being zero dimensions
14211466
1422- prompt = tree_map_tensor (prompt , lambda t : rearrange (t , '-> 1' ) if t .ndim == 0 else t )
1423-
14241467 modality_sample = [init_text_seq , * default (prompt , [])]
14251468
14261469 # take care of moving to device
14271470
14281471 modality_sample = tree_map_tensor (modality_sample , lambda t : t .to (device ))
1472+ modality_sample = tree_map_tensor (modality_sample , lambda t : rearrange (t , '-> 1' ) if t .ndim == 0 else t )
14291473
14301474 modality_sample = concat_contiguous_text (modality_sample )
14311475
14321476 * _ , last_modality_sample = modality_sample
1433- assert last_modality_sample .dtype in (torch .int , torch .long ), 'prompt must be text tokens'
14341477
14351478 curr_length = 0
14361479 curr_modality_id = None
0 commit comments