Skip to content

Commit 09f0b51

Browse files
committed
final touch-up to free transformer for the day
1 parent d41cffd commit 09f0b51

File tree

4 files changed

+13
-5
lines changed

4 files changed

+13
-5
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "x-transformers"
3-
version = "2.11.12"
3+
version = "2.11.14"
44
description = "X-Transformers"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_x_transformers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,6 +1437,11 @@ def test_free(
14371437

14381438
assert aux_loss.numel() == 1
14391439

1440+
rand_indices = torch.randint(0, 2 ** 8, ())
1441+
generated = model.generate(seq[:, :1], 32, latents = rand_indices)
1442+
1443+
assert generated.shape == (1, 32)
1444+
14401445
def test_kv_input_residual():
14411446
attn = Decoder(
14421447
dim = 256,

train_free.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ def decode_tokens(tokens):
6363
latent_bits = LATENT_BITS
6464
).cuda()
6565

66-
rand_index = torch.randint(0, 2 ** LATENT_BITS, ())
67-
latents = F.one_hot(rand_index, 2 ** LATENT_BITS).float().cuda()
66+
one_hot_indices = torch.randint(0, 2 ** LATENT_BITS, ())
6867

6968
# prepare enwik8 data
7069

@@ -126,9 +125,9 @@ def __len__(self):
126125
sample = model.generate(
127126
prompts = inp,
128127
seq_len = GENERATE_LENGTH,
129-
latents = latents
128+
latents = one_hot_indices
130129
)
131130

132131
output_str = decode_tokens(sample)
133132

134-
print(f'\n\nlatent {rand_index.tolist()} - ', output_str)
133+
print(f'\n\nlatent {one_hot_indices.tolist()} - ', output_str)

x_transformers/free_transformer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,10 @@ def generate(
282282
if not is_tensor(latents):
283283
latents = tensor(latents, device = self.device)
284284

285+
if latents.dtype in (torch.int, torch.long):
286+
# if given as indices
287+
latents = F.one_hot(latents, self.binary_mapper.num_codes).float()
288+
285289
if latents.ndim == 1: # repeat latents
286290
latents = repeat(latents, 'd -> b 1 d', b = batch)
287291
elif latents.ndim == 2:

0 commit comments

Comments
 (0)