diff --git a/app.py b/app.py index dba6fe3..9bf28c1 100644 --- a/app.py +++ b/app.py @@ -2,9 +2,11 @@ import re import sys import logging +import tempfile import numpy as np import torch import gradio as gr +import soundfile as sf from typing import Optional, Tuple from funasr import AutoModel from pathlib import Path @@ -235,6 +237,10 @@ def __init__(self, model_id: str = "openbmb/VoxCPM2") -> None: self.voxcpm_model: Optional[voxcpm.VoxCPM] = None self._model_id = model_id + # Cache for reference audio to avoid Gradio temp file cleanup issues + self._cached_audio_path: Optional[str] = None + self._cached_audio_data: Optional[Tuple[int, np.ndarray]] = None + def get_or_load_voxcpm(self) -> voxcpm.VoxCPM: if self.voxcpm_model is not None: return self.voxcpm_model @@ -243,6 +249,63 @@ def get_or_load_voxcpm(self) -> voxcpm.VoxCPM: logger.info("Model loaded successfully.") return self.voxcpm_model + def _get_stable_audio_path(self, audio_path: Optional[str]) -> Optional[str]: + """Get a stable audio path by caching audio data to avoid Gradio temp file cleanup issues. + + When Gradio provides a temp file path for uploaded/recorded audio, that file may be + cleaned up after the first generation attempt. This method reads the audio data and + saves it to a persistent temp file that survives multiple generation calls. + + Args: + audio_path: The audio file path provided by Gradio (may be a temp file) + + Returns: + A stable path to the cached audio file, or None if input is None + """ + if audio_path is None: + # Clear cache when no audio provided + self._cached_audio_path = None + self._cached_audio_data = None + return None + + # Check if we already have this audio cached + if self._cached_audio_path == audio_path and self._cached_audio_data is not None: + # Audio path matches cache, but verify the cached file still exists + cached_sr, cached_wav = self._cached_audio_data + cached_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") + try: + sf.write(cached_file.name, cached_wav, cached_sr) + logger.info(f"Using cached audio data, saved to: {cached_file.name}") + return cached_file.name + except Exception as e: + logger.warning(f"Failed to write cached audio: {e}") + cached_file.close() + os.unlink(cached_file.name) + + # Try to read the audio file + try: + audio_data, sr = sf.read(audio_path) + # Ensure mono + if len(audio_data.shape) > 1: + audio_data = audio_data.mean(axis=1) + audio_data = audio_data.astype(np.float32) + + # Cache the audio data + self._cached_audio_path = audio_path + self._cached_audio_data = (sr, audio_data) + + # Create a new stable temp file + cached_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") + sf.write(cached_file.name, audio_data, sr) + logger.info(f"Read audio from {audio_path}, cached to: {cached_file.name}") + return cached_file.name + except Exception as e: + logger.warning(f"Failed to read audio file {audio_path}: {e}") + # If the original file still exists, try using it directly + if os.path.exists(audio_path): + return audio_path + return None + def prompt_wav_recognition(self, prompt_wav: Optional[str]) -> str: if prompt_wav is None: return "" @@ -296,7 +359,8 @@ def generate_tts_audio( control = re.sub(r"[()()]", "", control).strip() final_text = f"({control}){text}" if control else text - audio_path = reference_wav_path_input if reference_wav_path_input else None + # Use cached audio path to avoid Gradio temp file cleanup issues + audio_path = self._get_stable_audio_path(reference_wav_path_input) prompt_text_clean = (prompt_text or "").strip() or None if audio_path and prompt_text_clean: @@ -368,7 +432,12 @@ def _run_asr_if_needed(checked, audio_path): return gr.update() try: logger.info("Running ASR on reference audio...") - asr_text = demo.prompt_wav_recognition(audio_path) + # Use cached audio path for ASR to avoid temp file cleanup issues + stable_audio_path = demo._get_stable_audio_path(audio_path) + if stable_audio_path is None: + logger.warning("Could not get stable audio path for ASR") + return gr.update(value="") + asr_text = demo.prompt_wav_recognition(stable_audio_path) logger.info(f"ASR result: {asr_text[:60]}...") return gr.update(value=asr_text) except Exception as e: diff --git a/src/voxcpm/model/voxcpm.py b/src/voxcpm/model/voxcpm.py index 445618b..592db3e 100644 --- a/src/voxcpm/model/voxcpm.py +++ b/src/voxcpm/model/voxcpm.py @@ -237,10 +237,12 @@ def optimize(self, disable: bool = False): self.residual_lm.forward_step = torch.compile( self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True ) + # Use default mode for feat_encoder to avoid CUDA graphs issues with dynamic shapes + # reduce-overhead mode requires fixed input shapes (CUDA graphs) self._feat_encoder_raw = self.feat_encoder - self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True) + self.feat_encoder = torch.compile(self.feat_encoder, mode="default") self.feat_decoder.estimator = torch.compile( - self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True + self.feat_decoder.estimator, mode="default" ) except Exception as e: print(f"Warning: torch.compile disabled - {e}", file=sys.stderr) diff --git a/src/voxcpm/model/voxcpm2.py b/src/voxcpm/model/voxcpm2.py index 90495f1..95683b2 100644 --- a/src/voxcpm/model/voxcpm2.py +++ b/src/voxcpm/model/voxcpm2.py @@ -285,10 +285,12 @@ def optimize(self, disable: bool = False): self.residual_lm.forward_step = torch.compile( self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True ) + # Use default mode for feat_encoder to avoid CUDA graphs issues with dynamic shapes + # reduce-overhead mode requires fixed input shapes (CUDA graphs) self._feat_encoder_raw = self.feat_encoder - self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True) + self.feat_encoder = torch.compile(self.feat_encoder, mode="default") self.feat_decoder.estimator = torch.compile( - self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True + self.feat_decoder.estimator, mode="default" ) except Exception as e: print(f"Warning: torch.compile disabled - {e}", file=sys.stderr)