Skip to content

Commit 6be1b10

Browse files
committed
first make contrived unet work with modality only training
1 parent 506134b commit 6be1b10

File tree

3 files changed

+145
-17
lines changed

3 files changed

+145
-17
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.4.12"
3+
version = "0.4.14"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

train_image_only_with_unet.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from shutil import rmtree
2+
from pathlib import Path
3+
4+
import torch
5+
from torch import tensor, nn
6+
from torch.nn import Module
7+
from torch.utils.data import Dataset, DataLoader
8+
from torch.optim import Adam
9+
10+
from einops import rearrange
11+
12+
import torchvision
13+
import torchvision.transforms as T
14+
from torchvision.utils import save_image
15+
16+
from transfusion_pytorch import Transfusion, print_modality_sample
17+
18+
rmtree('./results', ignore_errors = True)
19+
results_folder = Path('./results')
20+
results_folder.mkdir(exist_ok = True, parents = True)
21+
22+
# functions
23+
24+
def divisible_by(num, den):
25+
return (num % den) == 0
26+
27+
# encoder / decoder
28+
29+
class Encoder(Module):
30+
def forward(self, x):
31+
x = rearrange(x, '... 1 (h p1) (w p2) -> ... (p1 p2) h w', p1 = 2, p2 = 2)
32+
return x * 2 - 1
33+
34+
class Decoder(Module):
35+
def forward(self, x):
36+
x = rearrange(x, '... (p1 p2) h w -> ... 1 (h p1) (w p2)', p1 = 2, p2 = 2, h = 14)
37+
return ((x + 1) * 0.5).clamp(min = 0., max = 1.)
38+
39+
model = Transfusion(
40+
num_text_tokens = 10,
41+
dim_latent = 4,
42+
channel_first_latent = True,
43+
modality_default_shape = (14, 14),
44+
modality_encoder = Encoder(),
45+
modality_decoder = Decoder(),
46+
pre_post_transformer_enc_dec = (
47+
nn.Conv2d(4, 64, 3, 2, 1),
48+
nn.ConvTranspose2d(64, 4, 3, 2, 1, output_padding = 1),
49+
),
50+
add_pos_emb = True,
51+
modality_num_dim = 2,
52+
velocity_consistency_loss_weight = 0.1,
53+
transformer = dict(
54+
dim = 64,
55+
depth = 4,
56+
dim_head = 32,
57+
heads = 8
58+
)
59+
).cuda()
60+
61+
ema_model = model.create_ema()
62+
63+
class MnistDataset(Dataset):
64+
def __init__(self):
65+
self.mnist = torchvision.datasets.MNIST(
66+
'./data',
67+
download = True
68+
)
69+
70+
def __len__(self):
71+
return len(self.mnist)
72+
73+
def __getitem__(self, idx):
74+
pil, labels = self.mnist[idx]
75+
digit_tensor = T.PILToTensor()(pil)
76+
return (digit_tensor / 255).float()
77+
78+
def cycle(iter_dl):
79+
while True:
80+
for batch in iter_dl:
81+
yield batch
82+
83+
dataset = MnistDataset()
84+
85+
dataloader = DataLoader(dataset, batch_size = 32, shuffle = True)
86+
iter_dl = cycle(dataloader)
87+
88+
optimizer = Adam(model.parameters(), lr = 8e-4)
89+
90+
# train loop
91+
92+
for step in range(1, 100_000 + 1):
93+
94+
loss = model(next(iter_dl), velocity_consistency_ema_model = ema_model)
95+
loss.backward()
96+
97+
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
98+
99+
optimizer.step()
100+
optimizer.zero_grad()
101+
102+
ema_model.update()
103+
104+
print(f'{step}: {loss.item():.3f}')
105+
106+
if divisible_by(step, 500):
107+
image = ema_model.generate_modality_only(batch_size = 64)
108+
109+
save_image(
110+
rearrange(image, '(gh gw) 1 h w -> 1 (gh h) (gw w)', gh = 8).detach().cpu(),
111+
str(results_folder / f'{step}.png')
112+
)

