Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 126 additions & 28 deletions xinference/model/llm/mlx/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,33 +63,70 @@
logger = logging.getLogger(__name__)


class _MLXLogitsModelAdapter:
"""Expose mlx-vlm language models with the logits API expected by mlx-lm."""

def __init__(self, model: Any):
self._model = model

def __getattr__(self, name: str) -> Any:
return getattr(self._model, name)

def __call__(self, *args, **kwargs):
self._reset_position_cache_for_prefill(args, kwargs)
outputs = self._model(*args, **kwargs)
return getattr(outputs, "logits", outputs)

def _reset_position_cache_for_prefill(self, args, kwargs):
cache = kwargs.get("cache")
if cache is None or not cache:
return

cache_idx = getattr(getattr(self._model, "model", None), "fa_idx", 0)
if cache_idx >= len(cache) or cache[cache_idx] is None:
return

offset = cache[cache_idx].offset
if not isinstance(offset, int):
try:
offset = (offset if offset.ndim == 0 else offset[0]).item()
except AttributeError:
return
if offset != 0:
return

if args:
inputs = args[0]
else:
inputs = kwargs.get("inputs")
if getattr(inputs, "shape", None) is None or inputs.shape[-1] <= 1:
return

for attr in ("_rope_deltas", "_position_ids"):
if hasattr(self._model, attr):
setattr(self._model, attr, None)


class MLXBatchModel:
"""Wrapper around MLX-LM BatchGenerator for continuous batching."""

# Class-level storage for multiple batch generators keyed by (temperature, top_p)
_batch_generators: Dict[Tuple[float, float], Dict[str, Any]] = (
{}
) # (temp, top_p) -> {'generator': BatchGenerator, 'queues': dict, 'pending': dict, 'active': set, 'task': task}
_model_ref = None
_tokenizer_ref = None
_batch_size = 4
_max_context_length = 2048
_stop_tokens: set = set()
_lock: Optional[asyncio.Lock] = None # Will be initialized lazily
Comment thread
qinxuye marked this conversation as resolved.
Outdated
_mlx_lm_version: Optional[str] = None # Cache mlx-lm version

def __init__(
self, model, tokenizer, batch_size: int = 4, max_context_length: int = 2048
):
# Store references for creating new generators on demand
MLXBatchModel._model_ref = model
MLXBatchModel._tokenizer_ref = tokenizer
MLXBatchModel._batch_size = batch_size
MLXBatchModel._max_context_length = max_context_length
eos_token_ids = tokenizer.eos_token_ids
self._batch_generators: Dict[Tuple[float, float], Dict[str, Any]] = {}
self._model_ref = model
self._tokenizer_ref = tokenizer
self._batch_size = batch_size
self._max_context_length = max_context_length
eos_token_ids = getattr(tokenizer, "eos_token_ids", None)
if eos_token_ids is None:
eos_token_ids = getattr(tokenizer, "eos_token_id", None)
if isinstance(eos_token_ids, int):
eos_token_ids = [eos_token_ids]
MLXBatchModel._stop_tokens = set(eos_token_ids or [])
self._stop_tokens = set(eos_token_ids or [])

@staticmethod
def _get_lock() -> asyncio.Lock:
Expand Down Expand Up @@ -132,7 +169,7 @@ def _get_or_create_generator(self, temperature: float, top_p: float):
"""Get or create a BatchGenerator for the given sampling parameters."""
key = (round(temperature, 6), round(top_p, 6))

if key not in MLXBatchModel._batch_generators:
if key not in self._batch_generators:
logger.info(
f"Creating new BatchGenerator for temperature={temperature}, top_p={top_p}"
)
Expand All @@ -144,23 +181,23 @@ def _get_or_create_generator(self, temperature: float, top_p: float):

# Create batch generator
batch_generator = BatchGenerator(
model=MLXBatchModel._model_ref,
max_tokens=MLXBatchModel._max_context_length,
stop_tokens=MLXBatchModel._stop_tokens,
model=self._model_ref,
max_tokens=self._max_context_length,
stop_tokens=self._stop_tokens,
sampler=sampler,
completion_batch_size=MLXBatchModel._batch_size,
completion_batch_size=self._batch_size,
prefill_batch_size=1, # Use 1 for lowest latency
)

