Skip to content

Commit 80c3fbf

Browse files
committed
supporting model output being the cleaned image, given new paper from Kaiming He group, still need to verify with experiments
1 parent 80b7eac commit 80c3fbf

3 files changed

Lines changed: 62 additions & 2 deletions

File tree

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,3 +276,15 @@ $ pip install -U diffusers transformers accelerate scipy ftfy safetensors
276276
url = {https://api.semanticscholar.org/CorpusID:277104144}
277277
}
278278
```
279+
280+
```bibtex
281+
@misc{li2025basicsletdenoisinggenerative,
282+
title = {Back to Basics: Let Denoising Generative Models Denoise},
283+
author = {Tianhong Li and Kaiming He},
284+
year = {2025},
285+
eprint = {2511.13720},
286+
archivePrefix = {arXiv},
287+
primaryClass = {cs.CV},
288+
url = {https://arxiv.org/abs/2511.13720},
289+
}
290+
```

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.12.0"
3+
version = "0.14.0"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

transfusion_pytorch/transfusion.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,30 @@ def apply_fn_modality_type(
624624

625625
return tree_unflatten(out, tree_spec)
626626

627+
# decorator for model output to flow
628+
629+
def get_model_output_to_flow_fn(
630+
noised: Tensor,
631+
times: Tensor,
632+
eps = 5e-2,
633+
return_decorator = False
634+
):
635+
def to_flow(out):
636+
padded_times = append_dims(times, out.ndim - 1)
637+
flow = (out - noised) / (1. - padded_times).clamp_min(eps)
638+
return flow
639+
640+
if not return_decorator:
641+
return to_flow
642+
643+
def decorator(fn):
644+
def inner(embed):
645+
out = fn(embed)
646+
return to_flow(out)
647+
return inner
648+
649+
return decorator
650+
627651
# sampling related functions
628652

629653
# min_p for text
@@ -1230,6 +1254,7 @@ def __init__(
12301254
*,
12311255
num_text_tokens,
12321256
transformer: dict | Transformer,
1257+
pred_clean = False,
12331258
dim_latent: int | tuple[int, ...] | None = None,
12341259
channel_first_latent: bool | tuple[bool, ...] = False,
12351260
add_pos_emb: bool | tuple[bool, ...] = False,
@@ -1251,6 +1276,7 @@ def __init__(
12511276
rtol = 1e-5,
12521277
method = 'midpoint'
12531278
),
1279+
eps = 5e-2
12541280
):
12551281
super().__init__()
12561282

@@ -1453,6 +1479,11 @@ def __init__(
14531479
self.has_recon_loss = reconstruction_loss_weight > 0.
14541480
self.reconstruction_loss_weight = reconstruction_loss_weight
14551481

1482+
# whether model is predicting clean
1483+
1484+
self.pred_clean = pred_clean
1485+
self.eps = eps
1486+
14561487
# flow sampling related
14571488

14581489
self.odeint_fn = partial(odeint, **odeint_kwargs)
@@ -2001,6 +2032,13 @@ def forward_modality(
20012032
else:
20022033
noised_tokens = tokens
20032034

2035+
# save the noised and times
2036+
2037+
model_output_to_flow = identity
2038+
2039+
if self.pred_clean:
2040+
model_output_to_flow = get_model_output_to_flow_fn(noised_tokens, times)
2041+
20042042
# from latent to model tokens
20052043

20062044
noised_tokens = mod.latent_to_model(noised_tokens)
@@ -2034,7 +2072,9 @@ def forward_modality(
20342072

20352073
embed = inverse_pack_axial_dims(embed)
20362074

2037-
pred_flow = mod.model_to_latent(embed)
2075+
model_output = mod.model_to_latent(embed)
2076+
2077+
pred_flow = model_output_to_flow(model_output)
20382078

20392079
if not return_loss:
20402080
return pred_flow
@@ -2475,6 +2515,14 @@ def inner(pred_flow):
24752515

24762516
inverse_fn = model_to_pred_flow(batch_index, offset + precede_modality_tokens, modality_length, unpack_modality_shape)
24772517

2518+
# maybe decorate the function if model output is predicting clean
2519+
2520+
if self.pred_clean:
2521+
decorator = get_model_output_to_flow_fn(modality_tensor, modality_time, self.eps, return_decorator = True)
2522+
inverse_fn = decorator(inverse_fn)
2523+
2524+
# store function for extracting flow later
2525+
24782526
get_pred_flows[modality_type].append(inverse_fn)
24792527

24802528
# increment offset

0 commit comments

Comments
 (0)