Skip to content

Commit a66ccf4

Browse files
committed
able to turn on jaxtyped + beartype with environment flag and fix some type annotations
1 parent 6767221 commit a66ccf4

File tree

2 files changed

+34
-14
lines changed

2 files changed

+34
-14
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.1.2"
3+
version = "0.1.4"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

transfusion_pytorch/transfusion.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
i, j - sequence (row, col)
1313
"""
1414

15+
import os
16+
1517
import math
1618
from functools import partial
1719
from typing import NamedTuple, Callable, Literal
@@ -29,15 +31,16 @@
2931

3032
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
3133

32-
from beartype import beartype
33-
from beartype.door import is_bearable
3434
from tqdm import tqdm
3535

3636
pad_sequence = partial(pad_sequence, batch_first = True)
3737

3838
# tensor typing
3939

4040
import jaxtyping
41+
from jaxtyping import jaxtyped
42+
from beartype import beartype
43+
from beartype.door import is_bearable
4144

4245
class TorchTyping:
4346
def __init__(self, abstract_dtype):
@@ -63,11 +66,11 @@ def __getitem__(self, shapes: str):
6366

6467
# constants
6568

66-
ModalitySample = list[Int['_'] | Float['_ _'] | tuple[int, Float['_ _']]]
69+
ModalitySample = list[Int['_'] | Float['...'] | tuple[int, Float['_ _']]]
6770

6871
ModalityTokenTransform = str | Callable | None
6972

70-
RawModalityPositions = list[list[tuple[int, int]]]
73+
RawModalityPositions = list[list[tuple[int, int, int]]]
7174

7275
class LossBreakdown(NamedTuple):
7376
total: Float['']
@@ -112,6 +115,10 @@ def inner(self, *args, **kwargs):
112115
return out
113116
return inner
114117

118+
# maybe typecheck
119+
120+
typecheck = jaxtyped(typechecker = beartype) if os.environ.get('TYPECHECK', '').lower() in ('1', 'true') else identity
121+
115122
# default function for constituting modality shape from string
116123

117124
def default_to_modality_shape_fn(maybe_shape_str) -> tuple[int, ...]:
@@ -233,6 +240,7 @@ def inner(score, b, h, q_idx, kv_idx):
233240

234241
# converting a raw list of modality offsets and lengths to tensor
235242

243+
@typecheck
236244
def modality_positions_to_tensor(
237245
modalities: RawModalityPositions,
238246
pad_value = 0,
@@ -245,10 +253,11 @@ def modality_positions_to_tensor(
245253
if modalities.ndim == 2:
246254
modalities = modalities.reshape(*modalities.shape, 3)
247255

248-
return modalities
256+
return modalities.long()
249257

250258
# sanitizing modalities tensor, making sure it is ordered
251259

260+
@typecheck
252261
def order_modality_positions_by_seq_offset(
253262
modalities: Int['b m 3']
254263
) -> tuple[Int['b m 3'], Int['b m']]:
@@ -312,6 +321,7 @@ def embed_modality_tokens(
312321

313322
# functions for managing modality token mask
314323

324+
@typecheck
315325
def modality_positions_to_is_modality_mask(
316326
seq_len: int,
317327
modalities: Int['b m 3'],
@@ -342,6 +352,7 @@ def modality_positions_to_is_modality_mask(
342352

343353
return einx.logical_and('b t m, b m n -> b t m n', is_instance_for_type, is_modality_along_seq)
344354

355+
@typecheck
345356
def naive_attn_mask(
346357
seq_len: int,
347358
modalities: Int['b m 3'],
@@ -383,10 +394,11 @@ def __init__(self, dim):
383394
self.dim = dim
384395
self.register_buffer('weights', torch.randn(dim // 2))
385396

397+
@typecheck
386398
def forward(
387399
self,
388400
times: Float['b n'] | Float['b']
389-
) -> Float['b n {self.dim + 1}']:
401+
) -> Float['b n {self.dim+1}']:
390402

391403
if times.ndim == 1:
392404
times = rearrange(times, 'b -> b 1')
@@ -429,6 +441,7 @@ def __init__(
429441
nn.init.zeros_(self.to_ada_ln_zero.weight)
430442
nn.init.constant_(self.to_ada_ln_zero.bias, ada_ln_zero_init_bias)
431443

444+
@typecheck
432445
def forward_text(
433446
self,
434447
x: Float['b n {self.dim}'],
@@ -453,6 +466,7 @@ def forward_text(
453466

454467
return out
455468

469+
@typecheck
456470
def forward_modality(
457471
self,
458472
x: Float['b n {self.dim}'],
@@ -487,6 +501,7 @@ def forward_modality(
487501

488502
return modalities_out
489503

504+
@typecheck
490505
def forward(
491506
self,
492507
x: Float['b n {self.dim}'],
@@ -746,17 +761,18 @@ def __init__(
746761
self.layers = layers
747762
self.norm = RMSNorm(dim)
748763

764+
@typecheck
749765
def forward(
750766
self,
751767
x,
752768
times: Float[''] | Float['b'] | Float['b n'] | None = None,
753769
attn_mask: Bool['b i j'] | None = None,
754-
modality_positions: RawModalityPositions | Int['b n 2'] | None = None,
770+
modality_positions: RawModalityPositions | Int['b m 3'] | None = None,
755771
is_any_modality: bool | Bool['b n'] | None = None,
756772
rotary_emb: Tensor | None = None,
757773
cache: Tensor | None = None,
758774
modality_only = False,
759-
causal_mask: bool = False,
775+
causal_mask = False,
760776
return_kv_cache = False
761777
):
762778
batch, seq_len, device, input_is_cuda = x.shape[0], x.shape[-2], x.device, x.is_cuda
@@ -1026,6 +1042,7 @@ def device(self):
10261042

10271043
@torch.no_grad()
10281044
@eval_decorator
1045+
@typecheck
10291046
def sample(
10301047
self,
10311048
prompt: ModalitySample | None = None,
@@ -1251,6 +1268,7 @@ def ode_step_fn(step_times, denoised):
12511268

12521269
return processed_modality_sample
12531270

1271+
@typecheck
12541272
def forward_text(
12551273
self,
12561274
text: Int['b n'],
@@ -1259,8 +1277,8 @@ def forward_text(
12591277
cache: Tensor | None = None,
12601278
return_kv_cache = False
12611279
) -> (
1262-
Float[''],
1263-
Float['b n d'],
1280+
Float[''] |
1281+
Float['b n d'] |
12641282
tuple[Float['b n d'], list[Float['...']]]
12651283
):
12661284

@@ -1309,6 +1327,7 @@ def forward_text(
13091327

13101328
return loss
13111329

1330+
@typecheck
13121331
def forward_modality(
13131332
self,
13141333
modalities: Float['b ...'],
@@ -1371,6 +1390,7 @@ def forward_modality(
13711390

13721391
return F.mse_loss(pred_flow, flow)
13731392

1393+
@typecheck
13741394
def forward(
13751395
self,
13761396
modalities: (
@@ -1388,9 +1408,9 @@ def forward(
13881408
return_embed = False,
13891409
return_kv_cache = False,
13901410
) -> (
1391-
Float['b n l'] |
1392-
Float['b n d'] |
1393-
tuple[Float['b n _'], list[Float['...']]] |
1411+
Float['b _ l'] |
1412+
Float['b _ d'] |
1413+
tuple[Float['b _ _'], Tensor] |
13941414
Float[''] |
13951415
tuple[Float[''], LossBreakdown]
13961416
):

0 commit comments

Comments
 (0)