Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions t5/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,26 @@ def encode(self, s: str) -> mx.array:
)["input_ids"]
)

def decode(self, t: List[int], with_sep: bool = True) -> str:
tokens = self._tokenizer.convert_ids_to_tokens(t)
return "".join(t.replace("▁", " " if with_sep else "") for t in tokens)
def decode(self, t: List[int]) -> str:
return self._tokenizer.decode(t, skip_special_tokens=True)

def new_stream_state(self) -> dict:
"""Return state for incremental streaming detokenization.

Per-token detokenization is unsafe for most HuggingFace tokenizers
(byte-level BPE merges and SentencePiece byte-fallback both span
multiple ids), so the streaming loop must hand each new id to
``stream_decode`` along with this state object.
"""
return {"ids": [], "text": ""}

def stream_decode(self, state: dict, token_id: int) -> str:
"""Append ``token_id`` to ``state`` and return the new text fragment."""
state["ids"].append(int(token_id))
text = self._tokenizer.decode(state["ids"], skip_special_tokens=True)
delta = text[len(state["text"]) :]
state["text"] = text
return delta


def _relative_position_bucket(
Expand Down Expand Up @@ -505,13 +522,14 @@ def sample(logits):
print("Input: ", args.prompt, flush=True)

start = perf_counter_ns()
stream_state = tokenizer.new_stream_state()
for token, n_tokens in zip(
generate(args.prompt, model, tokenizer, args.temp), range(args.max_tokens)
):
if token.item() == tokenizer.eos_id:
break
print(
tokenizer.decode([token.item()], with_sep=n_tokens > 0),
tokenizer.stream_decode(stream_state, token.item()),
end="",
flush=True,
)
Expand Down
63 changes: 63 additions & 0 deletions t5/test_tokenizer_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Tests for streaming detokenization in t5.py.

Reproduces and guards against the bug reported in
https://github.com/ml-explore/mlx-examples/issues/1021 where streaming
generation emitted raw subword markers (e.g. ``Ġ``/``Ċ`` for the byte-level
BPE tokenizer used by ``Salesforce/codet5p-220m``) instead of plain text.

The previous implementation called ``convert_ids_to_tokens`` per generated
token and only stripped the SentencePiece ``▁`` marker, which is wrong for
byte-level BPE tokenizers and for SentencePiece tokens that use byte-fallback.
The fix decodes the running prefix with the HuggingFace tokenizer's own
``decode`` and yields the new substring on each step.
"""

from types import SimpleNamespace

import pytest

from t5 import Tokenizer


def _make_tokenizer(model_name: str, decoder_start_id: int = 0) -> Tokenizer:
config = SimpleNamespace(decoder_start_token_id=decoder_start_id)
return Tokenizer(config, model_name)


def _stream_decode(tokenizer: Tokenizer, ids):
"""Mimic the streaming loop in ``t5.py``'s ``__main__``."""
out = []
state = tokenizer.new_stream_state()
for tid in ids:
out.append(tokenizer.stream_decode(state, int(tid)))
return "".join(out)


@pytest.mark.parametrize(
"model_name,prompt",
[
("t5-small", "translate English to German: That is good."),
("Salesforce/codet5p-220m", "def print_hello_world():"),
],
)
def test_stream_decode_matches_full_decode(model_name: str, prompt: str) -> None:
tokenizer = _make_tokenizer(model_name)
ids = tokenizer._tokenizer(prompt, return_tensors="np")["input_ids"][0].tolist()

expected = tokenizer._tokenizer.decode(ids, skip_special_tokens=True)
streamed = _stream_decode(tokenizer, ids)

assert streamed == expected


def test_stream_decode_strips_byte_level_bpe_markers() -> None:
"""Regression: ``Ġ`` / ``Ċ`` must never leak into streamed output."""
tokenizer = _make_tokenizer("Salesforce/codet5p-220m")
ids = tokenizer._tokenizer('def hello():\n print("hi")', return_tensors="np")[
"input_ids"
][0].tolist()

streamed = _stream_decode(tokenizer, ids)

assert "Ġ" not in streamed
assert "Ċ" not in streamed