Skip to content

Commit bf5b4cd

Browse files
committed
make sampling work without time conditioning
1 parent fba69cf commit bf5b4cd

File tree

2 files changed

+95
-18
lines changed

2 files changed

+95
-18
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.0.25"
3+
version = "0.0.26"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

transfusion_pytorch/transfusion.py

Lines changed: 94 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,15 @@ def divisible_by(num, den):
7171
def cast_tuple(t, length = 1):
7272
return t if isinstance(t, tuple) else ((t,) * length)
7373

74+
def eval_decorator(fn):
75+
def inner(self, *args, **kwargs):
76+
was_training = self.training
77+
self.eval()
78+
out = fn(self, *args, **kwargs)
79+
self.train(was_training)
80+
return out
81+
return inner
82+
7483
# tensor helpers
7584

7685
def l2norm(t):
@@ -593,7 +602,7 @@ def __init__(
593602

594603
# modality start and end ids
595604

596-
self.som_ids, self.eom_ids = som_eom_tensor.tolist()
605+
self.som_ids, self.eom_ids = modality_start_end_tensor.tolist()
597606

598607
# entire "sentence" start and end id
599608

@@ -637,34 +646,93 @@ def device(self):
637646
return next(self.parameters()).device
638647

639648
@torch.no_grad()
649+
@eval_decorator
640650
def sample(
641651
self,
642652
prompt: ModalitySample | None = None,
643653
max_length = 8192,
644654
text_temperature = 1.5,
645655
text_min_p = 0.1,
656+
modality_length = 32, # fix the modality token length for now, but this will be determined by the language model in a metadata tag
657+
modality_steps = 16
646658
) -> ModalitySample:
647659

648-
was_training = self.training
649-
self.eval()
660+
device = self.device
650661

651-
seq = tensor([self.sos_id], device = self.device)
662+
init_text_seq = tensor([self.sos_id], device = self.device)
663+
modality_sample = [init_text_seq]
652664

653-
for _ in tqdm(range(max_length)):
654-
logits = self.forward([[seq]], return_loss = False)
655-
logits = logits[0][-1]
665+
curr_length = 0
666+
curr_modality_id = None
667+
is_decoding_text = True # starts off with text decoding, and alternates with modalities depending on [som] tokens detected
656668

657-
logits = min_p_filter(logits, min_p = text_min_p)
658-
probs = (logits / text_temperature).softmax(dim = -1)
669+
with tqdm(total = max_length) as pbar:
659670

660-
sampled = torch.multinomial(probs, 1)
661-
seq = torch.cat((seq, sampled), dim = -1)
671+
while curr_length <= max_length:
662672

663-
if sampled.item() == self.eos_id:
664-
break
673+
if is_decoding_text:
674+
*_, seq = modality_sample
665675

666-
self.train(was_training)
667-
return [seq]
676+
logits = self.forward([modality_sample], return_loss = False)
677+
logits = logits[0][-1]
678+
679+
logits = min_p_filter(logits, min_p = text_min_p)
680+
probs = (logits / text_temperature).softmax(dim = -1)
681+
682+
sampled = torch.multinomial(probs, 1)
683+
684+
seq = torch.cat((seq, sampled), dim = -1)
685+
modality_sample[-1] = seq
686+
687+
pbar.update(1)
688+
curr_length += 1
689+
690+
sampled_token_id = sampled.item()
691+
692+
if sampled_token_id == self.eos_id:
693+
break
694+
695+
if sampled_token_id in self.som_ids:
696+
curr_modality_id = self.som_ids.index(sampled_token_id)
697+
is_decoding_text = False
698+
699+
else:
700+
assert exists(curr_modality_id)
701+
702+
latent_dim = self.dim_latents[curr_modality_id]
703+
noise = torch.randn((modality_length, latent_dim), device = device)
704+
705+
def ode_step_fn(t, denoised):
706+
embeds = self.forward([[*modality_sample, denoised]], return_loss = False, return_embed = True)
707+
708+
to_flow_pred = self.model_to_latent_preds[curr_modality_id]
709+
flow = to_flow_pred(embeds)
710+
711+
return flow[0, -modality_length:]
712+
713+
times = torch.linspace(0, 1, modality_steps, device = device)
714+
715+
trajectory = self.odeint_fn(ode_step_fn, noise, times)
716+
717+
# add the sampled modality tokens
718+
719+
sampled_modality = trajectory[-1]
720+
modality_sample.append((curr_modality_id, sampled_modality))
721+
722+
# add the appropriate [eom]
723+
724+
eom_id = self.eom_ids[curr_modality_id]
725+
modality_sample.append(tensor([eom_id], device = device))
726+
727+
# back to decoding text
728+
729+
pbar.update(modality_length)
730+
curr_length += modality_length
731+
732+
curr_modality_id = None
733+
is_decoding_text = True
734+
735+
return modality_sample
668736

669737
def forward(
670738
self,
@@ -675,12 +743,16 @@ def forward(
675743
None
676744
) = None,
677745
return_loss = True,
678-
return_breakdown = False
746+
return_breakdown = False,
747+
return_embed = False
679748
) -> (
680749
Float['b n l'] |
750+
Float['b n d'] |
681751
Float[''] |
682752
tuple[Float[''], LossBreakdown]
683753
):
754+
return_loss &= not return_embed
755+
684756
device = self.device
685757

686758
# add "sentence" start and end tokens when training
@@ -717,7 +789,7 @@ def forward(
717789
modality_type, modality_tensor = modality
718790

719791
assert 0 <= modality_type < self.num_modalities, f'received a modality index that is out of range. only {self.num_modalities} modalities specified'
720-
assert self.dim_latents[modality_type] == modality_tensor.shape[-1], 'mismatch for modality latent dimension - expected {self.dim_latents[modality_type]} but received {modality_tensor.shape[-1]}'
792+
assert self.dim_latents[modality_type] == modality_tensor.shape[-1], f'mismatch for modality latent dimension - expected {self.dim_latents[modality_type]} but received {modality_tensor.shape[-1]}'
721793

722794
length = modality_tensor.shape[0]
723795

@@ -874,6 +946,11 @@ def forward(
874946
modality_positions = modality_positions
875947
)
876948

949+
# early return for embedding for decoding modality
950+
951+
if return_embed:
952+
return embed
953+
877954
# text unembedding
878955

879956
text_logits = self.to_text_logits(embed)

0 commit comments

Comments
 (0)