From 20b7c2a6c207c382e44d8cb0f302d8a618bc8d53 Mon Sep 17 00:00:00 2001 From: alvinttang Date: Sat, 25 Apr 2026 22:30:18 +0800 Subject: [PATCH] fix(t5): correct streaming detokenization for byte-level BPE and SentencePiece MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The streaming generation loop in `t5.py` was decoding each new token in isolation via `convert_ids_to_tokens` and only stripping the SentencePiece `▁` marker. This is incorrect for tokenizers whose ids do not map 1:1 to characters: byte-level BPE tokenizers (e.g. `Salesforce/codet5p-220m`, which uses `RobertaTokenizerFast`) produced raw `Ġ`/`Ċ` markers in the output, and SentencePiece byte-fallback would leak hex byte tokens. Issue #1021 demonstrated: $ python3 t5.py --model codet5p-220m --prompt 'def print_hello_world():' ĊĠĠĠĠprintĠ"HelloĠWorld"ĊĊ Fix: introduce a small streaming-state API on `Tokenizer` that delegates to the underlying HuggingFace tokenizer's own `decode` over the full prefix and yields the new substring on each step. This is the simplest correct fix suggested by the maintainer in the issue thread. Adds a regression test that compares streamed output to full-prefix decode for both `t5-small` (SentencePiece) and `Salesforce/codet5p-220m` (byte-level BPE), and asserts no `Ġ`/`Ċ` markers leak. --- t5/t5.py | 26 ++++++++++++--- t5/test_tokenizer_stream.py | 63 +++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 4 deletions(-) create mode 100644 t5/test_tokenizer_stream.py diff --git a/t5/t5.py b/t5/t5.py index 04a0da8c5..c0e22e208 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -37,9 +37,26 @@ def encode(self, s: str) -> mx.array: )["input_ids"] ) - def decode(self, t: List[int], with_sep: bool = True) -> str: - tokens = self._tokenizer.convert_ids_to_tokens(t) - return "".join(t.replace("▁", " " if with_sep else "") for t in tokens) + def decode(self, t: List[int]) -> str: + return self._tokenizer.decode(t, skip_special_tokens=True) + + def new_stream_state(self) -> dict: + """Return state for incremental streaming detokenization. + + Per-token detokenization is unsafe for most HuggingFace tokenizers + (byte-level BPE merges and SentencePiece byte-fallback both span + multiple ids), so the streaming loop must hand each new id to + ``stream_decode`` along with this state object. + """ + return {"ids": [], "text": ""} + + def stream_decode(self, state: dict, token_id: int) -> str: + """Append ``token_id`` to ``state`` and return the new text fragment.""" + state["ids"].append(int(token_id)) + text = self._tokenizer.decode(state["ids"], skip_special_tokens=True) + delta = text[len(state["text"]) :] + state["text"] = text + return delta def _relative_position_bucket( @@ -505,13 +522,14 @@ def sample(logits): print("Input: ", args.prompt, flush=True) start = perf_counter_ns() + stream_state = tokenizer.new_stream_state() for token, n_tokens in zip( generate(args.prompt, model, tokenizer, args.temp), range(args.max_tokens) ): if token.item() == tokenizer.eos_id: break print( - tokenizer.decode([token.item()], with_sep=n_tokens > 0), + tokenizer.stream_decode(stream_state, token.item()), end="", flush=True, ) diff --git a/t5/test_tokenizer_stream.py b/t5/test_tokenizer_stream.py new file mode 100644 index 000000000..d86829dd5 --- /dev/null +++ b/t5/test_tokenizer_stream.py @@ -0,0 +1,63 @@ +"""Tests for streaming detokenization in t5.py. + +Reproduces and guards against the bug reported in +https://github.com/ml-explore/mlx-examples/issues/1021 where streaming +generation emitted raw subword markers (e.g. ``Ġ``/``Ċ`` for the byte-level +BPE tokenizer used by ``Salesforce/codet5p-220m``) instead of plain text. + +The previous implementation called ``convert_ids_to_tokens`` per generated +token and only stripped the SentencePiece ``▁`` marker, which is wrong for +byte-level BPE tokenizers and for SentencePiece tokens that use byte-fallback. +The fix decodes the running prefix with the HuggingFace tokenizer's own +``decode`` and yields the new substring on each step. +""" + +from types import SimpleNamespace + +import pytest + +from t5 import Tokenizer + + +def _make_tokenizer(model_name: str, decoder_start_id: int = 0) -> Tokenizer: + config = SimpleNamespace(decoder_start_token_id=decoder_start_id) + return Tokenizer(config, model_name) + + +def _stream_decode(tokenizer: Tokenizer, ids): + """Mimic the streaming loop in ``t5.py``'s ``__main__``.""" + out = [] + state = tokenizer.new_stream_state() + for tid in ids: + out.append(tokenizer.stream_decode(state, int(tid))) + return "".join(out) + + +@pytest.mark.parametrize( + "model_name,prompt", + [ + ("t5-small", "translate English to German: That is good."), + ("Salesforce/codet5p-220m", "def print_hello_world():"), + ], +) +def test_stream_decode_matches_full_decode(model_name: str, prompt: str) -> None: + tokenizer = _make_tokenizer(model_name) + ids = tokenizer._tokenizer(prompt, return_tensors="np")["input_ids"][0].tolist() + + expected = tokenizer._tokenizer.decode(ids, skip_special_tokens=True) + streamed = _stream_decode(tokenizer, ids) + + assert streamed == expected + + +def test_stream_decode_strips_byte_level_bpe_markers() -> None: + """Regression: ``Ġ`` / ``Ċ`` must never leak into streamed output.""" + tokenizer = _make_tokenizer("Salesforce/codet5p-220m") + ids = tokenizer._tokenizer('def hello():\n print("hi")', return_tensors="np")[ + "input_ids" + ][0].tolist() + + streamed = _stream_decode(tokenizer, ids) + + assert "Ġ" not in streamed + assert "Ċ" not in streamed