@@ -147,6 +147,15 @@ def inner(t: Tensor, *args, **kwargs) -> Tensor:
147147 return out
148148 return inner
149149
150+ def pack_with_inverse (t , pattern ):
151+ packed , packed_shape = pack (t , pattern )
152+
153+ def inverse (out , inv_pattern = None ):
154+ inv_pattern = default (inv_pattern , pattern )
155+ return unpack (out , packed_shape , inv_pattern )
156+
157+ return packed , inverse
158+
150159def pack_one_with_inverse (t , pattern ):
151160 packed , packed_shape = pack ([t ], pattern )
152161
@@ -1115,6 +1124,7 @@ def __init__(
11151124 self ,
11161125 * ,
11171126 num_text_tokens ,
1127+ num_register_tokens = 16 ,
11181128 transformer : dict | Transformer ,
11191129 dim_latent : int | tuple [int , ...] | None = None ,
11201130 channel_first_latent : bool | tuple [bool , ...] = False ,
@@ -1298,6 +1308,11 @@ def __init__(
12981308 self .latent_to_model_projs = ModuleList (latent_to_model_projs )
12991309 self .model_to_latent_projs = ModuleList (model_to_latent_projs )
13001310
1311+ # maybe register tokens (used in hymba, renamed from "meta" to register as "meta" was reserved from above already for the modality meta tag)
1312+
1313+ self .register_tokens = nn .Parameter (torch .zeros (num_register_tokens , dim ))
1314+ nn .init .normal_ (self .register_tokens , std = 0.02 )
1315+
13011316 # relative positions
13021317
13031318 self .rotary_emb = RotaryEmbedding (transformer .dim_head )
@@ -2392,6 +2407,7 @@ def inner(pred_flow):
23922407 if modality_positions .numel () == 0 :
23932408 modality_positions = F .pad (modality_positions , (0 , 0 , 0 , 1 ))
23942409
2410+
23952411 # sort the modalities tensor and sanitize, readying for noising of modalities
23962412
23972413 modality_positions , sorted_indices = order_modality_positions_by_seq_offset (modality_positions )
@@ -2415,6 +2431,18 @@ def inner(pred_flow):
24152431
24162432 tokens = einx .where ('b n, b n d, b n d' , is_any_modality , modality_tokens , text_tokens )
24172433
2434+ # handle maybe meta / register tokens
2435+
2436+ register_tokens = repeat (self .register_tokens , '... -> b ...' , b = batch )
2437+
2438+ num_register_tokens = register_tokens .shape [- 2 ]
2439+ seq_len += num_register_tokens
2440+
2441+ tokens , unpack_register_tokens = pack_with_inverse ((register_tokens , tokens ), 'b * d' )
2442+ modality_positions [..., 1 ] += num_register_tokens
2443+
2444+ is_modalities = F .pad (is_modalities , (num_register_tokens , 0 ), value = False )
2445+
24182446 # derive rotary positions
24192447
24202448 rotary_positions = derive_rotary_positions_from_modality_positions (seq_len , modality_positions )
@@ -2455,6 +2483,10 @@ def inner(pred_flow):
24552483 return_kv_cache = True
24562484 )
24572485
2486+ # remove register tokens
2487+
2488+ _ , embed = unpack_register_tokens (embed )
2489+
24582490 # early return for embedding for decoding modality
24592491
24602492 if return_embed :
0 commit comments