Skip to content

Commit c90d8fe

Browse files
committed
improve is modality mask creation
1 parent 1fbee8c commit c90d8fe

File tree

2 files changed

+11
-14
lines changed

2 files changed

+11
-14
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.0.6"
3+
version = "0.0.7"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

transfusion_pytorch/transfusion.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -184,24 +184,21 @@ def modality_positions_to_is_modality_mask(
184184
offset = F.pad(offset, (1, 0))
185185
modalities = modalities + offset.to(modalities)
186186

187-
is_modalities = []
188-
189-
for modality_id in range(num_modalities):
190-
one_modality_type_mask = modalities[..., 0] == modality_id
191-
one_modality = modalities.masked_fill(~one_modality_type_mask[..., None], 0)
187+
seq = torch.arange(seq_len, device = device)
188+
type_seq = torch.arange(num_modalities, device = device)
192189

193-
left, right = one_modality[..., 1:].cumsum(dim = -1).unbind(dim = -1)
190+
modality_types = modalities[..., 0]
194191

195-
seq = torch.arange(seq_len, device = device)
192+
left, right = modalities[..., 1:].cumsum(dim = -1).unbind(dim = -1)
196193

197-
is_modality = (
198-
einx.greater_equal('i, b m -> b m i', seq, left) &
199-
einx.less('j, b m -> b m j', seq, right)
200-
)
194+
is_instance_for_type = einx.equal('b m, t -> b t m', modality_types, type_seq)
201195

202-
is_modalities.append(is_modality)
196+
is_modality_along_seq = (
197+
einx.greater_equal('i, b m -> b m i', seq, left) &
198+
einx.less('j, b m -> b m j', seq, right)
199+
)
203200

204-
return torch.stack(is_modalities, dim = 1)
201+
return einx.logical_and('b t m, b m n -> b t m n', is_instance_for_type, is_modality_along_seq)
205202

206203
def naive_attn_mask(
207204
seq_len: int,

0 commit comments

Comments
 (0)