transfusion_pytorch/transfusion.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,7 @@ def __init__(
10501050
add_pos_emb: bool | tuple[bool, ...] = False,
10511051
modality_encoder: Module | tuple[Module, ...] | None = None,
10521052
modality_decoder: Module | tuple[Module, ...] | None = None,
1053+
pre_post_transformer_enc_dec: tuple[Module, Module] | tuple[tuple[Module, Module], ...] | None = None,
10531054
modality_token_transform: tuple[ModalityTokenTransform, ...] | ModalityTokenTransform | None = None,
10541055
modality_default_shape: tuple[int, ...] | tuple[tuple[int, ...], ...] | None = None,
10551056
fallback_to_default_shape_if_invalid = False,
@@ -1188,14 +1189,32 @@ def __init__(
11881189

11891190
assert len(self.modality_token_transform) == self.num_modalities
11901191

1192+
# prepare pre-post transformer encoder / decoder, for the learnable unets as in paper
1193+
1194+
if is_bearable(pre_post_transformer_enc_dec, tuple[Module, Module]):
1195+
pre_post_transformer_enc_dec = (pre_post_transformer_enc_dec,)
1196+
1197+
pre_post_transformer_enc_dec = cast_tuple(pre_post_transformer_enc_dec, self.num_modalities)
1198+
assert len(pre_post_transformer_enc_dec) == self.num_modalities
1199+
11911200
# latent to model and back
11921201
# by default will be Linear, with or without rearranges depending on channel_first_latent setting
1193-
# can also be overridden for the unet down/up as in the paper
1202+
# can also be overridden for the unet down/up as in the paper with `pre_post_transformer_enc_dec: tuple[Module, Module]`
11941203

11951204
latent_to_model_projs = []
11961205
model_to_latent_projs = []
11971206

1198-
for dim_latent, one_channel_first_latent in zip(self.dim_latents, self.channel_first_latent):
1207+
for (
1208+
dim_latent,
1209+
one_channel_first_latent,
1210+
enc_dec,
1211+
) in zip(
1212+
self.dim_latents,
1213+
self.channel_first_latent,
1214+
pre_post_transformer_enc_dec
1215+
):
1216+
1217+
pre_attend_enc, post_attend_dec = default(enc_dec, (None, None))
11991218

12001219
latent_to_model_proj = Linear(dim_latent, dim) if dim_latent != dim else nn.Identity()
12011220
model_to_latent_proj = Linear(dim, dim_latent, bias = False)
@@ -1204,8 +1223,8 @@ def __init__(
12041223
latent_to_model_proj = nn.Sequential(Rearrange('b d ... -> b ... d'), latent_to_model_proj, Rearrange('b ... d -> b d ...'))
12051224
model_to_latent_proj = nn.Sequential(Rearrange('b d ... -> b ... d'), model_to_latent_proj, Rearrange('b ... d -> b d ...'))
12061225

1207-
latent_to_model_projs.append(latent_to_model_proj)
1208-
model_to_latent_projs.append(model_to_latent_proj)
1226+
latent_to_model_projs.append(default(pre_attend_enc, latent_to_model_proj))
1227+
model_to_latent_projs.append(default(post_attend_dec, model_to_latent_proj))
12091228

12101229
self.latent_to_model_projs = ModuleList(latent_to_model_projs)
12111230
self.model_to_latent_projs = ModuleList(model_to_latent_projs)
@@ -1706,18 +1725,6 @@ def forward_modality(
17061725
mod.encoder.eval()
17071726
modalities = self.maybe_add_temp_batch_dim(mod.encoder)(modalities).detach()
17081727

1709-
# axial positions
1710-
1711-
if mod.add_pos_emb:
1712-
assert exists(mod.num_dim), f'modality_num_dim must be set for modality {modality_type} if further injecting axial positional embedding'
1713-
1714-
if mod.channel_first_latent:
1715-
_, _, *axial_dims = modalities.shape
1716-
else:
1717-
_, *axial_dims, _ = modalities.shape
1718-
1719-
assert len(axial_dims) == mod.num_dim, f'received modalities of ndim {len(axial_dims)} but expected {modality_num_dim}'
1720-
17211728
# shapes and device
17221729

17231730
tokens = modalities
@@ -1753,6 +1760,15 @@ def forward_modality(
17531760
if mod.channel_first_latent:
17541761
noised_tokens = rearrange(noised_tokens, 'b d ... -> b ... d')
17551762

1763+
# axial positions
1764+
1765+
if mod.add_pos_emb:
1766+
assert exists(mod.num_dim), f'modality_num_dim must be set for modality {modality_type} if further injecting axial positional embedding'
1767+
1768+
_, *axial_dims, _ = noised_tokens.shape
1769+
1770+
assert len(axial_dims) == mod.num_dim, f'received modalities of ndim {len(axial_dims)} but expected {modality_num_dim}'
1771+
17561772
# maybe transform
17571773

17581774
noised_tokens = mod.token_transform(noised_tokens)

0 commit comments

Comments
 (0)