diff --git a/docs/parakeet-issues.md b/docs/parakeet-issues.md deleted file mode 100644 index 5e2570034..000000000 --- a/docs/parakeet-issues.md +++ /dev/null @@ -1,62 +0,0 @@ -# Parakeet STT Issues (2026-04-07) - -## Issue 1: macOS CoreML crash - -**Symptom**: Switching execution provider to CoreML gives `ONNXRuntimeError: model_path must not be empty` from `onnxruntime/core/optimizer/initializer.cc:45`. - -**Root cause**: The `NemoConformerAED` model class explicitly excludes CoreML in `onnx_asr/models/nemo.py:183`: -```python -def _get_excluded_providers() -> list[str]: - return [*TensorRtOptions.get_provider_names(), "CoreMLExecutionProvider"] -``` -But `NemoConformerTdt` (the model Parakeet uses — `nemo-parakeet-tdt-0.6b-v2/v3`) inherits from `NemoConformerRnnt` which has **no CoreML exclusion**. So CoreML is passed to the ONNX session, but the TDT model uses external data files (`onnx?data` pattern in `loader.py:134`) that CoreML can't handle. - -**Fix options**: -1. Exclude CoreML from TDT models (like AED does) — simplest, but removes the option -2. Catch the error in `parakeet.py:__load_model_inner()` and fall back to CPU with a toast message -3. Hide the CoreML option in the Client UI when the model variant doesn't support it (requires Core→Client communication of supported providers) - -## Issue 2: Windows CUDA doesn't reinitialize properly - -**Symptom**: Switching execution provider to CUDA downloads the model but STT doesn't work until full app restart. - -**Root cause (suspected)**: `del self.model` in `parakeet.py:101` drops the Python reference to the old ONNX session, but CUDA GPU memory isn't released immediately (Python GC is non-deterministic). When the new CUDA session tries to allocate GPU memory, the old one may still be holding it. - -**Fix options**: -1. Force `gc.collect()` after unloading the model before loading the new one -2. Add explicit ONNX session cleanup (call `self.model` internals to release sessions) -3. Add a brief delay between unload and reload for CUDA specifically - -## Broader architecture issues - -- `__load_model_inner()` has a generic `except Exception` that toasts an error but leaves `self.model = None`. No retry, no CPU fallback. -- No way for the user to recover without restarting the app if reload fails. -- The `_loading` flag prevents transcription during reload, but there's no timeout or progress feedback. -- The Preprocessor (`onnx_asr`) explicitly excludes CUDA — it always runs on CPU regardless of the selected provider. This is by design but may confuse users expecting full CUDA acceleration. - -## Key files - -### Core (wingman-ai) -- `providers/parakeet.py` — Parakeet provider, model loading, settings update -- `services/settings_service.py:144-148` — calls `parakeet.update_settings_async()` -- `api/interface.py:183-192` — `ParakeetSettings` dataclass - -### onnx_asr library (3rd party, in venv) -- `onnx_asr/loader.py:187-357` — `load_model()`, downloads files, creates ONNX sessions -- `onnx_asr/models/nemo.py:70-85` — `NemoConformerRnnt.__init__()` creates encoder/decoder sessions -- `onnx_asr/models/nemo.py:181-183` — `NemoConformerAED._get_excluded_providers()` excludes CoreML -- `onnx_asr/preprocessors/preprocessor.py` — Preprocessor, excludes CUDA -- `onnx_asr/onnx.py:81-101` — `update_onnx_providers()` filters provider list - -### Client (wingman-client) -- The execution provider dropdown is in the STT settings UI — may need to filter options per model/platform - -## Provider/model compatibility matrix - -| Provider | NemoConformerTdt (Parakeet) | NemoConformerAED | Preprocessor | -|----------|---------------------------|------------------|-------------| -| CPU | Yes | Yes | Yes | -| DirectML | Yes | Yes | Yes | -| CUDA | Yes | Yes | **Excluded** | -| CoreML | **Crashes** (should exclude) | **Excluded** | ? | -| TensorRT | Excluded | Excluded | Excluded | diff --git a/providers/edge.py b/providers/edge.py index 6b9966567..b20a8922a 100644 --- a/providers/edge.py +++ b/providers/edge.py @@ -1,10 +1,16 @@ from os import path +from typing import TYPE_CHECKING from edge_tts import Communicate +from api.enums import TtsProvider from api.interface import EdgeTtsConfig, SoundConfig +from providers.interfaces import TtsInterface, tts_provider from services.audio_player import AudioPlayer from services.file import get_writable_dir from services.printr import Printr +if TYPE_CHECKING: + from api.interface import WingmanConfig + RECORDING_PATH = "audio_output" OUTPUT_FILE: str = "edge_tts.mp3" @@ -48,3 +54,19 @@ async def __generate_speech( await communicate.save(file_path) return communicate, file_path + + +@tts_provider(TtsProvider.EDGE_TTS) +class EdgeTts(TtsInterface): + def __init__(self, config: "WingmanConfig"): + self._edge = Edge() + self._config = config + + async def play_audio(self, text, sound_config, audio_player, wingman_name): + await self._edge.play_audio( + text=text, + config=self._config.edge_tts, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + ) diff --git a/providers/elevenlabs.py b/providers/elevenlabs.py index ec045272a..bd73497a8 100644 --- a/providers/elevenlabs.py +++ b/providers/elevenlabs.py @@ -1,17 +1,21 @@ import asyncio -from typing import Callable, Optional +from typing import TYPE_CHECKING, Callable, Optional import requests from threading import Event, Thread import numpy as np import sounddevice as sd from elevenlabslib import User, GenerationOptions, PlaybackOptions, SFXOptions -from api.enums import LogType, SoundEffect, WingmanInitializationErrorType +from api.enums import LogType, SoundEffect, TtsProvider, WingmanInitializationErrorType from api.interface import ElevenlabsConfig, SoundConfig, WingmanInitializationError +from providers.interfaces import TtsInterface, tts_provider from services.audio_player import AudioPlayer from services.printr import Printr from services.sound_effects import get_sound_effects from services.websocket_user import WebSocketUser +if TYPE_CHECKING: + from api.interface import WingmanConfig + class ElevenLabs: def __init__(self, api_key: str, wingman_name: str): @@ -401,3 +405,20 @@ def get_available_models(self): def get_subscription_data(self): return self.user.get_subscription_data() + + +@tts_provider(TtsProvider.ELEVENLABS) +class ElevenLabsTts(TtsInterface): + def __init__(self, elevenlabs_instance: "ElevenLabs", config: "WingmanConfig"): + self._elevenlabs = elevenlabs_instance + self._config = config + + async def play_audio(self, text, sound_config, audio_player, wingman_name): + await self._elevenlabs.play_audio( + text=text, + config=self._config.elevenlabs, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + stream=self._config.elevenlabs.output_streaming, + ) diff --git a/providers/faster_whisper.py b/providers/faster_whisper.py index 842ce1aa2..34f619a30 100644 --- a/providers/faster_whisper.py +++ b/providers/faster_whisper.py @@ -1,16 +1,20 @@ from os import path import gc -from typing import Optional +from typing import TYPE_CHECKING, Optional from faster_whisper import WhisperModel -from api.enums import LogType +from api.enums import LogType, SttProvider from api.interface import ( FasterWhisperSettings, FasterWhisperTranscript, FasterWhisperSttConfig, WingmanInitializationError, ) +from providers.interfaces import SttInterface, Transcript, stt_provider from services.printr import Printr +if TYPE_CHECKING: + from api.interface import WingmanConfig + class FasterWhisper: def __init__(self, settings: FasterWhisperSettings): @@ -116,3 +120,34 @@ def transcribe( def validate(self, errors: list[WingmanInitializationError]): pass + + +@stt_provider(SttProvider.FASTER_WHISPER) +class FasterWhisperStt(SttInterface): + """Per-wingman adapter around the shared FasterWhisper singleton.""" + + def __init__(self, shared: "FasterWhisper", config: "WingmanConfig", wingman_name: str): + self._shared = shared + self._config = config + self._wingman_name = wingman_name + + async def transcribe(self, filename: str) -> Transcript | None: + hotwords: list[str] = [self._wingman_name] + default_hotwords = self._config.fasterwhisper.hotwords + if default_hotwords: + hotwords.extend(default_hotwords) + wingman_hotwords = self._config.fasterwhisper.additional_hotwords + if wingman_hotwords: + hotwords.extend(wingman_hotwords) + + result = self._shared.transcribe( + config=self._config.fasterwhisper, + filename=filename, + hotwords=list(set(hotwords)), + ) + if result is None: + return None + return Transcript( + text=result.text, + language=getattr(result, "language", None), + ) diff --git a/providers/google.py b/providers/google.py index 4bddeb65c..feadd2c53 100644 --- a/providers/google.py +++ b/providers/google.py @@ -1,9 +1,15 @@ import re +from typing import TYPE_CHECKING from google import genai from google.genai import types from openai import APIStatusError, OpenAI +from api.enums import ConversationProvider +from providers.interfaces import LlmInterface, llm_provider from services.printr import Printr +if TYPE_CHECKING: + from api.interface import WingmanConfig + printr = Printr() @@ -139,3 +145,16 @@ def get_available_models(self): if action == "generateContent": models.append(model) return models + + +@llm_provider(ConversationProvider.GOOGLE) +class GoogleLlm(LlmInterface): + def __init__(self, google_instance: "GoogleGenAI", config: "WingmanConfig"): + self._google = google_instance + self._config = config + + async def ask(self, messages, tools=None): + return self._google.ask( + messages=messages, tools=tools, + model=self._config.google.conversation_model, + ) diff --git a/providers/hume.py b/providers/hume.py index 902493e5e..eacc597e2 100644 --- a/providers/hume.py +++ b/providers/hume.py @@ -1,5 +1,6 @@ import base64 from os import path +from typing import TYPE_CHECKING import aiofiles from hume.client import AsyncHumeClient from hume.tts import ( @@ -7,17 +8,22 @@ PostedUtteranceVoiceWithId, PostedContextWithGenerationId, ) +from api.enums import TtsProvider from api.interface import ( HumeConfig, SoundConfig, VoiceInfo, WingmanInitializationError, ) +from providers.interfaces import TtsInterface, tts_provider from services.audio_player import AudioPlayer from services.file import get_writable_dir from services.printr import Printr from services.secret_keeper import SecretKeeper +if TYPE_CHECKING: + from api.interface import WingmanConfig + RECORDING_PATH = "audio_output" OUTPUT_FILE: str = "hume.mp3" @@ -106,3 +112,34 @@ async def __write_result_to_file(self, base64_encoded_audio: str): async with aiofiles.open(file_path, "wb") as f: await f.write(audio_data) return file_path + + +@tts_provider(TtsProvider.HUME) +class HumeTts(TtsInterface): + def __init__(self, hume_instance: "Hume", config: "WingmanConfig"): + self._hume = hume_instance + self._config = config + + async def play_audio(self, text, sound_config, audio_player, wingman_name): + try: + await self._hume.play_audio( + text=text, + config=self._config.hume, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + ) + except RuntimeError as e: + if "Event loop is closed" in str(e): + import asyncio + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + await self._hume.play_audio( + text=text, + config=self._config.hume, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + ) + else: + raise diff --git a/providers/interfaces.py b/providers/interfaces.py new file mode 100644 index 000000000..d1c1bf331 --- /dev/null +++ b/providers/interfaces.py @@ -0,0 +1,169 @@ +"""Unified provider interfaces for STT, TTS, and LLM providers. + +ABCs define required contracts — providers that don't implement them crash at instantiation. +Protocols define optional capabilities — check with isinstance() before calling. +Registration decorators map config enum values to provider classes for ProviderFactory lookup. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +from api.enums import ConversationProvider, SttProvider, TtsProvider + +if TYPE_CHECKING: + from openai.types.chat import ChatCompletion + + from api.interface import SoundConfig, WingmanInitializationError + from services.audio_player import AudioPlayer + + +# --------------------------------------------------------------------------- +# Unified return types +# --------------------------------------------------------------------------- + +@dataclass +class Transcript: + """Unified STT result. Every provider wraps its native result into this.""" + + text: str + language: str | None = None + confidence: float | None = None + + +# --------------------------------------------------------------------------- +# Core ABCs (required contracts) +# --------------------------------------------------------------------------- + +class SttInterface(ABC): + """Speech-to-text provider interface.""" + + @abstractmethod + async def transcribe(self, filename: str) -> Transcript | None: + """Transcribe an audio file to text.""" + ... + + +class TtsInterface(ABC): + """Text-to-speech provider interface.""" + + @abstractmethod + async def play_audio( + self, + text: str, + sound_config: "SoundConfig", + audio_player: "AudioPlayer", + wingman_name: str, + ) -> None: + """Synthesize speech and play it.""" + ... + + +class LlmInterface(ABC): + """Large language model provider interface.""" + + @abstractmethod + async def ask( + self, + messages: list[dict], + tools: list[dict] | None = None, + ) -> "ChatCompletion | None": + """Send messages to the LLM and get a completion.""" + ... + + +# --------------------------------------------------------------------------- +# Optional Protocols (capability checks) +# --------------------------------------------------------------------------- + +@runtime_checkable +class HasAvailableVoices(Protocol): + """Provider can enumerate available voices.""" + + async def get_available_voices(self, **kwargs) -> list: ... + + +@runtime_checkable +class HasAvailableModels(Protocol): + """Provider can enumerate available models.""" + + async def get_available_models(self) -> list: ... + + +@runtime_checkable +class HasLifecycle(Protocol): + """Provider has load/unload lifecycle for local models.""" + + async def load(self) -> None: ... + async def unload(self) -> None: ... + + +@runtime_checkable +class Validatable(Protocol): + """Provider can validate its configuration.""" + + async def validate(self, errors: "list[WingmanInitializationError]") -> None: ... + + +@runtime_checkable +class HasMinimalReasoning(Protocol): + """LLM provider supports reasoning effort tuning (e.g., O-series, Gemini).""" + + def get_minimal_reasoning_by_model(self, model_name: str) -> dict: ... + + +# --------------------------------------------------------------------------- +# Registration decorators + registries +# --------------------------------------------------------------------------- + +_STT_REGISTRY: dict[SttProvider, type[SttInterface]] = {} +_TTS_REGISTRY: dict[TtsProvider, type[TtsInterface]] = {} +_LLM_REGISTRY: dict[ConversationProvider, type[LlmInterface]] = {} + + +def stt_provider(*provider_enums: SttProvider): + """Register a class as the STT provider for given enum value(s).""" + + def decorator(cls): + for enum_val in provider_enums: + _STT_REGISTRY[enum_val] = cls + return cls + + return decorator + + +def tts_provider(*provider_enums: TtsProvider): + """Register a class as the TTS provider for given enum value(s).""" + + def decorator(cls): + for enum_val in provider_enums: + _TTS_REGISTRY[enum_val] = cls + return cls + + return decorator + + +def llm_provider(*provider_enums: ConversationProvider): + """Register a class as the LLM provider for given enum value(s).""" + + def decorator(cls): + for enum_val in provider_enums: + _LLM_REGISTRY[enum_val] = cls + return cls + + return decorator + + +def get_stt_class(provider: SttProvider) -> type[SttInterface] | None: + """Look up the registered STT provider class for an enum value.""" + return _STT_REGISTRY.get(provider) + + +def get_tts_class(provider: TtsProvider) -> type[TtsInterface] | None: + """Look up the registered TTS provider class for an enum value.""" + return _TTS_REGISTRY.get(provider) + + +def get_llm_class(provider: ConversationProvider) -> type[LlmInterface] | None: + """Look up the registered LLM provider class for an enum value.""" + return _LLM_REGISTRY.get(provider) diff --git a/providers/inworld.py b/providers/inworld.py index 4a641c7d7..dbe32768a 100644 --- a/providers/inworld.py +++ b/providers/inworld.py @@ -1,24 +1,28 @@ import base64 import json from os import path -from typing import Optional +from typing import TYPE_CHECKING, Optional import threading import queue import time import requests import aiofiles -from api.enums import LogType +from api.enums import LogType, TtsProvider from api.interface import ( SoundConfig, VoiceInfo, WingmanInitializationError, InworldConfig, ) +from providers.interfaces import TtsInterface, tts_provider from services.audio_player import AudioPlayer from services.file import get_writable_dir from services.printr import Printr from services.secret_keeper import SecretKeeper +if TYPE_CHECKING: + from api.interface import WingmanConfig + RECORDING_PATH = "audio_output" OUTPUT_FILE: str = "inworld.mp3" @@ -278,3 +282,19 @@ async def __write_result_to_file(self, base64_encoded_audio: str): async with aiofiles.open(file_path, "wb") as f: await f.write(audio_data) return file_path + + +@tts_provider(TtsProvider.INWORLD) +class InworldTts(TtsInterface): + def __init__(self, inworld_instance: "Inworld", config: "WingmanConfig"): + self._inworld = inworld_instance + self._config = config + + async def play_audio(self, text, sound_config, audio_player, wingman_name): + await self._inworld.play_audio( + text=text, + config=self._config.inworld, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + ) diff --git a/providers/open_ai.py b/providers/open_ai.py index a177fadc1..2c94baaa0 100644 --- a/providers/open_ai.py +++ b/providers/open_ai.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod import json import re -from typing import Literal, Mapping, Union +from typing import TYPE_CHECKING, Literal, Mapping, Union import httpx from openai import NOT_GIVEN, NotGiven, Omit, OpenAI, APIStatusError, AzureOpenAI import azure.cognitiveservices.speech as speechsdk -from api.enums import AzureRegion, LogType +from api.enums import AzureRegion, ConversationProvider, LogType, SttProvider, TtsProvider from api.interface import ( AzureInstanceConfig, @@ -14,10 +14,17 @@ SoundConfig, VoiceInfo, ) +from providers.interfaces import ( + SttInterface, TtsInterface, LlmInterface, + Transcript, stt_provider, tts_provider, llm_provider, +) from services.audio_player import AudioPlayer from services.openai_utils import get_minimal_reasoning_by_model from services.printr import Printr +if TYPE_CHECKING: + from api.interface import WingmanConfig + printr = Printr() @@ -580,3 +587,252 @@ def buffer_callback(audio_buffer): printr.toast_error( "An unknown OpenAI-compatible TTS error has occurred." ) + + +@stt_provider(SttProvider.OPENAI) +class OpenAiStt(SttInterface): + """OpenAI Whisper STT via the unified interface.""" + + def __init__(self, openai_instance: "OpenAi"): + self._openai = openai_instance + + async def transcribe(self, filename: str) -> Transcript | None: + result = self._openai.transcribe(filename=filename) + if result is None: + return None + return Transcript(text=result.text) + + +@tts_provider(TtsProvider.OPENAI) +class OpenAiTts(TtsInterface): + """OpenAI TTS via the unified interface.""" + + def __init__(self, openai_instance: "OpenAi", config: "WingmanConfig"): + self._openai = openai_instance + self._config = config + + async def play_audio(self, text, sound_config, audio_player, wingman_name): + await self._openai.play_audio( + text=text, + voice=self._config.openai.tts_voice, + model=self._config.openai.tts_model, + speed=self._config.openai.tts_speed, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + stream=self._config.openai.output_streaming, + ) + + +@llm_provider(ConversationProvider.OPENAI) +class OpenAiLlm(LlmInterface): + """OpenAI chat completions via the unified interface.""" + + def __init__(self, openai_instance: "OpenAi", config: "WingmanConfig"): + self._openai = openai_instance + self._config = config + + async def ask(self, messages, tools=None): + return self._openai.ask( + messages=messages, + tools=tools, + model=self._config.openai.conversation_model, + ) + + +@stt_provider(SttProvider.GROQ) +class GroqStt(SttInterface): + """Groq Whisper STT (uses OpenAi-compatible client).""" + + def __init__(self, openai_instance: "OpenAi"): + self._openai = openai_instance + + async def transcribe(self, filename: str) -> Transcript | None: + result = self._openai.transcribe( + filename=filename, model="whisper-large-v3-turbo" + ) + if result is None: + return None + return Transcript(text=result.text) + + +@llm_provider(ConversationProvider.MISTRAL) +class MistralLlm(LlmInterface): + def __init__(self, openai_instance: "OpenAi", config: "WingmanConfig"): + self._openai = openai_instance + self._config = config + + async def ask(self, messages, tools=None): + return self._openai.ask( + messages=messages, tools=tools, + model=self._config.mistral.conversation_model, + ) + + +@llm_provider(ConversationProvider.GROQ) +class GroqLlm(LlmInterface): + def __init__(self, openai_instance: "OpenAi", config: "WingmanConfig"): + self._openai = openai_instance + self._config = config + + async def ask(self, messages, tools=None): + return self._openai.ask( + messages=messages, tools=tools, + model=self._config.groq.conversation_model, + ) + + +@llm_provider(ConversationProvider.CEREBRAS) +class CerebrasLlm(LlmInterface): + def __init__(self, openai_instance: "OpenAi", config: "WingmanConfig"): + self._openai = openai_instance + self._config = config + + async def ask(self, messages, tools=None): + return self._openai.ask( + messages=messages, tools=tools, + model=self._config.cerebras.conversation_model, + ) + + +@llm_provider(ConversationProvider.OPENROUTER) +class OpenRouterLlm(LlmInterface): + def __init__(self, openai_instance: "OpenAi", config: "WingmanConfig", + supports_tools: bool = False): + self._openai = openai_instance + self._config = config + self.supports_tools = supports_tools + + async def ask(self, messages, tools=None): + effective_tools = tools if self.supports_tools else None + return self._openai.ask( + messages=messages, tools=effective_tools, + model=self._config.openrouter.conversation_model, + ) + + +@llm_provider(ConversationProvider.LOCAL_LLM) +class LocalLlm(LlmInterface): + def __init__(self, openai_instance: "OpenAi | None", config: "WingmanConfig"): + self._openai = openai_instance + self._config = config + + async def ask(self, messages, tools=None): + if not self._openai: + raise RuntimeError( + f"Local LLM provider is not initialized. " + f"Please check your Local LLM endpoint configuration " + f"({self._config.local_llm.endpoint})." + ) + return self._openai.ask( + messages=messages, tools=tools, + model=self._config.local_llm.conversation_model, + ) + + +@llm_provider(ConversationProvider.PERPLEXITY) +class PerplexityLlm(LlmInterface): + def __init__(self, openai_instance: "OpenAi", config: "WingmanConfig"): + self._openai = openai_instance + self._config = config + + async def ask(self, messages, tools=None): + return self._openai.ask( + messages=messages, tools=tools, + model=self._config.perplexity.conversation_model.value, + ) + + +@stt_provider(SttProvider.AZURE) +class AzureWhisperStt(SttInterface): + def __init__(self, azure_instance: "OpenAiAzure", api_key: str, config: "WingmanConfig"): + self._azure = azure_instance + self._api_key = api_key + self._config = config + + async def transcribe(self, filename: str) -> Transcript | None: + result = self._azure.transcribe_whisper( + filename=filename, + api_key=self._api_key, + config=self._config.azure.whisper, + ) + if result is None: + return None + return Transcript(text=result.text) + + +@stt_provider(SttProvider.AZURE_SPEECH) +class AzureSpeechStt(SttInterface): + def __init__(self, azure_instance: "OpenAiAzure", api_key: str, config: "WingmanConfig"): + self._azure = azure_instance + self._api_key = api_key + self._config = config + + async def transcribe(self, filename: str) -> Transcript | None: + result = self._azure.transcribe_azure_speech( + filename=filename, + api_key=self._api_key, + config=self._config.azure.stt, + ) + if result is None: + return None + text = result.get("_text") if isinstance(result, dict) else result.text + return Transcript(text=text) if text else None + + +@tts_provider(TtsProvider.AZURE) +class AzureTts(TtsInterface): + def __init__(self, azure_instance: "OpenAiAzure", api_key: str, config: "WingmanConfig"): + self._azure = azure_instance + self._api_key = api_key + self._config = config + + async def play_audio(self, text, sound_config, audio_player, wingman_name): + await self._azure.play_audio( + text=text, + api_key=self._api_key, + config=self._config.azure.tts, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + ) + + +@llm_provider(ConversationProvider.AZURE) +class AzureLlm(LlmInterface): + def __init__(self, azure_instance: "OpenAiAzure", api_key: str, config: "WingmanConfig"): + self._azure = azure_instance + self._api_key = api_key + self._config = config + + async def ask(self, messages, tools=None): + return self._azure.ask( + messages=messages, + api_key=self._api_key, + config=self._config.azure.conversation, + tools=tools, + ) + + +@tts_provider(TtsProvider.OPENAI_COMPATIBLE) +class OpenAiCompatibleTtsAdapter(TtsInterface): + def __init__(self, tts_instance: "OpenAiCompatibleTts", config: "WingmanConfig"): + self._tts = tts_instance + self._config = config + + async def play_audio(self, text, sound_config, audio_player, wingman_name): + from openai import NOT_GIVEN + await self._tts.play_audio( + text=text, + voice=self._config.openai_compatible_tts.voice, + model=self._config.openai_compatible_tts.model, + speed=( + self._config.openai_compatible_tts.speed + if self._config.openai_compatible_tts.speed + else NOT_GIVEN + ), + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + stream=self._config.openai_compatible_tts.output_streaming, + ) diff --git a/providers/parakeet.py b/providers/parakeet.py index d07aa4275..daef9aa67 100644 --- a/providers/parakeet.py +++ b/providers/parakeet.py @@ -1,19 +1,23 @@ import gc import platform import threading -from typing import Optional +from typing import TYPE_CHECKING, Optional import requests -from api.enums import LogType +from api.enums import LogType, SttProvider from api.interface import ( ParakeetSettings, ParakeetSttConfig, ParakeetTranscript, WingmanInitializationError, ) +from providers.interfaces import SttInterface, Transcript, stt_provider from services.printr import Printr +if TYPE_CHECKING: + from api.interface import WingmanConfig + EXECUTION_PROVIDER_MAP = { "cpu": ["CPUExecutionProvider"], @@ -74,30 +78,40 @@ def _load_model_inner(self, model_path: Optional[str] = None): if not providers: providers = ["CPUExecutionProvider"] - load_kwargs = {"providers": providers} + # Filter requested providers against what ONNX Runtime actually has + # available, so we know up front whether CUDA will really be used. + try: + import onnxruntime as ort + + available = set(ort.get_available_providers()) + except Exception: + available = None + + effective_providers = providers + if available is not None: + effective_providers = [p for p in providers if p in available] + if not effective_providers: + effective_providers = ["CPUExecutionProvider"] + + if ( + self.settings.execution_provider == "cuda" + and "CUDAExecutionProvider" not in available + ): + self.printr.print( + "Parakeet: CUDA requested but not available in this ONNX Runtime build. " + "Using CPU fallback. For CUDA support, install onnxruntime-gpu.", + server_only=True, + color=LogType.WARNING, + ) + + load_kwargs = {"providers": effective_providers} if model_path: load_kwargs["path"] = model_path self.model = onnx_asr.load_model(model_name, **load_kwargs) - # Check if requested CUDA provider was actually available - if self.settings.execution_provider == "cuda": - try: - import onnxruntime as ort - - available = ort.get_available_providers() - if "CUDAExecutionProvider" not in available: - self.printr.print( - "Parakeet: CUDA requested but not available in this ONNX Runtime build. " - "Using CPU fallback. For CUDA support, install onnxruntime-gpu.", - server_only=True, - color=LogType.WARNING, - ) - except Exception: - pass - self.printr.print( - f"Parakeet initialized with model '{model_name}' (providers: {providers}).", + f"Parakeet initialized with model '{model_name}' (providers: {effective_providers}).", server_only=True, color=LogType.POSITIVE, ) @@ -202,3 +216,21 @@ def _transcribe_remote(self, filename: str) -> Optional[ParakeetTranscript]: def validate(self, errors: list[WingmanInitializationError]): pass + + +@stt_provider(SttProvider.PARAKEET) +class ParakeetStt(SttInterface): + """Per-wingman adapter around the shared Parakeet singleton.""" + + def __init__(self, shared: "Parakeet", config: "WingmanConfig"): + self._shared = shared + self._config = config + + async def transcribe(self, filename: str) -> Transcript | None: + result = self._shared.transcribe( + config=self._config.parakeet, + filename=filename, + ) + if result is None: + return None + return Transcript(text=result.text) diff --git a/providers/pocket_tts.py b/providers/pocket_tts.py index e6de98749..5902a1333 100644 --- a/providers/pocket_tts.py +++ b/providers/pocket_tts.py @@ -3,11 +3,11 @@ import sys import glob import asyncio -from typing import Optional +from typing import TYPE_CHECKING, Optional import torch import torchaudio from pocket_tts import TTSModel -from api.enums import LogType +from api.enums import LogType, TtsProvider from api.interface import ( PocketTTSConfig, SoundConfig, @@ -15,11 +15,15 @@ WingmanInitializationError, VoiceInfo, ) +from providers.interfaces import TtsInterface, tts_provider from services.file import get_custom_voices_dir from services.audio_player import AudioPlayer from services.printr import Printr from providers.open_ai import OpenAiCompatibleTts +if TYPE_CHECKING: + from api.interface import WingmanConfig + MODELS_DIR = "pocket-tts-models" POCKET_TTS_VOICES_DIR = "embeddings" @@ -606,3 +610,21 @@ def _get_wingman_included_voices_dir(self) -> str: # Determine path to wingman included voices directory app_dir = self._get_app_dir() return os.path.join(app_dir, INCLUDED_VOICES_DIR) + + +@tts_provider(TtsProvider.POCKET_TTS) +class PocketTtsTts(TtsInterface): + """Per-wingman adapter around the shared PocketTTS singleton.""" + + def __init__(self, shared: "PocketTTS", config: "WingmanConfig"): + self._shared = shared + self._config = config + + async def play_audio(self, text, sound_config, audio_player, wingman_name): + await self._shared.play_audio( + text=text, + config=self._config.pocket_tts, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + ) diff --git a/providers/whispercpp.py b/providers/whispercpp.py index 81e3b3137..3041ef28d 100644 --- a/providers/whispercpp.py +++ b/providers/whispercpp.py @@ -1,13 +1,18 @@ +from typing import TYPE_CHECKING import requests -from api.enums import LogType +from api.enums import LogType, SttProvider from api.interface import ( WhispercppSettings, WhispercppSttConfig, WhispercppTranscript, WingmanInitializationError, ) +from providers.interfaces import SttInterface, Transcript, stt_provider from services.printr import Printr +if TYPE_CHECKING: + from api.interface import WingmanConfig + class Whispercpp: def __init__( @@ -91,3 +96,21 @@ def __is_server_running(self, timeout=5): return response.ok except Exception: return False + + +@stt_provider(SttProvider.WHISPERCPP) +class WhispercppStt(SttInterface): + """Per-wingman adapter around the shared Whispercpp singleton.""" + + def __init__(self, shared: "Whispercpp", config: "WingmanConfig"): + self._shared = shared + self._config = config + + async def transcribe(self, filename: str) -> Transcript | None: + result = self._shared.transcribe( + filename=filename, + config=self._config.whispercpp, + ) + if result is None: + return None + return Transcript(text=result.text) diff --git a/providers/wingman_pro.py b/providers/wingman_subscription.py similarity index 82% rename from providers/wingman_pro.py rename to providers/wingman_subscription.py index 25d2a7240..b162e3a7f 100644 --- a/providers/wingman_pro.py +++ b/providers/wingman_subscription.py @@ -1,8 +1,14 @@ -from typing import Optional +from typing import TYPE_CHECKING, Optional import openai import requests from openai.types.audio import Transcription -from api.enums import CommandTag, LogType +from api.enums import ( + CommandTag, + ConversationProvider, + LogType, + SttProvider, + TtsProvider, +) from api.interface import ( AzureSttConfig, AzureTtsConfig, @@ -11,13 +17,25 @@ VoiceInfo, WingmanProSettings, ) +from providers.interfaces import ( + SttInterface, + TtsInterface, + LlmInterface, + Transcript, + stt_provider, + tts_provider, + llm_provider, +) from services.audio_player import AudioPlayer from services.openai_utils import get_minimal_reasoning_by_model from services.printr import Printr from services.secret_keeper import SecretKeeper +if TYPE_CHECKING: + from api.interface import WingmanConfig + -class WingmanPro: +class WingmanSubscription: def __init__( self, wingman_name: str, settings: WingmanProSettings, timeout: int = 120 ): @@ -470,3 +488,81 @@ def __remove_nones(self, obj): ) else: return obj + + +# --------------------------------------------------------------------------- +# Adapter classes — bridge WingmanSubscription into unified provider interfaces +# --------------------------------------------------------------------------- + + +@stt_provider(SttProvider.WINGMAN_PRO) +class WingmanSubscriptionStt(SttInterface): + def __init__(self, ws_instance: "WingmanSubscription", config: "WingmanConfig"): + self._ws = ws_instance + self._config = config + + async def transcribe(self, filename: str) -> Transcript | None: + from api.enums import WingmanProSttProvider + if self._config.wingman_pro.stt_provider == WingmanProSttProvider.WHISPER: + result = self._ws.transcribe_whisper(filename=filename) + elif self._config.wingman_pro.stt_provider == WingmanProSttProvider.AZURE_SPEECH: + result = self._ws.transcribe_azure_speech( + filename=filename, config=self._config.azure.stt + ) + else: + return None + if result is None: + return None + # WingmanSubscription might return a dict instead of a real transcript object + text = result.get("_text") if isinstance(result, dict) else result.text + return Transcript(text=text) if text else None + + +@tts_provider(TtsProvider.WINGMAN_PRO) +class WingmanSubscriptionTts(TtsInterface): + def __init__(self, ws_instance: "WingmanSubscription", config: "WingmanConfig"): + self._ws = ws_instance + self._config = config + + async def play_audio(self, text, sound_config, audio_player, wingman_name): + from api.enums import WingmanProTtsProvider + if self._config.wingman_pro.tts_provider == WingmanProTtsProvider.OPENAI: + await self._ws.generate_openai_speech( + text=text, + voice=self._config.openai.tts_voice, + model=self._config.openai.tts_model, + speed=self._config.openai.tts_speed, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + ) + elif self._config.wingman_pro.tts_provider == WingmanProTtsProvider.AZURE: + await self._ws.generate_azure_speech( + text=text, + config=self._config.azure.tts, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + ) + elif self._config.wingman_pro.tts_provider == WingmanProTtsProvider.INWORLD: + await self._ws.generate_inworld_speech( + text=text, + config=self._config.inworld, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + ) + + +@llm_provider(ConversationProvider.WINGMAN_PRO) +class WingmanSubscriptionLlm(LlmInterface): + def __init__(self, ws_instance: "WingmanSubscription", config: "WingmanConfig"): + self._ws = ws_instance + self._config = config + + async def ask(self, messages, tools=None): + return self._ws.ask( + messages=messages, + deployment=self._config.wingman_pro.conversation_deployment, + tools=tools, + ) diff --git a/providers/x_ai.py b/providers/x_ai.py index 1706a1d99..d02f38c75 100644 --- a/providers/x_ai.py +++ b/providers/x_ai.py @@ -1,6 +1,12 @@ +from typing import TYPE_CHECKING from openai import OpenAI, APIStatusError +from api.enums import ConversationProvider +from providers.interfaces import LlmInterface, llm_provider from providers.open_ai import OpenAi +if TYPE_CHECKING: + from api.interface import WingmanConfig + class XAi(OpenAi): def _perform_ask( @@ -48,3 +54,16 @@ def _fix_tools(self, tools: list[dict[str, any]]) -> list[dict[str, any]]: } fixed_tools.append(fixed_tool) return fixed_tools + + +@llm_provider(ConversationProvider.XAI) +class XAiLlm(LlmInterface): + def __init__(self, xai_instance: "XAi", config: "WingmanConfig"): + self._xai = xai_instance + self._config = config + + async def ask(self, messages, tools=None): + return self._xai.ask( + messages=messages, tools=tools, + model=self._config.xai.conversation_model, + ) diff --git a/providers/xvasynth.py b/providers/xvasynth.py index a89392969..1c2815da2 100644 --- a/providers/xvasynth.py +++ b/providers/xvasynth.py @@ -3,13 +3,18 @@ import platform import subprocess import time +from typing import TYPE_CHECKING import requests -from api.enums import LogType +from api.enums import LogType, TtsProvider from api.interface import XVASynthSettings, XVASynthTtsConfig, SoundConfig +from providers.interfaces import TtsInterface, tts_provider from services.audio_player import AudioPlayer from services.file import get_writable_dir from services.printr import Printr +if TYPE_CHECKING: + from api.interface import WingmanConfig + RECORDING_PATH = "audio_output" OUTPUT_FILE = "xvasynth.wav" SYNTHESIZE_URL = "synthesize" @@ -214,3 +219,21 @@ def __is_server_running(self, timeout=10): return response.ok except Exception: return False + + +@tts_provider(TtsProvider.XVASYNTH) +class XVASynthTts(TtsInterface): + """Per-wingman adapter around the shared XVASynth singleton.""" + + def __init__(self, shared: "XVASynth", config: "WingmanConfig"): + self._shared = shared + self._config = config + + async def play_audio(self, text, sound_config, audio_player, wingman_name): + await self._shared.play_audio( + text=text, + config=self._config.xvasynth, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + ) diff --git a/services/benchmark.py b/services/benchmark.py index 3312dcbf8..685de8fe3 100644 --- a/services/benchmark.py +++ b/services/benchmark.py @@ -4,6 +4,13 @@ from services.printr import Printr +def format_ms(execution_time_ms: float) -> str: + """Format milliseconds as ``"1.2s"`` or ``"340ms"``.""" + if execution_time_ms >= 1000: + return f"{execution_time_ms/1000:.1f}s" + return f"{int(execution_time_ms)}ms" + + class Benchmark: def __init__(self, label: str): self.label = label @@ -53,13 +60,8 @@ def finish_snapshot(self) -> BenchmarkResult | None: def _create_benchmark_result(self, label: str, start_time: float): end_time = time.perf_counter() execution_time = (end_time - start_time) * 1000 # Convert to milliseconds - if execution_time >= 1000: - formatted_execution_time = f"{execution_time/1000:.1f}s" - else: - formatted_execution_time = f"{int(execution_time)}ms" - return BenchmarkResult( label=label, execution_time_ms=execution_time, - formatted_execution_time=formatted_execution_time, + formatted_execution_time=format_ms(execution_time), ) diff --git a/services/command_executor.py b/services/command_executor.py new file mode 100644 index 000000000..4d11054c2 --- /dev/null +++ b/services/command_executor.py @@ -0,0 +1,310 @@ +"""Command and action execution service. + +Handles instant-activation matching, command dispatch, and the +keyboard/mouse/joystick/audio action dispatcher extracted from Wingman. +""" + +import difflib +import random +import time +import traceback + +import keyboard.keyboard as keyboard +import mouse.mouse as mouse + +from api.enums import LogType +from api.interface import CommandConfig, WingmanConfig +from services.audio_library import AudioLibrary +from services.printr import Printr + +printr = Printr() + + +def _command_has_effective_actions(command: CommandConfig) -> bool: + """True if the command has at least one action the LLM can meaningfully trigger.""" + if command.is_system_command: + return True + if not command.actions: + return False + for action in command.actions: + if not action: + continue + if ( + action.keyboard is not None + or action.mouse is not None + or action.joystick is not None + or action.audio is not None + or action.write is not None + or action.wait is not None + ): + return True + return False + + +class CommandExecutor: + """Focused service for command lookup, instant activation, and action dispatch.""" + + def __init__( + self, + config: WingmanConfig, + audio_library: AudioLibrary, + wingman_name: str, + on_reset_history, # async callable: reset_conversation_history + on_add_forced_commands=None, # async callable: add_forced_assistant_command_calls + ): + self.config = config + self.audio_library = audio_library + self.wingman_name = wingman_name + self.on_reset_history = on_reset_history + self.on_add_forced_commands = on_add_forced_commands + + # ───────────────── Command lookup ─────────────────────────── # + + def get_command(self, command_name: str) -> CommandConfig | None: + if self.config.commands is None: + return None + command = next( + (item for item in self.config.commands if item.name == command_name), + None, + ) + return command + + def select_instant_command_response(self, command: CommandConfig) -> str | None: + command_responses = command.responses + if (command_responses is None) or (len(command_responses) == 0): + return None + return random.choice(command_responses) + + # ───────────────── Instant activation ─────────────────────── # + + async def try_instant_activation(self, transcript: str) -> tuple[str, bool]: + commands = await self._execute_instant_activation_command(transcript) + if commands: + if self.on_add_forced_commands is not None: + await self.on_add_forced_commands(commands) + responses = [] + for command in commands: + if command.responses: + responses.append(self.select_instant_command_response(command)) + + if len(responses) == len(commands): + responses = list(dict.fromkeys(responses)) + responses = [ + response + "." if not response.endswith(".") else response + for response in responses + ] + return " ".join(responses), True + + return None, True + + return None, False + + async def _execute_instant_activation_command( + self, transcript: str + ) -> list[CommandConfig] | None: + if not self.config.commands: + return None + try: + commands_by_instant_activation = {} + for command in self.config.commands: + if command.instant_activation: + for phrase in command.instant_activation: + if phrase.lower() in commands_by_instant_activation: + commands_by_instant_activation[phrase.lower()].append( + command + ) + else: + commands_by_instant_activation[phrase.lower()] = [command] + + phrase = difflib.get_close_matches( + transcript.lower(), + commands_by_instant_activation.keys(), + n=1, + cutoff=1, + ) + + if not phrase: + return None + + commands = commands_by_instant_activation[phrase[0]] + for command in commands: + await self.execute_command(command, True) + + return commands + except Exception as e: + await printr.print_async( + f"Error during instant activation in Wingman '{self.wingman_name}': {str(e)}", + color=LogType.ERROR, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + return None + + # ───────────────── Command execution ──────────────────────── # + + async def execute_command( + self, command: CommandConfig, is_instant=False + ) -> tuple[str | None, str]: + if not command: + return None, "Command not found" + + try: + if len(command.actions or []) == 0: + await printr.print_async( + f"No actions found for command: {command.name}", + color=LogType.WARNING, + ) + else: + await self.execute_action(command) + await printr.print_async( + f"Executed {'instant' if is_instant else 'AI'} command: {command.name}", + color=LogType.COMMAND, + ) + + if command.name == "ResetConversationHistory": + await self.on_reset_history() + await printr.print_async( + f"Executed command: {command.name}", color=LogType.COMMAND + ) + + return ( + self.select_instant_command_response(command), + command.additional_context or "OK", + ) + except Exception as e: + await printr.print_async( + f"Error executing command '{command.name}' for Wingman '{self.wingman_name}': {str(e)}", + color=LogType.ERROR, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + return None, "ERROR DURING PROCESSING" + + # ───────────────── Tool definition ───────────────────────── # + + def get_tool_definition(self) -> dict | None: + """Return the OpenAI-style execute_command tool definition, or None if no + eligible commands are configured.""" + if not self.config.commands: + return None + commands = [ + command.name + for command in self.config.commands + if (not command.force_instant_activation) + and _command_has_effective_actions(command) + ] + if not commands: + return None + return { + "type": "function", + "function": { + "name": "execute_command", + "description": "Executes a command", + "parameters": { + "type": "object", + "properties": { + "command_name": { + "type": "string", + "description": "The name of the command to execute", + "enum": commands, + }, + }, + "required": ["command_name"], + }, + }, + } + + # ───────────────── Action dispatch ────────────────────────── # + + async def execute_action(self, command: CommandConfig): + if not command or not command.actions: + return + + def contains_numpad_key(hotkey: str) -> bool: + if not hotkey: + return False + tokens = hotkey.lower().split("+") + return any(token.startswith("num ") for token in tokens) + + try: + for action in command.actions: + if action.keyboard: + if action.keyboard.hotkey_codes and not contains_numpad_key( + action.keyboard.hotkey + ): + code = action.keyboard.hotkey_codes + else: + code = action.keyboard.hotkey + + if action.keyboard.press == action.keyboard.release: + hold = action.keyboard.hold or 0.1 + if ( + action.keyboard.hotkey_codes + and len(action.keyboard.hotkey_codes) == 1 + and not contains_numpad_key(action.keyboard.hotkey) + ): + keyboard.direct_event( + action.keyboard.hotkey_codes[0], + 0 + (1 if action.keyboard.hotkey_extended else 0), + ) + time.sleep(hold) + keyboard.direct_event( + action.keyboard.hotkey_codes[0], + 2 + (1 if action.keyboard.hotkey_extended else 0), + ) + else: + keyboard.press(code) + time.sleep(hold) + keyboard.release(code) + else: + if ( + action.keyboard.hotkey_codes + and len(action.keyboard.hotkey_codes) == 1 + and not contains_numpad_key(action.keyboard.hotkey) + ): + keyboard.direct_event( + action.keyboard.hotkey_codes[0], + (0 if action.keyboard.press else 2) + + (1 if action.keyboard.hotkey_extended else 0), + ) + else: + keyboard.send( + code, + action.keyboard.press, + action.keyboard.release, + ) + + if action.mouse: + if action.mouse.move_to: + x, y = action.mouse.move_to + mouse.move(x, y) + + if action.mouse.move: + x, y = action.mouse.move + mouse.move(x, y, absolute=False, duration=0.5) + + if action.mouse.scroll: + mouse.wheel(action.mouse.scroll) + + if action.mouse.button: + if action.mouse.hold: + mouse.press(button=action.mouse.button) + time.sleep(action.mouse.hold) + mouse.release(button=action.mouse.button) + else: + mouse.click(button=action.mouse.button) + + if action.write: + keyboard.write(action.write) + + if action.wait: + time.sleep(action.wait) + + if action.audio: + await self.audio_library.handle_action( + action.audio, self.config.sound.volume + ) + except Exception as e: + await printr.print_async( + f"Error executing actions of command '{command.name}' for wingman '{self.wingman_name}': {str(e)}", + color=LogType.ERROR, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) diff --git a/services/config_manager.py b/services/config_manager.py index dc2e80a31..fd186d8b5 100644 --- a/services/config_manager.py +++ b/services/config_manager.py @@ -1452,6 +1452,7 @@ def perform_hardware_scan(self, system_manager): if system_manager.is_cuda_available(): self.settings_config.voice_activation.fasterwhisper.device = "cuda" self.settings_config.voice_activation.fasterwhisper.compute_type = "auto" + self.settings_config.voice_activation.parakeet.execution_provider = "cuda" self.printr.print( f"- GPU detected: {system_manager.get_gpu_name()}", color=LogType.STARTUP, @@ -1460,7 +1461,7 @@ def perform_hardware_scan(self, system_manager): source_name=self.log_source_name, ) self.printr.print( - "- Auto-configured FasterWhisper to use CUDA", + "- Auto-configured FasterWhisper and Parakeet to use CUDA", color=LogType.STARTUP, server_only=True, source=LogSource.SYSTEM, @@ -1468,8 +1469,10 @@ def perform_hardware_scan(self, system_manager): ) changes = True else: + self.settings_config.voice_activation.fasterwhisper.device = "cpu" + self.settings_config.voice_activation.parakeet.execution_provider = "cpu" self.printr.print( - "- No NVIDIA GPU detected, keeping current STT settings", + "- No NVIDIA GPU detected, STT providers will use CPU", color=LogType.STARTUP, server_only=True, source=LogSource.SYSTEM, diff --git a/services/config_service.py b/services/config_service.py index c0f4c03df..e57f807e4 100644 --- a/services/config_service.py +++ b/services/config_service.py @@ -282,7 +282,7 @@ async def get_wingman_skills( wingman_name: str, ) -> list[WingmanSkillState]: """Get all skills with their enabled/disabled state for a specific wingman.""" - import sys + from services.platform_utils import normalize_platform try: # Get all available skills @@ -307,10 +307,7 @@ async def get_wingman_skills( discoverable_skills = wingman_config.discoverable_skills - # Get current platform for filtering - current_platform = sys.platform - platform_map = {"win32": "windows", "darwin": "darwin", "linux": "linux"} - normalized_platform = platform_map.get(current_platform, current_platform) + normalized_platform = normalize_platform() # Build response with enabled state result = [] diff --git a/services/context_builder.py b/services/context_builder.py new file mode 100644 index 000000000..937213635 --- /dev/null +++ b/services/context_builder.py @@ -0,0 +1,284 @@ +"""System prompt assembly from template, skill prompts, TTS prompts, memory injection.""" + +import re +from datetime import datetime +from typing import TYPE_CHECKING, Optional + +from api.enums import ( + LogSource, + LogType, + TtsProvider, + WingmanProTtsProvider, +) +from services.printr import Printr +from services.token_utils import count_tokens, truncate_to_tokens + +if TYPE_CHECKING: + from api.interface import SettingsConfig, WingmanConfig + from services.persistent_memory import PersistentMemoryService + from services.skill_registry import SkillRegistry + from skills.skill_base import Skill + +printr = Printr() + + +class ContextBuilder: + def __init__( + self, + config: "WingmanConfig", + settings: "SettingsConfig", + wingman_name: str, + ): + self._config = config + self._settings = settings + self._wingman_name = wingman_name + self._last_compiled_context: str = "" + self._memory_recall_notified: bool = False + + async def build( + self, + skills: list["Skill"], + skill_registry: "SkillRegistry", + conversation_summary: str, + persistent_memory_service: Optional["PersistentMemoryService"], + messages: list, + config_dir_name: Optional[str] = None, + ) -> str: + """Build the context and return it as a string. + + With progressive disclosure, only includes prompts from ACTIVATED skills. + Skill prompts are auto-generated from @tool descriptions if no custom prompt is set. + """ + skill_prompts = "" + active_skill_names = skill_registry.active_skill_names + + for skill in skills: + # Only include prompts from activated skills (in progressive mode) + if skill.name not in active_skill_names: + continue + + # Get custom prompt if set + prompt = await skill.get_prompt() + + # Auto-generate prompt from tool descriptions if no custom prompt + if not prompt: + tools_desc = skill.get_tools_description() + if tools_desc: + prompt = f"Available tools:\n{tools_desc}" + + if prompt: + skill_prompts += "\n\n" + skill.name + "\n\n" + prompt + + # Get TTS prompt based on active TTS provider and user preference + tts_prompt = "" + if self._config.features.tts_provider == TtsProvider.ELEVENLABS: + if ( + self._config.elevenlabs.use_tts_prompt + and self._config.elevenlabs.tts_prompt + ): + tts_prompt = self._config.elevenlabs.tts_prompt + elif self._config.features.tts_provider == TtsProvider.INWORLD or ( + self._config.features.tts_provider == TtsProvider.WINGMAN_PRO + and self._config.wingman_pro.tts_provider == WingmanProTtsProvider.INWORLD + ): + if self._config.inworld.use_tts_prompt and self._config.inworld.tts_prompt: + tts_prompt = self._config.inworld.tts_prompt + elif self._config.features.tts_provider == TtsProvider.OPENAI_COMPATIBLE: + if ( + self._config.openai_compatible_tts.use_tts_prompt + and self._config.openai_compatible_tts.tts_prompt + ): + tts_prompt = self._config.openai_compatible_tts.tts_prompt + + # Add TTS header only if there's a prompt + if tts_prompt: + tts_prompt = "# TEXT-TO-SPEECH\n" + tts_prompt + + # Build user context with environment metadata + user_context = self.build_user_context(config_dir_name=config_dir_name) + + # Sanity check: truncate if someone bypasses the client's 2048-token limit + MAX_BACKSTORY_TOKENS = 2048 + backstory = self._config.prompts.backstory + + if backstory and count_tokens(backstory) > MAX_BACKSTORY_TOKENS: + original_tokens = count_tokens(backstory) + backstory = truncate_to_tokens(backstory, MAX_BACKSTORY_TOKENS) + await printr.print_async( + f"[{self._wingman_name}] Backstory will be truncated to {MAX_BACKSTORY_TOKENS} tokens for conversations (is {original_tokens}). " + f"Your saved backstory is unchanged. Consider shortening it.", + color=LogType.WARNING, + source_name=self._wingman_name, + source=LogSource.SYSTEM, + ) + + # Build conversation summary section + conversation_summary_section = "" + if conversation_summary: + conversation_summary_section = ( + "# CONVERSATION SUMMARY\n" + "The following is a summary of earlier parts of this conversation. " + "Treat it as factual context — the user and you discussed these topics previously.\n\n" + + conversation_summary + ) + + # Persistent memory injection + persistent_memory_context = "" + if persistent_memory_service and messages: + # Use the most recent user message as the query + last_user_msg = "" + for msg in reversed(messages): + role = ( + msg.get("role") + if isinstance(msg, dict) + else getattr(msg, "role", None) + ) + raw_content = ( + msg.get("content", "") + if isinstance(msg, dict) + else getattr(msg, "content", "") + ) + # Extract plain text from multimodal content (images etc.) + content = self._extract_text_content(raw_content) if raw_content else "" + if role == "user" and content: + last_user_msg = content + break + if last_user_msg: + try: + persistent_memory_context = ( + await persistent_memory_service.build_memory_context( + last_user_msg + ) + ) + if persistent_memory_context and not self._memory_recall_notified: + self._memory_recall_notified = True + # Count restored fact lines (lines starting with "- ") + fact_count = sum( + 1 + for line in persistent_memory_context.splitlines() + if line.startswith("- ") + ) + if fact_count > 0: + await printr.print_async( + f"Memory: {fact_count} {'memory' if fact_count == 1 else 'memories'} recalled", + color=LogType.MEMORY, + source_name=self._wingman_name, + ) + except Exception: + pass # Don't let memory failures break conversation + + context = self._config.prompts.system_prompt.format( + backstory=backstory, + skills=skill_prompts, + ttsprompt=tts_prompt, + user_context=user_context, + conversation_summary=conversation_summary_section, + ) + + # If the system prompt template doesn't include {conversation_summary}, + # append the summary at the end so it's never lost. + if ( + conversation_summary_section + and "{conversation_summary}" not in self._config.prompts.system_prompt + ): + context += "\n\n" + conversation_summary_section + + # Append persistent memory context + if persistent_memory_context: + context += "\n\n" + persistent_memory_context + + # Persistent memory tool instructions + if persistent_memory_service: + context += ( + "\n\n# PERSISTENT MEMORY\n" + "You have persistent memory. Important facts and past conversation summaries " + "are provided in the [Memory] sections above (if any). " + "You can use the `memory_remember`, `memory_recall`, and `memory_forget` tools when the user " + "explicitly asks you to remember, recall, or forget something. " + "You don't need to use `memory_remember` for routine information — that is handled automatically." + ) + + self._last_compiled_context = context + return context + + def get_last_context(self) -> str: + """Return the last compiled system context (cached from the most recent LLM call).""" + return self._last_compiled_context + + def build_user_context( + self, + config_dir_name: Optional[str] = None, + ) -> str: + """Build user context metadata for the system prompt. + + Includes timezone, config context, username, and wingman name. + """ + context_parts = [] + backstory = self._config.prompts.backstory or "" + backstory_lower = backstory.lower() + + # Date and timezone information + try: + now = datetime.now().astimezone() + local_tz = now.tzinfo + tz_name = str(local_tz) + # Get UTC offset in a readable format + utc_offset = now.strftime("%z") + # Format as +HH:MM or -HH:MM + if len(utc_offset) >= 5: + utc_offset = f"{utc_offset[:3]}:{utc_offset[3:]}" + # Include current date for relative date references ("last Sunday", "tomorrow", etc.) + current_date = now.strftime( + "%A, %B %d, %Y" + ) # e.g., "Tuesday, December 09, 2025" + context_parts.append(f"- Current date: {current_date}") + context_parts.append(f"- Timezone: {tz_name} (UTC{utc_offset})") + except Exception: + context_parts.append("- Timezone: Unknown") + + # Config/context name (e.g., "Star Citizen", "Elite Dangerous") + # This helps the LLM understand which game/context tools are relevant for + if config_dir_name: + context_parts.append(f"- Active context: {config_dir_name}") + + # Username (only if not explicitly named in backstory) + if self._settings.user_name: + # Check if username is mentioned in backstory as a standalone word + name_pattern = r"\b" + re.escape(self._settings.user_name.lower()) + r"\b" + if not re.search(name_pattern, backstory_lower): + context_parts.append( + f"- User's name (default): {self._settings.user_name}" + ) + + # Wingman name - always include as it's useful context + # The system prompt already tells LLM to prioritize backstory names + if self._wingman_name: + context_parts.append(f"- Your name (default): {self._wingman_name}") + + if context_parts: + return "\n".join(context_parts) + return "No additional context available." + + async def add_context(self, messages: list, context: str) -> None: + """Insert the compiled context as the system message at the start of messages.""" + messages.insert(0, {"role": "system", "content": context}) + + def reset_memory_notification(self) -> None: + """Reset the memory recall notification flag (e.g., on conversation reset).""" + self._memory_recall_notified = False + + @staticmethod + def _extract_text_content(content) -> str: + """Extract text from message content, handling both string and multimodal list formats.""" + if isinstance(content, str): + return content + if isinstance(content, list): + # Multimodal content: extract text parts only + parts = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + parts.append(part.get("text", "")) + elif isinstance(part, str): + parts.append(part) + return " ".join(parts) + return "" diff --git a/services/conversation_condenser.py b/services/conversation_condenser.py new file mode 100644 index 000000000..8197a4c33 --- /dev/null +++ b/services/conversation_condenser.py @@ -0,0 +1,585 @@ +"""Conversation condensation — summarizes history when it grows too large.""" + +import asyncio +from typing import TYPE_CHECKING + +from api.enums import LogType, LogSource +from services.file import get_prompt +from services.printr import Printr +from services.token_utils import count_tokens, truncate_to_tokens + +if TYPE_CHECKING: + from api.interface import WingmanConfig + from services.conversation_manager import ConversationManager + from services.local_ai_service import LocalAiService + from services.persistent_memory import PersistentMemoryService + +printr = Printr() + +_CONDENSE_TIMEOUT = 120.0 + + +class ConversationCondenser: + def __init__( + self, + conversation: "ConversationManager", + config: "WingmanConfig", + wingman_name: str, + ): + self._conversation = conversation + self._config = config + self._wingman_name = wingman_name + self._is_condensing = False + self._condense_task: asyncio.Task | None = None + self._support_token_ratio: float = 1.35 + + @property + def summary(self) -> str: + return self._conversation.conversation_summary + + @property + def is_condensing(self) -> bool: + return self._is_condensing + + def get_support_capacity(self, local_ai_service: "LocalAiService") -> int: + """Get the effective input capacity of the support model for a single pass. + + Returns the number of conversation tokens that can fit in one summarization + pass, accounting for system prompt, framing text, and output budget. + """ + system_prompt = get_prompt("condense-conversation") + budget = local_ai_service.get_token_budget(system_prompt) + # Subtract framing overhead (prefix/suffix around the conversation text) + framing_overhead = 80 + return max(0, budget.max_input_tokens - framing_overhead) + + async def maybe_condense(self, local_ai_service: "LocalAiService"): + """Check if condensation should run and fire it as a background task. + + Uses a token-based trigger: condenses when the conversation approaches + 70% of what the support model can handle in a single pass, so we + avoid chunking. Also has a message count safety cap. + + The cl100k_base token estimate is multiplied by _support_token_ratio + (calibrated from real support model usage) to account for tokenizer + differences between the estimation tokenizer and the actual model. + """ + if not self._config.features.condense_conversation: + return + if not local_ai_service or not local_ai_service.is_ready(): + return + if self._conversation.pending_tool_calls: + return # Never interrupt chained tool calls + if self._is_condensing: + return + + # Token-based trigger: condense when conversation reaches 70% of + # what the support model can handle in one pass. + # Apply _support_token_ratio to correct for tokenizer differences + # between cl100k_base (used for estimation) and the actual model. + capacity = self.get_support_capacity(local_ai_service) + cl100k_tokens = self._conversation.estimate_tokens() + conversation_tokens = int(cl100k_tokens * self._support_token_ratio) + token_trigger = conversation_tokens >= int(capacity * 0.7) + + # Message count safety cap + user_msg_count = sum( + 1 + for m in self._conversation.messages + if self._conversation.get_message_role(m) == "user" + ) + message_trigger = ( + user_msg_count >= self._config.features.condense_max_messages + ) + + if not token_trigger and not message_trigger: + return + + # Runs in background so user is never blocked. + # Store the task reference to prevent garbage collection mid-execution. + self._condense_task = asyncio.create_task( + self.condense(local_ai_service=local_ai_service) + ) + self._condense_task.add_done_callback( + lambda _: setattr(self, "_condense_task", None) + ) + + async def condense( + self, + local_ai_service: "LocalAiService", + persistent_memory_service: "PersistentMemoryService | None" = None, + background_tasks: set[asyncio.Task] | None = None, + force: bool = False, + ): + """Condense older conversation messages into a running summary using local AI. + + This preserves the most recent messages verbatim while summarizing older ones, + saving tokens without losing important context. Tool call/response pairs are + never split. + + Args: + local_ai_service: The local AI service to use for summarization. + persistent_memory_service: Optional service for extracting memories. + background_tasks: Optional set to track background tasks for memory extraction. + force: If True, skip the threshold check (used for manual trigger). + """ + if self._is_condensing: + await printr.print_async( + "Condensation skipped — already in progress.", + color=LogType.WARNING, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + return + if not local_ai_service or not local_ai_service.is_ready(): + await printr.print_async( + "Condensation skipped — local AI service not available.", + color=LogType.WARNING, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + return + + keep_recent = ( + self._config.features.condense_keep_recent + if not force + else min(self._config.features.condense_keep_recent, 2) + ) + total_msg_count = len(self._conversation.messages) + + # Need at least something to condense beyond what we keep + if total_msg_count <= keep_recent: + await printr.print_async( + f"Condensation skipped — only {total_msg_count} messages, need more than {keep_recent} to condense.", + color=LogType.LOCALMODEL, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + return + + self._is_condensing = True + _condensation_stats: dict = {} + + # Broadcast start + from api.commands import ConversationCondensationCommand + + if printr._connection_manager: + await printr._connection_manager.broadcast( + ConversationCondensationCommand( + wingman_name=self._wingman_name, + status="started", + ) + ) + + await printr.print_async( + "Conversation condensation started.", + color=LogType.INFO, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + + try: + # Wait for any pending tool calls to finish + for _ in range(30): # max 15 seconds + if not self._conversation.pending_tool_calls: + break + await asyncio.sleep(0.5) + else: + await printr.print_async( + "Condensation aborted — tool calls still pending after 15s.", + color=LogType.WARNING, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + return + + # Find the cutoff: keep the most recent `keep_recent` user messages + kept_user_count = 0 + cutoff_index = len(self._conversation.messages) + for i in range(len(self._conversation.messages) - 1, -1, -1): + if ( + self._conversation.get_message_role( + self._conversation.messages[i] + ) + == "user" + ): + kept_user_count += 1 + if kept_user_count == keep_recent: + cutoff_index = i + break + + if cutoff_index <= 0: + await printr.print_async( + f"Condensation skipped — cutoff_index={cutoff_index}, nothing to condense (kept_user_count={kept_user_count}, keep_recent={keep_recent}, total={len(self._conversation.messages)}).", + color=LogType.LOCALMODEL, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + return + + # Adjust cutoff forward to avoid orphaning tool responses + while cutoff_index < len(self._conversation.messages): + msg = self._conversation.messages[cutoff_index] + if self._conversation.get_message_role(msg) == "tool": + cutoff_index += 1 + else: + break + + if cutoff_index <= 0: + await printr.print_async( + "Condensation skipped — no messages to condense after tool adjustment.", + color=LogType.LOCALMODEL, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + return + + to_condense = self._conversation.messages[:cutoff_index] + + # Extract memories from messages about to be condensed (background, non-blocking) + if persistent_memory_service: + try: + task = asyncio.create_task( + persistent_memory_service.extract_memories( + to_condense, generate_summary=True + ) + ) + if background_tasks is not None: + background_tasks.add(task) + task.add_done_callback(background_tasks.discard) + except Exception: + pass # Don't let memory extraction block condensation + + condensed_text = self._conversation._messages_to_text(to_condense) + if not condensed_text.strip(): + await printr.print_async( + "Condensation skipped — messages produced no text content.", + color=LogType.LOCALMODEL, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + return + + # Estimate original token count + estimated_original_tokens = sum( + count_tokens(self._conversation._message_text_content(m)) + for m in to_condense + ) + + # Build the summarization prompt + existing_summary_section = "" + if self._conversation.conversation_summary: + existing_summary_section = ( + "EXISTING SUMMARY (incorporate and update — do not repeat verbatim):\n" + + self._conversation.conversation_summary + + "\n\n" + ) + + system_prompt = get_prompt("condense-conversation") + budget = local_ai_service.get_token_budget(system_prompt) + + user_prompt_prefix = ( + existing_summary_section + "CONVERSATION TO SUMMARIZE:\n" + ) + user_prompt_suffix = ( + "\n\n---\n" + "Now list every fact from the conversation above as bullet points.\n" + "Start from the FIRST message, end at the LAST. Include all names, preferences, and creative content. Never include secrets, API keys, credentials, passwords, or tokens:" + ) + prefix_suffix_tokens = count_tokens(user_prompt_prefix) + count_tokens( + user_prompt_suffix + ) + + # How much conversation text fits in one pass? + available_tokens = budget.max_input_tokens - prefix_suffix_tokens + + # Apply tokenizer ratio to decide if chunking is needed. + corrected_text_tokens = int( + count_tokens(condensed_text) * self._support_token_ratio + ) + corrected_available = int(available_tokens / self._support_token_ratio) + + if corrected_text_tokens > available_tokens: + # Chunk: summarize in segments, then merge + support_result = await asyncio.wait_for( + self._chunked_support( + condensed_text, + system_prompt, + existing_summary_section, + corrected_available, + local_ai_service, + ), + timeout=_CONDENSE_TIMEOUT, + ) + else: + user_prompt = ( + f"{user_prompt_prefix}{condensed_text}{user_prompt_suffix}" + ) + from services.skill_local_ai import SamplingPreset + + support_result = await asyncio.wait_for( + asyncio.get_event_loop().run_in_executor( + None, + lambda: local_ai_service.support( + text=user_prompt, + system_prompt=system_prompt, + preset=SamplingPreset.BALANCED, + ), + ), + timeout=_CONDENSE_TIMEOUT, + ) + + summary = support_result.text if support_result else None + + # Calibrate tokenizer ratio from real model usage + if support_result and support_result.prompt_tokens > 0: + cl100k_input = ( + budget.system_tokens + + count_tokens(condensed_text) + + prefix_suffix_tokens + ) + if cl100k_input > 0: + self._support_token_ratio = ( + support_result.prompt_tokens / cl100k_input + ) + + # Detect truncated output + if support_result and support_result.truncated: + await printr.print_async( + f"Condensation output was truncated (finish_reason=length). " + f"Model used {support_result.prompt_tokens} prompt tokens, " + f"generated {support_result.completion_tokens} tokens. " + f"Token ratio calibrated to {self._support_token_ratio:.2f}.", + color=LogType.WARNING, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + + if not summary: + await printr.print_async( + "Conversation condensation failed — local AI returned no result.", + color=LogType.WARNING, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + return + + # Clean pending tool calls being removed + for msg in to_condense: + if ( + self._conversation.get_message_role(msg) == "tool" + and msg.get("tool_call_id") + in self._conversation.pending_tool_calls + ): + self._conversation.pending_tool_calls.remove( + msg.get("tool_call_id") + ) + + # Replace old messages + del self._conversation.messages[:cutoff_index] + self._conversation.conversation_summary = summary + + estimated_summary_tokens = count_tokens(summary) + estimated_tokens_saved = max( + 0, estimated_original_tokens - estimated_summary_tokens + ) + + await printr.print_async( + f"Condensed {cutoff_index} messages into summary " + f"({len(summary)} chars, ~{estimated_summary_tokens} tokens). " + f"{len(self._conversation.messages)} messages remaining. " + f"~{estimated_tokens_saved} tokens saved. " + f"Token ratio: {self._support_token_ratio:.2f}.", + color=LogType.INFO, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + + # Record stats for the broadcast in finally + _condensation_stats = { + "messages_condensed": cutoff_index, + "messages_remaining": len(self._conversation.messages), + "summary_length": len(summary), + "estimated_tokens_saved": estimated_tokens_saved, + "summary_text": summary, + } + + except asyncio.TimeoutError: + await printr.print_async( + "Condensation timed out — local model took too long.", + color=LogType.WARNING, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + except Exception as e: + await printr.print_async( + f"Conversation condensation error: {e}", + color=LogType.ERROR, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + finally: + self._is_condensing = False + # Always broadcast finished so the client UI doesn't get stuck. + # Include summary_text if condensation produced one (even if a + # later step failed), so the client can show the view-history button. + if printr._connection_manager: + try: + await printr._connection_manager.broadcast( + ConversationCondensationCommand( + wingman_name=self._wingman_name, + status="finished", + **_condensation_stats, + ) + ) + except Exception as e: + await printr.print_async( + f"Failed to broadcast condensation finish: {e}", + color=LogType.WARNING, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + + async def _chunked_support( + self, + full_text: str, + system_prompt: str, + existing_summary_section: str, + chunk_max_tokens: int, + local_ai_service: "LocalAiService", + ) -> "SupportResult": + """Process text that exceeds the model's context window by chunking. + + Each chunk is processed independently, then results are merged into + one final summary. Returns a SupportResult from the merge step. + """ + from providers.llama_cpp_provider import SupportResult + from services.skill_local_ai import SamplingPreset + + budget = local_ai_service.get_token_budget(system_prompt) + + # Convert token budget to approximate char limit for splitting + # (splitting needs char positions; we use ~4 chars/token as a rough guide, + # then verify with count_tokens) + approx_chunk_chars = chunk_max_tokens * 4 + chunks = [] + remaining = full_text + while remaining: + if count_tokens(remaining) <= chunk_max_tokens: + chunks.append(remaining) + break + # Try to split at a newline boundary + split_at = remaining.rfind("\n", 0, approx_chunk_chars) + if split_at <= 0: + split_at = approx_chunk_chars + chunks.append(remaining[:split_at]) + remaining = remaining[split_at:].lstrip() + + loop = asyncio.get_event_loop() + chunk_summaries = [] + for i, chunk in enumerate(chunks): + user_prompt = ( + f"{existing_summary_section if i == 0 else ''}" + f"CONVERSATION TO SUMMARIZE (part {i + 1}/{len(chunks)}):\n{chunk}\n\n" + "---\nList every fact from the above as bullet points. Include all names, preferences, and creative content. Never include secrets, API keys, credentials, passwords, or tokens:" + ) + + # Safety: if chunk input exceeds budget, truncate chunk text + chunk_text_tokens = count_tokens(chunk) + prompt_overhead = count_tokens(user_prompt) - chunk_text_tokens + if prompt_overhead + chunk_text_tokens > budget.max_input_tokens: + safe_text_tokens = budget.max_input_tokens - prompt_overhead + if safe_text_tokens > 0: + chunk = truncate_to_tokens(chunk, safe_text_tokens) + user_prompt = ( + f"{existing_summary_section if i == 0 else ''}" + f"CONVERSATION TO SUMMARIZE (part {i + 1}/{len(chunks)}):\n{chunk}\n\n" + "---\nList every fact from the above as bullet points. Include all names, preferences, and creative content. Never include secrets, API keys, credentials, passwords, or tokens:" + ) + await printr.print_async( + f"Chunk {i + 1}/{len(chunks)} exceeded context budget, truncated to fit.", + color=LogType.WARNING, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + + result = await loop.run_in_executor( + None, + lambda p=user_prompt: local_ai_service.support( + text=p, + system_prompt=system_prompt, + preset=SamplingPreset.BALANCED, + ), + ) + if result.text: + chunk_summaries.append(result.text) + + # Calibrate tokenizer ratio from real model usage. + cl100k_input = count_tokens(system_prompt) + count_tokens(user_prompt) + if result.prompt_tokens > 0 and cl100k_input > 0: + self._support_token_ratio = result.prompt_tokens / cl100k_input + + if result.truncated: + await printr.print_async( + f"Chunk {i + 1}/{len(chunks)} output truncated " + f"(prompt={result.prompt_tokens}, " + f"completion={result.completion_tokens}).", + color=LogType.WARNING, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + + if not chunk_summaries: + return SupportResult(text=None) + if len(chunk_summaries) == 1: + return SupportResult(text=chunk_summaries[0]) + + # Merge all chunk summaries into one final summary + combined = "\n\n".join( + f"Part {i + 1}:\n{s}" for i, s in enumerate(chunk_summaries) + ) + merge_prompt = ( + f"{existing_summary_section}" + f"PARTIAL SUMMARIES TO MERGE:\n{combined}\n\n" + "Merge these into a single coherent summary. Keep all key facts:" + ) + + # Safety: truncate combined summaries if they exceed budget + if count_tokens(merge_prompt) > budget.max_input_tokens: + overhead = count_tokens(merge_prompt) - count_tokens(combined) + safe_combined = budget.max_input_tokens - overhead + if safe_combined > 0: + combined = truncate_to_tokens(combined, safe_combined) + merge_prompt = ( + f"{existing_summary_section}" + f"PARTIAL SUMMARIES TO MERGE:\n{combined}\n\n" + "Merge these into a single coherent summary. Keep all key facts:" + ) + await printr.print_async( + f"Merge input exceeded context budget, truncated to fit.", + color=LogType.WARNING, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + + return await loop.run_in_executor( + None, + lambda: local_ai_service.support( + text=merge_prompt, + system_prompt=system_prompt, + preset=SamplingPreset.BALANCED, + ), + ) diff --git a/services/conversation_manager.py b/services/conversation_manager.py new file mode 100644 index 000000000..f999b5df6 --- /dev/null +++ b/services/conversation_manager.py @@ -0,0 +1,558 @@ +"""Manages the conversation message list, tool responses, and history cleanup.""" + +import json +import random +import uuid +from typing import TYPE_CHECKING, Callable, Mapping, Optional + +from openai.types.chat import ( + ChatCompletionMessage, + ChatCompletionMessageToolCall, + ParsedFunction, +) + +from api.enums import ConversationProvider, LogType +from services.printr import Printr +from services.token_utils import count_tokens, truncate_to_tokens + +if TYPE_CHECKING: + from api.interface import CommandConfig, SettingsConfig, WingmanConfig + +printr = Printr() + + +class ConversationManager: + """Owns the conversation message list, tool-response bookkeeping, and + history cleanup / trimming logic. + """ + + def __init__( + self, + config: "WingmanConfig", + settings: "SettingsConfig", + wingman_name: str, + ): + self._config = config + self._settings = settings + self._wingman_name = wingman_name + self.messages: list = [] + self.pending_tool_calls: list[str] = [] + self.conversation_summary: str = "" + + # Skill state — set once via set_skill_context(); avoids per-call kwargs + self._skills: list = [] + self._skill_registry = None + self._tool_skills: dict = {} + + def set_skill_context( + self, + skills: list, + skill_registry, + tool_skills: dict, + ) -> None: + """Wire current skill state into the manager. + + Called by WingmanSkillManager after init/enable/disable so per-call + methods don't need to receive it as arguments. + """ + self._skills = skills + self._skill_registry = skill_registry + self._tool_skills = tool_skills + + # ------------------------------------------------------------------ + # GPT / assistant response helpers + # ------------------------------------------------------------------ + + async def add_gpt_response( + self, + message, + tool_calls, + ) -> tuple[bool, bool]: + """Adds a message from GPT to the conversation history as well as + adding dummy tool responses for any tool calls. + + Args: + message (dict | ChatCompletionMessage): The message to add. + tool_calls (list): The tool calls associated with the message. + """ + # call skill hooks (only for prepared/activated skills) + for skill in self._skills: + if skill.is_prepared: + await skill.on_add_assistant_message( + message.content, message.tool_calls + ) + + # do not tamper with this message as it will lead to 400 errors! + self.messages.append(message) + + # adding dummy tool responses to prevent corrupted message history on parallel requests + # and checks if waiting response should be played + unique_tools = {} + is_waiting_response_needed = False + is_summarize_needed = False + + if tool_calls: + for tool_call in tool_calls: + if not tool_call.id: + continue + # adding a dummy tool response to get updated later + self.add_tool_response(tool_call, "Loading..", False) + + function_name = tool_call.function.name + + # Meta-tools (search_skills, activate_skill, etc.) always need a follow-up + # LLM call so it can use the newly activated tools + if self._skill_registry and self._skill_registry.is_meta_tool(function_name): + is_summarize_needed = True + elif function_name in self._tool_skills: + skill = self._tool_skills[function_name] + if await skill.is_waiting_response_needed(function_name): + is_waiting_response_needed = True + if await skill.is_summarize_needed(function_name): + is_summarize_needed = True + + unique_tools[function_name] = True + + if len(unique_tools) == 1 and "execute_command" in unique_tools: + is_waiting_response_needed = True + + return is_waiting_response_needed, is_summarize_needed + + # ------------------------------------------------------------------ + # Tool response management + # ------------------------------------------------------------------ + + def add_tool_response( + self, tool_call, response: str, completed: bool = True + ): + """Adds a tool response to the conversation history. + + Args: + tool_call (dict|ChatCompletionMessageToolCall): The tool call to add the dummy response for. + response (str): The response content. + completed (bool): Whether the tool call is complete. + """ + msg = {"role": "tool", "content": response} + if tool_call.id is not None: + msg["tool_call_id"] = tool_call.id + if tool_call.function.name is not None: + msg["name"] = tool_call.function.name + self.messages.append(msg) + + if tool_call.id and not completed: + self.pending_tool_calls.append(tool_call.id) + + async def update_tool_response(self, tool_call_id, response) -> bool: + """Updates a tool response in the conversation history. + + Args: + tool_call_id (str): The identifier of the tool call to update the response for. + response (str): The new response to set. + + Returns: + bool: True if the response was updated, False if the tool call was not found. + """ + if not tool_call_id: + return False + + index = len(self.messages) + + # go through message history to find and update the tool call + for message in reversed(self.messages): + index -= 1 + if ( + self.get_message_role(message) == "tool" + and message.get("tool_call_id") == tool_call_id + ): + message["content"] = str(response) + if tool_call_id in self.pending_tool_calls: + self.pending_tool_calls.remove(tool_call_id) + return True + + return False + + async def trim_tool_responses( + self, + max_tokens: int = 500, + is_condensing: bool = False, + ): + """Trim oversized tool responses in conversation history. + + Called after the LLM has finished processing a turn with tool calls. + The LLM already had full access to the data; this just prevents stale + bulk data from inflating the context on subsequent turns. + + If significant trimming occurs, broadcasts a condensation notification + so the client UI can display a summary indicator. + + Args: + max_tokens: Maximum token count per tool response before trimming. + is_condensing: Whether condensation is currently running (suppresses + broadcast to avoid interfering with its own cycle). + """ + total_tokens_saved = 0 + for msg in self.messages: + if self.get_message_role(msg) != "tool": + continue + content = msg.get("content", "") + if not content: + continue + token_count = count_tokens(content) + if token_count <= max_tokens: + continue + total_tokens_saved += token_count - max_tokens + trimmed = truncate_to_tokens(content, max_tokens) + msg["content"] = ( + f"{trimmed}\n\n[...trimmed from ~{token_count} to " + f"~{max_tokens} tokens for conversation history. " + f"Full response was processed.]" + ) + + # Notify the client when significant trimming occurs so the UI can + # show a "Show history" indicator explaining the token drop. + # Skip if condensation is already running to avoid interfering with + # its own started/finished broadcast cycle. + if ( + total_tokens_saved > 1000 + and printr._connection_manager + and not is_condensing + ): + from api.commands import ConversationCondensationCommand + + if self.conversation_summary: + summary = ( + f"{self.conversation_summary}\n\n---\n\n" + f"[Latest turn: tool responses trimmed — ~{total_tokens_saved:,} tokens saved]" + ) + else: + summary = ( + f"Tool responses were automatically trimmed after LLM processing.\n" + f"~{total_tokens_saved:,} tokens saved.\n\n" + f"The LLM had full access to the complete data when generating " + f"its response. Responses are trimmed afterwards to keep the " + f"conversation context efficient." + ) + + await printr._connection_manager.broadcast( + ConversationCondensationCommand( + wingman_name=self._wingman_name, + status="finished", + estimated_tokens_saved=total_tokens_saved, + summary_text=summary, + ) + ) + + # ------------------------------------------------------------------ + # User / assistant message management + # ------------------------------------------------------------------ + + async def add_user_message( + self, + content: str, + images: list[tuple[str, str]] = None, + condense_fn: Optional[Callable] = None, + ): + """Shortens the conversation history if needed and adds a user message to it. + + Args: + content (str): The message content to add. + images (list[tuple[str, str]]): Optional list of (base64_data, mime_type) tuples to attach. + condense_fn: Optional async callable invoked after cleanup (``_maybe_condense_history``). + """ + # call skill hooks (only for prepared/activated skills) + for skill in self._skills: + if skill.is_prepared: + await skill.on_add_user_message(content) + + if images: + msg_content = [] + for img_b64, mime in images: + msg_content.append({ + "type": "image_url", + "image_url": { + "url": f"data:{mime};base64,{img_b64}", + "detail": "auto", + }, + }) + msg_content.append({"type": "text", "text": content}) + msg = {"role": "user", "content": msg_content} + else: + msg = {"role": "user", "content": content} + await self.cleanup_history() + if condense_fn: + await condense_fn() + self.messages.append(msg) + + async def add_assistant_message( + self, content: str + ): + """Adds an assistant message to the conversation history. + + Args: + content (str): The message content to add. + """ + # call skill hooks (only for prepared/activated skills) + for skill in self._skills: + if skill.is_prepared: + await skill.on_add_assistant_message(content, []) + + msg = {"role": "assistant", "content": content} + self.messages.append(msg) + + async def add_forced_assistant_command_calls( + self, + commands: list["CommandConfig"], + ): + """Adds forced assistant command calls to the conversation history. + + Args: + commands (list[CommandConfig]): The commands to add. + """ + + if not commands: + return + + message = ChatCompletionMessage( + content="", + role="assistant", + tool_calls=[], + ) + tool_id_to_command = {} + for command in commands: + tool_id = None + if ( + self._config.features.conversation_provider + == ConversationProvider.OPENAI + ) or ( + self._config.features.conversation_provider + == ConversationProvider.WINGMAN_PRO + and "gpt" in self._config.wingman_pro.conversation_deployment.lower() + ): + tool_id = f"call_{str(uuid.uuid4()).replace('-', '')}" + elif ( + self._config.features.conversation_provider + == ConversationProvider.GOOGLE + ): + if ( + self._config.google.conversation_model.startswith("gemini-3") + or self._config.google.conversation_model == "gemini-flash-latest" + or self._config.google.conversation_model == "gemini-pro-latest" + or self._config.google.conversation_model + == "gemini-flash-lite-latest" + ): + # gemini 3+ (latest = 3+) needs a thought signature like this, but we cant fake it: + # { + # 'model_extra': { + # 'extra_content': { + # 'google': { + # 'thought_signature': 'EjQKMgFyyNp8mNe4bQmQhOua7gGMH0C9RubFWewy6BzYZJs5f4RqDb8CaiR4gjLxoM1iQqP4' + # } + # } + # } + # } + return + tool_id = f"function-call-{''.join(random.choices('0123456789', k=20))}" + + # early exit for unsupported providers/models + if not tool_id: + return + + tool_call = ChatCompletionMessageToolCall( + id=tool_id, + function=ParsedFunction( + name="execute_command", + arguments=json.dumps({"command_name": command.name}), + ), + type="function", + ) + message.tool_calls.append(tool_call) + tool_id_to_command[tool_id] = command + + await self.add_gpt_response( + message, message.tool_calls + ) + for tool_call in message.tool_calls: + command = tool_id_to_command[tool_call.id] + await self.update_tool_response( + tool_call.id, command.additional_context or "OK" + ) + + # ------------------------------------------------------------------ + # History cleanup and token estimation + # ------------------------------------------------------------------ + + async def cleanup_history(self): + """Cleans up the conversation history by removing messages that are too old.""" + remember_messages = self._config.features.remember_messages + + if remember_messages is None or len(self.messages) == 0: + return 0 # Configuration not set, nothing to delete. + + # Find the cutoff index where to end deletion, making sure to only count 'user' messages towards the limit starting with newest messages. + cutoff_index = len(self.messages) + user_message_count = 0 + for message in reversed(self.messages): + if self.get_message_role(message) == "user": + user_message_count += 1 + if user_message_count == remember_messages: + break # Found the cutoff point. + cutoff_index -= 1 + + # If messages below the keep limit, don't delete anything. + if user_message_count < remember_messages: + return 0 + + total_deleted_messages = cutoff_index # Messages to delete. + + # Remove the pending tool calls that are no longer needed. + for mesage in self.messages[:cutoff_index]: + if ( + self.get_message_role(mesage) == "tool" + and mesage.get("tool_call_id") in self.pending_tool_calls + ): + self.pending_tool_calls.remove(mesage.get("tool_call_id")) + if self._settings.debug_mode: + await printr.print_async( + f"Removing pending tool call {mesage.get('tool_call_id')} due to message history clean up.", + color=LogType.WARNING, + ) + + # Remove the messages before the cutoff index, exclusive of the system message. + del self.messages[:cutoff_index] + + # Optional debugging printout. + if self._settings.debug_mode and total_deleted_messages > 0: + await printr.print_async( + f"Deleted {total_deleted_messages} messages from the conversation history.", + color=LogType.WARNING, + ) + + return total_deleted_messages + + def estimate_tokens(self) -> int: + """Estimate the total token count of the current conversation history.""" + return sum(count_tokens(self._message_text_content(m)) for m in self.messages) + + # ------------------------------------------------------------------ + # Reset + # ------------------------------------------------------------------ + + async def reset(self): + """Resets the conversation message list and summary. + + Note: skill/MCP registry resets and memory extraction are the + responsibility of the caller (the Wingman) since ConversationManager + does not own those services. + """ + self.messages = [] + self.conversation_summary = "" + + # ------------------------------------------------------------------ + # Message serialisation helpers + # ------------------------------------------------------------------ + + def get_conversation_messages(self, strip_nulls: bool = True) -> list[dict]: + """Return the conversation messages as a list of plain dicts for debugging.""" + + def _strip_none(obj): + if isinstance(obj, dict): + return {k: _strip_none(v) for k, v in obj.items() if v is not None} + if isinstance(obj, list): + return [_strip_none(item) for item in obj] + return obj + + result = [] + for msg in self.messages: + if hasattr(msg, "model_dump"): + d = msg.model_dump() + else: + d = msg + if strip_nulls: + d = _strip_none(d) + result.append(d) + return result + + # ------------------------------------------------------------------ + # Text extraction / conversion helpers + # ------------------------------------------------------------------ + + def _extract_text_content(self, content) -> str: + """Extract text from message content, handling both string and multimodal list formats.""" + if isinstance(content, str): + return content + if isinstance(content, list): + # Multimodal content: extract text parts only + parts = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + parts.append(part.get("text", "")) + elif isinstance(part, str): + parts.append(part) + return " ".join(parts) + return "" + + def _message_text_content(self, msg) -> str: + """Extract text content from a message for token estimation.""" + if isinstance(msg, Mapping): + return self._extract_text_content(msg.get("content", "")) or "" + elif hasattr(msg, "content"): + return self._extract_text_content(msg.content) or "" + return "" + + def _messages_to_text(self, messages: list) -> str: + """Convert a list of conversation messages to plain text for summarization.""" + lines = [] + for msg in messages: + role = self.get_message_role(msg) + content = "" + if isinstance(msg, Mapping): + content = self._extract_text_content(msg.get("content", "")) + elif hasattr(msg, "content"): + content = self._extract_text_content(msg.content) or "" + + if role == "user": + lines.append(f"User: {content}") + elif role == "assistant": + if content: + lines.append(f"Assistant: {content}") + # Include tool call info + tool_calls = None + if isinstance(msg, Mapping): + tool_calls = msg.get("tool_calls") + elif hasattr(msg, "tool_calls"): + tool_calls = msg.tool_calls + if tool_calls: + for tc in tool_calls: + fn = ( + tc.function + if hasattr(tc, "function") + else tc.get("function", {}) + ) + name = fn.name if hasattr(fn, "name") else fn.get("name", "?") + args = ( + fn.arguments + if hasattr(fn, "arguments") + else fn.get("arguments", "") + ) + lines.append(f" [Tool call: {name}({args})]") + elif role == "tool": + tool_name = ( + msg.get("name", "tool") if isinstance(msg, Mapping) else "tool" + ) + lines.append(f" [Tool result ({tool_name}): {content[:200]}]") + return "\n".join(lines) + + # ------------------------------------------------------------------ + # Role helper + # ------------------------------------------------------------------ + + def get_message_role(self, message) -> str: + """Helper method to get the role of the message regardless of its type.""" + if isinstance(message, Mapping): + return message.get("role") + elif hasattr(message, "role"): + return message.role + else: + raise TypeError( + f"Message is neither a mapping nor has a 'role' attribute: {message}" + ) diff --git a/services/instant_response_generator.py b/services/instant_response_generator.py new file mode 100644 index 000000000..b2f364fbc --- /dev/null +++ b/services/instant_response_generator.py @@ -0,0 +1,130 @@ +"""InstantResponseGenerator service. + +Generates a list of short generic filler phrases via an LLM call (used during +long tool-call turns) and provides a random non-repeating selection from that +list. +""" + +import json +import random +import traceback + +from api.enums import LogType +from services.printr import Printr + +printr = Printr() + + +class InstantResponseGenerator: + """Generates and serves generic instant-filler phrases. + + Construct once per Wingman, then call :meth:`generate` from + ``prepare()`` (wrapped in ``threaded_execution``) and + :meth:`get_random_filler` from the turn loop. + """ + + def __init__( + self, + wingman_name: str, + llm_call_fn, + get_context_fn, + ): + """ + Args: + wingman_name: Used for log messages. + llm_call_fn: Async callable matching ``actual_llm_call(messages, tools=None)`` + returning a ``ChatCompletion | None``. + get_context_fn: Async callable matching ``get_context() -> str``. + """ + self.wingman_name = wingman_name + self._llm_call = llm_call_fn + self._get_context = get_context_fn + self.instant_responses: list[str] = [] + self.last_used_instant_responses: list[int] = [] + + async def generate(self) -> None: + """Populate ``self.instant_responses`` via an LLM call. + + Called by ``Wingman.prepare()`` when + ``config.features.use_generic_instant_responses`` is enabled. + """ + context = await self._get_context() + messages = [ + { + "role": "system", + "content": """ + Generate a list in JSON format of at least 20 short direct text responses. + Make sure the response only contains the JSON, no additional text. + They must fit the described character in the given context by the user. + Every generated response must be generally usable in every situation. + Responses must show its still in progress and not in a finished state. + The user request this response is used on is unknown. Therefore it must be generic. + Good examples: + - "Processing..." + - "Stand by..." + + Bad examples: + - "Generating route..." (too specific) + - "I'm sorry, I can't do that." (too negative) + + Response example: + [ + "OK", + "Generating results...", + "Roger that!", + "Stand by..." + ] + """, + }, + {"role": "user", "content": context}, + ] + try: + completion = await self._llm_call(messages) + if completion is None: + return + if completion.choices[0].message.content: + retry_limit = 3 + retry_count = 1 + valid = False + while not valid and retry_count <= retry_limit: + try: + responses = json.loads(completion.choices[0].message.content) + valid = True + for response in responses: + if response not in self.instant_responses: + self.instant_responses.append(str(response)) + except json.JSONDecodeError: + messages.append(completion.choices[0].message) + messages.append( + { + "role": "user", + "content": "The response could not be parsed as JSON. Return only valid JSON with no additional text.", + } + ) + if retry_count <= retry_limit: + completion = await self._llm_call(messages) + retry_count += 1 + except Exception as e: + await printr.print_async( + f"Error while generating instant responses: {str(e)}", + color=LogType.ERROR, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + + def get_random_filler(self) -> str | None: + """Return a random non-recently-used filler phrase. + + Returns ``None`` when no responses have been generated yet. + """ + if not self.instant_responses: + return None + + if len(self.last_used_instant_responses) > 2: + self.last_used_instant_responses = self.last_used_instant_responses[-2:] + + random_index = random.randint(0, len(self.instant_responses) - 1) + while random_index in self.last_used_instant_responses: + random_index = random.randint(0, len(self.instant_responses) - 1) + + self.last_used_instant_responses.append(random_index) + return self.instant_responses[random_index] diff --git a/services/module_manager.py b/services/module_manager.py index 9771f2809..82031a566 100644 --- a/services/module_manager.py +++ b/services/module_manager.py @@ -12,21 +12,13 @@ SkillBase, SkillConfig, SkillToolInfo, - WingmanConfig, ) -from providers.faster_whisper import FasterWhisper -from providers.parakeet import Parakeet -from providers.whispercpp import Whispercpp -from providers.xvasynth import XVASynth -from services.audio_library import AudioLibrary -from services.audio_player import AudioPlayer -from services.file import get_writable_dir, get_custom_skills_dir +from services.file import get_custom_skills_dir from services.printr import Printr from skills.skill_base import Skill if TYPE_CHECKING: from wingmen.wingman import Wingman - from services.tower import Tower SKILLS_DIR = "skills" @@ -64,62 +56,6 @@ def get_module_name_and_path(module_string: str) -> tuple[str, str]: # module_path = path.join(module_path, module_name + ".py") return module_name, module_path - @staticmethod - def create_wingman_dynamically( - name: str, - config: WingmanConfig, - settings: SettingsConfig, - audio_player: AudioPlayer, - audio_library: AudioLibrary, - whispercpp: Whispercpp, - fasterwhisper: FasterWhisper, - parakeet: Parakeet, - xvasynth: XVASynth, - tower: "Tower", - ): - """Dynamically creates a Wingman instance from a module path and class name - - Args: - name (str): The name of the wingman. This is the key you gave it in the config, e.g. "atc" - config (WingmanConfig): All "general" config entries merged with the specific Wingman config settings. The Wingman takes precedence and overrides the general config. You can just add new keys to the config and they will be available here. - settings (SettingsConfig): The general user settings. - audio_player (AudioPlayer): The audio player handling the playback of audio files. - audio_library (AudioLibrary): The audio library handling the storage and retrieval of audio files. - whispercpp (Whispercpp): The Whispercpp provider for speech-to-text. - fasterwhisper (FasterWhisper): The FasterWhisper provider for speech-to-text. - parakeet (Parakeet): The Parakeet provider for speech-to-text via ONNX Runtime. - xvasynth (XVASynth): The XVASynth provider for text-to-speech. - tower (Tower): The Tower instance, that manages loaded Wingmen. - """ - - try: - # try to load from app dir first - module = import_module(config.custom_class.module) - except ModuleNotFoundError: - # split module into name and path - module_name, module_path = ModuleManager.get_module_name_and_path( - config.custom_class.module - ) - module_path = path.join(get_writable_dir(module_path), module_name + ".py") - # load from alternative absolute file path - spec = util.spec_from_file_location(module_name, module_path) - module = util.module_from_spec(spec) - spec.loader.exec_module(module) - DerivedWingmanClass = getattr(module, config.custom_class.name) - instance = DerivedWingmanClass( - name=name, - config=config, - settings=settings, - audio_player=audio_player, - audio_library=audio_library, - whispercpp=whispercpp, - fasterwhisper=fasterwhisper, - parakeet=parakeet, - xvasynth=xvasynth, - tower=tower, - ) - return instance - @staticmethod def load_skill( config: SkillConfig, settings: SettingsConfig, wingman: "Wingman" diff --git a/services/persistent_memory.py b/services/persistent_memory.py index 9a5762f1c..e73ef367b 100644 --- a/services/persistent_memory.py +++ b/services/persistent_memory.py @@ -586,6 +586,63 @@ def close(self) -> None: self._db.close() self._db = None + def get_tool_definitions(self) -> list[dict]: + """Return the OpenAI-style tool definitions for persistent memory operations. + These are exposed to the LLM whenever a PersistentMemoryService is active.""" + return [ + { + "type": "function", + "function": { + "name": "memory_remember", + "description": "Store an important fact or detail for future reference. Use when the user explicitly asks you to remember something.", + "parameters": { + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "The fact or detail to remember.", + }, + }, + "required": ["text"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "memory_recall", + "description": "Search your memory for relevant information. Use when the user asks what you remember or know about a topic.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "What to search for in memory.", + }, + }, + "required": ["query"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "memory_forget", + "description": "Remove a specific memory. Use when the user explicitly asks you to forget something.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Description of the memory to forget.", + }, + }, + "required": ["query"], + }, + }, + }, + ] + # --- Private helpers --- def _find_duplicate(self, embedding: list[float]) -> MemoryEntry | None: diff --git a/services/platform_utils.py b/services/platform_utils.py new file mode 100644 index 000000000..8f859a155 --- /dev/null +++ b/services/platform_utils.py @@ -0,0 +1,14 @@ +"""Shared platform helpers.""" + +import sys + +_PLATFORM_MAP = {"win32": "windows", "darwin": "darwin", "linux": "linux"} + + +def normalize_platform(platform: str | None = None) -> str: + """Return a normalized platform name (``windows``/``darwin``/``linux``). + + Falls back to the raw ``sys.platform`` string if no mapping exists. + """ + raw = platform if platform is not None else sys.platform + return _PLATFORM_MAP.get(raw, raw) diff --git a/services/provider_factory.py b/services/provider_factory.py new file mode 100644 index 000000000..2e70ad351 --- /dev/null +++ b/services/provider_factory.py @@ -0,0 +1,334 @@ +"""Factory for creating provider instances from config. + +Reads the config enum values, looks up the registered adapter class from +the decorator registry, retrieves API keys via SecretKeeper, and instantiates +the provider. Each provider holds a reference to the live config object. +""" + +import traceback +from typing import TYPE_CHECKING + +from api.enums import ( + ConversationProvider, + ImageGenerationProvider, + SttProvider, + TtsProvider, + WingmanInitializationErrorType, +) +from api.interface import WingmanInitializationError +from providers.interfaces import ( + LlmInterface, + SttInterface, + TtsInterface, + Validatable, + get_llm_class, + get_stt_class, + get_tts_class, +) +from services.printr import Printr + +if TYPE_CHECKING: + from api.interface import SettingsConfig, WingmanConfig + from services.secret_keeper import SecretKeeper + +printr = Printr() + + +# Import all provider modules so their decorators run and populate the registries. +# These imports have no other side effects. +import providers.faster_whisper # noqa: F401 +import providers.parakeet # noqa: F401 +import providers.whispercpp # noqa: F401 +import providers.open_ai # noqa: F401 +import providers.google # noqa: F401 +import providers.x_ai # noqa: F401 +import providers.elevenlabs # noqa: F401 +import providers.edge # noqa: F401 +import providers.hume # noqa: F401 +import providers.inworld # noqa: F401 +import providers.pocket_tts # noqa: F401 +import providers.xvasynth # noqa: F401 +import providers.wingman_subscription # noqa: F401 + + +class ProviderFactory: + """Creates STT, TTS, and LLM provider instances from config.""" + + def __init__( + self, + config: "WingmanConfig", + settings: "SettingsConfig", + secret_keeper: "SecretKeeper", + shared_providers: dict, + wingman_name: str, + ): + self._config = config + self._settings = settings + self._secret_keeper = secret_keeper + self._shared = shared_providers + self._wingman_name = wingman_name + + async def _retrieve_secret( + self, requester: str, errors: list[WingmanInitializationError] + ) -> str | None: + """Retrieve an API key, adding to errors if missing.""" + secret = await self._secret_keeper.retrieve( + requester=requester, + key=requester, + prompt_if_missing=True, + ) + if not secret: + errors.append( + WingmanInitializationError( + wingman_name=self._wingman_name, + message=f"Missing API key for '{requester}'.", + error_type=WingmanInitializationErrorType.MISSING_SECRET, + ) + ) + return secret + + async def create_stt( + self, errors: list[WingmanInitializationError] + ) -> SttInterface | None: + """Create the STT provider from config.""" + stt_enum = self._config.features.stt_provider + # Shared singleton providers — wrap in adapter + if stt_enum == SttProvider.FASTER_WHISPER: + from providers.faster_whisper import FasterWhisperStt + return FasterWhisperStt( + shared=self._shared["fasterwhisper"], + config=self._config, + wingman_name=self._wingman_name, + ) + elif stt_enum == SttProvider.PARAKEET: + from providers.parakeet import ParakeetStt + return ParakeetStt(shared=self._shared["parakeet"], config=self._config) + elif stt_enum == SttProvider.WHISPERCPP: + from providers.whispercpp import WhispercppStt + return WhispercppStt(shared=self._shared["whispercpp"], config=self._config) + elif stt_enum == SttProvider.OPENAI: + api_key = await self._retrieve_secret("openai", errors) + if not api_key: + return None + from providers.open_ai import OpenAi, OpenAiStt + openai = OpenAi(api_key=api_key, organization=self._config.openai.organization) + return OpenAiStt(openai_instance=openai) + elif stt_enum == SttProvider.GROQ: + api_key = await self._retrieve_secret("groq", errors) + if not api_key: + return None + from providers.open_ai import OpenAi, GroqStt + groq = OpenAi(api_key=api_key, base_url=self._config.groq.endpoint) + return GroqStt(openai_instance=groq) + elif stt_enum == SttProvider.AZURE: + api_key = await self._retrieve_secret("azure", errors) + if not api_key: + return None + from providers.open_ai import OpenAiAzure, AzureWhisperStt + return AzureWhisperStt( + azure_instance=OpenAiAzure(), api_key=api_key, config=self._config + ) + elif stt_enum == SttProvider.AZURE_SPEECH: + api_key = await self._retrieve_secret("azure", errors) + if not api_key: + return None + from providers.open_ai import OpenAiAzure, AzureSpeechStt + return AzureSpeechStt( + azure_instance=OpenAiAzure(), api_key=api_key, config=self._config + ) + elif stt_enum == SttProvider.WINGMAN_PRO: + from providers.wingman_subscription import WingmanSubscription, WingmanSubscriptionStt + ws = WingmanSubscription( + wingman_name=self._wingman_name, + settings=self._settings.wingman_pro, + ) + return WingmanSubscriptionStt(ws_instance=ws, config=self._config) + return None + + async def create_tts( + self, errors: list[WingmanInitializationError] + ) -> TtsInterface | None: + """Create the TTS provider from config.""" + tts_enum = self._config.features.tts_provider + if tts_enum == TtsProvider.EDGE_TTS: + from providers.edge import EdgeTts + return EdgeTts(config=self._config) + elif tts_enum == TtsProvider.ELEVENLABS: + api_key = await self._retrieve_secret("elevenlabs", errors) + if not api_key: + return None + from providers.elevenlabs import ElevenLabs, ElevenLabsTts + elevenlabs = ElevenLabs(api_key=api_key, wingman_name=self._wingman_name) + return ElevenLabsTts(elevenlabs_instance=elevenlabs, config=self._config) + elif tts_enum == TtsProvider.HUME: + api_key = await self._retrieve_secret("hume", errors) + if not api_key: + return None + from providers.hume import Hume, HumeTts + hume = Hume(api_key=api_key, wingman_name=self._wingman_name) + return HumeTts(hume_instance=hume, config=self._config) + elif tts_enum == TtsProvider.INWORLD: + api_key = await self._retrieve_secret("inworld", errors) + if not api_key: + return None + from providers.inworld import Inworld, InworldTts + inworld = Inworld(api_key=api_key, wingman_name=self._wingman_name) + return InworldTts(inworld_instance=inworld, config=self._config) + elif tts_enum == TtsProvider.OPENAI: + api_key = await self._retrieve_secret("openai", errors) + if not api_key: + return None + from providers.open_ai import OpenAi, OpenAiTts + openai = OpenAi(api_key=api_key, organization=self._config.openai.organization) + return OpenAiTts(openai_instance=openai, config=self._config) + elif tts_enum == TtsProvider.OPENAI_COMPATIBLE: + api_key = await self._retrieve_secret("openai_compatible", errors) + # api_key might be optional for local endpoints + from providers.open_ai import OpenAiCompatibleTts, OpenAiCompatibleTtsAdapter + tts = OpenAiCompatibleTts( + api_key=api_key or "", + base_url=self._config.openai_compatible_tts.endpoint, + ) + return OpenAiCompatibleTtsAdapter(tts_instance=tts, config=self._config) + elif tts_enum == TtsProvider.AZURE: + api_key = await self._retrieve_secret("azure", errors) + if not api_key: + return None + from providers.open_ai import OpenAiAzure, AzureTts + return AzureTts( + azure_instance=OpenAiAzure(), api_key=api_key, config=self._config + ) + elif tts_enum == TtsProvider.XVASYNTH: + from providers.xvasynth import XVASynthTts + return XVASynthTts(shared=self._shared["xvasynth"], config=self._config) + elif tts_enum == TtsProvider.POCKET_TTS: + from providers.pocket_tts import PocketTtsTts + return PocketTtsTts(shared=self._shared["pocket_tts"], config=self._config) + elif tts_enum == TtsProvider.WINGMAN_PRO: + from providers.wingman_subscription import WingmanSubscription, WingmanSubscriptionTts + ws = WingmanSubscription( + wingman_name=self._wingman_name, + settings=self._settings.wingman_pro, + ) + return WingmanSubscriptionTts(ws_instance=ws, config=self._config) + return None + + async def create_llm( + self, errors: list[WingmanInitializationError] + ) -> LlmInterface | None: + """Create the LLM provider from config.""" + llm_enum = self._config.features.conversation_provider + if llm_enum == ConversationProvider.OPENAI: + api_key = await self._retrieve_secret("openai", errors) + if not api_key: + return None + from providers.open_ai import OpenAi, OpenAiLlm + openai = OpenAi(api_key=api_key, organization=self._config.openai.organization) + return OpenAiLlm(openai_instance=openai, config=self._config) + elif llm_enum == ConversationProvider.MISTRAL: + api_key = await self._retrieve_secret("mistral", errors) + if not api_key: + return None + from providers.open_ai import OpenAi, MistralLlm + mistral = OpenAi(api_key=api_key, base_url=self._config.mistral.endpoint) + return MistralLlm(openai_instance=mistral, config=self._config) + elif llm_enum == ConversationProvider.GROQ: + api_key = await self._retrieve_secret("groq", errors) + if not api_key: + return None + from providers.open_ai import OpenAi, GroqLlm + groq = OpenAi(api_key=api_key, base_url=self._config.groq.endpoint) + return GroqLlm(openai_instance=groq, config=self._config) + elif llm_enum == ConversationProvider.CEREBRAS: + api_key = await self._retrieve_secret("cerebras", errors) + if not api_key: + return None + from providers.open_ai import OpenAi, CerebrasLlm + cerebras = OpenAi(api_key=api_key, base_url=self._config.cerebras.endpoint) + return CerebrasLlm(openai_instance=cerebras, config=self._config) + elif llm_enum == ConversationProvider.GOOGLE: + api_key = await self._retrieve_secret("google", errors) + if not api_key: + return None + from providers.google import GoogleGenAI, GoogleLlm + google = GoogleGenAI(api_key=api_key) + return GoogleLlm(google_instance=google, config=self._config) + elif llm_enum == ConversationProvider.OPENROUTER: + api_key = await self._retrieve_secret("openrouter", errors) + if not api_key: + return None + from providers.open_ai import OpenAi, OpenRouterLlm + openrouter = OpenAi(api_key=api_key, base_url=self._config.openrouter.endpoint) + supports_tools = await self._check_openrouter_tool_support(api_key) + return OpenRouterLlm( + openai_instance=openrouter, config=self._config, + supports_tools=supports_tools, + ) + elif llm_enum == ConversationProvider.LOCAL_LLM: + from providers.open_ai import OpenAi, LocalLlm + local_llm = None + if self._config.local_llm.endpoint: + local_llm = OpenAi( + api_key="not-needed", + base_url=self._config.local_llm.endpoint, + ) + return LocalLlm(openai_instance=local_llm, config=self._config) + elif llm_enum == ConversationProvider.AZURE: + api_key = await self._retrieve_secret("azure", errors) + if not api_key: + return None + from providers.open_ai import OpenAiAzure, AzureLlm + return AzureLlm( + azure_instance=OpenAiAzure(), api_key=api_key, config=self._config + ) + elif llm_enum == ConversationProvider.WINGMAN_PRO: + from providers.wingman_subscription import WingmanSubscription, WingmanSubscriptionLlm + ws = WingmanSubscription( + wingman_name=self._wingman_name, + settings=self._settings.wingman_pro, + ) + return WingmanSubscriptionLlm(ws_instance=ws, config=self._config) + elif llm_enum == ConversationProvider.PERPLEXITY: + api_key = await self._retrieve_secret("perplexity", errors) + if not api_key: + return None + from providers.open_ai import OpenAi, PerplexityLlm + perplexity = OpenAi(api_key=api_key, base_url=self._config.perplexity.endpoint) + return PerplexityLlm(openai_instance=perplexity, config=self._config) + elif llm_enum == ConversationProvider.XAI: + api_key = await self._retrieve_secret("xai", errors) + if not api_key: + return None + from providers.x_ai import XAi, XAiLlm + xai = XAi(api_key=api_key, base_url=self._config.xai.endpoint) + return XAiLlm(xai_instance=xai, config=self._config) + return None + + async def _check_openrouter_tool_support(self, api_key: str) -> bool: + """Check if the configured OpenRouter model supports tools. + + Replicates the logic from OpenAiWingman.validate_and_set_openrouter(). + """ + try: + import asyncio + import requests + + model = self._config.openrouter.conversation_model + + def _fetch(): + return requests.get( + f"https://openrouter.ai/api/v1/models/{model}", + headers={"Authorization": f"Bearer {api_key}"}, + timeout=10, + ) + + response = await asyncio.to_thread(_fetch) + if response.status_code == 200: + result = response.json() + supported_params = result.get("data", {}).get( + "supported_parameters", [] + ) + return "tools" in supported_params + except Exception: + pass + return False diff --git a/services/skill_local_ai.py b/services/skill_local_ai.py index e03fe8c5e..cfaaa6119 100644 --- a/services/skill_local_ai.py +++ b/services/skill_local_ai.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from services.local_ai_service import TokenBudget - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext printr = Printr() @@ -109,7 +109,7 @@ class SkillLocalAI: 4. On success: convert internal types to facade types and return """ - def __init__(self, wingman: "OpenAiWingman"): + def __init__(self, wingman: "WingmanContext"): self._wingman = wingman # ── Availability ────────────────────────────────────────────── diff --git a/services/stt_provider_manager.py b/services/stt_provider_manager.py index ed26e76fc..5b72b7b0c 100644 --- a/services/stt_provider_manager.py +++ b/services/stt_provider_manager.py @@ -1,6 +1,5 @@ import asyncio import os -import platform from typing import Awaitable, Callable, Optional from api.enums import LogType, VoiceActivationSttProvider @@ -74,9 +73,6 @@ async def _initialize_parakeet( """Download and initialize Parakeet.""" pk_settings = self.settings_service.settings.voice_activation.parakeet - # Auto-detect CUDA and update execution_provider in settings - self._auto_detect_execution_provider(pk_settings) - # Download model variant = pk_settings.model_variant repo_id = PARAKEET_REPO_MAP.get(variant) @@ -142,41 +138,6 @@ async def _initialize_fasterwhisper( await self._health_check_fasterwhisper() - def _auto_detect_execution_provider(self, pk_settings): - """Auto-detect CUDA and set execution_provider if still on default (cpu).""" - if pk_settings.execution_provider != "cpu": - # User has manually set a non-default provider — respect it - self.printr.print( - f"Parakeet execution_provider already set to '{pk_settings.execution_provider}', skipping auto-detection.", - server_only=True, - color=LogType.INFO, - ) - return - - if platform.system() == "Darwin": - # macOS — always CPU (CoreML excluded for TDT models) - pk_settings.execution_provider = "cpu" - return - - if self.system_manager.is_cuda_available(): - pk_settings.execution_provider = "cuda" - gpu_name = self.system_manager.get_gpu_name() or "Unknown GPU" - self.printr.print( - f"CUDA detected ({gpu_name}). Setting Parakeet to CUDA execution provider.", - server_only=True, - color=LogType.POSITIVE, - ) - else: - pk_settings.execution_provider = "cpu" - self.printr.print( - "No CUDA available. Parakeet will use CPU execution provider.", - server_only=True, - color=LogType.INFO, - ) - - # Persist to settings - self.settings_service.save_settings_to_disk() - async def switch_provider(self, new_provider: VoiceActivationSttProvider): """Switch active STT provider. Unloads old, downloads + loads new.""" old_provider = self.active_provider diff --git a/services/threading_utils.py b/services/threading_utils.py new file mode 100644 index 000000000..a24dd674b --- /dev/null +++ b/services/threading_utils.py @@ -0,0 +1,42 @@ +"""Threading helpers used across wingmen and skills.""" + +import asyncio +import threading +import traceback + +from api.enums import LogType +from services.printr import Printr + +printr = Printr() + + +def threaded_execution(function, *args) -> threading.Thread | None: + """Run ``function`` in a fresh daemon thread. + + If ``function`` is a coroutine function, a new event loop is created + inside the thread to run it. Otherwise it is called directly. + + Returns the started thread, or ``None`` on failure. + """ + try: + + def start_thread(function, *args): + if asyncio.iscoroutinefunction(function): + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + new_loop.run_until_complete(function(*args)) + new_loop.close() + else: + function(*args) + + thread = threading.Thread(target=start_thread, args=(function, *args)) + thread.name = function.__name__ + thread.daemon = True + thread.start() + return thread + except Exception as e: + printr.print( + f"Error starting threaded execution: {str(e)}", color=LogType.ERROR + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + return None diff --git a/services/tool_executor.py b/services/tool_executor.py new file mode 100644 index 000000000..b4cedcb66 --- /dev/null +++ b/services/tool_executor.py @@ -0,0 +1,518 @@ +"""Tool call dispatch -- routes function calls to memory, skills, MCP, commands, etc.""" + +import json +import time +import traceback +from typing import TYPE_CHECKING, Any, Callable, Awaitable + +from api.enums import LogType +from services.benchmark import Benchmark +from services.printr import Printr +from services.tool_response_cache import ToolResponseCompressor + +if TYPE_CHECKING: + from api.interface import CommandConfig, WingmanConfig, SettingsConfig + from services.capability_registry import CapabilityRegistry + from services.mcp_registry import McpRegistry + from services.persistent_memory import PersistentMemoryService + from services.skill_registry import SkillRegistry + from skills.skill_base import Skill + +printr = Printr() + + +class ToolExecutor: + """Dispatches tool calls to the appropriate handler (memory, skills, MCP, commands). + + This is a stateless dispatcher -- all mutable state (registries, services, callbacks) + is passed as parameters to ``handle_tool_calls`` / ``execute_by_function_call``. + """ + + def __init__( + self, + config: "WingmanConfig", + settings: "SettingsConfig", + wingman_name: str, + ): + self._config = config + self._settings = settings + self._wingman_name = wingman_name + self._tool_response_compressor = ToolResponseCompressor() + + # ------------------------------------------------------------------ + # fix_tool_calls (was _fix_tool_calls) + # ------------------------------------------------------------------ + + async def fix_tool_calls( + self, + tool_calls, + get_command_fn: Callable[[str], "CommandConfig | None"], + ): + """Fixes tool calls that have a command name as function name. + + Mistral sometimes returns the command name directly as the function name + instead of wrapping it in ``execute_command``. This method detects that + pattern and rewrites the tool call accordingly. + + Args: + tool_calls: The tool calls to fix. + get_command_fn: Callback ``(name) -> CommandConfig | None`` used to + check whether a string is a known command name. + + Returns: + list: The fixed tool calls. + """ + if tool_calls and len(tool_calls) > 0: + for tool_call in tool_calls: + function_name = tool_call.function.name + function_args = ( + tool_call.function.arguments + # Mistral returns a dict + if isinstance(tool_call.function.arguments, dict) + # OpenAI returns a string + else json.loads(tool_call.function.arguments) + ) + + # try to resolve function name to a command name + if (len(function_args) == 0 and get_command_fn(function_name)) or ( + len(function_args) == 1 + and "command_name" in function_args + and get_command_fn(function_args["command_name"]) + and function_name == function_args["command_name"] + ): + function_args["command_name"] = function_name + function_name = "execute_command" + + # update the tool call + tool_call.function.name = function_name + tool_call.function.arguments = json.dumps(function_args) + + if self._settings.debug_mode: + await printr.print_async( + "Applied command call fix.", color=LogType.WARNING + ) + + return tool_calls + + # ------------------------------------------------------------------ + # handle_tool_calls (was _handle_tool_calls) + # ------------------------------------------------------------------ + + async def handle_tool_calls( + self, + tool_calls, + *, + tool_skills: dict, + skill_registry: "SkillRegistry", + mcp_registry: "McpRegistry", + capability_registry: "CapabilityRegistry", + persistent_memory_service: "PersistentMemoryService | None", + get_command_fn: Callable[[str], "CommandConfig | None"], + execute_command_fn: Callable[["CommandConfig", bool], Awaitable[tuple]], + play_to_user_fn: Callable[[str], Awaitable[None]], + local_ai_service, + update_tool_response_fn: Callable[[str, str], Awaitable[bool]], + add_tool_response_fn: Callable, + pending_tool_calls: list, + ): + """Processes all the tool calls identified in the response message. + + Args: + tool_calls: The list of tool calls to process. + tool_skills: Mapping of tool name -> Skill instance. + skill_registry: For skill search/activation. + mcp_registry: For MCP tool execution. + capability_registry: For capability management. + persistent_memory_service: For memory tools (may be None). + get_command_fn: Callback to get a Command by name. + execute_command_fn: Callback to execute a Command. + play_to_user_fn: Callback to play audio response. + local_ai_service: For tool response compression. + update_tool_response_fn: Callback to update an existing tool + response in conversation history. + add_tool_response_fn: Callback to add a new tool response to + conversation history. + pending_tool_calls: List reference for tracking pending calls. + + Returns: + tuple: (instant_response, skill, tool_timings) where tool_timings + is a list of (label, time_ms) tuples. + """ + instant_response = None + function_response = "" + tool_timings: list[tuple[str, float]] = [] + + skill = None + + for tool_call in tool_calls: + try: + function_name = tool_call.function.name + function_args = ( + tool_call.function.arguments + # Mistral returns a dict + if isinstance(tool_call.function.arguments, dict) + # OpenAI returns a string + else json.loads(tool_call.function.arguments) + ) + + # Time the individual tool execution + tool_start = time.perf_counter() + ( + function_response, + instant_response, + skill, + tool_label, + ) = await self.execute_by_function_call( + function_name, + function_args, + tool_skills=tool_skills, + skill_registry=skill_registry, + mcp_registry=mcp_registry, + capability_registry=capability_registry, + persistent_memory_service=persistent_memory_service, + get_command_fn=get_command_fn, + execute_command_fn=execute_command_fn, + play_to_user_fn=play_to_user_fn, + ) + tool_time_ms = (time.perf_counter() - tool_start) * 1000 + + # Add timing if we got a label (actual tool execution, not meta-tool) + if tool_label: + tool_timings.append((tool_label, tool_time_ms)) + + # Compress large tool responses via local AI before the cloud LLM sees them + if ( + tool_call.id + and self._config.features.compress_tool_responses + and local_ai_service + and local_ai_service.is_ready() + and self._tool_response_compressor.should_compress( + str(function_response) + ) + ): + function_response = await self._tool_response_compressor.compress( + response_text=str(function_response), + local_ai_service=local_ai_service, + wingman_name=self._wingman_name, + tool_name=function_name, + ) + + if tool_call.id: + # updating the dummy tool response with the actual response + await update_tool_response_fn(tool_call.id, function_response) + else: + # adding a new tool response + add_tool_response_fn(tool_call, function_response) + except Exception as e: + if tool_call.id: + await update_tool_response_fn(tool_call.id, "Error") + else: + add_tool_response_fn(tool_call, "Error") + await printr.print_async( + f"Error while processing tool call: {str(e)}", color=LogType.ERROR + ) + printr.print( + traceback.format_exc(), color=LogType.ERROR, server_only=True + ) + return instant_response, skill, tool_timings + + # ------------------------------------------------------------------ + # execute_by_function_call (was execute_command_by_function_call) + # ------------------------------------------------------------------ + + async def execute_by_function_call( + self, + function_name: str, + function_args: dict[str, Any], + *, + tool_skills: dict, + skill_registry: "SkillRegistry", + mcp_registry: "McpRegistry", + capability_registry: "CapabilityRegistry", + persistent_memory_service: "PersistentMemoryService | None", + get_command_fn: Callable[[str], "CommandConfig | None"], + execute_command_fn: Callable[["CommandConfig", bool], Awaitable[tuple]], + play_to_user_fn: Callable[[str], Awaitable[None]], + ) -> tuple[str, str | None, "Skill | None", str | None]: + """Dispatches a single function call to the appropriate handler. + + Uses an OpenAI function call to execute a command. If it's an instant + activation_command, one of its responses will be played. + + Args: + function_name: The name of the function to be executed. + function_args: The arguments to pass to the function being executed. + tool_skills: Mapping of tool name -> Skill instance. + skill_registry: For skill search/activation and display names. + mcp_registry: For MCP tool execution and meta-tools. + capability_registry: For capability meta-tools. + persistent_memory_service: For memory tools (may be None). + get_command_fn: Callback ``(name) -> CommandConfig | None``. + execute_command_fn: Callback to execute a Command. + play_to_user_fn: Callback to play audio response. + + Returns: + A tuple containing: + - function_response (str): The text response or result obtained + after executing the function. + - instant_response (str | None): An immediate response or action + to be taken, if any (e.g., play audio). + - used_skill (Skill | None): The skill that was used, if any. + - tool_label (str | None): Label for benchmark timing + (e.g., "MCP: resolve-library-id"), or None for meta-tools. + """ + function_response = "" + instant_response = "" + used_skill = None + tool_label = None + + # ── 1. Persistent memory tools ────────────────────────────── + if ( + function_name in ("memory_remember", "memory_recall", "memory_forget") + and persistent_memory_service + ): + if function_name == "memory_remember": + text = function_args.get("text", "") + if text: + await persistent_memory_service.add_memory( + entry_type="fact", content=text + ) + function_response = f'I\'ll remember that: "{text}"' + await printr.print_async( + f"Memory stored: {text}", + color=LogType.MEMORY, + source_name=self._wingman_name, + ) + else: + function_response = "Nothing to remember -- no text provided." + + elif function_name == "memory_recall": + query = function_args.get("query", "") + if query: + results = await persistent_memory_service.search( + query, limit=10 + ) + if results: + lines = [f"- {r.content}" for r in results] + function_response = ( + "Here's what I remember:\n" + "\n".join(lines) + ) + else: + function_response = ( + "I don't have any memories matching that." + ) + else: + function_response = "No query provided for memory recall." + + elif function_name == "memory_forget": + query = function_args.get("query", "") + if query: + deleted = await persistent_memory_service.forget_by_query(query) + if deleted: + function_response = ( + f'Done -- I\'ve forgotten the memory related to "{query}".' + ) + else: + function_response = ( + "I couldn't find a memory closely matching that to forget." + ) + else: + function_response = "No query provided for memory forget." + + return function_response, None, None, f"💾 memory: {function_name}" + + # ── 2. Unified capability meta-tools ──────────────────────── + if capability_registry.is_meta_tool(function_name): + function_response, tools_changed = ( + await capability_registry.execute_meta_tool( + function_name, function_args + ) + ) + + # If a skill was activated, perform lazy validation + if tools_changed and function_name == "activate_capability": + capability_name = function_args.get("capability_name", "") + skill = skill_registry.get_skill_for_activation(capability_name) + if skill and skill.needs_activation(): + success, validation_msg = await skill.ensure_activated() + if not success: + # Validation failed -- deactivate the skill + skill_registry.deactivate_skill(capability_name) + function_response = validation_msg + tools_changed = False + await printr.print_async( + f"Skill activation failed: {capability_name}", + color=LogType.ERROR, + ) + else: + # Get display name for user-friendly message + display_name = skill_registry.get_skill_display_name( + capability_name + ) + await printr.print_async( + f"Skill activated: {display_name}", + color=LogType.SKILL, + ) + + return function_response, None, None, None # Meta-tool, no timing label + + # ── 3. Legacy skill meta-tools ────────────────────────────── + if skill_registry.is_meta_tool(function_name): + function_response, tools_changed = ( + await skill_registry.execute_meta_tool(function_name, function_args) + ) + + # If skill was activated, perform lazy validation + if tools_changed and function_name == "activate_skill": + skill_name = function_args.get("skill_name", "") + skill = skill_registry.get_skill_for_activation(skill_name) + if skill and skill.needs_activation(): + success, validation_msg = await skill.ensure_activated() + if not success: + # Validation failed -- deactivate the skill + skill_registry.deactivate_skill(skill_name) + function_response = validation_msg + tools_changed = False + await printr.print_async( + f"Skill activation failed: {skill_name}", + color=LogType.ERROR, + ) + else: + # Get display name for user-friendly message + display_name = skill_registry.get_skill_display_name( + skill_name + ) + await printr.print_async( + f"Skill activated: {display_name}", + color=LogType.SKILL, + ) + + return function_response, None, None, None # Meta-tool, no timing label + + # ── 4. MCP meta-tools ─────────────────────────────────────── + if mcp_registry.is_meta_tool(function_name): + function_response, tools_changed = ( + await mcp_registry.execute_meta_tool(function_name, function_args) + ) + return function_response, None, None, None # Meta-tool, no timing label + + # ── 5. MCP server tools (prefixed with mcp_) ──────────────── + if mcp_registry.is_mcp_tool(function_name): + connection = mcp_registry.get_connection_for_tool(function_name) + if connection: + display_name = connection.config.display_name + original_name = mcp_registry.get_original_tool_name(function_name) + tool_label = f"🌐 {display_name}: {original_name}" + + benchmark = Benchmark( + f"MCP '{connection.config.name}' - {original_name}" + ) + + # Always show simple 'called' message in UI + await printr.print_async( + f"{display_name}: called `{original_name}` with {function_args}", + color=LogType.MCP, + ) + + # Detailed 'calling' log only in terminal/log file + await printr.print_async( + f"{display_name}: calling `{original_name}` with {function_args}...", + color=LogType.MCP, + server_only=True, + ) + + try: + function_response = await mcp_registry.call_tool( + function_name, function_args + ) + except Exception as e: + await printr.print_async( + f"{display_name}: `{original_name}` failed - {str(e)}", + color=LogType.ERROR, + ) + printr.print( + traceback.format_exc(), + color=LogType.ERROR, + server_only=True, + ) + function_response = "ERROR DURING MCP TOOL EXECUTION" + finally: + # Detailed 'completed' with timing only in terminal/log file + await printr.print_async( + f"{display_name}: `{original_name}` completed", + color=LogType.MCP, + benchmark_result=benchmark.finish(), + server_only=not self._settings.debug_mode, + ) + + return function_response, None, None, tool_label + + # ── 6. Command execution ──────────────────────────────────── + if function_name == "execute_command": + # get the command based on the argument passed by the LLM + command = get_command_fn(function_args["command_name"]) + # execute the command + instant_response, function_response = await execute_command_fn( + command + ) + tool_label = ( + f"Command: {function_args.get('command_name', function_name)}" + ) + # if the command has responses, we have to play one of them + if instant_response: + await play_to_user_fn(instant_response) + + # ── 7. Skill tool execution ───────────────────────────────── + if function_name in tool_skills: + skill = tool_skills[function_name] + display_name = skill_registry.get_skill_display_name(skill.name) + tool_label = f"⚡ {display_name}: {function_name}" + + benchmark = Benchmark(f"Skill '{skill.name}' - {function_name}") + + # Always show simple 'called' message in UI + await printr.print_async( + f"{display_name}: called `{function_name}` with {function_args}", + color=LogType.SKILL, + skill_name=skill.name, + ) + + # Detailed 'calling' log only in terminal/log file + await printr.print_async( + f"{display_name}: calling `{function_name}` with {function_args}...", + color=LogType.SKILL, + skill_name=skill.name, + server_only=True, + ) + + try: + function_response, instant_response = await skill.execute_tool( + tool_name=function_name, + parameters=function_args, + benchmark=benchmark, + ) + used_skill = skill + if instant_response: + await play_to_user_fn(instant_response) + except Exception as e: + await printr.print_async( + f"{display_name}: `{function_name}` failed - {str(e)}", + color=LogType.ERROR, + ) + printr.print( + traceback.format_exc(), color=LogType.ERROR, server_only=True + ) + function_response = ( + "ERROR DURING PROCESSING" # hints to AI that there was an error + ) + instant_response = None + finally: + await printr.print_async( + f"{display_name}: `{function_name}` completed", + color=LogType.SKILL, + benchmark_result=benchmark.finish(), + skill_name=skill.name, + server_only=not self._settings.debug_mode, + ) + + return function_response, instant_response, used_skill, tool_label diff --git a/services/tower.py b/services/tower.py index fc8b3eaa5..8baea4a0d 100644 --- a/services/tower.py +++ b/services/tower.py @@ -16,7 +16,6 @@ from services.audio_player import AudioPlayer from services.audio_library import AudioLibrary from services.config_manager import ConfigManager -from services.module_manager import ModuleManager from services.printr import Printr from wingmen.open_ai_wingman import OpenAiWingman from wingmen.wingman import Wingman @@ -109,35 +108,19 @@ async def __instantiate_wingman( ): wingman = None try: - # it's a custom Wingman - if wingman_config.custom_class: - wingman = ModuleManager.create_wingman_dynamically( - name=wingman_name, - config=wingman_config, - settings=settings, - audio_player=self.audio_player, - audio_library=self.audio_library, - whispercpp=self.whispercpp, - fasterwhisper=self.fasterwhisper, - parakeet=self.parakeet, - xvasynth=self.xvasynth, - pocket_tts=self.pocket_tts, - tower=self, - ) - else: - wingman = OpenAiWingman( - name=wingman_name, - config=wingman_config, - settings=settings, - audio_player=self.audio_player, - audio_library=self.audio_library, - whispercpp=self.whispercpp, - fasterwhisper=self.fasterwhisper, - parakeet=self.parakeet, - xvasynth=self.xvasynth, - pocket_tts=self.pocket_tts, - tower=self, - ) + wingman = OpenAiWingman( + name=wingman_name, + config=wingman_config, + settings=settings, + audio_player=self.audio_player, + audio_library=self.audio_library, + whispercpp=self.whispercpp, + fasterwhisper=self.fasterwhisper, + parakeet=self.parakeet, + xvasynth=self.xvasynth, + pocket_tts=self.pocket_tts, + tower=self, + ) except FileNotFoundError as e: # pylint: disable=broad-except wingman_config.disabled = True self.disabled_wingmen.append(wingman_config) diff --git a/services/turn_metrics.py b/services/turn_metrics.py new file mode 100644 index 000000000..6aca93cce --- /dev/null +++ b/services/turn_metrics.py @@ -0,0 +1,108 @@ +"""Per-turn benchmark snapshot building and token-usage broadcasting.""" + +from api.enums import ConversationProvider +from api.interface import BenchmarkResult, WingmanConfig +from services.benchmark import Benchmark, format_ms +from services.printr import Printr +from services.token_utils import count_tokens + +printr = Printr() + + +class TurnMetrics: + """Focused service for per-turn benchmark snapshots and token-usage broadcast.""" + + def __init__( + self, + wingman_name: str, + config: WingmanConfig, + conversation, + ): + self.wingman_name = wingman_name + self.config = config + self.conversation = conversation + self.last_turn_prompt_tokens: int = 0 + self.last_turn_completion_tokens: int = 0 + + # ──────────────────────────── public API ─────────────────────────── # + + def add_benchmark_snapshot( + self, benchmark: Benchmark, label: str, execution_time_ms: float + ) -> None: + benchmark.snapshots.append( + BenchmarkResult( + label=label, + execution_time_ms=execution_time_ms, + formatted_execution_time=format_ms(execution_time_ms), + ) + ) + + def add_tool_execution_snapshot( + self, + benchmark: Benchmark, + total_time_ms: float, + tool_timings: list[tuple[str, float]], + ) -> None: + nested_snapshots = [ + BenchmarkResult( + label=label, + execution_time_ms=time_ms, + formatted_execution_time=format_ms(time_ms), + ) + for label, time_ms in tool_timings + ] + + benchmark.snapshots.append( + BenchmarkResult( + label="Tool Execution", + execution_time_ms=total_time_ms, + formatted_execution_time=format_ms(total_time_ms), + snapshots=nested_snapshots or None, + ) + ) + + async def broadcast_token_usage( + self, prompt_tokens: int, completion_tokens: int + ) -> None: + is_local = ( + self.config.features.conversation_provider == ConversationProvider.LOCAL_LLM + ) + + if is_local and prompt_tokens == 0: + prompt_tokens = sum( + count_tokens( + msg["content"] + if isinstance(msg.get("content"), str) + else str(msg.get("content", "")) + ) + for msg in self.conversation.messages + ) + if is_local and completion_tokens == 0 and self.conversation.messages: + last = self.conversation.messages[-1] + if last.get("role") == "assistant": + content = last.get("content", "") + completion_tokens = count_tokens( + content if isinstance(content, str) else str(content) + ) + + self.last_turn_prompt_tokens = prompt_tokens + self.last_turn_completion_tokens = completion_tokens + if prompt_tokens == 0 and completion_tokens == 0: + return + if not printr._connection_manager: + return + + from api.commands import ConversationTokenUsageCommand + + await printr._connection_manager.broadcast( + ConversationTokenUsageCommand( + wingman_name=self.wingman_name, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + is_local=is_local, + ) + ) + + def reset_token_counters(self) -> None: + self.last_turn_prompt_tokens = 0 + self.last_turn_completion_tokens = 0 diff --git a/services/voice_service.py b/services/voice_service.py index 27cc55825..634156151 100644 --- a/services/voice_service.py +++ b/services/voice_service.py @@ -18,7 +18,7 @@ from providers.hume import Hume from providers.inworld import Inworld from providers.open_ai import OpenAi, OpenAiAzure, OpenAiCompatibleTts -from providers.wingman_pro import WingmanPro +from providers.wingman_subscription import WingmanSubscription from providers.xvasynth import XVASynth from providers.pocket_tts import PocketTTS @@ -245,7 +245,7 @@ def get_azure_voices(self, api_key: str, region: AzureRegion, locale: str = ""): # GET /voices/azure/wingman-pro def get_wingman_pro_azure_voices(self, locale: str = ""): - wingman_pro = WingmanPro( + wingman_pro = WingmanSubscription( wingman_name="", settings=self.config_manager.settings_config.wingman_pro ) voices = wingman_pro.get_available_voices(locale=locale) @@ -258,7 +258,7 @@ def get_wingman_pro_azure_voices(self, locale: str = ""): def get_wingman_pro_inworld_voices( self, filter_language: str = None ) -> list[VoiceInfo]: - wingman_pro = WingmanPro( + wingman_pro = WingmanSubscription( wingman_name="", settings=self.config_manager.settings_config.wingman_pro ) voices = wingman_pro.get_available_inworld_voices( @@ -416,7 +416,7 @@ async def play_pocket_tts( async def play_wingman_pro_azure( self, text: str, config: AzureTtsConfig, sound_config: SoundConfig ): - wingman_pro = WingmanPro( + wingman_pro = WingmanSubscription( wingman_name="system", settings=self.config_manager.settings_config.wingman_pro, ) @@ -432,7 +432,7 @@ async def play_wingman_pro_azure( async def play_wingman_pro_openai( self, text: str, voice: str, model: str, speed: float, sound_config: SoundConfig ): - wingman_pro = WingmanPro( + wingman_pro = WingmanSubscription( wingman_name="system", settings=self.config_manager.settings_config.wingman_pro, ) @@ -453,7 +453,7 @@ async def play_wingman_pro_inworld( config: InworldConfig, sound_config: SoundConfig, ): - wingman_pro = WingmanPro( + wingman_pro = WingmanSubscription( wingman_name="system", settings=self.config_manager.settings_config.wingman_pro, ) diff --git a/services/wingman_mcp_manager.py b/services/wingman_mcp_manager.py new file mode 100644 index 000000000..5600bb4ce --- /dev/null +++ b/services/wingman_mcp_manager.py @@ -0,0 +1,282 @@ +"""WingmanMcpManager — owns all MCP discovery, connection, and lifecycle. + +Centralises MCP concerns (registry creation, secret injection, timeout +handling, enable/disable, parallel init) in one focused service. +""" + +import asyncio +import traceback +from typing import Callable + +from api.commands import McpStateChangedCommand +from api.enums import ( + LogSource, + LogType, + McpTransportType, + WingmanInitializationErrorType, +) +from api.interface import SettingsConfig, WingmanConfig, WingmanInitializationError +from services.mcp_client import McpClient +from services.mcp_registry import McpRegistry +from services.printr import Printr +from services.secret_keeper import SecretKeeper + +printr = Printr() + +_AUTH_HEADER_KEYS = {"authorization", "api-key", "x-api-key"} +_STDIO_DEFAULT_TIMEOUT = 60.0 +_HTTP_DEFAULT_TIMEOUT = 30.0 + + +class WingmanMcpManager: + """Manages MCP server discovery, connection, enable/disable, and teardown.""" + + def __init__( + self, + wingman_name: str, + mcp_client: McpClient, + secret_keeper: SecretKeeper, + get_mcp_config: Callable, # callable returning the central mcp_config (from tower) + settings: SettingsConfig, + config: WingmanConfig, + ): + self.wingman_name = wingman_name + self.mcp_client = mcp_client + self.secret_keeper = secret_keeper + self._get_mcp_config = get_mcp_config + self.settings = settings + self.config = config + + # Manager owns its registry and the broadcast callback + self.mcp_registry = McpRegistry( + mcp_client, + wingman_name=wingman_name, + on_state_changed=self._broadcast_mcp_state_changed, + ) + + # ─────────────────────────── Private ────────────────────────────────────── # + + async def _prepare_connection_params( + self, mcp_config, log_secret_found: bool = False + ) -> tuple[dict, float]: + """Build request headers (with secret-injected auth) and resolve the timeout.""" + headers: dict = {} + if mcp_config.headers: + headers.update(mcp_config.headers) + + secret_key = f"mcp_{mcp_config.name}" + api_key = await self.secret_keeper.retrieve( + requester=self.wingman_name, + key=secret_key, + prompt_if_missing=False, + ) + if api_key: + if log_secret_found: + printr.print( + f"MCP secret '{secret_key}' found ({len(api_key)} chars)", + color=LogType.INFO, + source_name=self.wingman_name, + server_only=True, + ) + if not any(k.lower() in _AUTH_HEADER_KEYS for k in headers.keys()): + headers["Authorization"] = f"Bearer {api_key}" + + default_timeout = ( + _STDIO_DEFAULT_TIMEOUT + if mcp_config.type == McpTransportType.STDIO + else _HTTP_DEFAULT_TIMEOUT + ) + timeout = ( + float(mcp_config.timeout) if mcp_config.timeout else default_timeout + ) + return headers, timeout + + def _broadcast_mcp_state_changed(self): + if printr._connection_manager: + printr.ensure_async( + printr._connection_manager.broadcast( + McpStateChangedCommand(wingman_name=self.wingman_name) + ) + ) + + # ─────────────────────────── Public API ─────────────────────────────────── # + + async def enable_mcp(self, mcp_name: str) -> tuple[bool, str]: + if not self.mcp_client.is_available: + return False, "MCP SDK not installed." + + if mcp_name in self.mcp_registry.get_connected_server_names(): + return True, f"MCP server '{mcp_name}' is already connected." + + central_mcp_config = self._get_mcp_config() + mcp_configs = central_mcp_config.servers if central_mcp_config else [] + + mcp_config = None + for cfg in mcp_configs: + if cfg.name == mcp_name: + mcp_config = cfg + break + + if not mcp_config: + return False, f"MCP server '{mcp_name}' not found in mcp.yaml." + + try: + headers, timeout = await self._prepare_connection_params(mcp_config) + + connection = await asyncio.wait_for( + self.mcp_registry.register_server( + config=mcp_config, + headers=headers if headers else None, + ), + timeout=timeout, + ) + + if connection.is_connected: + tool_count = len(connection.tools) + return True, f"MCP server '{mcp_name}' enabled with {tool_count} tools." + else: + error = connection.error or "Connection failed." + return False, f"MCP server '{mcp_name}' failed to connect: {error}" + + except asyncio.TimeoutError: + error_msg = f"Connection timed out ({int(timeout)}s)." + self.mcp_registry.set_server_error(mcp_name, error_msg) + return False, f"MCP server '{mcp_name}': {error_msg}" + + except Exception as e: + error_msg = f"Error enabling MCP '{mcp_name}': {str(e)}" + await printr.print_async(error_msg, color=LogType.ERROR) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + return False, error_msg + + async def disable_mcp(self, mcp_name: str) -> tuple[bool, str]: + if mcp_name not in self.mcp_registry.get_connected_server_names(): + return True, f"MCP server '{mcp_name}' is already disconnected." + + try: + await self.mcp_registry.unregister_server(mcp_name) + return True, f"MCP server '{mcp_name}' disabled." + + except Exception as e: + error_msg = f"Error disabling MCP '{mcp_name}': {str(e)}" + await printr.print_async(error_msg, color=LogType.ERROR) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + return False, error_msg + + async def init_mcps(self) -> list[WingmanInitializationError]: + errors = [] + + if not self.mcp_client.is_available: + printr.print( + f"[{self.wingman_name}] MCP SDK not installed, skipping MCP initialization.", + color=LogType.WARNING, + server_only=True, + ) + return errors + + await self.unload_mcps() + + central_mcp_config = self._get_mcp_config() + mcp_configs = central_mcp_config.servers if central_mcp_config else [] + if not mcp_configs: + return errors + + discoverable_mcps = self.config.discoverable_mcps + mcps_to_connect = [mcp for mcp in mcp_configs if mcp.name in discoverable_mcps] + + if not mcps_to_connect: + return errors + + async def connect_mcp(mcp_config): + local_errors = [] + try: + headers, timeout = await self._prepare_connection_params( + mcp_config, log_secret_found=True + ) + + try: + connection = await asyncio.wait_for( + self.mcp_registry.register_server( + config=mcp_config, + headers=headers if headers else None, + ), + timeout=timeout, + ) + except asyncio.TimeoutError: + error_msg = f"MCP '{mcp_config.display_name}' connection timed out ({int(timeout)}s)." + printr.print( + error_msg, + color=LogType.WARNING, + source_name=self.wingman_name, + server_only=True, + ) + local_errors.append( + WingmanInitializationError( + wingman_name=self.wingman_name, + message=error_msg, + error_type=WingmanInitializationErrorType.MCP_CONNECTION_FAILED, + ) + ) + return (False, None, local_errors) + + if connection.is_connected: + return ( + True, + f"{mcp_config.display_name} ({len(connection.tools)} tools)", + local_errors, + ) + else: + error_msg = f"MCP '{mcp_config.display_name}' failed to connect: {connection.error}" + local_errors.append( + WingmanInitializationError( + wingman_name=self.wingman_name, + message=error_msg, + error_type=WingmanInitializationErrorType.MCP_CONNECTION_FAILED, + ) + ) + return (False, None, local_errors) + + except Exception as e: + error_msg = f"MCP '{mcp_config.name}' initialization error: {str(e)}" + printr.print( + error_msg, + color=LogType.ERROR, + source_name=self.wingman_name, + server_only=True, + ) + printr.print( + traceback.format_exc(), color=LogType.ERROR, server_only=True + ) + local_errors.append( + WingmanInitializationError( + wingman_name=self.wingman_name, + message=error_msg, + error_type=WingmanInitializationErrorType.MCP_CONNECTION_FAILED, + ) + ) + return (False, None, local_errors) + + connection_tasks = [connect_mcp(mcp) for mcp in mcps_to_connect] + results = await asyncio.gather(*connection_tasks) + + connected_count = 0 + connected_names = [] + for success, connection_info, mcp_errors in results: + if success: + connected_count += 1 + connected_names.append(connection_info) + errors.extend(mcp_errors) + + if connected_count > 0: + await printr.print_async( + f"Discoverable MCP servers connected ({connected_count}): {', '.join(connected_names)}", + color=LogType.WINGMAN, + source=LogSource.WINGMAN, + source_name=self.wingman_name, + server_only=not self.settings.debug_mode, + ) + + return errors + + async def unload_mcps(self): + await self.mcp_registry.clear() diff --git a/services/wingman_skill_manager.py b/services/wingman_skill_manager.py new file mode 100644 index 000000000..98df17530 --- /dev/null +++ b/services/wingman_skill_manager.py @@ -0,0 +1,340 @@ +"""WingmanSkillManager — skill discovery, lifecycle, and state. + +Owns ``skills``, ``tool_skills``, and ``skill_tools``; the parent ``Wingman`` +exposes them as read-through properties. +""" + +import traceback +from typing import TYPE_CHECKING + +from api.interface import ( + SkillConfig, + SettingsConfig, + WingmanConfig, + WingmanInitializationError, +) +from api.enums import ( + LogSource, + LogType, + WingmanInitializationErrorType, +) +from services.module_manager import ModuleManager +from services.platform_utils import normalize_platform +from services.printr import Printr +from services.skill_registry import SkillRegistry +from skills.skill_base import Skill +from wingmen.wingman_context import WingmanContext + +if TYPE_CHECKING: + from wingmen.wingman import Wingman + +printr = Printr() + + +def _get_skill_folder_from_module(module: str) -> str: + """Extract folder name from module path like 'skills.star_head.main' -> 'star_head'""" + return module.replace(".main", "").replace(".", "/").split("/")[1] + + +class WingmanSkillManager: + """Manages skill discovery, loading, preparation, enable/disable, and teardown.""" + + def __init__( + self, + wingman: "Wingman", + config: WingmanConfig, + settings: SettingsConfig, + skill_registry: SkillRegistry, + ): + self._wingman = wingman + self.config = config + self.settings = settings + self.skill_registry = skill_registry + + # Manager owns skill state + self.skills: list[Skill] = [] + self.tool_skills: dict[str, Skill] = {} + self.skill_tools: list[dict] = [] + + # ──────────────────────────── Private helpers ─────────────────────────────── # + + def _sync_conversation_skill_context(self) -> None: + """Push current skill state into the conversation manager. + + Called after every mutation (init/enable/disable/unload) so that + ConversationManager does not need per-call skill kwargs. + """ + self._wingman.conversation.set_skill_context( + self.skills, self.skill_registry, self.tool_skills + ) + + def _build_user_skill_configs(self) -> dict[str, SkillConfig]: + """Map folder name → user SkillConfig for each entry in wingman config.""" + result: dict[str, SkillConfig] = {} + if self.config.skills: + for skill_config in self.config.skills: + folder_name = _get_skill_folder_from_module(skill_config.module) + result[folder_name] = skill_config + return result + + def _load_skill_config( + self, + skill_folder_name: str, + skill_config_path: str, + user_skill_configs: dict[str, SkillConfig], + ) -> SkillConfig | None: + """Read yaml config + merge user overrides; return SkillConfig or None.""" + skill_config_dict = ModuleManager.read_config(skill_config_path) + if not skill_config_dict: + return None + + if skill_folder_name in user_skill_configs: + user_config = user_skill_configs[skill_folder_name] + if user_config.custom_properties: + skill_config_dict["custom_properties"] = [ + prop.model_dump() for prop in user_config.custom_properties + ] + if user_config.prompt: + skill_config_dict["prompt"] = user_config.prompt + + return SkillConfig(**skill_config_dict) + + def _check_platform_supported( + self, skill_config: SkillConfig + ) -> tuple[bool, str]: + """Return (ok, reason) for whether the skill supports the current platform.""" + if not skill_config.platforms: + return True, "" + normalized = normalize_platform() + if normalized not in skill_config.platforms: + return ( + False, + f"Skill '{skill_config.name}' is not supported on {normalized}.", + ) + return True, "" + + def _instantiate_skill(self, skill_config: SkillConfig) -> Skill | None: + """Create WingmanContext, call ModuleManager.load_skill, bind helpers.""" + context = WingmanContext(self._wingman) + skill = ModuleManager.load_skill( + config=skill_config, + settings=self.settings, + wingman=context, + ) + if skill: + skill.threaded_execution = self._wingman.threaded_execution + return skill + + # ──────────────────────────── Public API ─────────────────────────────────── # + + async def init_skills(self) -> list[WingmanInitializationError]: + if self.skills: + await self.unload_skills() + + errors = [] + self.skills = [] + + user_skill_configs = self._build_user_skill_configs() + available_skills = ModuleManager.read_available_skill_configs() + discoverable_skills = self.config.discoverable_skills + + for ( + skill_folder_name, + skill_config_path, + _is_custom, + _is_local, + ) in available_skills: + try: + skill_config = self._load_skill_config( + skill_folder_name, skill_config_path, user_skill_configs + ) + if not skill_config: + continue + + if skill_config.name not in discoverable_skills: + continue + + ok, reason = self._check_platform_supported(skill_config) + if not ok: + printr.print( + f"Skipping skill - {reason}", + color=LogType.WARNING, + server_only=True, + ) + continue + + skill = self._instantiate_skill(skill_config) + if skill: + self.skills.append(skill) + await self.prepare_skill(skill) + + except Exception as e: + skill_name = skill_folder_name + error_msg = f"Error loading skill '{skill_name}': {str(e)}" + await printr.print_async( + error_msg, + color=LogType.ERROR, + ) + printr.print( + traceback.format_exc(), color=LogType.ERROR, server_only=True + ) + errors.append( + WingmanInitializationError( + wingman_name=self._wingman.name, + message=error_msg, + error_type=WingmanInitializationErrorType.SKILL_INITIALIZATION_FAILED, + ) + ) + + if self.skills: + skill_names = [s.config.name for s in self.skills] + await printr.print_async( + f"Discoverable skills ({len(skill_names)}): {', '.join(skill_names)}", + color=LogType.WINGMAN, + source=LogSource.WINGMAN, + source_name=self._wingman.name, + server_only=not self.settings.debug_mode, + ) + + self._sync_conversation_skill_context() + return errors + + async def prepare_skill(self, skill: Skill): + try: + for tool_name, tool in skill.get_tools(): + self.tool_skills[tool_name] = skill + self.skill_tools.append(tool) + + self.skill_registry.register_skill(skill) + + if skill.config.auto_activate: + success, message = await skill.ensure_activated() + if not success: + await printr.print_async( + f"Auto-activated skill '{skill.config.display_name}' failed to activate: {message}", + color=LogType.ERROR, + ) + except Exception as e: + await printr.print_async( + f"Error while preparing skill '{skill.name}': {str(e)}", + color=LogType.ERROR, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + + skill.llm_call = self._wingman.actual_llm_call + + async def unprepare_skill(self, skill: Skill): + try: + for tool_name, _ in skill.get_tools(): + self.tool_skills.pop(tool_name, None) + self.skill_tools = [ + t + for t in self.skill_tools + if t.get("function", {}).get("name") != tool_name + ] + self.skill_registry.unregister_skill(skill.name) + except Exception as e: + await printr.print_async( + f"Error while unpreparing skill '{skill.name}': {str(e)}", + color=LogType.ERROR, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + self._sync_conversation_skill_context() + + async def enable_skill(self, skill_name: str) -> tuple[bool, str]: + for existing_skill in self.skills: + if existing_skill.config.name == skill_name: + return True, f"Skill '{skill_name}' is already enabled." + + available_skills = ModuleManager.read_available_skill_configs() + user_skill_configs = self._build_user_skill_configs() + + for ( + skill_folder_name, + skill_config_path, + _is_custom, + _is_local, + ) in available_skills: + try: + skill_config = self._load_skill_config( + skill_folder_name, skill_config_path, user_skill_configs + ) + if not skill_config: + continue + + if skill_config.name != skill_name: + continue + + ok, reason = self._check_platform_supported(skill_config) + if not ok: + return False, reason + + skill = self._instantiate_skill(skill_config) + if skill: + self.skills.append(skill) + await self.prepare_skill(skill) + self._sync_conversation_skill_context() + + printr.print( + f"Skill '{skill_name}' activated (loaded and made discoverable).", + color=LogType.POSITIVE, + server_only=True, + ) + return True, f"Skill '{skill_name}' activated successfully." + + except Exception as e: + error_msg = f"Error activating skill '{skill_name}': {str(e)}" + await printr.print_async(error_msg, color=LogType.ERROR) + printr.print( + traceback.format_exc(), color=LogType.ERROR, server_only=True + ) + return False, error_msg + + return False, f"Skill '{skill_name}' not found." + + async def disable_skill(self, skill_name: str) -> tuple[bool, str]: + skill_to_remove = None + for skill in self.skills: + if skill.config.name == skill_name: + skill_to_remove = skill + break + + if not skill_to_remove: + return True, f"Skill '{skill_name}' is already deactivated." + + try: + await skill_to_remove.unload() + self.skills.remove(skill_to_remove) + await self.unprepare_skill(skill_to_remove) + + printr.print( + f"Skill '{skill_name}' deactivated (unloaded and removed from discoverable skills).", + color=LogType.WARNING, + server_only=True, + ) + return True, f"Skill '{skill_name}' deactivated successfully." + + except Exception as e: + error_msg = f"Error deactivating skill '{skill_name}': {str(e)}" + await printr.print_async(error_msg, color=LogType.ERROR) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + return False, error_msg + + async def unload_skills(self): + for skill in self.skills: + if not skill.is_prepared: + continue + try: + await skill.unload() + except Exception as e: + await printr.print_async( + f"Error unloading skill '{skill.name}': {str(e)}", + color=LogType.ERROR, + ) + printr.print( + traceback.format_exc(), color=LogType.ERROR, server_only=True + ) + self.tool_skills = {} + self.skill_tools = [] + self.skill_registry.clear() + self._sync_conversation_skill_context() diff --git a/skills/AGENTS.md b/skills/AGENTS.md index 45fd37908..50cf74367 100644 --- a/skills/AGENTS.md +++ b/skills/AGENTS.md @@ -118,10 +118,10 @@ from api.interface import SettingsConfig, SkillConfig, WingmanInitializationErro from skills.skill_base import Skill, tool if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class YourSkillName(Skill): - def __init__(self, config: SkillConfig, settings: SettingsConfig, wingman: "OpenAiWingman") -> None: + def __init__(self, config: SkillConfig, settings: SettingsConfig, wingman: "WingmanContext") -> None: super().__init__(config=config, settings=settings, wingman=wingman) async def validate(self) -> list[WingmanInitializationError]: diff --git a/skills/README.md b/skills/README.md index c9fcf6cde..fe4184f85 100644 --- a/skills/README.md +++ b/skills/README.md @@ -630,7 +630,7 @@ from api.interface import SettingsConfig, SkillConfig, WingmanInitializationErro from skills.skill_base import Skill, tool if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class YourSkillName(Skill): @@ -640,7 +640,7 @@ class YourSkillName(Skill): self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) # Initialize your skill here @@ -1766,7 +1766,7 @@ from skills.skill_base import Skill, tool from services.skill_local_ai import MemoryType if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class GameStatsTracker(Skill): @@ -1776,7 +1776,7 @@ class GameStatsTracker(Skill): self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) diff --git a/skills/api_request/main.py b/skills/api_request/main.py index 5a8a68d81..2bd98a9d7 100644 --- a/skills/api_request/main.py +++ b/skills/api_request/main.py @@ -12,7 +12,7 @@ from skills.skill_base import Skill, tool if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext DEFAULT_HEADERS = { "Strict-Transport-Security": "max-age=31536000; includeSubDomains", @@ -61,7 +61,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: self.default_headers = DEFAULT_HEADERS diff --git a/skills/ats_telemetry/main.py b/skills/ats_telemetry/main.py index 45b040932..625da298d 100644 --- a/skills/ats_telemetry/main.py +++ b/skills/ats_telemetry/main.py @@ -21,13 +21,13 @@ if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class ATSTelemetry(Skill): def __init__( - self, config: SkillConfig, settings: SettingsConfig, wingman: "OpenAiWingman" + self, config: SkillConfig, settings: SettingsConfig, wingman: "WingmanContext" ) -> None: self.loaded = False self.already_initialized_telemetry = False diff --git a/skills/audio_device_changer/main.py b/skills/audio_device_changer/main.py index 02fc0dc5e..a9885fb9d 100644 --- a/skills/audio_device_changer/main.py +++ b/skills/audio_device_changer/main.py @@ -15,7 +15,7 @@ from skills.skill_base import Skill if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class AudioDeviceChanger(Skill): @@ -25,7 +25,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) self.original_audio_device = settings.audio.output diff --git a/skills/auto_screenshot/main.py b/skills/auto_screenshot/main.py index 6b7328a99..053f75c80 100644 --- a/skills/auto_screenshot/main.py +++ b/skills/auto_screenshot/main.py @@ -14,7 +14,7 @@ from skills.skill_base import Skill, tool if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class AutoScreenshot(Skill): @@ -22,7 +22,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) diff --git a/skills/control_windows/main.py b/skills/control_windows/main.py index 2b09aa32c..9a5ae1bed 100644 --- a/skills/control_windows/main.py +++ b/skills/control_windows/main.py @@ -11,7 +11,7 @@ import mouse.mouse as mouse if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class ControlWindows(Skill): @@ -28,7 +28,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) diff --git a/skills/file_manager/main.py b/skills/file_manager/main.py index bb91ca215..8ebcc022d 100644 --- a/skills/file_manager/main.py +++ b/skills/file_manager/main.py @@ -10,7 +10,7 @@ from pdfminer.high_level import extract_text if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext DEFAULT_MAX_TEXT_SIZE = 24000 SUPPORTED_FILE_EXTENSIONS = [ @@ -101,7 +101,7 @@ class FileManager(Skill): def __init__( - self, config: SkillConfig, settings: SettingsConfig, wingman: "OpenAiWingman" + self, config: SkillConfig, settings: SettingsConfig, wingman: "WingmanContext" ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) self.allowed_file_extensions = SUPPORTED_FILE_EXTENSIONS diff --git a/skills/hud/main.py b/skills/hud/main.py index 15df3583b..6b59f9e44 100644 --- a/skills/hud/main.py +++ b/skills/hud/main.py @@ -31,7 +31,7 @@ from hud_server.validation import validate_hud_settings if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext printr = Printr() @@ -50,7 +50,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman" + wingman: "WingmanContext" ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) diff --git a/skills/image_generation/main.py b/skills/image_generation/main.py index 9eb4ffe39..8f9d5c695 100644 --- a/skills/image_generation/main.py +++ b/skills/image_generation/main.py @@ -7,7 +7,7 @@ from skills.skill_base import Skill, tool if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class ImageGeneration(Skill): @@ -16,7 +16,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) self.image_path = self.get_generated_files_dir() diff --git a/skills/msfs2020_control/main.py b/skills/msfs2020_control/main.py index 8024e0320..d94ff5c58 100644 --- a/skills/msfs2020_control/main.py +++ b/skills/msfs2020_control/main.py @@ -22,13 +22,13 @@ from skills.msfs2020_control.command_matcher.command_matcher import CommandMatcher if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class Msfs2020Control(Skill): def __init__( - self, config: SkillConfig, settings: SettingsConfig, wingman: "OpenAiWingman" + self, config: SkillConfig, settings: SettingsConfig, wingman: "WingmanContext" ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) self.already_initialized_simconnect = False diff --git a/skills/quick_commands/main.py b/skills/quick_commands/main.py index d3d14609c..f6083b865 100644 --- a/skills/quick_commands/main.py +++ b/skills/quick_commands/main.py @@ -7,7 +7,7 @@ from skills.skill_base import Skill if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class QuickCommands(Skill): @@ -16,7 +16,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) diff --git a/skills/radio_chatter/main.py b/skills/radio_chatter/main.py index ae933e709..3b777987c 100644 --- a/skills/radio_chatter/main.py +++ b/skills/radio_chatter/main.py @@ -22,7 +22,7 @@ from skills.skill_base import Skill, tool if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class RadioChatter(Skill): @@ -31,7 +31,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) @@ -155,42 +155,8 @@ async def validate(self) -> list[WingmanInitializationError]: ) ) - # Initialize all providers - initiated_providers = [] - for voice in voices: - voice_provider = voice.provider - if voice_provider not in initiated_providers: - initiated_providers.append(voice_provider) - - if voice_provider == TtsProvider.OPENAI and not self.wingman.openai: - await self.wingman.validate_and_set_openai(errors) - elif ( - voice_provider == TtsProvider.AZURE - and not self.wingman.openai_azure - ): - await self.wingman.validate_and_set_azure(errors) - elif ( - voice_provider == TtsProvider.ELEVENLABS - and not self.wingman.elevenlabs - ): - await self.wingman.validate_and_set_elevenlabs(errors) - elif ( - voice_provider == TtsProvider.WINGMAN_PRO - and not self.wingman.wingman_pro - ): - await self.wingman.validate_and_set_wingman_pro() - elif ( - voice_provider == TtsProvider.INWORLD - and not self.wingman.inworld - ): - await self.wingman.validate_and_set_inworld(errors) - elif ( - voice_provider == TtsProvider.OPENAI_COMPATIBLE - and not self.wingman.openai_compatible_tts - ): - await self.wingman.validate_and_set_openai_compatible_tts( - errors - ) + # Provider initialization is handled by ProviderFactory via + # switch_tts_provider() at voice-switch time. No pre-init needed. return errors @@ -605,7 +571,7 @@ async def _switch_voice( f"Switching voice to {voice_name} ({voice_provider.value})" ) - self.wingman.config.features.tts_provider = voice_provider + await self.wingman.switch_tts_provider(voice_provider) async def _get_original_voice_setting(self) -> VoiceSelection: voice_provider = self.wingman.config.features.tts_provider diff --git a/skills/skill_base.py b/skills/skill_base.py index 235a4afaf..e3240119a 100644 --- a/skills/skill_base.py +++ b/skills/skill_base.py @@ -26,7 +26,7 @@ if TYPE_CHECKING: from services.skill_local_ai import SkillLocalAI - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext # Type mapping from Python types to JSON Schema types @@ -296,7 +296,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: self.config = config self.settings = settings diff --git a/skills/spotify/main.py b/skills/spotify/main.py index 750ee692c..ea18e1abf 100644 --- a/skills/spotify/main.py +++ b/skills/spotify/main.py @@ -8,7 +8,7 @@ from services.file import get_generated_files_dir if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class Spotify(Skill): @@ -17,7 +17,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) diff --git a/skills/thinking_sound/main.py b/skills/thinking_sound/main.py index 99ca957cb..38ccbd4fd 100644 --- a/skills/thinking_sound/main.py +++ b/skills/thinking_sound/main.py @@ -10,7 +10,7 @@ from skills.skill_base import Skill if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class ThinkingSound(Skill): @@ -20,7 +20,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) diff --git a/skills/timer/main.py b/skills/timer/main.py index 92da197f8..0bcd57f04 100644 --- a/skills/timer/main.py +++ b/skills/timer/main.py @@ -13,7 +13,7 @@ from skills.skill_base import Skill, tool if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class ActualTimer: @@ -123,7 +123,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) diff --git a/skills/typing_assistant/main.py b/skills/typing_assistant/main.py index 78f115937..baf7a7d4c 100644 --- a/skills/typing_assistant/main.py +++ b/skills/typing_assistant/main.py @@ -6,7 +6,7 @@ import keyboard.keyboard as keyboard if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class TypingAssistant(Skill): @@ -21,7 +21,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) diff --git a/skills/uexcorp/main.py b/skills/uexcorp/main.py index 64e5ee161..9d0a910a7 100644 --- a/skills/uexcorp/main.py +++ b/skills/uexcorp/main.py @@ -14,7 +14,7 @@ from skills.uexcorp.uexcorp.helper import Helper if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class UEXCorp(Skill): @@ -24,7 +24,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: self.random_seed = uuid.uuid4() super().__init__(config=config, settings=settings, wingman=wingman) diff --git a/skills/uexcorp/uexcorp/handler/config_handler.py b/skills/uexcorp/uexcorp/handler/config_handler.py index 0ca067140..0beb347e8 100644 --- a/skills/uexcorp/uexcorp/handler/config_handler.py +++ b/skills/uexcorp/uexcorp/handler/config_handler.py @@ -6,7 +6,7 @@ from services.file import get_writable_dir if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext from skills.uexcorp.uexcorp.helper import Helper @@ -64,7 +64,7 @@ def __init__( helper: "Helper" ): self.__helper = helper - self.__wingman: "OpenAiWingman | None" = None + self.__wingman: "WingmanContext | None" = None self.__fine_config_path: str = get_writable_dir(os.path.join(self.__helper.get_data_path(), "config")) self.__api_url: str = "https://api.uexcorp.space/2.0" self.__api_use_key: bool = False @@ -324,8 +324,8 @@ def get_behavior_use_fasterwhisper_hotwords(self) -> bool: def set_behavior_update_fasterwhisper_hotwords(self, update: bool): self.__behavior_use_fasterwhisper_hotwords = update - def set_wingman(self, wingman: "OpenAiWingman"): + def set_wingman(self, wingman: "WingmanContext"): self.__wingman = wingman - def get_wingman(self) -> "OpenAiWingman": + def get_wingman(self) -> "WingmanContext": return self.__wingman \ No newline at end of file diff --git a/skills/uexcorp/uexcorp/helper.py b/skills/uexcorp/uexcorp/helper.py index 16a745c51..d9a11f6d8 100644 --- a/skills/uexcorp/uexcorp/helper.py +++ b/skills/uexcorp/uexcorp/helper.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from skills.uexcorp.uexcorp.handler.tool_handler import ToolHandler - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext printr = Printr() @@ -66,7 +66,7 @@ def __init__(self): self.__request_while_not_ready = False self.__wingman = None - def prepare(self, threaded_execution: callable, wingman: "OpenAiWingman"): + def prepare(self, threaded_execution: callable, wingman: "WingmanContext"): from skills.uexcorp.uexcorp.handler.tool_handler import ToolHandler self.__wingman = wingman @@ -325,7 +325,7 @@ def get_llm(self) -> Llm: def get_default_thread_ident(self) -> int: return self.__default_thread - def get_wingmen(self) -> "OpenAiWingman": + def get_wingmen(self) -> "WingmanContext": return self.__wingman def toast(self, message: str): diff --git a/skills/vision_ai/main.py b/skills/vision_ai/main.py index adc67b6f8..7cc348c74 100644 --- a/skills/vision_ai/main.py +++ b/skills/vision_ai/main.py @@ -8,7 +8,7 @@ from skills.skill_base import Skill, tool if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class VisionAI(Skill): @@ -17,7 +17,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) diff --git a/skills/voice_changer/main.py b/skills/voice_changer/main.py index 5c7742f23..b172f361b 100644 --- a/skills/voice_changer/main.py +++ b/skills/voice_changer/main.py @@ -15,7 +15,7 @@ from skills.skill_base import Skill if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class VoiceChanger(Skill): @@ -24,7 +24,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) @@ -46,45 +46,8 @@ async def validate(self) -> list[WingmanInitializationError]: voices: list[VoiceSelection] = self.retrieve_custom_property_value( "voice_changer_voices", errors ) - if voices and len(voices) > 0: - # Initialize all providers - initiated_providers = [] - - for voice in voices: - voice_provider = voice.provider - if voice_provider not in initiated_providers: - initiated_providers.append(voice_provider) - - # initiate provider - if voice_provider == TtsProvider.OPENAI and not self.wingman.openai: - await self.wingman.validate_and_set_openai(errors) - elif ( - voice_provider == TtsProvider.AZURE - and not self.wingman.openai_azure - ): - await self.wingman.validate_and_set_azure(errors) - elif ( - voice_provider == TtsProvider.ELEVENLABS - and not self.wingman.elevenlabs - ): - await self.wingman.validate_and_set_elevenlabs(errors) - elif ( - voice_provider == TtsProvider.WINGMAN_PRO - and not self.wingman.wingman_pro - ): - await self.wingman.validate_and_set_wingman_pro() - elif ( - voice_provider == TtsProvider.INWORLD - and not self.wingman.inworld - ): - await self.wingman.validate_and_set_inworld(errors) - elif ( - voice_provider == TtsProvider.OPENAI_COMPATIBLE - and not self.wingman.openai_compatible_tts - ): - await self.wingman.validate_and_set_openai_compatible_tts( - errors - ) + # Provider initialization is handled by ProviderFactory via + # switch_tts_provider() at voice-switch time. No pre-init needed. return errors @@ -256,7 +219,7 @@ async def _switch_voice(self, voices: list[VoiceSelection]) -> str: ) return f"Voice switching failed due to an unknown voice provider/subprovider. Provider: {voice_provider.value}" - self.wingman.config.features.tts_provider = voice_provider + await self.wingman.switch_tts_provider(voice_provider) if not provider_name: provider_name = voice_provider.value diff --git a/wingman_core.py b/wingman_core.py index 77715824a..362417c3c 100644 --- a/wingman_core.py +++ b/wingman_core.py @@ -59,7 +59,7 @@ from providers.llama_cpp_remote import LlamaCppRemote from providers.open_ai import OpenAi from providers.whispercpp import Whispercpp -from providers.wingman_pro import WingmanPro +from providers.wingman_subscription import WingmanSubscription from providers.xvasynth import XVASynth from providers.pocket_tts import PocketTTS from wingmen.open_ai_wingman import OpenAiWingman @@ -87,6 +87,7 @@ from services.image_processing import process_image, validate_image_mime from services.model_metadata import ModelMetadataService from services.audio_recorder import RECORDING_PATH, AudioRecorder +from services.threading_utils import threaded_execution from services.config_manager import ConfigManager from services.printr import Printr from services.secret_keeper import SecretKeeper @@ -146,6 +147,18 @@ def __init__( endpoint=self.send_text_to_wingman, tags=tags, ) + self.router.add_api_route( + methods=["POST"], + path="/start-recording-for-wingman", + endpoint=self.start_recording_for_wingman, + tags=tags, + ) + self.router.add_api_route( + methods=["POST"], + path="/stop-recording-for-wingman", + endpoint=self.stop_recording_for_wingman, + tags=tags, + ) self.router.add_api_route( methods=["POST"], path="/generate-greeting", @@ -1464,7 +1477,7 @@ def run_async_process(): text = None if provider == VoiceActivationSttProvider.WINGMAN_PRO: - wingman_pro = WingmanPro( + wingman_pro = WingmanSubscription( wingman_name="system", settings=self.settings_service.settings.wingman_pro, ) @@ -1719,6 +1732,51 @@ def get_startup_errors(self): async def stop_playback(self): await self.audio_player.stop_playback() + # POST /start-recording-for-wingman + async def start_recording_for_wingman(self, wingman_name: str): + """Start audio recording for a wingman (GUI mic toggle).""" + if not self.tower or self.active_recording["key"] != "": + return + + wingman = self.tower.get_wingman_by_name(wingman_name) + if not wingman: + return + + self.active_recording = dict(key="__gui__", wingman=wingman) + self.was_listening_before_ptt = self.is_listening + if ( + self.settings_service.settings.voice_activation.enabled + and self.is_listening + ): + self.start_voice_recognition(mute=True) + + self.audio_recorder.start_recording(wingman_name=wingman.name) + + # POST /stop-recording-for-wingman + async def stop_recording_for_wingman(self, wingman_name: str): + """Stop audio recording and process the result (GUI mic toggle).""" + if ( + not self.tower + or self.active_recording["key"] != "__gui__" + ): + return + + wingman = self.active_recording["wingman"] + recorded_audio_wav = self.audio_recorder.stop_recording( + wingman_name=wingman.name + ) + self.active_recording = {"key": "", "wingman": None} + + if ( + self.settings_service.settings.voice_activation.enabled + and not self.is_listening + and self.was_listening_before_ptt + ): + self.start_voice_recognition() + + if recorded_audio_wav and isinstance(wingman, Wingman): + threaded_execution(wingman.process, str(recorded_audio_wav)) + # POST /ask-wingman-conversation-provider async def ask_wingman_conversation_provider( self, wingman_name: str, text: str = Body(...) diff --git a/wingmen/open_ai_wingman.py b/wingmen/open_ai_wingman.py index 3203b96b7..be6af972e 100644 --- a/wingmen/open_ai_wingman.py +++ b/wingmen/open_ai_wingman.py @@ -1,3541 +1,11 @@ -import json -import time -import asyncio -import random -import traceback -import uuid -from datetime import datetime -from typing import ( - Mapping, - Optional, -) -from openai import APIConnectionError, NOT_GIVEN -from openai.types.chat import ( - ChatCompletion, - ChatCompletionMessage, - ChatCompletionMessageToolCall, - ParsedFunction, -) -import requests -from api.interface import ( - OpenRouterEndpointResult, - SettingsConfig, - SoundConfig, - WingmanInitializationError, - CommandConfig, -) -from api.enums import ( - ImageGenerationProvider, - LogType, - LogSource, - TtsProvider, - SttProvider, - ConversationProvider, - WingmanProSttProvider, - WingmanProTtsProvider, - WingmanInitializationErrorType, -) -from providers.edge import Edge -from providers.elevenlabs import ElevenLabs -from providers.google import GoogleGenAI -from providers.open_ai import OpenAi, OpenAiAzure, OpenAiCompatibleTts -from providers.hume import Hume -from providers.inworld import Inworld -from providers.open_ai import OpenAi, OpenAiAzure -from providers.x_ai import XAi -from providers.wingman_pro import WingmanPro -from api.commands import McpStateChangedCommand -from services.benchmark import Benchmark -from services.file import get_prompt -from services.token_utils import count_tokens, truncate_to_tokens -from services.markdown import cleanup_text -from services.printr import Printr -from services.skill_registry import SkillRegistry -from services.mcp_client import McpClient -from services.mcp_registry import McpRegistry -from services.capability_registry import CapabilityRegistry -from services.tool_response_cache import ToolResponseCompressor -from skills.skill_base import Skill -from wingmen.wingman import Wingman +"""Backward-compatibility shim. -printr = Printr() +``OpenAiWingman`` has been merged into :class:`wingmen.wingman.Wingman`. +This module re-exports ``Wingman`` under the old name so that existing +skills, custom wingmen, tower, and other code that imports +``from wingmen.open_ai_wingman import OpenAiWingman`` continues to work. +""" -# Max seconds to wait for the local support model during conversation condensation. -_CONDENSE_TIMEOUT = 120.0 +from wingmen.wingman import Wingman as OpenAiWingman # noqa: F401 - -class OpenAiWingman(Wingman): - """Our OpenAI Wingman base gives you everything you need to interact with OpenAI's various APIs. - - It transcribes speech to text using Whisper, uses the Completion API for conversation and implements the Tools API to execute functions. - """ - - AZURE_SERVICES = { - "tts": TtsProvider.AZURE, - "whisper": [SttProvider.AZURE, SttProvider.AZURE_SPEECH], - "conversation": ConversationProvider.AZURE, - } - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.edge_tts = Edge() - - # validate will set these: - self.openai: OpenAi | None = None - self.mistral: OpenAi | None = None - self.groq: OpenAi | None = None - self.cerebras: OpenAi | None = None - self.openrouter: OpenAi | None = None - self.openrouter_model_supports_tools = False - self.local_llm: OpenAi | None = None - self.openai_azure: OpenAiAzure | None = None - self.elevenlabs: ElevenLabs | None = None - self.openai_compatible_tts: OpenAiCompatibleTts | None = None - self.hume: Hume | None = None - self.inworld: Inworld | None = None - self.wingman_pro: WingmanPro | None = None - self.google: GoogleGenAI | None = None - self.perplexity: OpenAi | None = None - self.xai: XAi | None = None - - # tool queue - self.pending_tool_calls = [] - self.last_gpt_call = None - - # generated addional content - self.instant_responses = [] - self.last_used_instant_responses = [] - - self.messages = [] - self.conversation_summary: str = "" - self._last_compiled_context: str = "" - self._is_condensing = False - self._condense_task: asyncio.Task | None = None - self._last_prompt_tokens: int = 0 - """Last API-reported prompt_tokens from the conversation LLM.""" - self._support_token_ratio: float = 1.35 - """Calibration ratio: support-model tokens / cl100k_base tokens. - Starts at 1.35 (conservative — Qwen typically tokenizes English text - ~20-40% heavier than cl100k_base). Updated after the first support - call using real usage reported by the model's own tokenizer.""" - """The conversation history that is used for the GPT calls""" - - self.azure_api_keys = {key: None for key in self.AZURE_SERVICES} - - self.tool_skills: dict[str, Skill] = {} - self.skill_tools: list[dict] = [] - - # Progressive tool disclosure registry (MCP-inspired token optimization) - # Only meta-tools are sent to LLM initially; skills activated on-demand - self.skill_registry = SkillRegistry() - - # MCP (Model Context Protocol) support - # Allows connecting to external MCP servers that provide additional tools - self.mcp_client = McpClient(wingman_name=self.name) - self.mcp_registry = McpRegistry( - self.mcp_client, - wingman_name=self.name, - on_state_changed=self._broadcast_mcp_state_changed, - ) - - # Unified capability registry - combines skill and MCP discovery - # From the LLM's perspective, both are just "capabilities" - self.capability_registry = CapabilityRegistry( - self.skill_registry, self.mcp_registry - ) - - # Local AI service — set externally by WingmanCore if available - self.local_ai_service = None - - # Persistent memory — initialized lazily when local_ai_service is available - self.persistent_memory_service = None - self._memory_recall_notified = False - self._background_tasks: set[asyncio.Task] = set() - - # Stateless tool response compressor — summarizes large tool/MCP responses - # via local AI before the cloud LLM sees them - self._tool_response_compressor = ToolResponseCompressor() - - def _broadcast_mcp_state_changed(self): - """Broadcast MCP state change to UI via WebSocket.""" - if printr._connection_manager: - printr.ensure_async( - printr._connection_manager.broadcast( - McpStateChangedCommand(wingman_name=self.name) - ) - ) - - async def validate(self): - errors = await super().validate() - - try: - if self.uses_provider("whispercpp"): - self.whispercpp.validate(self.name, errors) - - if self.uses_provider("fasterwhisper"): - self.fasterwhisper.validate(errors) - - if self.uses_provider("parakeet"): - self.parakeet.validate(errors) - - if self.uses_provider("pocket_tts"): - self.pocket_tts.validate(errors) - - if self.uses_provider("openai"): - await self.validate_and_set_openai(errors) - - if self.uses_provider("mistral"): - await self.validate_and_set_mistral(errors) - - if self.uses_provider("groq"): - await self.validate_and_set_groq(errors) - - if self.uses_provider("cerebras"): - await self.validate_and_set_cerebras(errors) - - if self.uses_provider("google"): - await self.validate_and_set_google(errors) - - if self.uses_provider("openrouter"): - await self.validate_and_set_openrouter(errors) - - if self.uses_provider("local_llm"): - await self.validate_and_set_local_llm(errors) - - if self.uses_provider("elevenlabs"): - await self.validate_and_set_elevenlabs(errors) - - if self.uses_provider("openai_compatible"): - await self.validate_and_set_openai_compatible_tts(errors) - - if self.uses_provider("azure"): - await self.validate_and_set_azure(errors) - - if self.uses_provider("wingman_pro"): - await self.validate_and_set_wingman_pro() - - if self.uses_provider("perplexity"): - await self.validate_and_set_perplexity(errors) - - if self.uses_provider("xai"): - await self.validate_and_set_xai(errors) - - if self.uses_provider("hume"): - await self.validate_and_set_hume(errors) - - if self.uses_provider("inworld"): - await self.validate_and_set_inworld(errors) - - except Exception as e: - errors.append( - WingmanInitializationError( - wingman_name=self.name, - message=f"Error during provider validation: {str(e)}", - error_type=WingmanInitializationErrorType.UNKNOWN, - ) - ) - printr.print( - f"Error during provider validation: {str(e)}", - color=LogType.ERROR, - server_only=True, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - - return errors - - def uses_provider(self, provider_type: str): - if provider_type == "openai": - return any( - [ - self.config.features.tts_provider == TtsProvider.OPENAI, - self.config.features.stt_provider == SttProvider.OPENAI, - self.config.features.conversation_provider - == ConversationProvider.OPENAI, - self.config.features.image_generation_provider - == ImageGenerationProvider.OPENAI, - ] - ) - elif provider_type == "mistral": - return any( - [ - self.config.features.conversation_provider - == ConversationProvider.MISTRAL, - ] - ) - elif provider_type == "groq": - return any( - [ - self.config.features.conversation_provider - == ConversationProvider.GROQ, - self.config.features.stt_provider == SttProvider.GROQ, - ] - ) - elif provider_type == "cerebras": - return any( - [ - self.config.features.conversation_provider - == ConversationProvider.CEREBRAS, - ] - ) - elif provider_type == "google": - return any( - [ - self.config.features.conversation_provider - == ConversationProvider.GOOGLE, - ] - ) - elif provider_type == "openrouter": - return any( - [ - self.config.features.conversation_provider - == ConversationProvider.OPENROUTER, - ] - ) - elif provider_type == "local_llm": - return any( - [ - self.config.features.conversation_provider - == ConversationProvider.LOCAL_LLM, - ] - ) - elif provider_type == "azure": - return any( - [ - self.config.features.tts_provider == TtsProvider.AZURE, - self.config.features.stt_provider == SttProvider.AZURE, - self.config.features.stt_provider == SttProvider.AZURE_SPEECH, - self.config.features.conversation_provider - == ConversationProvider.AZURE, - ] - ) - elif provider_type == "edge_tts": - return self.config.features.tts_provider == TtsProvider.EDGE_TTS - elif provider_type == "elevenlabs": - return self.config.features.tts_provider == TtsProvider.ELEVENLABS - elif provider_type == "openai_compatible": - return self.config.features.tts_provider == TtsProvider.OPENAI_COMPATIBLE - elif provider_type == "pocket_tts": - return self.config.features.tts_provider == TtsProvider.POCKET_TTS - elif provider_type == "hume": - return self.config.features.tts_provider == TtsProvider.HUME - elif provider_type == "inworld": - return self.config.features.tts_provider == TtsProvider.INWORLD - elif provider_type == "xvasynth": - return self.config.features.tts_provider == TtsProvider.XVASYNTH - elif provider_type == "whispercpp": - return self.config.features.stt_provider == SttProvider.WHISPERCPP - elif provider_type == "fasterwhisper": - return self.config.features.stt_provider == SttProvider.FASTER_WHISPER - elif provider_type == "parakeet": - return self.config.features.stt_provider == SttProvider.PARAKEET - elif provider_type == "wingman_pro": - return any( - [ - self.config.features.conversation_provider - == ConversationProvider.WINGMAN_PRO, - self.config.features.tts_provider == TtsProvider.WINGMAN_PRO, - self.config.features.stt_provider == SttProvider.WINGMAN_PRO, - self.config.features.image_generation_provider - == ImageGenerationProvider.WINGMAN_PRO, - ] - ) - elif provider_type == "perplexity": - return any( - [ - self.config.features.conversation_provider - == ConversationProvider.PERPLEXITY, - ] - ) - elif provider_type == "xai": - return any( - [ - self.config.features.conversation_provider - == ConversationProvider.XAI, - ] - ) - return False - - async def prepare(self): - try: - if self.config.features.use_generic_instant_responses: - printr.print( - "Generating AI instant responses...", - color=LogType.WARNING, - server_only=True, - ) - self.threaded_execution(self._generate_instant_responses) - except Exception as e: - await printr.print_async( - f"Error while preparing wingman '{self.name}': {str(e)}", - color=LogType.ERROR, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - - def ensure_memory_initialized(self) -> bool: - """Initialize persistent memory service if not yet done. - - Returns True if the service is available after this call. - """ - # Tear down service if the toggle was turned off - if self.persistent_memory_service and not self.config.persistent_memory: - self.persistent_memory_service.close() - self.persistent_memory_service = None - return False - if self.persistent_memory_service: - return True - if self.config.persistent_memory and self.local_ai_service: - from services.persistent_memory import PersistentMemoryService - - self.persistent_memory_service = PersistentMemoryService( - wingman_name=self.name, - local_ai_service=self.local_ai_service, - ) - self.persistent_memory_service.initialize() - return True - return False - - async def unload(self): - """Extract memories, close clients, and close DB before unloading.""" - # Wait for any background memory extraction tasks to finish - if self._background_tasks: - await asyncio.gather(*self._background_tasks, return_exceptions=True) - self._background_tasks.clear() - - if self.persistent_memory_service: - from services.persistent_memory import MIN_MESSAGES_FOR_EXTRACTION - - if len(self.messages) >= MIN_MESSAGES_FOR_EXTRACTION: - try: - await self.persistent_memory_service.extract_memories( - self.messages, generate_summary=True - ) - except Exception: - pass - self.persistent_memory_service.close() - - if self.google: - await self.google.aclose() - self.google = None - - await super().unload() - - async def unload_skills(self): - await super().unload_skills() - self.tool_skills = {} - self.skill_tools = [] - self.skill_registry.clear() - - async def unload_mcps(self): - """Disconnect from all MCP servers.""" - await self.mcp_registry.clear() - - async def enable_mcp(self, mcp_name: str) -> tuple[bool, str]: - """Enable and connect to a single MCP server without reinitializing all MCPs. - - Args: - mcp_name: The name of the MCP server to enable - - Returns: - (success, message) tuple - """ - # Check if MCP SDK is available - if not self.mcp_client.is_available: - return False, "MCP SDK not installed." - - # Check if already connected - if mcp_name in self.mcp_registry.get_connected_server_names(): - return True, f"MCP server '{mcp_name}' is already connected." - - # Find the MCP config from central mcp.yaml - central_mcp_config = self.tower.config_manager.mcp_config - mcp_configs = central_mcp_config.servers if central_mcp_config else [] - - mcp_config = None - for cfg in mcp_configs: - if cfg.name == mcp_name: - mcp_config = cfg - break - - if not mcp_config: - return False, f"MCP server '{mcp_name}' not found in mcp.yaml." - - try: - # Build headers with secrets (same logic as init_mcps) - headers = {} - if mcp_config.headers: - headers.update(mcp_config.headers) - - # Check for API key in secrets - secret_key = f"mcp_{mcp_config.name}" - api_key = await self.secret_keeper.retrieve( - requester=self.name, - key=secret_key, - prompt_if_missing=False, - ) - if api_key: - if not any( - k.lower() in ["authorization", "api-key", "x-api-key"] - for k in headers.keys() - ): - headers["Authorization"] = f"Bearer {api_key}" - - # Connect with timeout - default_timeout = 60.0 if mcp_config.type.value == "stdio" else 30.0 - timeout = ( - float(mcp_config.timeout) if mcp_config.timeout else default_timeout - ) - - connection = await asyncio.wait_for( - self.mcp_registry.register_server( - config=mcp_config, - headers=headers if headers else None, - ), - timeout=timeout, - ) - - if connection.is_connected: - tool_count = len(connection.tools) - return True, f"MCP server '{mcp_name}' enabled with {tool_count} tools." - else: - error = connection.error or "Connection failed." - return False, f"MCP server '{mcp_name}' failed to connect: {error}" - - except asyncio.TimeoutError: - error_msg = f"Connection timed out ({int(timeout)}s)." - self.mcp_registry.set_server_error(mcp_name, error_msg) - return False, f"MCP server '{mcp_name}': {error_msg}" - - except Exception as e: - error_msg = f"Error enabling MCP '{mcp_name}': {str(e)}" - await printr.print_async(error_msg, color=LogType.ERROR) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - return False, error_msg - - async def disable_mcp(self, mcp_name: str) -> tuple[bool, str]: - """Disable and disconnect from a single MCP server without affecting other MCPs. - - Args: - mcp_name: The name of the MCP server to disable - - Returns: - (success, message) tuple - """ - # Check if the MCP is connected - if mcp_name not in self.mcp_registry.get_connected_server_names(): - return True, f"MCP server '{mcp_name}' is already disconnected." - - try: - await self.mcp_registry.unregister_server(mcp_name) - return True, f"MCP server '{mcp_name}' disabled." - - except Exception as e: - error_msg = f"Error disabling MCP '{mcp_name}': {str(e)}" - await printr.print_async(error_msg, color=LogType.ERROR) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - return False, error_msg - - async def prepare_skill(self, skill: Skill): - # prepare the skill and skill tools - try: - for tool_name, tool in skill.get_tools(): - self.tool_skills[tool_name] = skill - self.skill_tools.append(tool) - - # Register with the progressive disclosure registry - self.skill_registry.register_skill(skill) - - # Auto-activated skills need to be validated/prepared immediately - # so their hooks (like on_play_to_user) will work - if skill.config.auto_activate: - success, message = await skill.ensure_activated() - if not success: - await printr.print_async( - f"Auto-activated skill '{skill.config.display_name}' failed to activate: {message}", - color=LogType.ERROR, - ) - except Exception as e: - await printr.print_async( - f"Error while preparing skill '{skill.name}': {str(e)}", - color=LogType.ERROR, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - - # init skill methods - skill.llm_call = self.actual_llm_call - - async def unprepare_skill(self, skill: Skill): - """Remove a skill's tools and registrations when it's disabled.""" - try: - # Remove tool mappings - for tool_name, _ in skill.get_tools(): - self.tool_skills.pop(tool_name, None) - # Remove from skill_tools list - self.skill_tools = [ - t - for t in self.skill_tools - if t.get("function", {}).get("name") != tool_name - ] - - # Unregister from the progressive disclosure registry - self.skill_registry.unregister_skill(skill.name) - except Exception as e: - await printr.print_async( - f"Error while unpreparing skill '{skill.name}': {str(e)}", - color=LogType.ERROR, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - - async def init_mcps(self) -> list[WingmanInitializationError]: - """ - Initialize MCP (Model Context Protocol) server connections. - - Loads MCP servers from central mcp.yaml config, only connecting those in wingman's discoverable_mcps. - MCP servers provide external tools similar to skills. - - Returns: - list[WingmanInitializationError]: Errors encountered (non-fatal, wingman still loads) - """ - errors = [] - - # Check if MCP SDK is available - if not self.mcp_client.is_available: - printr.print( - f"[{self.name}] MCP SDK not installed, skipping MCP initialization.", - color=LogType.WARNING, - server_only=True, - ) - return errors - - # Disconnect existing MCP servers - await self.unload_mcps() - - # Get MCP configs from central mcp.yaml - central_mcp_config = self.tower.config_manager.mcp_config - mcp_configs = central_mcp_config.servers if central_mcp_config else [] - if not mcp_configs: - return errors - - # Get discoverable MCPs list (whitelist) from wingman config - discoverable_mcps = self.config.discoverable_mcps - - # Filter to only discoverable MCPs - mcps_to_connect = [mcp for mcp in mcp_configs if mcp.name in discoverable_mcps] - - if not mcps_to_connect: - return errors - - # Prepare connection tasks for parallel execution - async def connect_mcp(mcp_config): - """Connect to a single MCP server. Returns (success, connection_info, errors).""" - local_errors = [] - try: - # Build headers with secrets - headers = {} - if mcp_config.headers: - headers.update(mcp_config.headers) - - # Check for API key in secrets (using mcp_ prefix) - secret_key = f"mcp_{mcp_config.name}" - api_key = await self.secret_keeper.retrieve( - requester=self.name, - key=secret_key, - prompt_if_missing=False, - ) - if api_key: - printr.print( - f"MCP secret '{secret_key}' found ({len(api_key)} chars)", - color=LogType.INFO, - source_name=self.name, - server_only=True, - ) - if not any( - k.lower() in ["authorization", "api-key", "x-api-key"] - for k in headers.keys() - ): - headers["Authorization"] = f"Bearer {api_key}" - - # Connect with timeout - default_timeout = 60.0 if mcp_config.type.value == "stdio" else 30.0 - timeout = ( - float(mcp_config.timeout) if mcp_config.timeout else default_timeout - ) - - try: - connection = await asyncio.wait_for( - self.mcp_registry.register_server( - config=mcp_config, - headers=headers if headers else None, - ), - timeout=timeout, - ) - except asyncio.TimeoutError: - error_msg = f"MCP '{mcp_config.display_name}' connection timed out ({int(timeout)}s)." - printr.print( - error_msg, - color=LogType.WARNING, - source_name=self.name, - server_only=True, - ) - local_errors.append( - WingmanInitializationError( - wingman_name=self.name, - message=error_msg, - error_type=WingmanInitializationErrorType.MCP_CONNECTION_FAILED, - ) - ) - return (False, None, local_errors) - - if connection.is_connected: - return ( - True, - f"{mcp_config.display_name} ({len(connection.tools)} tools)", - local_errors, - ) - else: - error_msg = f"MCP '{mcp_config.display_name}' failed to connect: {connection.error}" - local_errors.append( - WingmanInitializationError( - wingman_name=self.name, - message=error_msg, - error_type=WingmanInitializationErrorType.MCP_CONNECTION_FAILED, - ) - ) - return (False, None, local_errors) - - except Exception as e: - error_msg = f"MCP '{mcp_config.name}' initialization error: {str(e)}" - printr.print( - error_msg, - color=LogType.ERROR, - source_name=self.name, - server_only=True, - ) - printr.print( - traceback.format_exc(), color=LogType.ERROR, server_only=True - ) - local_errors.append( - WingmanInitializationError( - wingman_name=self.name, - message=error_msg, - error_type=WingmanInitializationErrorType.MCP_CONNECTION_FAILED, - ) - ) - return (False, None, local_errors) - - # Connect to all MCPs in parallel - connection_tasks = [connect_mcp(mcp) for mcp in mcps_to_connect] - results = await asyncio.gather(*connection_tasks) - - # Collect results - connected_count = 0 - connected_names = [] - for success, connection_info, mcp_errors in results: - if success: - connected_count += 1 - connected_names.append(connection_info) - errors.extend(mcp_errors) - - # Log consolidated MCP status for this wingman - if connected_count > 0: - await printr.print_async( - f"Discoverable MCP servers connected ({connected_count}): {', '.join(connected_names)}", - color=LogType.WINGMAN, - source=LogSource.WINGMAN, - source_name=self.name, - server_only=not self.settings.debug_mode, - ) - - return errors - - async def validate_and_set_openai(self, errors: list[WingmanInitializationError]): - api_key = await self.retrieve_secret("openai", errors) - if api_key: - self.openai = OpenAi( - api_key=api_key, - organization=self.config.openai.organization, - base_url=self.config.openai.base_url, - ) - - async def validate_and_set_mistral(self, errors: list[WingmanInitializationError]): - api_key = await self.retrieve_secret("mistral", errors) - if api_key: - # TODO: maybe use their native client (or LangChain) instead of OpenAI(?) - self.mistral = OpenAi( - api_key=api_key, - organization=self.config.openai.organization, - base_url=self.config.mistral.endpoint, - ) - - async def validate_and_set_groq(self, errors: list[WingmanInitializationError]): - api_key = await self.retrieve_secret("groq", errors) - if api_key: - # TODO: maybe use their native client (or LangChain) instead of OpenAI(?) - self.groq = OpenAi( - api_key=api_key, - base_url=self.config.groq.endpoint, - ) - - async def validate_and_set_cerebras(self, errors: list[WingmanInitializationError]): - api_key = await self.retrieve_secret("cerebras", errors) - if api_key: - # TODO: maybe use their native client (or LangChain) instead of OpenAI(?) - self.cerebras = OpenAi( - api_key=api_key, - base_url=self.config.cerebras.endpoint, - ) - - async def validate_and_set_google(self, errors: list[WingmanInitializationError]): - api_key = await self.retrieve_secret("google", errors) - if api_key: - if self.google: - await self.google.aclose() - self.google = GoogleGenAI(api_key=api_key) - - async def validate_and_set_openrouter( - self, errors: list[WingmanInitializationError] - ): - api_key = await self.retrieve_secret("openrouter", errors) - - async def does_openrouter_model_support_tools(model_id: str): - if not model_id: - return False - response = requests.get( - url=f"https://openrouter.ai/api/v1/models/{model_id}/endpoints", - timeout=10, - ) - response.raise_for_status() - content = response.json() - result = OpenRouterEndpointResult(**content.get("data", {})) - supports_tools = any( - all( - p in (endpoint.supported_parameters or []) - for p in ["tools", "tool_choice"] - ) - for endpoint in result.endpoints - ) - if not supports_tools: - printr.print( - f"{self.name}: OpenRouter model {model_id} does not support tools, so they'll be omitted from calls.", - source=LogSource.WINGMAN, - source_name=self.name, - color=LogType.WARNING, - server_only=True, - ) - return supports_tools - - if api_key: - self.openrouter = OpenAi( - api_key=api_key, - base_url=self.config.openrouter.endpoint, - ) - self.openrouter_model_supports_tools = ( - await does_openrouter_model_support_tools( - self.config.openrouter.conversation_model - ) - ) - - async def validate_and_set_local_llm( - self, errors: list[WingmanInitializationError] - ): - api_key = await self.retrieve_secret("local_llm", errors, is_required=False) - self.local_llm = OpenAi( - api_key=api_key or "local", - base_url=self.config.local_llm.endpoint, - ) - - async def validate_and_set_elevenlabs( - self, errors: list[WingmanInitializationError] - ): - api_key = await self.retrieve_secret("elevenlabs", errors) - if api_key: - self.elevenlabs = ElevenLabs( - api_key=api_key, - wingman_name=self.name, - ) - self.elevenlabs.validate_config( - config=self.config.elevenlabs, errors=errors - ) - - async def validate_and_set_openai_compatible_tts( - self, errors: list[WingmanInitializationError] - ): - if ( - self.config.openai_compatible_tts.base_url - and self.config.openai_compatible_tts.api_key - ): - self.openai_compatible_tts = OpenAiCompatibleTts( - api_key=self.config.openai_compatible_tts.api_key, - base_url=self.config.openai_compatible_tts.base_url, - ) - printr.print( - f"Wingman {self.name}: Initialized OpenAI-compatible TTS with base URL {self.config.openai_compatible_tts.base_url} and API key {self.config.openai_compatible_tts.api_key}", - server_only=True, - ) - - async def validate_and_set_hume(self, errors: list[WingmanInitializationError]): - api_key = await self.retrieve_secret("hume", errors) - if api_key: - self.hume = Hume( - api_key=api_key, - wingman_name=self.name, - ) - self.hume.validate_config(config=self.config.hume, errors=errors) - - async def validate_and_set_inworld(self, errors: list[WingmanInitializationError]): - api_key = await self.retrieve_secret("inworld", errors) - if api_key: - self.inworld = Inworld( - api_key=api_key, - wingman_name=self.name, - ) - self.inworld.validate_config(config=self.config.inworld, errors=errors) - - async def validate_and_set_azure(self, errors: list[WingmanInitializationError]): - for key_type in self.AZURE_SERVICES: - if self.uses_provider("azure"): - api_key = await self.retrieve_secret(f"azure_{key_type}", errors) - if api_key: - self.azure_api_keys[key_type] = api_key - if len(errors) == 0: - self.openai_azure = OpenAiAzure() - - async def validate_and_set_wingman_pro(self): - self.wingman_pro = WingmanPro( - wingman_name=self.name, settings=self.settings.wingman_pro - ) - - async def validate_and_set_perplexity( - self, errors: list[WingmanInitializationError] - ): - api_key = await self.retrieve_secret("perplexity", errors) - if api_key: - self.perplexity = OpenAi( - api_key=api_key, - base_url=self.config.perplexity.endpoint, - ) - - async def validate_and_set_xai(self, errors: list[WingmanInitializationError]): - api_key = await self.retrieve_secret("xai", errors) - if api_key: - self.xai = XAi( - api_key=api_key, - base_url=self.config.xai.endpoint, - ) - - # overrides the base class method - async def update_settings(self, settings: SettingsConfig): - """Update the settings of the Wingman. This method should always be called when the user Settings have changed.""" - try: - await super().update_settings(settings) - - if self.uses_provider("wingman_pro"): - await self.validate_and_set_wingman_pro() - printr.print( - f"Wingman {self.name}: reinitialized Wingman Pro with new settings", - server_only=True, - ) - except Exception as e: - await printr.print_async( - f"Error while updating settings for wingman '{self.name}': {str(e)}", - color=LogType.ERROR, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - - async def _generate_instant_responses(self) -> None: - """Generates general instant responses based on given context.""" - context = await self.get_context() - messages = [ - { - "role": "system", - "content": """ - Generate a list in JSON format of at least 20 short direct text responses. - Make sure the response only contains the JSON, no additional text. - They must fit the described character in the given context by the user. - Every generated response must be generally usable in every situation. - Responses must show its still in progress and not in a finished state. - The user request this response is used on is unknown. Therefore it must be generic. - Good examples: - - "Processing..." - - "Stand by..." - - Bad examples: - - "Generating route..." (too specific) - - "I'm sorry, I can't do that." (too negative) - - Response example: - [ - "OK", - "Generating results...", - "Roger that!", - "Stand by..." - ] - """, - }, - {"role": "user", "content": context}, - ] - try: - completion = await self.actual_llm_call(messages) - if completion is None: - return - if completion.choices[0].message.content: - retry_limit = 3 - retry_count = 1 - valid = False - while not valid and retry_count <= retry_limit: - try: - responses = json.loads(completion.choices[0].message.content) - valid = True - for response in responses: - if response not in self.instant_responses: - self.instant_responses.append(str(response)) - except json.JSONDecodeError: - messages.append(completion.choices[0].message) - messages.append( - { - "role": "user", - "content": "It was tried to handle the response in its entirety as a JSON string. Fix response to be a pure, valid JSON, it was not convertable.", - } - ) - if retry_count <= retry_limit: - completion = await self.actual_llm_call(messages) - retry_count += 1 - except Exception as e: - await printr.print_async( - f"Error while generating instant responses: {str(e)}", - color=LogType.ERROR, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - - async def _transcribe(self, audio_input_wav: str) -> str | None: - """Transcribes the recorded audio to text using the OpenAI Whisper API. - - Args: - audio_input_wav (str): The path to the audio file that contains the user's speech. This is a recording of what you you said. - - Returns: - str | None: The transcript of the audio file or None if the transcription failed. - """ - transcript = None - - try: - if self.config.features.stt_provider == SttProvider.AZURE: - transcript = self.openai_azure.transcribe_whisper( - filename=audio_input_wav, - api_key=self.azure_api_keys["whisper"], - config=self.config.azure.whisper, - ) - elif self.config.features.stt_provider == SttProvider.AZURE_SPEECH: - transcript = self.openai_azure.transcribe_azure_speech( - filename=audio_input_wav, - api_key=self.azure_api_keys["tts"], - config=self.config.azure.stt, - ) - elif self.config.features.stt_provider == SttProvider.WHISPERCPP: - transcript = self.whispercpp.transcribe( - filename=audio_input_wav, config=self.config.whispercpp - ) - elif self.config.features.stt_provider == SttProvider.FASTER_WHISPER: - hotwords: list[str] = [] - # add my name - hotwords.append(self.name) - # add default hotwords - default_hotwords = self.config.fasterwhisper.hotwords - if default_hotwords and len(default_hotwords) > 0: - hotwords.extend(default_hotwords) - # and my additional hotwords - wingman_hotwords = self.config.fasterwhisper.additional_hotwords - if wingman_hotwords and len(wingman_hotwords) > 0: - hotwords.extend(wingman_hotwords) - - transcript = self.fasterwhisper.transcribe( - filename=audio_input_wav, - config=self.config.fasterwhisper, - hotwords=list(set(hotwords)), - ) - elif self.config.features.stt_provider == SttProvider.PARAKEET: - transcript = self.parakeet.transcribe( - config=self.config.parakeet, - filename=audio_input_wav, - ) - elif self.config.features.stt_provider == SttProvider.WINGMAN_PRO: - if ( - self.config.wingman_pro.stt_provider - == WingmanProSttProvider.WHISPER - ): - transcript = self.wingman_pro.transcribe_whisper( - filename=audio_input_wav - ) - elif ( - self.config.wingman_pro.stt_provider - == WingmanProSttProvider.AZURE_SPEECH - ): - transcript = self.wingman_pro.transcribe_azure_speech( - filename=audio_input_wav, config=self.config.azure.stt - ) - elif self.config.features.stt_provider == SttProvider.OPENAI: - transcript = self.openai.transcribe(filename=audio_input_wav) - elif self.config.features.stt_provider == SttProvider.GROQ: - transcript = self.groq.transcribe( - filename=audio_input_wav, model="whisper-large-v3-turbo" - ) - except Exception as e: - await printr.print_async( - f"Error during transcription using '{self.config.features.stt_provider}': {str(e)}", - color=LogType.ERROR, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - - result = None - if transcript: - # Wingman Pro might returns a serialized dict instead of a real Azure Speech transcription object - result = ( - transcript.get("_text") - if isinstance(transcript, dict) - else transcript.text - ) - - return result - - async def _get_response_for_transcript( - self, transcript: str, benchmark: Benchmark, images: list[tuple[str, str]] = None - ) -> tuple[str | None, str | None, Skill | None, bool]: - """Gets the response for a given transcript. - - This function interprets the transcript, runs instant commands if triggered, - calls the OpenAI API when needed, processes any tool calls, and generates the final response. - - Args: - transcript (str): The user's spoken text transcribed. - - Returns: - tuple[str | None, str | None, Skill | None, bool]: A tuple containing the final response, the instant response (if any), the skill that was used, and a boolean indicating whether the current audio should be interrupted. - """ - self.ensure_memory_initialized() - - await self.add_user_message(transcript, images=images) - - benchmark.start_snapshot("Instant activation commands") - instant_response, instant_command_executed = await self._try_instant_activation( - transcript=transcript - ) - if instant_response: - await self.add_assistant_message(instant_response) - benchmark.finish_snapshot() - if ( - instant_response == "." - ): # thats for the "The UI should not give a response" option in commands - instant_response = None - return instant_response, instant_response, None, True - benchmark.finish_snapshot() - - # Track cumulative times for proper aggregation - llm_processing_time_ms = 0.0 - tool_execution_time_ms = 0.0 - tool_timings: list[tuple[str, float]] = ( - [] - ) # (label, time_ms) for individual tools - - # make a GPT call with the conversation history - # if an instant command got executed, prevent tool calls to avoid duplicate executions - llm_start = time.perf_counter() - completion = await self._llm_call(instant_command_executed is False) - llm_processing_time_ms += (time.perf_counter() - llm_start) * 1000 - - if completion is None: - self._add_benchmark_snapshot( - benchmark, "LLM Processing", llm_processing_time_ms - ) - return None, None, None, True - - response_message, tool_calls, usage = await self._process_completion( - completion, instant_command_executed is False - ) - - # Track token usage across the turn (last prompt_tokens, summed completion_tokens) - turn_prompt_tokens = usage[0] - turn_completion_tokens = usage[1] - self._last_prompt_tokens = turn_prompt_tokens - - # add message and dummy tool responses to conversation history - is_waiting_response_needed, is_summarize_needed = await self._add_gpt_response( - response_message, tool_calls - ) - interrupt = True # initial answer should be awaited if exists - - while tool_calls: - if is_waiting_response_needed: - message = None - if response_message.content: - message = response_message.content - elif self.instant_responses: - message = self._get_random_filler() - is_summarize_needed = True - if message: - self.threaded_execution(self.play_to_user, message, interrupt) - await printr.print_async( - f"{message}", - color=LogType.POSITIVE, - source=LogSource.WINGMAN, - source_name=self.name, - skill_name="", - ) - interrupt = False - else: - is_summarize_needed = True - else: - is_summarize_needed = True - - # Time tool execution and collect individual timings - tool_start = time.perf_counter() - instant_response, skill, iteration_timings = await self._handle_tool_calls( - tool_calls - ) - tool_execution_time_ms += (time.perf_counter() - tool_start) * 1000 - tool_timings.extend(iteration_timings) - - if instant_response: - await self._trim_tool_responses() - # Add snapshots before returning - self._add_benchmark_snapshot( - benchmark, "LLM Processing", llm_processing_time_ms - ) - if tool_execution_time_ms > 0: - self._add_tool_execution_snapshot( - benchmark, tool_execution_time_ms, tool_timings - ) - await self._broadcast_token_usage( - turn_prompt_tokens, turn_completion_tokens - ) - return None, instant_response, None, interrupt - - if is_summarize_needed: - # Time the follow-up LLM call - llm_start = time.perf_counter() - completion = await self._llm_call(True) - llm_processing_time_ms += (time.perf_counter() - llm_start) * 1000 - - if completion is None: - await self._trim_tool_responses() - self._add_benchmark_snapshot( - benchmark, "LLM Processing", llm_processing_time_ms - ) - if tool_execution_time_ms > 0: - self._add_tool_execution_snapshot( - benchmark, tool_execution_time_ms, tool_timings - ) - await self._broadcast_token_usage( - turn_prompt_tokens, turn_completion_tokens - ) - return None, None, None, True - - response_message, tool_calls, usage = await self._process_completion( - completion - ) - # Last call's prompt_tokens is most meaningful (includes full context) - turn_prompt_tokens = usage[0] - turn_completion_tokens += usage[1] - self._last_prompt_tokens = turn_prompt_tokens - - is_waiting_response_needed, is_summarize_needed = ( - await self._add_gpt_response(response_message, tool_calls) - ) - if tool_calls: - interrupt = False - elif is_waiting_response_needed: - await self._trim_tool_responses() - self._add_benchmark_snapshot( - benchmark, "LLM Processing", llm_processing_time_ms - ) - if tool_execution_time_ms > 0: - self._add_tool_execution_snapshot( - benchmark, tool_execution_time_ms, tool_timings - ) - await self._broadcast_token_usage( - turn_prompt_tokens, turn_completion_tokens - ) - return None, None, None, interrupt - - # Trim oversized tool responses now that the LLM has processed them - await self._trim_tool_responses() - - # Add final snapshots - self._add_benchmark_snapshot( - benchmark, "LLM Processing", llm_processing_time_ms - ) - if tool_execution_time_ms > 0: - self._add_tool_execution_snapshot( - benchmark, tool_execution_time_ms, tool_timings - ) - await self._broadcast_token_usage(turn_prompt_tokens, turn_completion_tokens) - return response_message.content, response_message.content, None, interrupt - - def _add_benchmark_snapshot( - self, benchmark: Benchmark, label: str, execution_time_ms: float - ): - """Add a snapshot with the given label and execution time.""" - if execution_time_ms >= 1000: - formatted_time = f"{execution_time_ms/1000:.1f}s" - else: - formatted_time = f"{int(execution_time_ms)}ms" - - from api.interface import BenchmarkResult - - benchmark.snapshots.append( - BenchmarkResult( - label=label, - execution_time_ms=execution_time_ms, - formatted_execution_time=formatted_time, - ) - ) - - def _add_tool_execution_snapshot( - self, - benchmark: Benchmark, - total_time_ms: float, - tool_timings: list[tuple[str, float]], - ): - """Add a tool execution snapshot with nested individual tool timings.""" - from api.interface import BenchmarkResult - - if total_time_ms >= 1000: - formatted_time = f"{total_time_ms/1000:.1f}s" - else: - formatted_time = f"{int(total_time_ms)}ms" - - # Create nested snapshots for individual tools - nested_snapshots = [] - for label, time_ms in tool_timings: - if time_ms >= 1000: - fmt = f"{time_ms/1000:.1f}s" - else: - fmt = f"{int(time_ms)}ms" - nested_snapshots.append( - BenchmarkResult( - label=label, - execution_time_ms=time_ms, - formatted_execution_time=fmt, - ) - ) - - benchmark.snapshots.append( - BenchmarkResult( - label="Tool Execution", - execution_time_ms=total_time_ms, - formatted_execution_time=formatted_time, - snapshots=nested_snapshots if nested_snapshots else None, - ) - ) - - async def _broadcast_token_usage(self, prompt_tokens: int, completion_tokens: int): - """Broadcast actual API-reported token usage to the client.""" - is_local = ( - self.config.features.conversation_provider == ConversationProvider.LOCAL_LLM - ) - - # Local providers (e.g. Ollama) often report 0 tokens — estimate from messages - if is_local and prompt_tokens == 0: - prompt_tokens = sum( - count_tokens( - msg["content"] - if isinstance(msg.get("content"), str) - else str(msg.get("content", "")) - ) - for msg in self.messages - ) - if is_local and completion_tokens == 0 and self.messages: - last = self.messages[-1] - if last.get("role") == "assistant": - content = last.get("content", "") - completion_tokens = count_tokens( - content if isinstance(content, str) else str(content) - ) - - self.last_turn_prompt_tokens = prompt_tokens - self.last_turn_completion_tokens = completion_tokens - if prompt_tokens == 0 and completion_tokens == 0: - return - if not printr._connection_manager: - return - - from api.commands import ConversationTokenUsageCommand - - await printr._connection_manager.broadcast( - ConversationTokenUsageCommand( - wingman_name=self.name, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - is_local=is_local, - ) - ) - - def _get_random_filler(self): - # get last two used instant responses - if len(self.last_used_instant_responses) > 2: - self.last_used_instant_responses = self.last_used_instant_responses[-2:] - - # get a random instant response that was not used in the last two responses - random_index = random.randint(0, len(self.instant_responses) - 1) - while random_index in self.last_used_instant_responses: - random_index = random.randint(0, len(self.instant_responses) - 1) - - # add the index to the last used list and return - self.last_used_instant_responses.append(random_index) - return self.instant_responses[random_index] - - async def _fix_tool_calls(self, tool_calls): - """Fixes tool calls that have a command name as function name. - - Args: - tool_calls (list): The tool calls to fix. - - Returns: - list: The fixed tool calls. - """ - if tool_calls and len(tool_calls) > 0: - for tool_call in tool_calls: - function_name = tool_call.function.name - function_args = ( - tool_call.function.arguments - # Mistral returns a dict - if isinstance(tool_call.function.arguments, dict) - # OpenAI returns a string - else json.loads(tool_call.function.arguments) - ) - - # try to resolve function name to a command name - if (len(function_args) == 0 and self.get_command(function_name)) or ( - len(function_args) == 1 - and "command_name" in function_args - and self.get_command(function_args["command_name"]) - and function_name == function_args["command_name"] - ): - function_args["command_name"] = function_name - function_name = "execute_command" - - # update the tool call - tool_call.function.name = function_name - tool_call.function.arguments = json.dumps(function_args) - - if self.settings.debug_mode: - await printr.print_async( - "Applied command call fix.", color=LogType.WARNING - ) - - return tool_calls - - async def _add_gpt_response(self, message, tool_calls) -> (bool, bool): - """Adds a message from GPT to the conversation history as well as adding dummy tool responses for any tool calls. - - Args: - message (dict | ChatCompletionMessage): The message to add. - tool_calls (list): The tool calls associated with the message. - """ - # call skill hooks (only for prepared/activated skills) - for skill in self.skills: - if skill.is_prepared: - await skill.on_add_assistant_message( - message.content, message.tool_calls - ) - - # do not tamper with this message as it will lead to 400 errors! - self.messages.append(message) - - # adding dummy tool responses to prevent corrupted message history on parallel requests - # and checks if waiting response should be played - unique_tools = {} - is_waiting_response_needed = False - is_summarize_needed = False - - if tool_calls: - for tool_call in tool_calls: - if not tool_call.id: - continue - # adding a dummy tool response to get updated later - self._add_tool_response(tool_call, "Loading..", False) - - function_name = tool_call.function.name - - # Meta-tools (search_skills, activate_skill, etc.) always need a follow-up - # LLM call so it can use the newly activated tools - if self.skill_registry.is_meta_tool(function_name): - is_summarize_needed = True - elif function_name in self.tool_skills: - skill = self.tool_skills[function_name] - if await skill.is_waiting_response_needed(function_name): - is_waiting_response_needed = True - if await skill.is_summarize_needed(function_name): - is_summarize_needed = True - - unique_tools[function_name] = True - - if len(unique_tools) == 1 and "execute_command" in unique_tools: - is_waiting_response_needed = True - - return is_waiting_response_needed, is_summarize_needed - - def _add_tool_response(self, tool_call, response: str, completed: bool = True): - """Adds a tool response to the conversation history. - - Args: - tool_call (dict|ChatCompletionMessageToolCall): The tool call to add the dummy response for. - """ - msg = {"role": "tool", "content": response} - if tool_call.id is not None: - msg["tool_call_id"] = tool_call.id - if tool_call.function.name is not None: - msg["name"] = tool_call.function.name - self.messages.append(msg) - - if tool_call.id and not completed: - self.pending_tool_calls.append(tool_call.id) - - async def _update_tool_response(self, tool_call_id, response) -> bool: - """Updates a tool response in the conversation history. - - Args: - tool_call_id (str): The identifier of the tool call to update the response for. - response (str): The new response to set. - - Returns: - bool: True if the response was updated, False if the tool call was not found. - """ - if not tool_call_id: - return False - - index = len(self.messages) - - # go through message history to find and update the tool call - for message in reversed(self.messages): - index -= 1 - if ( - self.__get_message_role(message) == "tool" - and message.get("tool_call_id") == tool_call_id - ): - message["content"] = str(response) - if tool_call_id in self.pending_tool_calls: - self.pending_tool_calls.remove(tool_call_id) - return True - - return False - - async def _trim_tool_responses(self, max_tokens: int = 500): - """Trim oversized tool responses in conversation history. - - Called after the LLM has finished processing a turn with tool calls. - The LLM already had full access to the data; this just prevents stale - bulk data from inflating the context on subsequent turns. - - If significant trimming occurs, broadcasts a condensation notification - so the client UI can display a summary indicator. - """ - total_tokens_saved = 0 - for msg in self.messages: - if self.__get_message_role(msg) != "tool": - continue - content = msg.get("content", "") - if not content: - continue - token_count = count_tokens(content) - if token_count <= max_tokens: - continue - total_tokens_saved += token_count - max_tokens - trimmed = truncate_to_tokens(content, max_tokens) - msg["content"] = ( - f"{trimmed}\n\n[...trimmed from ~{token_count} to " - f"~{max_tokens} tokens for conversation history. " - f"Full response was processed.]" - ) - - # Notify the client when significant trimming occurs so the UI can - # show a "Show history" indicator explaining the token drop. - # Skip if condensation is already running to avoid interfering with - # its own started/finished broadcast cycle. - if ( - total_tokens_saved > 1000 - and printr._connection_manager - and not self._is_condensing - ): - from api.commands import ConversationCondensationCommand - - if self.conversation_summary: - summary = ( - f"{self.conversation_summary}\n\n---\n\n" - f"[Latest turn: tool responses trimmed — ~{total_tokens_saved:,} tokens saved]" - ) - else: - summary = ( - f"Tool responses were automatically trimmed after LLM processing.\n" - f"~{total_tokens_saved:,} tokens saved.\n\n" - f"The LLM had full access to the complete data when generating " - f"its response. Responses are trimmed afterwards to keep the " - f"conversation context efficient." - ) - - await printr._connection_manager.broadcast( - ConversationCondensationCommand( - wingman_name=self.name, - status="finished", - estimated_tokens_saved=total_tokens_saved, - summary_text=summary, - ) - ) - - async def add_user_message(self, content: str, images: list[tuple[str, str]] = None): - """Shortens the conversation history if needed and adds a user message to it. - - Args: - content (str): The message content to add. - images (list[tuple[str, str]]): Optional list of (base64_data, mime_type) tuples to attach. - """ - self._memory_recall_notified = False - - # call skill hooks (only for prepared/activated skills) - for skill in self.skills: - if skill.is_prepared: - await skill.on_add_user_message(content) - - if images: - msg_content = [] - for img_b64, mime in images: - msg_content.append({ - "type": "image_url", - "image_url": { - "url": f"data:{mime};base64,{img_b64}", - "detail": "auto", - }, - }) - msg_content.append({"type": "text", "text": content}) - msg = {"role": "user", "content": msg_content} - else: - msg = {"role": "user", "content": content} - await self._cleanup_conversation_history() - await self._maybe_condense_history() - self.messages.append(msg) - - async def add_assistant_message(self, content: str): - """Adds an assistant message to the conversation history. - - Args: - content (str): The message content to add. - """ - # call skill hooks (only for prepared/activated skills) - for skill in self.skills: - if skill.is_prepared: - await skill.on_add_assistant_message(content, []) - - msg = {"role": "assistant", "content": content} - self.messages.append(msg) - - async def add_forced_assistant_command_calls(self, commands: list[CommandConfig]): - """Adds forced assistant command calls to the conversation history. - - Args: - commands (list[CommandConfig]): The commands to add. - """ - - if not commands: - return - - message = ChatCompletionMessage( - content="", - role="assistant", - tool_calls=[], - ) - tool_id_to_command = {} - for command in commands: - tool_id = None - if ( - self.config.features.conversation_provider - == ConversationProvider.OPENAI - ) or ( - self.config.features.conversation_provider - == ConversationProvider.WINGMAN_PRO - and "gpt" in self.config.wingman_pro.conversation_deployment.lower() - ): - tool_id = f"call_{str(uuid.uuid4()).replace('-', '')}" - elif ( - self.config.features.conversation_provider - == ConversationProvider.GOOGLE - ): - if ( - self.config.google.conversation_model.startswith("gemini-3") - or self.config.google.conversation_model == "gemini-flash-latest" - or self.config.google.conversation_model == "gemini-pro-latest" - or self.config.google.conversation_model - == "gemini-flash-lite-latest" - ): - # gemini 3+ (latest = 3+) needs a thought signature like this, but we cant fake it: - # { - # 'model_extra': { - # 'extra_content': { - # 'google': { - # 'thought_signature': 'EjQKMgFyyNp8mNe4bQmQhOua7gGMH0C9RubFWewy6BzYZJs5f4RqDb8CaiR4gjLxoM1iQqP4' - # } - # } - # } - # } - return - tool_id = f"function-call-{''.join(random.choices('0123456789', k=20))}" - - # early exit for unsupported providers/models - if not tool_id: - return - - tool_call = ChatCompletionMessageToolCall( - id=tool_id, - function=ParsedFunction( - name="execute_command", - arguments=json.dumps({"command_name": command.name}), - ), - type="function", - ) - message.tool_calls.append(tool_call) - tool_id_to_command[tool_id] = command - - await self._add_gpt_response(message, message.tool_calls) - for tool_call in message.tool_calls: - command = tool_id_to_command[tool_call.id] - await self._update_tool_response( - tool_call.id, command.additional_context or "OK" - ) - - async def _cleanup_conversation_history(self): - """Cleans up the conversation history by removing messages that are too old.""" - remember_messages = self.config.features.remember_messages - - if remember_messages is None or len(self.messages) == 0: - return 0 # Configuration not set, nothing to delete. - - # Find the cutoff index where to end deletion, making sure to only count 'user' messages towards the limit starting with newest messages. - cutoff_index = len(self.messages) - user_message_count = 0 - for message in reversed(self.messages): - if self.__get_message_role(message) == "user": - user_message_count += 1 - if user_message_count == remember_messages: - break # Found the cutoff point. - cutoff_index -= 1 - - # If messages below the keep limit, don't delete anything. - if user_message_count < remember_messages: - return 0 - - total_deleted_messages = cutoff_index # Messages to delete. - - # Remove the pending tool calls that are no longer needed. - for mesage in self.messages[:cutoff_index]: - if ( - self.__get_message_role(mesage) == "tool" - and mesage.get("tool_call_id") in self.pending_tool_calls - ): - self.pending_tool_calls.remove(mesage.get("tool_call_id")) - if self.settings.debug_mode: - await printr.print_async( - f"Removing pending tool call {mesage.get('tool_call_id')} due to message history clean up.", - color=LogType.WARNING, - ) - - # Remove the messages before the cutoff index, exclusive of the system message. - del self.messages[:cutoff_index] - - # Optional debugging printout. - if self.settings.debug_mode and total_deleted_messages > 0: - await printr.print_async( - f"Deleted {total_deleted_messages} messages from the conversation history.", - color=LogType.WARNING, - ) - - return total_deleted_messages - - def _estimate_conversation_tokens(self) -> int: - """Estimate the total token count of the current conversation history.""" - return sum(count_tokens(self._message_text_content(m)) for m in self.messages) - - def _get_support_capacity(self) -> int: - """Get the effective input capacity of the support model for a single pass. - - Returns the number of conversation tokens that can fit in one summarization - pass, accounting for system prompt, framing text, and output budget. - """ - system_prompt = get_prompt("condense-conversation") - budget = self.local_ai_service.get_token_budget(system_prompt) - # Subtract framing overhead (prefix/suffix around the conversation text) - framing_overhead = 80 - return max(0, budget.max_input_tokens - framing_overhead) - - async def _maybe_condense_history(self): - """Check if condensation should run and fire it as a background task. - - Uses a token-based trigger: condenses when the conversation approaches - 70% of what the support model can handle in a single pass, so we - avoid chunking. Also has a message count safety cap. - - The cl100k_base token estimate is multiplied by _support_token_ratio - (calibrated from real support model usage) to account for tokenizer - differences between the estimation tokenizer and the actual model. - """ - if not self.config.features.condense_conversation: - return - if not self.local_ai_service or not self.local_ai_service.is_ready(): - return - if self.pending_tool_calls: - return # Never interrupt chained tool calls - if self._is_condensing: - return - - # Token-based trigger: condense when conversation reaches 70% of - # what the support model can handle in one pass. - # Apply _support_token_ratio to correct for tokenizer differences - # between cl100k_base (used for estimation) and the actual model. - capacity = self._get_support_capacity() - cl100k_tokens = self._estimate_conversation_tokens() - conversation_tokens = int(cl100k_tokens * self._support_token_ratio) - token_trigger = conversation_tokens >= int(capacity * 0.7) - - # Message count safety cap - user_msg_count = sum( - 1 for m in self.messages if self.__get_message_role(m) == "user" - ) - message_trigger = user_msg_count >= self.config.features.condense_max_messages - - if not token_trigger and not message_trigger: - return - - # Runs in background so user is never blocked. - # Store the task reference to prevent garbage collection mid-execution. - self._condense_task = asyncio.create_task(self._condense_history()) - self._condense_task.add_done_callback( - lambda _: setattr(self, "_condense_task", None) - ) - - async def _condense_history(self, force: bool = False): - """Condense older conversation messages into a running summary using local AI. - - This preserves the most recent messages verbatim while summarizing older ones, - saving tokens without losing important context. Tool call/response pairs are - never split. - - Args: - force: If True, skip the threshold check (used for manual trigger). - """ - if self._is_condensing: - await printr.print_async( - "Condensation skipped — already in progress.", - color=LogType.WARNING, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - return - if not self.local_ai_service or not self.local_ai_service.is_ready(): - await printr.print_async( - "Condensation skipped — local AI service not available.", - color=LogType.WARNING, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - return - - keep_recent = ( - self.config.features.condense_keep_recent - if not force - else min(self.config.features.condense_keep_recent, 2) - ) - total_msg_count = len(self.messages) - - # Need at least something to condense beyond what we keep - if total_msg_count <= keep_recent: - await printr.print_async( - f"Condensation skipped — only {total_msg_count} messages, need more than {keep_recent} to condense.", - color=LogType.LOCALMODEL, - source_name=self.name, - source=LogSource.WINGMAN, - ) - return - - self._is_condensing = True - _condensation_stats: dict = {} - - # Broadcast start - from api.commands import ConversationCondensationCommand - - if printr._connection_manager: - await printr._connection_manager.broadcast( - ConversationCondensationCommand( - wingman_name=self.name, - status="started", - ) - ) - - await printr.print_async( - "Conversation condensation started.", - color=LogType.INFO, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - - try: - # Wait for any pending tool calls to finish - for _ in range(30): # max 15 seconds - if not self.pending_tool_calls: - break - await asyncio.sleep(0.5) - else: - await printr.print_async( - "Condensation aborted — tool calls still pending after 15s.", - color=LogType.WARNING, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - return - - # Find the cutoff: keep the most recent `keep_recent` user messages - kept_user_count = 0 - cutoff_index = len(self.messages) - for i in range(len(self.messages) - 1, -1, -1): - if self.__get_message_role(self.messages[i]) == "user": - kept_user_count += 1 - if kept_user_count == keep_recent: - cutoff_index = i - break - - if cutoff_index <= 0: - await printr.print_async( - f"Condensation skipped — cutoff_index={cutoff_index}, nothing to condense (kept_user_count={kept_user_count}, keep_recent={keep_recent}, total={len(self.messages)}).", - color=LogType.LOCALMODEL, - source_name=self.name, - source=LogSource.WINGMAN, - ) - return - - # Adjust cutoff forward to avoid orphaning tool responses - while cutoff_index < len(self.messages): - msg = self.messages[cutoff_index] - if self.__get_message_role(msg) == "tool": - cutoff_index += 1 - else: - break - - if cutoff_index <= 0: - await printr.print_async( - "Condensation skipped — no messages to condense after tool adjustment.", - color=LogType.LOCALMODEL, - source_name=self.name, - source=LogSource.WINGMAN, - ) - return - - to_condense = self.messages[:cutoff_index] - - # Extract memories from messages about to be condensed (background, non-blocking) - if self.persistent_memory_service: - try: - task = asyncio.create_task( - self.persistent_memory_service.extract_memories( - to_condense, generate_summary=True - ) - ) - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) - except Exception: - pass # Don't let memory extraction block condensation - - condensed_text = self._messages_to_text(to_condense) - if not condensed_text.strip(): - await printr.print_async( - "Condensation skipped — messages produced no text content.", - color=LogType.LOCALMODEL, - source_name=self.name, - source=LogSource.WINGMAN, - ) - return - - # Estimate original token count - estimated_original_tokens = sum( - count_tokens(self._message_text_content(m)) for m in to_condense - ) - - # Build the summarization prompt - existing_summary_section = "" - if self.conversation_summary: - existing_summary_section = ( - "EXISTING SUMMARY (incorporate and update — do not repeat verbatim):\n" - + self.conversation_summary - + "\n\n" - ) - - system_prompt = get_prompt("condense-conversation") - budget = self.local_ai_service.get_token_budget(system_prompt) - - user_prompt_prefix = ( - existing_summary_section + "CONVERSATION TO SUMMARIZE:\n" - ) - user_prompt_suffix = ( - "\n\n---\n" - "Now list every fact from the conversation above as bullet points.\n" - "Start from the FIRST message, end at the LAST. Include all secrets, names, preferences, and creative content:" - ) - prefix_suffix_tokens = count_tokens(user_prompt_prefix) + count_tokens( - user_prompt_suffix - ) - - # How much conversation text fits in one pass? - available_tokens = budget.max_input_tokens - prefix_suffix_tokens - - # Apply tokenizer ratio to decide if chunking is needed. - corrected_text_tokens = int( - count_tokens(condensed_text) * self._support_token_ratio - ) - corrected_available = int(available_tokens / self._support_token_ratio) - - if corrected_text_tokens > available_tokens: - # Chunk: summarize in segments, then merge - support_result = await asyncio.wait_for( - self._chunked_support( - condensed_text, - system_prompt, - existing_summary_section, - corrected_available, - ), - timeout=_CONDENSE_TIMEOUT, - ) - else: - user_prompt = ( - f"{user_prompt_prefix}{condensed_text}{user_prompt_suffix}" - ) - from services.skill_local_ai import SamplingPreset - - support_result = await asyncio.wait_for( - asyncio.get_event_loop().run_in_executor( - None, - lambda: self.local_ai_service.support( - text=user_prompt, - system_prompt=system_prompt, - preset=SamplingPreset.BALANCED, - ), - ), - timeout=_CONDENSE_TIMEOUT, - ) - - summary = support_result.text if support_result else None - - # Calibrate tokenizer ratio from real model usage - if support_result and support_result.prompt_tokens > 0: - cl100k_input = ( - budget.system_tokens - + count_tokens(condensed_text) - + prefix_suffix_tokens - ) - if cl100k_input > 0: - self._support_token_ratio = ( - support_result.prompt_tokens / cl100k_input - ) - - # Detect truncated output - if support_result and support_result.truncated: - await printr.print_async( - f"Condensation output was truncated (finish_reason=length). " - f"Model used {support_result.prompt_tokens} prompt tokens, " - f"generated {support_result.completion_tokens} tokens. " - f"Token ratio calibrated to {self._support_token_ratio:.2f}.", - color=LogType.WARNING, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - - if not summary: - await printr.print_async( - "Conversation condensation failed — local AI returned no result.", - color=LogType.WARNING, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - return - - # Clean pending tool calls being removed - for msg in to_condense: - if ( - self.__get_message_role(msg) == "tool" - and msg.get("tool_call_id") in self.pending_tool_calls - ): - self.pending_tool_calls.remove(msg.get("tool_call_id")) - - # Replace old messages - del self.messages[:cutoff_index] - self.conversation_summary = summary - - estimated_summary_tokens = count_tokens(summary) - estimated_tokens_saved = max( - 0, estimated_original_tokens - estimated_summary_tokens - ) - - await printr.print_async( - f"Condensed {cutoff_index} messages into summary " - f"({len(summary)} chars, ~{estimated_summary_tokens} tokens). " - f"{len(self.messages)} messages remaining. " - f"~{estimated_tokens_saved} tokens saved. " - f"Token ratio: {self._support_token_ratio:.2f}.", - color=LogType.INFO, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - - # Record stats for the broadcast in finally - _condensation_stats = { - "messages_condensed": cutoff_index, - "messages_remaining": len(self.messages), - "summary_length": len(summary), - "estimated_tokens_saved": estimated_tokens_saved, - "summary_text": summary, - } - - except asyncio.TimeoutError: - await printr.print_async( - "Condensation timed out — local model took too long.", - color=LogType.WARNING, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - except Exception as e: - await printr.print_async( - f"Conversation condensation error: {e}", - color=LogType.ERROR, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - finally: - self._is_condensing = False - # Always broadcast finished so the client UI doesn't get stuck. - # Include summary_text if condensation produced one (even if a - # later step failed), so the client can show the view-history button. - if printr._connection_manager: - try: - await printr._connection_manager.broadcast( - ConversationCondensationCommand( - wingman_name=self.name, - status="finished", - **_condensation_stats, - ) - ) - except Exception as e: - await printr.print_async( - f"Failed to broadcast condensation finish: {e}", - color=LogType.WARNING, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - - async def _chunked_support( - self, - full_text: str, - system_prompt: str, - existing_summary_section: str, - chunk_max_tokens: int, - ) -> "SupportResult": - """Process text that exceeds the model's context window by chunking. - - Each chunk is processed independently, then results are merged into - one final summary. Returns a SupportResult from the merge step. - """ - from providers.llama_cpp_provider import SupportResult - from services.skill_local_ai import SamplingPreset - - budget = self.local_ai_service.get_token_budget(system_prompt) - - # Convert token budget to approximate char limit for splitting - # (splitting needs char positions; we use ~4 chars/token as a rough guide, - # then verify with count_tokens) - approx_chunk_chars = chunk_max_tokens * 4 - chunks = [] - remaining = full_text - while remaining: - if count_tokens(remaining) <= chunk_max_tokens: - chunks.append(remaining) - break - # Try to split at a newline boundary - split_at = remaining.rfind("\n", 0, approx_chunk_chars) - if split_at <= 0: - split_at = approx_chunk_chars - chunks.append(remaining[:split_at]) - remaining = remaining[split_at:].lstrip() - - loop = asyncio.get_event_loop() - chunk_summaries = [] - for i, chunk in enumerate(chunks): - user_prompt = ( - f"{existing_summary_section if i == 0 else ''}" - f"CONVERSATION TO SUMMARIZE (part {i + 1}/{len(chunks)}):\n{chunk}\n\n" - "---\nList every fact from the above as bullet points. Include all secrets, names, preferences, and creative content:" - ) - - # Safety: if chunk input exceeds budget, truncate chunk text - chunk_text_tokens = count_tokens(chunk) - prompt_overhead = count_tokens(user_prompt) - chunk_text_tokens - if prompt_overhead + chunk_text_tokens > budget.max_input_tokens: - safe_text_tokens = budget.max_input_tokens - prompt_overhead - if safe_text_tokens > 0: - chunk = truncate_to_tokens(chunk, safe_text_tokens) - user_prompt = ( - f"{existing_summary_section if i == 0 else ''}" - f"CONVERSATION TO SUMMARIZE (part {i + 1}/{len(chunks)}):\n{chunk}\n\n" - "---\nList every fact from the above as bullet points. Include all secrets, names, preferences, and creative content:" - ) - await printr.print_async( - f"Chunk {i + 1}/{len(chunks)} exceeded context budget, truncated to fit.", - color=LogType.WARNING, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - - result = await loop.run_in_executor( - None, - lambda p=user_prompt: self.local_ai_service.support( - text=p, - system_prompt=system_prompt, - preset=SamplingPreset.BALANCED, - ), - ) - if result.text: - chunk_summaries.append(result.text) - - # Calibrate tokenizer ratio from real model usage. - cl100k_input = count_tokens(system_prompt) + count_tokens(user_prompt) - if result.prompt_tokens > 0 and cl100k_input > 0: - self._support_token_ratio = result.prompt_tokens / cl100k_input - - if result.truncated: - await printr.print_async( - f"Chunk {i + 1}/{len(chunks)} output truncated " - f"(prompt={result.prompt_tokens}, " - f"completion={result.completion_tokens}).", - color=LogType.WARNING, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - - if not chunk_summaries: - return SupportResult(text=None) - if len(chunk_summaries) == 1: - return SupportResult(text=chunk_summaries[0]) - - # Merge all chunk summaries into one final summary - combined = "\n\n".join( - f"Part {i + 1}:\n{s}" for i, s in enumerate(chunk_summaries) - ) - merge_prompt = ( - f"{existing_summary_section}" - f"PARTIAL SUMMARIES TO MERGE:\n{combined}\n\n" - "Merge these into a single coherent summary. Keep all key facts:" - ) - - # Safety: truncate combined summaries if they exceed budget - if count_tokens(merge_prompt) > budget.max_input_tokens: - overhead = count_tokens(merge_prompt) - count_tokens(combined) - safe_combined = budget.max_input_tokens - overhead - if safe_combined > 0: - combined = truncate_to_tokens(combined, safe_combined) - merge_prompt = ( - f"{existing_summary_section}" - f"PARTIAL SUMMARIES TO MERGE:\n{combined}\n\n" - "Merge these into a single coherent summary. Keep all key facts:" - ) - await printr.print_async( - f"Merge input exceeded context budget, truncated to fit.", - color=LogType.WARNING, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - - return await loop.run_in_executor( - None, - lambda: self.local_ai_service.support( - text=merge_prompt, - system_prompt=system_prompt, - preset=SamplingPreset.BALANCED, - ), - ) - - def _extract_text_content(self, content) -> str: - """Extract text from message content, handling both string and multimodal list formats.""" - if isinstance(content, str): - return content - if isinstance(content, list): - # Multimodal content: extract text parts only - parts = [] - for part in content: - if isinstance(part, dict) and part.get("type") == "text": - parts.append(part.get("text", "")) - elif isinstance(part, str): - parts.append(part) - return " ".join(parts) - return "" - - def _message_text_content(self, msg) -> str: - """Extract text content from a message for token estimation.""" - if isinstance(msg, Mapping): - return self._extract_text_content(msg.get("content", "")) or "" - elif hasattr(msg, "content"): - return self._extract_text_content(msg.content) or "" - return "" - - def _messages_to_text(self, messages: list) -> str: - """Convert a list of conversation messages to plain text for summarization.""" - lines = [] - for msg in messages: - role = self.__get_message_role(msg) - content = "" - if isinstance(msg, Mapping): - content = self._extract_text_content(msg.get("content", "")) - elif hasattr(msg, "content"): - content = self._extract_text_content(msg.content) or "" - - if role == "user": - lines.append(f"User: {content}") - elif role == "assistant": - if content: - lines.append(f"Assistant: {content}") - # Include tool call info - tool_calls = None - if isinstance(msg, Mapping): - tool_calls = msg.get("tool_calls") - elif hasattr(msg, "tool_calls"): - tool_calls = msg.tool_calls - if tool_calls: - for tc in tool_calls: - fn = ( - tc.function - if hasattr(tc, "function") - else tc.get("function", {}) - ) - name = fn.name if hasattr(fn, "name") else fn.get("name", "?") - args = ( - fn.arguments - if hasattr(fn, "arguments") - else fn.get("arguments", "") - ) - lines.append(f" [Tool call: {name}({args})]") - elif role == "tool": - tool_name = ( - msg.get("name", "tool") if isinstance(msg, Mapping) else "tool" - ) - lines.append(f" [Tool result ({tool_name}): {content[:200]}]") - return "\n".join(lines) - - async def reset_conversation_history(self): - """Resets the conversation history and skill activation state. - - When the conversation is reset, the LLM loses all memory of which skills - were activated and why. So we must also reset the skill registry and MCP - registry to ensure the progressive disclosure state matches the LLM's memory. - """ - # Extract memories before clearing (if enabled and enough messages) - if self.persistent_memory_service and len(self.messages) >= 4: - try: - await self.persistent_memory_service.extract_memories( - self.messages, generate_summary=True - ) - except Exception: - pass # Don't let extraction failure prevent reset - - self.messages = [] - self.conversation_summary = "" - self._last_prompt_tokens = 0 - # Keep _support_token_ratio — it's model-specific, not conversation-specific - self.skill_registry.reset_activations() - self.mcp_registry.reset_activations() - - async def _try_instant_activation(self, transcript: str) -> (str, bool): - """Tries to execute an instant activation command if present in the transcript. - - Args: - transcript (str): The transcript to check for an instant activation command. - - Returns: - tuple[str, bool]: A tuple containing the response to the instant command and a boolean indicating whether an instant command was executed. - """ - commands = await self._execute_instant_activation_command(transcript) - if commands: - await self.add_forced_assistant_command_calls(commands) - responses = [] - for command in commands: - if command.responses: - responses.append(self._select_instant_command_response(command)) - - if len(responses) == len(commands): - # clear duplicates - responses = list(dict.fromkeys(responses)) - responses = [ - response + "." if not response.endswith(".") else response - for response in responses - ] - return " ".join(responses), True - - return None, True - - return None, False - - async def get_context(self): - """Build the context and inserts it into the messages. - - With progressive disclosure, only includes prompts from ACTIVATED skills. - Skill prompts are auto-generated from @tool descriptions if no custom prompt is set. - """ - skill_prompts = "" - active_skill_names = self.skill_registry.active_skill_names - - for skill in self.skills: - # Only include prompts from activated skills (in progressive mode) - if skill.name not in active_skill_names: - continue - - # Get custom prompt if set - prompt = await skill.get_prompt() - - # Auto-generate prompt from tool descriptions if no custom prompt - if not prompt: - tools_desc = skill.get_tools_description() - if tools_desc: - prompt = f"Available tools:\n{tools_desc}" - - if prompt: - skill_prompts += "\n\n" + skill.name + "\n\n" + prompt - - # Get TTS prompt based on active TTS provider and user preference - tts_prompt = "" - if self.config.features.tts_provider == TtsProvider.ELEVENLABS: - if ( - self.config.elevenlabs.use_tts_prompt - and self.config.elevenlabs.tts_prompt - ): - tts_prompt = self.config.elevenlabs.tts_prompt - elif self.config.features.tts_provider == TtsProvider.INWORLD or ( - self.config.features.tts_provider == TtsProvider.WINGMAN_PRO - and self.config.wingman_pro.tts_provider == WingmanProTtsProvider.INWORLD - ): - if self.config.inworld.use_tts_prompt and self.config.inworld.tts_prompt: - tts_prompt = self.config.inworld.tts_prompt - elif self.config.features.tts_provider == TtsProvider.OPENAI_COMPATIBLE: - if ( - self.config.openai_compatible_tts.use_tts_prompt - and self.config.openai_compatible_tts.tts_prompt - ): - tts_prompt = self.config.openai_compatible_tts.tts_prompt - - # Add TTS header only if there's a prompt - if tts_prompt: - tts_prompt = "# TEXT-TO-SPEECH\n" + tts_prompt - - # Build user context with environment metadata - user_context = self._build_user_context() - - # Sanity check: truncate if someone bypasses the client's 2048-token limit - MAX_BACKSTORY_TOKENS = 2048 - backstory = self.config.prompts.backstory - - if backstory and count_tokens(backstory) > MAX_BACKSTORY_TOKENS: - original_tokens = count_tokens(backstory) - backstory = truncate_to_tokens(backstory, MAX_BACKSTORY_TOKENS) - await printr.print_async( - f"[{self.name}] Backstory will be truncated to {MAX_BACKSTORY_TOKENS} tokens for conversations (is {original_tokens}). " - f"Your saved backstory is unchanged. Consider shortening it.", - color=LogType.WARNING, - source_name=self.name, - source=LogSource.SYSTEM, - ) - - # Build conversation summary section - conversation_summary = "" - if self.conversation_summary: - conversation_summary = ( - "# CONVERSATION SUMMARY\n" - "The following is a summary of earlier parts of this conversation. " - "Treat it as factual context — the user and you discussed these topics previously.\n\n" - + self.conversation_summary - ) - - # Persistent memory injection - persistent_memory_context = "" - if self.persistent_memory_service and self.messages: - # Use the most recent user message as the query - last_user_msg = "" - for msg in reversed(self.messages): - role = msg.get("role") if isinstance(msg, dict) else getattr(msg, "role", None) - raw_content = msg.get("content", "") if isinstance(msg, dict) else getattr(msg, "content", "") - # Extract plain text from multimodal content (images etc.) - content = self._extract_text_content(raw_content) if raw_content else "" - if role == "user" and content: - last_user_msg = content - break - if last_user_msg: - try: - persistent_memory_context = await self.persistent_memory_service.build_memory_context(last_user_msg) - if persistent_memory_context and not self._memory_recall_notified: - self._memory_recall_notified = True - # Count restored fact lines (lines starting with "- ") - fact_count = sum( - 1 for line in persistent_memory_context.splitlines() - if line.startswith("- ") - ) - if fact_count > 0: - await printr.print_async( - f"Memory recalled: {fact_count} relevant {'memory' if fact_count == 1 else 'memories'} loaded.", - color=LogType.MEMORY, - source_name=self.name, - ) - except Exception: - pass # Don't let memory failures break conversation - - context = self.config.prompts.system_prompt.format( - backstory=backstory, - skills=skill_prompts, - ttsprompt=tts_prompt, - user_context=user_context, - conversation_summary=conversation_summary, - ) - - # If the system prompt template doesn't include {conversation_summary}, - # append the summary at the end so it's never lost. - if ( - conversation_summary - and "{conversation_summary}" not in self.config.prompts.system_prompt - ): - context += "\n\n" + conversation_summary - - # Append persistent memory context - if persistent_memory_context: - context += "\n\n" + persistent_memory_context - - # Persistent memory tool instructions - if self.persistent_memory_service: - context += ( - "\n\n# PERSISTENT MEMORY\n" - "You have persistent memory. Important facts and past conversation summaries " - "are provided in the [Memory] sections above (if any). " - "You can use the `remember`, `recall`, and `forget` tools when the user " - "explicitly asks you to remember, recall, or forget something. " - "You don't need to use `remember` for routine information — that is handled automatically." - ) - - self._last_compiled_context = context - return context - - def get_last_context(self) -> str: - """Return the last compiled system context (cached from the most recent LLM call).""" - return self._last_compiled_context - - def get_conversation_messages(self, strip_nulls: bool = True) -> list[dict]: - """Return the conversation messages as a list of plain dicts for debugging.""" - - def _strip_none(obj): - if isinstance(obj, dict): - return {k: _strip_none(v) for k, v in obj.items() if v is not None} - if isinstance(obj, list): - return [_strip_none(item) for item in obj] - return obj - - result = [] - for msg in self.messages: - if hasattr(msg, "model_dump"): - d = msg.model_dump() - else: - d = msg - if strip_nulls: - d = _strip_none(d) - result.append(d) - return result - - def _build_user_context(self) -> str: - """Build user context metadata for the system prompt. - - Includes timezone, config context, username, and wingman name. - """ - context_parts = [] - backstory = self.config.prompts.backstory or "" - backstory_lower = backstory.lower() - - # Date and timezone information - try: - now = datetime.now().astimezone() - local_tz = now.tzinfo - tz_name = str(local_tz) - # Get UTC offset in a readable format - utc_offset = now.strftime("%z") - # Format as +HH:MM or -HH:MM - if len(utc_offset) >= 5: - utc_offset = f"{utc_offset[:3]}:{utc_offset[3:]}" - # Include current date for relative date references ("last Sunday", "tomorrow", etc.) - current_date = now.strftime( - "%A, %B %d, %Y" - ) # e.g., "Tuesday, December 09, 2025" - context_parts.append(f"- Current date: {current_date}") - context_parts.append(f"- Timezone: {tz_name} (UTC{utc_offset})") - except Exception: - context_parts.append("- Timezone: Unknown") - - # Config/context name (e.g., "Star Citizen", "Elite Dangerous") - # This helps the LLM understand which game/context tools are relevant for - if self.tower and self.tower.config_dir and self.tower.config_dir.name: - context_parts.append(f"- Active context: {self.tower.config_dir.name}") - - # Username (only if not explicitly named in backstory) - if self.settings.user_name: - # Check if username is mentioned in backstory as a standalone word - import re - - name_pattern = r"\b" + re.escape(self.settings.user_name.lower()) + r"\b" - if not re.search(name_pattern, backstory_lower): - context_parts.append( - f"- User's name (default): {self.settings.user_name}" - ) - - # Wingman name - always include as it's useful context - # The system prompt already tells LLM to prioritize backstory names - if self.name: - context_parts.append(f"- Your name (default): {self.name}") - - if context_parts: - return "\n".join(context_parts) - return "No additional context available." - - async def add_context(self, messages): - context = await self.get_context() - - messages.insert(0, {"role": "system", "content": context}) - - async def generate_image(self, text: str) -> str: - """ - Generates an image from the provided text configured provider. - """ - - if ( - self.config.features.image_generation_provider - == ImageGenerationProvider.WINGMAN_PRO - ): - try: - return await self.wingman_pro.generate_image(text) - except Exception as e: - await printr.print_async( - f"Error during image generation: {str(e)}", color=LogType.ERROR - ) - printr.print( - traceback.format_exc(), color=LogType.ERROR, server_only=True - ) - - return "" - - async def actual_llm_call(self, messages, tools: list[dict] = None): - """ - Perform the actual LLM call with the messages provided. - """ - - try: - completion = None - if self.config.features.conversation_provider == ConversationProvider.AZURE: - completion = self.openai_azure.ask( - messages=messages, - api_key=self.azure_api_keys["conversation"], - config=self.config.azure.conversation, - tools=tools, - ) - elif ( - self.config.features.conversation_provider - == ConversationProvider.OPENAI - ): - completion = self.openai.ask( - messages=messages, - tools=tools, - model=self.config.openai.conversation_model, - ) - elif ( - self.config.features.conversation_provider - == ConversationProvider.MISTRAL - ): - completion = self.mistral.ask( - messages=messages, - tools=tools, - model=self.config.mistral.conversation_model, - ) - elif ( - self.config.features.conversation_provider == ConversationProvider.GROQ - ): - completion = self.groq.ask( - messages=messages, - tools=tools, - model=self.config.groq.conversation_model, - ) - elif ( - self.config.features.conversation_provider - == ConversationProvider.CEREBRAS - ): - completion = self.cerebras.ask( - messages=messages, - tools=tools, - model=self.config.cerebras.conversation_model, - ) - elif ( - self.config.features.conversation_provider - == ConversationProvider.GOOGLE - ): - completion = self.google.ask( - messages=messages, - tools=tools, - model=self.config.google.conversation_model, - ) - elif ( - self.config.features.conversation_provider - == ConversationProvider.OPENROUTER - ): - # OpenRouter throws an error if the model doesn't support tools but we send some - if self.openrouter_model_supports_tools: - completion = self.openrouter.ask( - messages=messages, - tools=tools, - model=self.config.openrouter.conversation_model, - ) - else: - completion = self.openrouter.ask( - messages=messages, - model=self.config.openrouter.conversation_model, - ) - elif ( - self.config.features.conversation_provider - == ConversationProvider.LOCAL_LLM - ): - if not self.local_llm: - raise RuntimeError( - f"Local LLM provider is not initialized. " - f"Please check your Local LLM endpoint configuration ({self.config.local_llm.endpoint})." - ) - completion = self.local_llm.ask( - messages=messages, - tools=tools, - model=self.config.local_llm.conversation_model, - ) - elif ( - self.config.features.conversation_provider - == ConversationProvider.WINGMAN_PRO - ): - completion = self.wingman_pro.ask( - messages=messages, - deployment=self.config.wingman_pro.conversation_deployment, - tools=tools, - ) - elif ( - self.config.features.conversation_provider - == ConversationProvider.PERPLEXITY - ): - completion = self.perplexity.ask( - messages=messages, - tools=tools, - model=self.config.perplexity.conversation_model.value, - ) - elif self.config.features.conversation_provider == ConversationProvider.XAI: - completion = self.xai.ask( - messages=messages, - tools=tools, - model=self.config.xai.conversation_model, - ) - except APIConnectionError as e: - provider = self.config.features.conversation_provider.value - # Dig out the underlying cause for a useful message - cause = e.__cause__ - detail = str(cause) if cause else str(e) - message = ( - f"Could not connect to {provider}: {detail}" - ) - await printr.print_async( - message, - color=LogType.ERROR, - source=LogSource.WINGMAN, - source_name=self.name, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - return None - except Exception as e: - await printr.print_async( - f"Error during LLM call: {str(e)}", - color=LogType.ERROR, - source=LogSource.WINGMAN, - source_name=self.name, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - return None - - return completion - - async def _llm_call(self, allow_tool_calls: bool = True): - """Makes the primary LLM call with the conversation history and tools enabled. - - Returns: - The LLM completion object or None if the call fails. - """ - - # save request time for later comparison - thiscall = time.time() - self.last_gpt_call = thiscall - - # build tools - tools = self.build_tools() if allow_tool_calls else None - - if self.settings.debug_mode: - await printr.print_async( - f"Calling LLM with {(len(self.messages))} messages (excluding context) and {len(tools) if tools else 0} tools.", - color=LogType.INFO, - ) - - messages = self.messages.copy() - await self.add_context(messages) - - # DEBUG: Print compiled context (dev-only, remove before release) - # if messages and messages[0].get("role") == "system": - # print("\n" + "=" * 80) - # print("COMPILED CONTEXT:") - # print("=" * 80) - # print(messages[0].get("content", "")) - # print("=" * 80 + "\n") - - completion = await self.actual_llm_call(messages, tools) - - # if request isnt most recent, ignore the response - if self.last_gpt_call != thiscall: - await printr.print_async( - "LLM call was cancelled due to a new call.", color=LogType.WARNING - ) - return None - - return completion - - async def _process_completion( - self, completion: ChatCompletion, allow_tool_calls: bool = True - ): - """Processes the completion returned by the LLM call. - - Args: - completion: The completion object from an OpenAI call. - - Returns: - A tuple containing the message response, tool calls, and usage (prompt_tokens, completion_tokens) from the completion. - """ - - response_message = completion.choices[0].message - - content = response_message.content - if content is None: - response_message.content = "" - - # remove hallucinated tools, if none were allowed - if not allow_tool_calls: - response_message.tool_calls = None - - # temporary fix for tool calls that have a command name as function name - if response_message.tool_calls: - response_message.tool_calls = await self._fix_tool_calls( - response_message.tool_calls - ) - - # Extract token usage from the API response - prompt_tokens = 0 - completion_tokens = 0 - if completion.usage: - prompt_tokens = completion.usage.prompt_tokens or 0 - completion_tokens = completion.usage.completion_tokens or 0 - - return ( - response_message, - response_message.tool_calls, - (prompt_tokens, completion_tokens), - ) - - async def _handle_tool_calls(self, tool_calls): - """Processes all the tool calls identified in the response message. - - Args: - tool_calls: The list of tool calls to process. - - Returns: - tuple: (instant_response, skill, tool_timings) where tool_timings is a list of (label, time_ms) tuples. - """ - instant_response = None - function_response = "" - tool_timings: list[tuple[str, float]] = [] - - skill = None - - for tool_call in tool_calls: - try: - function_name = tool_call.function.name - function_args = ( - tool_call.function.arguments - # Mistral returns a dict - if isinstance(tool_call.function.arguments, dict) - # OpenAI returns a string - else json.loads(tool_call.function.arguments) - ) - - # Time the individual tool execution - tool_start = time.perf_counter() - ( - function_response, - instant_response, - skill, - tool_label, - ) = await self.execute_command_by_function_call( - function_name, function_args - ) - tool_time_ms = (time.perf_counter() - tool_start) * 1000 - - # Add timing if we got a label (actual tool execution, not meta-tool) - if tool_label: - tool_timings.append((tool_label, tool_time_ms)) - - # Compress large tool responses via local AI before the cloud LLM sees them - if ( - tool_call.id - and self.config.features.compress_tool_responses - and self.local_ai_service - and self.local_ai_service.is_ready() - and self._tool_response_compressor.should_compress( - str(function_response) - ) - ): - function_response = await self._tool_response_compressor.compress( - response_text=str(function_response), - local_ai_service=self.local_ai_service, - wingman_name=self.name, - tool_name=function_name, - ) - - if tool_call.id: - # updating the dummy tool response with the actual response - await self._update_tool_response(tool_call.id, function_response) - else: - # adding a new tool response - self._add_tool_response(tool_call, function_response) - except Exception as e: - self._add_tool_response(tool_call, "Error") - await printr.print_async( - f"Error while processing tool call: {str(e)}", color=LogType.ERROR - ) - printr.print( - traceback.format_exc(), color=LogType.ERROR, server_only=True - ) - return instant_response, skill, tool_timings - - async def execute_command_by_function_call( - self, function_name: str, function_args: dict[str, any] - ) -> tuple[str, str | None, Skill | None, str | None]: - """ - Uses an OpenAI function call to execute a command. If it's an instant activation_command, one if its responses will be played. - - Args: - function_name (str): The name of the function to be executed. - function_args (dict[str, any]): The arguments to pass to the function being executed. - - Returns: - A tuple containing: - - function_response (str): The text response or result obtained after executing the function. - - instant_response (str): An immediate response or action to be taken, if any (e.g., play audio). - - used_skill (Skill): The skill that was used, if any. - - tool_label (str): Label for benchmark timing (e.g., "MCP: resolve-library-id"), or None for meta-tools. - """ - function_response = "" - instant_response = "" - used_skill = None - tool_label = None - - # Handle persistent memory tools - if function_name in ("memory_remember", "memory_recall", "memory_forget") and self.persistent_memory_service: - if function_name == "memory_remember": - text = function_args.get("text", "") - if text: - await self.persistent_memory_service.add_memory( - entry_type="fact", content=text - ) - function_response = f"I'll remember that: \"{text}\"" - await printr.print_async( - f"Memory stored: {text}", - color=LogType.MEMORY, - source_name=self.name, - ) - else: - function_response = "Nothing to remember — no text provided." - - elif function_name == "memory_recall": - query = function_args.get("query", "") - if query: - results = await self.persistent_memory_service.search(query, limit=10) - if results: - lines = [f"- {r.content}" for r in results] - function_response = "Here's what I remember:\n" + "\n".join(lines) - else: - function_response = "I don't have any memories matching that." - else: - function_response = "No query provided for memory recall." - - elif function_name == "memory_forget": - query = function_args.get("query", "") - if query: - deleted = await self.persistent_memory_service.forget_by_query(query) - if deleted: - function_response = f"Done — I've forgotten the memory related to \"{query}\"." - else: - function_response = "I couldn't find a memory closely matching that to forget." - else: - function_response = "No query provided for memory forget." - - return function_response, None, None, f"💾 memory: {function_name}" - - # Handle unified capability meta-tools (activate_capability, list_active_capabilities) - if self.capability_registry.is_meta_tool(function_name): - function_response, tools_changed = ( - await self.capability_registry.execute_meta_tool( - function_name, function_args - ) - ) - - # If a skill was activated, perform lazy validation - if tools_changed and function_name == "activate_capability": - capability_name = function_args.get("capability_name", "") - skill = self.skill_registry.get_skill_for_activation(capability_name) - if skill and skill.needs_activation(): - success, validation_msg = await skill.ensure_activated() - if not success: - # Validation failed - deactivate the skill - self.skill_registry.deactivate_skill(capability_name) - function_response = validation_msg - tools_changed = False - await printr.print_async( - f"Skill activation failed: {capability_name}", - color=LogType.ERROR, - ) - else: - # Get display name for user-friendly message - display_name = self.skill_registry.get_skill_display_name( - capability_name - ) - await printr.print_async( - f"Skill activated: {display_name}", - color=LogType.SKILL, - ) - - return function_response, None, None, None # Meta-tool, no timing label - - # Handle legacy meta-tools for progressive skill discovery/activation - # These are kept for backward compatibility but shouldn't be called - if self.skill_registry.is_meta_tool(function_name): - function_response, tools_changed = ( - await self.skill_registry.execute_meta_tool( - function_name, function_args - ) - ) - - # If skill was activated, perform lazy validation - if tools_changed and function_name == "activate_skill": - skill_name = function_args.get("skill_name", "") - skill = self.skill_registry.get_skill_for_activation(skill_name) - if skill and skill.needs_activation(): - success, validation_msg = await skill.ensure_activated() - if not success: - # Validation failed - deactivate the skill - self.skill_registry.deactivate_skill(skill_name) - function_response = validation_msg - tools_changed = False - await printr.print_async( - f"Skill activation failed: {skill_name}", - color=LogType.ERROR, - ) - else: - # Get display name for user-friendly message - display_name = self.skill_registry.get_skill_display_name( - skill_name - ) - await printr.print_async( - f"Skill activated: {display_name}", - color=LogType.SKILL, - ) - - return function_response, None, None, None # Meta-tool, no timing label - - # Handle MCP meta-tools for server discovery/activation - if self.mcp_registry.is_meta_tool(function_name): - function_response, tools_changed = ( - await self.mcp_registry.execute_meta_tool(function_name, function_args) - ) - return function_response, None, None, None # Meta-tool, no timing label - - # Handle MCP server tools (prefixed with mcp_) - if self.mcp_registry.is_mcp_tool(function_name): - connection = self.mcp_registry.get_connection_for_tool(function_name) - if connection: - display_name = connection.config.display_name - original_name = self.mcp_registry.get_original_tool_name(function_name) - tool_label = f"🌐 {display_name}: {original_name}" - - benchmark = Benchmark( - f"MCP '{connection.config.name}' - {original_name}" - ) - - # Always show simple 'called' message in UI so users know the wingman is working - await printr.print_async( - f"{display_name}: called `{original_name}` with {function_args}", - color=LogType.MCP, - ) - - # Detailed 'calling' log only in terminal/log file - await printr.print_async( - f"{display_name}: calling `{original_name}` with {function_args}...", - color=LogType.MCP, - server_only=True, - ) - - try: - function_response = await self.mcp_registry.call_tool( - function_name, function_args - ) - except Exception as e: - await printr.print_async( - f"{display_name}: `{original_name}` failed - {str(e)}", - color=LogType.ERROR, - ) - printr.print( - traceback.format_exc(), color=LogType.ERROR, server_only=True - ) - function_response = "ERROR DURING MCP TOOL EXECUTION" - finally: - # Detailed 'completed' with timing only in terminal/log file (or UI if debug) - await printr.print_async( - f"{display_name}: `{original_name}` completed", - color=LogType.MCP, - benchmark_result=benchmark.finish(), - server_only=not self.settings.debug_mode, - ) - - return function_response, None, None, tool_label - - # Handle command calls - if function_name == "execute_command": - # get the command based on the argument passed by the LLM - command = self.get_command(function_args["command_name"]) - # execute the command - instant_response, function_response = await self._execute_command(command) - tool_label = f"Command: {function_args.get('command_name', function_name)}" - # if the command has responses, we have to play one of them - if instant_response: - await self.play_to_user(instant_response) - - # Go through the skills and check if the function name matches any of the tools - if function_name in self.tool_skills: - skill = self.tool_skills[function_name] - display_name = self.skill_registry.get_skill_display_name(skill.name) - tool_label = f"⚡ {display_name}: {function_name}" - - benchmark = Benchmark(f"Skill '{skill.name}' - {function_name}") - - # Always show simple 'called' message in UI so users know the wingman is working - await printr.print_async( - f"{display_name}: called `{function_name}` with {function_args}", - color=LogType.SKILL, - skill_name=skill.name, - ) - - # Detailed 'calling' log only in terminal/log file - await printr.print_async( - f"{display_name}: calling `{function_name}` with {function_args}...", - color=LogType.SKILL, - skill_name=skill.name, - server_only=True, - ) - - try: - function_response, instant_response = await skill.execute_tool( - tool_name=function_name, - parameters=function_args, - benchmark=benchmark, - ) - used_skill = skill - if instant_response: - await self.play_to_user(instant_response) - except Exception as e: - await printr.print_async( - f"{display_name}: `{function_name}` failed - {str(e)}", - color=LogType.ERROR, - ) - printr.print( - traceback.format_exc(), color=LogType.ERROR, server_only=True - ) - function_response = ( - "ERROR DURING PROCESSING" # hints to AI that there was an error - ) - instant_response = None - finally: - await printr.print_async( - f"{display_name}: `{function_name}` completed", - color=LogType.SKILL, - benchmark_result=benchmark.finish(), - skill_name=skill.name, - server_only=not self.settings.debug_mode, - ) - - return function_response, instant_response, used_skill, tool_label - - async def play_to_user( - self, - text: str, - no_interrupt: bool = False, - sound_config: Optional[SoundConfig] = None, - ): - """Plays audio to the user using the configured TTS Provider (default: OpenAI TTS). - Also adds sound effects if enabled in the configuration. - - Args: - text (str): The text to play as audio. - """ - if sound_config: - printr.print( - "Using custom sound config for playback", LogType.INFO, server_only=True - ) - else: - sound_config = self.config.sound - - # remove Markdown, links, emotes and code blocks - text, contains_links, contains_code_blocks = cleanup_text(text) - - # wait for audio player to finish playing - if no_interrupt and self.audio_player.is_playing: - while self.audio_player.is_playing: - await asyncio.sleep(0.1) - - # call skill hooks (only for prepared/activated skills) - changed_text = text - for skill in self.skills: - if skill.is_prepared: - changed_text = await skill.on_play_to_user(text, sound_config) - if changed_text != text: - printr.print( - f"Skill '{skill.config.display_name}' modified the text to: '{changed_text}'", - LogType.INFO, - ) - text = changed_text - - if sound_config.volume == 0.0: - printr.print( - "Volume modifier is set to 0. Skipping TTS processing.", - LogType.WARNING, - server_only=True, - ) - return - - if "{SKIP-TTS}" in text: - printr.print( - "Skip TTS phrase found in input. Skipping TTS processing.", - LogType.WARNING, - server_only=True, - ) - return - - try: - if self.config.features.tts_provider == TtsProvider.EDGE_TTS: - await self.edge_tts.play_audio( - text=text, - config=self.config.edge_tts, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - ) - elif self.config.features.tts_provider == TtsProvider.ELEVENLABS: - await self.elevenlabs.play_audio( - text=text, - config=self.config.elevenlabs, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - stream=self.config.elevenlabs.output_streaming, - ) - elif self.config.features.tts_provider == TtsProvider.HUME: - try: - await self.hume.play_audio( - text=text, - config=self.config.hume, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - ) - except RuntimeError as e: - if "Event loop is closed" in str(e): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - await self.hume.play_audio( - text=text, - config=self.config.hume, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - ) - elif self.config.features.tts_provider == TtsProvider.INWORLD: - await self.inworld.play_audio( - text=text, - config=self.config.inworld, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - ) - elif self.config.features.tts_provider == TtsProvider.AZURE: - await self.openai_azure.play_audio( - text=text, - api_key=self.azure_api_keys["tts"], - config=self.config.azure.tts, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - ) - elif self.config.features.tts_provider == TtsProvider.XVASYNTH: - await self.xvasynth.play_audio( - text=text, - config=self.config.xvasynth, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - ) - elif self.config.features.tts_provider == TtsProvider.OPENAI: - await self.openai.play_audio( - text=text, - voice=self.config.openai.tts_voice, - model=self.config.openai.tts_model, - speed=self.config.openai.tts_speed, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - stream=self.config.openai.output_streaming, - ) - elif self.config.features.tts_provider == TtsProvider.OPENAI_COMPATIBLE: - await self.openai_compatible_tts.play_audio( - text=text, - voice=self.config.openai_compatible_tts.voice, - model=self.config.openai_compatible_tts.model, - speed=( - self.config.openai_compatible_tts.speed - if self.config.openai_compatible_tts.speed #!= 1.0 - else NOT_GIVEN - ), - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - stream=self.config.openai_compatible_tts.output_streaming, - ) - elif self.config.features.tts_provider == TtsProvider.POCKET_TTS: - await self.pocket_tts.play_audio( - text=text, - config=self.config.pocket_tts, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - ) - elif self.config.features.tts_provider == TtsProvider.WINGMAN_PRO: - if self.config.wingman_pro.tts_provider == WingmanProTtsProvider.OPENAI: - await self.wingman_pro.generate_openai_speech( - text=text, - voice=self.config.openai.tts_voice, - model=self.config.openai.tts_model, - speed=self.config.openai.tts_speed, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - ) - elif ( - self.config.wingman_pro.tts_provider == WingmanProTtsProvider.AZURE - ): - await self.wingman_pro.generate_azure_speech( - text=text, - config=self.config.azure.tts, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - ) - elif ( - self.config.wingman_pro.tts_provider - == WingmanProTtsProvider.INWORLD - ): - await self.wingman_pro.generate_inworld_speech( - text=text, - config=self.config.inworld, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - ) - else: - printr.toast_error( - f"Unsupported TTS provider: {self.config.features.tts_provider}" - ) - except Exception as e: - await printr.print_async( - f"Error during TTS playback: {str(e)}", color=LogType.ERROR - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - - async def _execute_command( - self, command: CommandConfig, is_instant=False - ) -> tuple[str | None, str]: - """Executes a command by delegating to the Wingman base implementation. - - Returns: - tuple[str | None, str]: A 2-tuple of: - - Instant response (str) to play immediately, or None if there is no instant response. - - Function/tool response (str) to feed back to the LLM. - """ - return await super()._execute_command(command, is_instant) - - def build_tools(self) -> list[dict]: - """ - Builds tools for the LLM call. - - In progressive mode: Returns meta-tools (search_skills, activate_skill) plus - tools from activated skills only. - - In legacy mode: Returns all skill tools. - - Returns: - list[dict]: A list of tool descriptors in OpenAI format. - """ - - def _command_has_effective_actions(command: CommandConfig) -> bool: - if command.is_system_command: - return True - - if not command.actions: - return False - - for action in command.actions: - if not action: - continue - if ( - action.keyboard is not None - or action.mouse is not None - or action.joystick is not None - or action.audio is not None - or action.write is not None - or action.wait is not None - ): - return True - - return False - - commands = [ - command.name - for command in self.config.commands - if (not command.force_instant_activation) - and _command_has_effective_actions(command) - ] - tools: list[dict] = [] - if commands: - tools.append( - { - "type": "function", - "function": { - "name": "execute_command", - "description": "Executes a command", - "parameters": { - "type": "object", - "properties": { - "command_name": { - "type": "string", - "description": "The name of the command to execute", - "enum": commands, - }, - }, - "required": ["command_name"], - }, - }, - } - ) - - # Unified capability discovery: single activate_capability meta-tool - # Combines skills and MCP servers - LLM doesn't need to know the difference - for _, tool in self.capability_registry.get_meta_tools(): - tools.append(tool) - - # Add tools from activated capabilities (both skills and MCPs) - for _, tool in self.skill_registry.get_active_tools(): - tools.append(tool) - - for _, tool in self.mcp_registry.get_active_tools(): - tools.append(tool) - - # Persistent memory tools — auto-injected when enabled - if self.persistent_memory_service: - tools.append({ - "type": "function", - "function": { - "name": "memory_remember", - "description": "Store an important fact or detail for future reference. Use when the user explicitly asks you to remember something.", - "parameters": { - "type": "object", - "properties": { - "text": { - "type": "string", - "description": "The fact or detail to remember.", - }, - }, - "required": ["text"], - }, - }, - }) - tools.append({ - "type": "function", - "function": { - "name": "memory_recall", - "description": "Search your memory for relevant information. Use when the user asks what you remember or know about a topic.", - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "What to search for in memory.", - }, - }, - "required": ["query"], - }, - }, - }) - tools.append({ - "type": "function", - "function": { - "name": "memory_forget", - "description": "Remove a specific memory. Use when the user explicitly asks you to forget something.", - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "Description of the memory to forget.", - }, - }, - "required": ["query"], - }, - }, - }) - - return tools - - def __get_message_role(self, message): - """Helper method to get the role of the message regardless of its type.""" - if isinstance(message, Mapping): - return message.get("role") - elif hasattr(message, "role"): - return message.role - else: - raise TypeError( - f"Message is neither a mapping nor has a 'role' attribute: {message}" - ) +__all__ = ["OpenAiWingman"] diff --git a/wingmen/wingman.py b/wingmen/wingman.py index 82b96d8a5..ee9c41d94 100644 --- a/wingmen/wingman.py +++ b/wingmen/wingman.py @@ -1,18 +1,23 @@ -import traceback -from copy import deepcopy -import random +"""Unified Wingman class. + +Merges the former base ``Wingman`` and its only subclass ``OpenAiWingman`` +into a single class that delegates to extracted services and provider +interfaces for STT, TTS, and LLM. +""" + import time -import difflib import asyncio +import traceback import threading +from copy import deepcopy from typing import ( Any, Dict, Optional, TYPE_CHECKING, ) -import keyboard.keyboard as keyboard -import mouse.mouse as mouse +from openai import APIConnectionError +from openai.types.chat import ChatCompletion from api.interface import ( CommandConfig, SettingsConfig, @@ -23,21 +28,36 @@ ) from api.enums import ( CommandTag, + ConversationProvider, + ImageGenerationProvider, LogSource, LogType, + SttProvider, + TtsProvider, WingmanInitializationErrorType, ) -from providers.faster_whisper import FasterWhisper -from providers.parakeet import Parakeet -from providers.whispercpp import Whispercpp -from providers.xvasynth import XVASynth -from providers.pocket_tts import PocketTTS +from providers.interfaces import LlmInterface, SttInterface, TtsInterface from services.audio_player import AudioPlayer from services.benchmark import Benchmark -from services.module_manager import ModuleManager +from services.markdown import cleanup_text from services.secret_keeper import SecretKeeper from services.printr import Printr from services.audio_library import AudioLibrary +from services.conversation_manager import ConversationManager +from services.conversation_condenser import ConversationCondenser +from services.context_builder import ContextBuilder +from services.command_executor import CommandExecutor +from services.tool_executor import ToolExecutor +from services.provider_factory import ProviderFactory +from services.skill_registry import SkillRegistry +from services.threading_utils import threaded_execution +from services.mcp_client import McpClient +from services.capability_registry import CapabilityRegistry +from services.wingman_mcp_manager import WingmanMcpManager +from services.wingman_skill_manager import WingmanSkillManager, _get_skill_folder_from_module +from services.tool_response_cache import ToolResponseCompressor +from services.turn_metrics import TurnMetrics +from services.instant_response_generator import InstantResponseGenerator from skills.skill_base import Skill if TYPE_CHECKING: @@ -46,17 +66,24 @@ printr = Printr() -def _get_skill_folder_from_module(module: str) -> str: - """Extract folder name from module path like 'skills.star_head.main' -> 'star_head'""" - return module.replace(".main", "").replace(".", "/").split("/")[1] - - class Wingman: - """The "highest" Wingman base class in the chain. It does some very basic things but is meant to be 'virtual', and so are most its methods, so you'll probably never instantiate it directly. + """Unified Wingman class. + + Handles lifecycle, process loop, audio, command execution, config + save/load, skill management, provider routing, conversation management, + tool execution, context building, and condensation. - Instead, you'll create a custom wingman that inherits from this (or a another subclass of it) and override its methods if needed. + Providers are resolved via :class:`ProviderFactory` into three + interface slots: ``stt``, ``tts``, ``llm``. Heavy orchestration logic + is delegated to extracted service objects. """ + AZURE_SERVICES = { + "tts": None, # kept for potential future use + "whisper": None, + "conversation": None, + } + def __init__( self, name: str, @@ -64,81 +91,140 @@ def __init__( settings: SettingsConfig, audio_player: AudioPlayer, audio_library: AudioLibrary, - whispercpp: Whispercpp, - fasterwhisper: FasterWhisper, - parakeet: Parakeet, - xvasynth: XVASynth, - pocket_tts: PocketTTS, - tower: "Tower", + whispercpp=None, + fasterwhisper=None, + parakeet=None, + xvasynth=None, + pocket_tts=None, + tower: "Tower" = None, ): - """The constructor of the Wingman class. You can override it in your custom wingman. - - Args: - name (str): The name of the wingman. This is the key you gave it in the config, e.g. "atc" - config (WingmanConfig): All "general" config entries merged with the specific Wingman config settings. The Wingman takes precedence and overrides the general config. You can just add new keys to the config and they will be available here. - """ - self.config = config - """All "general" config entries merged with the specific Wingman config settings. The Wingman takes precedence and overrides the general config. You can just add new keys to the config and they will be available here.""" - self.settings = settings - """The general user settings.""" + self.name = name + self.audio_player = audio_player + self.audio_library = audio_library + self.tower = tower self.secret_keeper = SecretKeeper() - """A service that allows you to store and retrieve secrets like API keys. It can prompt the user for secrets if necessary.""" self.secret_keeper.secret_events.subscribe( "secrets_saved", self.handle_secret_saved ) - self.name = name - """The name of the wingman. This is the key you gave it in the config, e.g. "atc".""" - - self.audio_player = audio_player - """A service that allows you to play audio files and add sound effects to them.""" + # Shared provider singletons (passed from Tower) + self._shared_providers = { + "whispercpp": whispercpp, + "fasterwhisper": fasterwhisper, + "parakeet": parakeet, + "xvasynth": xvasynth, + "pocket_tts": pocket_tts, + } + + # --- Provider interface slots (populated by validate → ProviderFactory) --- + self.stt: SttInterface | None = None + self.tts: TtsInterface | None = None + self.llm: LlmInterface | None = None + + # --- Extracted services --- + self.conversation = ConversationManager(config, settings, name) + self.condenser = ConversationCondenser(self.conversation, config, name) + self.context_builder = ContextBuilder(config, settings, name) + self.tool_executor = ToolExecutor(config, settings, name) + self.command_executor = CommandExecutor( + config=config, + audio_library=audio_library, + wingman_name=name, + on_reset_history=self.reset_conversation_history, + on_add_forced_commands=self.conversation.add_forced_assistant_command_calls, + ) - self.audio_library = audio_library - """A service that allows you to play and manage audio files from the audio library.""" + # --- Metrics service --- + self.metrics = TurnMetrics( + wingman_name=name, + config=config, + conversation=self.conversation, + ) self.execution_start: None | float = None - """Used for benchmarking executon times. The timer is (re-)started whenever the process function starts.""" - self.whispercpp = whispercpp - """A class that handles the communication with the Whispercpp server for transcription.""" + # --- Skills --- + self.skill_registry = SkillRegistry() + self.skill_manager = WingmanSkillManager( + wingman=self, + config=config, + settings=settings, + skill_registry=self.skill_registry, + ) - self.fasterwhisper = fasterwhisper - """A class that handles local transcriptions using FasterWhisper.""" + # --- MCP --- + self.mcp_client = McpClient(wingman_name=self.name) + self.mcp_manager = WingmanMcpManager( + wingman_name=self.name, + mcp_client=self.mcp_client, + secret_keeper=self.secret_keeper, + get_mcp_config=lambda: self.tower.config_manager.mcp_config if self.tower else None, + settings=self.settings, + config=self.config, + ) - self.parakeet = parakeet - """A class that handles local transcriptions using NVIDIA Parakeet TDT via ONNX Runtime.""" + # --- Unified capability registry --- + self.capability_registry = CapabilityRegistry( + self.skill_registry, self.mcp_registry + ) - self.xvasynth = xvasynth - """A class that handles the communication with the XVASynth server for TTS.""" + # --- Local AI / persistent memory --- + self.local_ai_service = None + self.persistent_memory_service = None + self._memory_recall_notified = False + self._background_tasks: set[asyncio.Task] = set() + self._tool_response_compressor = ToolResponseCompressor() + + # --- Image generation (lazy) --- + self._image_subscription = None + + # --- Instant response generator --- + self.instant_response_generator = InstantResponseGenerator( + wingman_name=name, + llm_call_fn=self.actual_llm_call, + get_context_fn=self.get_context, + ) - self.pocket_tts = pocket_tts - """A class that handles the communication with the PocketTTS server for TTS.""" + # --- Conversation state --- + self.last_gpt_call = None - self.tower = tower - """The Tower instance that manages all Wingmen in the same config dir.""" + # ──────────────────────────────── Backward-compat properties ──────────────── # - self.last_turn_prompt_tokens: int = 0 - self.last_turn_completion_tokens: int = 0 + @property + def mcp_registry(self): + """Backward-compat: many callers access wingman.mcp_registry directly.""" + return self.mcp_manager.mcp_registry - self.skills: list[Skill] = [] + @property + def skills(self): + return self.skill_manager.skills + + @property + def tool_skills(self): + return self.skill_manager.tool_skills + + @property + def skill_tools(self): + return self.skill_manager.skill_tools + + # ──────────────────────────────── Record keys ─────────────────────────────── # def get_record_key(self) -> str | int: - """Returns the activation or "push-to-talk" key for this Wingman.""" return self.config.record_key_codes or self.config.record_key def get_record_mouse_button(self) -> str: - """Returns the activation or "push-to-talk" mouse button for this Wingman.""" return self.config.record_mouse_button def get_record_joystick_button(self) -> str: - """Returns the activation or "push-to-talk" joystick button for this Wingman.""" if not self.config.record_joystick_button: return None return f"{self.config.record_joystick_button.guid}{self.config.record_joystick_button.button}" + # ──────────────────────────────── Secrets ──────────────────────────────────── # + async def handle_secret_saved(self, _secrets: Dict[str, Any]): await printr.print_async( text="Secret saved", @@ -147,26 +233,7 @@ async def handle_secret_saved(self, _secrets: Dict[str, Any]): ) await self.validate() - # ──────────────────────────────────── Hooks ─────────────────────────────────── # - - async def validate(self) -> list[WingmanInitializationError]: - """Use this function to validate params and config before the Wingman is started. - If you add new config sections or entries to your custom wingman, you should validate them here. - - It's a good idea to collect all errors from the base class and not to swallow them first. - - If you return MISSING_SECRET errors, the user will be asked for them. - If you return other errors, your Wingman will not be loaded by Tower. - - Returns: - list[WingmanInitializationError]: A list of errors or an empty list if everything is okay. - """ - return [] - async def retrieve_secret(self, secret_name, errors, is_required=True): - """Use this method to retrieve secrets like API keys from the SecretKeeper. - If the key is missing, the user will be prompted to enter it. - """ try: api_key = await self.secret_keeper.retrieve( requester=self.name, @@ -184,7 +251,7 @@ async def retrieve_secret(self, secret_name, errors, is_required=True): ) except Exception as e: printr.print( - f"Error retrieving secret ''{secret_name}: {e}", + f"Error retrieving secret '{secret_name}': {e}", color=LogType.ERROR, server_only=True, ) @@ -201,340 +268,137 @@ async def retrieve_secret(self, secret_name, errors, is_required=True): return api_key - async def prepare(self): - """This method is called only once when the Wingman is instantiated by Tower. - It is run AFTER validate() and AFTER init_skills() so you can access validated params safely here. + # ──────────────────────────────── Validate ─────────────────────────────────── # - You can override it if you need to load async data from an API or file.""" - - async def unload(self): - """This method is called when the Wingman is unloaded by Tower. You can override it if you need to clean up resources.""" - # Unsubscribe from secret events to prevent duplicate handlers - self.secret_keeper.secret_events.unsubscribe( - "secrets_saved", self.handle_secret_saved - ) - await self.unload_skills() + async def validate(self) -> list[WingmanInitializationError]: + errors: list[WingmanInitializationError] = [] - async def unload_skills(self): - """Call this to trigger unload for skills that were actually prepared/used.""" - for skill in self.skills: - # Only unload skills that were actually prepared (activated) - # Skills that were never used don't need cleanup - if not skill.is_prepared: - continue - try: - await skill.unload() - except Exception as e: - await printr.print_async( - f"Error unloading skill '{skill.name}': {str(e)}", - color=LogType.ERROR, - ) - printr.print( - traceback.format_exc(), color=LogType.ERROR, server_only=True + try: + factory = ProviderFactory( + config=self.config, + settings=self.settings, + secret_keeper=self.secret_keeper, + shared_providers=self._shared_providers, + wingman_name=self.name, + ) + self.stt = await factory.create_stt(errors) + self.tts = await factory.create_tts(errors) + self.llm = await factory.create_llm(errors) + except Exception as e: + errors.append( + WingmanInitializationError( + wingman_name=self.name, + message=f"Error during provider validation: {str(e)}", + error_type=WingmanInitializationErrorType.UNKNOWN, ) + ) + printr.print( + f"Error during provider validation: {str(e)}", + color=LogType.ERROR, + server_only=True, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - async def init_skills(self) -> list[WingmanInitializationError]: - """Load all available skills with lazy validation. - - Skills are loaded but NOT validated during init. Validation happens - on first activation via the SkillRegistry. User config overrides from - self.config.skills are merged with default configs. - - Platform-incompatible skills are skipped entirely. - """ - import sys - - current_platform = sys.platform # 'win32', 'darwin', 'linux' - platform_map = {"win32": "windows", "darwin": "darwin", "linux": "linux"} - normalized_platform = platform_map.get(current_platform, current_platform) - - if self.skills: - await self.unload_skills() - - errors = [] - self.skills = [] - - # Build a lookup of user config overrides by skill folder name - # The key must be the folder name (e.g., 'star_head') not the class name (e.g., 'StarHead') - user_skill_configs: dict[str, "SkillConfig"] = {} - if self.config.skills: - for skill_config in self.config.skills: - folder_name = _get_skill_folder_from_module(skill_config.module) - user_skill_configs[folder_name] = skill_config - - # Get all available skill configs - available_skills = ModuleManager.read_available_skill_configs() + return errors - # Get discoverable skills list (whitelist) - discoverable_skills = self.config.discoverable_skills + # ──────────────────────────────── Lifecycle ─────────────────────────────────── # - for ( - skill_folder_name, - skill_config_path, - _is_custom, - _is_local, - ) in available_skills: - try: - # Load default skill config first to get the display name - skill_config_dict = ModuleManager.read_config(skill_config_path) - if not skill_config_dict: - continue - - # Import SkillConfig here to avoid circular imports - from api.interface import SkillConfig - - # Check if user has overrides for this skill - if skill_folder_name in user_skill_configs: - # Merge user overrides into default config - user_config = user_skill_configs[skill_folder_name] - # User config takes precedence - merge custom_properties especially - if user_config.custom_properties: - skill_config_dict["custom_properties"] = [ - prop.model_dump() for prop in user_config.custom_properties - ] - if user_config.prompt: - skill_config_dict["prompt"] = user_config.prompt - - skill_config = SkillConfig(**skill_config_dict) - - # Check if skill is discoverable for this wingman (whitelist - must be in list) - if skill_config.name not in discoverable_skills: - continue - - # Check platform compatibility BEFORE loading the module - if skill_config.platforms: - if normalized_platform not in skill_config.platforms: - printr.print( - f"Skipping skill '{skill_config.name}' - not supported on {normalized_platform}", - color=LogType.WARNING, - server_only=True, - ) - continue - - # Load the skill module - skill = ModuleManager.load_skill( - config=skill_config, - settings=self.settings, - wingman=self, - ) - if skill: - # Set up skill methods - skill.threaded_execution = self.threaded_execution - - # Add to skills list WITHOUT validation - # Validation will happen lazily on first activation - self.skills.append(skill) - await self.prepare_skill(skill) - - except Exception as e: - skill_name = skill_folder_name - error_msg = f"Error loading skill '{skill_name}': {str(e)}" - await printr.print_async( - error_msg, - color=LogType.ERROR, - ) + async def prepare(self): + try: + if self.config.features.use_generic_instant_responses: printr.print( - traceback.format_exc(), color=LogType.ERROR, server_only=True - ) - errors.append( - WingmanInitializationError( - wingman_name=self.name, - message=error_msg, - error_type=WingmanInitializationErrorType.SKILL_INITIALIZATION_FAILED, - ) + "Generating AI instant responses...", + color=LogType.WARNING, + server_only=True, ) - - # Log summary of discoverable skills for this wingman - if self.skills: - skill_names = [s.config.name for s in self.skills] + self.threaded_execution(self.instant_response_generator.generate) + except Exception as e: await printr.print_async( - f"Discoverable skills ({len(skill_names)}): {', '.join(skill_names)}", - color=LogType.WINGMAN, - source=LogSource.WINGMAN, - source_name=self.name, - server_only=not self.settings.debug_mode, + f"Error while preparing wingman '{self.name}': {str(e)}", + color=LogType.ERROR, ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - return errors - - async def prepare_skill(self, skill: Skill): - """This method is called only once when the Skill is instantiated. - It is run AFTER validate() so you can access validated params safely here. - - You can override it if you need to react on data of this skill.""" + async def unload(self): + # Wait for any background memory extraction tasks to finish + if self._background_tasks: + await asyncio.gather(*self._background_tasks, return_exceptions=True) + self._background_tasks.clear() + + if self.persistent_memory_service: + from services.persistent_memory import MIN_MESSAGES_FOR_EXTRACTION + + if len(self.conversation.messages) >= MIN_MESSAGES_FOR_EXTRACTION: + try: + await self.persistent_memory_service.extract_memories( + self.conversation.messages, generate_summary=True + ) + except Exception: + pass + self.persistent_memory_service.close() - async def unprepare_skill(self, skill: Skill): - """Remove a skill's registration. Called when a skill is disabled. + # Unsubscribe from secret events to prevent duplicate handlers + self.secret_keeper.secret_events.unsubscribe( + "secrets_saved", self.handle_secret_saved + ) + await self.unload_skills() - Override in subclass to clean up skill-specific registrations.""" - pass + async def unload_skills(self): + return await self.skill_manager.unload_skills() - async def enable_skill(self, skill_name: str) -> tuple[bool, str]: - """Enable a single skill without reinitializing all skills. + async def unload_mcps(self): + return await self.mcp_manager.unload_mcps() - Args: - skill_name: The display name of the skill to enable + # ──────────────────────────────── Memory ──────────────────────────────────── # - Returns: - (success, message) tuple - """ - import sys + def ensure_memory_initialized(self) -> bool: + if self.persistent_memory_service and not self.config.persistent_memory: + self.persistent_memory_service.close() + self.persistent_memory_service = None + return False + if self.persistent_memory_service: + return True + if self.config.persistent_memory and self.local_ai_service: + from services.persistent_memory import PersistentMemoryService - current_platform = sys.platform - platform_map = {"win32": "windows", "darwin": "darwin", "linux": "linux"} - normalized_platform = platform_map.get(current_platform, current_platform) + self.persistent_memory_service = PersistentMemoryService( + wingman_name=self.name, + local_ai_service=self.local_ai_service, + ) + self.persistent_memory_service.initialize() + return True + return False - # Check if skill is already enabled - for existing_skill in self.skills: - if existing_skill.config.name == skill_name: - return True, f"Skill '{skill_name}' is already enabled." + # ──────────────────────────────── MCP (forwarding) ─────────────────────────── # - # Find the skill config - available_skills = ModuleManager.read_available_skill_configs() + async def enable_mcp(self, mcp_name: str) -> tuple[bool, str]: + return await self.mcp_manager.enable_mcp(mcp_name) - # Build user config lookup by skill folder name - user_skill_configs: dict[str, "SkillConfig"] = {} - if self.config.skills: - for skill_config in self.config.skills: - folder_name = _get_skill_folder_from_module(skill_config.module) - user_skill_configs[folder_name] = skill_config - - for ( - skill_folder_name, - skill_config_path, - _is_custom, - _is_local, - ) in available_skills: - try: - skill_config_dict = ModuleManager.read_config(skill_config_path) - if not skill_config_dict: - continue - - from api.interface import SkillConfig - - # Apply user overrides - if skill_folder_name in user_skill_configs: - user_config = user_skill_configs[skill_folder_name] - if user_config.custom_properties: - skill_config_dict["custom_properties"] = [ - prop.model_dump() for prop in user_config.custom_properties - ] - if user_config.prompt: - skill_config_dict["prompt"] = user_config.prompt - - skill_config = SkillConfig(**skill_config_dict) - - if skill_config.name != skill_name: - continue - - # Check platform compatibility - if skill_config.platforms: - if normalized_platform not in skill_config.platforms: - return ( - False, - f"Skill '{skill_name}' is not supported on {normalized_platform}.", - ) + async def disable_mcp(self, mcp_name: str) -> tuple[bool, str]: + return await self.mcp_manager.disable_mcp(mcp_name) - # Load and register the skill - skill = ModuleManager.load_skill( - config=skill_config, - settings=self.settings, - wingman=self, - ) - if skill: - skill.threaded_execution = self.threaded_execution - self.skills.append(skill) - await self.prepare_skill(skill) + async def init_mcps(self) -> list[WingmanInitializationError]: + return await self.mcp_manager.init_mcps() - printr.print( - f"Skill '{skill_name}' activated (loaded and made discoverable).", - color=LogType.POSITIVE, - server_only=True, - ) - return True, f"Skill '{skill_name}' activated successfully." + # ──────────────────────────────── Skills (forwarding) ─────────────────────── # - except Exception as e: - error_msg = f"Error activating skill '{skill_name}': {str(e)}" - await printr.print_async(error_msg, color=LogType.ERROR) - printr.print( - traceback.format_exc(), color=LogType.ERROR, server_only=True - ) - return False, error_msg + async def init_skills(self) -> list[WingmanInitializationError]: + return await self.skill_manager.init_skills() - return False, f"Skill '{skill_name}' not found." + async def enable_skill(self, skill_name: str) -> tuple[bool, str]: + return await self.skill_manager.enable_skill(skill_name) async def disable_skill(self, skill_name: str) -> tuple[bool, str]: - """Disable a single skill without reinitializing all skills. - - Args: - skill_name: The display name of the skill to disable - - Returns: - (success, message) tuple - """ - # Find the skill in our list - skill_to_remove = None - for skill in self.skills: - if skill.config.name == skill_name: - skill_to_remove = skill - break - - if not skill_to_remove: - return True, f"Skill '{skill_name}' is already deactivated." - - try: - # Unload the skill (cleanup resources, unsubscribe events) - await skill_to_remove.unload() - - # Remove from skill list - self.skills.remove(skill_to_remove) + return await self.skill_manager.disable_skill(skill_name) - # Remove skill-specific registrations (tools, registry, etc.) - await self.unprepare_skill(skill_to_remove) - - printr.print( - f"Skill '{skill_name}' deactivated (unloaded and removed from discoverable skills).", - color=LogType.WARNING, - server_only=True, - ) - return True, f"Skill '{skill_name}' deactivated successfully." - - except Exception as e: - error_msg = f"Error deactivating skill '{skill_name}': {str(e)}" - await printr.print_async(error_msg, color=LogType.ERROR) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - return False, error_msg - - async def reset_conversation_history(self): - """This function is called when the user triggers the ResetConversationHistory command. - It's a global command that should be implemented by every Wingman that keeps a message history. - """ - - # ──────────────────────────── The main processing loop ──────────────────────────── # + # ──────────────────────────── The main processing loop ──────────────────────── # async def process(self, audio_input_wav: str = None, transcript: str = None, images: list[tuple[str, str]] = None): - """The main method that gets called when the wingman is activated. This method controls what your wingman actually does and you can override it if you want to. - - The base implementation here triggers the transcription and processing of the given audio input. - If you don't need even transcription, you can just override this entire process method. If you want transcription but then do something in addition, you can override the listed hooks. - - Async so you can do async processing, e.g. send a request to an API. - - Args: - audio_input_wav (str): The path to the audio file that contains the user's speech. This is a recording of what you you said. - - Hooks: - - async _transcribe: transcribe the audio to text - - async _get_response_for_transcript: process the transcript and return a text response - - async play_to_user: do something with the response, e.g. play it as audio - """ - try: process_result = None benchmark_transcribe = None if not transcript: - # transcribe the audio. benchmark_transcribe = Benchmark(label="Voice transcription") transcript = await self._transcribe(audio_input_wav) @@ -554,9 +418,6 @@ async def process(self, audio_input_wav: str = None, transcript: str = None, ima additional_data=additional_data, ) - # Further process the transcript. - # Return a string that is the "answer" to your passed transcript. - benchmark_llm = Benchmark(label="Command/AI Processing") process_result, instant_response, skill, interrupt = ( await self._get_response_for_transcript( @@ -568,13 +429,12 @@ async def process(self, audio_input_wav: str = None, transcript: str = None, ima if actual_response: token_usage = None - if self.last_turn_prompt_tokens or self.last_turn_completion_tokens: + if self.metrics.last_turn_prompt_tokens or self.metrics.last_turn_completion_tokens: token_usage = ( - self.last_turn_prompt_tokens, - self.last_turn_completion_tokens, + self.metrics.last_turn_prompt_tokens, + self.metrics.last_turn_completion_tokens, ) - self.last_turn_prompt_tokens = 0 - self.last_turn_completion_tokens = 0 + self.metrics.reset_token_counters() await printr.print_async( f"{actual_response}", color=LogType.POSITIVE, @@ -588,8 +448,6 @@ async def process(self, audio_input_wav: str = None, transcript: str = None, ima if process_result: if self.settings.streamer_mode: self.tower.save_last_message(self.name, process_result) - - # the last step in the chain. You'll probably want to play the response to the user as audio using a TTS provider or mechanism of your choice. await self.play_to_user(str(process_result), not interrupt) except Exception as e: await printr.print_async( @@ -598,342 +456,533 @@ async def process(self, audio_input_wav: str = None, transcript: str = None, ima ) printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - # ───────────────── virtual methods / hooks ───────────────── # + # ───────────────── Transcription ───────────────── # async def _transcribe(self, audio_input_wav: str) -> str | None: - """Transcribes the audio to text. You can override this method if you want to use a different transcription service. - - Args: - audio_input_wav (str): The path to the audio file that contains the user's speech. This is a recording of what you you said. - - Returns: - str | None: The transcript of the audio file and the detected language as locale (if determined). - """ + if not self.stt: + return None + try: + transcript = await self.stt.transcribe(filename=audio_input_wav) + if transcript: + # Wingman Pro might return a serialized dict instead of a real object + if isinstance(transcript, dict): + return transcript.get("_text") + return transcript.text + except Exception as e: + await printr.print_async( + f"Error during transcription using '{self.config.features.stt_provider}': {str(e)}", + color=LogType.ERROR, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) return None + # ───────────────── Response orchestration ───────────────── # + async def _get_response_for_transcript( self, transcript: str, benchmark: Benchmark, images: list[tuple[str, str]] = None - ) -> tuple[str | None, str | None, Skill | None, bool | None]: - """Processes the transcript and return a response as text. This where you'll do most of your work. - Pass the transcript to AI providers and build a conversation. Call commands or APIs. Play temporary results to the user etc. + ) -> tuple[str | None, str | None, Skill | None, bool]: + self.ensure_memory_initialized() + await self.add_user_message(transcript, images=images) - Args: - transcript (str): The user's spoken text transcribed as text. + benchmark.start_snapshot("Instant activation commands") + instant_response, instant_command_executed = await self.command_executor.try_instant_activation( + transcript=transcript + ) + if instant_response: + await self.conversation.add_assistant_message(instant_response) + benchmark.finish_snapshot() + if instant_response == ".": + instant_response = None + return instant_response, instant_response, None, True + benchmark.finish_snapshot() + + llm_processing_time_ms = 0.0 + tool_execution_time_ms = 0.0 + tool_timings: list[tuple[str, float]] = [] + + llm_start = time.perf_counter() + completion = await self._llm_call(instant_command_executed is False) + llm_processing_time_ms += (time.perf_counter() - llm_start) * 1000 + + if completion is None: + self.metrics.add_benchmark_snapshot( + benchmark, "LLM Processing", llm_processing_time_ms + ) + return None, None, None, True - Returns: - A tuple of strings representing the response to a function call and/or an instant response. - """ - return "", "", None, None + response_message, tool_calls, usage = await self._process_completion( + completion, instant_command_executed is False + ) - async def play_to_user( - self, - text: str, - no_interrupt: bool = False, - sound_config: Optional[SoundConfig] = None, - ): - """You'll probably want to play the response to the user as audio using a TTS provider or mechanism of your choice. + turn_prompt_tokens = usage[0] + turn_completion_tokens = usage[1] - Args: - text (str): The response of your _get_response_for_transcript. This is usually the "response" from conversation with the AI. - no_interrupt (bool): prevent interrupting the audio playback - sound_config (SoundConfig): An optional sound configuration to use for the playback. If unset, the Wingman's sound config is used. - """ - pass + is_waiting_response_needed, is_summarize_needed = await self.conversation.add_gpt_response( + response_message, tool_calls + ) + interrupt = True - # ───────────────────────────────── Commands ─────────────────────────────── # + while tool_calls: + if is_waiting_response_needed: + message = None + if response_message.content: + message = response_message.content + else: + filler = self.instant_response_generator.get_random_filler() + if filler: + message = filler + is_summarize_needed = True + if message: + self.threaded_execution(self.play_to_user, message, not interrupt) + await printr.print_async( + f"{message}", + color=LogType.POSITIVE, + source=LogSource.WINGMAN, + source_name=self.name, + skill_name="", + ) + interrupt = False + else: + is_summarize_needed = True + else: + is_summarize_needed = True - def get_command(self, command_name: str) -> CommandConfig | None: - """Extracts the command with the given name + tool_start = time.perf_counter() + instant_response, skill, iteration_timings = await self._handle_tool_calls( + tool_calls + ) + tool_execution_time_ms += (time.perf_counter() - tool_start) * 1000 + tool_timings.extend(iteration_timings) - Args: - command_name (str): the name of the command you used in the config + if instant_response: + await self.conversation.trim_tool_responses(max_tokens=500, is_condensing=self.condenser.is_condensing) + self.metrics.add_benchmark_snapshot( + benchmark, "LLM Processing", llm_processing_time_ms + ) + if tool_execution_time_ms > 0: + self.metrics.add_tool_execution_snapshot( + benchmark, tool_execution_time_ms, tool_timings + ) + await self.metrics.broadcast_token_usage( + turn_prompt_tokens, turn_completion_tokens + ) + return None, instant_response, None, interrupt - Returns: - {}: The command object from the config - """ - if self.config.commands is None: - return None + if is_summarize_needed: + llm_start = time.perf_counter() + completion = await self._llm_call(True) + llm_processing_time_ms += (time.perf_counter() - llm_start) * 1000 - command = next( - (item for item in self.config.commands if item.name == command_name), - None, - ) - return command + if completion is None: + await self.conversation.trim_tool_responses(max_tokens=500, is_condensing=self.condenser.is_condensing) + self.metrics.add_benchmark_snapshot( + benchmark, "LLM Processing", llm_processing_time_ms + ) + if tool_execution_time_ms > 0: + self.metrics.add_tool_execution_snapshot( + benchmark, tool_execution_time_ms, tool_timings + ) + await self.metrics.broadcast_token_usage( + turn_prompt_tokens, turn_completion_tokens + ) + return None, None, None, True + + response_message, tool_calls, usage = await self._process_completion( + completion + ) + turn_prompt_tokens = usage[0] + turn_completion_tokens += usage[1] - def _select_instant_command_response(self, command: CommandConfig) -> str | None: - """Returns one of the configured responses of the command. This base implementation returns a random one. + is_waiting_response_needed, is_summarize_needed = ( + await self.conversation.add_gpt_response(response_message, tool_calls) + ) + if tool_calls: + interrupt = False + elif is_waiting_response_needed: + await self.conversation.trim_tool_responses(max_tokens=500, is_condensing=self.condenser.is_condensing) + self.metrics.add_benchmark_snapshot( + benchmark, "LLM Processing", llm_processing_time_ms + ) + if tool_execution_time_ms > 0: + self.metrics.add_tool_execution_snapshot( + benchmark, tool_execution_time_ms, tool_timings + ) + await self.metrics.broadcast_token_usage( + turn_prompt_tokens, turn_completion_tokens + ) + return None, None, None, interrupt + + await self.conversation.trim_tool_responses(max_tokens=500, is_condensing=self.condenser.is_condensing) + + self.metrics.add_benchmark_snapshot( + benchmark, "LLM Processing", llm_processing_time_ms + ) + if tool_execution_time_ms > 0: + self.metrics.add_tool_execution_snapshot( + benchmark, tool_execution_time_ms, tool_timings + ) + await self.metrics.broadcast_token_usage(turn_prompt_tokens, turn_completion_tokens) + return response_message.content, response_message.content, None, interrupt - Args: - command (dict): The command object from the config + # ───────────────── LLM call ───────────────── # - Returns: - str: A random response from the command's responses list in the config. - """ - command_responses = command.responses - if (command_responses is None) or (len(command_responses) == 0): + async def actual_llm_call(self, messages, tools: list[dict] = None): + if not self.llm: + await printr.print_async( + f"No LLM provider configured for wingman '{self.name}'.", + color=LogType.ERROR, + source=LogSource.WINGMAN, + source_name=self.name, + ) return None - return random.choice(command_responses) + try: + completion = await self.llm.ask(messages=messages, tools=tools) + except APIConnectionError as e: + provider = self.config.features.conversation_provider.value + cause = e.__cause__ + detail = str(cause) if cause else str(e) + message = f"Could not connect to {provider}: {detail}" + await printr.print_async( + message, + color=LogType.ERROR, + source=LogSource.WINGMAN, + source_name=self.name, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + return None + except Exception as e: + await printr.print_async( + f"Error during LLM call: {str(e)}", + color=LogType.ERROR, + source=LogSource.WINGMAN, + source_name=self.name, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + return None - async def _execute_instant_activation_command( - self, transcript: str - ) -> list[CommandConfig] | None: - """Uses a fuzzy string matching algorithm to match the transcript to a configured instant_activation command and executes it immediately. + return completion - Args: - transcript (text): What the user said, transcripted to text. Needs to be similar to one of the defined instant_activation phrases to work. + async def _llm_call(self, allow_tool_calls: bool = True): + thiscall = time.time() + self.last_gpt_call = thiscall - Returns: - {} | None: The executed instant_activation command. - """ + tools = self.build_tools() if allow_tool_calls else None - try: - # create list with phrases pointing to commands - commands_by_instant_activation = {} - for command in self.config.commands: - if command.instant_activation: - for phrase in command.instant_activation: - if phrase.lower() in commands_by_instant_activation: - commands_by_instant_activation[phrase.lower()].append( - command - ) - else: - commands_by_instant_activation[phrase.lower()] = [command] - - # find best matching phrase - phrase = difflib.get_close_matches( - transcript.lower(), - commands_by_instant_activation.keys(), - n=1, - cutoff=1, + if self.settings.debug_mode: + await printr.print_async( + f"Calling LLM with {len(self.conversation.messages)} messages (excluding context) and {len(tools) if tools else 0} tools.", + color=LogType.INFO, ) - # if no phrase found, return None - if not phrase: - return None + messages = self.conversation.messages.copy() + await self.add_context(messages) - # execute all commands for the phrase - commands = commands_by_instant_activation[phrase[0]] - for command in commands: - await self._execute_command(command, True) + completion = await self.actual_llm_call(messages, tools) - # return the executed command - return commands - except Exception as e: + if self.last_gpt_call != thiscall: await printr.print_async( - f"Error during instant activation in Wingman '{self.name}': {str(e)}", - color=LogType.ERROR, + "LLM call was cancelled due to a new call.", color=LogType.WARNING ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) return None - async def _execute_command( - self, command: CommandConfig, is_instant=False - ) -> tuple[str | None, str]: - """Triggers the execution of a command. This base implementation executes the keypresses defined in the command. + return completion - Args: - command (dict): The command object from the config to execute + async def _process_completion( + self, completion: ChatCompletion, allow_tool_calls: bool = True + ): + response_message = completion.choices[0].message - Returns: - tuple[str | None, str]: A 2-tuple of: - - Instant response (str) to play immediately, or None if there is no instant response. - - Function/tool response (str) to feed back to the LLM (uses command's additional_context, - falls back to "OK", or an error string on failure). - """ + content = response_message.content + if content is None: + response_message.content = "" - if not command: - return None, "Command not found" + if not allow_tool_calls: + response_message.tool_calls = None - try: - if len(command.actions or []) == 0: - await printr.print_async( - f"No actions found for command: {command.name}", - color=LogType.WARNING, - ) - else: - await self.execute_action(command) - await printr.print_async( - f"Executed {'instant' if is_instant else 'AI'} command: {command.name}", - color=LogType.COMMAND, - ) + if response_message.tool_calls: + response_message.tool_calls = await self.tool_executor.fix_tool_calls( + response_message.tool_calls, self.command_executor.get_command + ) - # handle the global special commands: - if command.name == "ResetConversationHistory": - await self.reset_conversation_history() - await printr.print_async( - f"Executed command: {command.name}", color=LogType.COMMAND + prompt_tokens = 0 + completion_tokens = 0 + if completion.usage: + prompt_tokens = completion.usage.prompt_tokens or 0 + completion_tokens = completion.usage.completion_tokens or 0 + + return ( + response_message, + response_message.tool_calls, + (prompt_tokens, completion_tokens), + ) + + # ───────────────── Tool calls ───────────────── # + + async def _handle_tool_calls(self, tool_calls): + return await self.tool_executor.handle_tool_calls( + tool_calls, + tool_skills=self.tool_skills, + skill_registry=self.skill_registry, + mcp_registry=self.mcp_registry, + capability_registry=self.capability_registry, + persistent_memory_service=self.persistent_memory_service, + get_command_fn=self.command_executor.get_command, + execute_command_fn=self.command_executor.execute_command, + play_to_user_fn=self.play_to_user, + local_ai_service=self.local_ai_service, + update_tool_response_fn=self.conversation.update_tool_response, + add_tool_response_fn=self.conversation.add_tool_response, + pending_tool_calls=self.conversation.pending_tool_calls, + ) + + async def execute_command_by_function_call( + self, function_name: str, function_args: dict[str, Any] + ) -> tuple[str, str | None, Skill | None, str | None]: + """Public API kept for backward compatibility with skills.""" + return await self.tool_executor.execute_by_function_call( + function_name, + function_args, + tool_skills=self.tool_skills, + skill_registry=self.skill_registry, + mcp_registry=self.mcp_registry, + capability_registry=self.capability_registry, + persistent_memory_service=self.persistent_memory_service, + get_command_fn=self.command_executor.get_command, + execute_command_fn=self.command_executor.execute_command, + play_to_user_fn=self.play_to_user, + ) + + # ───────────────── Conversation delegation ───────────────── # + + async def add_user_message(self, content: str, images: list[tuple[str, str]] = None): + """Thin wrapper: resets memory-recall state then delegates to ConversationManager.""" + self._memory_recall_notified = False + self.context_builder.reset_memory_notification() + await self.conversation.add_user_message( + content, + images=images, + condense_fn=lambda: self.condenser.maybe_condense(self.local_ai_service), + ) + + async def reset_conversation_history(self): + if self.persistent_memory_service and len(self.conversation.messages) >= 4: + try: + await self.persistent_memory_service.extract_memories( + self.conversation.messages, generate_summary=True ) + except Exception: + pass - return ( - self._select_instant_command_response(command), - command.additional_context or "OK", - ) - except Exception as e: - await printr.print_async( - f"Error executing command '{command.name}' for Wingman '{self.name}': {str(e)}", - color=LogType.ERROR, + await self.conversation.reset() + self.skill_registry.reset_activations() + self.mcp_registry.reset_activations() + + def get_conversation_messages(self, strip_nulls: bool = True) -> list[dict]: + return self.conversation.get_conversation_messages(strip_nulls=strip_nulls) + + # ─── Backward-compat: expose messages directly for skills/services that access it ─── # + + @property + def messages(self) -> list: + return self.conversation.messages + + @messages.setter + def messages(self, value: list): + self.conversation.messages = value + + @property + def conversation_summary(self) -> str: + return self.conversation.conversation_summary + + @conversation_summary.setter + def conversation_summary(self, value: str): + self.conversation.conversation_summary = value + + @property + def pending_tool_calls(self) -> list: + return self.conversation.pending_tool_calls + + @property + def _is_condensing(self) -> bool: + return self.condenser.is_condensing + + # ───────────────── Context ───────────────── # + + async def get_context(self): + config_dir_name = None + if self.tower and self.tower.config_dir and self.tower.config_dir.name: + config_dir_name = self.tower.config_dir.name + + return await self.context_builder.build( + skills=self.skills, + skill_registry=self.skill_registry, + conversation_summary=self.conversation.conversation_summary, + persistent_memory_service=self.persistent_memory_service, + messages=self.conversation.messages, + config_dir_name=config_dir_name, + ) + + def get_last_context(self) -> str: + return self.context_builder.get_last_context() + + async def add_context(self, messages): + context = await self.get_context() + messages.insert(0, {"role": "system", "content": context}) + + # ───────────────── TTS / play_to_user ───────────────── # + + async def play_to_user( + self, + text: str, + no_interrupt: bool = False, + sound_config: Optional[SoundConfig] = None, + ): + if sound_config: + printr.print( + "Using custom sound config for playback", LogType.INFO, server_only=True ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - return None, "ERROR DURING PROCESSING" + else: + sound_config = self.config.sound - async def execute_action(self, command: CommandConfig): - """Executes the actions defined in the command (in order). + text, contains_links, contains_code_blocks = cleanup_text(text) - Args: - command (dict): The command object from the config to execute - """ - if not command or not command.actions: - return + if no_interrupt and self.audio_player.is_playing: + while self.audio_player.is_playing: + await asyncio.sleep(0.1) - def contains_numpad_key(hotkey: str) -> bool: - """Check if the hotkey string contains a numpad key anywhere in the chord. + changed_text = text + for skill in self.skills: + if skill.is_prepared: + changed_text = await skill.on_play_to_user(text, sound_config) + if changed_text != text: + printr.print( + f"Skill '{skill.config.display_name}' modified the text to: '{changed_text}'", + LogType.INFO, + ) + text = changed_text - Args: - hotkey: The hotkey string (e.g., 'num 1', 'ctrl+num 1', 'alt+num 2') + if sound_config.volume == 0.0: + printr.print( + "Volume modifier is set to 0. Skipping TTS processing.", + LogType.WARNING, + server_only=True, + ) + return - Returns: - True if any token in the chord is a numpad key (num 0 - num 9) - """ - if not hotkey: - return False - tokens = hotkey.lower().split("+") - return any(token.startswith("num ") for token in tokens) + if "{SKIP-TTS}" in text: + printr.print( + "Skip TTS phrase found in input. Skipping TTS processing.", + LogType.WARNING, + server_only=True, + ) + return + + if not self.tts: + printr.print( + f"No TTS provider configured for wingman '{self.name}'.", + LogType.WARNING, + server_only=True, + ) + return try: - for action in command.actions: - if action.keyboard: - if action.keyboard.hotkey_codes and not contains_numpad_key( - action.keyboard.hotkey - ): - code = action.keyboard.hotkey_codes - else: - code = action.keyboard.hotkey - - if action.keyboard.press == action.keyboard.release: - # compressed key events - hold = action.keyboard.hold or 0.1 - if ( - action.keyboard.hotkey_codes - and len(action.keyboard.hotkey_codes) == 1 - and not contains_numpad_key(action.keyboard.hotkey) - ): - keyboard.direct_event( - action.keyboard.hotkey_codes[0], - 0 + (1 if action.keyboard.hotkey_extended else 0), - ) - time.sleep(hold) - keyboard.direct_event( - action.keyboard.hotkey_codes[0], - 2 + (1 if action.keyboard.hotkey_extended else 0), - ) - else: - keyboard.press(code) - time.sleep(hold) - keyboard.release(code) - else: - # single key events - if ( - action.keyboard.hotkey_codes - and len(action.keyboard.hotkey_codes) == 1 - and not contains_numpad_key(action.keyboard.hotkey) - ): - keyboard.direct_event( - action.keyboard.hotkey_codes[0], - (0 if action.keyboard.press else 2) - + (1 if action.keyboard.hotkey_extended else 0), - ) - else: - keyboard.send( - code, - action.keyboard.press, - action.keyboard.release, - ) - - if action.mouse: - if action.mouse.move_to: - x, y = action.mouse.move_to - mouse.move(x, y) - - if action.mouse.move: - x, y = action.mouse.move - mouse.move(x, y, absolute=False, duration=0.5) - - if action.mouse.scroll: - mouse.wheel(action.mouse.scroll) - - if action.mouse.button: - if action.mouse.hold: - mouse.press(button=action.mouse.button) - time.sleep(action.mouse.hold) - mouse.release(button=action.mouse.button) - else: - mouse.click(button=action.mouse.button) - - if action.write: - keyboard.write(action.write) - - if action.wait: - time.sleep(action.wait) - - if action.audio: - await self.audio_library.handle_action( - action.audio, self.config.sound.volume - ) + await self.tts.play_audio( + text=text, + sound_config=sound_config, + audio_player=self.audio_player, + wingman_name=self.name, + ) except Exception as e: await printr.print_async( - f"Error executing actions of command '{command.name}' for wingman '{self.name}': {str(e)}", - color=LogType.ERROR, + f"Error during TTS playback: {str(e)}", color=LogType.ERROR ) printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - def threaded_execution(self, function, *args) -> threading.Thread | None: - """Execute a function in a separate thread.""" - try: + # ───────────────── Image generation ───────────────── # - def start_thread(function, *args): - if asyncio.iscoroutinefunction(function): - new_loop = asyncio.new_event_loop() - asyncio.set_event_loop(new_loop) - new_loop.run_until_complete(function(*args)) - new_loop.close() - else: - function(*args) + async def generate_image(self, text: str) -> str: + if ( + self.config.features.image_generation_provider + != ImageGenerationProvider.WINGMAN_PRO + ): + return "" + try: + if self._image_subscription is None: + from providers.wingman_subscription import WingmanSubscription - thread = threading.Thread(target=start_thread, args=(function, *args)) - thread.name = function.__name__ - thread.daemon = True # Mark as daemon so it dies when main process exits - thread.start() - return thread + self._image_subscription = WingmanSubscription( + wingman_name=self.name, settings=self.settings.wingman_pro + ) + return await self._image_subscription.generate_image(text) except Exception as e: - printr.print( - f"Error starting threaded execution: {str(e)}", color=LogType.ERROR + await printr.print_async( + f"Error during image generation: {str(e)}", color=LogType.ERROR ) printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - return None + return "" - async def update_config( - self, config: WingmanConfig, skip_config_validation: bool = True - ) -> bool: - """Update the config of the Wingman. + # ───────────────── Build tools ───────────────── # + + def build_tools(self) -> list[dict]: + """Assemble the full tool list for LLM calls.""" + tools: list[dict] = [] + + command_tool = self.command_executor.get_tool_definition() + if command_tool: + tools.append(command_tool) + + for _, tool in self.capability_registry.get_meta_tools(): + tools.append(tool) + + for _, tool in self.skill_registry.get_active_tools(): + tools.append(tool) + + for _, tool in self.mcp_registry.get_active_tools(): + tools.append(tool) + + if self.persistent_memory_service: + tools.extend(self.persistent_memory_service.get_tool_definitions()) + + return tools + + # ───────────────── Backward-compat delegation ────────────── # - This method should always be called if the config of the Wingman has changed. + def get_command(self, command_name: str) -> CommandConfig | None: + """Backward-compat: delegate to command_executor.""" + return self.command_executor.get_command(command_name) + + async def execute_action(self, command: CommandConfig): + """Backward-compat: delegate to command_executor.""" + await self.command_executor.execute_action(command) + + # ───────────────── Threading ─────────────────────────────── # + + def threaded_execution(self, function, *args) -> threading.Thread | None: + return threaded_execution(function, *args) - Args: - config: The new wingman configuration - skip_config_validation: If False, validate the config and rollback on error + # ───────────────── Config management ─────────────────────── # - Returns: - True if config was updated successfully, False otherwise - """ + async def update_config( + self, config: WingmanConfig, skip_config_validation: bool = True + ) -> bool: try: if not skip_config_validation: old_config = deepcopy(self.config) self.config = config - # Propagate skill config changes to loaded skills + # Propagate to all services that hold a config reference + self.command_executor.config = config + self.conversation._config = config + self.condenser._config = config + self.context_builder._config = config + self.tool_executor._config = config + self.metrics.config = config + self.mcp_manager.config = config + self.skill_manager.config = config + await self._update_skill_configs(config) if not skip_config_validation: @@ -944,7 +993,16 @@ async def update_config( error.error_type != WingmanInitializationErrorType.MISSING_SECRET ): + # Roll back config on all services self.config = old_config + self.command_executor.config = old_config + self.conversation._config = old_config + self.condenser._config = old_config + self.context_builder._config = old_config + self.tool_executor._config = old_config + self.metrics.config = old_config + self.mcp_manager.config = old_config + self.skill_manager.config = old_config return False return True @@ -957,15 +1015,9 @@ async def update_config( return False async def _update_skill_configs(self, wingman_config: WingmanConfig) -> None: - """Propagate skill config changes to loaded skills. - - When the wingman config changes (e.g., user updates custom_properties for a skill), - we need to update the SkillConfig on each loaded skill instance so they see the new values. - """ if not self.skills or not wingman_config.skills: return - # Build lookup of new skill configs by folder name new_skill_configs: dict[str, "SkillConfig"] = {} for skill_config in wingman_config.skills: try: @@ -979,9 +1031,7 @@ async def _update_skill_configs(self, wingman_config: WingmanConfig) -> None: continue new_skill_configs[folder_name] = skill_config - # Update each loaded skill if its config changed for skill in self.skills: - # Get the folder name for this skill try: skill_folder = _get_skill_folder_from_module(skill.config.module) except Exception: @@ -997,51 +1047,56 @@ async def _update_skill_configs(self, wingman_config: WingmanConfig) -> None: fields_set = getattr(user_override, "model_fields_set", None) if fields_set is None: - # Pydantic v1 fallback fields_set = getattr(user_override, "__fields_set__", set()) - # Create updated config by copying current and applying overrides - # This preserves all default values while applying user overrides updated_config = deepcopy(skill.config) - # Apply overrides even if they're explicitly empty. - # This allows users to clear custom properties/prompt in the UI. if "custom_properties" in fields_set: updated_config.custom_properties = user_override.custom_properties if "prompt" in fields_set: updated_config.prompt = user_override.prompt - # Let the skill handle the config update (will compare old vs new) await skill.update_config(updated_config) + async def update_settings(self, settings: SettingsConfig): + try: + self.settings = settings + + # Propagate to all services that hold a settings reference + self.conversation._settings = settings + self.context_builder._settings = settings + self.tool_executor._settings = settings + self.mcp_manager.settings = settings + self.skill_manager.settings = settings + + for skill in self.skills: + skill.settings = settings + + # Re-create Wingman Pro provider when settings change + # (subscription settings might have been updated) + uses_wingman_pro = any([ + self.config.features.conversation_provider == ConversationProvider.WINGMAN_PRO, + self.config.features.tts_provider == TtsProvider.WINGMAN_PRO, + self.config.features.stt_provider == SttProvider.WINGMAN_PRO, + self.config.features.image_generation_provider == ImageGenerationProvider.WINGMAN_PRO, + ]) + if uses_wingman_pro: + await self.validate() + printr.print( + f"Wingman {self.name}: reinitialized providers with new settings", + server_only=True, + ) + + printr.print(f"Wingman {self.name}'s settings changed", server_only=True) + except Exception as e: + await printr.print_async( + f"Error while updating settings for wingman '{self.name}': {str(e)}", + color=LogType.ERROR, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + async def save_config(self): - """Save the config of the Wingman.""" self.tower.save_wingman(self.name) async def save_commands(self): - """Save only the commands section of this wingman's config. - - This performs a partial YAML update - only the commands field is modified - in the config file, avoiding full config serialization. This is much safer - than save_config() for command-only changes as it won't accidentally - overwrite other fields. - - Use this instead of save_config() when you only changed command definitions, - instant_activation phrases, or other command-related fields. - - Example use cases: - - QuickCommands learning instant activation phrases - - Skills dynamically adding/modifying commands - - Skills updating command responses or actions - """ self.tower.save_wingman_commands(self.name) - - async def update_settings(self, settings: SettingsConfig): - """Update the settings of the Wingman. This method should always be called when the user Settings have changed.""" - self.settings = settings - - # Propagate settings changes to already-loaded skills - for skill in self.skills: - skill.settings = settings - - printr.print(f"Wingman {self.name}'s settings changed", server_only=True) diff --git a/wingmen/wingman_context.py b/wingmen/wingman_context.py new file mode 100644 index 000000000..37a5e2e69 --- /dev/null +++ b/wingmen/wingman_context.py @@ -0,0 +1,162 @@ +"""Controlled interface for skills. Limits what plugins can access.""" + +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from openai.types.chat import ChatCompletion + from api.enums import TtsProvider + from api.interface import SoundConfig, WingmanConfig, SettingsConfig + from services.audio_player import AudioPlayer + from wingmen.wingman import Wingman + + +class WingmanContext: + """What skills see. Controlled API surface — no internal leaking.""" + + def __init__(self, wingman: "Wingman"): + self._wingman = wingman + + # --- Properties --- + + @property + def name(self) -> str: + return self._wingman.name + + @property + def config(self) -> "WingmanConfig": + return self._wingman.config + + @property + def settings(self) -> "SettingsConfig": + return self._wingman.settings + + @property + def audio_player(self) -> "AudioPlayer": + return self._wingman.audio_player + + @property + def tower(self): + return self._wingman.tower + + @property + def secret_keeper(self): + return self._wingman.secret_keeper + + # --- Conversation --- + + async def llm_call(self, messages: list[dict], tools: list[dict] | None = None) -> "ChatCompletion | None": + """Make an LLM call. Replaces actual_llm_call().""" + return await self._wingman.actual_llm_call(messages, tools) + + def get_conversation_history(self) -> list[dict]: + """Get a shallow copy of the conversation history. + + Note: message objects are shared with the live conversation state. + Do not mutate individual messages. + """ + return list(self._wingman.conversation.messages) + + async def add_user_message(self, content: str): + await self._wingman.add_user_message(content) + + async def add_assistant_message(self, content: str): + await self._wingman.conversation.add_assistant_message(content) + + async def reset_conversation_history(self): + await self._wingman.reset_conversation_history() + + # --- Audio --- + + async def play_to_user(self, text: str, no_interrupt: bool = False, + sound_config: "Optional[SoundConfig]" = None): + await self._wingman.play_to_user(text, no_interrupt, sound_config) + + # --- Image generation --- + + async def generate_image(self, text: str) -> str: + return await self._wingman.generate_image(text) + + # --- Secrets --- + + async def retrieve_secret(self, secret_name: str, errors: list = None) -> str | None: + return await self._wingman.retrieve_secret(secret_name, errors or []) + + # --- Utilities --- + + def threaded_execution(self, func, *args): + self._wingman.threaded_execution(func, *args) + + async def get_context(self) -> str: + return await self._wingman.context_builder.build( + skills=self._wingman.skills, + skill_registry=self._wingman.skill_registry, + conversation_summary=self._wingman.condenser.summary, + persistent_memory_service=self._wingman.persistent_memory_service, + messages=self._wingman.conversation.messages, + config_dir_name=self._wingman.tower.config_dir.name if self._wingman.tower and self._wingman.tower.config_dir and self._wingman.tower.config_dir.name else None, + ) + + # --- Provider switching (for voice_changer and similar) --- + + async def switch_tts_provider(self, provider: "TtsProvider", + errors: list = None) -> bool: + """Hot-swap the TTS provider at runtime. + + Updates config.features.tts_provider and creates a new TTS instance. + Used by voice_changer skill. + """ + from services.provider_factory import ProviderFactory + old_provider = self._wingman.config.features.tts_provider + self._wingman.config.features.tts_provider = provider + factory = ProviderFactory( + config=self._wingman.config, + settings=self._wingman.settings, + secret_keeper=self._wingman.secret_keeper, + shared_providers=self._wingman._shared_providers, + wingman_name=self._wingman.name, + ) + _errors = errors or [] + new_tts = await factory.create_tts(_errors) + if new_tts: + self._wingman.tts = new_tts + return True + # Roll back config on failure + self._wingman.config.features.tts_provider = old_provider + return False + + # --- Commands --- + + def get_command(self, command_name: str): + """Delegate to command_executor for skills that need direct command lookup.""" + return self._wingman.command_executor.get_command(command_name) + + # --- Backward compatibility (temporary) --- + # These provide access to registries that some skills currently use. + # They should be replaced with proper facade methods in a future iteration. + + @property + def tool_skills(self) -> dict: + # tool_skills lives directly on the Wingman instance, not on tool_executor + return self._wingman.tool_skills + + @property + def mcp_registry(self): + return self._wingman.mcp_registry + + @property + def skill_registry(self): + return self._wingman.skill_registry + + # Expose messages property for backward compat (quick_commands reads it) + @property + def messages(self) -> list: + return self._wingman.conversation.messages + + # Expose local AI services for SkillLocalAI facade + @property + def local_ai_service(self): + return self._wingman.local_ai_service + + @property + def persistent_memory_service(self): + return self._wingman.persistent_memory_service