Skip to content

Commit b04a8de

Browse files
committed
fallback to inefficient masking if input is on cpu
1 parent 09aae3d commit b04a8de

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.1.0"
3+
version = "0.1.1"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

transfusion_pytorch/transfusion.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -613,16 +613,17 @@ def __init__(
613613
def forward(
614614
self,
615615
x,
616-
attn_mask: Tensor | None = None,
616+
attn_mask: Tensor | None = None, # for manual masking
617617
rotary_emb: Tensor | None = None,
618618
cache: Tensor | None = None,
619619
causal = False,
620-
block_mask = None,
620+
block_mask = None, # only passed in for flex attention
621621
return_kv_cache = False
622622
):
623623
device = x.device
624624

625625
assert not (exists(block_mask) and exists(attn_mask))
626+
assert not (not self.use_flex_attn and exists(block_mask)), 'you cannot pass in the `block_mask` if `use_flex_attn` was not set to be `True`'
626627

627628
x = self.norm(x)
628629

@@ -748,7 +749,10 @@ def forward(
748749
causal_mask: bool = False,
749750
return_kv_cache = False
750751
):
751-
batch, seq_len, device = x.shape[0], x.shape[-2], x.device
752+
batch, seq_len, device, input_is_cuda = x.shape[0], x.shape[-2], x.device, x.is_cuda
753+
754+
should_use_flex_attn = input_is_cuda and self.use_flex_attn
755+
752756
assert not (exists(attn_mask) and exists(modality_positions))
753757

754758
# handle time
@@ -766,7 +770,7 @@ def forward(
766770
attn_mask_kwargs = dict()
767771

768772
if causal_mask:
769-
if self.use_flex_attn:
773+
if should_use_flex_attn:
770774
block_mask = create_block_mask(causal, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, device = device)
771775
attn_mask_kwargs.update(block_mask = block_mask)
772776
else:
@@ -775,7 +779,7 @@ def forward(
775779
if exists(modality_positions):
776780
assert not causal_mask
777781

778-
if self.use_flex_attn:
782+
if should_use_flex_attn:
779783
transfusion_mask_fn = transfusion_attn_mask(modality_positions)
780784
block_mask = create_block_mask(transfusion_mask_fn, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, device = device)
781785
attn_mask_kwargs.update(block_mask = block_mask)

0 commit comments

Comments
 (0)