diff --git a/.gitignore b/.gitignore index d397292..f7fa981 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ launch.json +.venv/ __pycache__ voxcpm.egg-info .DS_Store ./pretrained_models/ -app_local.py \ No newline at end of file +app_local.py diff --git a/app.py b/app.py index dba6fe3..95eac94 100644 --- a/app.py +++ b/app.py @@ -3,7 +3,6 @@ import sys import logging import numpy as np -import torch import gradio as gr from typing import Optional, Tuple from funasr import AutoModel @@ -12,6 +11,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" import voxcpm +from voxcpm.model.utils import resolve_runtime_device logging.basicConfig( level=logging.INFO, @@ -220,17 +220,14 @@ # ---------- Model ---------- class VoxCPMDemo: - def __init__(self, model_id: str = "openbmb/VoxCPM2") -> None: - self.device = "cuda" if torch.cuda.is_available() else "cpu" - logger.info(f"Running on device: {self.device}") + def __init__(self, model_id: str = "openbmb/VoxCPM2", device: str = "auto") -> None: + self.device = resolve_runtime_device(device, "cuda") + logger.info(f"Running VoxCPM on device: {self.device}") + self.optimize = self.device.startswith("cuda") self.asr_model_id = "iic/SenseVoiceSmall" - self.asr_model: Optional[AutoModel] = AutoModel( - model=self.asr_model_id, - disable_update=True, - log_level="DEBUG", - device="cuda:0" if self.device == "cuda" else "cpu", - ) + self.asr_device = "cuda:0" if self.device.startswith("cuda") else "cpu" + self.asr_model: Optional[AutoModel] = None self.voxcpm_model: Optional[voxcpm.VoxCPM] = None self._model_id = model_id @@ -239,14 +236,37 @@ def get_or_load_voxcpm(self) -> voxcpm.VoxCPM: if self.voxcpm_model is not None: return self.voxcpm_model logger.info(f"Loading model: {self._model_id}") - self.voxcpm_model = voxcpm.VoxCPM.from_pretrained(self._model_id, optimize=True) + self.voxcpm_model = voxcpm.VoxCPM.from_pretrained( + self._model_id, + optimize=self.optimize, + device=self.device, + ) logger.info("Model loaded successfully.") return self.voxcpm_model + def get_or_load_asr_model(self) -> AutoModel: + if self.asr_model is not None: + return self.asr_model + logger.info( + f"Loading ASR model: {self.asr_model_id} on device: {self.asr_device}" + ) + self.asr_model = AutoModel( + model=self.asr_model_id, + disable_update=True, + log_level="DEBUG", + device=self.asr_device, + ) + logger.info("ASR model loaded successfully.") + return self.asr_model + def prompt_wav_recognition(self, prompt_wav: Optional[str]) -> str: if prompt_wav is None: return "" - res = self.asr_model.generate(input=prompt_wav, language="auto", use_itn=True) + res = self.get_or_load_asr_model().generate( + input=prompt_wav, + language="auto", + use_itn=True, + ) return res[0]["text"].split("|>")[-1] def _build_generate_kwargs( @@ -487,8 +507,9 @@ def run_demo( server_port: int = 8808, show_error: bool = True, model_id: str = "openbmb/VoxCPM2", + device: str = "auto", ): - demo = VoxCPMDemo(model_id=model_id) + demo = VoxCPMDemo(model_id=model_id, device=device) interface = create_demo_interface(demo) interface.queue(max_size=10, default_concurrency_limit=1).launch( server_name=server_name, @@ -508,5 +529,11 @@ def run_demo( help="Local path or HuggingFace repo ID (default: openbmb/VoxCPM2)", ) parser.add_argument("--port", type=int, default=8808, help="Server port") + parser.add_argument( + "--device", + type=str, + default="auto", + help="Runtime device: auto, cpu, mps, cuda, or cuda:N (default: auto)", + ) args = parser.parse_args() - run_demo(model_id=args.model_id, server_port=args.port) + run_demo(model_id=args.model_id, server_port=args.port, device=args.device) diff --git a/src/voxcpm/cli.py b/src/voxcpm/cli.py index f6d40d6..2ccf5bb 100644 --- a/src/voxcpm/cli.py +++ b/src/voxcpm/cli.py @@ -11,6 +11,10 @@ import sys from pathlib import Path +import soundfile as sf + +from voxcpm.core import VoxCPM + DEFAULT_HF_MODEL_ID = "openbmb/VoxCPM2" # ----------------------------- @@ -169,8 +173,6 @@ def validate_batch_args(args, parser): def load_model(args): - from voxcpm.core import VoxCPM - print("Loading VoxCPM model...", file=sys.stderr) zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get( @@ -263,8 +265,6 @@ def _run_single(args, parser, *, text: str, output: str, prompt_text: str | None and (args.prompt_audio is not None or args.reference_audio is not None), ) - import soundfile as sf - sf.write(str(output_path), audio_array, model.tts_model.sample_rate) duration = len(audio_array) / model.tts_model.sample_rate @@ -306,8 +306,6 @@ def cmd_validate(args, parser): def cmd_batch(args, parser): - import soundfile as sf - input_file = require_file_exists(args.input, parser, "input file") output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True)