Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 43 additions & 99 deletions renderers/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from __future__ import annotations

import contextlib
import enum
import io
import logging
import queue
import threading
Expand Down Expand Up @@ -1164,7 +1162,7 @@ def _model_has_vision_config(model_name: str) -> bool:


# Models for which ``fastokens`` is known to diverge from vanilla
# ``transformers.AutoTokenizer`` and therefore must NOT be patched.
# ``transformers.AutoTokenizer`` and therefore must NOT be adapted.
# Empirical audit ran each entry of ``MODEL_RENDERER_MAP`` through both
# backends. The entries below fail to load under fastokens (DeepSeek-V3
# family — Metaspace pretokenizer not yet implemented).
Expand All @@ -1182,7 +1180,7 @@ def _model_has_vision_config(model_name: str) -> bool:
)


_FASTOKENS_PATCH_LOCK = threading.Lock()
_FASTOKENS_ANNOUNCE_LOCK = threading.Lock()
_FASTOKENS_ANNOUNCED = False


Expand Down Expand Up @@ -1222,46 +1220,27 @@ def _preserve_requested_tokenizer_name(
return tokenizer


def _patched_load(model_name_or_path: str, **kwargs):
"""Run ``AutoTokenizer.from_pretrained`` with fastokens patched in
process-locally — patch around the load, unpatch right after.

fastokens captures the loaded backend on a per-tokenizer basis, so
after we unpatch the returned tokenizer object continues to use
fastokens for ``encode``/``decode`` while subsequent
``AutoTokenizer.from_pretrained`` calls (outside our control) go
back to vanilla. This keeps the global side effect minimal.

fastokens itself prints ``[fastokens] patch_transformers: ...`` to
stdout on every patch/unpatch call. Building a pool of size N would
therefore emit ~N lines (more under thread contention, where some
threads see ``already patched``). We swallow those prints under a
lock — ``contextlib.redirect_stdout`` swaps ``sys.stdout``
process-wide, so the lock keeps unrelated stdout writes from other
threads from disappearing into our buffer. The patch/unpatch calls
are cheap; only the brief patch+unpatch is serialized, the actual
``from_pretrained`` still runs concurrently across pool slots. A
single ``logger.info`` is emitted on the first patch so the fast
path is still discoverable in logs.
"""
import fastokens
def _adapt_tokenizer_with_fastokens(tokenizer):
"""Replace only one fully loaded tokenizer's backend with fastokens."""
# This is the same shim used by patch_transformers, scoped to one object.
from fastokens._compat import _TokenizerShim

global _FASTOKENS_ANNOUNCED

with _FASTOKENS_PATCH_LOCK:
with contextlib.redirect_stdout(io.StringIO()):
fastokens.patch_transformers()
backend = getattr(tokenizer, "_tokenizer", None)
if backend is None:
raise TypeError(
f"{type(tokenizer).__name__} has no fast tokenizer backend to adapt"
)
tokenizer._tokenizer = _TokenizerShim(backend)

with _FASTOKENS_ANNOUNCE_LOCK:
if not _FASTOKENS_ANNOUNCED:
logger.info(
"fastokens enabled — tokenizers load through the Rust BPE fast path (~10x encode speedup)."
)
_FASTOKENS_ANNOUNCED = True
try:
return _load_tokenizer_via_auto(model_name_or_path, **kwargs)
finally:
with _FASTOKENS_PATCH_LOCK:
with contextlib.redirect_stdout(io.StringIO()):
fastokens.unpatch_transformers()
return tokenizer


def _load_fast_tokenizer_directly(
Expand Down Expand Up @@ -1335,23 +1314,21 @@ def load_tokenizer(
``trust_remote_code=True`` AND a pinned ``revision=<sha>`` so
transformers only executes the reviewed commit's tokenizer Python.

**Performance** — ``use_fastokens=True`` (default) routes the load
through ``fastokens.patch_transformers()`` so the resulting tokenizer
encodes ~10x faster than vanilla ``tokenizers``. The patch is
bracketed: it's applied before ``from_pretrained`` and removed
immediately after, so global ``AutoTokenizer.from_pretrained`` calls
elsewhere in the user's process are not affected.
**Performance** — ``use_fastokens=True`` (default) replaces only the
returned tokenizer's backend with the fastokens compatibility shim, so it
encodes ~10x faster than vanilla ``tokenizers`` without patching
Transformers process-wide during the load.

Models in ``FASTOKENS_INCOMPATIBLE`` (DeepSeek-V3 family) skip the
patch — fastokens currently fails to load them. Pass
adaptation — fastokens currently fails to load them. Pass
``use_fastokens=False`` to force the vanilla backend for any other
model.

Unknown / fine-tuned model paths fall through to
``trust_remote_code=False`` and the patched-load fast path. If
fastokens raises during the patched load (e.g. an unknown
pre-tokenizer type), we automatically retry with the vanilla
backend and emit an INFO log.
``trust_remote_code=False`` and the fastokens adaptation path. If
fastokens rejects the loaded backend (e.g. an unknown pre-tokenizer
type), we automatically retry with the vanilla backend and emit an
INFO log.

``AutoTokenizer.from_pretrained`` eagerly builds the model config to
resolve the tokenizer class. If that construction raises on a
Expand All @@ -1368,26 +1345,19 @@ def load_tokenizer(
load_name_or_path = _tokenizer_source_for(model_name_or_path)
kwargs = _tokenizer_load_kwargs(load_name_or_path)

if not use_fastokens or load_name_or_path in FASTOKENS_INCOMPATIBLE:
tok = _load_tokenizer_via_auto(load_name_or_path, **kwargs)
return _preserve_requested_tokenizer_name(
tok,
requested_name_or_path=model_name_or_path,
loaded_name_or_path=load_name_or_path,
)

try:
tok = _patched_load(load_name_or_path, **kwargs)
except Exception as exc:
logger.info(
"fastokens could not load %r (%s: %s); falling back to vanilla "
"AutoTokenizer. Add this model to FASTOKENS_INCOMPATIBLE in "
"renderers.base to suppress the retry.",
load_name_or_path,
type(exc).__name__,
str(exc)[:160],
)
tok = _load_tokenizer_via_auto(load_name_or_path, **kwargs)
tok = _load_tokenizer_via_auto(load_name_or_path, **kwargs)
if use_fastokens and load_name_or_path not in FASTOKENS_INCOMPATIBLE:
try:
tok = _adapt_tokenizer_with_fastokens(tok)
except Exception as exc:
logger.info(
"fastokens could not adapt %r (%s: %s); using vanilla "
"AutoTokenizer. Add this model to FASTOKENS_INCOMPATIBLE in "
"renderers.base to suppress this message.",
load_name_or_path,
type(exc).__name__,
str(exc)[:160],
)

return _preserve_requested_tokenizer_name(
tok,
Expand Down Expand Up @@ -1721,8 +1691,8 @@ def trim_to_turn_close(
# Per-model offset-aware tokenizer cache. ``attribute_text_segments``
# uses the fast HuggingFace tokenizer's ``offset_mapping`` to attribute
# each token to its source text segment under one BPE pass. Fastokens
# (the Rust BPE we patch in by default for ~10x faster encode) does not
# track character offsets — the patched tokenizer's
# (the Rust BPE backend we install by default for ~10x faster encode) does not
# track character offsets — the adapted tokenizer's
# ``return_offsets_mapping=True`` raises ``NotImplementedError``. So we
# keep a parallel vanilla tokenizer per model purely for offset queries.
# Memory cost is one extra tokenizer per *unique* model name across all
Expand Down Expand Up @@ -1773,44 +1743,18 @@ def _has_offsets(tok) -> bool:
except (NotImplementedError, ValueError, TypeError):
return False

# We want HF's Rust tokenizer with offset tracking, not the fastokens
# shim. The shim is installed by a *process-global* monkeypatch that
# ``load_tokenizer`` toggles per pool-slot load, so a plain reload here
# can race a concurrent slot's open patch window and silently pick up
# the offset-less shim (then get cached, poisoning the process). So:
# load, verify offsets, and if missing, reload with the patch forced
# off — serialized against pool patch/unpatch via ``_FASTOKENS_PATCH_LOCK``
# so no concurrent window can swap the shim back in mid-load — then
# restore the prior patch state. Never cache a non-offset tokenizer.
# This path deliberately loads through vanilla Transformers because
# fastokens adaptation is scoped to ``load_tokenizer``'s returned object.
offset_tok = _load_tokenizer_via_auto(load_name_or_path, **kwargs)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Restore unpatched offset reloads under global fastokens

When the host process has already called fastokens.patch_transformers() (for example because another serving stack enables it globally), this reload is not vanilla: _load_tokenizer_via_auto() returns another fastokens shim, _has_offsets() stays false, and hand-coded renderers that call attribute_text_segments() raise instead of using the offset cache. The deleted fallback used to temporarily unpatch in exactly this case, so this path still needs to force the offset tokenizer reload through unpatched Transformers.

Useful? React with 👍 / 👎.

offset_tok = _preserve_requested_tokenizer_name(
offset_tok,
requested_name_or_path=name_or_path,
loaded_name_or_path=load_name_or_path,
)
if not _has_offsets(offset_tok):
import fastokens

with _FASTOKENS_PATCH_LOCK:
was_patched = bool(getattr(fastokens, "_patched", False))
if was_patched:
with contextlib.redirect_stdout(io.StringIO()):
fastokens.unpatch_transformers()
try:
offset_tok = _load_tokenizer_via_auto(load_name_or_path, **kwargs)
offset_tok = _preserve_requested_tokenizer_name(
offset_tok,
requested_name_or_path=name_or_path,
loaded_name_or_path=load_name_or_path,
)
finally:
if was_patched:
with contextlib.redirect_stdout(io.StringIO()):
fastokens.patch_transformers()
if not _has_offsets(offset_tok):
raise RuntimeError(
f"Could not load an offset-capable tokenizer for {name_or_path!r}: "
"offset_mapping is unavailable even with the fastokens patch off. "
"offset_mapping is unavailable from vanilla Transformers. "
"Hand-coded renderers require a fast tokenizer for body/scaffold "
"attribution."
)
Expand Down Expand Up @@ -1840,7 +1784,7 @@ def attribute_text_segments(
the most recently entered segment.

Requires a HuggingFace fast tokenizer with offset tracking. The
``fastokens`` patch ``load_tokenizer`` applies by default does
The ``fastokens`` backend ``load_tokenizer`` installs by default does
**not** track offsets — when that's the case we transparently load
a vanilla offset-capable tokenizer for the same model and cache it
(see :func:`_get_offset_tokenizer`). Hand-coded renderers are only
Expand Down
94 changes: 61 additions & 33 deletions tests/test_load_tokenizer_fastokens.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Coverage for the fastokens fast-path in ``renderers.base.load_tokenizer``.

``load_tokenizer`` defaults to routing every supported model through
``fastokens.patch_transformers()`` for ~10x faster encode. Models in
``FASTOKENS_INCOMPATIBLE`` skip the patch (DeepSeek's Metaspace
``load_tokenizer`` defaults to adapting every supported model's returned
backend with fastokens for ~10x faster encode. Models in
``FASTOKENS_INCOMPATIBLE`` skip adaptation (DeepSeek's Metaspace
pretokenizer isn't supported). Callers can opt out per-call with
``use_fastokens=False``.

Expand All @@ -16,16 +16,18 @@
3. With ``use_fastokens=False``, the resulting tokenizer is vanilla.
4. For incompat models, the fast path is silently skipped and the
tokenizer still loads + encodes correctly.
5. The fastokens patch is removed immediately after the load so it
doesn't leak into the caller's process — subsequent
``AutoTokenizer.from_pretrained`` calls outside ``load_tokenizer``
use vanilla.
5. Fastokens adaptation is scoped to the returned tokenizer, so concurrent
and subsequent ``AutoTokenizer.from_pretrained`` calls stay vanilla.
"""

from __future__ import annotations

import concurrent.futures
import threading

import pytest
from transformers import AutoTokenizer
from tokenizers import Tokenizer, models
from transformers import AutoTokenizer, PreTrainedTokenizerFast

from renderers.base import (
FASTOKENS_INCOMPATIBLE,
Expand Down Expand Up @@ -93,13 +95,16 @@ def test_fast_and_vanilla_encode_identically_on_compatible_model():
" ".join([f"word_{i}" for i in range(50)]),
]
for s in samples:
assert fast.encode(s, add_special_tokens=False) == vanilla.encode(
s, add_special_tokens=False
), f"encode diverged on {s!r}"
fast_ids = fast.encode(s, add_special_tokens=False)
vanilla_ids = vanilla.encode(s, add_special_tokens=False)
assert fast_ids == vanilla_ids, f"encode diverged on {s!r}"
assert fast.decode(fast_ids) == vanilla.decode(vanilla_ids), (
f"decode diverged on {s!r}"
)


# ---------------------------------------------------------------------------
# Denylist: incompat models silently skip the patch and still load.
# Denylist: incompat models silently skip adaptation and still load.
# ---------------------------------------------------------------------------


Expand All @@ -120,49 +125,78 @@ def test_incompat_model_loads_via_vanilla_backend(model):
pytest.skip(f"{model}: repo unreachable in this env ({e})")
tok = load_tokenizer(model)
assert "Shim" not in _backend_class_name(tok), (
f"{model}: should NOT have been patched; got {_backend_class_name(tok)!r}"
f"{model}: should NOT have been adapted; got {_backend_class_name(tok)!r}"
)
# And it still encodes.
ids = tok.encode("hello", add_special_tokens=False)
assert len(ids) > 0


# ---------------------------------------------------------------------------
# Patch must not leak: AutoTokenizer.from_pretrained calls OUTSIDE
# load_tokenizer should still produce a vanilla tokenizer.
# Adaptation must not leak: AutoTokenizer.from_pretrained calls OUTSIDE
# load_tokenizer should always produce a vanilla tokenizer.
# ---------------------------------------------------------------------------


def test_patch_is_unloaded_after_call():
"""``load_tokenizer`` brackets the fastokens patch. After it returns
a fastokens-shimmed tokenizer, a fresh ``AutoTokenizer.from_pretrained``
call must NOT pick up the patch — the user's process stays clean."""
def test_fastokens_is_scoped_to_loaded_tokenizer():
"""A fresh ``AutoTokenizer`` call stays vanilla after fast adaptation."""
fast = load_tokenizer(_FAST_MODEL)
assert "Shim" in _backend_class_name(fast), "preconditions: fast path active"

# Now call AutoTokenizer.from_pretrained directly. It MUST be vanilla.
direct = AutoTokenizer.from_pretrained(_FAST_MODEL, trust_remote_code=False)
assert "Shim" not in _backend_class_name(direct), (
f"fastokens patch leaked into user-side AutoTokenizer call: "
f"fastokens leaked into user-side AutoTokenizer call: "
f"got {_backend_class_name(direct)!r}"
)


def test_fastokens_load_does_not_patch_transformers_concurrently(monkeypatch, tmp_path):
"""A slow renderer load must not expose fastokens to unrelated callers."""
import renderers.base as rb

backend = Tokenizer(models.BPE({"[UNK]": 0, "hello": 1}, [], unk_token="[UNK]"))
PreTrainedTokenizerFast(
tokenizer_object=backend, unk_token="[UNK]"
).save_pretrained(tmp_path)

started = threading.Event()
release = threading.Event()
tokenizer = AutoTokenizer.from_pretrained(tmp_path)

def slow_load(*args, **kwargs):
started.set()
assert release.wait(timeout=5)
return tokenizer

monkeypatch.setattr(rb, "_load_tokenizer_via_auto", slow_load)

with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
loaded = executor.submit(load_tokenizer, "test-model")
assert started.wait(timeout=5)
try:
direct = AutoTokenizer.from_pretrained(tmp_path)
assert "Shim" not in _backend_class_name(direct)
finally:
release.set()

assert "Shim" in _backend_class_name(loaded.result())


# ---------------------------------------------------------------------------
# Failure-mode fallback: if fastokens raises during the patched load,
# Failure-mode fallback: if fastokens raises during per-tokenizer adaptation,
# load_tokenizer falls back to vanilla without surfacing the error.
# ---------------------------------------------------------------------------


def test_fallback_on_fastokens_load_error(monkeypatch):
"""Simulate fastokens raising during patched load — load_tokenizer
should fall back to vanilla and return a working tokenizer."""
def test_fallback_on_fastokens_adaptation_error(monkeypatch):
"""An adaptation error returns the already-loaded vanilla tokenizer."""
import renderers.base as rb

def _boom(*args, **kwargs):
raise ValueError("simulated fastokens failure: unsupported pre-tokenizer")

monkeypatch.setattr(rb, "_patched_load", _boom)
monkeypatch.setattr(rb, "_adapt_tokenizer_with_fastokens", _boom)

tok = load_tokenizer(_FAST_MODEL) # default use_fastokens=True
# The vanilla fallback ran — backend is not a fastokens shim.
Expand All @@ -172,18 +206,12 @@ def _boom(*args, **kwargs):


# ---------------------------------------------------------------------------
# Print suppression: fastokens itself prints "[fastokens]
# patch_transformers: ..." on every patch/unpatch call. Building a
# RendererPool of size N would emit ~N lines (the pool factory calls
# load_tokenizer once per slot). load_tokenizer swallows that stdout
# chatter and emits a single INFO log on the first patch instead.
# Fastokens adaptation emits one INFO log per process, not once per pool slot.
# ---------------------------------------------------------------------------


def test_no_fastokens_stdout_chatter(capsys, caplog):
"""``load_tokenizer`` must not leak ``[fastokens]`` prints onto
stdout, and must emit exactly one INFO log per process announcing
the fast path (not once per call)."""
"""Fast adaptation stays quiet and announces its path once per process."""
import logging

import renderers.base as rb
Expand Down
Loading