Skip to content

Commit 14b30bc

Browse files
committed
handle rotary and polar positional embeddings with caching when attention layers is not wrapped
1 parent 78eaa5e commit 14b30bc

File tree

3 files changed

+50
-3
lines changed

3 files changed

+50
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "x-transformers"
3-
version = "2.15.2"
3+
version = "2.16.0"
44
description = "X-Transformers"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_x_transformers.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1533,7 +1533,7 @@ def test_seq_start_pos_parity():
15331533
input_not_include_cache = True,
15341534
attn_layers = Decoder(
15351535
dim = 32,
1536-
depth = 2
1536+
depth = 2,
15371537
)
15381538
)
15391539

@@ -1568,3 +1568,44 @@ def test_seq_start_pos_parity():
15681568
is_not_masked = torch.arange(seq_len) >= seq_start_pos[:, None]
15691569

15701570
assert torch.allclose(parallel_logits[is_not_masked], seq_logits[is_not_masked], atol = 1e-5)
1571+
1572+
@param('pos_emb_type', ('rotary', 'polar'))
1573+
def test_pos_emb_parity(pos_emb_type):
1574+
pos_emb_kwargs = {f'{pos_emb_type}_pos_emb': True}
1575+
1576+
model = Decoder(
1577+
dim = 128,
1578+
depth = 1,
1579+
heads = 4,
1580+
**pos_emb_kwargs
1581+
)
1582+
1583+
model.eval()
1584+
1585+
# parallel
1586+
1587+
seq = torch.randn(2, 15, 128)
1588+
1589+
parallel_logits = model(seq)
1590+
1591+
# prompt pass
1592+
1593+
prompt = seq[:, :10]
1594+
cache = None
1595+
all_seq_logits = []
1596+
1597+
logits, cache = model(prompt, cache = cache, return_hiddens = True)
1598+
all_seq_logits.append(logits[:, -1:])
1599+
1600+
# sequential
1601+
1602+
for i in range(4):
1603+
input_embeds = seq[:, 10 + i : 10 + i + 1]
1604+
logits, cache = model(input_embeds, cache = cache, return_hiddens = True)
1605+
all_seq_logits.append(logits[:, -1:])
1606+
1607+
seq_logits = torch.cat(all_seq_logits, dim = 1)
1608+
1609+
parallel_logits_without_prompt = parallel_logits[:, 9 : 14]
1610+
1611+
assert torch.allclose(seq_logits, parallel_logits_without_prompt, atol = 1e-5)

x_transformers/x_transformers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2710,7 +2710,7 @@ def forward(
27102710
mems = None,
27112711
mem_masks = None,
27122712
seq_start_pos: Tensor | None = None,
2713-
seq_pos_offset: int = 0,
2713+
seq_pos_offset = None,
27142714
cache: LayerIntermediates | None = None,
27152715
input_not_include_cache = False,
27162716
cache_age = 1,
@@ -2738,6 +2738,12 @@ def forward(
27382738
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
27392739
assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa'
27402740

2741+
# handle seq pos offset if not passed in from wrapper
2742+
# default to 0, but if cache is detected, set appropriate for the relative positional embeddings
2743+
2744+
if not exists(seq_pos_offset):
2745+
seq_pos_offset = cache.cache_length if exists(cache) else 0
2746+
27412747
# handle condition
27422748

27432749
if exists(condition):

0 commit comments

Comments
 (0)