Skip to content

Commit 2caa45a

Browse files
committed
allow for attending to nothing
1 parent 8d135c3 commit 2caa45a

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.1.5"
3+
version = "0.1.6"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

transfusion_pytorch/transfusion.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,8 @@ def __init__(
600600
heads = 8,
601601
dropout = 0.,
602602
softcap_value = 50.,
603-
use_flex_attn = False
603+
use_flex_attn = False,
604+
gate_values = True
604605
):
605606
super().__init__()
606607
self.scale = dim_head ** -0.5
@@ -616,6 +617,11 @@ def __init__(
616617
Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads)
617618
)
618619

620+
self.to_gates = nn.Sequential(
621+
nn.Linear(dim, heads, bias = False),
622+
Rearrange('b n h -> b h n 1', h = heads)
623+
) if gate_values else None
624+
619625
self.softcap_value = softcap_value
620626

621627
self.dropout = nn.Dropout(dropout)
@@ -705,6 +711,11 @@ def forward(
705711

706712
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
707713

714+
# maybe gate values
715+
716+
if exists(self.to_gates):
717+
out = out * self.to_gates(x).sigmoid()
718+
708719
# combine heads and out
709720

710721
out = self.to_out(out)

0 commit comments

Comments
 (0)