diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index b024038932..efc3506ab4 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -1864,6 +1864,77 @@ async def add_worker(self, worker_address: str): self._worker_address_to_worker[worker_address] = worker_ref logger.debug("Worker %s has been added successfully", worker_address) + @log_async(logger=logger) + async def ensure_worker( + self, worker_address: str + ) -> xo.ActorRefType["WorkerActor"]: + from .worker import WorkerActor + + worker_ref = await xo.actor_ref( + address=worker_address, uid=WorkerActor.default_uid() + ) + if worker_address in self._worker_address_to_worker: + self._worker_address_to_worker[worker_address] = worker_ref + logger.debug("Worker %s already registered, refreshed ref", worker_address) + else: + self._worker_address_to_worker[worker_address] = worker_ref + logger.debug("Worker %s has been added successfully", worker_address) + return worker_ref + + @log_async(logger=logger) + async def restore_worker_models( + self, worker_address: str, models: Dict[str, Dict[str, Any]] + ): + if not models: + return + worker_ref = await self.ensure_worker(worker_address) + restored = 0 + for replica_model_uid in models.keys(): + model_uid, rep_id = parse_replica_model_uid(replica_model_uid) + if rep_id < 0: + rep_id = 0 + + replica_info = self._model_uid_to_replica_info.get(model_uid, None) + if replica_info is None: + replica_count = rep_id + 1 + replica_info = ReplicaInfo( + replica=replica_count, + scheduler=itertools.cycle(range(replica_count)), + ) + self._model_uid_to_replica_info[model_uid] = replica_info + elif rep_id + 1 > replica_info.replica: + replica_info.replica = rep_id + 1 + replica_info.scheduler = itertools.cycle(range(replica_info.replica)) + + if all( + w.address != worker_ref.address + for w in replica_info.replica_to_worker_refs[rep_id] + ): + replica_info.replica_to_worker_refs[rep_id].append(worker_ref) + + existing = self._replica_model_uid_to_worker.get(replica_model_uid, None) + if existing is None: + self._replica_model_uid_to_worker[replica_model_uid] = worker_ref + elif isinstance(existing, (list, tuple)): + if all(w.address != worker_ref.address for w in existing): + if isinstance(existing, tuple): + self._replica_model_uid_to_worker[replica_model_uid] = [ + *existing, + worker_ref, + ] + else: + existing.append(worker_ref) + else: + if existing.address != worker_ref.address: + self._replica_model_uid_to_worker[replica_model_uid] = [ + existing, + worker_ref, + ] + restored += 1 + logger.info( + "Restored %s model replicas for worker %s", restored, worker_address + ) + @log_async(logger=logger) async def remove_worker(self, worker_address: str): uids_to_remove = [] @@ -1896,6 +1967,20 @@ async def remove_worker(self, worker_address: str): async def report_worker_status( self, worker_address: str, status: Dict[str, Union[ResourceStatus, GPUStatus]] ): + if worker_address not in self._worker_address_to_worker: + logger.warning( + "Worker %s reported status but is not registered; restoring models", + worker_address, + ) + worker_ref = await self.ensure_worker(worker_address) + try: + models = await worker_ref.list_models() + await self.restore_worker_models(worker_address, models) + except Exception: + logger.exception( + "Failed to restore worker models on status report for %s", + worker_address, + ) if worker_address not in self._worker_status: logger.debug("Worker %s resources: %s", worker_address, status) self._worker_status[worker_address] = WorkerStatus( diff --git a/xinference/core/worker.py b/xinference/core/worker.py index afa13acc33..441cb38793 100644 --- a/xinference/core/worker.py +++ b/xinference/core/worker.py @@ -51,6 +51,7 @@ XINFERENCE_DISABLE_METRICS, XINFERENCE_ENABLE_VIRTUAL_ENV, XINFERENCE_HEALTH_CHECK_INTERVAL, + XINFERENCE_HEALTH_CHECK_TIMEOUT, XINFERENCE_VIRTUAL_ENV_DIR, XINFERENCE_VIRTUAL_ENV_SKIP_INSTALLED, ) @@ -191,6 +192,14 @@ def __init__( self._lock = asyncio.Lock() + async def _reset_supervisor_refs(self): + async with self._lock: + self._supervisor_ref = None + self._status_guard_ref = None + self._event_collector_ref = None + self._cache_tracker_ref = None + self._progress_tracker_ref = None + async def recover_sub_pool(self, address): logger.warning("Process %s is down.", address) # Xoscar does not remove the address from sub_processes. @@ -437,52 +446,63 @@ async def get_supervisor_ref(self, add_worker: bool = True) -> xo.ActorRefType: """ from .supervisor import SupervisorActor - if self._supervisor_ref is not None: - return self._supervisor_ref - supervisor_ref = await xo.actor_ref( # type: ignore - address=self._supervisor_address, uid=SupervisorActor.default_uid() - ) - # Prevent concurrent operations leads to double initialization, check again. - if self._supervisor_ref is not None: + async with self._lock: + if self._supervisor_ref is not None: + return self._supervisor_ref + supervisor_ref = await xo.actor_ref( # type: ignore + address=self._supervisor_address, uid=SupervisorActor.default_uid() + ) + # Prevent concurrent operations leads to double initialization, check again. + if self._supervisor_ref is not None: + return self._supervisor_ref + self._supervisor_ref = supervisor_ref + if add_worker: + await self._supervisor_ref.ensure_worker(self.address) + if len(self._model_uid_to_model) == 0: + logger.info("Connected to supervisor as a fresh worker") + else: + try: + models = await self.list_models() + await self._supervisor_ref.restore_worker_models( + self.address, models + ) + except Exception: + logger.exception( + "Failed to restore worker models to supervisor" + ) + + self._status_guard_ref = await xo.actor_ref( + address=self._supervisor_address, uid=StatusGuardActor.default_uid() + ) + self._event_collector_ref = await xo.actor_ref( + address=self._supervisor_address, uid=EventCollectorActor.default_uid() + ) + self._cache_tracker_ref = await xo.actor_ref( + address=self._supervisor_address, uid=CacheTrackerActor.default_uid() + ) + self._progress_tracker_ref = None + # cache_tracker is on supervisor + from ..model.audio import get_audio_model_descriptions + from ..model.embedding import get_embedding_model_descriptions + from ..model.flexible import get_flexible_model_descriptions + from ..model.image import get_image_model_descriptions + from ..model.llm import get_llm_version_infos + from ..model.rerank import get_rerank_model_descriptions + from ..model.video import get_video_model_descriptions + + # record model version + model_version_infos: Dict[str, List[Dict]] = {} # type: ignore + model_version_infos.update(get_llm_version_infos()) + model_version_infos.update(get_embedding_model_descriptions()) + model_version_infos.update(get_rerank_model_descriptions()) + model_version_infos.update(get_image_model_descriptions()) + model_version_infos.update(get_audio_model_descriptions()) + model_version_infos.update(get_video_model_descriptions()) + model_version_infos.update(get_flexible_model_descriptions()) + await self._cache_tracker_ref.record_model_version( + model_version_infos, self.address + ) return self._supervisor_ref - self._supervisor_ref = supervisor_ref - if add_worker and len(self._model_uid_to_model) == 0: - # Newly started (or restarted), has no model, notify supervisor - await self._supervisor_ref.add_worker(self.address) - logger.info("Connected to supervisor as a fresh worker") - - self._status_guard_ref = await xo.actor_ref( - address=self._supervisor_address, uid=StatusGuardActor.default_uid() - ) - self._event_collector_ref = await xo.actor_ref( - address=self._supervisor_address, uid=EventCollectorActor.default_uid() - ) - self._cache_tracker_ref = await xo.actor_ref( - address=self._supervisor_address, uid=CacheTrackerActor.default_uid() - ) - self._progress_tracker_ref = None - # cache_tracker is on supervisor - from ..model.audio import get_audio_model_descriptions - from ..model.embedding import get_embedding_model_descriptions - from ..model.flexible import get_flexible_model_descriptions - from ..model.image import get_image_model_descriptions - from ..model.llm import get_llm_version_infos - from ..model.rerank import get_rerank_model_descriptions - from ..model.video import get_video_model_descriptions - - # record model version - model_version_infos: Dict[str, List[Dict]] = {} # type: ignore - model_version_infos.update(get_llm_version_infos()) - model_version_infos.update(get_embedding_model_descriptions()) - model_version_infos.update(get_rerank_model_descriptions()) - model_version_infos.update(get_image_model_descriptions()) - model_version_infos.update(get_audio_model_descriptions()) - model_version_infos.update(get_video_model_descriptions()) - model_version_infos.update(get_flexible_model_descriptions()) - await self._cache_tracker_ref.record_model_version( - model_version_infos, self.address - ) - return self._supervisor_ref @staticmethod def get_devices_count(): @@ -1836,7 +1856,20 @@ async def report_status(self): except Exception: logger.exception("Report status got error.") supervisor_ref = await self.get_supervisor_ref() - await supervisor_ref.report_worker_status(self.address, status) + try: + await asyncio.wait_for( + supervisor_ref.report_worker_status(self.address, status), + timeout=XINFERENCE_HEALTH_CHECK_TIMEOUT, + ) + except asyncio.TimeoutError: + logger.warning( + "report_worker_status timed out, will reset supervisor refs for retry" + ) + await self._reset_supervisor_refs() + raise + except Exception: + await self._reset_supervisor_refs() + raise async def _periodical_report_status(self): while True: