1212i, j - sequence (row, col)
1313"""
1414
15+ import os
16+
1517import math
1618from functools import partial
1719from typing import NamedTuple , Callable , Literal
2931
3032from rotary_embedding_torch import RotaryEmbedding , apply_rotary_emb
3133
32- from beartype import beartype
33- from beartype .door import is_bearable
3434from tqdm import tqdm
3535
3636pad_sequence = partial (pad_sequence , batch_first = True )
3737
3838# tensor typing
3939
4040import jaxtyping
41+ from jaxtyping import jaxtyped
42+ from beartype import beartype
43+ from beartype .door import is_bearable
4144
4245class TorchTyping :
4346 def __init__ (self , abstract_dtype ):
@@ -63,11 +66,11 @@ def __getitem__(self, shapes: str):
6366
6467# constants
6568
66- ModalitySample = list [Int ['_' ] | Float ['_ _ ' ] | tuple [int , Float ['_ _' ]]]
69+ ModalitySample = list [Int ['_' ] | Float ['... ' ] | tuple [int , Float ['_ _' ]]]
6770
6871ModalityTokenTransform = str | Callable | None
6972
70- RawModalityPositions = list [list [tuple [int , int ]]]
73+ RawModalityPositions = list [list [tuple [int , int , int ]]]
7174
7275class LossBreakdown (NamedTuple ):
7376 total : Float ['' ]
@@ -112,6 +115,10 @@ def inner(self, *args, **kwargs):
112115 return out
113116 return inner
114117
118+ # maybe typecheck
119+
120+ typecheck = jaxtyped (typechecker = beartype ) if os .environ .get ('TYPECHECK' , '' ).lower () in ('1' , 'true' ) else identity
121+
115122# default function for constituting modality shape from string
116123
117124def default_to_modality_shape_fn (maybe_shape_str ) -> tuple [int , ...]:
@@ -233,6 +240,7 @@ def inner(score, b, h, q_idx, kv_idx):
233240
234241# converting a raw list of modality offsets and lengths to tensor
235242
243+ @typecheck
236244def modality_positions_to_tensor (
237245 modalities : RawModalityPositions ,
238246 pad_value = 0 ,
@@ -245,10 +253,11 @@ def modality_positions_to_tensor(
245253 if modalities .ndim == 2 :
246254 modalities = modalities .reshape (* modalities .shape , 3 )
247255
248- return modalities
256+ return modalities . long ()
249257
250258# sanitizing modalities tensor, making sure it is ordered
251259
260+ @typecheck
252261def order_modality_positions_by_seq_offset (
253262 modalities : Int ['b m 3' ]
254263) -> tuple [Int ['b m 3' ], Int ['b m' ]]:
@@ -312,6 +321,7 @@ def embed_modality_tokens(
312321
313322# functions for managing modality token mask
314323
324+ @typecheck
315325def modality_positions_to_is_modality_mask (
316326 seq_len : int ,
317327 modalities : Int ['b m 3' ],
@@ -342,6 +352,7 @@ def modality_positions_to_is_modality_mask(
342352
343353 return einx .logical_and ('b t m, b m n -> b t m n' , is_instance_for_type , is_modality_along_seq )
344354
355+ @typecheck
345356def naive_attn_mask (
346357 seq_len : int ,
347358 modalities : Int ['b m 3' ],
@@ -383,10 +394,11 @@ def __init__(self, dim):
383394 self .dim = dim
384395 self .register_buffer ('weights' , torch .randn (dim // 2 ))
385396
397+ @typecheck
386398 def forward (
387399 self ,
388400 times : Float ['b n' ] | Float ['b' ]
389- ) -> Float ['b n {self.dim + 1}' ]:
401+ ) -> Float ['b n {self.dim+ 1}' ]:
390402
391403 if times .ndim == 1 :
392404 times = rearrange (times , 'b -> b 1' )
@@ -429,6 +441,7 @@ def __init__(
429441 nn .init .zeros_ (self .to_ada_ln_zero .weight )
430442 nn .init .constant_ (self .to_ada_ln_zero .bias , ada_ln_zero_init_bias )
431443
444+ @typecheck
432445 def forward_text (
433446 self ,
434447 x : Float ['b n {self.dim}' ],
@@ -453,6 +466,7 @@ def forward_text(
453466
454467 return out
455468
469+ @typecheck
456470 def forward_modality (
457471 self ,
458472 x : Float ['b n {self.dim}' ],
@@ -487,6 +501,7 @@ def forward_modality(
487501
488502 return modalities_out
489503
504+ @typecheck
490505 def forward (
491506 self ,
492507 x : Float ['b n {self.dim}' ],
@@ -746,17 +761,18 @@ def __init__(
746761 self .layers = layers
747762 self .norm = RMSNorm (dim )
748763
764+ @typecheck
749765 def forward (
750766 self ,
751767 x ,
752768 times : Float ['' ] | Float ['b' ] | Float ['b n' ] | None = None ,
753769 attn_mask : Bool ['b i j' ] | None = None ,
754- modality_positions : RawModalityPositions | Int ['b n 2 ' ] | None = None ,
770+ modality_positions : RawModalityPositions | Int ['b m 3 ' ] | None = None ,
755771 is_any_modality : bool | Bool ['b n' ] | None = None ,
756772 rotary_emb : Tensor | None = None ,
757773 cache : Tensor | None = None ,
758774 modality_only = False ,
759- causal_mask : bool = False ,
775+ causal_mask = False ,
760776 return_kv_cache = False
761777 ):
762778 batch , seq_len , device , input_is_cuda = x .shape [0 ], x .shape [- 2 ], x .device , x .is_cuda
@@ -1026,6 +1042,7 @@ def device(self):
10261042
10271043 @torch .no_grad ()
10281044 @eval_decorator
1045+ @typecheck
10291046 def sample (
10301047 self ,
10311048 prompt : ModalitySample | None = None ,
@@ -1251,6 +1268,7 @@ def ode_step_fn(step_times, denoised):
12511268
12521269 return processed_modality_sample
12531270
1271+ @typecheck
12541272 def forward_text (
12551273 self ,
12561274 text : Int ['b n' ],
@@ -1259,8 +1277,8 @@ def forward_text(
12591277 cache : Tensor | None = None ,
12601278 return_kv_cache = False
12611279 ) -> (
1262- Float ['' ],
1263- Float ['b n d' ],
1280+ Float ['' ] |
1281+ Float ['b n d' ] |
12641282 tuple [Float ['b n d' ], list [Float ['...' ]]]
12651283 ):
12661284
@@ -1309,6 +1327,7 @@ def forward_text(
13091327
13101328 return loss
13111329
1330+ @typecheck
13121331 def forward_modality (
13131332 self ,
13141333 modalities : Float ['b ...' ],
@@ -1371,6 +1390,7 @@ def forward_modality(
13711390
13721391 return F .mse_loss (pred_flow , flow )
13731392
1393+ @typecheck
13741394 def forward (
13751395 self ,
13761396 modalities : (
@@ -1388,9 +1408,9 @@ def forward(
13881408 return_embed = False ,
13891409 return_kv_cache = False ,
13901410 ) -> (
1391- Float ['b n l' ] |
1392- Float ['b n d' ] |
1393- tuple [Float ['b n _' ], list [ Float [ '...' ]] ] |
1411+ Float ['b _ l' ] |
1412+ Float ['b _ d' ] |
1413+ tuple [Float ['b _ _' ], Tensor ] |
13941414 Float ['' ] |
13951415 tuple [Float ['' ], LossBreakdown ]
13961416 ):
0 commit comments