MLXBatchModel._batch_generators[key] = {
self._batch_generators[key] = {
"generator": batch_generator,
"queues": {}, # uid -> asyncio.Queue
"pending": {}, # uid -> list of results
"active": set(), # active uids
"task": None,
}

return MLXBatchModel._batch_generators[key]
return self._batch_generators[key]

def _ensure_background_worker(self, gen_dict):
"""Ensure background worker is running for this generator."""
Expand Down Expand Up @@ -242,8 +279,8 @@ async def generate_stream(
self._ensure_background_worker(gen_dict)

# Prepare request
assert MLXBatchModel._tokenizer_ref is not None
prompt_tokens = MLXBatchModel._tokenizer_ref.encode(prompt)
assert self._tokenizer_ref is not None
prompt_tokens = self._tokenizer_ref.encode(prompt)
input_echo_len = len(prompt_tokens)

# Create queue first
Expand Down Expand Up @@ -301,8 +338,8 @@ async def generate_stream(
token_count += 1

# Decode current token
assert MLXBatchModel._tokenizer_ref is not None
token_text = MLXBatchModel._tokenizer_ref.decode(
assert self._tokenizer_ref is not None
token_text = self._tokenizer_ref.decode(
[result.token], skip_special_tokens=skip_special_tokens
)
token_text = token_text.strip("�")
Expand Down Expand Up @@ -1302,14 +1339,23 @@ def generate(
generate_config: Optional[MLXGenerateConfig] = None,
from_chat: bool = False,
) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate method for vision models (not using continuous batching)."""
"""Generate method for vision models.

Text-only requests can use mlx-lm continuous batching through the
underlying language model. Requests with images stay on mlx-vlm because
their image tensors are not generally batch compatible.
"""
generate_config = self._sanitize_generate_config(generate_config)
logger.debug(f"[MLXVisionModel] generation params: {generate_config}")

assert self._model is not None
assert self._tokenizer is not None

stream = generate_config.get("stream", False)
if self._batch_model is not None and self._is_text_only_prompt(prompt):
batch_result = self._generate_text_only_with_batch(prompt, generate_config)
if batch_result is not None:
return batch_result

if stream:
# _generate_stream yields (chunk, usage) tuples; unwrap for the caller
Expand Down Expand Up @@ -1352,6 +1398,48 @@ def _unwrap():
),
)

@staticmethod
def _is_text_only_prompt(prompt: Union[str, Dict[str, Any]]) -> bool:
if not isinstance(prompt, dict):
return True
multi_modal_data = prompt.get("multi_modal_data")
if not multi_modal_data:
return True
images = (
multi_modal_data.get("image")
if isinstance(multi_modal_data, dict)
else None
)
return not images
Comment thread
qinxuye marked this conversation as resolved.
Outdated

def _generate_text_only_with_batch(
self, prompt: Union[str, Dict[str, Any]], generate_config: MLXGenerateConfig
) -> Optional[Union[Completion, Iterator[CompletionChunk]]]:
if self._loop is None or not self._loop.is_running():
return None

prompt_text = prompt.get("prompt", "") if isinstance(prompt, dict) else prompt
future = asyncio.run_coroutine_threadsafe(
self.async_generate(prompt_text, generate_config), self._loop
)
result = future.result()
if not hasattr(result, "__aiter__"):
return result

async_generator = result

def _sync_completion_chunks():
while True:
try:
next_future = asyncio.run_coroutine_threadsafe(
async_generator.__anext__(), self._loop
)
yield next_future.result()
except StopAsyncIteration:
break

return _sync_completion_chunks()

def wait_for_load(self):
"""Override parent wait_for_load to skip MLXBatchModel creation."""
if self._loading_thread:
Expand Down Expand Up @@ -1406,6 +1494,16 @@ def load(self):
config.update(self._model_config)
self._context_length = get_context_length_from_config(config)

language_model = getattr(self._model, "language_model", None)
if language_model is not None:
batch_size = self._model_config.get("batch_size", 4)
self._batch_model = MLXBatchModel(
model=_MLXLogitsModelAdapter(language_model),
tokenizer=self._tokenizer,
batch_size=batch_size,
max_context_length=self._context_length,
)

def _generate_stream_inner(self, **kwargs):
import mlx.core as mx

Expand Down
52 changes: 52 additions & 0 deletions xinference/model/llm/mlx/tests/test_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,58 @@ def test_load_mlx_vision(setup):
assert len(completion["choices"][0]["message"]["content"]) != 0


@pytest.mark.skipif(
sys.platform != "darwin" or platform.processor() != "arm",
reason="MLX only works for Apple silicon chip",
)
def test_mlx_vision_text_only_parallel_inference(setup):
"""Test MLX VLM text-only requests can use continuous batching."""
endpoint, _ = setup
client = Client(endpoint)

model_uid = client.launch_model(
model_name="qwen2-vl-instruct",
model_engine="MLX",
model_size_in_billions=2,
model_format="mlx",
quantization="4bit",
)
assert len(client.list_models()) == 1
model = client.get_model(model_uid)

thread1 = InferenceThread("write a poem.", {"stream": True}, model)
thread2 = InferenceThread("中国的首都是哪里?", {"stream": False}, model)
thread3 = InferenceThread("介绍一下Python。", {"stream": True}, model)

thread1.start()
thread2.start()
thread3.start()

result1 = thread1.join()
result2 = thread2.join()
result3 = thread3.join()

assert result1 is not None
assert result2 is not None
assert result3 is not None

assert "choices" in result1
assert len(result1["choices"]) > 0
assert "delta" in result1["choices"][0]
assert result1["choices"][0]["finish_reason"] in ["stop", "length"]

assert "choices" in result2
assert len(result2["choices"]) > 0
assert "message" in result2["choices"][0]
assert "content" in result2["choices"][0]["message"]
assert len(result2["choices"][0]["message"]["content"]) > 0

assert "choices" in result3
assert len(result3["choices"]) > 0
assert "delta" in result3["choices"][0]
assert result3["choices"][0]["finish_reason"] in ["stop", "length"]


@pytest.mark.skipif(
sys.platform != "darwin" or platform.processor() != "arm",
reason="MLX only works for Apple silicon chip",
Expand Down
Loading