Skip to content

Commit 372e3d2

Browse files
committed
muon is here to stay
1 parent ea04570 commit 372e3d2

File tree

4 files changed

+50
-20
lines changed

4 files changed

+50
-20
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.11.0"
3+
version = "0.12.0"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -49,6 +49,7 @@ build-backend = "hatchling.build"
4949
[project.optional-dependencies]
5050

5151
examples = [
52+
"adam-atan2-pytorch>=0.2.2",
5253
"datasets",
5354
"diffusers"
5455
]

train_image_only.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from torch import tensor
66
from torch.nn import Module
77
from torch.utils.data import Dataset, DataLoader
8-
from torch.optim import Adam
8+
9+
from adam_atan2_pytorch import MuonAdamAtan2
910

1011
from einops import rearrange
1112

@@ -83,7 +84,7 @@ def cycle(iter_dl):
8384
dataloader = DataLoader(dataset, batch_size = 32, shuffle = True)
8485
iter_dl = cycle(dataloader)
8586

86-
optimizer = Adam(model.parameters(), lr = 8e-4)
87+
optimizer = MuonAdamAtan2(model.muon_parameters(), model.parameters(), lr = 8e-4)
8788

8889
# train loop
8990

train_mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
IMAGE_AFTER_TEXT = True # False for captioning, True for text-to-image
2525
USE_PROMPT = False # whether to use prompting, or synthesize from start token
2626
NUM_TRAIN_STEPS = 20_000
27-
SAMPLE_EVERY = 250
27+
SAMPLE_EVERY = 500
2828
CHANNEL_FIRST = True
2929

3030
# functions

transfusion_pytorch/transfusion.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -813,18 +813,24 @@ def forward(self, x):
813813
x, gates = x.chunk(2, dim = -1)
814814
return F.gelu(gates) * x
815815

816-
def FeedForward(
817-
dim,
818-
expansion_factor = 4.,
819-
dropout = 0.
820-
):
821-
dim_inner = int(dim * expansion_factor * 2 / 3)
822-
return nn.Sequential(
823-
Linear(dim, dim_inner * 2),
824-
GEGLU(),
825-
nn.Dropout(dropout),
826-
Linear(dim_inner, dim)
827-
)
816+
class FeedForward(Module):
817+
def __init__(
818+
self,
819+
dim,
820+
expansion_factor = 4.,
821+
dropout = 0.
822+
):
823+
super().__init__()
824+
dim_inner = int(dim * expansion_factor * 2 / 3)
825+
self.net = nn.Sequential(
826+
Linear(dim, dim_inner * 2),
827+
GEGLU(),
828+
nn.Dropout(dropout),
829+
Linear(dim_inner, dim)
830+
)
831+
832+
def forward(self, x):
833+
return self.net(x)
828834

829835
class Attention(Module):
830836
def __init__(
@@ -847,9 +853,14 @@ def __init__(
847853
assert not (use_flex_attn and not exists(flex_attention)), 'flex attention is only available on torch 2.5.0 (nightly) onwards'
848854
self.use_flex_attn = use_flex_attn
849855

850-
self.to_qkv = nn.Sequential(
851-
Linear(dim, dim_inner * 3, bias = False),
852-
Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads)
856+
self.to_qk = nn.Sequential(
857+
Linear(dim, dim_inner * 2, bias = False),
858+
Rearrange('b n (qk h d) -> qk b h n d', qk = 2, h = heads)
859+
)
860+
861+
self.to_v = nn.Sequential(
862+
Linear(dim, dim_inner, bias = False),
863+
Rearrange('b n (h d) -> b h n d', h = heads)
853864
)
854865

855866
self.to_learned_value_residual = nn.Sequential(
@@ -902,7 +913,7 @@ def forward(
902913

903914
# project to queries, keys, values
904915

905-
q, k, v = self.to_qkv(x)
916+
q, k, v = (*self.to_qk(x), self.to_v(x))
906917

907918
# value residual
908919

@@ -1522,6 +1533,23 @@ def parameters_without_encoder_decoder(self):
15221533
set(self.modality_decoder.parameters())
15231534
)
15241535

1536+
def muon_parameters(self):
1537+
params = []
1538+
1539+
for m in self.modules():
1540+
if isinstance(m, Attention):
1541+
params.extend([
1542+
*m.to_v.parameters(),
1543+
*m.to_out.parameters(),
1544+
])
1545+
elif isinstance(m, FeedForward):
1546+
params.extend([
1547+
m.net[0].weight,
1548+
m.net[-1].weight
1549+
])
1550+
1551+
return params
1552+
15251553
def create_dataloader(
15261554
self,
15271555
*args,

0 commit comments

Comments
 (0)