Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ jobs:
pip install "xllamacpp>=0.2.0"
if [ "$MODULE" == "metal" ]; then
conda install -c conda-forge "ffmpeg<7"
pip install "mlx>=0.22.0"
pip install mlx-lm
pip install "mlx-vlm>=0.3.4"
pip install "mlx>=0.22.0,<0.31.2"
pip install "mlx-lm<0.31.3"
pip install "mlx-vlm>=0.3.4,<0.5.0"
pip install mlx-whisper
pip install f5-tts-mlx
pip install qwen-vl-utils!=0.0.9
Expand Down
5 changes: 3 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,9 @@ vllm =
sglang =
sglang[srt]>=0.4.2.post4 ; sys_platform=='linux'
mlx =
mlx-lm>=0.21.5 ; sys_platform=='darwin' and platform_machine=='arm64'
mlx-vlm>=0.3.4 ; sys_platform=='darwin' and platform_machine=='arm64'
mlx>=0.22.0,<0.31.2 ; sys_platform=='darwin' and platform_machine=='arm64'
mlx-lm>=0.21.5,<0.31.3 ; sys_platform=='darwin' and platform_machine=='arm64'
mlx-vlm>=0.3.4,<0.5.0 ; sys_platform=='darwin' and platform_machine=='arm64'
mlx-whisper ; sys_platform=='darwin' and platform_machine=='arm64'
f5-tts-mlx ; sys_platform=='darwin' and platform_machine=='arm64'
mlx-audio ; sys_platform=='darwin' and platform_machine=='arm64'
Expand Down
4 changes: 3 additions & 1 deletion xinference/core/virtual_env_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@
"huggingface-hub<1.0",
],
"mlx": [
"mlx-lm>=0.24.0",
"mlx>=0.22.0,<0.31.2",
"mlx-lm>=0.24.0,<0.31.3",
"mlx-vlm>=0.3.4,<0.5.0",
],
"llama.cpp": [
"xllamacpp>=0.2.6",
Expand Down
72 changes: 65 additions & 7 deletions xinference/model/image/ocr/mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
import platform
import sys
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union

import PIL.Image
Expand Down Expand Up @@ -52,6 +53,50 @@ def check_lib(cls):
return False, "MLX engine is only supported on Apple silicon Macs."
return super().check_lib()

@staticmethod
def _reset_mlx_vlm_generation_stream():
try:
import importlib

import mlx.core as mx

device = mx.default_device()
if hasattr(mx, "new_thread_local_stream"):
stream = mx.new_thread_local_stream(device)
else:
stream = mx.new_stream(device)
mlx_vlm_generate = importlib.import_module("mlx_vlm.generate")
setattr(mlx_vlm_generate, "generation_stream", stream)
return stream
except Exception:
logger.debug("Failed to reset mlx-vlm generation stream", exc_info=True)
return None

@staticmethod
@contextmanager
def _mlx_stream_context(stream):
import mlx.core as mx

original_eval = mx.eval
original_async_eval = mx.async_eval

def _eval_on_stream(*args, **kwargs):
with mx.stream(stream):
return original_eval(*args, **kwargs)

def _async_eval_on_stream(*args, **kwargs):
with mx.stream(stream):
return original_async_eval(*args, **kwargs)

mx.eval = _eval_on_stream
mx.async_eval = _async_eval_on_stream
try:
with mx.stream(stream):
yield
finally:
mx.eval = original_eval
mx.async_eval = original_async_eval

def load(self):
if sys.platform != "darwin" or platform.processor() != "arm":
raise RuntimeError("MLX OCR engine only works on Apple silicon Macs.")
Expand Down Expand Up @@ -319,14 +364,27 @@ def _generate_text(self, image: PIL.Image.Image, prompt: str, **kwargs) -> str:
tokenizer = processor.tokenizer
detokenizer.reset()
text_parts = []
stream = self._reset_mlx_vlm_generation_stream()

for token, _ in generate_step(
input_ids, self._model, pixel_values, mask, **gen_kwargs
):
if token == tokenizer.eos_token_id or token in stop_token_ids:
break
detokenizer.add_token(token)
text_parts.append(detokenizer.last_segment)
if stream is None:
token_iter = generate_step(
input_ids, self._model, pixel_values, mask, **gen_kwargs
)
for token, _ in token_iter:
if token == tokenizer.eos_token_id or token in stop_token_ids:
break
detokenizer.add_token(token)
text_parts.append(detokenizer.last_segment)
else:
with self._mlx_stream_context(stream):
token_iter = generate_step(
input_ids, self._model, pixel_values, mask, **gen_kwargs
)
for token, _ in token_iter:
if token == tokenizer.eos_token_id or token in stop_token_ids:
break
detokenizer.add_token(token)
text_parts.append(detokenizer.last_segment)
detokenizer.finalize()
text_parts.append(detokenizer.last_segment)
return "".join(text_parts)
Loading
Loading