From 6d97129afeacbcc929aefb22d1eed056e1420709 Mon Sep 17 00:00:00 2001 From: qinxuye Date: Sun, 17 May 2026 01:07:31 +0800 Subject: [PATCH 01/13] feat: batch text-only MLX VLM requests --- xinference/model/llm/mlx/core.py | 154 +++++++++++++++++---- xinference/model/llm/mlx/tests/test_mlx.py | 52 +++++++ 2 files changed, 178 insertions(+), 28 deletions(-) diff --git a/xinference/model/llm/mlx/core.py b/xinference/model/llm/mlx/core.py index b4bcf621e1..27de87c992 100644 --- a/xinference/model/llm/mlx/core.py +++ b/xinference/model/llm/mlx/core.py @@ -63,33 +63,70 @@ logger = logging.getLogger(__name__) +class _MLXLogitsModelAdapter: + """Expose mlx-vlm language models with the logits API expected by mlx-lm.""" + + def __init__(self, model: Any): + self._model = model + + def __getattr__(self, name: str) -> Any: + return getattr(self._model, name) + + def __call__(self, *args, **kwargs): + self._reset_position_cache_for_prefill(args, kwargs) + outputs = self._model(*args, **kwargs) + return getattr(outputs, "logits", outputs) + + def _reset_position_cache_for_prefill(self, args, kwargs): + cache = kwargs.get("cache") + if cache is None or not cache: + return + + cache_idx = getattr(getattr(self._model, "model", None), "fa_idx", 0) + if cache_idx >= len(cache) or cache[cache_idx] is None: + return + + offset = cache[cache_idx].offset + if not isinstance(offset, int): + try: + offset = (offset if offset.ndim == 0 else offset[0]).item() + except AttributeError: + return + if offset != 0: + return + + if args: + inputs = args[0] + else: + inputs = kwargs.get("inputs") + if getattr(inputs, "shape", None) is None or inputs.shape[-1] <= 1: + return + + for attr in ("_rope_deltas", "_position_ids"): + if hasattr(self._model, attr): + setattr(self._model, attr, None) + + class MLXBatchModel: """Wrapper around MLX-LM BatchGenerator for continuous batching.""" - # Class-level storage for multiple batch generators keyed by (temperature, top_p) - _batch_generators: Dict[Tuple[float, float], Dict[str, Any]] = ( - {} - ) # (temp, top_p) -> {'generator': BatchGenerator, 'queues': dict, 'pending': dict, 'active': set, 'task': task} - _model_ref = None - _tokenizer_ref = None - _batch_size = 4 - _max_context_length = 2048 - _stop_tokens: set = set() _lock: Optional[asyncio.Lock] = None # Will be initialized lazily _mlx_lm_version: Optional[str] = None # Cache mlx-lm version def __init__( self, model, tokenizer, batch_size: int = 4, max_context_length: int = 2048 ): - # Store references for creating new generators on demand - MLXBatchModel._model_ref = model - MLXBatchModel._tokenizer_ref = tokenizer - MLXBatchModel._batch_size = batch_size - MLXBatchModel._max_context_length = max_context_length - eos_token_ids = tokenizer.eos_token_ids + self._batch_generators: Dict[Tuple[float, float], Dict[str, Any]] = {} + self._model_ref = model + self._tokenizer_ref = tokenizer + self._batch_size = batch_size + self._max_context_length = max_context_length + eos_token_ids = getattr(tokenizer, "eos_token_ids", None) + if eos_token_ids is None: + eos_token_ids = getattr(tokenizer, "eos_token_id", None) if isinstance(eos_token_ids, int): eos_token_ids = [eos_token_ids] - MLXBatchModel._stop_tokens = set(eos_token_ids or []) + self._stop_tokens = set(eos_token_ids or []) @staticmethod def _get_lock() -> asyncio.Lock: @@ -132,7 +169,7 @@ def _get_or_create_generator(self, temperature: float, top_p: float): """Get or create a BatchGenerator for the given sampling parameters.""" key = (round(temperature, 6), round(top_p, 6)) - if key not in MLXBatchModel._batch_generators: + if key not in self._batch_generators: logger.info( f"Creating new BatchGenerator for temperature={temperature}, top_p={top_p}" ) @@ -144,15 +181,15 @@ def _get_or_create_generator(self, temperature: float, top_p: float): # Create batch generator batch_generator = BatchGenerator( - model=MLXBatchModel._model_ref, - max_tokens=MLXBatchModel._max_context_length, - stop_tokens=MLXBatchModel._stop_tokens, + model=self._model_ref, + max_tokens=self._max_context_length, + stop_tokens=self._stop_tokens, sampler=sampler, - completion_batch_size=MLXBatchModel._batch_size, + completion_batch_size=self._batch_size, prefill_batch_size=1, # Use 1 for lowest latency ) - MLXBatchModel._batch_generators[key] = { + self._batch_generators[key] = { "generator": batch_generator, "queues": {}, # uid -> asyncio.Queue "pending": {}, # uid -> list of results @@ -160,7 +197,7 @@ def _get_or_create_generator(self, temperature: float, top_p: float): "task": None, } - return MLXBatchModel._batch_generators[key] + return self._batch_generators[key] def _ensure_background_worker(self, gen_dict): """Ensure background worker is running for this generator.""" @@ -242,8 +279,8 @@ async def generate_stream( self._ensure_background_worker(gen_dict) # Prepare request - assert MLXBatchModel._tokenizer_ref is not None - prompt_tokens = MLXBatchModel._tokenizer_ref.encode(prompt) + assert self._tokenizer_ref is not None + prompt_tokens = self._tokenizer_ref.encode(prompt) input_echo_len = len(prompt_tokens) # Create queue first @@ -301,8 +338,8 @@ async def generate_stream( token_count += 1 # Decode current token - assert MLXBatchModel._tokenizer_ref is not None - token_text = MLXBatchModel._tokenizer_ref.decode( + assert self._tokenizer_ref is not None + token_text = self._tokenizer_ref.decode( [result.token], skip_special_tokens=skip_special_tokens ) token_text = token_text.strip("�") @@ -1302,7 +1339,12 @@ def generate( generate_config: Optional[MLXGenerateConfig] = None, from_chat: bool = False, ) -> Union[Completion, Iterator[CompletionChunk]]: - """Generate method for vision models (not using continuous batching).""" + """Generate method for vision models. + + Text-only requests can use mlx-lm continuous batching through the + underlying language model. Requests with images stay on mlx-vlm because + their image tensors are not generally batch compatible. + """ generate_config = self._sanitize_generate_config(generate_config) logger.debug(f"[MLXVisionModel] generation params: {generate_config}") @@ -1310,6 +1352,10 @@ def generate( assert self._tokenizer is not None stream = generate_config.get("stream", False) + if self._batch_model is not None and self._is_text_only_prompt(prompt): + batch_result = self._generate_text_only_with_batch(prompt, generate_config) + if batch_result is not None: + return batch_result if stream: # _generate_stream yields (chunk, usage) tuples; unwrap for the caller @@ -1352,6 +1398,48 @@ def _unwrap(): ), ) + @staticmethod + def _is_text_only_prompt(prompt: Union[str, Dict[str, Any]]) -> bool: + if not isinstance(prompt, dict): + return True + multi_modal_data = prompt.get("multi_modal_data") + if not multi_modal_data: + return True + images = ( + multi_modal_data.get("image") + if isinstance(multi_modal_data, dict) + else None + ) + return not images + + def _generate_text_only_with_batch( + self, prompt: Union[str, Dict[str, Any]], generate_config: MLXGenerateConfig + ) -> Optional[Union[Completion, Iterator[CompletionChunk]]]: + if self._loop is None or not self._loop.is_running(): + return None + + prompt_text = prompt.get("prompt", "") if isinstance(prompt, dict) else prompt + future = asyncio.run_coroutine_threadsafe( + self.async_generate(prompt_text, generate_config), self._loop + ) + result = future.result() + if not hasattr(result, "__aiter__"): + return result + + async_generator = result + + def _sync_completion_chunks(): + while True: + try: + next_future = asyncio.run_coroutine_threadsafe( + async_generator.__anext__(), self._loop + ) + yield next_future.result() + except StopAsyncIteration: + break + + return _sync_completion_chunks() + def wait_for_load(self): """Override parent wait_for_load to skip MLXBatchModel creation.""" if self._loading_thread: @@ -1406,6 +1494,16 @@ def load(self): config.update(self._model_config) self._context_length = get_context_length_from_config(config) + language_model = getattr(self._model, "language_model", None) + if language_model is not None: + batch_size = self._model_config.get("batch_size", 4) + self._batch_model = MLXBatchModel( + model=_MLXLogitsModelAdapter(language_model), + tokenizer=self._tokenizer, + batch_size=batch_size, + max_context_length=self._context_length, + ) + def _generate_stream_inner(self, **kwargs): import mlx.core as mx diff --git a/xinference/model/llm/mlx/tests/test_mlx.py b/xinference/model/llm/mlx/tests/test_mlx.py index 53304414ba..030ff8f670 100644 --- a/xinference/model/llm/mlx/tests/test_mlx.py +++ b/xinference/model/llm/mlx/tests/test_mlx.py @@ -146,6 +146,58 @@ def test_load_mlx_vision(setup): assert len(completion["choices"][0]["message"]["content"]) != 0 +@pytest.mark.skipif( + sys.platform != "darwin" or platform.processor() != "arm", + reason="MLX only works for Apple silicon chip", +) +def test_mlx_vision_text_only_parallel_inference(setup): + """Test MLX VLM text-only requests can use continuous batching.""" + endpoint, _ = setup + client = Client(endpoint) + + model_uid = client.launch_model( + model_name="qwen2-vl-instruct", + model_engine="MLX", + model_size_in_billions=2, + model_format="mlx", + quantization="4bit", + ) + assert len(client.list_models()) == 1 + model = client.get_model(model_uid) + + thread1 = InferenceThread("write a poem.", {"stream": True}, model) + thread2 = InferenceThread("中国的首都是哪里?", {"stream": False}, model) + thread3 = InferenceThread("介绍一下Python。", {"stream": True}, model) + + thread1.start() + thread2.start() + thread3.start() + + result1 = thread1.join() + result2 = thread2.join() + result3 = thread3.join() + + assert result1 is not None + assert result2 is not None + assert result3 is not None + + assert "choices" in result1 + assert len(result1["choices"]) > 0 + assert "delta" in result1["choices"][0] + assert result1["choices"][0]["finish_reason"] in ["stop", "length"] + + assert "choices" in result2 + assert len(result2["choices"]) > 0 + assert "message" in result2["choices"][0] + assert "content" in result2["choices"][0]["message"] + assert len(result2["choices"][0]["message"]["content"]) > 0 + + assert "choices" in result3 + assert len(result3["choices"]) > 0 + assert "delta" in result3["choices"][0] + assert result3["choices"][0]["finish_reason"] in ["stop", "length"] + + @pytest.mark.skipif( sys.platform != "darwin" or platform.processor() != "arm", reason="MLX only works for Apple silicon chip", From 27de5d48a4c2b136a309d650e116ebdab30b44ad Mon Sep 17 00:00:00 2001 From: qinxuye Date: Sun, 17 May 2026 01:12:09 +0800 Subject: [PATCH 02/13] fix: simplify MLX VLM text-only prompt check --- xinference/model/llm/mlx/core.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/xinference/model/llm/mlx/core.py b/xinference/model/llm/mlx/core.py index 27de87c992..43b56f9d86 100644 --- a/xinference/model/llm/mlx/core.py +++ b/xinference/model/llm/mlx/core.py @@ -1403,14 +1403,9 @@ def _is_text_only_prompt(prompt: Union[str, Dict[str, Any]]) -> bool: if not isinstance(prompt, dict): return True multi_modal_data = prompt.get("multi_modal_data") - if not multi_modal_data: + if not isinstance(multi_modal_data, dict): return True - images = ( - multi_modal_data.get("image") - if isinstance(multi_modal_data, dict) - else None - ) - return not images + return not multi_modal_data.get("image") def _generate_text_only_with_batch( self, prompt: Union[str, Dict[str, Any]], generate_config: MLXGenerateConfig From 75941df71b861ce11cf816694a18195552dac7a2 Mon Sep 17 00:00:00 2001 From: qinxuye Date: Sun, 17 May 2026 01:17:12 +0800 Subject: [PATCH 03/13] fix: guard MLX VLM batching against other modalities --- xinference/model/llm/mlx/core.py | 4 +++- xinference/model/llm/mlx/tests/test_mlx.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/xinference/model/llm/mlx/core.py b/xinference/model/llm/mlx/core.py index 43b56f9d86..cec82d96dc 100644 --- a/xinference/model/llm/mlx/core.py +++ b/xinference/model/llm/mlx/core.py @@ -1405,7 +1405,9 @@ def _is_text_only_prompt(prompt: Union[str, Dict[str, Any]]) -> bool: multi_modal_data = prompt.get("multi_modal_data") if not isinstance(multi_modal_data, dict): return True - return not multi_modal_data.get("image") + return not any( + multi_modal_data.get(modality) for modality in ("image", "video", "audio") + ) def _generate_text_only_with_batch( self, prompt: Union[str, Dict[str, Any]], generate_config: MLXGenerateConfig diff --git a/xinference/model/llm/mlx/tests/test_mlx.py b/xinference/model/llm/mlx/tests/test_mlx.py index 030ff8f670..b9a4524827 100644 --- a/xinference/model/llm/mlx/tests/test_mlx.py +++ b/xinference/model/llm/mlx/tests/test_mlx.py @@ -21,6 +21,7 @@ import pytest from .....client import Client +from ..core import MLXVisionModel class InferenceThread(threading.Thread): @@ -66,6 +67,23 @@ def join(self, timeout=None): return self._result +def test_mlx_vision_text_only_prompt_detection(): + assert MLXVisionModel._is_text_only_prompt("hello") + assert MLXVisionModel._is_text_only_prompt({"prompt": "hello"}) + assert MLXVisionModel._is_text_only_prompt( + {"prompt": "hello", "multi_modal_data": {}} + ) + assert not MLXVisionModel._is_text_only_prompt( + {"prompt": "hello", "multi_modal_data": {"image": "image"}} + ) + assert not MLXVisionModel._is_text_only_prompt( + {"prompt": "hello", "multi_modal_data": {"video": "video"}} + ) + assert not MLXVisionModel._is_text_only_prompt( + {"prompt": "hello", "multi_modal_data": {"audio": "audio"}} + ) + + @pytest.mark.skipif( sys.platform != "darwin" or platform.processor() != "arm", reason="MLX only works for Apple silicon chip", From c4d46256e42f1869c0285442645143f1a62ac770 Mon Sep 17 00:00:00 2001 From: qinxuye Date: Sun, 17 May 2026 01:25:47 +0800 Subject: [PATCH 04/13] fix: isolate MLX batch locks per instance --- xinference/model/llm/mlx/core.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/xinference/model/llm/mlx/core.py b/xinference/model/llm/mlx/core.py index cec82d96dc..81923d4641 100644 --- a/xinference/model/llm/mlx/core.py +++ b/xinference/model/llm/mlx/core.py @@ -110,13 +110,13 @@ def _reset_position_cache_for_prefill(self, args, kwargs): class MLXBatchModel: """Wrapper around MLX-LM BatchGenerator for continuous batching.""" - _lock: Optional[asyncio.Lock] = None # Will be initialized lazily _mlx_lm_version: Optional[str] = None # Cache mlx-lm version def __init__( self, model, tokenizer, batch_size: int = 4, max_context_length: int = 2048 ): self._batch_generators: Dict[Tuple[float, float], Dict[str, Any]] = {} + self._lock: Optional[asyncio.Lock] = None self._model_ref = model self._tokenizer_ref = tokenizer self._batch_size = batch_size @@ -128,12 +128,11 @@ def __init__( eos_token_ids = [eos_token_ids] self._stop_tokens = set(eos_token_ids or []) - @staticmethod - def _get_lock() -> asyncio.Lock: + def _get_lock(self) -> asyncio.Lock: """Get or create the async lock.""" - if MLXBatchModel._lock is None: - MLXBatchModel._lock = asyncio.Lock() - return MLXBatchModel._lock + if self._lock is None: + self._lock = asyncio.Lock() + return self._lock @staticmethod def _get_mlx_lm_version() -> str: @@ -228,7 +227,7 @@ async def _background_worker(self, gen_dict): continue # Distribute results to respective request queues - async with MLXBatchModel._get_lock(): + async with self._get_lock(): for result in batch_results: if result.uid in gen_dict["queues"]: queue = gen_dict["queues"][result.uid] @@ -303,7 +302,7 @@ async def generate_stream( ) # Add to active requests set - async with MLXBatchModel._get_lock(): + async with self._get_lock(): gen_dict["active"].add(inserted_uid) gen_dict["queues"][inserted_uid] = queue @@ -377,7 +376,7 @@ async def generate_stream( except asyncio.TimeoutError: # Check if request is still in active set - async with MLXBatchModel._get_lock(): + async with self._get_lock(): is_active = inserted_uid in gen_dict["active"] if not is_active: @@ -402,7 +401,7 @@ async def generate_stream( ) finally: # Clean up queue using the correct uid - async with MLXBatchModel._get_lock(): + async with self._get_lock(): if inserted_uid in gen_dict["queues"]: del gen_dict["queues"][inserted_uid] logger.debug(f"Cleaned up queue for uid {inserted_uid}") From bb7481c3c8ec58dade83bb951667883adb5a6030 Mon Sep 17 00:00:00 2001 From: qinxuye Date: Sun, 17 May 2026 02:29:27 +0800 Subject: [PATCH 05/13] test: cap MLX VLM text-only generation length --- xinference/model/llm/mlx/tests/test_mlx.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/xinference/model/llm/mlx/tests/test_mlx.py b/xinference/model/llm/mlx/tests/test_mlx.py index b9a4524827..19dee8d4ee 100644 --- a/xinference/model/llm/mlx/tests/test_mlx.py +++ b/xinference/model/llm/mlx/tests/test_mlx.py @@ -158,7 +158,7 @@ def test_load_mlx_vision(setup): # test no image messages = [{"role": "user", "content": "write a poem."}] - completion = model.chat(messages) + completion = model.chat(messages, generate_config={"max_tokens": 32}) assert "content" in completion["choices"][0]["message"] assert "content" in completion["choices"][0]["message"] assert len(completion["choices"][0]["message"]["content"]) != 0 @@ -183,9 +183,15 @@ def test_mlx_vision_text_only_parallel_inference(setup): assert len(client.list_models()) == 1 model = client.get_model(model_uid) - thread1 = InferenceThread("write a poem.", {"stream": True}, model) - thread2 = InferenceThread("中国的首都是哪里?", {"stream": False}, model) - thread3 = InferenceThread("介绍一下Python。", {"stream": True}, model) + thread1 = InferenceThread( + "write a poem.", {"stream": True, "max_tokens": 32}, model + ) + thread2 = InferenceThread( + "中国的首都是哪里?", {"stream": False, "max_tokens": 32}, model + ) + thread3 = InferenceThread( + "介绍一下Python。", {"stream": True, "max_tokens": 32}, model + ) thread1.start() thread2.start() From 82790e419d005fd50c406c6f5f85742c3b0d6a57 Mon Sep 17 00:00:00 2001 From: qinxuye Date: Sun, 17 May 2026 12:30:17 +0800 Subject: [PATCH 06/13] fix: reset MLX batch stream in worker thread --- xinference/model/llm/mlx/core.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/xinference/model/llm/mlx/core.py b/xinference/model/llm/mlx/core.py index 81923d4641..2d346bfdbf 100644 --- a/xinference/model/llm/mlx/core.py +++ b/xinference/model/llm/mlx/core.py @@ -206,6 +206,7 @@ def _ensure_background_worker(self, gen_dict): async def _background_worker(self, gen_dict): """Background worker that continuously calls next() and distributes results.""" + self._reset_mlx_lm_generation_stream() key = f"temp={gen_dict['generator'].sampler}" logger.info(f"Starting BatchGenerator background worker for {key}") batch_generator = gen_dict["generator"] @@ -258,8 +259,31 @@ async def _background_worker(self, gen_dict): except Exception as e: logger.error(f"Error in background worker: {e}", exc_info=True) + await self._fail_active_requests(gen_dict, e) await asyncio.sleep(0.01) + @staticmethod + def _reset_mlx_lm_generation_stream(): + try: + import importlib + + import mlx.core as mx + + mlx_lm_generate = importlib.import_module("mlx_lm.generate") + + mlx_lm_generate.generation_stream = mx.new_stream(mx.gpu) + except Exception: + logger.debug("Failed to reset mlx-lm generation stream", exc_info=True) + + async def _fail_active_requests(self, gen_dict, exc: Exception): + async with self._get_lock(): + for uid, queue in list(gen_dict["queues"].items()): + try: + queue.put_nowait(exc) + except asyncio.QueueFull: + logger.warning(f"Queue full for uid {uid} while failing request") + gen_dict["active"].clear() + async def generate_stream( self, prompt: str, @@ -331,6 +355,8 @@ async def generate_stream( try: # Wait for result with short timeout (so we can check if request is still active) result = await asyncio.wait_for(queue.get(), timeout=0.5) + if isinstance(result, Exception): + raise result if result.token is not None: generated_tokens.append(result.token) From ab2e7fb4ef2bcbaf6e4eae3d7e567a79c8733e07 Mon Sep 17 00:00:00 2001 From: qinxuye Date: Sun, 17 May 2026 13:18:19 +0800 Subject: [PATCH 07/13] fix: bind MLX batch stream before batch ops --- xinference/model/llm/mlx/core.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/xinference/model/llm/mlx/core.py b/xinference/model/llm/mlx/core.py index 2d346bfdbf..c6898b3c02 100644 --- a/xinference/model/llm/mlx/core.py +++ b/xinference/model/llm/mlx/core.py @@ -169,6 +169,7 @@ def _get_or_create_generator(self, temperature: float, top_p: float): key = (round(temperature, 6), round(top_p, 6)) if key not in self._batch_generators: + self._reset_mlx_lm_generation_stream() logger.info( f"Creating new BatchGenerator for temperature={temperature}, top_p={top_p}" ) @@ -213,6 +214,7 @@ async def _background_worker(self, gen_dict): while True: try: + self._reset_mlx_lm_generation_stream() # Get next batch of results for ALL active requests # Use different API based on mlx-lm version if MLXBatchModel._is_new_mlx_lm(): @@ -271,7 +273,8 @@ def _reset_mlx_lm_generation_stream(): mlx_lm_generate = importlib.import_module("mlx_lm.generate") - mlx_lm_generate.generation_stream = mx.new_stream(mx.gpu) + device = mx.default_device() + mlx_lm_generate.generation_stream = mx.new_stream(device) except Exception: logger.debug("Failed to reset mlx-lm generation stream", exc_info=True) @@ -311,6 +314,7 @@ async def generate_stream( # Insert prompt into batch to get the real uid # Use different API based on mlx-lm version + self._reset_mlx_lm_generation_stream() if MLXBatchModel._is_new_mlx_lm(): # New API (mlx-lm >= 0.31.2): max_tokens should be a list request_ids = batch_generator.insert( From 22b59067c55b7151723ddc87a47610616e854766 Mon Sep 17 00:00:00 2001 From: qinxuye Date: Sun, 17 May 2026 15:18:01 +0800 Subject: [PATCH 08/13] fix: use default MLX stream for batch generation --- xinference/model/llm/mlx/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xinference/model/llm/mlx/core.py b/xinference/model/llm/mlx/core.py index c6898b3c02..8523aaa5ca 100644 --- a/xinference/model/llm/mlx/core.py +++ b/xinference/model/llm/mlx/core.py @@ -274,7 +274,7 @@ def _reset_mlx_lm_generation_stream(): mlx_lm_generate = importlib.import_module("mlx_lm.generate") device = mx.default_device() - mlx_lm_generate.generation_stream = mx.new_stream(device) + mlx_lm_generate.generation_stream = mx.default_stream(device) except Exception: logger.debug("Failed to reset mlx-lm generation stream", exc_info=True) From b96c439f7031d43125eaef533673f7c81e2d0f93 Mon Sep 17 00:00:00 2001 From: qinxuye Date: Sun, 17 May 2026 18:36:07 +0800 Subject: [PATCH 09/13] fix: avoid mlx-lm batch stream wrapper --- xinference/model/llm/mlx/core.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/xinference/model/llm/mlx/core.py b/xinference/model/llm/mlx/core.py index 8523aaa5ca..d62924a9cc 100644 --- a/xinference/model/llm/mlx/core.py +++ b/xinference/model/llm/mlx/core.py @@ -218,8 +218,8 @@ async def _background_worker(self, gen_dict): # Get next batch of results for ALL active requests # Use different API based on mlx-lm version if MLXBatchModel._is_new_mlx_lm(): - # New API (mlx-lm >= 0.31.2): use next_generated() - batch_results = batch_generator.next_generated() + # New API (mlx-lm >= 0.31.2): return generated tokens only. + batch_results = self._next_generated(batch_generator) else: # Old API (mlx-lm < 0.31.2): use next() and unpack batch_results = batch_generator.next() @@ -264,6 +264,17 @@ async def _background_worker(self, gen_dict): await self._fail_active_requests(gen_dict, e) await asyncio.sleep(0.01) + @staticmethod + def _next_generated(batch_generator): + if not hasattr(batch_generator, "_next"): + return batch_generator.next_generated() + + while True: + prompt_responses, generation_responses = batch_generator._next() + if not generation_responses and prompt_responses: + continue + return generation_responses + @staticmethod def _reset_mlx_lm_generation_stream(): try: From db38608f0c21118631a3f8dd7b8000938b94128c Mon Sep 17 00:00:00 2001 From: qinxuye Date: Sun, 17 May 2026 20:42:29 +0800 Subject: [PATCH 10/13] fix: run MLX generation on default stream --- xinference/model/llm/mlx/core.py | 65 ++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 28 deletions(-) diff --git a/xinference/model/llm/mlx/core.py b/xinference/model/llm/mlx/core.py index d62924a9cc..181dc97d02 100644 --- a/xinference/model/llm/mlx/core.py +++ b/xinference/model/llm/mlx/core.py @@ -269,11 +269,14 @@ def _next_generated(batch_generator): if not hasattr(batch_generator, "_next"): return batch_generator.next_generated() - while True: - prompt_responses, generation_responses = batch_generator._next() - if not generation_responses and prompt_responses: - continue - return generation_responses + import mlx.core as mx + + with mx.stream(mx.default_stream(mx.default_device())): + while True: + prompt_responses, generation_responses = batch_generator._next() + if not generation_responses and prompt_responses: + continue + return generation_responses @staticmethod def _reset_mlx_lm_generation_stream(): @@ -1564,29 +1567,35 @@ def _generate_stream_inner(self, **kwargs): detokenizer.reset() tic = time.perf_counter() try: - for n, (token, logprobs) in enumerate( - generate_step(input_ids, self._model, pixel_values, mask, **kwargs), - ): - if n == 0: - prompt_time = time.perf_counter() - tic - prompt_tps = len(input_ids) / prompt_time - tic = time.perf_counter() - if token == tokenizer.eos_token_id or token in stop_token_ids: - break - detokenizer.add_token(token) - - # Yield the last segment if streaming - yield GenerationResponse( - text=detokenizer.last_segment, - token=token, - logprobs=logprobs, - from_draft=False, - prompt_tokens=len(input_ids), - prompt_tps=prompt_tps, - generation_tokens=n + 1, - generation_tps=(n + 1) / (time.perf_counter() - tic), - peak_memory=mx.metal.get_peak_memory() / 1e9, - ) + with mx.stream(mx.default_stream(mx.default_device())): + token = None + logprobs = None + for n, (token, logprobs) in enumerate( + generate_step(input_ids, self._model, pixel_values, mask, **kwargs), + ): + if n == 0: + prompt_time = time.perf_counter() - tic + prompt_tps = len(input_ids) / prompt_time + tic = time.perf_counter() + if token == tokenizer.eos_token_id or token in stop_token_ids: + break + detokenizer.add_token(token) + + # Yield the last segment if streaming + yield GenerationResponse( + text=detokenizer.last_segment, + token=token, + logprobs=logprobs, + from_draft=False, + prompt_tokens=len(input_ids), + prompt_tps=prompt_tps, + generation_tokens=n + 1, + generation_tps=(n + 1) / (time.perf_counter() - tic), + peak_memory=mx.metal.get_peak_memory() / 1e9, + ) + + if token is None: + return detokenizer.finalize() yield GenerationResponse( From fe3f39f04237052fbe7ead049224198c5b3a2b1a Mon Sep 17 00:00:00 2001 From: qinxuye Date: Mon, 18 May 2026 18:48:43 +0800 Subject: [PATCH 11/13] fix: pin mlx deps for metal stream regression --- .github/workflows/python.yaml | 6 +- setup.cfg | 5 +- xinference/core/virtual_env_manager.py | 4 +- xinference/model/image/ocr/mlx.py | 72 +++++++++++++++++++--- xinference/model/llm/mlx/core.py | 82 ++++++++++++++++++++++---- 5 files changed, 145 insertions(+), 24 deletions(-) diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index c7f56c584f..265b9afccb 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -119,9 +119,9 @@ jobs: pip install "xllamacpp>=0.2.0" if [ "$MODULE" == "metal" ]; then conda install -c conda-forge "ffmpeg<7" - pip install "mlx>=0.22.0" - pip install mlx-lm - pip install "mlx-vlm>=0.3.4" + pip install "mlx>=0.22.0,<0.31.2" + pip install "mlx-lm<0.31.3" + pip install "mlx-vlm>=0.3.4,<0.5.0" pip install mlx-whisper pip install f5-tts-mlx pip install qwen-vl-utils!=0.0.9 diff --git a/setup.cfg b/setup.cfg index 22afc3e0c7..8a8abcf0aa 100644 --- a/setup.cfg +++ b/setup.cfg @@ -138,8 +138,9 @@ vllm = sglang = sglang[srt]>=0.4.2.post4 ; sys_platform=='linux' mlx = - mlx-lm>=0.21.5 ; sys_platform=='darwin' and platform_machine=='arm64' - mlx-vlm>=0.3.4 ; sys_platform=='darwin' and platform_machine=='arm64' + mlx>=0.22.0,<0.31.2 ; sys_platform=='darwin' and platform_machine=='arm64' + mlx-lm>=0.21.5,<0.31.3 ; sys_platform=='darwin' and platform_machine=='arm64' + mlx-vlm>=0.3.4,<0.5.0 ; sys_platform=='darwin' and platform_machine=='arm64' mlx-whisper ; sys_platform=='darwin' and platform_machine=='arm64' f5-tts-mlx ; sys_platform=='darwin' and platform_machine=='arm64' mlx-audio ; sys_platform=='darwin' and platform_machine=='arm64' diff --git a/xinference/core/virtual_env_manager.py b/xinference/core/virtual_env_manager.py index aa36fb232a..fb04a34a7d 100644 --- a/xinference/core/virtual_env_manager.py +++ b/xinference/core/virtual_env_manager.py @@ -56,7 +56,9 @@ "huggingface-hub<1.0", ], "mlx": [ - "mlx-lm>=0.24.0", + "mlx>=0.22.0,<0.31.2", + "mlx-lm>=0.24.0,<0.31.3", + "mlx-vlm>=0.3.4,<0.5.0", ], "llama.cpp": [ "xllamacpp>=0.2.6", diff --git a/xinference/model/image/ocr/mlx.py b/xinference/model/image/ocr/mlx.py index 203a6161b0..5071f77d19 100644 --- a/xinference/model/image/ocr/mlx.py +++ b/xinference/model/image/ocr/mlx.py @@ -15,6 +15,7 @@ import logging import platform import sys +from contextlib import contextmanager from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union import PIL.Image @@ -52,6 +53,50 @@ def check_lib(cls): return False, "MLX engine is only supported on Apple silicon Macs." return super().check_lib() + @staticmethod + def _reset_mlx_vlm_generation_stream(): + try: + import importlib + + import mlx.core as mx + + device = mx.default_device() + if hasattr(mx, "new_thread_local_stream"): + stream = mx.new_thread_local_stream(device) + else: + stream = mx.new_stream(device) + mlx_vlm_generate = importlib.import_module("mlx_vlm.generate") + setattr(mlx_vlm_generate, "generation_stream", stream) + return stream + except Exception: + logger.debug("Failed to reset mlx-vlm generation stream", exc_info=True) + return None + + @staticmethod + @contextmanager + def _mlx_stream_context(stream): + import mlx.core as mx + + original_eval = mx.eval + original_async_eval = mx.async_eval + + def _eval_on_stream(*args, **kwargs): + with mx.stream(stream): + return original_eval(*args, **kwargs) + + def _async_eval_on_stream(*args, **kwargs): + with mx.stream(stream): + return original_async_eval(*args, **kwargs) + + mx.eval = _eval_on_stream + mx.async_eval = _async_eval_on_stream + try: + with mx.stream(stream): + yield + finally: + mx.eval = original_eval + mx.async_eval = original_async_eval + def load(self): if sys.platform != "darwin" or platform.processor() != "arm": raise RuntimeError("MLX OCR engine only works on Apple silicon Macs.") @@ -319,14 +364,27 @@ def _generate_text(self, image: PIL.Image.Image, prompt: str, **kwargs) -> str: tokenizer = processor.tokenizer detokenizer.reset() text_parts = [] + stream = self._reset_mlx_vlm_generation_stream() - for token, _ in generate_step( - input_ids, self._model, pixel_values, mask, **gen_kwargs - ): - if token == tokenizer.eos_token_id or token in stop_token_ids: - break - detokenizer.add_token(token) - text_parts.append(detokenizer.last_segment) + if stream is None: + token_iter = generate_step( + input_ids, self._model, pixel_values, mask, **gen_kwargs + ) + for token, _ in token_iter: + if token == tokenizer.eos_token_id or token in stop_token_ids: + break + detokenizer.add_token(token) + text_parts.append(detokenizer.last_segment) + else: + with self._mlx_stream_context(stream): + token_iter = generate_step( + input_ids, self._model, pixel_values, mask, **gen_kwargs + ) + for token, _ in token_iter: + if token == tokenizer.eos_token_id or token in stop_token_ids: + break + detokenizer.add_token(token) + text_parts.append(detokenizer.last_segment) detokenizer.finalize() text_parts.append(detokenizer.last_segment) return "".join(text_parts) diff --git a/xinference/model/llm/mlx/core.py b/xinference/model/llm/mlx/core.py index 181dc97d02..896e7fc2ea 100644 --- a/xinference/model/llm/mlx/core.py +++ b/xinference/model/llm/mlx/core.py @@ -14,6 +14,7 @@ import asyncio import concurrent.futures +import contextlib import importlib import logging import pathlib @@ -169,7 +170,7 @@ def _get_or_create_generator(self, temperature: float, top_p: float): key = (round(temperature, 6), round(top_p, 6)) if key not in self._batch_generators: - self._reset_mlx_lm_generation_stream() + stream = self._reset_mlx_lm_generation_stream() logger.info( f"Creating new BatchGenerator for temperature={temperature}, top_p={top_p}" ) @@ -178,6 +179,16 @@ def _get_or_create_generator(self, temperature: float, top_p: float): # Create sampler with specific settings sampler = make_sampler(temp=temperature, top_p=top_p) + batch_generator_kwargs = {} + try: + import inspect + + if "stream" in inspect.signature(BatchGenerator.__init__).parameters: + batch_generator_kwargs["stream"] = stream + except Exception: + logger.debug( + "Failed to inspect BatchGenerator signature", exc_info=True + ) # Create batch generator batch_generator = BatchGenerator( @@ -187,6 +198,7 @@ def _get_or_create_generator(self, temperature: float, top_p: float): sampler=sampler, completion_batch_size=self._batch_size, prefill_batch_size=1, # Use 1 for lowest latency + **batch_generator_kwargs, ) self._batch_generators[key] = { @@ -266,12 +278,18 @@ async def _background_worker(self, gen_dict): @staticmethod def _next_generated(batch_generator): - if not hasattr(batch_generator, "_next"): + if hasattr(batch_generator, "next_generated"): return batch_generator.next_generated() import mlx.core as mx - with mx.stream(mx.default_stream(mx.default_device())): + stream = getattr(batch_generator, "stream", None) or getattr( + batch_generator, "_stream", None + ) + if stream is None: + stream = MLXBatchModel._new_mlx_thread_local_stream() + + with mx.stream(stream): while True: prompt_responses, generation_responses = batch_generator._next() if not generation_responses and prompt_responses: @@ -279,18 +297,57 @@ def _next_generated(batch_generator): return generation_responses @staticmethod - def _reset_mlx_lm_generation_stream(): + def _new_mlx_thread_local_stream(): + import mlx.core as mx + + device = mx.default_device() + if hasattr(mx, "new_thread_local_stream"): + return mx.new_thread_local_stream(device) + return mx.new_stream(device) + + @staticmethod + def _reset_generation_stream(module_name: str): try: import importlib - import mlx.core as mx + generate_module = importlib.import_module(module_name) + stream = MLXBatchModel._new_mlx_thread_local_stream() + setattr(generate_module, "generation_stream", stream) + return stream + except Exception: + logger.debug( + "Failed to reset %s generation stream", module_name, exc_info=True + ) + return None - mlx_lm_generate = importlib.import_module("mlx_lm.generate") + @staticmethod + def _reset_mlx_lm_generation_stream(): + return MLXBatchModel._reset_generation_stream("mlx_lm.generate") - device = mx.default_device() - mlx_lm_generate.generation_stream = mx.default_stream(device) - except Exception: - logger.debug("Failed to reset mlx-lm generation stream", exc_info=True) + @staticmethod + @contextlib.contextmanager + def _mlx_stream_context(stream): + import mlx.core as mx + + original_eval = mx.eval + original_async_eval = mx.async_eval + + def _eval_on_stream(*args, **kwargs): + with mx.stream(stream): + return original_eval(*args, **kwargs) + + def _async_eval_on_stream(*args, **kwargs): + with mx.stream(stream): + return original_async_eval(*args, **kwargs) + + mx.eval = _eval_on_stream + mx.async_eval = _async_eval_on_stream + try: + with mx.stream(stream): + yield + finally: + mx.eval = original_eval + mx.async_eval = original_async_eval async def _fail_active_requests(self, gen_dict, exc: Exception): async with self._get_lock(): @@ -1566,8 +1623,11 @@ def _generate_stream_inner(self, **kwargs): detokenizer.reset() tic = time.perf_counter() + stream = MLXBatchModel._reset_generation_stream("mlx_vlm.generate") + if stream is None: + stream = MLXBatchModel._new_mlx_thread_local_stream() try: - with mx.stream(mx.default_stream(mx.default_device())): + with MLXBatchModel._mlx_stream_context(stream): token = None logprobs = None for n, (token, logprobs) in enumerate( From e0ee88e2b5b1f1b1c61e4f739444d199b2894d5f Mon Sep 17 00:00:00 2001 From: qinxuye Date: Tue, 19 May 2026 11:27:33 +0800 Subject: [PATCH 12/13] test: cap mlx parallel generation --- xinference/model/llm/mlx/tests/test_mlx.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/xinference/model/llm/mlx/tests/test_mlx.py b/xinference/model/llm/mlx/tests/test_mlx.py index 19dee8d4ee..e980d406df 100644 --- a/xinference/model/llm/mlx/tests/test_mlx.py +++ b/xinference/model/llm/mlx/tests/test_mlx.py @@ -242,9 +242,13 @@ def test_mlx_parallel_inference(setup): model = client.get_model(model_uid) # Test parallel streaming and non-streaming requests - thread1 = InferenceThread("1+1等于几?", {"stream": True}, model) - thread2 = InferenceThread("中国的首都是哪里?", {"stream": False}, model) - thread3 = InferenceThread("介绍一下Python。", {"stream": True}, model) + thread1 = InferenceThread("1+1等于几?", {"stream": True, "max_tokens": 32}, model) + thread2 = InferenceThread( + "中国的首都是哪里?", {"stream": False, "max_tokens": 32}, model + ) + thread3 = InferenceThread( + "介绍一下Python。", {"stream": True, "max_tokens": 32}, model + ) # Start all threads thread1.start() From 9717c8b7cca30ab312afa8c9b789eabe43eff67d Mon Sep 17 00:00:00 2001 From: qinxuye Date: Tue, 19 May 2026 11:46:24 +0800 Subject: [PATCH 13/13] fix: avoid idle mlx batch generation --- xinference/model/llm/mlx/core.py | 34 ++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/xinference/model/llm/mlx/core.py b/xinference/model/llm/mlx/core.py index 896e7fc2ea..6b96b32ca9 100644 --- a/xinference/model/llm/mlx/core.py +++ b/xinference/model/llm/mlx/core.py @@ -213,7 +213,8 @@ def _get_or_create_generator(self, temperature: float, top_p: float): def _ensure_background_worker(self, gen_dict): """Ensure background worker is running for this generator.""" - if gen_dict["task"] is None: + task = gen_dict["task"] + if task is None or task.done(): loop = asyncio.get_event_loop() gen_dict["task"] = loop.create_task(self._background_worker(gen_dict)) @@ -226,6 +227,13 @@ async def _background_worker(self, gen_dict): while True: try: + async with self._get_lock(): + has_active_requests = bool(gen_dict["active"]) + + if not has_active_requests: + await asyncio.sleep(0.001) + continue + self._reset_mlx_lm_generation_stream() # Get next batch of results for ALL active requests # Use different API based on mlx-lm version @@ -268,6 +276,11 @@ async def _background_worker(self, gen_dict): f"Request {result.uid} finished, removed from active set" ) + has_active_requests = bool(gen_dict["active"]) + + if not has_active_requests: + self._synchronize_mlx() + # Small delay to prevent busy waiting await asyncio.sleep(0.0001) @@ -305,6 +318,19 @@ def _new_mlx_thread_local_stream(): return mx.new_thread_local_stream(device) return mx.new_stream(device) + @staticmethod + def _synchronize_mlx(): + try: + import mlx.core as mx + + synchronize = getattr(mx, "synchronize", None) + if synchronize is not None: + synchronize() + else: + mx.eval(mx.array([0])) + except Exception: + logger.debug("Failed to synchronize MLX", exc_info=True) + @staticmethod def _reset_generation_stream(module_name: str): try: @@ -372,9 +398,6 @@ async def generate_stream( gen_dict = self._get_or_create_generator(temperature, top_p) batch_generator = gen_dict["generator"] - # Ensure background worker is running for this generator - self._ensure_background_worker(gen_dict) - # Prepare request assert self._tokenizer_ref is not None prompt_tokens = self._tokenizer_ref.encode(prompt) @@ -419,6 +442,9 @@ async def generate_stream( ) del gen_dict["pending"][inserted_uid] + # Ensure background worker is running after the request is visible. + self._ensure_background_worker(gen_dict) + try: # Track generated text generated_tokens = []