Skip to content

Commit 09aae3d

Browse files
committed
make unet skips optional, as a bit unproven for language models
1 parent 46cdcdf commit 09aae3d

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.0.45"
3+
version = "0.1.0"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

transfusion_pytorch/transfusion.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,7 @@ def __init__(
701701
ff_expansion_factor = 4,
702702
attn_kwargs: dict = dict(),
703703
ff_kwargs: dict = dict(),
704+
unet_skips = True,
704705
use_flex_attn = False
705706
):
706707
super().__init__()
@@ -720,7 +721,7 @@ def __init__(
720721
for ind in range(depth):
721722
is_latter_half = ind >= (depth // 2)
722723

723-
skip_proj = Linear(dim * 2, dim, bias = False) if is_latter_half else None
724+
skip_proj = Linear(dim * 2, dim, bias = False) if is_latter_half and unet_skips else None
724725

725726
attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout, use_flex_attn = use_flex_attn, **attn_kwargs)
726727

@@ -815,10 +816,12 @@ def forward(
815816
if is_first_half:
816817
skips.append(x)
817818

818-
if is_later_half:
819+
if is_later_half and exists(skip_proj):
819820
skip = skips.pop()
821+
822+
residual = x
820823
x = torch.cat((x, skip), dim = -1)
821-
x = skip_proj(x)
824+
x = skip_proj(x) + residual
822825

823826
# attention and feedforward
824827

0 commit comments

Comments
 (0)