Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
launch.json
.venv/
__pycache__
voxcpm.egg-info
.DS_Store
./pretrained_models/
app_local.py
app_local.py
55 changes: 41 additions & 14 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import sys
import logging
import numpy as np
import torch
import gradio as gr
from typing import Optional, Tuple
from funasr import AutoModel
Expand All @@ -12,6 +11,7 @@
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import voxcpm
from voxcpm.model.utils import resolve_runtime_device

logging.basicConfig(
level=logging.INFO,
Expand Down Expand Up @@ -220,17 +220,14 @@
# ---------- Model ----------

class VoxCPMDemo:
def __init__(self, model_id: str = "openbmb/VoxCPM2") -> None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Running on device: {self.device}")
def __init__(self, model_id: str = "openbmb/VoxCPM2", device: str = "auto") -> None:
self.device = resolve_runtime_device(device, "cuda")
logger.info(f"Running VoxCPM on device: {self.device}")
self.optimize = self.device.startswith("cuda")

self.asr_model_id = "iic/SenseVoiceSmall"
self.asr_model: Optional[AutoModel] = AutoModel(
model=self.asr_model_id,
disable_update=True,
log_level="DEBUG",
device="cuda:0" if self.device == "cuda" else "cpu",
)
self.asr_device = "cuda:0" if self.device.startswith("cuda") else "cpu"
self.asr_model: Optional[AutoModel] = None

self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
self._model_id = model_id
Expand All @@ -239,14 +236,37 @@ def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
if self.voxcpm_model is not None:
return self.voxcpm_model
logger.info(f"Loading model: {self._model_id}")
self.voxcpm_model = voxcpm.VoxCPM.from_pretrained(self._model_id, optimize=True)
self.voxcpm_model = voxcpm.VoxCPM.from_pretrained(
self._model_id,
optimize=self.optimize,
device=self.device,
)
logger.info("Model loaded successfully.")
return self.voxcpm_model

def get_or_load_asr_model(self) -> AutoModel:
if self.asr_model is not None:
return self.asr_model
logger.info(
f"Loading ASR model: {self.asr_model_id} on device: {self.asr_device}"
)
self.asr_model = AutoModel(
model=self.asr_model_id,
disable_update=True,
log_level="DEBUG",
device=self.asr_device,
)
logger.info("ASR model loaded successfully.")
return self.asr_model

def prompt_wav_recognition(self, prompt_wav: Optional[str]) -> str:
if prompt_wav is None:
return ""
res = self.asr_model.generate(input=prompt_wav, language="auto", use_itn=True)
res = self.get_or_load_asr_model().generate(
input=prompt_wav,
language="auto",
use_itn=True,
)
return res[0]["text"].split("|>")[-1]

def _build_generate_kwargs(
Expand Down Expand Up @@ -487,8 +507,9 @@ def run_demo(
server_port: int = 8808,
show_error: bool = True,
model_id: str = "openbmb/VoxCPM2",
device: str = "auto",
):
demo = VoxCPMDemo(model_id=model_id)
demo = VoxCPMDemo(model_id=model_id, device=device)
interface = create_demo_interface(demo)
interface.queue(max_size=10, default_concurrency_limit=1).launch(
server_name=server_name,
Expand All @@ -508,5 +529,11 @@ def run_demo(
help="Local path or HuggingFace repo ID (default: openbmb/VoxCPM2)",
)
parser.add_argument("--port", type=int, default=8808, help="Server port")
parser.add_argument(
"--device",
type=str,
default="auto",
help="Runtime device: auto, cpu, mps, cuda, or cuda:N (default: auto)",
)
args = parser.parse_args()
run_demo(model_id=args.model_id, server_port=args.port)
run_demo(model_id=args.model_id, server_port=args.port, device=args.device)
10 changes: 4 additions & 6 deletions src/voxcpm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
import sys
from pathlib import Path

import soundfile as sf

from voxcpm.core import VoxCPM

DEFAULT_HF_MODEL_ID = "openbmb/VoxCPM2"

# -----------------------------
Expand Down Expand Up @@ -169,8 +173,6 @@ def validate_batch_args(args, parser):


def load_model(args):
from voxcpm.core import VoxCPM

print("Loading VoxCPM model...", file=sys.stderr)

zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
Expand Down Expand Up @@ -263,8 +265,6 @@ def _run_single(args, parser, *, text: str, output: str, prompt_text: str | None
and (args.prompt_audio is not None or args.reference_audio is not None),
)

import soundfile as sf

sf.write(str(output_path), audio_array, model.tts_model.sample_rate)

duration = len(audio_array) / model.tts_model.sample_rate
Expand Down Expand Up @@ -306,8 +306,6 @@ def cmd_validate(args, parser):


def cmd_batch(args, parser):
import soundfile as sf

input_file = require_file_exists(args.input, parser, "input file")
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
Expand Down