From 951c761b7e7f82a39411b2dc3b7c88e76585b6a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Thu, 9 Apr 2026 10:47:02 +0200 Subject: [PATCH 01/34] feat: add unified provider interfaces, protocols, and registration decorators --- providers/interfaces.py | 169 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 169 insertions(+) create mode 100644 providers/interfaces.py diff --git a/providers/interfaces.py b/providers/interfaces.py new file mode 100644 index 00000000..d1c1bf33 --- /dev/null +++ b/providers/interfaces.py @@ -0,0 +1,169 @@ +"""Unified provider interfaces for STT, TTS, and LLM providers. + +ABCs define required contracts — providers that don't implement them crash at instantiation. +Protocols define optional capabilities — check with isinstance() before calling. +Registration decorators map config enum values to provider classes for ProviderFactory lookup. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +from api.enums import ConversationProvider, SttProvider, TtsProvider + +if TYPE_CHECKING: + from openai.types.chat import ChatCompletion + + from api.interface import SoundConfig, WingmanInitializationError + from services.audio_player import AudioPlayer + + +# --------------------------------------------------------------------------- +# Unified return types +# --------------------------------------------------------------------------- + +@dataclass +class Transcript: + """Unified STT result. Every provider wraps its native result into this.""" + + text: str + language: str | None = None + confidence: float | None = None + + +# --------------------------------------------------------------------------- +# Core ABCs (required contracts) +# --------------------------------------------------------------------------- + +class SttInterface(ABC): + """Speech-to-text provider interface.""" + + @abstractmethod + async def transcribe(self, filename: str) -> Transcript | None: + """Transcribe an audio file to text.""" + ... + + +class TtsInterface(ABC): + """Text-to-speech provider interface.""" + + @abstractmethod + async def play_audio( + self, + text: str, + sound_config: "SoundConfig", + audio_player: "AudioPlayer", + wingman_name: str, + ) -> None: + """Synthesize speech and play it.""" + ... + + +class LlmInterface(ABC): + """Large language model provider interface.""" + + @abstractmethod + async def ask( + self, + messages: list[dict], + tools: list[dict] | None = None, + ) -> "ChatCompletion | None": + """Send messages to the LLM and get a completion.""" + ... + + +# --------------------------------------------------------------------------- +# Optional Protocols (capability checks) +# --------------------------------------------------------------------------- + +@runtime_checkable +class HasAvailableVoices(Protocol): + """Provider can enumerate available voices.""" + + async def get_available_voices(self, **kwargs) -> list: ... + + +@runtime_checkable +class HasAvailableModels(Protocol): + """Provider can enumerate available models.""" + + async def get_available_models(self) -> list: ... + + +@runtime_checkable +class HasLifecycle(Protocol): + """Provider has load/unload lifecycle for local models.""" + + async def load(self) -> None: ... + async def unload(self) -> None: ... + + +@runtime_checkable +class Validatable(Protocol): + """Provider can validate its configuration.""" + + async def validate(self, errors: "list[WingmanInitializationError]") -> None: ... + + +@runtime_checkable +class HasMinimalReasoning(Protocol): + """LLM provider supports reasoning effort tuning (e.g., O-series, Gemini).""" + + def get_minimal_reasoning_by_model(self, model_name: str) -> dict: ... + + +# --------------------------------------------------------------------------- +# Registration decorators + registries +# --------------------------------------------------------------------------- + +_STT_REGISTRY: dict[SttProvider, type[SttInterface]] = {} +_TTS_REGISTRY: dict[TtsProvider, type[TtsInterface]] = {} +_LLM_REGISTRY: dict[ConversationProvider, type[LlmInterface]] = {} + + +def stt_provider(*provider_enums: SttProvider): + """Register a class as the STT provider for given enum value(s).""" + + def decorator(cls): + for enum_val in provider_enums: + _STT_REGISTRY[enum_val] = cls + return cls + + return decorator + + +def tts_provider(*provider_enums: TtsProvider): + """Register a class as the TTS provider for given enum value(s).""" + + def decorator(cls): + for enum_val in provider_enums: + _TTS_REGISTRY[enum_val] = cls + return cls + + return decorator + + +def llm_provider(*provider_enums: ConversationProvider): + """Register a class as the LLM provider for given enum value(s).""" + + def decorator(cls): + for enum_val in provider_enums: + _LLM_REGISTRY[enum_val] = cls + return cls + + return decorator + + +def get_stt_class(provider: SttProvider) -> type[SttInterface] | None: + """Look up the registered STT provider class for an enum value.""" + return _STT_REGISTRY.get(provider) + + +def get_tts_class(provider: TtsProvider) -> type[TtsInterface] | None: + """Look up the registered TTS provider class for an enum value.""" + return _TTS_REGISTRY.get(provider) + + +def get_llm_class(provider: ConversationProvider) -> type[LlmInterface] | None: + """Look up the registered LLM provider class for an enum value.""" + return _LLM_REGISTRY.get(provider) From f356a54ead7bf6dc898a4dc8270ee0b1e31c19e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Thu, 9 Apr 2026 10:53:21 +0200 Subject: [PATCH 02/34] feat: add STT/TTS/LLM adapter classes to all providers (Tasks 2-13) Co-Authored-By: Claude Opus 4.6 --- providers/edge.py | 22 +++ providers/elevenlabs.py | 25 +++- providers/faster_whisper.py | 39 +++++- providers/google.py | 19 +++ providers/hume.py | 37 +++++ providers/inworld.py | 24 +++- providers/open_ai.py | 260 +++++++++++++++++++++++++++++++++++- providers/parakeet.py | 26 +++- providers/pocket_tts.py | 26 +++- providers/whispercpp.py | 25 +++- providers/x_ai.py | 19 +++ providers/xvasynth.py | 25 +++- 12 files changed, 533 insertions(+), 14 deletions(-) diff --git a/providers/edge.py b/providers/edge.py index 6b996656..b20a8922 100644 --- a/providers/edge.py +++ b/providers/edge.py @@ -1,10 +1,16 @@ from os import path +from typing import TYPE_CHECKING from edge_tts import Communicate +from api.enums import TtsProvider from api.interface import EdgeTtsConfig, SoundConfig +from providers.interfaces import TtsInterface, tts_provider from services.audio_player import AudioPlayer from services.file import get_writable_dir from services.printr import Printr +if TYPE_CHECKING: + from api.interface import WingmanConfig + RECORDING_PATH = "audio_output" OUTPUT_FILE: str = "edge_tts.mp3" @@ -48,3 +54,19 @@ async def __generate_speech( await communicate.save(file_path) return communicate, file_path + + +@tts_provider(TtsProvider.EDGE_TTS) +class EdgeTts(TtsInterface): + def __init__(self, config: "WingmanConfig"): + self._edge = Edge() + self._config = config + + async def play_audio(self, text, sound_config, audio_player, wingman_name): + await self._edge.play_audio( + text=text, + config=self._config.edge_tts, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + ) diff --git a/providers/elevenlabs.py b/providers/elevenlabs.py index ec045272..bd73497a 100644 --- a/providers/elevenlabs.py +++ b/providers/elevenlabs.py @@ -1,17 +1,21 @@ import asyncio -from typing import Callable, Optional +from typing import TYPE_CHECKING, Callable, Optional import requests from threading import Event, Thread import numpy as np import sounddevice as sd from elevenlabslib import User, GenerationOptions, PlaybackOptions, SFXOptions -from api.enums import LogType, SoundEffect, WingmanInitializationErrorType +from api.enums import LogType, SoundEffect, TtsProvider, WingmanInitializationErrorType from api.interface import ElevenlabsConfig, SoundConfig, WingmanInitializationError +from providers.interfaces import TtsInterface, tts_provider from services.audio_player import AudioPlayer from services.printr import Printr from services.sound_effects import get_sound_effects from services.websocket_user import WebSocketUser +if TYPE_CHECKING: + from api.interface import WingmanConfig + class ElevenLabs: def __init__(self, api_key: str, wingman_name: str): @@ -401,3 +405,20 @@ def get_available_models(self): def get_subscription_data(self): return self.user.get_subscription_data() + + +@tts_provider(TtsProvider.ELEVENLABS) +class ElevenLabsTts(TtsInterface): + def __init__(self, elevenlabs_instance: "ElevenLabs", config: "WingmanConfig"): + self._elevenlabs = elevenlabs_instance + self._config = config + + async def play_audio(self, text, sound_config, audio_player, wingman_name): + await self._elevenlabs.play_audio( + text=text, + config=self._config.elevenlabs, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + stream=self._config.elevenlabs.output_streaming, + ) diff --git a/providers/faster_whisper.py b/providers/faster_whisper.py index 842ce1aa..34f619a3 100644 --- a/providers/faster_whisper.py +++ b/providers/faster_whisper.py @@ -1,16 +1,20 @@ from os import path import gc -from typing import Optional +from typing import TYPE_CHECKING, Optional from faster_whisper import WhisperModel -from api.enums import LogType +from api.enums import LogType, SttProvider from api.interface import ( FasterWhisperSettings, FasterWhisperTranscript, FasterWhisperSttConfig, WingmanInitializationError, ) +from providers.interfaces import SttInterface, Transcript, stt_provider from services.printr import Printr +if TYPE_CHECKING: + from api.interface import WingmanConfig + class FasterWhisper: def __init__(self, settings: FasterWhisperSettings): @@ -116,3 +120,34 @@ def transcribe( def validate(self, errors: list[WingmanInitializationError]): pass + + +@stt_provider(SttProvider.FASTER_WHISPER) +class FasterWhisperStt(SttInterface): + """Per-wingman adapter around the shared FasterWhisper singleton.""" + + def __init__(self, shared: "FasterWhisper", config: "WingmanConfig", wingman_name: str): + self._shared = shared + self._config = config + self._wingman_name = wingman_name + + async def transcribe(self, filename: str) -> Transcript | None: + hotwords: list[str] = [self._wingman_name] + default_hotwords = self._config.fasterwhisper.hotwords + if default_hotwords: + hotwords.extend(default_hotwords) + wingman_hotwords = self._config.fasterwhisper.additional_hotwords + if wingman_hotwords: + hotwords.extend(wingman_hotwords) + + result = self._shared.transcribe( + config=self._config.fasterwhisper, + filename=filename, + hotwords=list(set(hotwords)), + ) + if result is None: + return None + return Transcript( + text=result.text, + language=getattr(result, "language", None), + ) diff --git a/providers/google.py b/providers/google.py index 4bddeb65..feadd2c5 100644 --- a/providers/google.py +++ b/providers/google.py @@ -1,9 +1,15 @@ import re +from typing import TYPE_CHECKING from google import genai from google.genai import types from openai import APIStatusError, OpenAI +from api.enums import ConversationProvider +from providers.interfaces import LlmInterface, llm_provider from services.printr import Printr +if TYPE_CHECKING: + from api.interface import WingmanConfig + printr = Printr() @@ -139,3 +145,16 @@ def get_available_models(self): if action == "generateContent": models.append(model) return models + + +@llm_provider(ConversationProvider.GOOGLE) +class GoogleLlm(LlmInterface): + def __init__(self, google_instance: "GoogleGenAI", config: "WingmanConfig"): + self._google = google_instance + self._config = config + + async def ask(self, messages, tools=None): + return self._google.ask( + messages=messages, tools=tools, + model=self._config.google.conversation_model, + ) diff --git a/providers/hume.py b/providers/hume.py index 902493e5..eacc597e 100644 --- a/providers/hume.py +++ b/providers/hume.py @@ -1,5 +1,6 @@ import base64 from os import path +from typing import TYPE_CHECKING import aiofiles from hume.client import AsyncHumeClient from hume.tts import ( @@ -7,17 +8,22 @@ PostedUtteranceVoiceWithId, PostedContextWithGenerationId, ) +from api.enums import TtsProvider from api.interface import ( HumeConfig, SoundConfig, VoiceInfo, WingmanInitializationError, ) +from providers.interfaces import TtsInterface, tts_provider from services.audio_player import AudioPlayer from services.file import get_writable_dir from services.printr import Printr from services.secret_keeper import SecretKeeper +if TYPE_CHECKING: + from api.interface import WingmanConfig + RECORDING_PATH = "audio_output" OUTPUT_FILE: str = "hume.mp3" @@ -106,3 +112,34 @@ async def __write_result_to_file(self, base64_encoded_audio: str): async with aiofiles.open(file_path, "wb") as f: await f.write(audio_data) return file_path + + +@tts_provider(TtsProvider.HUME) +class HumeTts(TtsInterface): + def __init__(self, hume_instance: "Hume", config: "WingmanConfig"): + self._hume = hume_instance + self._config = config + + async def play_audio(self, text, sound_config, audio_player, wingman_name): + try: + await self._hume.play_audio( + text=text, + config=self._config.hume, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + ) + except RuntimeError as e: + if "Event loop is closed" in str(e): + import asyncio + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + await self._hume.play_audio( + text=text, + config=self._config.hume, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + ) + else: + raise diff --git a/providers/inworld.py b/providers/inworld.py index 4a641c7d..dbe32768 100644 --- a/providers/inworld.py +++ b/providers/inworld.py @@ -1,24 +1,28 @@ import base64 import json from os import path -from typing import Optional +from typing import TYPE_CHECKING, Optional import threading import queue import time import requests import aiofiles -from api.enums import LogType +from api.enums import LogType, TtsProvider from api.interface import ( SoundConfig, VoiceInfo, WingmanInitializationError, InworldConfig, ) +from providers.interfaces import TtsInterface, tts_provider from services.audio_player import AudioPlayer from services.file import get_writable_dir from services.printr import Printr from services.secret_keeper import SecretKeeper +if TYPE_CHECKING: + from api.interface import WingmanConfig + RECORDING_PATH = "audio_output" OUTPUT_FILE: str = "inworld.mp3" @@ -278,3 +282,19 @@ async def __write_result_to_file(self, base64_encoded_audio: str): async with aiofiles.open(file_path, "wb") as f: await f.write(audio_data) return file_path + + +@tts_provider(TtsProvider.INWORLD) +class InworldTts(TtsInterface): + def __init__(self, inworld_instance: "Inworld", config: "WingmanConfig"): + self._inworld = inworld_instance + self._config = config + + async def play_audio(self, text, sound_config, audio_player, wingman_name): + await self._inworld.play_audio( + text=text, + config=self._config.inworld, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + ) diff --git a/providers/open_ai.py b/providers/open_ai.py index a177fadc..2c94baaa 100644 --- a/providers/open_ai.py +++ b/providers/open_ai.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod import json import re -from typing import Literal, Mapping, Union +from typing import TYPE_CHECKING, Literal, Mapping, Union import httpx from openai import NOT_GIVEN, NotGiven, Omit, OpenAI, APIStatusError, AzureOpenAI import azure.cognitiveservices.speech as speechsdk -from api.enums import AzureRegion, LogType +from api.enums import AzureRegion, ConversationProvider, LogType, SttProvider, TtsProvider from api.interface import ( AzureInstanceConfig, @@ -14,10 +14,17 @@ SoundConfig, VoiceInfo, ) +from providers.interfaces import ( + SttInterface, TtsInterface, LlmInterface, + Transcript, stt_provider, tts_provider, llm_provider, +) from services.audio_player import AudioPlayer from services.openai_utils import get_minimal_reasoning_by_model from services.printr import Printr +if TYPE_CHECKING: + from api.interface import WingmanConfig + printr = Printr() @@ -580,3 +587,252 @@ def buffer_callback(audio_buffer): printr.toast_error( "An unknown OpenAI-compatible TTS error has occurred." ) + + +@stt_provider(SttProvider.OPENAI) +class OpenAiStt(SttInterface): + """OpenAI Whisper STT via the unified interface.""" + + def __init__(self, openai_instance: "OpenAi"): + self._openai = openai_instance + + async def transcribe(self, filename: str) -> Transcript | None: + result = self._openai.transcribe(filename=filename) + if result is None: + return None + return Transcript(text=result.text) + + +@tts_provider(TtsProvider.OPENAI) +class OpenAiTts(TtsInterface): + """OpenAI TTS via the unified interface.""" + + def __init__(self, openai_instance: "OpenAi", config: "WingmanConfig"): + self._openai = openai_instance + self._config = config + + async def play_audio(self, text, sound_config, audio_player, wingman_name): + await self._openai.play_audio( + text=text, + voice=self._config.openai.tts_voice, + model=self._config.openai.tts_model, + speed=self._config.openai.tts_speed, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + stream=self._config.openai.output_streaming, + ) + + +@llm_provider(ConversationProvider.OPENAI) +class OpenAiLlm(LlmInterface): + """OpenAI chat completions via the unified interface.""" + + def __init__(self, openai_instance: "OpenAi", config: "WingmanConfig"): + self._openai = openai_instance + self._config = config + + async def ask(self, messages, tools=None): + return self._openai.ask( + messages=messages, + tools=tools, + model=self._config.openai.conversation_model, + ) + + +@stt_provider(SttProvider.GROQ) +class GroqStt(SttInterface): + """Groq Whisper STT (uses OpenAi-compatible client).""" + + def __init__(self, openai_instance: "OpenAi"): + self._openai = openai_instance + + async def transcribe(self, filename: str) -> Transcript | None: + result = self._openai.transcribe( + filename=filename, model="whisper-large-v3-turbo" + ) + if result is None: + return None + return Transcript(text=result.text) + + +@llm_provider(ConversationProvider.MISTRAL) +class MistralLlm(LlmInterface): + def __init__(self, openai_instance: "OpenAi", config: "WingmanConfig"): + self._openai = openai_instance + self._config = config + + async def ask(self, messages, tools=None): + return self._openai.ask( + messages=messages, tools=tools, + model=self._config.mistral.conversation_model, + ) + + +@llm_provider(ConversationProvider.GROQ) +class GroqLlm(LlmInterface): + def __init__(self, openai_instance: "OpenAi", config: "WingmanConfig"): + self._openai = openai_instance + self._config = config + + async def ask(self, messages, tools=None): + return self._openai.ask( + messages=messages, tools=tools, + model=self._config.groq.conversation_model, + ) + + +@llm_provider(ConversationProvider.CEREBRAS) +class CerebrasLlm(LlmInterface): + def __init__(self, openai_instance: "OpenAi", config: "WingmanConfig"): + self._openai = openai_instance + self._config = config + + async def ask(self, messages, tools=None): + return self._openai.ask( + messages=messages, tools=tools, + model=self._config.cerebras.conversation_model, + ) + + +@llm_provider(ConversationProvider.OPENROUTER) +class OpenRouterLlm(LlmInterface): + def __init__(self, openai_instance: "OpenAi", config: "WingmanConfig", + supports_tools: bool = False): + self._openai = openai_instance + self._config = config + self.supports_tools = supports_tools + + async def ask(self, messages, tools=None): + effective_tools = tools if self.supports_tools else None + return self._openai.ask( + messages=messages, tools=effective_tools, + model=self._config.openrouter.conversation_model, + ) + + +@llm_provider(ConversationProvider.LOCAL_LLM) +class LocalLlm(LlmInterface): + def __init__(self, openai_instance: "OpenAi | None", config: "WingmanConfig"): + self._openai = openai_instance + self._config = config + + async def ask(self, messages, tools=None): + if not self._openai: + raise RuntimeError( + f"Local LLM provider is not initialized. " + f"Please check your Local LLM endpoint configuration " + f"({self._config.local_llm.endpoint})." + ) + return self._openai.ask( + messages=messages, tools=tools, + model=self._config.local_llm.conversation_model, + ) + + +@llm_provider(ConversationProvider.PERPLEXITY) +class PerplexityLlm(LlmInterface): + def __init__(self, openai_instance: "OpenAi", config: "WingmanConfig"): + self._openai = openai_instance + self._config = config + + async def ask(self, messages, tools=None): + return self._openai.ask( + messages=messages, tools=tools, + model=self._config.perplexity.conversation_model.value, + ) + + +@stt_provider(SttProvider.AZURE) +class AzureWhisperStt(SttInterface): + def __init__(self, azure_instance: "OpenAiAzure", api_key: str, config: "WingmanConfig"): + self._azure = azure_instance + self._api_key = api_key + self._config = config + + async def transcribe(self, filename: str) -> Transcript | None: + result = self._azure.transcribe_whisper( + filename=filename, + api_key=self._api_key, + config=self._config.azure.whisper, + ) + if result is None: + return None + return Transcript(text=result.text) + + +@stt_provider(SttProvider.AZURE_SPEECH) +class AzureSpeechStt(SttInterface): + def __init__(self, azure_instance: "OpenAiAzure", api_key: str, config: "WingmanConfig"): + self._azure = azure_instance + self._api_key = api_key + self._config = config + + async def transcribe(self, filename: str) -> Transcript | None: + result = self._azure.transcribe_azure_speech( + filename=filename, + api_key=self._api_key, + config=self._config.azure.stt, + ) + if result is None: + return None + text = result.get("_text") if isinstance(result, dict) else result.text + return Transcript(text=text) if text else None + + +@tts_provider(TtsProvider.AZURE) +class AzureTts(TtsInterface): + def __init__(self, azure_instance: "OpenAiAzure", api_key: str, config: "WingmanConfig"): + self._azure = azure_instance + self._api_key = api_key + self._config = config + + async def play_audio(self, text, sound_config, audio_player, wingman_name): + await self._azure.play_audio( + text=text, + api_key=self._api_key, + config=self._config.azure.tts, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + ) + + +@llm_provider(ConversationProvider.AZURE) +class AzureLlm(LlmInterface): + def __init__(self, azure_instance: "OpenAiAzure", api_key: str, config: "WingmanConfig"): + self._azure = azure_instance + self._api_key = api_key + self._config = config + + async def ask(self, messages, tools=None): + return self._azure.ask( + messages=messages, + api_key=self._api_key, + config=self._config.azure.conversation, + tools=tools, + ) + + +@tts_provider(TtsProvider.OPENAI_COMPATIBLE) +class OpenAiCompatibleTtsAdapter(TtsInterface): + def __init__(self, tts_instance: "OpenAiCompatibleTts", config: "WingmanConfig"): + self._tts = tts_instance + self._config = config + + async def play_audio(self, text, sound_config, audio_player, wingman_name): + from openai import NOT_GIVEN + await self._tts.play_audio( + text=text, + voice=self._config.openai_compatible_tts.voice, + model=self._config.openai_compatible_tts.model, + speed=( + self._config.openai_compatible_tts.speed + if self._config.openai_compatible_tts.speed + else NOT_GIVEN + ), + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + stream=self._config.openai_compatible_tts.output_streaming, + ) diff --git a/providers/parakeet.py b/providers/parakeet.py index d07aa427..58850193 100644 --- a/providers/parakeet.py +++ b/providers/parakeet.py @@ -1,19 +1,23 @@ import gc import platform import threading -from typing import Optional +from typing import TYPE_CHECKING, Optional import requests -from api.enums import LogType +from api.enums import LogType, SttProvider from api.interface import ( ParakeetSettings, ParakeetSttConfig, ParakeetTranscript, WingmanInitializationError, ) +from providers.interfaces import SttInterface, Transcript, stt_provider from services.printr import Printr +if TYPE_CHECKING: + from api.interface import WingmanConfig + EXECUTION_PROVIDER_MAP = { "cpu": ["CPUExecutionProvider"], @@ -202,3 +206,21 @@ def _transcribe_remote(self, filename: str) -> Optional[ParakeetTranscript]: def validate(self, errors: list[WingmanInitializationError]): pass + + +@stt_provider(SttProvider.PARAKEET) +class ParakeetStt(SttInterface): + """Per-wingman adapter around the shared Parakeet singleton.""" + + def __init__(self, shared: "Parakeet", config: "WingmanConfig"): + self._shared = shared + self._config = config + + async def transcribe(self, filename: str) -> Transcript | None: + result = self._shared.transcribe( + config=self._config.parakeet, + filename=filename, + ) + if result is None: + return None + return Transcript(text=result.text) diff --git a/providers/pocket_tts.py b/providers/pocket_tts.py index e6de9874..5902a133 100644 --- a/providers/pocket_tts.py +++ b/providers/pocket_tts.py @@ -3,11 +3,11 @@ import sys import glob import asyncio -from typing import Optional +from typing import TYPE_CHECKING, Optional import torch import torchaudio from pocket_tts import TTSModel -from api.enums import LogType +from api.enums import LogType, TtsProvider from api.interface import ( PocketTTSConfig, SoundConfig, @@ -15,11 +15,15 @@ WingmanInitializationError, VoiceInfo, ) +from providers.interfaces import TtsInterface, tts_provider from services.file import get_custom_voices_dir from services.audio_player import AudioPlayer from services.printr import Printr from providers.open_ai import OpenAiCompatibleTts +if TYPE_CHECKING: + from api.interface import WingmanConfig + MODELS_DIR = "pocket-tts-models" POCKET_TTS_VOICES_DIR = "embeddings" @@ -606,3 +610,21 @@ def _get_wingman_included_voices_dir(self) -> str: # Determine path to wingman included voices directory app_dir = self._get_app_dir() return os.path.join(app_dir, INCLUDED_VOICES_DIR) + + +@tts_provider(TtsProvider.POCKET_TTS) +class PocketTtsTts(TtsInterface): + """Per-wingman adapter around the shared PocketTTS singleton.""" + + def __init__(self, shared: "PocketTTS", config: "WingmanConfig"): + self._shared = shared + self._config = config + + async def play_audio(self, text, sound_config, audio_player, wingman_name): + await self._shared.play_audio( + text=text, + config=self._config.pocket_tts, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + ) diff --git a/providers/whispercpp.py b/providers/whispercpp.py index 81e3b313..3041ef28 100644 --- a/providers/whispercpp.py +++ b/providers/whispercpp.py @@ -1,13 +1,18 @@ +from typing import TYPE_CHECKING import requests -from api.enums import LogType +from api.enums import LogType, SttProvider from api.interface import ( WhispercppSettings, WhispercppSttConfig, WhispercppTranscript, WingmanInitializationError, ) +from providers.interfaces import SttInterface, Transcript, stt_provider from services.printr import Printr +if TYPE_CHECKING: + from api.interface import WingmanConfig + class Whispercpp: def __init__( @@ -91,3 +96,21 @@ def __is_server_running(self, timeout=5): return response.ok except Exception: return False + + +@stt_provider(SttProvider.WHISPERCPP) +class WhispercppStt(SttInterface): + """Per-wingman adapter around the shared Whispercpp singleton.""" + + def __init__(self, shared: "Whispercpp", config: "WingmanConfig"): + self._shared = shared + self._config = config + + async def transcribe(self, filename: str) -> Transcript | None: + result = self._shared.transcribe( + filename=filename, + config=self._config.whispercpp, + ) + if result is None: + return None + return Transcript(text=result.text) diff --git a/providers/x_ai.py b/providers/x_ai.py index 1706a1d9..d02f38c7 100644 --- a/providers/x_ai.py +++ b/providers/x_ai.py @@ -1,6 +1,12 @@ +from typing import TYPE_CHECKING from openai import OpenAI, APIStatusError +from api.enums import ConversationProvider +from providers.interfaces import LlmInterface, llm_provider from providers.open_ai import OpenAi +if TYPE_CHECKING: + from api.interface import WingmanConfig + class XAi(OpenAi): def _perform_ask( @@ -48,3 +54,16 @@ def _fix_tools(self, tools: list[dict[str, any]]) -> list[dict[str, any]]: } fixed_tools.append(fixed_tool) return fixed_tools + + +@llm_provider(ConversationProvider.XAI) +class XAiLlm(LlmInterface): + def __init__(self, xai_instance: "XAi", config: "WingmanConfig"): + self._xai = xai_instance + self._config = config + + async def ask(self, messages, tools=None): + return self._xai.ask( + messages=messages, tools=tools, + model=self._config.xai.conversation_model, + ) diff --git a/providers/xvasynth.py b/providers/xvasynth.py index a8939296..1c2815da 100644 --- a/providers/xvasynth.py +++ b/providers/xvasynth.py @@ -3,13 +3,18 @@ import platform import subprocess import time +from typing import TYPE_CHECKING import requests -from api.enums import LogType +from api.enums import LogType, TtsProvider from api.interface import XVASynthSettings, XVASynthTtsConfig, SoundConfig +from providers.interfaces import TtsInterface, tts_provider from services.audio_player import AudioPlayer from services.file import get_writable_dir from services.printr import Printr +if TYPE_CHECKING: + from api.interface import WingmanConfig + RECORDING_PATH = "audio_output" OUTPUT_FILE = "xvasynth.wav" SYNTHESIZE_URL = "synthesize" @@ -214,3 +219,21 @@ def __is_server_running(self, timeout=10): return response.ok except Exception: return False + + +@tts_provider(TtsProvider.XVASYNTH) +class XVASynthTts(TtsInterface): + """Per-wingman adapter around the shared XVASynth singleton.""" + + def __init__(self, shared: "XVASynth", config: "WingmanConfig"): + self._shared = shared + self._config = config + + async def play_audio(self, text, sound_config, audio_player, wingman_name): + await self._shared.play_audio( + text=text, + config=self._config.xvasynth, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + ) From 2c556a00a452efc3a4f4c9791c4708af1fe5224c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Thu, 9 Apr 2026 10:56:37 +0200 Subject: [PATCH 03/34] feat: rename WingmanPro to WingmanSubscription, add STT/TTS/LLM adapters Co-Authored-By: Claude Sonnet 4.6 --- ...wingman_pro.py => wingman_subscription.py} | 102 +++++++++++++++++- services/voice_service.py | 12 +-- wingman_core.py | 4 +- wingmen/open_ai_wingman.py | 6 +- 4 files changed, 110 insertions(+), 14 deletions(-) rename providers/{wingman_pro.py => wingman_subscription.py} (82%) diff --git a/providers/wingman_pro.py b/providers/wingman_subscription.py similarity index 82% rename from providers/wingman_pro.py rename to providers/wingman_subscription.py index 25d2a724..b162e3a7 100644 --- a/providers/wingman_pro.py +++ b/providers/wingman_subscription.py @@ -1,8 +1,14 @@ -from typing import Optional +from typing import TYPE_CHECKING, Optional import openai import requests from openai.types.audio import Transcription -from api.enums import CommandTag, LogType +from api.enums import ( + CommandTag, + ConversationProvider, + LogType, + SttProvider, + TtsProvider, +) from api.interface import ( AzureSttConfig, AzureTtsConfig, @@ -11,13 +17,25 @@ VoiceInfo, WingmanProSettings, ) +from providers.interfaces import ( + SttInterface, + TtsInterface, + LlmInterface, + Transcript, + stt_provider, + tts_provider, + llm_provider, +) from services.audio_player import AudioPlayer from services.openai_utils import get_minimal_reasoning_by_model from services.printr import Printr from services.secret_keeper import SecretKeeper +if TYPE_CHECKING: + from api.interface import WingmanConfig + -class WingmanPro: +class WingmanSubscription: def __init__( self, wingman_name: str, settings: WingmanProSettings, timeout: int = 120 ): @@ -470,3 +488,81 @@ def __remove_nones(self, obj): ) else: return obj + + +# --------------------------------------------------------------------------- +# Adapter classes — bridge WingmanSubscription into unified provider interfaces +# --------------------------------------------------------------------------- + + +@stt_provider(SttProvider.WINGMAN_PRO) +class WingmanSubscriptionStt(SttInterface): + def __init__(self, ws_instance: "WingmanSubscription", config: "WingmanConfig"): + self._ws = ws_instance + self._config = config + + async def transcribe(self, filename: str) -> Transcript | None: + from api.enums import WingmanProSttProvider + if self._config.wingman_pro.stt_provider == WingmanProSttProvider.WHISPER: + result = self._ws.transcribe_whisper(filename=filename) + elif self._config.wingman_pro.stt_provider == WingmanProSttProvider.AZURE_SPEECH: + result = self._ws.transcribe_azure_speech( + filename=filename, config=self._config.azure.stt + ) + else: + return None + if result is None: + return None + # WingmanSubscription might return a dict instead of a real transcript object + text = result.get("_text") if isinstance(result, dict) else result.text + return Transcript(text=text) if text else None + + +@tts_provider(TtsProvider.WINGMAN_PRO) +class WingmanSubscriptionTts(TtsInterface): + def __init__(self, ws_instance: "WingmanSubscription", config: "WingmanConfig"): + self._ws = ws_instance + self._config = config + + async def play_audio(self, text, sound_config, audio_player, wingman_name): + from api.enums import WingmanProTtsProvider + if self._config.wingman_pro.tts_provider == WingmanProTtsProvider.OPENAI: + await self._ws.generate_openai_speech( + text=text, + voice=self._config.openai.tts_voice, + model=self._config.openai.tts_model, + speed=self._config.openai.tts_speed, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + ) + elif self._config.wingman_pro.tts_provider == WingmanProTtsProvider.AZURE: + await self._ws.generate_azure_speech( + text=text, + config=self._config.azure.tts, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + ) + elif self._config.wingman_pro.tts_provider == WingmanProTtsProvider.INWORLD: + await self._ws.generate_inworld_speech( + text=text, + config=self._config.inworld, + sound_config=sound_config, + audio_player=audio_player, + wingman_name=wingman_name, + ) + + +@llm_provider(ConversationProvider.WINGMAN_PRO) +class WingmanSubscriptionLlm(LlmInterface): + def __init__(self, ws_instance: "WingmanSubscription", config: "WingmanConfig"): + self._ws = ws_instance + self._config = config + + async def ask(self, messages, tools=None): + return self._ws.ask( + messages=messages, + deployment=self._config.wingman_pro.conversation_deployment, + tools=tools, + ) diff --git a/services/voice_service.py b/services/voice_service.py index 27cc5582..63415615 100644 --- a/services/voice_service.py +++ b/services/voice_service.py @@ -18,7 +18,7 @@ from providers.hume import Hume from providers.inworld import Inworld from providers.open_ai import OpenAi, OpenAiAzure, OpenAiCompatibleTts -from providers.wingman_pro import WingmanPro +from providers.wingman_subscription import WingmanSubscription from providers.xvasynth import XVASynth from providers.pocket_tts import PocketTTS @@ -245,7 +245,7 @@ def get_azure_voices(self, api_key: str, region: AzureRegion, locale: str = ""): # GET /voices/azure/wingman-pro def get_wingman_pro_azure_voices(self, locale: str = ""): - wingman_pro = WingmanPro( + wingman_pro = WingmanSubscription( wingman_name="", settings=self.config_manager.settings_config.wingman_pro ) voices = wingman_pro.get_available_voices(locale=locale) @@ -258,7 +258,7 @@ def get_wingman_pro_azure_voices(self, locale: str = ""): def get_wingman_pro_inworld_voices( self, filter_language: str = None ) -> list[VoiceInfo]: - wingman_pro = WingmanPro( + wingman_pro = WingmanSubscription( wingman_name="", settings=self.config_manager.settings_config.wingman_pro ) voices = wingman_pro.get_available_inworld_voices( @@ -416,7 +416,7 @@ async def play_pocket_tts( async def play_wingman_pro_azure( self, text: str, config: AzureTtsConfig, sound_config: SoundConfig ): - wingman_pro = WingmanPro( + wingman_pro = WingmanSubscription( wingman_name="system", settings=self.config_manager.settings_config.wingman_pro, ) @@ -432,7 +432,7 @@ async def play_wingman_pro_azure( async def play_wingman_pro_openai( self, text: str, voice: str, model: str, speed: float, sound_config: SoundConfig ): - wingman_pro = WingmanPro( + wingman_pro = WingmanSubscription( wingman_name="system", settings=self.config_manager.settings_config.wingman_pro, ) @@ -453,7 +453,7 @@ async def play_wingman_pro_inworld( config: InworldConfig, sound_config: SoundConfig, ): - wingman_pro = WingmanPro( + wingman_pro = WingmanSubscription( wingman_name="system", settings=self.config_manager.settings_config.wingman_pro, ) diff --git a/wingman_core.py b/wingman_core.py index 77715824..e5d0e366 100644 --- a/wingman_core.py +++ b/wingman_core.py @@ -59,7 +59,7 @@ from providers.llama_cpp_remote import LlamaCppRemote from providers.open_ai import OpenAi from providers.whispercpp import Whispercpp -from providers.wingman_pro import WingmanPro +from providers.wingman_subscription import WingmanSubscription from providers.xvasynth import XVASynth from providers.pocket_tts import PocketTTS from wingmen.open_ai_wingman import OpenAiWingman @@ -1464,7 +1464,7 @@ def run_async_process(): text = None if provider == VoiceActivationSttProvider.WINGMAN_PRO: - wingman_pro = WingmanPro( + wingman_pro = WingmanSubscription( wingman_name="system", settings=self.settings_service.settings.wingman_pro, ) diff --git a/wingmen/open_ai_wingman.py b/wingmen/open_ai_wingman.py index 3203b96b..c7832ef3 100644 --- a/wingmen/open_ai_wingman.py +++ b/wingmen/open_ai_wingman.py @@ -43,7 +43,7 @@ from providers.inworld import Inworld from providers.open_ai import OpenAi, OpenAiAzure from providers.x_ai import XAi -from providers.wingman_pro import WingmanPro +from providers.wingman_subscription import WingmanSubscription from api.commands import McpStateChangedCommand from services.benchmark import Benchmark from services.file import get_prompt @@ -94,7 +94,7 @@ def __init__(self, *args, **kwargs): self.openai_compatible_tts: OpenAiCompatibleTts | None = None self.hume: Hume | None = None self.inworld: Inworld | None = None - self.wingman_pro: WingmanPro | None = None + self.wingman_pro: WingmanSubscription | None = None self.google: GoogleGenAI | None = None self.perplexity: OpenAi | None = None self.xai: XAi | None = None @@ -901,7 +901,7 @@ async def validate_and_set_azure(self, errors: list[WingmanInitializationError]) self.openai_azure = OpenAiAzure() async def validate_and_set_wingman_pro(self): - self.wingman_pro = WingmanPro( + self.wingman_pro = WingmanSubscription( wingman_name=self.name, settings=self.settings.wingman_pro ) From 9af3436c94e28143d565316355be60cb58407157 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Thu, 9 Apr 2026 10:57:33 +0200 Subject: [PATCH 04/34] =?UTF-8?q?milestone:=20M2=20complete=20=E2=80=94=20?= =?UTF-8?q?all=20provider=20adapters=20created?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit From a29afc72b6642f6f15f109640bf63098da8ad92d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Thu, 9 Apr 2026 11:00:32 +0200 Subject: [PATCH 05/34] =?UTF-8?q?feat:=20add=20ProviderFactory=20=E2=80=94?= =?UTF-8?q?=20creates=20providers=20from=20config=20using=20decorator=20re?= =?UTF-8?q?gistry?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Sonnet 4.6 --- services/provider_factory.py | 328 +++++++++++++++++++++++++++++++++++ 1 file changed, 328 insertions(+) create mode 100644 services/provider_factory.py diff --git a/services/provider_factory.py b/services/provider_factory.py new file mode 100644 index 00000000..404de76f --- /dev/null +++ b/services/provider_factory.py @@ -0,0 +1,328 @@ +"""Factory for creating provider instances from config. + +Reads the config enum values, looks up the registered adapter class from +the decorator registry, retrieves API keys via SecretKeeper, and instantiates +the provider. Each provider holds a reference to the live config object. +""" + +import traceback +from typing import TYPE_CHECKING + +from api.enums import ( + ConversationProvider, + ImageGenerationProvider, + SttProvider, + TtsProvider, + WingmanInitializationErrorType, +) +from api.interface import WingmanInitializationError +from providers.interfaces import ( + LlmInterface, + SttInterface, + TtsInterface, + Validatable, + get_llm_class, + get_stt_class, + get_tts_class, +) +from services.printr import Printr + +if TYPE_CHECKING: + from api.interface import SettingsConfig, WingmanConfig + from services.secret_keeper import SecretKeeper + +printr = Printr() + + +# Import all provider modules so their decorators run and populate the registries. +# These imports have no other side effects. +import providers.faster_whisper # noqa: F401 +import providers.parakeet # noqa: F401 +import providers.whispercpp # noqa: F401 +import providers.open_ai # noqa: F401 +import providers.google # noqa: F401 +import providers.x_ai # noqa: F401 +import providers.elevenlabs # noqa: F401 +import providers.edge # noqa: F401 +import providers.hume # noqa: F401 +import providers.inworld # noqa: F401 +import providers.pocket_tts # noqa: F401 +import providers.xvasynth # noqa: F401 +import providers.wingman_subscription # noqa: F401 + + +class ProviderFactory: + """Creates STT, TTS, and LLM provider instances from config.""" + + def __init__( + self, + config: "WingmanConfig", + settings: "SettingsConfig", + secret_keeper: "SecretKeeper", + shared_providers: dict, + wingman_name: str, + ): + self._config = config + self._settings = settings + self._secret_keeper = secret_keeper + self._shared = shared_providers + self._wingman_name = wingman_name + + async def _retrieve_secret( + self, requester: str, errors: list[WingmanInitializationError] + ) -> str | None: + """Retrieve an API key, adding to errors if missing.""" + secret = await self._secret_keeper.retrieve( + requester=requester, + key=requester, + prompt_if_missing=True, + ) + if not secret: + errors.append( + WingmanInitializationError( + wingman_name=self._wingman_name, + message=f"Missing API key for '{requester}'.", + error_type=WingmanInitializationErrorType.MISSING_SECRET, + ) + ) + return secret + + async def create_stt( + self, errors: list[WingmanInitializationError] + ) -> SttInterface | None: + """Create the STT provider from config.""" + stt_enum = self._config.features.stt_provider + # Shared singleton providers — wrap in adapter + if stt_enum == SttProvider.FASTER_WHISPER: + from providers.faster_whisper import FasterWhisperStt + return FasterWhisperStt( + shared=self._shared["fasterwhisper"], + config=self._config, + wingman_name=self._wingman_name, + ) + elif stt_enum == SttProvider.PARAKEET: + from providers.parakeet import ParakeetStt + return ParakeetStt(shared=self._shared["parakeet"], config=self._config) + elif stt_enum == SttProvider.WHISPERCPP: + from providers.whispercpp import WhispercppStt + return WhispercppStt(shared=self._shared["whispercpp"], config=self._config) + elif stt_enum == SttProvider.OPENAI: + api_key = await self._retrieve_secret("openai", errors) + if not api_key: + return None + from providers.open_ai import OpenAi, OpenAiStt + openai = OpenAi(api_key=api_key, organization=self._config.openai.organization) + return OpenAiStt(openai_instance=openai) + elif stt_enum == SttProvider.GROQ: + api_key = await self._retrieve_secret("groq", errors) + if not api_key: + return None + from providers.open_ai import OpenAi, GroqStt + groq = OpenAi(api_key=api_key, base_url=self._config.groq.endpoint) + return GroqStt(openai_instance=groq) + elif stt_enum == SttProvider.AZURE: + api_key = await self._retrieve_secret("azure", errors) + if not api_key: + return None + from providers.open_ai import OpenAiAzure, AzureWhisperStt + return AzureWhisperStt( + azure_instance=OpenAiAzure(), api_key=api_key, config=self._config + ) + elif stt_enum == SttProvider.AZURE_SPEECH: + api_key = await self._retrieve_secret("azure", errors) + if not api_key: + return None + from providers.open_ai import OpenAiAzure, AzureSpeechStt + return AzureSpeechStt( + azure_instance=OpenAiAzure(), api_key=api_key, config=self._config + ) + elif stt_enum == SttProvider.WINGMAN_PRO: + from providers.wingman_subscription import WingmanSubscription, WingmanSubscriptionStt + ws = WingmanSubscription( + wingman_name=self._wingman_name, + settings=self._settings.wingman_pro, + ) + return WingmanSubscriptionStt(ws_instance=ws, config=self._config) + return None + + async def create_tts( + self, errors: list[WingmanInitializationError] + ) -> TtsInterface | None: + """Create the TTS provider from config.""" + tts_enum = self._config.features.tts_provider + if tts_enum == TtsProvider.EDGE_TTS: + from providers.edge import EdgeTts + return EdgeTts(config=self._config) + elif tts_enum == TtsProvider.ELEVENLABS: + api_key = await self._retrieve_secret("elevenlabs", errors) + if not api_key: + return None + from providers.elevenlabs import ElevenLabs, ElevenLabsTts + elevenlabs = ElevenLabs(api_key=api_key, wingman_name=self._wingman_name) + return ElevenLabsTts(elevenlabs_instance=elevenlabs, config=self._config) + elif tts_enum == TtsProvider.HUME: + api_key = await self._retrieve_secret("hume", errors) + if not api_key: + return None + from providers.hume import Hume, HumeTts + hume = Hume(api_key=api_key, wingman_name=self._wingman_name) + return HumeTts(hume_instance=hume, config=self._config) + elif tts_enum == TtsProvider.INWORLD: + api_key = await self._retrieve_secret("inworld", errors) + if not api_key: + return None + from providers.inworld import Inworld, InworldTts + inworld = Inworld(api_key=api_key, wingman_name=self._wingman_name) + return InworldTts(inworld_instance=inworld, config=self._config) + elif tts_enum == TtsProvider.OPENAI: + api_key = await self._retrieve_secret("openai", errors) + if not api_key: + return None + from providers.open_ai import OpenAi, OpenAiTts + openai = OpenAi(api_key=api_key, organization=self._config.openai.organization) + return OpenAiTts(openai_instance=openai, config=self._config) + elif tts_enum == TtsProvider.OPENAI_COMPATIBLE: + api_key = await self._retrieve_secret("openai_compatible", errors) + # api_key might be optional for local endpoints + from providers.open_ai import OpenAiCompatibleTts, OpenAiCompatibleTtsAdapter + tts = OpenAiCompatibleTts( + api_key=api_key or "", + base_url=self._config.openai_compatible_tts.endpoint, + ) + return OpenAiCompatibleTtsAdapter(tts_instance=tts, config=self._config) + elif tts_enum == TtsProvider.AZURE: + api_key = await self._retrieve_secret("azure", errors) + if not api_key: + return None + from providers.open_ai import OpenAiAzure, AzureTts + return AzureTts( + azure_instance=OpenAiAzure(), api_key=api_key, config=self._config + ) + elif tts_enum == TtsProvider.XVASYNTH: + from providers.xvasynth import XVASynthTts + return XVASynthTts(shared=self._shared["xvasynth"], config=self._config) + elif tts_enum == TtsProvider.POCKET_TTS: + from providers.pocket_tts import PocketTtsTts + return PocketTtsTts(shared=self._shared["pocket_tts"], config=self._config) + elif tts_enum == TtsProvider.WINGMAN_PRO: + from providers.wingman_subscription import WingmanSubscription, WingmanSubscriptionTts + ws = WingmanSubscription( + wingman_name=self._wingman_name, + settings=self._settings.wingman_pro, + ) + return WingmanSubscriptionTts(ws_instance=ws, config=self._config) + return None + + async def create_llm( + self, errors: list[WingmanInitializationError] + ) -> LlmInterface | None: + """Create the LLM provider from config.""" + llm_enum = self._config.features.conversation_provider + if llm_enum == ConversationProvider.OPENAI: + api_key = await self._retrieve_secret("openai", errors) + if not api_key: + return None + from providers.open_ai import OpenAi, OpenAiLlm + openai = OpenAi(api_key=api_key, organization=self._config.openai.organization) + return OpenAiLlm(openai_instance=openai, config=self._config) + elif llm_enum == ConversationProvider.MISTRAL: + api_key = await self._retrieve_secret("mistral", errors) + if not api_key: + return None + from providers.open_ai import OpenAi, MistralLlm + mistral = OpenAi(api_key=api_key, base_url=self._config.mistral.endpoint) + return MistralLlm(openai_instance=mistral, config=self._config) + elif llm_enum == ConversationProvider.GROQ: + api_key = await self._retrieve_secret("groq", errors) + if not api_key: + return None + from providers.open_ai import OpenAi, GroqLlm + groq = OpenAi(api_key=api_key, base_url=self._config.groq.endpoint) + return GroqLlm(openai_instance=groq, config=self._config) + elif llm_enum == ConversationProvider.CEREBRAS: + api_key = await self._retrieve_secret("cerebras", errors) + if not api_key: + return None + from providers.open_ai import OpenAi, CerebrasLlm + cerebras = OpenAi(api_key=api_key, base_url=self._config.cerebras.endpoint) + return CerebrasLlm(openai_instance=cerebras, config=self._config) + elif llm_enum == ConversationProvider.GOOGLE: + api_key = await self._retrieve_secret("google", errors) + if not api_key: + return None + from providers.google import GoogleGenAI, GoogleLlm + google = GoogleGenAI(api_key=api_key) + return GoogleLlm(google_instance=google, config=self._config) + elif llm_enum == ConversationProvider.OPENROUTER: + api_key = await self._retrieve_secret("openrouter", errors) + if not api_key: + return None + from providers.open_ai import OpenAi, OpenRouterLlm + openrouter = OpenAi(api_key=api_key, base_url=self._config.openrouter.endpoint) + supports_tools = await self._check_openrouter_tool_support(api_key) + return OpenRouterLlm( + openai_instance=openrouter, config=self._config, + supports_tools=supports_tools, + ) + elif llm_enum == ConversationProvider.LOCAL_LLM: + from providers.open_ai import OpenAi, LocalLlm + local_llm = None + if self._config.local_llm.endpoint: + local_llm = OpenAi( + api_key="not-needed", + base_url=self._config.local_llm.endpoint, + ) + return LocalLlm(openai_instance=local_llm, config=self._config) + elif llm_enum == ConversationProvider.AZURE: + api_key = await self._retrieve_secret("azure", errors) + if not api_key: + return None + from providers.open_ai import OpenAiAzure, AzureLlm + return AzureLlm( + azure_instance=OpenAiAzure(), api_key=api_key, config=self._config + ) + elif llm_enum == ConversationProvider.WINGMAN_PRO: + from providers.wingman_subscription import WingmanSubscription, WingmanSubscriptionLlm + ws = WingmanSubscription( + wingman_name=self._wingman_name, + settings=self._settings.wingman_pro, + ) + return WingmanSubscriptionLlm(ws_instance=ws, config=self._config) + elif llm_enum == ConversationProvider.PERPLEXITY: + api_key = await self._retrieve_secret("perplexity", errors) + if not api_key: + return None + from providers.open_ai import OpenAi, PerplexityLlm + perplexity = OpenAi(api_key=api_key, base_url=self._config.perplexity.endpoint) + return PerplexityLlm(openai_instance=perplexity, config=self._config) + elif llm_enum == ConversationProvider.XAI: + api_key = await self._retrieve_secret("xai", errors) + if not api_key: + return None + from providers.x_ai import XAi, XAiLlm + xai = XAi(api_key=api_key, base_url=self._config.xai.endpoint) + return XAiLlm(xai_instance=xai, config=self._config) + return None + + async def _check_openrouter_tool_support(self, api_key: str) -> bool: + """Check if the configured OpenRouter model supports tools. + + Replicates the logic from OpenAiWingman.validate_and_set_openrouter(). + """ + try: + import requests + model = self._config.openrouter.conversation_model + response = requests.get( + f"https://openrouter.ai/api/v1/models/{model}", + headers={"Authorization": f"Bearer {api_key}"}, + timeout=10, + ) + if response.status_code == 200: + result = response.json() + supported_params = result.get("data", {}).get( + "supported_parameters", [] + ) + return "tools" in supported_params + except Exception: + pass + return False From 62dd24538c3df9800c7088df2737ca404c9c372f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Thu, 9 Apr 2026 11:05:37 +0200 Subject: [PATCH 06/34] feat: extract ConversationManager from OpenAiWingman Move all message-list management (add/update/trim tool responses, user/ assistant message helpers, history cleanup, token estimation, reset, and text-extraction utilities) into a standalone ConversationManager class. The original OpenAiWingman is unchanged; wiring happens in M5. Co-Authored-By: Claude Opus 4.6 --- services/conversation_manager.py | 557 +++++++++++++++++++++++++++++++ 1 file changed, 557 insertions(+) create mode 100644 services/conversation_manager.py diff --git a/services/conversation_manager.py b/services/conversation_manager.py new file mode 100644 index 00000000..e4293d6b --- /dev/null +++ b/services/conversation_manager.py @@ -0,0 +1,557 @@ +"""Manages the conversation message list, tool responses, and history cleanup.""" + +import json +import random +import uuid +from typing import TYPE_CHECKING, Callable, Mapping, Optional + +from openai.types.chat import ( + ChatCompletionMessage, + ChatCompletionMessageToolCall, + ParsedFunction, +) + +from api.enums import ConversationProvider, LogType, LogSource +from services.printr import Printr +from services.token_utils import count_tokens, truncate_to_tokens + +if TYPE_CHECKING: + from api.interface import CommandConfig, SettingsConfig, WingmanConfig + from skills.skill_base import Skill + +printr = Printr() + + +class ConversationManager: + """Owns the conversation message list, tool-response bookkeeping, and + history cleanup / trimming logic. + + Extracted from ``OpenAiWingman`` — all behaviour is identical, only the + home module changed. + """ + + def __init__( + self, + config: "WingmanConfig", + settings: "SettingsConfig", + wingman_name: str, + ): + self._config = config + self._settings = settings + self._wingman_name = wingman_name + self.messages: list = [] + self.pending_tool_calls: list[str] = [] + self.conversation_summary: str = "" + + # ------------------------------------------------------------------ + # GPT / assistant response helpers + # ------------------------------------------------------------------ + + async def add_gpt_response( + self, + message, + tool_calls, + skills: list["Skill"] = None, + skill_registry=None, + tool_skills: dict = None, + ) -> tuple[bool, bool]: + """Adds a message from GPT to the conversation history as well as + adding dummy tool responses for any tool calls. + + Args: + message (dict | ChatCompletionMessage): The message to add. + tool_calls (list): The tool calls associated with the message. + skills: List of active skills for hook invocation. + skill_registry: The skill registry (used for ``is_meta_tool``). + tool_skills: Mapping of function-name -> Skill instance. + """ + # call skill hooks (only for prepared/activated skills) + for skill in (skills or []): + if skill.is_prepared: + await skill.on_add_assistant_message( + message.content, message.tool_calls + ) + + # do not tamper with this message as it will lead to 400 errors! + self.messages.append(message) + + # adding dummy tool responses to prevent corrupted message history on parallel requests + # and checks if waiting response should be played + unique_tools = {} + is_waiting_response_needed = False + is_summarize_needed = False + + if tool_calls: + for tool_call in tool_calls: + if not tool_call.id: + continue + # adding a dummy tool response to get updated later + self.add_tool_response(tool_call, "Loading..", False) + + function_name = tool_call.function.name + + # Meta-tools (search_skills, activate_skill, etc.) always need a follow-up + # LLM call so it can use the newly activated tools + if skill_registry and skill_registry.is_meta_tool(function_name): + is_summarize_needed = True + elif tool_skills and function_name in tool_skills: + skill = tool_skills[function_name] + if await skill.is_waiting_response_needed(function_name): + is_waiting_response_needed = True + if await skill.is_summarize_needed(function_name): + is_summarize_needed = True + + unique_tools[function_name] = True + + if len(unique_tools) == 1 and "execute_command" in unique_tools: + is_waiting_response_needed = True + + return is_waiting_response_needed, is_summarize_needed + + # ------------------------------------------------------------------ + # Tool response management + # ------------------------------------------------------------------ + + def add_tool_response( + self, tool_call, response: str, completed: bool = True + ): + """Adds a tool response to the conversation history. + + Args: + tool_call (dict|ChatCompletionMessageToolCall): The tool call to add the dummy response for. + response (str): The response content. + completed (bool): Whether the tool call is complete. + """ + msg = {"role": "tool", "content": response} + if tool_call.id is not None: + msg["tool_call_id"] = tool_call.id + if tool_call.function.name is not None: + msg["name"] = tool_call.function.name + self.messages.append(msg) + + if tool_call.id and not completed: + self.pending_tool_calls.append(tool_call.id) + + async def update_tool_response(self, tool_call_id, response) -> bool: + """Updates a tool response in the conversation history. + + Args: + tool_call_id (str): The identifier of the tool call to update the response for. + response (str): The new response to set. + + Returns: + bool: True if the response was updated, False if the tool call was not found. + """ + if not tool_call_id: + return False + + index = len(self.messages) + + # go through message history to find and update the tool call + for message in reversed(self.messages): + index -= 1 + if ( + self.get_message_role(message) == "tool" + and message.get("tool_call_id") == tool_call_id + ): + message["content"] = str(response) + if tool_call_id in self.pending_tool_calls: + self.pending_tool_calls.remove(tool_call_id) + return True + + return False + + async def trim_tool_responses( + self, + max_tokens: int = 500, + is_condensing: bool = False, + ): + """Trim oversized tool responses in conversation history. + + Called after the LLM has finished processing a turn with tool calls. + The LLM already had full access to the data; this just prevents stale + bulk data from inflating the context on subsequent turns. + + If significant trimming occurs, broadcasts a condensation notification + so the client UI can display a summary indicator. + + Args: + max_tokens: Maximum token count per tool response before trimming. + is_condensing: Whether condensation is currently running (suppresses + broadcast to avoid interfering with its own cycle). + """ + total_tokens_saved = 0 + for msg in self.messages: + if self.get_message_role(msg) != "tool": + continue + content = msg.get("content", "") + if not content: + continue + token_count = count_tokens(content) + if token_count <= max_tokens: + continue + total_tokens_saved += token_count - max_tokens + trimmed = truncate_to_tokens(content, max_tokens) + msg["content"] = ( + f"{trimmed}\n\n[...trimmed from ~{token_count} to " + f"~{max_tokens} tokens for conversation history. " + f"Full response was processed.]" + ) + + # Notify the client when significant trimming occurs so the UI can + # show a "Show history" indicator explaining the token drop. + # Skip if condensation is already running to avoid interfering with + # its own started/finished broadcast cycle. + if ( + total_tokens_saved > 1000 + and printr._connection_manager + and not is_condensing + ): + from api.commands import ConversationCondensationCommand + + if self.conversation_summary: + summary = ( + f"{self.conversation_summary}\n\n---\n\n" + f"[Latest turn: tool responses trimmed — ~{total_tokens_saved:,} tokens saved]" + ) + else: + summary = ( + f"Tool responses were automatically trimmed after LLM processing.\n" + f"~{total_tokens_saved:,} tokens saved.\n\n" + f"The LLM had full access to the complete data when generating " + f"its response. Responses are trimmed afterwards to keep the " + f"conversation context efficient." + ) + + await printr._connection_manager.broadcast( + ConversationCondensationCommand( + wingman_name=self._wingman_name, + status="finished", + estimated_tokens_saved=total_tokens_saved, + summary_text=summary, + ) + ) + + # ------------------------------------------------------------------ + # User / assistant message management + # ------------------------------------------------------------------ + + async def add_user_message( + self, + content: str, + images: list[tuple[str, str]] = None, + skills: list["Skill"] = None, + condense_fn: Optional[Callable] = None, + ): + """Shortens the conversation history if needed and adds a user message to it. + + Args: + content (str): The message content to add. + images (list[tuple[str, str]]): Optional list of (base64_data, mime_type) tuples to attach. + skills: List of active skills for hook invocation. + condense_fn: Optional async callable invoked after cleanup (``_maybe_condense_history``). + """ + # call skill hooks (only for prepared/activated skills) + for skill in (skills or []): + if skill.is_prepared: + await skill.on_add_user_message(content) + + if images: + msg_content = [] + for img_b64, mime in images: + msg_content.append({ + "type": "image_url", + "image_url": { + "url": f"data:{mime};base64,{img_b64}", + "detail": "auto", + }, + }) + msg_content.append({"type": "text", "text": content}) + msg = {"role": "user", "content": msg_content} + else: + msg = {"role": "user", "content": content} + await self.cleanup_history() + if condense_fn: + await condense_fn() + self.messages.append(msg) + + async def add_assistant_message( + self, content: str, skills: list["Skill"] = None + ): + """Adds an assistant message to the conversation history. + + Args: + content (str): The message content to add. + skills: List of active skills for hook invocation. + """ + # call skill hooks (only for prepared/activated skills) + for skill in (skills or []): + if skill.is_prepared: + await skill.on_add_assistant_message(content, []) + + msg = {"role": "assistant", "content": content} + self.messages.append(msg) + + async def add_forced_assistant_command_calls( + self, + commands: list["CommandConfig"], + skills: list["Skill"] = None, + skill_registry=None, + tool_skills: dict = None, + ): + """Adds forced assistant command calls to the conversation history. + + Args: + commands (list[CommandConfig]): The commands to add. + skills: List of active skills for hook invocation. + skill_registry: The skill registry (used for ``is_meta_tool``). + tool_skills: Mapping of function-name -> Skill instance. + """ + + if not commands: + return + + message = ChatCompletionMessage( + content="", + role="assistant", + tool_calls=[], + ) + tool_id_to_command = {} + for command in commands: + tool_id = None + if ( + self._config.features.conversation_provider + == ConversationProvider.OPENAI + ) or ( + self._config.features.conversation_provider + == ConversationProvider.WINGMAN_PRO + and "gpt" in self._config.wingman_pro.conversation_deployment.lower() + ): + tool_id = f"call_{str(uuid.uuid4()).replace('-', '')}" + elif ( + self._config.features.conversation_provider + == ConversationProvider.GOOGLE + ): + if ( + self._config.google.conversation_model.startswith("gemini-3") + or self._config.google.conversation_model == "gemini-flash-latest" + or self._config.google.conversation_model == "gemini-pro-latest" + or self._config.google.conversation_model + == "gemini-flash-lite-latest" + ): + # gemini 3+ (latest = 3+) needs a thought signature like this, but we cant fake it: + # { + # 'model_extra': { + # 'extra_content': { + # 'google': { + # 'thought_signature': 'EjQKMgFyyNp8mNe4bQmQhOua7gGMH0C9RubFWewy6BzYZJs5f4RqDb8CaiR4gjLxoM1iQqP4' + # } + # } + # } + # } + return + tool_id = f"function-call-{''.join(random.choices('0123456789', k=20))}" + + # early exit for unsupported providers/models + if not tool_id: + return + + tool_call = ChatCompletionMessageToolCall( + id=tool_id, + function=ParsedFunction( + name="execute_command", + arguments=json.dumps({"command_name": command.name}), + ), + type="function", + ) + message.tool_calls.append(tool_call) + tool_id_to_command[tool_id] = command + + await self.add_gpt_response( + message, message.tool_calls, skills, skill_registry, tool_skills + ) + for tool_call in message.tool_calls: + command = tool_id_to_command[tool_call.id] + await self.update_tool_response( + tool_call.id, command.additional_context or "OK" + ) + + # ------------------------------------------------------------------ + # History cleanup and token estimation + # ------------------------------------------------------------------ + + async def cleanup_history(self): + """Cleans up the conversation history by removing messages that are too old.""" + remember_messages = self._config.features.remember_messages + + if remember_messages is None or len(self.messages) == 0: + return 0 # Configuration not set, nothing to delete. + + # Find the cutoff index where to end deletion, making sure to only count 'user' messages towards the limit starting with newest messages. + cutoff_index = len(self.messages) + user_message_count = 0 + for message in reversed(self.messages): + if self.get_message_role(message) == "user": + user_message_count += 1 + if user_message_count == remember_messages: + break # Found the cutoff point. + cutoff_index -= 1 + + # If messages below the keep limit, don't delete anything. + if user_message_count < remember_messages: + return 0 + + total_deleted_messages = cutoff_index # Messages to delete. + + # Remove the pending tool calls that are no longer needed. + for mesage in self.messages[:cutoff_index]: + if ( + self.get_message_role(mesage) == "tool" + and mesage.get("tool_call_id") in self.pending_tool_calls + ): + self.pending_tool_calls.remove(mesage.get("tool_call_id")) + if self._settings.debug_mode: + await printr.print_async( + f"Removing pending tool call {mesage.get('tool_call_id')} due to message history clean up.", + color=LogType.WARNING, + ) + + # Remove the messages before the cutoff index, exclusive of the system message. + del self.messages[:cutoff_index] + + # Optional debugging printout. + if self._settings.debug_mode and total_deleted_messages > 0: + await printr.print_async( + f"Deleted {total_deleted_messages} messages from the conversation history.", + color=LogType.WARNING, + ) + + return total_deleted_messages + + def estimate_tokens(self) -> int: + """Estimate the total token count of the current conversation history.""" + return sum(count_tokens(self._message_text_content(m)) for m in self.messages) + + # ------------------------------------------------------------------ + # Reset + # ------------------------------------------------------------------ + + async def reset(self): + """Resets the conversation message list and summary. + + Note: skill/MCP registry resets and memory extraction are the + responsibility of the caller (the Wingman) since ConversationManager + does not own those services. + """ + self.messages = [] + self.conversation_summary = "" + + # ------------------------------------------------------------------ + # Message serialisation helpers + # ------------------------------------------------------------------ + + def get_conversation_messages(self, strip_nulls: bool = True) -> list[dict]: + """Return the conversation messages as a list of plain dicts for debugging.""" + + def _strip_none(obj): + if isinstance(obj, dict): + return {k: _strip_none(v) for k, v in obj.items() if v is not None} + if isinstance(obj, list): + return [_strip_none(item) for item in obj] + return obj + + result = [] + for msg in self.messages: + if hasattr(msg, "model_dump"): + d = msg.model_dump() + else: + d = msg + if strip_nulls: + d = _strip_none(d) + result.append(d) + return result + + # ------------------------------------------------------------------ + # Text extraction / conversion helpers + # ------------------------------------------------------------------ + + def _extract_text_content(self, content) -> str: + """Extract text from message content, handling both string and multimodal list formats.""" + if isinstance(content, str): + return content + if isinstance(content, list): + # Multimodal content: extract text parts only + parts = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + parts.append(part.get("text", "")) + elif isinstance(part, str): + parts.append(part) + return " ".join(parts) + return "" + + def _message_text_content(self, msg) -> str: + """Extract text content from a message for token estimation.""" + if isinstance(msg, Mapping): + return self._extract_text_content(msg.get("content", "")) or "" + elif hasattr(msg, "content"): + return self._extract_text_content(msg.content) or "" + return "" + + def _messages_to_text(self, messages: list) -> str: + """Convert a list of conversation messages to plain text for summarization.""" + lines = [] + for msg in messages: + role = self.get_message_role(msg) + content = "" + if isinstance(msg, Mapping): + content = self._extract_text_content(msg.get("content", "")) + elif hasattr(msg, "content"): + content = self._extract_text_content(msg.content) or "" + + if role == "user": + lines.append(f"User: {content}") + elif role == "assistant": + if content: + lines.append(f"Assistant: {content}") + # Include tool call info + tool_calls = None + if isinstance(msg, Mapping): + tool_calls = msg.get("tool_calls") + elif hasattr(msg, "tool_calls"): + tool_calls = msg.tool_calls + if tool_calls: + for tc in tool_calls: + fn = ( + tc.function + if hasattr(tc, "function") + else tc.get("function", {}) + ) + name = fn.name if hasattr(fn, "name") else fn.get("name", "?") + args = ( + fn.arguments + if hasattr(fn, "arguments") + else fn.get("arguments", "") + ) + lines.append(f" [Tool call: {name}({args})]") + elif role == "tool": + tool_name = ( + msg.get("name", "tool") if isinstance(msg, Mapping) else "tool" + ) + lines.append(f" [Tool result ({tool_name}): {content[:200]}]") + return "\n".join(lines) + + # ------------------------------------------------------------------ + # Role helper + # ------------------------------------------------------------------ + + def get_message_role(self, message) -> str: + """Helper method to get the role of the message regardless of its type.""" + if isinstance(message, Mapping): + return message.get("role") + elif hasattr(message, "role"): + return message.role + else: + raise TypeError( + f"Message is neither a mapping nor has a 'role' attribute: {message}" + ) From 89fd65521d02ea0ba8fcf0158449477b4ec6e88e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Thu, 9 Apr 2026 11:09:10 +0200 Subject: [PATCH 07/34] feat: extract ConversationCondenser from OpenAiWingman Co-Authored-By: Claude Opus 4.6 --- services/conversation_condenser.py | 585 +++++++++++++++++++++++++++++ 1 file changed, 585 insertions(+) create mode 100644 services/conversation_condenser.py diff --git a/services/conversation_condenser.py b/services/conversation_condenser.py new file mode 100644 index 00000000..e7fb1888 --- /dev/null +++ b/services/conversation_condenser.py @@ -0,0 +1,585 @@ +"""Conversation condensation — summarizes history when it grows too large.""" + +import asyncio +from typing import TYPE_CHECKING + +from api.enums import LogType, LogSource +from services.file import get_prompt +from services.printr import Printr +from services.token_utils import count_tokens, truncate_to_tokens + +if TYPE_CHECKING: + from api.interface import WingmanConfig + from services.conversation_manager import ConversationManager + from services.local_ai_service import LocalAiService + from services.persistent_memory import PersistentMemoryService + +printr = Printr() + +_CONDENSE_TIMEOUT = 120.0 + + +class ConversationCondenser: + def __init__( + self, + conversation: "ConversationManager", + config: "WingmanConfig", + wingman_name: str, + ): + self._conversation = conversation + self._config = config + self._wingman_name = wingman_name + self._is_condensing = False + self._condense_task: asyncio.Task | None = None + self._support_token_ratio: float = 1.35 + + @property + def summary(self) -> str: + return self._conversation.conversation_summary + + @property + def is_condensing(self) -> bool: + return self._is_condensing + + def get_support_capacity(self, local_ai_service: "LocalAiService") -> int: + """Get the effective input capacity of the support model for a single pass. + + Returns the number of conversation tokens that can fit in one summarization + pass, accounting for system prompt, framing text, and output budget. + """ + system_prompt = get_prompt("condense-conversation") + budget = local_ai_service.get_token_budget(system_prompt) + # Subtract framing overhead (prefix/suffix around the conversation text) + framing_overhead = 80 + return max(0, budget.max_input_tokens - framing_overhead) + + async def maybe_condense(self, local_ai_service: "LocalAiService"): + """Check if condensation should run and fire it as a background task. + + Uses a token-based trigger: condenses when the conversation approaches + 70% of what the support model can handle in a single pass, so we + avoid chunking. Also has a message count safety cap. + + The cl100k_base token estimate is multiplied by _support_token_ratio + (calibrated from real support model usage) to account for tokenizer + differences between the estimation tokenizer and the actual model. + """ + if not self._config.features.condense_conversation: + return + if not local_ai_service or not local_ai_service.is_ready(): + return + if self._conversation.pending_tool_calls: + return # Never interrupt chained tool calls + if self._is_condensing: + return + + # Token-based trigger: condense when conversation reaches 70% of + # what the support model can handle in one pass. + # Apply _support_token_ratio to correct for tokenizer differences + # between cl100k_base (used for estimation) and the actual model. + capacity = self.get_support_capacity(local_ai_service) + cl100k_tokens = self._conversation.estimate_tokens() + conversation_tokens = int(cl100k_tokens * self._support_token_ratio) + token_trigger = conversation_tokens >= int(capacity * 0.7) + + # Message count safety cap + user_msg_count = sum( + 1 + for m in self._conversation.messages + if self._conversation.get_message_role(m) == "user" + ) + message_trigger = ( + user_msg_count >= self._config.features.condense_max_messages + ) + + if not token_trigger and not message_trigger: + return + + # Runs in background so user is never blocked. + # Store the task reference to prevent garbage collection mid-execution. + self._condense_task = asyncio.create_task( + self.condense(local_ai_service=local_ai_service) + ) + self._condense_task.add_done_callback( + lambda _: setattr(self, "_condense_task", None) + ) + + async def condense( + self, + local_ai_service: "LocalAiService", + persistent_memory_service: "PersistentMemoryService | None" = None, + background_tasks: set[asyncio.Task] | None = None, + force: bool = False, + ): + """Condense older conversation messages into a running summary using local AI. + + This preserves the most recent messages verbatim while summarizing older ones, + saving tokens without losing important context. Tool call/response pairs are + never split. + + Args: + local_ai_service: The local AI service to use for summarization. + persistent_memory_service: Optional service for extracting memories. + background_tasks: Optional set to track background tasks for memory extraction. + force: If True, skip the threshold check (used for manual trigger). + """ + if self._is_condensing: + await printr.print_async( + "Condensation skipped — already in progress.", + color=LogType.WARNING, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + return + if not local_ai_service or not local_ai_service.is_ready(): + await printr.print_async( + "Condensation skipped — local AI service not available.", + color=LogType.WARNING, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + return + + keep_recent = ( + self._config.features.condense_keep_recent + if not force + else min(self._config.features.condense_keep_recent, 2) + ) + total_msg_count = len(self._conversation.messages) + + # Need at least something to condense beyond what we keep + if total_msg_count <= keep_recent: + await printr.print_async( + f"Condensation skipped — only {total_msg_count} messages, need more than {keep_recent} to condense.", + color=LogType.LOCALMODEL, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + return + + self._is_condensing = True + _condensation_stats: dict = {} + + # Broadcast start + from api.commands import ConversationCondensationCommand + + if printr._connection_manager: + await printr._connection_manager.broadcast( + ConversationCondensationCommand( + wingman_name=self._wingman_name, + status="started", + ) + ) + + await printr.print_async( + "Conversation condensation started.", + color=LogType.INFO, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + + try: + # Wait for any pending tool calls to finish + for _ in range(30): # max 15 seconds + if not self._conversation.pending_tool_calls: + break + await asyncio.sleep(0.5) + else: + await printr.print_async( + "Condensation aborted — tool calls still pending after 15s.", + color=LogType.WARNING, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + return + + # Find the cutoff: keep the most recent `keep_recent` user messages + kept_user_count = 0 + cutoff_index = len(self._conversation.messages) + for i in range(len(self._conversation.messages) - 1, -1, -1): + if ( + self._conversation.get_message_role( + self._conversation.messages[i] + ) + == "user" + ): + kept_user_count += 1 + if kept_user_count == keep_recent: + cutoff_index = i + break + + if cutoff_index <= 0: + await printr.print_async( + f"Condensation skipped — cutoff_index={cutoff_index}, nothing to condense (kept_user_count={kept_user_count}, keep_recent={keep_recent}, total={len(self._conversation.messages)}).", + color=LogType.LOCALMODEL, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + return + + # Adjust cutoff forward to avoid orphaning tool responses + while cutoff_index < len(self._conversation.messages): + msg = self._conversation.messages[cutoff_index] + if self._conversation.get_message_role(msg) == "tool": + cutoff_index += 1 + else: + break + + if cutoff_index <= 0: + await printr.print_async( + "Condensation skipped — no messages to condense after tool adjustment.", + color=LogType.LOCALMODEL, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + return + + to_condense = self._conversation.messages[:cutoff_index] + + # Extract memories from messages about to be condensed (background, non-blocking) + if persistent_memory_service: + try: + task = asyncio.create_task( + persistent_memory_service.extract_memories( + to_condense, generate_summary=True + ) + ) + if background_tasks is not None: + background_tasks.add(task) + task.add_done_callback(background_tasks.discard) + except Exception: + pass # Don't let memory extraction block condensation + + condensed_text = self._conversation._messages_to_text(to_condense) + if not condensed_text.strip(): + await printr.print_async( + "Condensation skipped — messages produced no text content.", + color=LogType.LOCALMODEL, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + return + + # Estimate original token count + estimated_original_tokens = sum( + count_tokens(self._conversation._message_text_content(m)) + for m in to_condense + ) + + # Build the summarization prompt + existing_summary_section = "" + if self._conversation.conversation_summary: + existing_summary_section = ( + "EXISTING SUMMARY (incorporate and update — do not repeat verbatim):\n" + + self._conversation.conversation_summary + + "\n\n" + ) + + system_prompt = get_prompt("condense-conversation") + budget = local_ai_service.get_token_budget(system_prompt) + + user_prompt_prefix = ( + existing_summary_section + "CONVERSATION TO SUMMARIZE:\n" + ) + user_prompt_suffix = ( + "\n\n---\n" + "Now list every fact from the conversation above as bullet points.\n" + "Start from the FIRST message, end at the LAST. Include all secrets, names, preferences, and creative content:" + ) + prefix_suffix_tokens = count_tokens(user_prompt_prefix) + count_tokens( + user_prompt_suffix + ) + + # How much conversation text fits in one pass? + available_tokens = budget.max_input_tokens - prefix_suffix_tokens + + # Apply tokenizer ratio to decide if chunking is needed. + corrected_text_tokens = int( + count_tokens(condensed_text) * self._support_token_ratio + ) + corrected_available = int(available_tokens / self._support_token_ratio) + + if corrected_text_tokens > available_tokens: + # Chunk: summarize in segments, then merge + support_result = await asyncio.wait_for( + self._chunked_support( + condensed_text, + system_prompt, + existing_summary_section, + corrected_available, + local_ai_service, + ), + timeout=_CONDENSE_TIMEOUT, + ) + else: + user_prompt = ( + f"{user_prompt_prefix}{condensed_text}{user_prompt_suffix}" + ) + from services.skill_local_ai import SamplingPreset + + support_result = await asyncio.wait_for( + asyncio.get_event_loop().run_in_executor( + None, + lambda: local_ai_service.support( + text=user_prompt, + system_prompt=system_prompt, + preset=SamplingPreset.BALANCED, + ), + ), + timeout=_CONDENSE_TIMEOUT, + ) + + summary = support_result.text if support_result else None + + # Calibrate tokenizer ratio from real model usage + if support_result and support_result.prompt_tokens > 0: + cl100k_input = ( + budget.system_tokens + + count_tokens(condensed_text) + + prefix_suffix_tokens + ) + if cl100k_input > 0: + self._support_token_ratio = ( + support_result.prompt_tokens / cl100k_input + ) + + # Detect truncated output + if support_result and support_result.truncated: + await printr.print_async( + f"Condensation output was truncated (finish_reason=length). " + f"Model used {support_result.prompt_tokens} prompt tokens, " + f"generated {support_result.completion_tokens} tokens. " + f"Token ratio calibrated to {self._support_token_ratio:.2f}.", + color=LogType.WARNING, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + + if not summary: + await printr.print_async( + "Conversation condensation failed — local AI returned no result.", + color=LogType.WARNING, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + return + + # Clean pending tool calls being removed + for msg in to_condense: + if ( + self._conversation.get_message_role(msg) == "tool" + and msg.get("tool_call_id") + in self._conversation.pending_tool_calls + ): + self._conversation.pending_tool_calls.remove( + msg.get("tool_call_id") + ) + + # Replace old messages + del self._conversation.messages[:cutoff_index] + self._conversation.conversation_summary = summary + + estimated_summary_tokens = count_tokens(summary) + estimated_tokens_saved = max( + 0, estimated_original_tokens - estimated_summary_tokens + ) + + await printr.print_async( + f"Condensed {cutoff_index} messages into summary " + f"({len(summary)} chars, ~{estimated_summary_tokens} tokens). " + f"{len(self._conversation.messages)} messages remaining. " + f"~{estimated_tokens_saved} tokens saved. " + f"Token ratio: {self._support_token_ratio:.2f}.", + color=LogType.INFO, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + + # Record stats for the broadcast in finally + _condensation_stats = { + "messages_condensed": cutoff_index, + "messages_remaining": len(self._conversation.messages), + "summary_length": len(summary), + "estimated_tokens_saved": estimated_tokens_saved, + "summary_text": summary, + } + + except asyncio.TimeoutError: + await printr.print_async( + "Condensation timed out — local model took too long.", + color=LogType.WARNING, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + except Exception as e: + await printr.print_async( + f"Conversation condensation error: {e}", + color=LogType.ERROR, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + finally: + self._is_condensing = False + # Always broadcast finished so the client UI doesn't get stuck. + # Include summary_text if condensation produced one (even if a + # later step failed), so the client can show the view-history button. + if printr._connection_manager: + try: + await printr._connection_manager.broadcast( + ConversationCondensationCommand( + wingman_name=self._wingman_name, + status="finished", + **_condensation_stats, + ) + ) + except Exception as e: + await printr.print_async( + f"Failed to broadcast condensation finish: {e}", + color=LogType.WARNING, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + + async def _chunked_support( + self, + full_text: str, + system_prompt: str, + existing_summary_section: str, + chunk_max_tokens: int, + local_ai_service: "LocalAiService", + ) -> "SupportResult": + """Process text that exceeds the model's context window by chunking. + + Each chunk is processed independently, then results are merged into + one final summary. Returns a SupportResult from the merge step. + """ + from providers.llama_cpp_provider import SupportResult + from services.skill_local_ai import SamplingPreset + + budget = local_ai_service.get_token_budget(system_prompt) + + # Convert token budget to approximate char limit for splitting + # (splitting needs char positions; we use ~4 chars/token as a rough guide, + # then verify with count_tokens) + approx_chunk_chars = chunk_max_tokens * 4 + chunks = [] + remaining = full_text + while remaining: + if count_tokens(remaining) <= chunk_max_tokens: + chunks.append(remaining) + break + # Try to split at a newline boundary + split_at = remaining.rfind("\n", 0, approx_chunk_chars) + if split_at <= 0: + split_at = approx_chunk_chars + chunks.append(remaining[:split_at]) + remaining = remaining[split_at:].lstrip() + + loop = asyncio.get_event_loop() + chunk_summaries = [] + for i, chunk in enumerate(chunks): + user_prompt = ( + f"{existing_summary_section if i == 0 else ''}" + f"CONVERSATION TO SUMMARIZE (part {i + 1}/{len(chunks)}):\n{chunk}\n\n" + "---\nList every fact from the above as bullet points. Include all secrets, names, preferences, and creative content:" + ) + + # Safety: if chunk input exceeds budget, truncate chunk text + chunk_text_tokens = count_tokens(chunk) + prompt_overhead = count_tokens(user_prompt) - chunk_text_tokens + if prompt_overhead + chunk_text_tokens > budget.max_input_tokens: + safe_text_tokens = budget.max_input_tokens - prompt_overhead + if safe_text_tokens > 0: + chunk = truncate_to_tokens(chunk, safe_text_tokens) + user_prompt = ( + f"{existing_summary_section if i == 0 else ''}" + f"CONVERSATION TO SUMMARIZE (part {i + 1}/{len(chunks)}):\n{chunk}\n\n" + "---\nList every fact from the above as bullet points. Include all secrets, names, preferences, and creative content:" + ) + await printr.print_async( + f"Chunk {i + 1}/{len(chunks)} exceeded context budget, truncated to fit.", + color=LogType.WARNING, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + + result = await loop.run_in_executor( + None, + lambda p=user_prompt: local_ai_service.support( + text=p, + system_prompt=system_prompt, + preset=SamplingPreset.BALANCED, + ), + ) + if result.text: + chunk_summaries.append(result.text) + + # Calibrate tokenizer ratio from real model usage. + cl100k_input = count_tokens(system_prompt) + count_tokens(user_prompt) + if result.prompt_tokens > 0 and cl100k_input > 0: + self._support_token_ratio = result.prompt_tokens / cl100k_input + + if result.truncated: + await printr.print_async( + f"Chunk {i + 1}/{len(chunks)} output truncated " + f"(prompt={result.prompt_tokens}, " + f"completion={result.completion_tokens}).", + color=LogType.WARNING, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + + if not chunk_summaries: + return SupportResult(text=None) + if len(chunk_summaries) == 1: + return SupportResult(text=chunk_summaries[0]) + + # Merge all chunk summaries into one final summary + combined = "\n\n".join( + f"Part {i + 1}:\n{s}" for i, s in enumerate(chunk_summaries) + ) + merge_prompt = ( + f"{existing_summary_section}" + f"PARTIAL SUMMARIES TO MERGE:\n{combined}\n\n" + "Merge these into a single coherent summary. Keep all key facts:" + ) + + # Safety: truncate combined summaries if they exceed budget + if count_tokens(merge_prompt) > budget.max_input_tokens: + overhead = count_tokens(merge_prompt) - count_tokens(combined) + safe_combined = budget.max_input_tokens - overhead + if safe_combined > 0: + combined = truncate_to_tokens(combined, safe_combined) + merge_prompt = ( + f"{existing_summary_section}" + f"PARTIAL SUMMARIES TO MERGE:\n{combined}\n\n" + "Merge these into a single coherent summary. Keep all key facts:" + ) + await printr.print_async( + f"Merge input exceeded context budget, truncated to fit.", + color=LogType.WARNING, + server_only=True, + source_name=self._wingman_name, + source=LogSource.WINGMAN, + ) + + return await loop.run_in_executor( + None, + lambda: local_ai_service.support( + text=merge_prompt, + system_prompt=system_prompt, + preset=SamplingPreset.BALANCED, + ), + ) From 6ea37a4658d713119854bee82f15c9deb7663760 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Thu, 9 Apr 2026 11:12:24 +0200 Subject: [PATCH 08/34] feat: extract ContextBuilder from OpenAiWingman Co-Authored-By: Claude Opus 4.6 --- services/context_builder.py | 289 ++++++++++++++++++++++++++++++++++++ 1 file changed, 289 insertions(+) create mode 100644 services/context_builder.py diff --git a/services/context_builder.py b/services/context_builder.py new file mode 100644 index 00000000..8b4a97a8 --- /dev/null +++ b/services/context_builder.py @@ -0,0 +1,289 @@ +"""System prompt assembly from template, skill prompts, TTS prompts, memory injection.""" + +import re +from datetime import datetime +from typing import TYPE_CHECKING, Optional + +from api.enums import ( + LogSource, + LogType, + TtsProvider, + WingmanProTtsProvider, +) +from services.printr import Printr +from services.token_utils import count_tokens, truncate_to_tokens + +if TYPE_CHECKING: + from api.interface import SettingsConfig, WingmanConfig + from services.persistent_memory import PersistentMemoryService + from services.skill_registry import SkillRegistry + from skills.skill_base import Skill + +printr = Printr() + + +class ContextBuilder: + def __init__( + self, + config: "WingmanConfig", + settings: "SettingsConfig", + wingman_name: str, + ): + self._config = config + self._settings = settings + self._wingman_name = wingman_name + self._last_compiled_context: str = "" + self._memory_recall_notified: bool = False + + async def build( + self, + skills: list["Skill"], + skill_registry: "SkillRegistry", + conversation_summary: str, + persistent_memory_service: Optional["PersistentMemoryService"], + messages: list, + config_dir_name: Optional[str] = None, + ) -> str: + """Build the context and return it as a string. + + With progressive disclosure, only includes prompts from ACTIVATED skills. + Skill prompts are auto-generated from @tool descriptions if no custom prompt is set. + """ + skill_prompts = "" + active_skill_names = skill_registry.active_skill_names + + for skill in skills: + # Only include prompts from activated skills (in progressive mode) + if skill.name not in active_skill_names: + continue + + # Get custom prompt if set + prompt = await skill.get_prompt() + + # Auto-generate prompt from tool descriptions if no custom prompt + if not prompt: + tools_desc = skill.get_tools_description() + if tools_desc: + prompt = f"Available tools:\n{tools_desc}" + + if prompt: + skill_prompts += "\n\n" + skill.name + "\n\n" + prompt + + # Get TTS prompt based on active TTS provider and user preference + tts_prompt = "" + if self._config.features.tts_provider == TtsProvider.ELEVENLABS: + if ( + self._config.elevenlabs.use_tts_prompt + and self._config.elevenlabs.tts_prompt + ): + tts_prompt = self._config.elevenlabs.tts_prompt + elif self._config.features.tts_provider == TtsProvider.INWORLD or ( + self._config.features.tts_provider == TtsProvider.WINGMAN_PRO + and self._config.wingman_pro.tts_provider == WingmanProTtsProvider.INWORLD + ): + if self._config.inworld.use_tts_prompt and self._config.inworld.tts_prompt: + tts_prompt = self._config.inworld.tts_prompt + elif self._config.features.tts_provider == TtsProvider.OPENAI_COMPATIBLE: + if ( + self._config.openai_compatible_tts.use_tts_prompt + and self._config.openai_compatible_tts.tts_prompt + ): + tts_prompt = self._config.openai_compatible_tts.tts_prompt + + # Add TTS header only if there's a prompt + if tts_prompt: + tts_prompt = "# TEXT-TO-SPEECH\n" + tts_prompt + + # Build user context with environment metadata + user_context = self.build_user_context(config_dir_name=config_dir_name) + + # Sanity check: truncate if someone bypasses the client's 2048-token limit + MAX_BACKSTORY_TOKENS = 2048 + backstory = self._config.prompts.backstory + + if backstory and count_tokens(backstory) > MAX_BACKSTORY_TOKENS: + original_tokens = count_tokens(backstory) + backstory = truncate_to_tokens(backstory, MAX_BACKSTORY_TOKENS) + await printr.print_async( + f"[{self._wingman_name}] Backstory will be truncated to {MAX_BACKSTORY_TOKENS} tokens for conversations (is {original_tokens}). " + f"Your saved backstory is unchanged. Consider shortening it.", + color=LogType.WARNING, + source_name=self._wingman_name, + source=LogSource.SYSTEM, + ) + + # Build conversation summary section + conversation_summary_section = "" + if conversation_summary: + conversation_summary_section = ( + "# CONVERSATION SUMMARY\n" + "The following is a summary of earlier parts of this conversation. " + "Treat it as factual context — the user and you discussed these topics previously.\n\n" + + conversation_summary + ) + + # Persistent memory injection + persistent_memory_context = "" + if persistent_memory_service and messages: + # Use the most recent user message as the query + last_user_msg = "" + for msg in reversed(messages): + role = ( + msg.get("role") + if isinstance(msg, dict) + else getattr(msg, "role", None) + ) + raw_content = ( + msg.get("content", "") + if isinstance(msg, dict) + else getattr(msg, "content", "") + ) + # Extract plain text from multimodal content (images etc.) + content = ( + self._extract_text_content(raw_content) if raw_content else "" + ) + if role == "user" and content: + last_user_msg = content + break + if last_user_msg: + try: + persistent_memory_context = ( + await persistent_memory_service.build_memory_context( + last_user_msg + ) + ) + if ( + persistent_memory_context + and not self._memory_recall_notified + ): + self._memory_recall_notified = True + # Count restored fact lines (lines starting with "- ") + fact_count = sum( + 1 + for line in persistent_memory_context.splitlines() + if line.startswith("- ") + ) + if fact_count > 0: + await printr.print_async( + f"Memory recalled: {fact_count} relevant {'memory' if fact_count == 1 else 'memories'} loaded.", + color=LogType.MEMORY, + source_name=self._wingman_name, + ) + except Exception: + pass # Don't let memory failures break conversation + + context = self._config.prompts.system_prompt.format( + backstory=backstory, + skills=skill_prompts, + ttsprompt=tts_prompt, + user_context=user_context, + conversation_summary=conversation_summary_section, + ) + + # If the system prompt template doesn't include {conversation_summary}, + # append the summary at the end so it's never lost. + if ( + conversation_summary_section + and "{conversation_summary}" not in self._config.prompts.system_prompt + ): + context += "\n\n" + conversation_summary_section + + # Append persistent memory context + if persistent_memory_context: + context += "\n\n" + persistent_memory_context + + # Persistent memory tool instructions + if persistent_memory_service: + context += ( + "\n\n# PERSISTENT MEMORY\n" + "You have persistent memory. Important facts and past conversation summaries " + "are provided in the [Memory] sections above (if any). " + "You can use the `remember`, `recall`, and `forget` tools when the user " + "explicitly asks you to remember, recall, or forget something. " + "You don't need to use `remember` for routine information — that is handled automatically." + ) + + self._last_compiled_context = context + return context + + def get_last_context(self) -> str: + """Return the last compiled system context (cached from the most recent LLM call).""" + return self._last_compiled_context + + def build_user_context( + self, + config_dir_name: Optional[str] = None, + ) -> str: + """Build user context metadata for the system prompt. + + Includes timezone, config context, username, and wingman name. + """ + context_parts = [] + backstory = self._config.prompts.backstory or "" + backstory_lower = backstory.lower() + + # Date and timezone information + try: + now = datetime.now().astimezone() + local_tz = now.tzinfo + tz_name = str(local_tz) + # Get UTC offset in a readable format + utc_offset = now.strftime("%z") + # Format as +HH:MM or -HH:MM + if len(utc_offset) >= 5: + utc_offset = f"{utc_offset[:3]}:{utc_offset[3:]}" + # Include current date for relative date references ("last Sunday", "tomorrow", etc.) + current_date = now.strftime( + "%A, %B %d, %Y" + ) # e.g., "Tuesday, December 09, 2025" + context_parts.append(f"- Current date: {current_date}") + context_parts.append(f"- Timezone: {tz_name} (UTC{utc_offset})") + except Exception: + context_parts.append("- Timezone: Unknown") + + # Config/context name (e.g., "Star Citizen", "Elite Dangerous") + # This helps the LLM understand which game/context tools are relevant for + if config_dir_name: + context_parts.append(f"- Active context: {config_dir_name}") + + # Username (only if not explicitly named in backstory) + if self._settings.user_name: + # Check if username is mentioned in backstory as a standalone word + name_pattern = r"\b" + re.escape(self._settings.user_name.lower()) + r"\b" + if not re.search(name_pattern, backstory_lower): + context_parts.append( + f"- User's name (default): {self._settings.user_name}" + ) + + # Wingman name - always include as it's useful context + # The system prompt already tells LLM to prioritize backstory names + if self._wingman_name: + context_parts.append(f"- Your name (default): {self._wingman_name}") + + if context_parts: + return "\n".join(context_parts) + return "No additional context available." + + async def add_context(self, messages: list, context: str) -> None: + """Insert the compiled context as the system message at the start of messages.""" + messages.insert(0, {"role": "system", "content": context}) + + def reset_memory_notification(self) -> None: + """Reset the memory recall notification flag (e.g., on conversation reset).""" + self._memory_recall_notified = False + + @staticmethod + def _extract_text_content(content) -> str: + """Extract text from message content, handling both string and multimodal list formats.""" + if isinstance(content, str): + return content + if isinstance(content, list): + # Multimodal content: extract text parts only + parts = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + parts.append(part.get("text", "")) + elif isinstance(part, str): + parts.append(part) + return " ".join(parts) + return "" From dc1f8ec4516b88a4a1c35692e7a9c73ea5b0adc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Thu, 9 Apr 2026 11:18:04 +0200 Subject: [PATCH 09/34] feat: extract ToolExecutor from OpenAiWingman Co-Authored-By: Claude Opus 4.6 --- services/tool_executor.py | 515 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 515 insertions(+) create mode 100644 services/tool_executor.py diff --git a/services/tool_executor.py b/services/tool_executor.py new file mode 100644 index 00000000..bebdadb0 --- /dev/null +++ b/services/tool_executor.py @@ -0,0 +1,515 @@ +"""Tool call dispatch -- routes function calls to memory, skills, MCP, commands, etc.""" + +import json +import time +import traceback +from typing import TYPE_CHECKING, Callable, Awaitable + +from api.enums import LogType +from services.benchmark import Benchmark +from services.printr import Printr +from services.tool_response_cache import ToolResponseCompressor + +if TYPE_CHECKING: + from api.interface import CommandConfig, WingmanConfig, SettingsConfig + from services.capability_registry import CapabilityRegistry + from services.mcp_registry import McpRegistry + from services.persistent_memory import PersistentMemoryService + from services.skill_registry import SkillRegistry + from skills.skill_base import Skill + +printr = Printr() + + +class ToolExecutor: + """Dispatches tool calls to the appropriate handler (memory, skills, MCP, commands). + + This is a stateless dispatcher -- all mutable state (registries, services, callbacks) + is passed as parameters to ``handle_tool_calls`` / ``execute_by_function_call``. + """ + + def __init__( + self, + config: "WingmanConfig", + settings: "SettingsConfig", + wingman_name: str, + ): + self._config = config + self._settings = settings + self._wingman_name = wingman_name + self._tool_response_compressor = ToolResponseCompressor() + + # ------------------------------------------------------------------ + # fix_tool_calls (was _fix_tool_calls) + # ------------------------------------------------------------------ + + async def fix_tool_calls( + self, + tool_calls, + get_command_fn: Callable[[str], "CommandConfig | None"], + ): + """Fixes tool calls that have a command name as function name. + + Mistral sometimes returns the command name directly as the function name + instead of wrapping it in ``execute_command``. This method detects that + pattern and rewrites the tool call accordingly. + + Args: + tool_calls: The tool calls to fix. + get_command_fn: Callback ``(name) -> CommandConfig | None`` used to + check whether a string is a known command name. + + Returns: + list: The fixed tool calls. + """ + if tool_calls and len(tool_calls) > 0: + for tool_call in tool_calls: + function_name = tool_call.function.name + function_args = ( + tool_call.function.arguments + # Mistral returns a dict + if isinstance(tool_call.function.arguments, dict) + # OpenAI returns a string + else json.loads(tool_call.function.arguments) + ) + + # try to resolve function name to a command name + if (len(function_args) == 0 and get_command_fn(function_name)) or ( + len(function_args) == 1 + and "command_name" in function_args + and get_command_fn(function_args["command_name"]) + and function_name == function_args["command_name"] + ): + function_args["command_name"] = function_name + function_name = "execute_command" + + # update the tool call + tool_call.function.name = function_name + tool_call.function.arguments = json.dumps(function_args) + + if self._settings.debug_mode: + await printr.print_async( + "Applied command call fix.", color=LogType.WARNING + ) + + return tool_calls + + # ------------------------------------------------------------------ + # handle_tool_calls (was _handle_tool_calls) + # ------------------------------------------------------------------ + + async def handle_tool_calls( + self, + tool_calls, + *, + tool_skills: dict, + skill_registry: "SkillRegistry", + mcp_registry: "McpRegistry", + capability_registry: "CapabilityRegistry", + persistent_memory_service: "PersistentMemoryService | None", + get_command_fn: Callable[[str], "CommandConfig | None"], + execute_command_fn: Callable[["CommandConfig", bool], Awaitable[tuple]], + play_to_user_fn: Callable[[str], Awaitable[None]], + local_ai_service, + update_tool_response_fn: Callable[[str, str], Awaitable[bool]], + add_tool_response_fn: Callable, + pending_tool_calls: list, + ): + """Processes all the tool calls identified in the response message. + + Args: + tool_calls: The list of tool calls to process. + tool_skills: Mapping of tool name -> Skill instance. + skill_registry: For skill search/activation. + mcp_registry: For MCP tool execution. + capability_registry: For capability management. + persistent_memory_service: For memory tools (may be None). + get_command_fn: Callback to get a Command by name. + execute_command_fn: Callback to execute a Command. + play_to_user_fn: Callback to play audio response. + local_ai_service: For tool response compression. + update_tool_response_fn: Callback to update an existing tool + response in conversation history. + add_tool_response_fn: Callback to add a new tool response to + conversation history. + pending_tool_calls: List reference for tracking pending calls. + + Returns: + tuple: (instant_response, skill, tool_timings) where tool_timings + is a list of (label, time_ms) tuples. + """ + instant_response = None + function_response = "" + tool_timings: list[tuple[str, float]] = [] + + skill = None + + for tool_call in tool_calls: + try: + function_name = tool_call.function.name + function_args = ( + tool_call.function.arguments + # Mistral returns a dict + if isinstance(tool_call.function.arguments, dict) + # OpenAI returns a string + else json.loads(tool_call.function.arguments) + ) + + # Time the individual tool execution + tool_start = time.perf_counter() + ( + function_response, + instant_response, + skill, + tool_label, + ) = await self.execute_by_function_call( + function_name, + function_args, + tool_skills=tool_skills, + skill_registry=skill_registry, + mcp_registry=mcp_registry, + capability_registry=capability_registry, + persistent_memory_service=persistent_memory_service, + get_command_fn=get_command_fn, + execute_command_fn=execute_command_fn, + play_to_user_fn=play_to_user_fn, + ) + tool_time_ms = (time.perf_counter() - tool_start) * 1000 + + # Add timing if we got a label (actual tool execution, not meta-tool) + if tool_label: + tool_timings.append((tool_label, tool_time_ms)) + + # Compress large tool responses via local AI before the cloud LLM sees them + if ( + tool_call.id + and self._config.features.compress_tool_responses + and local_ai_service + and local_ai_service.is_ready() + and self._tool_response_compressor.should_compress( + str(function_response) + ) + ): + function_response = await self._tool_response_compressor.compress( + response_text=str(function_response), + local_ai_service=local_ai_service, + wingman_name=self._wingman_name, + tool_name=function_name, + ) + + if tool_call.id: + # updating the dummy tool response with the actual response + await update_tool_response_fn(tool_call.id, function_response) + else: + # adding a new tool response + add_tool_response_fn(tool_call, function_response) + except Exception as e: + add_tool_response_fn(tool_call, "Error") + await printr.print_async( + f"Error while processing tool call: {str(e)}", color=LogType.ERROR + ) + printr.print( + traceback.format_exc(), color=LogType.ERROR, server_only=True + ) + return instant_response, skill, tool_timings + + # ------------------------------------------------------------------ + # execute_by_function_call (was execute_command_by_function_call) + # ------------------------------------------------------------------ + + async def execute_by_function_call( + self, + function_name: str, + function_args: dict[str, any], + *, + tool_skills: dict, + skill_registry: "SkillRegistry", + mcp_registry: "McpRegistry", + capability_registry: "CapabilityRegistry", + persistent_memory_service: "PersistentMemoryService | None", + get_command_fn: Callable[[str], "CommandConfig | None"], + execute_command_fn: Callable[["CommandConfig", bool], Awaitable[tuple]], + play_to_user_fn: Callable[[str], Awaitable[None]], + ) -> tuple[str, str | None, "Skill | None", str | None]: + """Dispatches a single function call to the appropriate handler. + + Uses an OpenAI function call to execute a command. If it's an instant + activation_command, one of its responses will be played. + + Args: + function_name: The name of the function to be executed. + function_args: The arguments to pass to the function being executed. + tool_skills: Mapping of tool name -> Skill instance. + skill_registry: For skill search/activation and display names. + mcp_registry: For MCP tool execution and meta-tools. + capability_registry: For capability meta-tools. + persistent_memory_service: For memory tools (may be None). + get_command_fn: Callback ``(name) -> CommandConfig | None``. + execute_command_fn: Callback to execute a Command. + play_to_user_fn: Callback to play audio response. + + Returns: + A tuple containing: + - function_response (str): The text response or result obtained + after executing the function. + - instant_response (str | None): An immediate response or action + to be taken, if any (e.g., play audio). + - used_skill (Skill | None): The skill that was used, if any. + - tool_label (str | None): Label for benchmark timing + (e.g., "MCP: resolve-library-id"), or None for meta-tools. + """ + function_response = "" + instant_response = "" + used_skill = None + tool_label = None + + # ── 1. Persistent memory tools ────────────────────────────── + if ( + function_name in ("memory_remember", "memory_recall", "memory_forget") + and persistent_memory_service + ): + if function_name == "memory_remember": + text = function_args.get("text", "") + if text: + await persistent_memory_service.add_memory( + entry_type="fact", content=text + ) + function_response = f'I\'ll remember that: "{text}"' + await printr.print_async( + f"Memory stored: {text}", + color=LogType.MEMORY, + source_name=self._wingman_name, + ) + else: + function_response = "Nothing to remember -- no text provided." + + elif function_name == "memory_recall": + query = function_args.get("query", "") + if query: + results = await persistent_memory_service.search( + query, limit=10 + ) + if results: + lines = [f"- {r.content}" for r in results] + function_response = ( + "Here's what I remember:\n" + "\n".join(lines) + ) + else: + function_response = ( + "I don't have any memories matching that." + ) + else: + function_response = "No query provided for memory recall." + + elif function_name == "memory_forget": + query = function_args.get("query", "") + if query: + deleted = await persistent_memory_service.forget_by_query(query) + if deleted: + function_response = ( + f'Done -- I\'ve forgotten the memory related to "{query}".' + ) + else: + function_response = ( + "I couldn't find a memory closely matching that to forget." + ) + else: + function_response = "No query provided for memory forget." + + return function_response, None, None, f"💾 memory: {function_name}" + + # ── 2. Unified capability meta-tools ──────────────────────── + if capability_registry.is_meta_tool(function_name): + function_response, tools_changed = ( + await capability_registry.execute_meta_tool( + function_name, function_args + ) + ) + + # If a skill was activated, perform lazy validation + if tools_changed and function_name == "activate_capability": + capability_name = function_args.get("capability_name", "") + skill = skill_registry.get_skill_for_activation(capability_name) + if skill and skill.needs_activation(): + success, validation_msg = await skill.ensure_activated() + if not success: + # Validation failed -- deactivate the skill + skill_registry.deactivate_skill(capability_name) + function_response = validation_msg + tools_changed = False + await printr.print_async( + f"Skill activation failed: {capability_name}", + color=LogType.ERROR, + ) + else: + # Get display name for user-friendly message + display_name = skill_registry.get_skill_display_name( + capability_name + ) + await printr.print_async( + f"Skill activated: {display_name}", + color=LogType.SKILL, + ) + + return function_response, None, None, None # Meta-tool, no timing label + + # ── 3. Legacy skill meta-tools ────────────────────────────── + if skill_registry.is_meta_tool(function_name): + function_response, tools_changed = ( + await skill_registry.execute_meta_tool(function_name, function_args) + ) + + # If skill was activated, perform lazy validation + if tools_changed and function_name == "activate_skill": + skill_name = function_args.get("skill_name", "") + skill = skill_registry.get_skill_for_activation(skill_name) + if skill and skill.needs_activation(): + success, validation_msg = await skill.ensure_activated() + if not success: + # Validation failed -- deactivate the skill + skill_registry.deactivate_skill(skill_name) + function_response = validation_msg + tools_changed = False + await printr.print_async( + f"Skill activation failed: {skill_name}", + color=LogType.ERROR, + ) + else: + # Get display name for user-friendly message + display_name = skill_registry.get_skill_display_name( + skill_name + ) + await printr.print_async( + f"Skill activated: {display_name}", + color=LogType.SKILL, + ) + + return function_response, None, None, None # Meta-tool, no timing label + + # ── 4. MCP meta-tools ─────────────────────────────────────── + if mcp_registry.is_meta_tool(function_name): + function_response, tools_changed = ( + await mcp_registry.execute_meta_tool(function_name, function_args) + ) + return function_response, None, None, None # Meta-tool, no timing label + + # ── 5. MCP server tools (prefixed with mcp_) ──────────────── + if mcp_registry.is_mcp_tool(function_name): + connection = mcp_registry.get_connection_for_tool(function_name) + if connection: + display_name = connection.config.display_name + original_name = mcp_registry.get_original_tool_name(function_name) + tool_label = f"🌐 {display_name}: {original_name}" + + benchmark = Benchmark( + f"MCP '{connection.config.name}' - {original_name}" + ) + + # Always show simple 'called' message in UI + await printr.print_async( + f"{display_name}: called `{original_name}` with {function_args}", + color=LogType.MCP, + ) + + # Detailed 'calling' log only in terminal/log file + await printr.print_async( + f"{display_name}: calling `{original_name}` with {function_args}...", + color=LogType.MCP, + server_only=True, + ) + + try: + function_response = await mcp_registry.call_tool( + function_name, function_args + ) + except Exception as e: + await printr.print_async( + f"{display_name}: `{original_name}` failed - {str(e)}", + color=LogType.ERROR, + ) + printr.print( + traceback.format_exc(), + color=LogType.ERROR, + server_only=True, + ) + function_response = "ERROR DURING MCP TOOL EXECUTION" + finally: + # Detailed 'completed' with timing only in terminal/log file + await printr.print_async( + f"{display_name}: `{original_name}` completed", + color=LogType.MCP, + benchmark_result=benchmark.finish(), + server_only=not self._settings.debug_mode, + ) + + return function_response, None, None, tool_label + + # ── 6. Command execution ──────────────────────────────────── + if function_name == "execute_command": + # get the command based on the argument passed by the LLM + command = get_command_fn(function_args["command_name"]) + # execute the command + instant_response, function_response = await execute_command_fn( + command + ) + tool_label = ( + f"Command: {function_args.get('command_name', function_name)}" + ) + # if the command has responses, we have to play one of them + if instant_response: + await play_to_user_fn(instant_response) + + # ── 7. Skill tool execution ───────────────────────────────── + if function_name in tool_skills: + skill = tool_skills[function_name] + display_name = skill_registry.get_skill_display_name(skill.name) + tool_label = f"⚡ {display_name}: {function_name}" + + benchmark = Benchmark(f"Skill '{skill.name}' - {function_name}") + + # Always show simple 'called' message in UI + await printr.print_async( + f"{display_name}: called `{function_name}` with {function_args}", + color=LogType.SKILL, + skill_name=skill.name, + ) + + # Detailed 'calling' log only in terminal/log file + await printr.print_async( + f"{display_name}: calling `{function_name}` with {function_args}...", + color=LogType.SKILL, + skill_name=skill.name, + server_only=True, + ) + + try: + function_response, instant_response = await skill.execute_tool( + tool_name=function_name, + parameters=function_args, + benchmark=benchmark, + ) + used_skill = skill + if instant_response: + await play_to_user_fn(instant_response) + except Exception as e: + await printr.print_async( + f"{display_name}: `{function_name}` failed - {str(e)}", + color=LogType.ERROR, + ) + printr.print( + traceback.format_exc(), color=LogType.ERROR, server_only=True + ) + function_response = ( + "ERROR DURING PROCESSING" # hints to AI that there was an error + ) + instant_response = None + finally: + await printr.print_async( + f"{display_name}: `{function_name}` completed", + color=LogType.SKILL, + benchmark_result=benchmark.finish(), + skill_name=skill.name, + server_only=not self._settings.debug_mode, + ) + + return function_response, instant_response, used_skill, tool_label From 286fa9794cf6a3bfa3790f1a70a5879a266855e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Thu, 9 Apr 2026 11:18:49 +0200 Subject: [PATCH 10/34] =?UTF-8?q?milestone:=20M4=20complete=20=E2=80=94=20?= =?UTF-8?q?all=20services=20extracted=20(ConversationManager,=20Conversati?= =?UTF-8?q?onCondenser,=20ContextBuilder,=20ToolExecutor)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit From e247f57056977a08825b2589846cbf1b5bb4ae09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Thu, 9 Apr 2026 11:54:39 +0200 Subject: [PATCH 11/34] feat: merge OpenAiWingman into Wingman, wire factory + services Absorb all orchestration from OpenAiWingman into the unified Wingman class. Provider routing is now handled by ProviderFactory which creates SttInterface, TtsInterface, and LlmInterface instances. Conversation management, condensation, context building, and tool execution are delegated to the extracted services. open_ai_wingman.py is reduced to a backward-compatibility shim that re-exports Wingman as OpenAiWingman so existing skills, tower, and custom wingmen continue to work without modification. Co-Authored-By: Claude Opus 4.6 --- wingmen/open_ai_wingman.py | 3546 +----------------------------------- wingmen/wingman.py | 1610 ++++++++++++---- 2 files changed, 1307 insertions(+), 3849 deletions(-) diff --git a/wingmen/open_ai_wingman.py b/wingmen/open_ai_wingman.py index c7832ef3..be6af972 100644 --- a/wingmen/open_ai_wingman.py +++ b/wingmen/open_ai_wingman.py @@ -1,3541 +1,11 @@ -import json -import time -import asyncio -import random -import traceback -import uuid -from datetime import datetime -from typing import ( - Mapping, - Optional, -) -from openai import APIConnectionError, NOT_GIVEN -from openai.types.chat import ( - ChatCompletion, - ChatCompletionMessage, - ChatCompletionMessageToolCall, - ParsedFunction, -) -import requests -from api.interface import ( - OpenRouterEndpointResult, - SettingsConfig, - SoundConfig, - WingmanInitializationError, - CommandConfig, -) -from api.enums import ( - ImageGenerationProvider, - LogType, - LogSource, - TtsProvider, - SttProvider, - ConversationProvider, - WingmanProSttProvider, - WingmanProTtsProvider, - WingmanInitializationErrorType, -) -from providers.edge import Edge -from providers.elevenlabs import ElevenLabs -from providers.google import GoogleGenAI -from providers.open_ai import OpenAi, OpenAiAzure, OpenAiCompatibleTts -from providers.hume import Hume -from providers.inworld import Inworld -from providers.open_ai import OpenAi, OpenAiAzure -from providers.x_ai import XAi -from providers.wingman_subscription import WingmanSubscription -from api.commands import McpStateChangedCommand -from services.benchmark import Benchmark -from services.file import get_prompt -from services.token_utils import count_tokens, truncate_to_tokens -from services.markdown import cleanup_text -from services.printr import Printr -from services.skill_registry import SkillRegistry -from services.mcp_client import McpClient -from services.mcp_registry import McpRegistry -from services.capability_registry import CapabilityRegistry -from services.tool_response_cache import ToolResponseCompressor -from skills.skill_base import Skill -from wingmen.wingman import Wingman +"""Backward-compatibility shim. -printr = Printr() +``OpenAiWingman`` has been merged into :class:`wingmen.wingman.Wingman`. +This module re-exports ``Wingman`` under the old name so that existing +skills, custom wingmen, tower, and other code that imports +``from wingmen.open_ai_wingman import OpenAiWingman`` continues to work. +""" -# Max seconds to wait for the local support model during conversation condensation. -_CONDENSE_TIMEOUT = 120.0 +from wingmen.wingman import Wingman as OpenAiWingman # noqa: F401 - -class OpenAiWingman(Wingman): - """Our OpenAI Wingman base gives you everything you need to interact with OpenAI's various APIs. - - It transcribes speech to text using Whisper, uses the Completion API for conversation and implements the Tools API to execute functions. - """ - - AZURE_SERVICES = { - "tts": TtsProvider.AZURE, - "whisper": [SttProvider.AZURE, SttProvider.AZURE_SPEECH], - "conversation": ConversationProvider.AZURE, - } - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.edge_tts = Edge() - - # validate will set these: - self.openai: OpenAi | None = None - self.mistral: OpenAi | None = None - self.groq: OpenAi | None = None - self.cerebras: OpenAi | None = None - self.openrouter: OpenAi | None = None - self.openrouter_model_supports_tools = False - self.local_llm: OpenAi | None = None - self.openai_azure: OpenAiAzure | None = None - self.elevenlabs: ElevenLabs | None = None - self.openai_compatible_tts: OpenAiCompatibleTts | None = None - self.hume: Hume | None = None - self.inworld: Inworld | None = None - self.wingman_pro: WingmanSubscription | None = None - self.google: GoogleGenAI | None = None - self.perplexity: OpenAi | None = None - self.xai: XAi | None = None - - # tool queue - self.pending_tool_calls = [] - self.last_gpt_call = None - - # generated addional content - self.instant_responses = [] - self.last_used_instant_responses = [] - - self.messages = [] - self.conversation_summary: str = "" - self._last_compiled_context: str = "" - self._is_condensing = False - self._condense_task: asyncio.Task | None = None - self._last_prompt_tokens: int = 0 - """Last API-reported prompt_tokens from the conversation LLM.""" - self._support_token_ratio: float = 1.35 - """Calibration ratio: support-model tokens / cl100k_base tokens. - Starts at 1.35 (conservative — Qwen typically tokenizes English text - ~20-40% heavier than cl100k_base). Updated after the first support - call using real usage reported by the model's own tokenizer.""" - """The conversation history that is used for the GPT calls""" - - self.azure_api_keys = {key: None for key in self.AZURE_SERVICES} - - self.tool_skills: dict[str, Skill] = {} - self.skill_tools: list[dict] = [] - - # Progressive tool disclosure registry (MCP-inspired token optimization) - # Only meta-tools are sent to LLM initially; skills activated on-demand - self.skill_registry = SkillRegistry() - - # MCP (Model Context Protocol) support - # Allows connecting to external MCP servers that provide additional tools - self.mcp_client = McpClient(wingman_name=self.name) - self.mcp_registry = McpRegistry( - self.mcp_client, - wingman_name=self.name, - on_state_changed=self._broadcast_mcp_state_changed, - ) - - # Unified capability registry - combines skill and MCP discovery - # From the LLM's perspective, both are just "capabilities" - self.capability_registry = CapabilityRegistry( - self.skill_registry, self.mcp_registry - ) - - # Local AI service — set externally by WingmanCore if available - self.local_ai_service = None - - # Persistent memory — initialized lazily when local_ai_service is available - self.persistent_memory_service = None - self._memory_recall_notified = False - self._background_tasks: set[asyncio.Task] = set() - - # Stateless tool response compressor — summarizes large tool/MCP responses - # via local AI before the cloud LLM sees them - self._tool_response_compressor = ToolResponseCompressor() - - def _broadcast_mcp_state_changed(self): - """Broadcast MCP state change to UI via WebSocket.""" - if printr._connection_manager: - printr.ensure_async( - printr._connection_manager.broadcast( - McpStateChangedCommand(wingman_name=self.name) - ) - ) - - async def validate(self): - errors = await super().validate() - - try: - if self.uses_provider("whispercpp"): - self.whispercpp.validate(self.name, errors) - - if self.uses_provider("fasterwhisper"): - self.fasterwhisper.validate(errors) - - if self.uses_provider("parakeet"): - self.parakeet.validate(errors) - - if self.uses_provider("pocket_tts"): - self.pocket_tts.validate(errors) - - if self.uses_provider("openai"): - await self.validate_and_set_openai(errors) - - if self.uses_provider("mistral"): - await self.validate_and_set_mistral(errors) - - if self.uses_provider("groq"): - await self.validate_and_set_groq(errors) - - if self.uses_provider("cerebras"): - await self.validate_and_set_cerebras(errors) - - if self.uses_provider("google"): - await self.validate_and_set_google(errors) - - if self.uses_provider("openrouter"): - await self.validate_and_set_openrouter(errors) - - if self.uses_provider("local_llm"): - await self.validate_and_set_local_llm(errors) - - if self.uses_provider("elevenlabs"): - await self.validate_and_set_elevenlabs(errors) - - if self.uses_provider("openai_compatible"): - await self.validate_and_set_openai_compatible_tts(errors) - - if self.uses_provider("azure"): - await self.validate_and_set_azure(errors) - - if self.uses_provider("wingman_pro"): - await self.validate_and_set_wingman_pro() - - if self.uses_provider("perplexity"): - await self.validate_and_set_perplexity(errors) - - if self.uses_provider("xai"): - await self.validate_and_set_xai(errors) - - if self.uses_provider("hume"): - await self.validate_and_set_hume(errors) - - if self.uses_provider("inworld"): - await self.validate_and_set_inworld(errors) - - except Exception as e: - errors.append( - WingmanInitializationError( - wingman_name=self.name, - message=f"Error during provider validation: {str(e)}", - error_type=WingmanInitializationErrorType.UNKNOWN, - ) - ) - printr.print( - f"Error during provider validation: {str(e)}", - color=LogType.ERROR, - server_only=True, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - - return errors - - def uses_provider(self, provider_type: str): - if provider_type == "openai": - return any( - [ - self.config.features.tts_provider == TtsProvider.OPENAI, - self.config.features.stt_provider == SttProvider.OPENAI, - self.config.features.conversation_provider - == ConversationProvider.OPENAI, - self.config.features.image_generation_provider - == ImageGenerationProvider.OPENAI, - ] - ) - elif provider_type == "mistral": - return any( - [ - self.config.features.conversation_provider - == ConversationProvider.MISTRAL, - ] - ) - elif provider_type == "groq": - return any( - [ - self.config.features.conversation_provider - == ConversationProvider.GROQ, - self.config.features.stt_provider == SttProvider.GROQ, - ] - ) - elif provider_type == "cerebras": - return any( - [ - self.config.features.conversation_provider - == ConversationProvider.CEREBRAS, - ] - ) - elif provider_type == "google": - return any( - [ - self.config.features.conversation_provider - == ConversationProvider.GOOGLE, - ] - ) - elif provider_type == "openrouter": - return any( - [ - self.config.features.conversation_provider - == ConversationProvider.OPENROUTER, - ] - ) - elif provider_type == "local_llm": - return any( - [ - self.config.features.conversation_provider - == ConversationProvider.LOCAL_LLM, - ] - ) - elif provider_type == "azure": - return any( - [ - self.config.features.tts_provider == TtsProvider.AZURE, - self.config.features.stt_provider == SttProvider.AZURE, - self.config.features.stt_provider == SttProvider.AZURE_SPEECH, - self.config.features.conversation_provider - == ConversationProvider.AZURE, - ] - ) - elif provider_type == "edge_tts": - return self.config.features.tts_provider == TtsProvider.EDGE_TTS - elif provider_type == "elevenlabs": - return self.config.features.tts_provider == TtsProvider.ELEVENLABS - elif provider_type == "openai_compatible": - return self.config.features.tts_provider == TtsProvider.OPENAI_COMPATIBLE - elif provider_type == "pocket_tts": - return self.config.features.tts_provider == TtsProvider.POCKET_TTS - elif provider_type == "hume": - return self.config.features.tts_provider == TtsProvider.HUME - elif provider_type == "inworld": - return self.config.features.tts_provider == TtsProvider.INWORLD - elif provider_type == "xvasynth": - return self.config.features.tts_provider == TtsProvider.XVASYNTH - elif provider_type == "whispercpp": - return self.config.features.stt_provider == SttProvider.WHISPERCPP - elif provider_type == "fasterwhisper": - return self.config.features.stt_provider == SttProvider.FASTER_WHISPER - elif provider_type == "parakeet": - return self.config.features.stt_provider == SttProvider.PARAKEET - elif provider_type == "wingman_pro": - return any( - [ - self.config.features.conversation_provider - == ConversationProvider.WINGMAN_PRO, - self.config.features.tts_provider == TtsProvider.WINGMAN_PRO, - self.config.features.stt_provider == SttProvider.WINGMAN_PRO, - self.config.features.image_generation_provider - == ImageGenerationProvider.WINGMAN_PRO, - ] - ) - elif provider_type == "perplexity": - return any( - [ - self.config.features.conversation_provider - == ConversationProvider.PERPLEXITY, - ] - ) - elif provider_type == "xai": - return any( - [ - self.config.features.conversation_provider - == ConversationProvider.XAI, - ] - ) - return False - - async def prepare(self): - try: - if self.config.features.use_generic_instant_responses: - printr.print( - "Generating AI instant responses...", - color=LogType.WARNING, - server_only=True, - ) - self.threaded_execution(self._generate_instant_responses) - except Exception as e: - await printr.print_async( - f"Error while preparing wingman '{self.name}': {str(e)}", - color=LogType.ERROR, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - - def ensure_memory_initialized(self) -> bool: - """Initialize persistent memory service if not yet done. - - Returns True if the service is available after this call. - """ - # Tear down service if the toggle was turned off - if self.persistent_memory_service and not self.config.persistent_memory: - self.persistent_memory_service.close() - self.persistent_memory_service = None - return False - if self.persistent_memory_service: - return True - if self.config.persistent_memory and self.local_ai_service: - from services.persistent_memory import PersistentMemoryService - - self.persistent_memory_service = PersistentMemoryService( - wingman_name=self.name, - local_ai_service=self.local_ai_service, - ) - self.persistent_memory_service.initialize() - return True - return False - - async def unload(self): - """Extract memories, close clients, and close DB before unloading.""" - # Wait for any background memory extraction tasks to finish - if self._background_tasks: - await asyncio.gather(*self._background_tasks, return_exceptions=True) - self._background_tasks.clear() - - if self.persistent_memory_service: - from services.persistent_memory import MIN_MESSAGES_FOR_EXTRACTION - - if len(self.messages) >= MIN_MESSAGES_FOR_EXTRACTION: - try: - await self.persistent_memory_service.extract_memories( - self.messages, generate_summary=True - ) - except Exception: - pass - self.persistent_memory_service.close() - - if self.google: - await self.google.aclose() - self.google = None - - await super().unload() - - async def unload_skills(self): - await super().unload_skills() - self.tool_skills = {} - self.skill_tools = [] - self.skill_registry.clear() - - async def unload_mcps(self): - """Disconnect from all MCP servers.""" - await self.mcp_registry.clear() - - async def enable_mcp(self, mcp_name: str) -> tuple[bool, str]: - """Enable and connect to a single MCP server without reinitializing all MCPs. - - Args: - mcp_name: The name of the MCP server to enable - - Returns: - (success, message) tuple - """ - # Check if MCP SDK is available - if not self.mcp_client.is_available: - return False, "MCP SDK not installed." - - # Check if already connected - if mcp_name in self.mcp_registry.get_connected_server_names(): - return True, f"MCP server '{mcp_name}' is already connected." - - # Find the MCP config from central mcp.yaml - central_mcp_config = self.tower.config_manager.mcp_config - mcp_configs = central_mcp_config.servers if central_mcp_config else [] - - mcp_config = None - for cfg in mcp_configs: - if cfg.name == mcp_name: - mcp_config = cfg - break - - if not mcp_config: - return False, f"MCP server '{mcp_name}' not found in mcp.yaml." - - try: - # Build headers with secrets (same logic as init_mcps) - headers = {} - if mcp_config.headers: - headers.update(mcp_config.headers) - - # Check for API key in secrets - secret_key = f"mcp_{mcp_config.name}" - api_key = await self.secret_keeper.retrieve( - requester=self.name, - key=secret_key, - prompt_if_missing=False, - ) - if api_key: - if not any( - k.lower() in ["authorization", "api-key", "x-api-key"] - for k in headers.keys() - ): - headers["Authorization"] = f"Bearer {api_key}" - - # Connect with timeout - default_timeout = 60.0 if mcp_config.type.value == "stdio" else 30.0 - timeout = ( - float(mcp_config.timeout) if mcp_config.timeout else default_timeout - ) - - connection = await asyncio.wait_for( - self.mcp_registry.register_server( - config=mcp_config, - headers=headers if headers else None, - ), - timeout=timeout, - ) - - if connection.is_connected: - tool_count = len(connection.tools) - return True, f"MCP server '{mcp_name}' enabled with {tool_count} tools." - else: - error = connection.error or "Connection failed." - return False, f"MCP server '{mcp_name}' failed to connect: {error}" - - except asyncio.TimeoutError: - error_msg = f"Connection timed out ({int(timeout)}s)." - self.mcp_registry.set_server_error(mcp_name, error_msg) - return False, f"MCP server '{mcp_name}': {error_msg}" - - except Exception as e: - error_msg = f"Error enabling MCP '{mcp_name}': {str(e)}" - await printr.print_async(error_msg, color=LogType.ERROR) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - return False, error_msg - - async def disable_mcp(self, mcp_name: str) -> tuple[bool, str]: - """Disable and disconnect from a single MCP server without affecting other MCPs. - - Args: - mcp_name: The name of the MCP server to disable - - Returns: - (success, message) tuple - """ - # Check if the MCP is connected - if mcp_name not in self.mcp_registry.get_connected_server_names(): - return True, f"MCP server '{mcp_name}' is already disconnected." - - try: - await self.mcp_registry.unregister_server(mcp_name) - return True, f"MCP server '{mcp_name}' disabled." - - except Exception as e: - error_msg = f"Error disabling MCP '{mcp_name}': {str(e)}" - await printr.print_async(error_msg, color=LogType.ERROR) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - return False, error_msg - - async def prepare_skill(self, skill: Skill): - # prepare the skill and skill tools - try: - for tool_name, tool in skill.get_tools(): - self.tool_skills[tool_name] = skill - self.skill_tools.append(tool) - - # Register with the progressive disclosure registry - self.skill_registry.register_skill(skill) - - # Auto-activated skills need to be validated/prepared immediately - # so their hooks (like on_play_to_user) will work - if skill.config.auto_activate: - success, message = await skill.ensure_activated() - if not success: - await printr.print_async( - f"Auto-activated skill '{skill.config.display_name}' failed to activate: {message}", - color=LogType.ERROR, - ) - except Exception as e: - await printr.print_async( - f"Error while preparing skill '{skill.name}': {str(e)}", - color=LogType.ERROR, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - - # init skill methods - skill.llm_call = self.actual_llm_call - - async def unprepare_skill(self, skill: Skill): - """Remove a skill's tools and registrations when it's disabled.""" - try: - # Remove tool mappings - for tool_name, _ in skill.get_tools(): - self.tool_skills.pop(tool_name, None) - # Remove from skill_tools list - self.skill_tools = [ - t - for t in self.skill_tools - if t.get("function", {}).get("name") != tool_name - ] - - # Unregister from the progressive disclosure registry - self.skill_registry.unregister_skill(skill.name) - except Exception as e: - await printr.print_async( - f"Error while unpreparing skill '{skill.name}': {str(e)}", - color=LogType.ERROR, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - - async def init_mcps(self) -> list[WingmanInitializationError]: - """ - Initialize MCP (Model Context Protocol) server connections. - - Loads MCP servers from central mcp.yaml config, only connecting those in wingman's discoverable_mcps. - MCP servers provide external tools similar to skills. - - Returns: - list[WingmanInitializationError]: Errors encountered (non-fatal, wingman still loads) - """ - errors = [] - - # Check if MCP SDK is available - if not self.mcp_client.is_available: - printr.print( - f"[{self.name}] MCP SDK not installed, skipping MCP initialization.", - color=LogType.WARNING, - server_only=True, - ) - return errors - - # Disconnect existing MCP servers - await self.unload_mcps() - - # Get MCP configs from central mcp.yaml - central_mcp_config = self.tower.config_manager.mcp_config - mcp_configs = central_mcp_config.servers if central_mcp_config else [] - if not mcp_configs: - return errors - - # Get discoverable MCPs list (whitelist) from wingman config - discoverable_mcps = self.config.discoverable_mcps - - # Filter to only discoverable MCPs - mcps_to_connect = [mcp for mcp in mcp_configs if mcp.name in discoverable_mcps] - - if not mcps_to_connect: - return errors - - # Prepare connection tasks for parallel execution - async def connect_mcp(mcp_config): - """Connect to a single MCP server. Returns (success, connection_info, errors).""" - local_errors = [] - try: - # Build headers with secrets - headers = {} - if mcp_config.headers: - headers.update(mcp_config.headers) - - # Check for API key in secrets (using mcp_ prefix) - secret_key = f"mcp_{mcp_config.name}" - api_key = await self.secret_keeper.retrieve( - requester=self.name, - key=secret_key, - prompt_if_missing=False, - ) - if api_key: - printr.print( - f"MCP secret '{secret_key}' found ({len(api_key)} chars)", - color=LogType.INFO, - source_name=self.name, - server_only=True, - ) - if not any( - k.lower() in ["authorization", "api-key", "x-api-key"] - for k in headers.keys() - ): - headers["Authorization"] = f"Bearer {api_key}" - - # Connect with timeout - default_timeout = 60.0 if mcp_config.type.value == "stdio" else 30.0 - timeout = ( - float(mcp_config.timeout) if mcp_config.timeout else default_timeout - ) - - try: - connection = await asyncio.wait_for( - self.mcp_registry.register_server( - config=mcp_config, - headers=headers if headers else None, - ), - timeout=timeout, - ) - except asyncio.TimeoutError: - error_msg = f"MCP '{mcp_config.display_name}' connection timed out ({int(timeout)}s)." - printr.print( - error_msg, - color=LogType.WARNING, - source_name=self.name, - server_only=True, - ) - local_errors.append( - WingmanInitializationError( - wingman_name=self.name, - message=error_msg, - error_type=WingmanInitializationErrorType.MCP_CONNECTION_FAILED, - ) - ) - return (False, None, local_errors) - - if connection.is_connected: - return ( - True, - f"{mcp_config.display_name} ({len(connection.tools)} tools)", - local_errors, - ) - else: - error_msg = f"MCP '{mcp_config.display_name}' failed to connect: {connection.error}" - local_errors.append( - WingmanInitializationError( - wingman_name=self.name, - message=error_msg, - error_type=WingmanInitializationErrorType.MCP_CONNECTION_FAILED, - ) - ) - return (False, None, local_errors) - - except Exception as e: - error_msg = f"MCP '{mcp_config.name}' initialization error: {str(e)}" - printr.print( - error_msg, - color=LogType.ERROR, - source_name=self.name, - server_only=True, - ) - printr.print( - traceback.format_exc(), color=LogType.ERROR, server_only=True - ) - local_errors.append( - WingmanInitializationError( - wingman_name=self.name, - message=error_msg, - error_type=WingmanInitializationErrorType.MCP_CONNECTION_FAILED, - ) - ) - return (False, None, local_errors) - - # Connect to all MCPs in parallel - connection_tasks = [connect_mcp(mcp) for mcp in mcps_to_connect] - results = await asyncio.gather(*connection_tasks) - - # Collect results - connected_count = 0 - connected_names = [] - for success, connection_info, mcp_errors in results: - if success: - connected_count += 1 - connected_names.append(connection_info) - errors.extend(mcp_errors) - - # Log consolidated MCP status for this wingman - if connected_count > 0: - await printr.print_async( - f"Discoverable MCP servers connected ({connected_count}): {', '.join(connected_names)}", - color=LogType.WINGMAN, - source=LogSource.WINGMAN, - source_name=self.name, - server_only=not self.settings.debug_mode, - ) - - return errors - - async def validate_and_set_openai(self, errors: list[WingmanInitializationError]): - api_key = await self.retrieve_secret("openai", errors) - if api_key: - self.openai = OpenAi( - api_key=api_key, - organization=self.config.openai.organization, - base_url=self.config.openai.base_url, - ) - - async def validate_and_set_mistral(self, errors: list[WingmanInitializationError]): - api_key = await self.retrieve_secret("mistral", errors) - if api_key: - # TODO: maybe use their native client (or LangChain) instead of OpenAI(?) - self.mistral = OpenAi( - api_key=api_key, - organization=self.config.openai.organization, - base_url=self.config.mistral.endpoint, - ) - - async def validate_and_set_groq(self, errors: list[WingmanInitializationError]): - api_key = await self.retrieve_secret("groq", errors) - if api_key: - # TODO: maybe use their native client (or LangChain) instead of OpenAI(?) - self.groq = OpenAi( - api_key=api_key, - base_url=self.config.groq.endpoint, - ) - - async def validate_and_set_cerebras(self, errors: list[WingmanInitializationError]): - api_key = await self.retrieve_secret("cerebras", errors) - if api_key: - # TODO: maybe use their native client (or LangChain) instead of OpenAI(?) - self.cerebras = OpenAi( - api_key=api_key, - base_url=self.config.cerebras.endpoint, - ) - - async def validate_and_set_google(self, errors: list[WingmanInitializationError]): - api_key = await self.retrieve_secret("google", errors) - if api_key: - if self.google: - await self.google.aclose() - self.google = GoogleGenAI(api_key=api_key) - - async def validate_and_set_openrouter( - self, errors: list[WingmanInitializationError] - ): - api_key = await self.retrieve_secret("openrouter", errors) - - async def does_openrouter_model_support_tools(model_id: str): - if not model_id: - return False - response = requests.get( - url=f"https://openrouter.ai/api/v1/models/{model_id}/endpoints", - timeout=10, - ) - response.raise_for_status() - content = response.json() - result = OpenRouterEndpointResult(**content.get("data", {})) - supports_tools = any( - all( - p in (endpoint.supported_parameters or []) - for p in ["tools", "tool_choice"] - ) - for endpoint in result.endpoints - ) - if not supports_tools: - printr.print( - f"{self.name}: OpenRouter model {model_id} does not support tools, so they'll be omitted from calls.", - source=LogSource.WINGMAN, - source_name=self.name, - color=LogType.WARNING, - server_only=True, - ) - return supports_tools - - if api_key: - self.openrouter = OpenAi( - api_key=api_key, - base_url=self.config.openrouter.endpoint, - ) - self.openrouter_model_supports_tools = ( - await does_openrouter_model_support_tools( - self.config.openrouter.conversation_model - ) - ) - - async def validate_and_set_local_llm( - self, errors: list[WingmanInitializationError] - ): - api_key = await self.retrieve_secret("local_llm", errors, is_required=False) - self.local_llm = OpenAi( - api_key=api_key or "local", - base_url=self.config.local_llm.endpoint, - ) - - async def validate_and_set_elevenlabs( - self, errors: list[WingmanInitializationError] - ): - api_key = await self.retrieve_secret("elevenlabs", errors) - if api_key: - self.elevenlabs = ElevenLabs( - api_key=api_key, - wingman_name=self.name, - ) - self.elevenlabs.validate_config( - config=self.config.elevenlabs, errors=errors - ) - - async def validate_and_set_openai_compatible_tts( - self, errors: list[WingmanInitializationError] - ): - if ( - self.config.openai_compatible_tts.base_url - and self.config.openai_compatible_tts.api_key - ): - self.openai_compatible_tts = OpenAiCompatibleTts( - api_key=self.config.openai_compatible_tts.api_key, - base_url=self.config.openai_compatible_tts.base_url, - ) - printr.print( - f"Wingman {self.name}: Initialized OpenAI-compatible TTS with base URL {self.config.openai_compatible_tts.base_url} and API key {self.config.openai_compatible_tts.api_key}", - server_only=True, - ) - - async def validate_and_set_hume(self, errors: list[WingmanInitializationError]): - api_key = await self.retrieve_secret("hume", errors) - if api_key: - self.hume = Hume( - api_key=api_key, - wingman_name=self.name, - ) - self.hume.validate_config(config=self.config.hume, errors=errors) - - async def validate_and_set_inworld(self, errors: list[WingmanInitializationError]): - api_key = await self.retrieve_secret("inworld", errors) - if api_key: - self.inworld = Inworld( - api_key=api_key, - wingman_name=self.name, - ) - self.inworld.validate_config(config=self.config.inworld, errors=errors) - - async def validate_and_set_azure(self, errors: list[WingmanInitializationError]): - for key_type in self.AZURE_SERVICES: - if self.uses_provider("azure"): - api_key = await self.retrieve_secret(f"azure_{key_type}", errors) - if api_key: - self.azure_api_keys[key_type] = api_key - if len(errors) == 0: - self.openai_azure = OpenAiAzure() - - async def validate_and_set_wingman_pro(self): - self.wingman_pro = WingmanSubscription( - wingman_name=self.name, settings=self.settings.wingman_pro - ) - - async def validate_and_set_perplexity( - self, errors: list[WingmanInitializationError] - ): - api_key = await self.retrieve_secret("perplexity", errors) - if api_key: - self.perplexity = OpenAi( - api_key=api_key, - base_url=self.config.perplexity.endpoint, - ) - - async def validate_and_set_xai(self, errors: list[WingmanInitializationError]): - api_key = await self.retrieve_secret("xai", errors) - if api_key: - self.xai = XAi( - api_key=api_key, - base_url=self.config.xai.endpoint, - ) - - # overrides the base class method - async def update_settings(self, settings: SettingsConfig): - """Update the settings of the Wingman. This method should always be called when the user Settings have changed.""" - try: - await super().update_settings(settings) - - if self.uses_provider("wingman_pro"): - await self.validate_and_set_wingman_pro() - printr.print( - f"Wingman {self.name}: reinitialized Wingman Pro with new settings", - server_only=True, - ) - except Exception as e: - await printr.print_async( - f"Error while updating settings for wingman '{self.name}': {str(e)}", - color=LogType.ERROR, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - - async def _generate_instant_responses(self) -> None: - """Generates general instant responses based on given context.""" - context = await self.get_context() - messages = [ - { - "role": "system", - "content": """ - Generate a list in JSON format of at least 20 short direct text responses. - Make sure the response only contains the JSON, no additional text. - They must fit the described character in the given context by the user. - Every generated response must be generally usable in every situation. - Responses must show its still in progress and not in a finished state. - The user request this response is used on is unknown. Therefore it must be generic. - Good examples: - - "Processing..." - - "Stand by..." - - Bad examples: - - "Generating route..." (too specific) - - "I'm sorry, I can't do that." (too negative) - - Response example: - [ - "OK", - "Generating results...", - "Roger that!", - "Stand by..." - ] - """, - }, - {"role": "user", "content": context}, - ] - try: - completion = await self.actual_llm_call(messages) - if completion is None: - return - if completion.choices[0].message.content: - retry_limit = 3 - retry_count = 1 - valid = False - while not valid and retry_count <= retry_limit: - try: - responses = json.loads(completion.choices[0].message.content) - valid = True - for response in responses: - if response not in self.instant_responses: - self.instant_responses.append(str(response)) - except json.JSONDecodeError: - messages.append(completion.choices[0].message) - messages.append( - { - "role": "user", - "content": "It was tried to handle the response in its entirety as a JSON string. Fix response to be a pure, valid JSON, it was not convertable.", - } - ) - if retry_count <= retry_limit: - completion = await self.actual_llm_call(messages) - retry_count += 1 - except Exception as e: - await printr.print_async( - f"Error while generating instant responses: {str(e)}", - color=LogType.ERROR, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - - async def _transcribe(self, audio_input_wav: str) -> str | None: - """Transcribes the recorded audio to text using the OpenAI Whisper API. - - Args: - audio_input_wav (str): The path to the audio file that contains the user's speech. This is a recording of what you you said. - - Returns: - str | None: The transcript of the audio file or None if the transcription failed. - """ - transcript = None - - try: - if self.config.features.stt_provider == SttProvider.AZURE: - transcript = self.openai_azure.transcribe_whisper( - filename=audio_input_wav, - api_key=self.azure_api_keys["whisper"], - config=self.config.azure.whisper, - ) - elif self.config.features.stt_provider == SttProvider.AZURE_SPEECH: - transcript = self.openai_azure.transcribe_azure_speech( - filename=audio_input_wav, - api_key=self.azure_api_keys["tts"], - config=self.config.azure.stt, - ) - elif self.config.features.stt_provider == SttProvider.WHISPERCPP: - transcript = self.whispercpp.transcribe( - filename=audio_input_wav, config=self.config.whispercpp - ) - elif self.config.features.stt_provider == SttProvider.FASTER_WHISPER: - hotwords: list[str] = [] - # add my name - hotwords.append(self.name) - # add default hotwords - default_hotwords = self.config.fasterwhisper.hotwords - if default_hotwords and len(default_hotwords) > 0: - hotwords.extend(default_hotwords) - # and my additional hotwords - wingman_hotwords = self.config.fasterwhisper.additional_hotwords - if wingman_hotwords and len(wingman_hotwords) > 0: - hotwords.extend(wingman_hotwords) - - transcript = self.fasterwhisper.transcribe( - filename=audio_input_wav, - config=self.config.fasterwhisper, - hotwords=list(set(hotwords)), - ) - elif self.config.features.stt_provider == SttProvider.PARAKEET: - transcript = self.parakeet.transcribe( - config=self.config.parakeet, - filename=audio_input_wav, - ) - elif self.config.features.stt_provider == SttProvider.WINGMAN_PRO: - if ( - self.config.wingman_pro.stt_provider - == WingmanProSttProvider.WHISPER - ): - transcript = self.wingman_pro.transcribe_whisper( - filename=audio_input_wav - ) - elif ( - self.config.wingman_pro.stt_provider - == WingmanProSttProvider.AZURE_SPEECH - ): - transcript = self.wingman_pro.transcribe_azure_speech( - filename=audio_input_wav, config=self.config.azure.stt - ) - elif self.config.features.stt_provider == SttProvider.OPENAI: - transcript = self.openai.transcribe(filename=audio_input_wav) - elif self.config.features.stt_provider == SttProvider.GROQ: - transcript = self.groq.transcribe( - filename=audio_input_wav, model="whisper-large-v3-turbo" - ) - except Exception as e: - await printr.print_async( - f"Error during transcription using '{self.config.features.stt_provider}': {str(e)}", - color=LogType.ERROR, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - - result = None - if transcript: - # Wingman Pro might returns a serialized dict instead of a real Azure Speech transcription object - result = ( - transcript.get("_text") - if isinstance(transcript, dict) - else transcript.text - ) - - return result - - async def _get_response_for_transcript( - self, transcript: str, benchmark: Benchmark, images: list[tuple[str, str]] = None - ) -> tuple[str | None, str | None, Skill | None, bool]: - """Gets the response for a given transcript. - - This function interprets the transcript, runs instant commands if triggered, - calls the OpenAI API when needed, processes any tool calls, and generates the final response. - - Args: - transcript (str): The user's spoken text transcribed. - - Returns: - tuple[str | None, str | None, Skill | None, bool]: A tuple containing the final response, the instant response (if any), the skill that was used, and a boolean indicating whether the current audio should be interrupted. - """ - self.ensure_memory_initialized() - - await self.add_user_message(transcript, images=images) - - benchmark.start_snapshot("Instant activation commands") - instant_response, instant_command_executed = await self._try_instant_activation( - transcript=transcript - ) - if instant_response: - await self.add_assistant_message(instant_response) - benchmark.finish_snapshot() - if ( - instant_response == "." - ): # thats for the "The UI should not give a response" option in commands - instant_response = None - return instant_response, instant_response, None, True - benchmark.finish_snapshot() - - # Track cumulative times for proper aggregation - llm_processing_time_ms = 0.0 - tool_execution_time_ms = 0.0 - tool_timings: list[tuple[str, float]] = ( - [] - ) # (label, time_ms) for individual tools - - # make a GPT call with the conversation history - # if an instant command got executed, prevent tool calls to avoid duplicate executions - llm_start = time.perf_counter() - completion = await self._llm_call(instant_command_executed is False) - llm_processing_time_ms += (time.perf_counter() - llm_start) * 1000 - - if completion is None: - self._add_benchmark_snapshot( - benchmark, "LLM Processing", llm_processing_time_ms - ) - return None, None, None, True - - response_message, tool_calls, usage = await self._process_completion( - completion, instant_command_executed is False - ) - - # Track token usage across the turn (last prompt_tokens, summed completion_tokens) - turn_prompt_tokens = usage[0] - turn_completion_tokens = usage[1] - self._last_prompt_tokens = turn_prompt_tokens - - # add message and dummy tool responses to conversation history - is_waiting_response_needed, is_summarize_needed = await self._add_gpt_response( - response_message, tool_calls - ) - interrupt = True # initial answer should be awaited if exists - - while tool_calls: - if is_waiting_response_needed: - message = None - if response_message.content: - message = response_message.content - elif self.instant_responses: - message = self._get_random_filler() - is_summarize_needed = True - if message: - self.threaded_execution(self.play_to_user, message, interrupt) - await printr.print_async( - f"{message}", - color=LogType.POSITIVE, - source=LogSource.WINGMAN, - source_name=self.name, - skill_name="", - ) - interrupt = False - else: - is_summarize_needed = True - else: - is_summarize_needed = True - - # Time tool execution and collect individual timings - tool_start = time.perf_counter() - instant_response, skill, iteration_timings = await self._handle_tool_calls( - tool_calls - ) - tool_execution_time_ms += (time.perf_counter() - tool_start) * 1000 - tool_timings.extend(iteration_timings) - - if instant_response: - await self._trim_tool_responses() - # Add snapshots before returning - self._add_benchmark_snapshot( - benchmark, "LLM Processing", llm_processing_time_ms - ) - if tool_execution_time_ms > 0: - self._add_tool_execution_snapshot( - benchmark, tool_execution_time_ms, tool_timings - ) - await self._broadcast_token_usage( - turn_prompt_tokens, turn_completion_tokens - ) - return None, instant_response, None, interrupt - - if is_summarize_needed: - # Time the follow-up LLM call - llm_start = time.perf_counter() - completion = await self._llm_call(True) - llm_processing_time_ms += (time.perf_counter() - llm_start) * 1000 - - if completion is None: - await self._trim_tool_responses() - self._add_benchmark_snapshot( - benchmark, "LLM Processing", llm_processing_time_ms - ) - if tool_execution_time_ms > 0: - self._add_tool_execution_snapshot( - benchmark, tool_execution_time_ms, tool_timings - ) - await self._broadcast_token_usage( - turn_prompt_tokens, turn_completion_tokens - ) - return None, None, None, True - - response_message, tool_calls, usage = await self._process_completion( - completion - ) - # Last call's prompt_tokens is most meaningful (includes full context) - turn_prompt_tokens = usage[0] - turn_completion_tokens += usage[1] - self._last_prompt_tokens = turn_prompt_tokens - - is_waiting_response_needed, is_summarize_needed = ( - await self._add_gpt_response(response_message, tool_calls) - ) - if tool_calls: - interrupt = False - elif is_waiting_response_needed: - await self._trim_tool_responses() - self._add_benchmark_snapshot( - benchmark, "LLM Processing", llm_processing_time_ms - ) - if tool_execution_time_ms > 0: - self._add_tool_execution_snapshot( - benchmark, tool_execution_time_ms, tool_timings - ) - await self._broadcast_token_usage( - turn_prompt_tokens, turn_completion_tokens - ) - return None, None, None, interrupt - - # Trim oversized tool responses now that the LLM has processed them - await self._trim_tool_responses() - - # Add final snapshots - self._add_benchmark_snapshot( - benchmark, "LLM Processing", llm_processing_time_ms - ) - if tool_execution_time_ms > 0: - self._add_tool_execution_snapshot( - benchmark, tool_execution_time_ms, tool_timings - ) - await self._broadcast_token_usage(turn_prompt_tokens, turn_completion_tokens) - return response_message.content, response_message.content, None, interrupt - - def _add_benchmark_snapshot( - self, benchmark: Benchmark, label: str, execution_time_ms: float - ): - """Add a snapshot with the given label and execution time.""" - if execution_time_ms >= 1000: - formatted_time = f"{execution_time_ms/1000:.1f}s" - else: - formatted_time = f"{int(execution_time_ms)}ms" - - from api.interface import BenchmarkResult - - benchmark.snapshots.append( - BenchmarkResult( - label=label, - execution_time_ms=execution_time_ms, - formatted_execution_time=formatted_time, - ) - ) - - def _add_tool_execution_snapshot( - self, - benchmark: Benchmark, - total_time_ms: float, - tool_timings: list[tuple[str, float]], - ): - """Add a tool execution snapshot with nested individual tool timings.""" - from api.interface import BenchmarkResult - - if total_time_ms >= 1000: - formatted_time = f"{total_time_ms/1000:.1f}s" - else: - formatted_time = f"{int(total_time_ms)}ms" - - # Create nested snapshots for individual tools - nested_snapshots = [] - for label, time_ms in tool_timings: - if time_ms >= 1000: - fmt = f"{time_ms/1000:.1f}s" - else: - fmt = f"{int(time_ms)}ms" - nested_snapshots.append( - BenchmarkResult( - label=label, - execution_time_ms=time_ms, - formatted_execution_time=fmt, - ) - ) - - benchmark.snapshots.append( - BenchmarkResult( - label="Tool Execution", - execution_time_ms=total_time_ms, - formatted_execution_time=formatted_time, - snapshots=nested_snapshots if nested_snapshots else None, - ) - ) - - async def _broadcast_token_usage(self, prompt_tokens: int, completion_tokens: int): - """Broadcast actual API-reported token usage to the client.""" - is_local = ( - self.config.features.conversation_provider == ConversationProvider.LOCAL_LLM - ) - - # Local providers (e.g. Ollama) often report 0 tokens — estimate from messages - if is_local and prompt_tokens == 0: - prompt_tokens = sum( - count_tokens( - msg["content"] - if isinstance(msg.get("content"), str) - else str(msg.get("content", "")) - ) - for msg in self.messages - ) - if is_local and completion_tokens == 0 and self.messages: - last = self.messages[-1] - if last.get("role") == "assistant": - content = last.get("content", "") - completion_tokens = count_tokens( - content if isinstance(content, str) else str(content) - ) - - self.last_turn_prompt_tokens = prompt_tokens - self.last_turn_completion_tokens = completion_tokens - if prompt_tokens == 0 and completion_tokens == 0: - return - if not printr._connection_manager: - return - - from api.commands import ConversationTokenUsageCommand - - await printr._connection_manager.broadcast( - ConversationTokenUsageCommand( - wingman_name=self.name, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - is_local=is_local, - ) - ) - - def _get_random_filler(self): - # get last two used instant responses - if len(self.last_used_instant_responses) > 2: - self.last_used_instant_responses = self.last_used_instant_responses[-2:] - - # get a random instant response that was not used in the last two responses - random_index = random.randint(0, len(self.instant_responses) - 1) - while random_index in self.last_used_instant_responses: - random_index = random.randint(0, len(self.instant_responses) - 1) - - # add the index to the last used list and return - self.last_used_instant_responses.append(random_index) - return self.instant_responses[random_index] - - async def _fix_tool_calls(self, tool_calls): - """Fixes tool calls that have a command name as function name. - - Args: - tool_calls (list): The tool calls to fix. - - Returns: - list: The fixed tool calls. - """ - if tool_calls and len(tool_calls) > 0: - for tool_call in tool_calls: - function_name = tool_call.function.name - function_args = ( - tool_call.function.arguments - # Mistral returns a dict - if isinstance(tool_call.function.arguments, dict) - # OpenAI returns a string - else json.loads(tool_call.function.arguments) - ) - - # try to resolve function name to a command name - if (len(function_args) == 0 and self.get_command(function_name)) or ( - len(function_args) == 1 - and "command_name" in function_args - and self.get_command(function_args["command_name"]) - and function_name == function_args["command_name"] - ): - function_args["command_name"] = function_name - function_name = "execute_command" - - # update the tool call - tool_call.function.name = function_name - tool_call.function.arguments = json.dumps(function_args) - - if self.settings.debug_mode: - await printr.print_async( - "Applied command call fix.", color=LogType.WARNING - ) - - return tool_calls - - async def _add_gpt_response(self, message, tool_calls) -> (bool, bool): - """Adds a message from GPT to the conversation history as well as adding dummy tool responses for any tool calls. - - Args: - message (dict | ChatCompletionMessage): The message to add. - tool_calls (list): The tool calls associated with the message. - """ - # call skill hooks (only for prepared/activated skills) - for skill in self.skills: - if skill.is_prepared: - await skill.on_add_assistant_message( - message.content, message.tool_calls - ) - - # do not tamper with this message as it will lead to 400 errors! - self.messages.append(message) - - # adding dummy tool responses to prevent corrupted message history on parallel requests - # and checks if waiting response should be played - unique_tools = {} - is_waiting_response_needed = False - is_summarize_needed = False - - if tool_calls: - for tool_call in tool_calls: - if not tool_call.id: - continue - # adding a dummy tool response to get updated later - self._add_tool_response(tool_call, "Loading..", False) - - function_name = tool_call.function.name - - # Meta-tools (search_skills, activate_skill, etc.) always need a follow-up - # LLM call so it can use the newly activated tools - if self.skill_registry.is_meta_tool(function_name): - is_summarize_needed = True - elif function_name in self.tool_skills: - skill = self.tool_skills[function_name] - if await skill.is_waiting_response_needed(function_name): - is_waiting_response_needed = True - if await skill.is_summarize_needed(function_name): - is_summarize_needed = True - - unique_tools[function_name] = True - - if len(unique_tools) == 1 and "execute_command" in unique_tools: - is_waiting_response_needed = True - - return is_waiting_response_needed, is_summarize_needed - - def _add_tool_response(self, tool_call, response: str, completed: bool = True): - """Adds a tool response to the conversation history. - - Args: - tool_call (dict|ChatCompletionMessageToolCall): The tool call to add the dummy response for. - """ - msg = {"role": "tool", "content": response} - if tool_call.id is not None: - msg["tool_call_id"] = tool_call.id - if tool_call.function.name is not None: - msg["name"] = tool_call.function.name - self.messages.append(msg) - - if tool_call.id and not completed: - self.pending_tool_calls.append(tool_call.id) - - async def _update_tool_response(self, tool_call_id, response) -> bool: - """Updates a tool response in the conversation history. - - Args: - tool_call_id (str): The identifier of the tool call to update the response for. - response (str): The new response to set. - - Returns: - bool: True if the response was updated, False if the tool call was not found. - """ - if not tool_call_id: - return False - - index = len(self.messages) - - # go through message history to find and update the tool call - for message in reversed(self.messages): - index -= 1 - if ( - self.__get_message_role(message) == "tool" - and message.get("tool_call_id") == tool_call_id - ): - message["content"] = str(response) - if tool_call_id in self.pending_tool_calls: - self.pending_tool_calls.remove(tool_call_id) - return True - - return False - - async def _trim_tool_responses(self, max_tokens: int = 500): - """Trim oversized tool responses in conversation history. - - Called after the LLM has finished processing a turn with tool calls. - The LLM already had full access to the data; this just prevents stale - bulk data from inflating the context on subsequent turns. - - If significant trimming occurs, broadcasts a condensation notification - so the client UI can display a summary indicator. - """ - total_tokens_saved = 0 - for msg in self.messages: - if self.__get_message_role(msg) != "tool": - continue - content = msg.get("content", "") - if not content: - continue - token_count = count_tokens(content) - if token_count <= max_tokens: - continue - total_tokens_saved += token_count - max_tokens - trimmed = truncate_to_tokens(content, max_tokens) - msg["content"] = ( - f"{trimmed}\n\n[...trimmed from ~{token_count} to " - f"~{max_tokens} tokens for conversation history. " - f"Full response was processed.]" - ) - - # Notify the client when significant trimming occurs so the UI can - # show a "Show history" indicator explaining the token drop. - # Skip if condensation is already running to avoid interfering with - # its own started/finished broadcast cycle. - if ( - total_tokens_saved > 1000 - and printr._connection_manager - and not self._is_condensing - ): - from api.commands import ConversationCondensationCommand - - if self.conversation_summary: - summary = ( - f"{self.conversation_summary}\n\n---\n\n" - f"[Latest turn: tool responses trimmed — ~{total_tokens_saved:,} tokens saved]" - ) - else: - summary = ( - f"Tool responses were automatically trimmed after LLM processing.\n" - f"~{total_tokens_saved:,} tokens saved.\n\n" - f"The LLM had full access to the complete data when generating " - f"its response. Responses are trimmed afterwards to keep the " - f"conversation context efficient." - ) - - await printr._connection_manager.broadcast( - ConversationCondensationCommand( - wingman_name=self.name, - status="finished", - estimated_tokens_saved=total_tokens_saved, - summary_text=summary, - ) - ) - - async def add_user_message(self, content: str, images: list[tuple[str, str]] = None): - """Shortens the conversation history if needed and adds a user message to it. - - Args: - content (str): The message content to add. - images (list[tuple[str, str]]): Optional list of (base64_data, mime_type) tuples to attach. - """ - self._memory_recall_notified = False - - # call skill hooks (only for prepared/activated skills) - for skill in self.skills: - if skill.is_prepared: - await skill.on_add_user_message(content) - - if images: - msg_content = [] - for img_b64, mime in images: - msg_content.append({ - "type": "image_url", - "image_url": { - "url": f"data:{mime};base64,{img_b64}", - "detail": "auto", - }, - }) - msg_content.append({"type": "text", "text": content}) - msg = {"role": "user", "content": msg_content} - else: - msg = {"role": "user", "content": content} - await self._cleanup_conversation_history() - await self._maybe_condense_history() - self.messages.append(msg) - - async def add_assistant_message(self, content: str): - """Adds an assistant message to the conversation history. - - Args: - content (str): The message content to add. - """ - # call skill hooks (only for prepared/activated skills) - for skill in self.skills: - if skill.is_prepared: - await skill.on_add_assistant_message(content, []) - - msg = {"role": "assistant", "content": content} - self.messages.append(msg) - - async def add_forced_assistant_command_calls(self, commands: list[CommandConfig]): - """Adds forced assistant command calls to the conversation history. - - Args: - commands (list[CommandConfig]): The commands to add. - """ - - if not commands: - return - - message = ChatCompletionMessage( - content="", - role="assistant", - tool_calls=[], - ) - tool_id_to_command = {} - for command in commands: - tool_id = None - if ( - self.config.features.conversation_provider - == ConversationProvider.OPENAI - ) or ( - self.config.features.conversation_provider - == ConversationProvider.WINGMAN_PRO - and "gpt" in self.config.wingman_pro.conversation_deployment.lower() - ): - tool_id = f"call_{str(uuid.uuid4()).replace('-', '')}" - elif ( - self.config.features.conversation_provider - == ConversationProvider.GOOGLE - ): - if ( - self.config.google.conversation_model.startswith("gemini-3") - or self.config.google.conversation_model == "gemini-flash-latest" - or self.config.google.conversation_model == "gemini-pro-latest" - or self.config.google.conversation_model - == "gemini-flash-lite-latest" - ): - # gemini 3+ (latest = 3+) needs a thought signature like this, but we cant fake it: - # { - # 'model_extra': { - # 'extra_content': { - # 'google': { - # 'thought_signature': 'EjQKMgFyyNp8mNe4bQmQhOua7gGMH0C9RubFWewy6BzYZJs5f4RqDb8CaiR4gjLxoM1iQqP4' - # } - # } - # } - # } - return - tool_id = f"function-call-{''.join(random.choices('0123456789', k=20))}" - - # early exit for unsupported providers/models - if not tool_id: - return - - tool_call = ChatCompletionMessageToolCall( - id=tool_id, - function=ParsedFunction( - name="execute_command", - arguments=json.dumps({"command_name": command.name}), - ), - type="function", - ) - message.tool_calls.append(tool_call) - tool_id_to_command[tool_id] = command - - await self._add_gpt_response(message, message.tool_calls) - for tool_call in message.tool_calls: - command = tool_id_to_command[tool_call.id] - await self._update_tool_response( - tool_call.id, command.additional_context or "OK" - ) - - async def _cleanup_conversation_history(self): - """Cleans up the conversation history by removing messages that are too old.""" - remember_messages = self.config.features.remember_messages - - if remember_messages is None or len(self.messages) == 0: - return 0 # Configuration not set, nothing to delete. - - # Find the cutoff index where to end deletion, making sure to only count 'user' messages towards the limit starting with newest messages. - cutoff_index = len(self.messages) - user_message_count = 0 - for message in reversed(self.messages): - if self.__get_message_role(message) == "user": - user_message_count += 1 - if user_message_count == remember_messages: - break # Found the cutoff point. - cutoff_index -= 1 - - # If messages below the keep limit, don't delete anything. - if user_message_count < remember_messages: - return 0 - - total_deleted_messages = cutoff_index # Messages to delete. - - # Remove the pending tool calls that are no longer needed. - for mesage in self.messages[:cutoff_index]: - if ( - self.__get_message_role(mesage) == "tool" - and mesage.get("tool_call_id") in self.pending_tool_calls - ): - self.pending_tool_calls.remove(mesage.get("tool_call_id")) - if self.settings.debug_mode: - await printr.print_async( - f"Removing pending tool call {mesage.get('tool_call_id')} due to message history clean up.", - color=LogType.WARNING, - ) - - # Remove the messages before the cutoff index, exclusive of the system message. - del self.messages[:cutoff_index] - - # Optional debugging printout. - if self.settings.debug_mode and total_deleted_messages > 0: - await printr.print_async( - f"Deleted {total_deleted_messages} messages from the conversation history.", - color=LogType.WARNING, - ) - - return total_deleted_messages - - def _estimate_conversation_tokens(self) -> int: - """Estimate the total token count of the current conversation history.""" - return sum(count_tokens(self._message_text_content(m)) for m in self.messages) - - def _get_support_capacity(self) -> int: - """Get the effective input capacity of the support model for a single pass. - - Returns the number of conversation tokens that can fit in one summarization - pass, accounting for system prompt, framing text, and output budget. - """ - system_prompt = get_prompt("condense-conversation") - budget = self.local_ai_service.get_token_budget(system_prompt) - # Subtract framing overhead (prefix/suffix around the conversation text) - framing_overhead = 80 - return max(0, budget.max_input_tokens - framing_overhead) - - async def _maybe_condense_history(self): - """Check if condensation should run and fire it as a background task. - - Uses a token-based trigger: condenses when the conversation approaches - 70% of what the support model can handle in a single pass, so we - avoid chunking. Also has a message count safety cap. - - The cl100k_base token estimate is multiplied by _support_token_ratio - (calibrated from real support model usage) to account for tokenizer - differences between the estimation tokenizer and the actual model. - """ - if not self.config.features.condense_conversation: - return - if not self.local_ai_service or not self.local_ai_service.is_ready(): - return - if self.pending_tool_calls: - return # Never interrupt chained tool calls - if self._is_condensing: - return - - # Token-based trigger: condense when conversation reaches 70% of - # what the support model can handle in one pass. - # Apply _support_token_ratio to correct for tokenizer differences - # between cl100k_base (used for estimation) and the actual model. - capacity = self._get_support_capacity() - cl100k_tokens = self._estimate_conversation_tokens() - conversation_tokens = int(cl100k_tokens * self._support_token_ratio) - token_trigger = conversation_tokens >= int(capacity * 0.7) - - # Message count safety cap - user_msg_count = sum( - 1 for m in self.messages if self.__get_message_role(m) == "user" - ) - message_trigger = user_msg_count >= self.config.features.condense_max_messages - - if not token_trigger and not message_trigger: - return - - # Runs in background so user is never blocked. - # Store the task reference to prevent garbage collection mid-execution. - self._condense_task = asyncio.create_task(self._condense_history()) - self._condense_task.add_done_callback( - lambda _: setattr(self, "_condense_task", None) - ) - - async def _condense_history(self, force: bool = False): - """Condense older conversation messages into a running summary using local AI. - - This preserves the most recent messages verbatim while summarizing older ones, - saving tokens without losing important context. Tool call/response pairs are - never split. - - Args: - force: If True, skip the threshold check (used for manual trigger). - """ - if self._is_condensing: - await printr.print_async( - "Condensation skipped — already in progress.", - color=LogType.WARNING, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - return - if not self.local_ai_service or not self.local_ai_service.is_ready(): - await printr.print_async( - "Condensation skipped — local AI service not available.", - color=LogType.WARNING, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - return - - keep_recent = ( - self.config.features.condense_keep_recent - if not force - else min(self.config.features.condense_keep_recent, 2) - ) - total_msg_count = len(self.messages) - - # Need at least something to condense beyond what we keep - if total_msg_count <= keep_recent: - await printr.print_async( - f"Condensation skipped — only {total_msg_count} messages, need more than {keep_recent} to condense.", - color=LogType.LOCALMODEL, - source_name=self.name, - source=LogSource.WINGMAN, - ) - return - - self._is_condensing = True - _condensation_stats: dict = {} - - # Broadcast start - from api.commands import ConversationCondensationCommand - - if printr._connection_manager: - await printr._connection_manager.broadcast( - ConversationCondensationCommand( - wingman_name=self.name, - status="started", - ) - ) - - await printr.print_async( - "Conversation condensation started.", - color=LogType.INFO, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - - try: - # Wait for any pending tool calls to finish - for _ in range(30): # max 15 seconds - if not self.pending_tool_calls: - break - await asyncio.sleep(0.5) - else: - await printr.print_async( - "Condensation aborted — tool calls still pending after 15s.", - color=LogType.WARNING, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - return - - # Find the cutoff: keep the most recent `keep_recent` user messages - kept_user_count = 0 - cutoff_index = len(self.messages) - for i in range(len(self.messages) - 1, -1, -1): - if self.__get_message_role(self.messages[i]) == "user": - kept_user_count += 1 - if kept_user_count == keep_recent: - cutoff_index = i - break - - if cutoff_index <= 0: - await printr.print_async( - f"Condensation skipped — cutoff_index={cutoff_index}, nothing to condense (kept_user_count={kept_user_count}, keep_recent={keep_recent}, total={len(self.messages)}).", - color=LogType.LOCALMODEL, - source_name=self.name, - source=LogSource.WINGMAN, - ) - return - - # Adjust cutoff forward to avoid orphaning tool responses - while cutoff_index < len(self.messages): - msg = self.messages[cutoff_index] - if self.__get_message_role(msg) == "tool": - cutoff_index += 1 - else: - break - - if cutoff_index <= 0: - await printr.print_async( - "Condensation skipped — no messages to condense after tool adjustment.", - color=LogType.LOCALMODEL, - source_name=self.name, - source=LogSource.WINGMAN, - ) - return - - to_condense = self.messages[:cutoff_index] - - # Extract memories from messages about to be condensed (background, non-blocking) - if self.persistent_memory_service: - try: - task = asyncio.create_task( - self.persistent_memory_service.extract_memories( - to_condense, generate_summary=True - ) - ) - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) - except Exception: - pass # Don't let memory extraction block condensation - - condensed_text = self._messages_to_text(to_condense) - if not condensed_text.strip(): - await printr.print_async( - "Condensation skipped — messages produced no text content.", - color=LogType.LOCALMODEL, - source_name=self.name, - source=LogSource.WINGMAN, - ) - return - - # Estimate original token count - estimated_original_tokens = sum( - count_tokens(self._message_text_content(m)) for m in to_condense - ) - - # Build the summarization prompt - existing_summary_section = "" - if self.conversation_summary: - existing_summary_section = ( - "EXISTING SUMMARY (incorporate and update — do not repeat verbatim):\n" - + self.conversation_summary - + "\n\n" - ) - - system_prompt = get_prompt("condense-conversation") - budget = self.local_ai_service.get_token_budget(system_prompt) - - user_prompt_prefix = ( - existing_summary_section + "CONVERSATION TO SUMMARIZE:\n" - ) - user_prompt_suffix = ( - "\n\n---\n" - "Now list every fact from the conversation above as bullet points.\n" - "Start from the FIRST message, end at the LAST. Include all secrets, names, preferences, and creative content:" - ) - prefix_suffix_tokens = count_tokens(user_prompt_prefix) + count_tokens( - user_prompt_suffix - ) - - # How much conversation text fits in one pass? - available_tokens = budget.max_input_tokens - prefix_suffix_tokens - - # Apply tokenizer ratio to decide if chunking is needed. - corrected_text_tokens = int( - count_tokens(condensed_text) * self._support_token_ratio - ) - corrected_available = int(available_tokens / self._support_token_ratio) - - if corrected_text_tokens > available_tokens: - # Chunk: summarize in segments, then merge - support_result = await asyncio.wait_for( - self._chunked_support( - condensed_text, - system_prompt, - existing_summary_section, - corrected_available, - ), - timeout=_CONDENSE_TIMEOUT, - ) - else: - user_prompt = ( - f"{user_prompt_prefix}{condensed_text}{user_prompt_suffix}" - ) - from services.skill_local_ai import SamplingPreset - - support_result = await asyncio.wait_for( - asyncio.get_event_loop().run_in_executor( - None, - lambda: self.local_ai_service.support( - text=user_prompt, - system_prompt=system_prompt, - preset=SamplingPreset.BALANCED, - ), - ), - timeout=_CONDENSE_TIMEOUT, - ) - - summary = support_result.text if support_result else None - - # Calibrate tokenizer ratio from real model usage - if support_result and support_result.prompt_tokens > 0: - cl100k_input = ( - budget.system_tokens - + count_tokens(condensed_text) - + prefix_suffix_tokens - ) - if cl100k_input > 0: - self._support_token_ratio = ( - support_result.prompt_tokens / cl100k_input - ) - - # Detect truncated output - if support_result and support_result.truncated: - await printr.print_async( - f"Condensation output was truncated (finish_reason=length). " - f"Model used {support_result.prompt_tokens} prompt tokens, " - f"generated {support_result.completion_tokens} tokens. " - f"Token ratio calibrated to {self._support_token_ratio:.2f}.", - color=LogType.WARNING, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - - if not summary: - await printr.print_async( - "Conversation condensation failed — local AI returned no result.", - color=LogType.WARNING, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - return - - # Clean pending tool calls being removed - for msg in to_condense: - if ( - self.__get_message_role(msg) == "tool" - and msg.get("tool_call_id") in self.pending_tool_calls - ): - self.pending_tool_calls.remove(msg.get("tool_call_id")) - - # Replace old messages - del self.messages[:cutoff_index] - self.conversation_summary = summary - - estimated_summary_tokens = count_tokens(summary) - estimated_tokens_saved = max( - 0, estimated_original_tokens - estimated_summary_tokens - ) - - await printr.print_async( - f"Condensed {cutoff_index} messages into summary " - f"({len(summary)} chars, ~{estimated_summary_tokens} tokens). " - f"{len(self.messages)} messages remaining. " - f"~{estimated_tokens_saved} tokens saved. " - f"Token ratio: {self._support_token_ratio:.2f}.", - color=LogType.INFO, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - - # Record stats for the broadcast in finally - _condensation_stats = { - "messages_condensed": cutoff_index, - "messages_remaining": len(self.messages), - "summary_length": len(summary), - "estimated_tokens_saved": estimated_tokens_saved, - "summary_text": summary, - } - - except asyncio.TimeoutError: - await printr.print_async( - "Condensation timed out — local model took too long.", - color=LogType.WARNING, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - except Exception as e: - await printr.print_async( - f"Conversation condensation error: {e}", - color=LogType.ERROR, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - finally: - self._is_condensing = False - # Always broadcast finished so the client UI doesn't get stuck. - # Include summary_text if condensation produced one (even if a - # later step failed), so the client can show the view-history button. - if printr._connection_manager: - try: - await printr._connection_manager.broadcast( - ConversationCondensationCommand( - wingman_name=self.name, - status="finished", - **_condensation_stats, - ) - ) - except Exception as e: - await printr.print_async( - f"Failed to broadcast condensation finish: {e}", - color=LogType.WARNING, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - - async def _chunked_support( - self, - full_text: str, - system_prompt: str, - existing_summary_section: str, - chunk_max_tokens: int, - ) -> "SupportResult": - """Process text that exceeds the model's context window by chunking. - - Each chunk is processed independently, then results are merged into - one final summary. Returns a SupportResult from the merge step. - """ - from providers.llama_cpp_provider import SupportResult - from services.skill_local_ai import SamplingPreset - - budget = self.local_ai_service.get_token_budget(system_prompt) - - # Convert token budget to approximate char limit for splitting - # (splitting needs char positions; we use ~4 chars/token as a rough guide, - # then verify with count_tokens) - approx_chunk_chars = chunk_max_tokens * 4 - chunks = [] - remaining = full_text - while remaining: - if count_tokens(remaining) <= chunk_max_tokens: - chunks.append(remaining) - break - # Try to split at a newline boundary - split_at = remaining.rfind("\n", 0, approx_chunk_chars) - if split_at <= 0: - split_at = approx_chunk_chars - chunks.append(remaining[:split_at]) - remaining = remaining[split_at:].lstrip() - - loop = asyncio.get_event_loop() - chunk_summaries = [] - for i, chunk in enumerate(chunks): - user_prompt = ( - f"{existing_summary_section if i == 0 else ''}" - f"CONVERSATION TO SUMMARIZE (part {i + 1}/{len(chunks)}):\n{chunk}\n\n" - "---\nList every fact from the above as bullet points. Include all secrets, names, preferences, and creative content:" - ) - - # Safety: if chunk input exceeds budget, truncate chunk text - chunk_text_tokens = count_tokens(chunk) - prompt_overhead = count_tokens(user_prompt) - chunk_text_tokens - if prompt_overhead + chunk_text_tokens > budget.max_input_tokens: - safe_text_tokens = budget.max_input_tokens - prompt_overhead - if safe_text_tokens > 0: - chunk = truncate_to_tokens(chunk, safe_text_tokens) - user_prompt = ( - f"{existing_summary_section if i == 0 else ''}" - f"CONVERSATION TO SUMMARIZE (part {i + 1}/{len(chunks)}):\n{chunk}\n\n" - "---\nList every fact from the above as bullet points. Include all secrets, names, preferences, and creative content:" - ) - await printr.print_async( - f"Chunk {i + 1}/{len(chunks)} exceeded context budget, truncated to fit.", - color=LogType.WARNING, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - - result = await loop.run_in_executor( - None, - lambda p=user_prompt: self.local_ai_service.support( - text=p, - system_prompt=system_prompt, - preset=SamplingPreset.BALANCED, - ), - ) - if result.text: - chunk_summaries.append(result.text) - - # Calibrate tokenizer ratio from real model usage. - cl100k_input = count_tokens(system_prompt) + count_tokens(user_prompt) - if result.prompt_tokens > 0 and cl100k_input > 0: - self._support_token_ratio = result.prompt_tokens / cl100k_input - - if result.truncated: - await printr.print_async( - f"Chunk {i + 1}/{len(chunks)} output truncated " - f"(prompt={result.prompt_tokens}, " - f"completion={result.completion_tokens}).", - color=LogType.WARNING, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - - if not chunk_summaries: - return SupportResult(text=None) - if len(chunk_summaries) == 1: - return SupportResult(text=chunk_summaries[0]) - - # Merge all chunk summaries into one final summary - combined = "\n\n".join( - f"Part {i + 1}:\n{s}" for i, s in enumerate(chunk_summaries) - ) - merge_prompt = ( - f"{existing_summary_section}" - f"PARTIAL SUMMARIES TO MERGE:\n{combined}\n\n" - "Merge these into a single coherent summary. Keep all key facts:" - ) - - # Safety: truncate combined summaries if they exceed budget - if count_tokens(merge_prompt) > budget.max_input_tokens: - overhead = count_tokens(merge_prompt) - count_tokens(combined) - safe_combined = budget.max_input_tokens - overhead - if safe_combined > 0: - combined = truncate_to_tokens(combined, safe_combined) - merge_prompt = ( - f"{existing_summary_section}" - f"PARTIAL SUMMARIES TO MERGE:\n{combined}\n\n" - "Merge these into a single coherent summary. Keep all key facts:" - ) - await printr.print_async( - f"Merge input exceeded context budget, truncated to fit.", - color=LogType.WARNING, - server_only=True, - source_name=self.name, - source=LogSource.WINGMAN, - ) - - return await loop.run_in_executor( - None, - lambda: self.local_ai_service.support( - text=merge_prompt, - system_prompt=system_prompt, - preset=SamplingPreset.BALANCED, - ), - ) - - def _extract_text_content(self, content) -> str: - """Extract text from message content, handling both string and multimodal list formats.""" - if isinstance(content, str): - return content - if isinstance(content, list): - # Multimodal content: extract text parts only - parts = [] - for part in content: - if isinstance(part, dict) and part.get("type") == "text": - parts.append(part.get("text", "")) - elif isinstance(part, str): - parts.append(part) - return " ".join(parts) - return "" - - def _message_text_content(self, msg) -> str: - """Extract text content from a message for token estimation.""" - if isinstance(msg, Mapping): - return self._extract_text_content(msg.get("content", "")) or "" - elif hasattr(msg, "content"): - return self._extract_text_content(msg.content) or "" - return "" - - def _messages_to_text(self, messages: list) -> str: - """Convert a list of conversation messages to plain text for summarization.""" - lines = [] - for msg in messages: - role = self.__get_message_role(msg) - content = "" - if isinstance(msg, Mapping): - content = self._extract_text_content(msg.get("content", "")) - elif hasattr(msg, "content"): - content = self._extract_text_content(msg.content) or "" - - if role == "user": - lines.append(f"User: {content}") - elif role == "assistant": - if content: - lines.append(f"Assistant: {content}") - # Include tool call info - tool_calls = None - if isinstance(msg, Mapping): - tool_calls = msg.get("tool_calls") - elif hasattr(msg, "tool_calls"): - tool_calls = msg.tool_calls - if tool_calls: - for tc in tool_calls: - fn = ( - tc.function - if hasattr(tc, "function") - else tc.get("function", {}) - ) - name = fn.name if hasattr(fn, "name") else fn.get("name", "?") - args = ( - fn.arguments - if hasattr(fn, "arguments") - else fn.get("arguments", "") - ) - lines.append(f" [Tool call: {name}({args})]") - elif role == "tool": - tool_name = ( - msg.get("name", "tool") if isinstance(msg, Mapping) else "tool" - ) - lines.append(f" [Tool result ({tool_name}): {content[:200]}]") - return "\n".join(lines) - - async def reset_conversation_history(self): - """Resets the conversation history and skill activation state. - - When the conversation is reset, the LLM loses all memory of which skills - were activated and why. So we must also reset the skill registry and MCP - registry to ensure the progressive disclosure state matches the LLM's memory. - """ - # Extract memories before clearing (if enabled and enough messages) - if self.persistent_memory_service and len(self.messages) >= 4: - try: - await self.persistent_memory_service.extract_memories( - self.messages, generate_summary=True - ) - except Exception: - pass # Don't let extraction failure prevent reset - - self.messages = [] - self.conversation_summary = "" - self._last_prompt_tokens = 0 - # Keep _support_token_ratio — it's model-specific, not conversation-specific - self.skill_registry.reset_activations() - self.mcp_registry.reset_activations() - - async def _try_instant_activation(self, transcript: str) -> (str, bool): - """Tries to execute an instant activation command if present in the transcript. - - Args: - transcript (str): The transcript to check for an instant activation command. - - Returns: - tuple[str, bool]: A tuple containing the response to the instant command and a boolean indicating whether an instant command was executed. - """ - commands = await self._execute_instant_activation_command(transcript) - if commands: - await self.add_forced_assistant_command_calls(commands) - responses = [] - for command in commands: - if command.responses: - responses.append(self._select_instant_command_response(command)) - - if len(responses) == len(commands): - # clear duplicates - responses = list(dict.fromkeys(responses)) - responses = [ - response + "." if not response.endswith(".") else response - for response in responses - ] - return " ".join(responses), True - - return None, True - - return None, False - - async def get_context(self): - """Build the context and inserts it into the messages. - - With progressive disclosure, only includes prompts from ACTIVATED skills. - Skill prompts are auto-generated from @tool descriptions if no custom prompt is set. - """ - skill_prompts = "" - active_skill_names = self.skill_registry.active_skill_names - - for skill in self.skills: - # Only include prompts from activated skills (in progressive mode) - if skill.name not in active_skill_names: - continue - - # Get custom prompt if set - prompt = await skill.get_prompt() - - # Auto-generate prompt from tool descriptions if no custom prompt - if not prompt: - tools_desc = skill.get_tools_description() - if tools_desc: - prompt = f"Available tools:\n{tools_desc}" - - if prompt: - skill_prompts += "\n\n" + skill.name + "\n\n" + prompt - - # Get TTS prompt based on active TTS provider and user preference - tts_prompt = "" - if self.config.features.tts_provider == TtsProvider.ELEVENLABS: - if ( - self.config.elevenlabs.use_tts_prompt - and self.config.elevenlabs.tts_prompt - ): - tts_prompt = self.config.elevenlabs.tts_prompt - elif self.config.features.tts_provider == TtsProvider.INWORLD or ( - self.config.features.tts_provider == TtsProvider.WINGMAN_PRO - and self.config.wingman_pro.tts_provider == WingmanProTtsProvider.INWORLD - ): - if self.config.inworld.use_tts_prompt and self.config.inworld.tts_prompt: - tts_prompt = self.config.inworld.tts_prompt - elif self.config.features.tts_provider == TtsProvider.OPENAI_COMPATIBLE: - if ( - self.config.openai_compatible_tts.use_tts_prompt - and self.config.openai_compatible_tts.tts_prompt - ): - tts_prompt = self.config.openai_compatible_tts.tts_prompt - - # Add TTS header only if there's a prompt - if tts_prompt: - tts_prompt = "# TEXT-TO-SPEECH\n" + tts_prompt - - # Build user context with environment metadata - user_context = self._build_user_context() - - # Sanity check: truncate if someone bypasses the client's 2048-token limit - MAX_BACKSTORY_TOKENS = 2048 - backstory = self.config.prompts.backstory - - if backstory and count_tokens(backstory) > MAX_BACKSTORY_TOKENS: - original_tokens = count_tokens(backstory) - backstory = truncate_to_tokens(backstory, MAX_BACKSTORY_TOKENS) - await printr.print_async( - f"[{self.name}] Backstory will be truncated to {MAX_BACKSTORY_TOKENS} tokens for conversations (is {original_tokens}). " - f"Your saved backstory is unchanged. Consider shortening it.", - color=LogType.WARNING, - source_name=self.name, - source=LogSource.SYSTEM, - ) - - # Build conversation summary section - conversation_summary = "" - if self.conversation_summary: - conversation_summary = ( - "# CONVERSATION SUMMARY\n" - "The following is a summary of earlier parts of this conversation. " - "Treat it as factual context — the user and you discussed these topics previously.\n\n" - + self.conversation_summary - ) - - # Persistent memory injection - persistent_memory_context = "" - if self.persistent_memory_service and self.messages: - # Use the most recent user message as the query - last_user_msg = "" - for msg in reversed(self.messages): - role = msg.get("role") if isinstance(msg, dict) else getattr(msg, "role", None) - raw_content = msg.get("content", "") if isinstance(msg, dict) else getattr(msg, "content", "") - # Extract plain text from multimodal content (images etc.) - content = self._extract_text_content(raw_content) if raw_content else "" - if role == "user" and content: - last_user_msg = content - break - if last_user_msg: - try: - persistent_memory_context = await self.persistent_memory_service.build_memory_context(last_user_msg) - if persistent_memory_context and not self._memory_recall_notified: - self._memory_recall_notified = True - # Count restored fact lines (lines starting with "- ") - fact_count = sum( - 1 for line in persistent_memory_context.splitlines() - if line.startswith("- ") - ) - if fact_count > 0: - await printr.print_async( - f"Memory recalled: {fact_count} relevant {'memory' if fact_count == 1 else 'memories'} loaded.", - color=LogType.MEMORY, - source_name=self.name, - ) - except Exception: - pass # Don't let memory failures break conversation - - context = self.config.prompts.system_prompt.format( - backstory=backstory, - skills=skill_prompts, - ttsprompt=tts_prompt, - user_context=user_context, - conversation_summary=conversation_summary, - ) - - # If the system prompt template doesn't include {conversation_summary}, - # append the summary at the end so it's never lost. - if ( - conversation_summary - and "{conversation_summary}" not in self.config.prompts.system_prompt - ): - context += "\n\n" + conversation_summary - - # Append persistent memory context - if persistent_memory_context: - context += "\n\n" + persistent_memory_context - - # Persistent memory tool instructions - if self.persistent_memory_service: - context += ( - "\n\n# PERSISTENT MEMORY\n" - "You have persistent memory. Important facts and past conversation summaries " - "are provided in the [Memory] sections above (if any). " - "You can use the `remember`, `recall`, and `forget` tools when the user " - "explicitly asks you to remember, recall, or forget something. " - "You don't need to use `remember` for routine information — that is handled automatically." - ) - - self._last_compiled_context = context - return context - - def get_last_context(self) -> str: - """Return the last compiled system context (cached from the most recent LLM call).""" - return self._last_compiled_context - - def get_conversation_messages(self, strip_nulls: bool = True) -> list[dict]: - """Return the conversation messages as a list of plain dicts for debugging.""" - - def _strip_none(obj): - if isinstance(obj, dict): - return {k: _strip_none(v) for k, v in obj.items() if v is not None} - if isinstance(obj, list): - return [_strip_none(item) for item in obj] - return obj - - result = [] - for msg in self.messages: - if hasattr(msg, "model_dump"): - d = msg.model_dump() - else: - d = msg - if strip_nulls: - d = _strip_none(d) - result.append(d) - return result - - def _build_user_context(self) -> str: - """Build user context metadata for the system prompt. - - Includes timezone, config context, username, and wingman name. - """ - context_parts = [] - backstory = self.config.prompts.backstory or "" - backstory_lower = backstory.lower() - - # Date and timezone information - try: - now = datetime.now().astimezone() - local_tz = now.tzinfo - tz_name = str(local_tz) - # Get UTC offset in a readable format - utc_offset = now.strftime("%z") - # Format as +HH:MM or -HH:MM - if len(utc_offset) >= 5: - utc_offset = f"{utc_offset[:3]}:{utc_offset[3:]}" - # Include current date for relative date references ("last Sunday", "tomorrow", etc.) - current_date = now.strftime( - "%A, %B %d, %Y" - ) # e.g., "Tuesday, December 09, 2025" - context_parts.append(f"- Current date: {current_date}") - context_parts.append(f"- Timezone: {tz_name} (UTC{utc_offset})") - except Exception: - context_parts.append("- Timezone: Unknown") - - # Config/context name (e.g., "Star Citizen", "Elite Dangerous") - # This helps the LLM understand which game/context tools are relevant for - if self.tower and self.tower.config_dir and self.tower.config_dir.name: - context_parts.append(f"- Active context: {self.tower.config_dir.name}") - - # Username (only if not explicitly named in backstory) - if self.settings.user_name: - # Check if username is mentioned in backstory as a standalone word - import re - - name_pattern = r"\b" + re.escape(self.settings.user_name.lower()) + r"\b" - if not re.search(name_pattern, backstory_lower): - context_parts.append( - f"- User's name (default): {self.settings.user_name}" - ) - - # Wingman name - always include as it's useful context - # The system prompt already tells LLM to prioritize backstory names - if self.name: - context_parts.append(f"- Your name (default): {self.name}") - - if context_parts: - return "\n".join(context_parts) - return "No additional context available." - - async def add_context(self, messages): - context = await self.get_context() - - messages.insert(0, {"role": "system", "content": context}) - - async def generate_image(self, text: str) -> str: - """ - Generates an image from the provided text configured provider. - """ - - if ( - self.config.features.image_generation_provider - == ImageGenerationProvider.WINGMAN_PRO - ): - try: - return await self.wingman_pro.generate_image(text) - except Exception as e: - await printr.print_async( - f"Error during image generation: {str(e)}", color=LogType.ERROR - ) - printr.print( - traceback.format_exc(), color=LogType.ERROR, server_only=True - ) - - return "" - - async def actual_llm_call(self, messages, tools: list[dict] = None): - """ - Perform the actual LLM call with the messages provided. - """ - - try: - completion = None - if self.config.features.conversation_provider == ConversationProvider.AZURE: - completion = self.openai_azure.ask( - messages=messages, - api_key=self.azure_api_keys["conversation"], - config=self.config.azure.conversation, - tools=tools, - ) - elif ( - self.config.features.conversation_provider - == ConversationProvider.OPENAI - ): - completion = self.openai.ask( - messages=messages, - tools=tools, - model=self.config.openai.conversation_model, - ) - elif ( - self.config.features.conversation_provider - == ConversationProvider.MISTRAL - ): - completion = self.mistral.ask( - messages=messages, - tools=tools, - model=self.config.mistral.conversation_model, - ) - elif ( - self.config.features.conversation_provider == ConversationProvider.GROQ - ): - completion = self.groq.ask( - messages=messages, - tools=tools, - model=self.config.groq.conversation_model, - ) - elif ( - self.config.features.conversation_provider - == ConversationProvider.CEREBRAS - ): - completion = self.cerebras.ask( - messages=messages, - tools=tools, - model=self.config.cerebras.conversation_model, - ) - elif ( - self.config.features.conversation_provider - == ConversationProvider.GOOGLE - ): - completion = self.google.ask( - messages=messages, - tools=tools, - model=self.config.google.conversation_model, - ) - elif ( - self.config.features.conversation_provider - == ConversationProvider.OPENROUTER - ): - # OpenRouter throws an error if the model doesn't support tools but we send some - if self.openrouter_model_supports_tools: - completion = self.openrouter.ask( - messages=messages, - tools=tools, - model=self.config.openrouter.conversation_model, - ) - else: - completion = self.openrouter.ask( - messages=messages, - model=self.config.openrouter.conversation_model, - ) - elif ( - self.config.features.conversation_provider - == ConversationProvider.LOCAL_LLM - ): - if not self.local_llm: - raise RuntimeError( - f"Local LLM provider is not initialized. " - f"Please check your Local LLM endpoint configuration ({self.config.local_llm.endpoint})." - ) - completion = self.local_llm.ask( - messages=messages, - tools=tools, - model=self.config.local_llm.conversation_model, - ) - elif ( - self.config.features.conversation_provider - == ConversationProvider.WINGMAN_PRO - ): - completion = self.wingman_pro.ask( - messages=messages, - deployment=self.config.wingman_pro.conversation_deployment, - tools=tools, - ) - elif ( - self.config.features.conversation_provider - == ConversationProvider.PERPLEXITY - ): - completion = self.perplexity.ask( - messages=messages, - tools=tools, - model=self.config.perplexity.conversation_model.value, - ) - elif self.config.features.conversation_provider == ConversationProvider.XAI: - completion = self.xai.ask( - messages=messages, - tools=tools, - model=self.config.xai.conversation_model, - ) - except APIConnectionError as e: - provider = self.config.features.conversation_provider.value - # Dig out the underlying cause for a useful message - cause = e.__cause__ - detail = str(cause) if cause else str(e) - message = ( - f"Could not connect to {provider}: {detail}" - ) - await printr.print_async( - message, - color=LogType.ERROR, - source=LogSource.WINGMAN, - source_name=self.name, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - return None - except Exception as e: - await printr.print_async( - f"Error during LLM call: {str(e)}", - color=LogType.ERROR, - source=LogSource.WINGMAN, - source_name=self.name, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - return None - - return completion - - async def _llm_call(self, allow_tool_calls: bool = True): - """Makes the primary LLM call with the conversation history and tools enabled. - - Returns: - The LLM completion object or None if the call fails. - """ - - # save request time for later comparison - thiscall = time.time() - self.last_gpt_call = thiscall - - # build tools - tools = self.build_tools() if allow_tool_calls else None - - if self.settings.debug_mode: - await printr.print_async( - f"Calling LLM with {(len(self.messages))} messages (excluding context) and {len(tools) if tools else 0} tools.", - color=LogType.INFO, - ) - - messages = self.messages.copy() - await self.add_context(messages) - - # DEBUG: Print compiled context (dev-only, remove before release) - # if messages and messages[0].get("role") == "system": - # print("\n" + "=" * 80) - # print("COMPILED CONTEXT:") - # print("=" * 80) - # print(messages[0].get("content", "")) - # print("=" * 80 + "\n") - - completion = await self.actual_llm_call(messages, tools) - - # if request isnt most recent, ignore the response - if self.last_gpt_call != thiscall: - await printr.print_async( - "LLM call was cancelled due to a new call.", color=LogType.WARNING - ) - return None - - return completion - - async def _process_completion( - self, completion: ChatCompletion, allow_tool_calls: bool = True - ): - """Processes the completion returned by the LLM call. - - Args: - completion: The completion object from an OpenAI call. - - Returns: - A tuple containing the message response, tool calls, and usage (prompt_tokens, completion_tokens) from the completion. - """ - - response_message = completion.choices[0].message - - content = response_message.content - if content is None: - response_message.content = "" - - # remove hallucinated tools, if none were allowed - if not allow_tool_calls: - response_message.tool_calls = None - - # temporary fix for tool calls that have a command name as function name - if response_message.tool_calls: - response_message.tool_calls = await self._fix_tool_calls( - response_message.tool_calls - ) - - # Extract token usage from the API response - prompt_tokens = 0 - completion_tokens = 0 - if completion.usage: - prompt_tokens = completion.usage.prompt_tokens or 0 - completion_tokens = completion.usage.completion_tokens or 0 - - return ( - response_message, - response_message.tool_calls, - (prompt_tokens, completion_tokens), - ) - - async def _handle_tool_calls(self, tool_calls): - """Processes all the tool calls identified in the response message. - - Args: - tool_calls: The list of tool calls to process. - - Returns: - tuple: (instant_response, skill, tool_timings) where tool_timings is a list of (label, time_ms) tuples. - """ - instant_response = None - function_response = "" - tool_timings: list[tuple[str, float]] = [] - - skill = None - - for tool_call in tool_calls: - try: - function_name = tool_call.function.name - function_args = ( - tool_call.function.arguments - # Mistral returns a dict - if isinstance(tool_call.function.arguments, dict) - # OpenAI returns a string - else json.loads(tool_call.function.arguments) - ) - - # Time the individual tool execution - tool_start = time.perf_counter() - ( - function_response, - instant_response, - skill, - tool_label, - ) = await self.execute_command_by_function_call( - function_name, function_args - ) - tool_time_ms = (time.perf_counter() - tool_start) * 1000 - - # Add timing if we got a label (actual tool execution, not meta-tool) - if tool_label: - tool_timings.append((tool_label, tool_time_ms)) - - # Compress large tool responses via local AI before the cloud LLM sees them - if ( - tool_call.id - and self.config.features.compress_tool_responses - and self.local_ai_service - and self.local_ai_service.is_ready() - and self._tool_response_compressor.should_compress( - str(function_response) - ) - ): - function_response = await self._tool_response_compressor.compress( - response_text=str(function_response), - local_ai_service=self.local_ai_service, - wingman_name=self.name, - tool_name=function_name, - ) - - if tool_call.id: - # updating the dummy tool response with the actual response - await self._update_tool_response(tool_call.id, function_response) - else: - # adding a new tool response - self._add_tool_response(tool_call, function_response) - except Exception as e: - self._add_tool_response(tool_call, "Error") - await printr.print_async( - f"Error while processing tool call: {str(e)}", color=LogType.ERROR - ) - printr.print( - traceback.format_exc(), color=LogType.ERROR, server_only=True - ) - return instant_response, skill, tool_timings - - async def execute_command_by_function_call( - self, function_name: str, function_args: dict[str, any] - ) -> tuple[str, str | None, Skill | None, str | None]: - """ - Uses an OpenAI function call to execute a command. If it's an instant activation_command, one if its responses will be played. - - Args: - function_name (str): The name of the function to be executed. - function_args (dict[str, any]): The arguments to pass to the function being executed. - - Returns: - A tuple containing: - - function_response (str): The text response or result obtained after executing the function. - - instant_response (str): An immediate response or action to be taken, if any (e.g., play audio). - - used_skill (Skill): The skill that was used, if any. - - tool_label (str): Label for benchmark timing (e.g., "MCP: resolve-library-id"), or None for meta-tools. - """ - function_response = "" - instant_response = "" - used_skill = None - tool_label = None - - # Handle persistent memory tools - if function_name in ("memory_remember", "memory_recall", "memory_forget") and self.persistent_memory_service: - if function_name == "memory_remember": - text = function_args.get("text", "") - if text: - await self.persistent_memory_service.add_memory( - entry_type="fact", content=text - ) - function_response = f"I'll remember that: \"{text}\"" - await printr.print_async( - f"Memory stored: {text}", - color=LogType.MEMORY, - source_name=self.name, - ) - else: - function_response = "Nothing to remember — no text provided." - - elif function_name == "memory_recall": - query = function_args.get("query", "") - if query: - results = await self.persistent_memory_service.search(query, limit=10) - if results: - lines = [f"- {r.content}" for r in results] - function_response = "Here's what I remember:\n" + "\n".join(lines) - else: - function_response = "I don't have any memories matching that." - else: - function_response = "No query provided for memory recall." - - elif function_name == "memory_forget": - query = function_args.get("query", "") - if query: - deleted = await self.persistent_memory_service.forget_by_query(query) - if deleted: - function_response = f"Done — I've forgotten the memory related to \"{query}\"." - else: - function_response = "I couldn't find a memory closely matching that to forget." - else: - function_response = "No query provided for memory forget." - - return function_response, None, None, f"💾 memory: {function_name}" - - # Handle unified capability meta-tools (activate_capability, list_active_capabilities) - if self.capability_registry.is_meta_tool(function_name): - function_response, tools_changed = ( - await self.capability_registry.execute_meta_tool( - function_name, function_args - ) - ) - - # If a skill was activated, perform lazy validation - if tools_changed and function_name == "activate_capability": - capability_name = function_args.get("capability_name", "") - skill = self.skill_registry.get_skill_for_activation(capability_name) - if skill and skill.needs_activation(): - success, validation_msg = await skill.ensure_activated() - if not success: - # Validation failed - deactivate the skill - self.skill_registry.deactivate_skill(capability_name) - function_response = validation_msg - tools_changed = False - await printr.print_async( - f"Skill activation failed: {capability_name}", - color=LogType.ERROR, - ) - else: - # Get display name for user-friendly message - display_name = self.skill_registry.get_skill_display_name( - capability_name - ) - await printr.print_async( - f"Skill activated: {display_name}", - color=LogType.SKILL, - ) - - return function_response, None, None, None # Meta-tool, no timing label - - # Handle legacy meta-tools for progressive skill discovery/activation - # These are kept for backward compatibility but shouldn't be called - if self.skill_registry.is_meta_tool(function_name): - function_response, tools_changed = ( - await self.skill_registry.execute_meta_tool( - function_name, function_args - ) - ) - - # If skill was activated, perform lazy validation - if tools_changed and function_name == "activate_skill": - skill_name = function_args.get("skill_name", "") - skill = self.skill_registry.get_skill_for_activation(skill_name) - if skill and skill.needs_activation(): - success, validation_msg = await skill.ensure_activated() - if not success: - # Validation failed - deactivate the skill - self.skill_registry.deactivate_skill(skill_name) - function_response = validation_msg - tools_changed = False - await printr.print_async( - f"Skill activation failed: {skill_name}", - color=LogType.ERROR, - ) - else: - # Get display name for user-friendly message - display_name = self.skill_registry.get_skill_display_name( - skill_name - ) - await printr.print_async( - f"Skill activated: {display_name}", - color=LogType.SKILL, - ) - - return function_response, None, None, None # Meta-tool, no timing label - - # Handle MCP meta-tools for server discovery/activation - if self.mcp_registry.is_meta_tool(function_name): - function_response, tools_changed = ( - await self.mcp_registry.execute_meta_tool(function_name, function_args) - ) - return function_response, None, None, None # Meta-tool, no timing label - - # Handle MCP server tools (prefixed with mcp_) - if self.mcp_registry.is_mcp_tool(function_name): - connection = self.mcp_registry.get_connection_for_tool(function_name) - if connection: - display_name = connection.config.display_name - original_name = self.mcp_registry.get_original_tool_name(function_name) - tool_label = f"🌐 {display_name}: {original_name}" - - benchmark = Benchmark( - f"MCP '{connection.config.name}' - {original_name}" - ) - - # Always show simple 'called' message in UI so users know the wingman is working - await printr.print_async( - f"{display_name}: called `{original_name}` with {function_args}", - color=LogType.MCP, - ) - - # Detailed 'calling' log only in terminal/log file - await printr.print_async( - f"{display_name}: calling `{original_name}` with {function_args}...", - color=LogType.MCP, - server_only=True, - ) - - try: - function_response = await self.mcp_registry.call_tool( - function_name, function_args - ) - except Exception as e: - await printr.print_async( - f"{display_name}: `{original_name}` failed - {str(e)}", - color=LogType.ERROR, - ) - printr.print( - traceback.format_exc(), color=LogType.ERROR, server_only=True - ) - function_response = "ERROR DURING MCP TOOL EXECUTION" - finally: - # Detailed 'completed' with timing only in terminal/log file (or UI if debug) - await printr.print_async( - f"{display_name}: `{original_name}` completed", - color=LogType.MCP, - benchmark_result=benchmark.finish(), - server_only=not self.settings.debug_mode, - ) - - return function_response, None, None, tool_label - - # Handle command calls - if function_name == "execute_command": - # get the command based on the argument passed by the LLM - command = self.get_command(function_args["command_name"]) - # execute the command - instant_response, function_response = await self._execute_command(command) - tool_label = f"Command: {function_args.get('command_name', function_name)}" - # if the command has responses, we have to play one of them - if instant_response: - await self.play_to_user(instant_response) - - # Go through the skills and check if the function name matches any of the tools - if function_name in self.tool_skills: - skill = self.tool_skills[function_name] - display_name = self.skill_registry.get_skill_display_name(skill.name) - tool_label = f"⚡ {display_name}: {function_name}" - - benchmark = Benchmark(f"Skill '{skill.name}' - {function_name}") - - # Always show simple 'called' message in UI so users know the wingman is working - await printr.print_async( - f"{display_name}: called `{function_name}` with {function_args}", - color=LogType.SKILL, - skill_name=skill.name, - ) - - # Detailed 'calling' log only in terminal/log file - await printr.print_async( - f"{display_name}: calling `{function_name}` with {function_args}...", - color=LogType.SKILL, - skill_name=skill.name, - server_only=True, - ) - - try: - function_response, instant_response = await skill.execute_tool( - tool_name=function_name, - parameters=function_args, - benchmark=benchmark, - ) - used_skill = skill - if instant_response: - await self.play_to_user(instant_response) - except Exception as e: - await printr.print_async( - f"{display_name}: `{function_name}` failed - {str(e)}", - color=LogType.ERROR, - ) - printr.print( - traceback.format_exc(), color=LogType.ERROR, server_only=True - ) - function_response = ( - "ERROR DURING PROCESSING" # hints to AI that there was an error - ) - instant_response = None - finally: - await printr.print_async( - f"{display_name}: `{function_name}` completed", - color=LogType.SKILL, - benchmark_result=benchmark.finish(), - skill_name=skill.name, - server_only=not self.settings.debug_mode, - ) - - return function_response, instant_response, used_skill, tool_label - - async def play_to_user( - self, - text: str, - no_interrupt: bool = False, - sound_config: Optional[SoundConfig] = None, - ): - """Plays audio to the user using the configured TTS Provider (default: OpenAI TTS). - Also adds sound effects if enabled in the configuration. - - Args: - text (str): The text to play as audio. - """ - if sound_config: - printr.print( - "Using custom sound config for playback", LogType.INFO, server_only=True - ) - else: - sound_config = self.config.sound - - # remove Markdown, links, emotes and code blocks - text, contains_links, contains_code_blocks = cleanup_text(text) - - # wait for audio player to finish playing - if no_interrupt and self.audio_player.is_playing: - while self.audio_player.is_playing: - await asyncio.sleep(0.1) - - # call skill hooks (only for prepared/activated skills) - changed_text = text - for skill in self.skills: - if skill.is_prepared: - changed_text = await skill.on_play_to_user(text, sound_config) - if changed_text != text: - printr.print( - f"Skill '{skill.config.display_name}' modified the text to: '{changed_text}'", - LogType.INFO, - ) - text = changed_text - - if sound_config.volume == 0.0: - printr.print( - "Volume modifier is set to 0. Skipping TTS processing.", - LogType.WARNING, - server_only=True, - ) - return - - if "{SKIP-TTS}" in text: - printr.print( - "Skip TTS phrase found in input. Skipping TTS processing.", - LogType.WARNING, - server_only=True, - ) - return - - try: - if self.config.features.tts_provider == TtsProvider.EDGE_TTS: - await self.edge_tts.play_audio( - text=text, - config=self.config.edge_tts, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - ) - elif self.config.features.tts_provider == TtsProvider.ELEVENLABS: - await self.elevenlabs.play_audio( - text=text, - config=self.config.elevenlabs, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - stream=self.config.elevenlabs.output_streaming, - ) - elif self.config.features.tts_provider == TtsProvider.HUME: - try: - await self.hume.play_audio( - text=text, - config=self.config.hume, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - ) - except RuntimeError as e: - if "Event loop is closed" in str(e): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - await self.hume.play_audio( - text=text, - config=self.config.hume, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - ) - elif self.config.features.tts_provider == TtsProvider.INWORLD: - await self.inworld.play_audio( - text=text, - config=self.config.inworld, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - ) - elif self.config.features.tts_provider == TtsProvider.AZURE: - await self.openai_azure.play_audio( - text=text, - api_key=self.azure_api_keys["tts"], - config=self.config.azure.tts, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - ) - elif self.config.features.tts_provider == TtsProvider.XVASYNTH: - await self.xvasynth.play_audio( - text=text, - config=self.config.xvasynth, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - ) - elif self.config.features.tts_provider == TtsProvider.OPENAI: - await self.openai.play_audio( - text=text, - voice=self.config.openai.tts_voice, - model=self.config.openai.tts_model, - speed=self.config.openai.tts_speed, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - stream=self.config.openai.output_streaming, - ) - elif self.config.features.tts_provider == TtsProvider.OPENAI_COMPATIBLE: - await self.openai_compatible_tts.play_audio( - text=text, - voice=self.config.openai_compatible_tts.voice, - model=self.config.openai_compatible_tts.model, - speed=( - self.config.openai_compatible_tts.speed - if self.config.openai_compatible_tts.speed #!= 1.0 - else NOT_GIVEN - ), - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - stream=self.config.openai_compatible_tts.output_streaming, - ) - elif self.config.features.tts_provider == TtsProvider.POCKET_TTS: - await self.pocket_tts.play_audio( - text=text, - config=self.config.pocket_tts, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - ) - elif self.config.features.tts_provider == TtsProvider.WINGMAN_PRO: - if self.config.wingman_pro.tts_provider == WingmanProTtsProvider.OPENAI: - await self.wingman_pro.generate_openai_speech( - text=text, - voice=self.config.openai.tts_voice, - model=self.config.openai.tts_model, - speed=self.config.openai.tts_speed, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - ) - elif ( - self.config.wingman_pro.tts_provider == WingmanProTtsProvider.AZURE - ): - await self.wingman_pro.generate_azure_speech( - text=text, - config=self.config.azure.tts, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - ) - elif ( - self.config.wingman_pro.tts_provider - == WingmanProTtsProvider.INWORLD - ): - await self.wingman_pro.generate_inworld_speech( - text=text, - config=self.config.inworld, - sound_config=sound_config, - audio_player=self.audio_player, - wingman_name=self.name, - ) - else: - printr.toast_error( - f"Unsupported TTS provider: {self.config.features.tts_provider}" - ) - except Exception as e: - await printr.print_async( - f"Error during TTS playback: {str(e)}", color=LogType.ERROR - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - - async def _execute_command( - self, command: CommandConfig, is_instant=False - ) -> tuple[str | None, str]: - """Executes a command by delegating to the Wingman base implementation. - - Returns: - tuple[str | None, str]: A 2-tuple of: - - Instant response (str) to play immediately, or None if there is no instant response. - - Function/tool response (str) to feed back to the LLM. - """ - return await super()._execute_command(command, is_instant) - - def build_tools(self) -> list[dict]: - """ - Builds tools for the LLM call. - - In progressive mode: Returns meta-tools (search_skills, activate_skill) plus - tools from activated skills only. - - In legacy mode: Returns all skill tools. - - Returns: - list[dict]: A list of tool descriptors in OpenAI format. - """ - - def _command_has_effective_actions(command: CommandConfig) -> bool: - if command.is_system_command: - return True - - if not command.actions: - return False - - for action in command.actions: - if not action: - continue - if ( - action.keyboard is not None - or action.mouse is not None - or action.joystick is not None - or action.audio is not None - or action.write is not None - or action.wait is not None - ): - return True - - return False - - commands = [ - command.name - for command in self.config.commands - if (not command.force_instant_activation) - and _command_has_effective_actions(command) - ] - tools: list[dict] = [] - if commands: - tools.append( - { - "type": "function", - "function": { - "name": "execute_command", - "description": "Executes a command", - "parameters": { - "type": "object", - "properties": { - "command_name": { - "type": "string", - "description": "The name of the command to execute", - "enum": commands, - }, - }, - "required": ["command_name"], - }, - }, - } - ) - - # Unified capability discovery: single activate_capability meta-tool - # Combines skills and MCP servers - LLM doesn't need to know the difference - for _, tool in self.capability_registry.get_meta_tools(): - tools.append(tool) - - # Add tools from activated capabilities (both skills and MCPs) - for _, tool in self.skill_registry.get_active_tools(): - tools.append(tool) - - for _, tool in self.mcp_registry.get_active_tools(): - tools.append(tool) - - # Persistent memory tools — auto-injected when enabled - if self.persistent_memory_service: - tools.append({ - "type": "function", - "function": { - "name": "memory_remember", - "description": "Store an important fact or detail for future reference. Use when the user explicitly asks you to remember something.", - "parameters": { - "type": "object", - "properties": { - "text": { - "type": "string", - "description": "The fact or detail to remember.", - }, - }, - "required": ["text"], - }, - }, - }) - tools.append({ - "type": "function", - "function": { - "name": "memory_recall", - "description": "Search your memory for relevant information. Use when the user asks what you remember or know about a topic.", - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "What to search for in memory.", - }, - }, - "required": ["query"], - }, - }, - }) - tools.append({ - "type": "function", - "function": { - "name": "memory_forget", - "description": "Remove a specific memory. Use when the user explicitly asks you to forget something.", - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "Description of the memory to forget.", - }, - }, - "required": ["query"], - }, - }, - }) - - return tools - - def __get_message_role(self, message): - """Helper method to get the role of the message regardless of its type.""" - if isinstance(message, Mapping): - return message.get("role") - elif hasattr(message, "role"): - return message.role - else: - raise TypeError( - f"Message is neither a mapping nor has a 'role' attribute: {message}" - ) +__all__ = ["OpenAiWingman"] diff --git a/wingmen/wingman.py b/wingmen/wingman.py index 82b96d8a..de3f603d 100644 --- a/wingmen/wingman.py +++ b/wingmen/wingman.py @@ -1,16 +1,26 @@ -import traceback -from copy import deepcopy -import random +"""Unified Wingman class. + +Merges the former base ``Wingman`` and its only subclass ``OpenAiWingman`` +into a single class that delegates to extracted services and provider +interfaces for STT, TTS, and LLM. +""" + +import json import time -import difflib import asyncio +import random +import traceback import threading +from copy import deepcopy +import difflib from typing import ( Any, Dict, Optional, TYPE_CHECKING, ) +from openai import APIConnectionError +from openai.types.chat import ChatCompletion import keyboard.keyboard as keyboard import mouse.mouse as mouse from api.interface import ( @@ -23,21 +33,34 @@ ) from api.enums import ( CommandTag, + ConversationProvider, + ImageGenerationProvider, LogSource, LogType, + SttProvider, + TtsProvider, WingmanInitializationErrorType, ) -from providers.faster_whisper import FasterWhisper -from providers.parakeet import Parakeet -from providers.whispercpp import Whispercpp -from providers.xvasynth import XVASynth -from providers.pocket_tts import PocketTTS +from api.commands import McpStateChangedCommand +from providers.interfaces import LlmInterface, SttInterface, TtsInterface from services.audio_player import AudioPlayer from services.benchmark import Benchmark +from services.markdown import cleanup_text from services.module_manager import ModuleManager from services.secret_keeper import SecretKeeper from services.printr import Printr from services.audio_library import AudioLibrary +from services.conversation_manager import ConversationManager +from services.conversation_condenser import ConversationCondenser +from services.context_builder import ContextBuilder +from services.tool_executor import ToolExecutor +from services.provider_factory import ProviderFactory +from services.skill_registry import SkillRegistry +from services.mcp_client import McpClient +from services.mcp_registry import McpRegistry +from services.capability_registry import CapabilityRegistry +from services.tool_response_cache import ToolResponseCompressor +from services.token_utils import count_tokens from skills.skill_base import Skill if TYPE_CHECKING: @@ -52,11 +75,23 @@ def _get_skill_folder_from_module(module: str) -> str: class Wingman: - """The "highest" Wingman base class in the chain. It does some very basic things but is meant to be 'virtual', and so are most its methods, so you'll probably never instantiate it directly. + """Unified Wingman class. - Instead, you'll create a custom wingman that inherits from this (or a another subclass of it) and override its methods if needed. + Handles lifecycle, process loop, audio, command execution, config + save/load, skill management, provider routing, conversation management, + tool execution, context building, and condensation. + + Providers are resolved via :class:`ProviderFactory` into three + interface slots: ``stt``, ``tts``, ``llm``. Heavy orchestration logic + is delegated to extracted service objects. """ + AZURE_SERVICES = { + "tts": None, # kept for potential future use + "whisper": None, + "conversation": None, + } + def __init__( self, name: str, @@ -64,81 +99,104 @@ def __init__( settings: SettingsConfig, audio_player: AudioPlayer, audio_library: AudioLibrary, - whispercpp: Whispercpp, - fasterwhisper: FasterWhisper, - parakeet: Parakeet, - xvasynth: XVASynth, - pocket_tts: PocketTTS, - tower: "Tower", + whispercpp=None, + fasterwhisper=None, + parakeet=None, + xvasynth=None, + pocket_tts=None, + tower: "Tower" = None, ): - """The constructor of the Wingman class. You can override it in your custom wingman. - - Args: - name (str): The name of the wingman. This is the key you gave it in the config, e.g. "atc" - config (WingmanConfig): All "general" config entries merged with the specific Wingman config settings. The Wingman takes precedence and overrides the general config. You can just add new keys to the config and they will be available here. - """ - self.config = config - """All "general" config entries merged with the specific Wingman config settings. The Wingman takes precedence and overrides the general config. You can just add new keys to the config and they will be available here.""" - self.settings = settings - """The general user settings.""" + self.name = name + self.audio_player = audio_player + self.audio_library = audio_library + self.tower = tower self.secret_keeper = SecretKeeper() - """A service that allows you to store and retrieve secrets like API keys. It can prompt the user for secrets if necessary.""" self.secret_keeper.secret_events.subscribe( "secrets_saved", self.handle_secret_saved ) - self.name = name - """The name of the wingman. This is the key you gave it in the config, e.g. "atc".""" - - self.audio_player = audio_player - """A service that allows you to play audio files and add sound effects to them.""" - - self.audio_library = audio_library - """A service that allows you to play and manage audio files from the audio library.""" - - self.execution_start: None | float = None - """Used for benchmarking executon times. The timer is (re-)started whenever the process function starts.""" - + # Shared provider singletons (passed from Tower) + self._shared_providers = { + "whispercpp": whispercpp, + "fasterwhisper": fasterwhisper, + "parakeet": parakeet, + "xvasynth": xvasynth, + "pocket_tts": pocket_tts, + } + + # --- Provider interface slots (populated by validate → ProviderFactory) --- + self.stt: SttInterface | None = None + self.tts: TtsInterface | None = None + self.llm: LlmInterface | None = None + + # --- Backward-compat: keep old attributes for custom wingmen / skills --- self.whispercpp = whispercpp - """A class that handles the communication with the Whispercpp server for transcription.""" - self.fasterwhisper = fasterwhisper - """A class that handles local transcriptions using FasterWhisper.""" - self.parakeet = parakeet - """A class that handles local transcriptions using NVIDIA Parakeet TDT via ONNX Runtime.""" - self.xvasynth = xvasynth - """A class that handles the communication with the XVASynth server for TTS.""" - self.pocket_tts = pocket_tts - """A class that handles the communication with the PocketTTS server for TTS.""" - self.tower = tower - """The Tower instance that manages all Wingmen in the same config dir.""" + # --- Extracted services --- + self.conversation = ConversationManager(config, settings, name) + self.condenser = ConversationCondenser(self.conversation, config, name) + self.context_builder = ContextBuilder(config, settings, name) + self.tool_executor = ToolExecutor(config, settings, name) + # --- Token tracking --- self.last_turn_prompt_tokens: int = 0 self.last_turn_completion_tokens: int = 0 + self.execution_start: None | float = None + self._last_prompt_tokens: int = 0 + # --- Skills --- self.skills: list[Skill] = [] + self.tool_skills: dict[str, Skill] = {} + self.skill_tools: list[dict] = [] + self.skill_registry = SkillRegistry() + + # --- MCP --- + self.mcp_client = McpClient(wingman_name=self.name) + self.mcp_registry = McpRegistry( + self.mcp_client, + wingman_name=self.name, + on_state_changed=self._broadcast_mcp_state_changed, + ) + + # --- Unified capability registry --- + self.capability_registry = CapabilityRegistry( + self.skill_registry, self.mcp_registry + ) + + # --- Local AI / persistent memory --- + self.local_ai_service = None + self.persistent_memory_service = None + self._memory_recall_notified = False + self._background_tasks: set[asyncio.Task] = set() + self._tool_response_compressor = ToolResponseCompressor() + + # --- Conversation state --- + self.last_gpt_call = None + self.instant_responses = [] + self.last_used_instant_responses = [] + + # ──────────────────────────────── Record keys ─────────────────────────────── # def get_record_key(self) -> str | int: - """Returns the activation or "push-to-talk" key for this Wingman.""" return self.config.record_key_codes or self.config.record_key def get_record_mouse_button(self) -> str: - """Returns the activation or "push-to-talk" mouse button for this Wingman.""" return self.config.record_mouse_button def get_record_joystick_button(self) -> str: - """Returns the activation or "push-to-talk" joystick button for this Wingman.""" if not self.config.record_joystick_button: return None return f"{self.config.record_joystick_button.guid}{self.config.record_joystick_button.button}" + # ──────────────────────────────── Secrets ──────────────────────────────────── # + async def handle_secret_saved(self, _secrets: Dict[str, Any]): await printr.print_async( text="Secret saved", @@ -147,26 +205,7 @@ async def handle_secret_saved(self, _secrets: Dict[str, Any]): ) await self.validate() - # ──────────────────────────────────── Hooks ─────────────────────────────────── # - - async def validate(self) -> list[WingmanInitializationError]: - """Use this function to validate params and config before the Wingman is started. - If you add new config sections or entries to your custom wingman, you should validate them here. - - It's a good idea to collect all errors from the base class and not to swallow them first. - - If you return MISSING_SECRET errors, the user will be asked for them. - If you return other errors, your Wingman will not be loaded by Tower. - - Returns: - list[WingmanInitializationError]: A list of errors or an empty list if everything is okay. - """ - return [] - async def retrieve_secret(self, secret_name, errors, is_required=True): - """Use this method to retrieve secrets like API keys from the SecretKeeper. - If the key is missing, the user will be prompted to enter it. - """ try: api_key = await self.secret_keeper.retrieve( requester=self.name, @@ -184,7 +223,7 @@ async def retrieve_secret(self, secret_name, errors, is_required=True): ) except Exception as e: printr.print( - f"Error retrieving secret ''{secret_name}: {e}", + f"Error retrieving secret '{secret_name}': {e}", color=LogType.ERROR, server_only=True, ) @@ -201,14 +240,75 @@ async def retrieve_secret(self, secret_name, errors, is_required=True): return api_key - async def prepare(self): - """This method is called only once when the Wingman is instantiated by Tower. - It is run AFTER validate() and AFTER init_skills() so you can access validated params safely here. + # ──────────────────────────────── Validate ─────────────────────────────────── # + + async def validate(self) -> list[WingmanInitializationError]: + errors: list[WingmanInitializationError] = [] + + try: + factory = ProviderFactory( + config=self.config, + settings=self.settings, + secret_keeper=self.secret_keeper, + shared_providers=self._shared_providers, + wingman_name=self.name, + ) + self.stt = await factory.create_stt(errors) + self.tts = await factory.create_tts(errors) + self.llm = await factory.create_llm(errors) + except Exception as e: + errors.append( + WingmanInitializationError( + wingman_name=self.name, + message=f"Error during provider validation: {str(e)}", + error_type=WingmanInitializationErrorType.UNKNOWN, + ) + ) + printr.print( + f"Error during provider validation: {str(e)}", + color=LogType.ERROR, + server_only=True, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + + return errors + + # ──────────────────────────────── Lifecycle ─────────────────────────────────── # - You can override it if you need to load async data from an API or file.""" + async def prepare(self): + try: + if self.config.features.use_generic_instant_responses: + printr.print( + "Generating AI instant responses...", + color=LogType.WARNING, + server_only=True, + ) + self.threaded_execution(self._generate_instant_responses) + except Exception as e: + await printr.print_async( + f"Error while preparing wingman '{self.name}': {str(e)}", + color=LogType.ERROR, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) async def unload(self): - """This method is called when the Wingman is unloaded by Tower. You can override it if you need to clean up resources.""" + # Wait for any background memory extraction tasks to finish + if self._background_tasks: + await asyncio.gather(*self._background_tasks, return_exceptions=True) + self._background_tasks.clear() + + if self.persistent_memory_service: + from services.persistent_memory import MIN_MESSAGES_FOR_EXTRACTION + + if len(self.conversation.messages) >= MIN_MESSAGES_FOR_EXTRACTION: + try: + await self.persistent_memory_service.extract_memories( + self.conversation.messages, generate_summary=True + ) + except Exception: + pass + self.persistent_memory_service.close() + # Unsubscribe from secret events to prevent duplicate handlers self.secret_keeper.secret_events.unsubscribe( "secrets_saved", self.handle_secret_saved @@ -216,10 +316,7 @@ async def unload(self): await self.unload_skills() async def unload_skills(self): - """Call this to trigger unload for skills that were actually prepared/used.""" for skill in self.skills: - # Only unload skills that were actually prepared (activated) - # Skills that were never used don't need cleanup if not skill.is_prepared: continue try: @@ -232,19 +329,270 @@ async def unload_skills(self): printr.print( traceback.format_exc(), color=LogType.ERROR, server_only=True ) + self.tool_skills = {} + self.skill_tools = [] + self.skill_registry.clear() - async def init_skills(self) -> list[WingmanInitializationError]: - """Load all available skills with lazy validation. + async def unload_mcps(self): + await self.mcp_registry.clear() + + # ──────────────────────────────── Memory ──────────────────────────────────── # + + def ensure_memory_initialized(self) -> bool: + if self.persistent_memory_service and not self.config.persistent_memory: + self.persistent_memory_service.close() + self.persistent_memory_service = None + return False + if self.persistent_memory_service: + return True + if self.config.persistent_memory and self.local_ai_service: + from services.persistent_memory import PersistentMemoryService + + self.persistent_memory_service = PersistentMemoryService( + wingman_name=self.name, + local_ai_service=self.local_ai_service, + ) + self.persistent_memory_service.initialize() + return True + return False + + # ──────────────────────────────── MCP ──────────────────────────────────────── # + + def _broadcast_mcp_state_changed(self): + if printr._connection_manager: + printr.ensure_async( + printr._connection_manager.broadcast( + McpStateChangedCommand(wingman_name=self.name) + ) + ) + + async def enable_mcp(self, mcp_name: str) -> tuple[bool, str]: + if not self.mcp_client.is_available: + return False, "MCP SDK not installed." + + if mcp_name in self.mcp_registry.get_connected_server_names(): + return True, f"MCP server '{mcp_name}' is already connected." + + central_mcp_config = self.tower.config_manager.mcp_config + mcp_configs = central_mcp_config.servers if central_mcp_config else [] + + mcp_config = None + for cfg in mcp_configs: + if cfg.name == mcp_name: + mcp_config = cfg + break + + if not mcp_config: + return False, f"MCP server '{mcp_name}' not found in mcp.yaml." + + try: + headers = {} + if mcp_config.headers: + headers.update(mcp_config.headers) + + secret_key = f"mcp_{mcp_config.name}" + api_key = await self.secret_keeper.retrieve( + requester=self.name, + key=secret_key, + prompt_if_missing=False, + ) + if api_key: + if not any( + k.lower() in ["authorization", "api-key", "x-api-key"] + for k in headers.keys() + ): + headers["Authorization"] = f"Bearer {api_key}" + + default_timeout = 60.0 if mcp_config.type.value == "stdio" else 30.0 + timeout = ( + float(mcp_config.timeout) if mcp_config.timeout else default_timeout + ) + + connection = await asyncio.wait_for( + self.mcp_registry.register_server( + config=mcp_config, + headers=headers if headers else None, + ), + timeout=timeout, + ) + + if connection.is_connected: + tool_count = len(connection.tools) + return True, f"MCP server '{mcp_name}' enabled with {tool_count} tools." + else: + error = connection.error or "Connection failed." + return False, f"MCP server '{mcp_name}' failed to connect: {error}" + + except asyncio.TimeoutError: + error_msg = f"Connection timed out ({int(timeout)}s)." + self.mcp_registry.set_server_error(mcp_name, error_msg) + return False, f"MCP server '{mcp_name}': {error_msg}" - Skills are loaded but NOT validated during init. Validation happens - on first activation via the SkillRegistry. User config overrides from - self.config.skills are merged with default configs. + except Exception as e: + error_msg = f"Error enabling MCP '{mcp_name}': {str(e)}" + await printr.print_async(error_msg, color=LogType.ERROR) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + return False, error_msg - Platform-incompatible skills are skipped entirely. - """ + async def disable_mcp(self, mcp_name: str) -> tuple[bool, str]: + if mcp_name not in self.mcp_registry.get_connected_server_names(): + return True, f"MCP server '{mcp_name}' is already disconnected." + + try: + await self.mcp_registry.unregister_server(mcp_name) + return True, f"MCP server '{mcp_name}' disabled." + + except Exception as e: + error_msg = f"Error disabling MCP '{mcp_name}': {str(e)}" + await printr.print_async(error_msg, color=LogType.ERROR) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + return False, error_msg + + async def init_mcps(self) -> list[WingmanInitializationError]: + errors = [] + + if not self.mcp_client.is_available: + printr.print( + f"[{self.name}] MCP SDK not installed, skipping MCP initialization.", + color=LogType.WARNING, + server_only=True, + ) + return errors + + await self.unload_mcps() + + central_mcp_config = self.tower.config_manager.mcp_config + mcp_configs = central_mcp_config.servers if central_mcp_config else [] + if not mcp_configs: + return errors + + discoverable_mcps = self.config.discoverable_mcps + mcps_to_connect = [mcp for mcp in mcp_configs if mcp.name in discoverable_mcps] + + if not mcps_to_connect: + return errors + + async def connect_mcp(mcp_config): + local_errors = [] + try: + headers = {} + if mcp_config.headers: + headers.update(mcp_config.headers) + + secret_key = f"mcp_{mcp_config.name}" + api_key = await self.secret_keeper.retrieve( + requester=self.name, + key=secret_key, + prompt_if_missing=False, + ) + if api_key: + printr.print( + f"MCP secret '{secret_key}' found ({len(api_key)} chars)", + color=LogType.INFO, + source_name=self.name, + server_only=True, + ) + if not any( + k.lower() in ["authorization", "api-key", "x-api-key"] + for k in headers.keys() + ): + headers["Authorization"] = f"Bearer {api_key}" + + default_timeout = 60.0 if mcp_config.type.value == "stdio" else 30.0 + timeout = ( + float(mcp_config.timeout) if mcp_config.timeout else default_timeout + ) + + try: + connection = await asyncio.wait_for( + self.mcp_registry.register_server( + config=mcp_config, + headers=headers if headers else None, + ), + timeout=timeout, + ) + except asyncio.TimeoutError: + error_msg = f"MCP '{mcp_config.display_name}' connection timed out ({int(timeout)}s)." + printr.print( + error_msg, + color=LogType.WARNING, + source_name=self.name, + server_only=True, + ) + local_errors.append( + WingmanInitializationError( + wingman_name=self.name, + message=error_msg, + error_type=WingmanInitializationErrorType.MCP_CONNECTION_FAILED, + ) + ) + return (False, None, local_errors) + + if connection.is_connected: + return ( + True, + f"{mcp_config.display_name} ({len(connection.tools)} tools)", + local_errors, + ) + else: + error_msg = f"MCP '{mcp_config.display_name}' failed to connect: {connection.error}" + local_errors.append( + WingmanInitializationError( + wingman_name=self.name, + message=error_msg, + error_type=WingmanInitializationErrorType.MCP_CONNECTION_FAILED, + ) + ) + return (False, None, local_errors) + + except Exception as e: + error_msg = f"MCP '{mcp_config.name}' initialization error: {str(e)}" + printr.print( + error_msg, + color=LogType.ERROR, + source_name=self.name, + server_only=True, + ) + printr.print( + traceback.format_exc(), color=LogType.ERROR, server_only=True + ) + local_errors.append( + WingmanInitializationError( + wingman_name=self.name, + message=error_msg, + error_type=WingmanInitializationErrorType.MCP_CONNECTION_FAILED, + ) + ) + return (False, None, local_errors) + + connection_tasks = [connect_mcp(mcp) for mcp in mcps_to_connect] + results = await asyncio.gather(*connection_tasks) + + connected_count = 0 + connected_names = [] + for success, connection_info, mcp_errors in results: + if success: + connected_count += 1 + connected_names.append(connection_info) + errors.extend(mcp_errors) + + if connected_count > 0: + await printr.print_async( + f"Discoverable MCP servers connected ({connected_count}): {', '.join(connected_names)}", + color=LogType.WINGMAN, + source=LogSource.WINGMAN, + source_name=self.name, + server_only=not self.settings.debug_mode, + ) + + return errors + + # ──────────────────────────────── Skills ───────────────────────────────────── # + + async def init_skills(self) -> list[WingmanInitializationError]: import sys - current_platform = sys.platform # 'win32', 'darwin', 'linux' + current_platform = sys.platform platform_map = {"win32": "windows", "darwin": "darwin", "linux": "linux"} normalized_platform = platform_map.get(current_platform, current_platform) @@ -254,18 +602,13 @@ async def init_skills(self) -> list[WingmanInitializationError]: errors = [] self.skills = [] - # Build a lookup of user config overrides by skill folder name - # The key must be the folder name (e.g., 'star_head') not the class name (e.g., 'StarHead') user_skill_configs: dict[str, "SkillConfig"] = {} if self.config.skills: for skill_config in self.config.skills: folder_name = _get_skill_folder_from_module(skill_config.module) user_skill_configs[folder_name] = skill_config - # Get all available skill configs available_skills = ModuleManager.read_available_skill_configs() - - # Get discoverable skills list (whitelist) discoverable_skills = self.config.discoverable_skills for ( @@ -275,19 +618,14 @@ async def init_skills(self) -> list[WingmanInitializationError]: _is_local, ) in available_skills: try: - # Load default skill config first to get the display name skill_config_dict = ModuleManager.read_config(skill_config_path) if not skill_config_dict: continue - # Import SkillConfig here to avoid circular imports from api.interface import SkillConfig - # Check if user has overrides for this skill if skill_folder_name in user_skill_configs: - # Merge user overrides into default config user_config = user_skill_configs[skill_folder_name] - # User config takes precedence - merge custom_properties especially if user_config.custom_properties: skill_config_dict["custom_properties"] = [ prop.model_dump() for prop in user_config.custom_properties @@ -297,11 +635,9 @@ async def init_skills(self) -> list[WingmanInitializationError]: skill_config = SkillConfig(**skill_config_dict) - # Check if skill is discoverable for this wingman (whitelist - must be in list) if skill_config.name not in discoverable_skills: continue - # Check platform compatibility BEFORE loading the module if skill_config.platforms: if normalized_platform not in skill_config.platforms: printr.print( @@ -311,18 +647,13 @@ async def init_skills(self) -> list[WingmanInitializationError]: ) continue - # Load the skill module skill = ModuleManager.load_skill( config=skill_config, settings=self.settings, wingman=self, ) if skill: - # Set up skill methods skill.threaded_execution = self.threaded_execution - - # Add to skills list WITHOUT validation - # Validation will happen lazily on first activation self.skills.append(skill) await self.prepare_skill(skill) @@ -344,7 +675,6 @@ async def init_skills(self) -> list[WingmanInitializationError]: ) ) - # Log summary of discoverable skills for this wingman if self.skills: skill_names = [s.config.name for s in self.skills] await printr.print_async( @@ -358,41 +688,58 @@ async def init_skills(self) -> list[WingmanInitializationError]: return errors async def prepare_skill(self, skill: Skill): - """This method is called only once when the Skill is instantiated. - It is run AFTER validate() so you can access validated params safely here. + try: + for tool_name, tool in skill.get_tools(): + self.tool_skills[tool_name] = skill + self.skill_tools.append(tool) - You can override it if you need to react on data of this skill.""" + self.skill_registry.register_skill(skill) - async def unprepare_skill(self, skill: Skill): - """Remove a skill's registration. Called when a skill is disabled. - - Override in subclass to clean up skill-specific registrations.""" - pass + if skill.config.auto_activate: + success, message = await skill.ensure_activated() + if not success: + await printr.print_async( + f"Auto-activated skill '{skill.config.display_name}' failed to activate: {message}", + color=LogType.ERROR, + ) + except Exception as e: + await printr.print_async( + f"Error while preparing skill '{skill.name}': {str(e)}", + color=LogType.ERROR, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - async def enable_skill(self, skill_name: str) -> tuple[bool, str]: - """Enable a single skill without reinitializing all skills. + skill.llm_call = self.actual_llm_call - Args: - skill_name: The display name of the skill to enable + async def unprepare_skill(self, skill: Skill): + try: + for tool_name, _ in skill.get_tools(): + self.tool_skills.pop(tool_name, None) + self.skill_tools = [ + t + for t in self.skill_tools + if t.get("function", {}).get("name") != tool_name + ] + self.skill_registry.unregister_skill(skill.name) + except Exception as e: + await printr.print_async( + f"Error while unpreparing skill '{skill.name}': {str(e)}", + color=LogType.ERROR, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - Returns: - (success, message) tuple - """ + async def enable_skill(self, skill_name: str) -> tuple[bool, str]: import sys current_platform = sys.platform platform_map = {"win32": "windows", "darwin": "darwin", "linux": "linux"} normalized_platform = platform_map.get(current_platform, current_platform) - # Check if skill is already enabled for existing_skill in self.skills: if existing_skill.config.name == skill_name: return True, f"Skill '{skill_name}' is already enabled." - # Find the skill config available_skills = ModuleManager.read_available_skill_configs() - - # Build user config lookup by skill folder name user_skill_configs: dict[str, "SkillConfig"] = {} if self.config.skills: for skill_config in self.config.skills: @@ -412,7 +759,6 @@ async def enable_skill(self, skill_name: str) -> tuple[bool, str]: from api.interface import SkillConfig - # Apply user overrides if skill_folder_name in user_skill_configs: user_config = user_skill_configs[skill_folder_name] if user_config.custom_properties: @@ -427,7 +773,6 @@ async def enable_skill(self, skill_name: str) -> tuple[bool, str]: if skill_config.name != skill_name: continue - # Check platform compatibility if skill_config.platforms: if normalized_platform not in skill_config.platforms: return ( @@ -435,7 +780,6 @@ async def enable_skill(self, skill_name: str) -> tuple[bool, str]: f"Skill '{skill_name}' is not supported on {normalized_platform}.", ) - # Load and register the skill skill = ModuleManager.load_skill( config=skill_config, settings=self.settings, @@ -464,15 +808,6 @@ async def enable_skill(self, skill_name: str) -> tuple[bool, str]: return False, f"Skill '{skill_name}' not found." async def disable_skill(self, skill_name: str) -> tuple[bool, str]: - """Disable a single skill without reinitializing all skills. - - Args: - skill_name: The display name of the skill to disable - - Returns: - (success, message) tuple - """ - # Find the skill in our list skill_to_remove = None for skill in self.skills: if skill.config.name == skill_name: @@ -483,13 +818,8 @@ async def disable_skill(self, skill_name: str) -> tuple[bool, str]: return True, f"Skill '{skill_name}' is already deactivated." try: - # Unload the skill (cleanup resources, unsubscribe events) await skill_to_remove.unload() - - # Remove from skill list self.skills.remove(skill_to_remove) - - # Remove skill-specific registrations (tools, registry, etc.) await self.unprepare_skill(skill_to_remove) printr.print( @@ -505,36 +835,14 @@ async def disable_skill(self, skill_name: str) -> tuple[bool, str]: printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) return False, error_msg - async def reset_conversation_history(self): - """This function is called when the user triggers the ResetConversationHistory command. - It's a global command that should be implemented by every Wingman that keeps a message history. - """ - - # ──────────────────────────── The main processing loop ──────────────────────────── # + # ──────────────────────────── The main processing loop ──────────────────────── # async def process(self, audio_input_wav: str = None, transcript: str = None, images: list[tuple[str, str]] = None): - """The main method that gets called when the wingman is activated. This method controls what your wingman actually does and you can override it if you want to. - - The base implementation here triggers the transcription and processing of the given audio input. - If you don't need even transcription, you can just override this entire process method. If you want transcription but then do something in addition, you can override the listed hooks. - - Async so you can do async processing, e.g. send a request to an API. - - Args: - audio_input_wav (str): The path to the audio file that contains the user's speech. This is a recording of what you you said. - - Hooks: - - async _transcribe: transcribe the audio to text - - async _get_response_for_transcript: process the transcript and return a text response - - async play_to_user: do something with the response, e.g. play it as audio - """ - try: process_result = None benchmark_transcribe = None if not transcript: - # transcribe the audio. benchmark_transcribe = Benchmark(label="Voice transcription") transcript = await self._transcribe(audio_input_wav) @@ -554,9 +862,6 @@ async def process(self, audio_input_wav: str = None, transcript: str = None, ima additional_data=additional_data, ) - # Further process the transcript. - # Return a string that is the "answer" to your passed transcript. - benchmark_llm = Benchmark(label="Command/AI Processing") process_result, instant_response, skill, interrupt = ( await self._get_response_for_transcript( @@ -588,8 +893,6 @@ async def process(self, audio_input_wav: str = None, transcript: str = None, ima if process_result: if self.settings.streamer_mode: self.tower.save_last_message(self.name, process_result) - - # the last step in the chain. You'll probably want to play the response to the user as audio using a TTS provider or mechanism of your choice. await self.play_to_user(str(process_result), not interrupt) except Exception as e: await printr.print_async( @@ -598,33 +901,411 @@ async def process(self, audio_input_wav: str = None, transcript: str = None, ima ) printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - # ───────────────── virtual methods / hooks ───────────────── # + # ───────────────── Transcription ───────────────── # async def _transcribe(self, audio_input_wav: str) -> str | None: - """Transcribes the audio to text. You can override this method if you want to use a different transcription service. - - Args: - audio_input_wav (str): The path to the audio file that contains the user's speech. This is a recording of what you you said. - - Returns: - str | None: The transcript of the audio file and the detected language as locale (if determined). - """ + if not self.stt: + return None + try: + transcript = await self.stt.transcribe(filename=audio_input_wav) + if transcript: + # Wingman Pro might return a serialized dict instead of a real object + if isinstance(transcript, dict): + return transcript.get("_text") + return transcript.text + except Exception as e: + await printr.print_async( + f"Error during transcription using '{self.config.features.stt_provider}': {str(e)}", + color=LogType.ERROR, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) return None + # ───────────────── Response orchestration ───────────────── # + async def _get_response_for_transcript( self, transcript: str, benchmark: Benchmark, images: list[tuple[str, str]] = None - ) -> tuple[str | None, str | None, Skill | None, bool | None]: - """Processes the transcript and return a response as text. This where you'll do most of your work. - Pass the transcript to AI providers and build a conversation. Call commands or APIs. Play temporary results to the user etc. + ) -> tuple[str | None, str | None, Skill | None, bool]: + self.ensure_memory_initialized() + + await self.add_user_message(transcript, images=images) + + benchmark.start_snapshot("Instant activation commands") + instant_response, instant_command_executed = await self._try_instant_activation( + transcript=transcript + ) + if instant_response: + await self.add_assistant_message(instant_response) + benchmark.finish_snapshot() + if instant_response == ".": + instant_response = None + return instant_response, instant_response, None, True + benchmark.finish_snapshot() + + llm_processing_time_ms = 0.0 + tool_execution_time_ms = 0.0 + tool_timings: list[tuple[str, float]] = [] + + llm_start = time.perf_counter() + completion = await self._llm_call(instant_command_executed is False) + llm_processing_time_ms += (time.perf_counter() - llm_start) * 1000 + + if completion is None: + self._add_benchmark_snapshot( + benchmark, "LLM Processing", llm_processing_time_ms + ) + return None, None, None, True + + response_message, tool_calls, usage = await self._process_completion( + completion, instant_command_executed is False + ) + + turn_prompt_tokens = usage[0] + turn_completion_tokens = usage[1] + self._last_prompt_tokens = turn_prompt_tokens + + is_waiting_response_needed, is_summarize_needed = await self._add_gpt_response( + response_message, tool_calls + ) + interrupt = True + + while tool_calls: + if is_waiting_response_needed: + message = None + if response_message.content: + message = response_message.content + elif self.instant_responses: + message = self._get_random_filler() + is_summarize_needed = True + if message: + self.threaded_execution(self.play_to_user, message, interrupt) + await printr.print_async( + f"{message}", + color=LogType.POSITIVE, + source=LogSource.WINGMAN, + source_name=self.name, + skill_name="", + ) + interrupt = False + else: + is_summarize_needed = True + else: + is_summarize_needed = True + + tool_start = time.perf_counter() + instant_response, skill, iteration_timings = await self._handle_tool_calls( + tool_calls + ) + tool_execution_time_ms += (time.perf_counter() - tool_start) * 1000 + tool_timings.extend(iteration_timings) + + if instant_response: + await self._trim_tool_responses() + self._add_benchmark_snapshot( + benchmark, "LLM Processing", llm_processing_time_ms + ) + if tool_execution_time_ms > 0: + self._add_tool_execution_snapshot( + benchmark, tool_execution_time_ms, tool_timings + ) + await self._broadcast_token_usage( + turn_prompt_tokens, turn_completion_tokens + ) + return None, instant_response, None, interrupt + + if is_summarize_needed: + llm_start = time.perf_counter() + completion = await self._llm_call(True) + llm_processing_time_ms += (time.perf_counter() - llm_start) * 1000 + + if completion is None: + await self._trim_tool_responses() + self._add_benchmark_snapshot( + benchmark, "LLM Processing", llm_processing_time_ms + ) + if tool_execution_time_ms > 0: + self._add_tool_execution_snapshot( + benchmark, tool_execution_time_ms, tool_timings + ) + await self._broadcast_token_usage( + turn_prompt_tokens, turn_completion_tokens + ) + return None, None, None, True + + response_message, tool_calls, usage = await self._process_completion( + completion + ) + turn_prompt_tokens = usage[0] + turn_completion_tokens += usage[1] + self._last_prompt_tokens = turn_prompt_tokens + + is_waiting_response_needed, is_summarize_needed = ( + await self._add_gpt_response(response_message, tool_calls) + ) + if tool_calls: + interrupt = False + elif is_waiting_response_needed: + await self._trim_tool_responses() + self._add_benchmark_snapshot( + benchmark, "LLM Processing", llm_processing_time_ms + ) + if tool_execution_time_ms > 0: + self._add_tool_execution_snapshot( + benchmark, tool_execution_time_ms, tool_timings + ) + await self._broadcast_token_usage( + turn_prompt_tokens, turn_completion_tokens + ) + return None, None, None, interrupt + + await self._trim_tool_responses() + + self._add_benchmark_snapshot( + benchmark, "LLM Processing", llm_processing_time_ms + ) + if tool_execution_time_ms > 0: + self._add_tool_execution_snapshot( + benchmark, tool_execution_time_ms, tool_timings + ) + await self._broadcast_token_usage(turn_prompt_tokens, turn_completion_tokens) + return response_message.content, response_message.content, None, interrupt + + # ───────────────── LLM call ───────────────── # + + async def actual_llm_call(self, messages, tools: list[dict] = None): + if not self.llm: + await printr.print_async( + f"No LLM provider configured for wingman '{self.name}'.", + color=LogType.ERROR, + source=LogSource.WINGMAN, + source_name=self.name, + ) + return None + + try: + completion = await self.llm.ask(messages=messages, tools=tools) + except APIConnectionError as e: + provider = self.config.features.conversation_provider.value + cause = e.__cause__ + detail = str(cause) if cause else str(e) + message = f"Could not connect to {provider}: {detail}" + await printr.print_async( + message, + color=LogType.ERROR, + source=LogSource.WINGMAN, + source_name=self.name, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + return None + except Exception as e: + await printr.print_async( + f"Error during LLM call: {str(e)}", + color=LogType.ERROR, + source=LogSource.WINGMAN, + source_name=self.name, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + return None + + return completion + + async def _llm_call(self, allow_tool_calls: bool = True): + thiscall = time.time() + self.last_gpt_call = thiscall + + tools = self.build_tools() if allow_tool_calls else None + + if self.settings.debug_mode: + await printr.print_async( + f"Calling LLM with {len(self.conversation.messages)} messages (excluding context) and {len(tools) if tools else 0} tools.", + color=LogType.INFO, + ) + messages = self.conversation.messages.copy() + await self.add_context(messages) - Args: - transcript (str): The user's spoken text transcribed as text. + completion = await self.actual_llm_call(messages, tools) - Returns: - A tuple of strings representing the response to a function call and/or an instant response. - """ - return "", "", None, None + if self.last_gpt_call != thiscall: + await printr.print_async( + "LLM call was cancelled due to a new call.", color=LogType.WARNING + ) + return None + + return completion + + async def _process_completion( + self, completion: ChatCompletion, allow_tool_calls: bool = True + ): + response_message = completion.choices[0].message + + content = response_message.content + if content is None: + response_message.content = "" + + if not allow_tool_calls: + response_message.tool_calls = None + + if response_message.tool_calls: + response_message.tool_calls = await self.tool_executor.fix_tool_calls( + response_message.tool_calls, self.get_command + ) + + prompt_tokens = 0 + completion_tokens = 0 + if completion.usage: + prompt_tokens = completion.usage.prompt_tokens or 0 + completion_tokens = completion.usage.completion_tokens or 0 + + return ( + response_message, + response_message.tool_calls, + (prompt_tokens, completion_tokens), + ) + + # ───────────────── Tool calls ───────────────── # + + async def _handle_tool_calls(self, tool_calls): + return await self.tool_executor.handle_tool_calls( + tool_calls, + tool_skills=self.tool_skills, + skill_registry=self.skill_registry, + mcp_registry=self.mcp_registry, + capability_registry=self.capability_registry, + persistent_memory_service=self.persistent_memory_service, + get_command_fn=self.get_command, + execute_command_fn=self._execute_command, + play_to_user_fn=self.play_to_user, + local_ai_service=self.local_ai_service, + update_tool_response_fn=self.conversation.update_tool_response, + add_tool_response_fn=self.conversation.add_tool_response, + pending_tool_calls=self.conversation.pending_tool_calls, + ) + + async def execute_command_by_function_call( + self, function_name: str, function_args: dict[str, any] + ) -> tuple[str, str | None, Skill | None, str | None]: + """Public API kept for backward compatibility with skills.""" + return await self.tool_executor.execute_by_function_call( + function_name, + function_args, + tool_skills=self.tool_skills, + skill_registry=self.skill_registry, + mcp_registry=self.mcp_registry, + capability_registry=self.capability_registry, + persistent_memory_service=self.persistent_memory_service, + get_command_fn=self.get_command, + execute_command_fn=self._execute_command, + play_to_user_fn=self.play_to_user, + ) + + # ───────────────── Conversation delegation ───────────────── # + + async def _add_gpt_response(self, message, tool_calls) -> tuple[bool, bool]: + return await self.conversation.add_gpt_response( + message, + tool_calls, + skills=self.skills, + skill_registry=self.skill_registry, + tool_skills=self.tool_skills, + ) + + async def _trim_tool_responses(self, max_tokens: int = 500): + await self.conversation.trim_tool_responses( + max_tokens=max_tokens, + is_condensing=self.condenser.is_condensing, + ) + + async def add_user_message(self, content: str, images: list[tuple[str, str]] = None): + self._memory_recall_notified = False + self.context_builder.reset_memory_notification() + await self.conversation.add_user_message( + content, + images=images, + skills=self.skills, + condense_fn=lambda: self.condenser.maybe_condense(self.local_ai_service), + ) + + async def add_assistant_message(self, content: str): + await self.conversation.add_assistant_message( + content, skills=self.skills + ) + + async def add_forced_assistant_command_calls(self, commands: list[CommandConfig]): + await self.conversation.add_forced_assistant_command_calls( + commands, + skills=self.skills, + skill_registry=self.skill_registry, + tool_skills=self.tool_skills, + ) + + async def reset_conversation_history(self): + if self.persistent_memory_service and len(self.conversation.messages) >= 4: + try: + await self.persistent_memory_service.extract_memories( + self.conversation.messages, generate_summary=True + ) + except Exception: + pass + + await self.conversation.reset() + self._last_prompt_tokens = 0 + self.skill_registry.reset_activations() + self.mcp_registry.reset_activations() + + def get_conversation_messages(self, strip_nulls: bool = True) -> list[dict]: + return self.conversation.get_conversation_messages(strip_nulls=strip_nulls) + + # ─── Backward-compat: expose messages directly for skills/services that access it ─── # + + @property + def messages(self) -> list: + return self.conversation.messages + + @messages.setter + def messages(self, value: list): + self.conversation.messages = value + + @property + def conversation_summary(self) -> str: + return self.conversation.conversation_summary + + @conversation_summary.setter + def conversation_summary(self, value: str): + self.conversation.conversation_summary = value + + @property + def pending_tool_calls(self) -> list: + return self.conversation.pending_tool_calls + + @property + def _is_condensing(self) -> bool: + return self.condenser.is_condensing + + # ───────────────── Context ───────────────── # + + async def get_context(self): + config_dir_name = None + if self.tower and self.tower.config_dir and self.tower.config_dir.name: + config_dir_name = self.tower.config_dir.name + + return await self.context_builder.build( + skills=self.skills, + skill_registry=self.skill_registry, + conversation_summary=self.conversation.conversation_summary, + persistent_memory_service=self.persistent_memory_service, + messages=self.conversation.messages, + config_dir_name=config_dir_name, + ) + + def get_last_context(self) -> str: + return self.context_builder.get_last_context() + + async def add_context(self, messages): + context = await self.get_context() + messages.insert(0, {"role": "system", "content": context}) + + # ───────────────── TTS / play_to_user ───────────────── # async def play_to_user( self, @@ -632,29 +1313,406 @@ async def play_to_user( no_interrupt: bool = False, sound_config: Optional[SoundConfig] = None, ): - """You'll probably want to play the response to the user as audio using a TTS provider or mechanism of your choice. + if sound_config: + printr.print( + "Using custom sound config for playback", LogType.INFO, server_only=True + ) + else: + sound_config = self.config.sound - Args: - text (str): The response of your _get_response_for_transcript. This is usually the "response" from conversation with the AI. - no_interrupt (bool): prevent interrupting the audio playback - sound_config (SoundConfig): An optional sound configuration to use for the playback. If unset, the Wingman's sound config is used. - """ - pass + text, contains_links, contains_code_blocks = cleanup_text(text) - # ───────────────────────────────── Commands ─────────────────────────────── # + if no_interrupt and self.audio_player.is_playing: + while self.audio_player.is_playing: + await asyncio.sleep(0.1) - def get_command(self, command_name: str) -> CommandConfig | None: - """Extracts the command with the given name + changed_text = text + for skill in self.skills: + if skill.is_prepared: + changed_text = await skill.on_play_to_user(text, sound_config) + if changed_text != text: + printr.print( + f"Skill '{skill.config.display_name}' modified the text to: '{changed_text}'", + LogType.INFO, + ) + text = changed_text - Args: - command_name (str): the name of the command you used in the config + if sound_config.volume == 0.0: + printr.print( + "Volume modifier is set to 0. Skipping TTS processing.", + LogType.WARNING, + server_only=True, + ) + return - Returns: - {}: The command object from the config - """ + if "{SKIP-TTS}" in text: + printr.print( + "Skip TTS phrase found in input. Skipping TTS processing.", + LogType.WARNING, + server_only=True, + ) + return + + if not self.tts: + printr.print( + f"No TTS provider configured for wingman '{self.name}'.", + LogType.WARNING, + server_only=True, + ) + return + + try: + await self.tts.play_audio( + text=text, + sound_config=sound_config, + audio_player=self.audio_player, + wingman_name=self.name, + ) + except Exception as e: + await printr.print_async( + f"Error during TTS playback: {str(e)}", color=LogType.ERROR + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + + # ───────────────── Image generation ───────────────── # + + async def generate_image(self, text: str) -> str: + if ( + self.config.features.image_generation_provider + == ImageGenerationProvider.WINGMAN_PRO + ): + try: + from providers.wingman_subscription import WingmanSubscription + + wingman_pro = WingmanSubscription( + wingman_name=self.name, settings=self.settings.wingman_pro + ) + return await wingman_pro.generate_image(text) + except Exception as e: + await printr.print_async( + f"Error during image generation: {str(e)}", color=LogType.ERROR + ) + printr.print( + traceback.format_exc(), color=LogType.ERROR, server_only=True + ) + return "" + + # ───────────────── Instant activation ───────────────── # + + async def _try_instant_activation(self, transcript: str) -> tuple[str, bool]: + commands = await self._execute_instant_activation_command(transcript) + if commands: + await self.add_forced_assistant_command_calls(commands) + responses = [] + for command in commands: + if command.responses: + responses.append(self._select_instant_command_response(command)) + + if len(responses) == len(commands): + responses = list(dict.fromkeys(responses)) + responses = [ + response + "." if not response.endswith(".") else response + for response in responses + ] + return " ".join(responses), True + + return None, True + + return None, False + + # ───────────────── Instant responses ───────────────── # + + async def _generate_instant_responses(self) -> None: + context = await self.get_context() + messages = [ + { + "role": "system", + "content": """ + Generate a list in JSON format of at least 20 short direct text responses. + Make sure the response only contains the JSON, no additional text. + They must fit the described character in the given context by the user. + Every generated response must be generally usable in every situation. + Responses must show its still in progress and not in a finished state. + The user request this response is used on is unknown. Therefore it must be generic. + Good examples: + - "Processing..." + - "Stand by..." + + Bad examples: + - "Generating route..." (too specific) + - "I'm sorry, I can't do that." (too negative) + + Response example: + [ + "OK", + "Generating results...", + "Roger that!", + "Stand by..." + ] + """, + }, + {"role": "user", "content": context}, + ] + try: + completion = await self.actual_llm_call(messages) + if completion is None: + return + if completion.choices[0].message.content: + retry_limit = 3 + retry_count = 1 + valid = False + while not valid and retry_count <= retry_limit: + try: + responses = json.loads(completion.choices[0].message.content) + valid = True + for response in responses: + if response not in self.instant_responses: + self.instant_responses.append(str(response)) + except json.JSONDecodeError: + messages.append(completion.choices[0].message) + messages.append( + { + "role": "user", + "content": "It was tried to handle the response in its entirety as a JSON string. Fix response to be a pure, valid JSON, it was not convertable.", + } + ) + if retry_count <= retry_limit: + completion = await self.actual_llm_call(messages) + retry_count += 1 + except Exception as e: + await printr.print_async( + f"Error while generating instant responses: {str(e)}", + color=LogType.ERROR, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + + def _get_random_filler(self): + if len(self.last_used_instant_responses) > 2: + self.last_used_instant_responses = self.last_used_instant_responses[-2:] + + random_index = random.randint(0, len(self.instant_responses) - 1) + while random_index in self.last_used_instant_responses: + random_index = random.randint(0, len(self.instant_responses) - 1) + + self.last_used_instant_responses.append(random_index) + return self.instant_responses[random_index] + + # ───────────────── Benchmarks / token usage ───────────────── # + + def _add_benchmark_snapshot( + self, benchmark: Benchmark, label: str, execution_time_ms: float + ): + if execution_time_ms >= 1000: + formatted_time = f"{execution_time_ms/1000:.1f}s" + else: + formatted_time = f"{int(execution_time_ms)}ms" + + from api.interface import BenchmarkResult + + benchmark.snapshots.append( + BenchmarkResult( + label=label, + execution_time_ms=execution_time_ms, + formatted_execution_time=formatted_time, + ) + ) + + def _add_tool_execution_snapshot( + self, + benchmark: Benchmark, + total_time_ms: float, + tool_timings: list[tuple[str, float]], + ): + from api.interface import BenchmarkResult + + if total_time_ms >= 1000: + formatted_time = f"{total_time_ms/1000:.1f}s" + else: + formatted_time = f"{int(total_time_ms)}ms" + + nested_snapshots = [] + for label, time_ms in tool_timings: + if time_ms >= 1000: + fmt = f"{time_ms/1000:.1f}s" + else: + fmt = f"{int(time_ms)}ms" + nested_snapshots.append( + BenchmarkResult( + label=label, + execution_time_ms=time_ms, + formatted_execution_time=fmt, + ) + ) + + benchmark.snapshots.append( + BenchmarkResult( + label="Tool Execution", + execution_time_ms=total_time_ms, + formatted_execution_time=formatted_time, + snapshots=nested_snapshots if nested_snapshots else None, + ) + ) + + async def _broadcast_token_usage(self, prompt_tokens: int, completion_tokens: int): + is_local = ( + self.config.features.conversation_provider == ConversationProvider.LOCAL_LLM + ) + + if is_local and prompt_tokens == 0: + prompt_tokens = sum( + count_tokens( + msg["content"] + if isinstance(msg.get("content"), str) + else str(msg.get("content", "")) + ) + for msg in self.conversation.messages + ) + if is_local and completion_tokens == 0 and self.conversation.messages: + last = self.conversation.messages[-1] + if last.get("role") == "assistant": + content = last.get("content", "") + completion_tokens = count_tokens( + content if isinstance(content, str) else str(content) + ) + + self.last_turn_prompt_tokens = prompt_tokens + self.last_turn_completion_tokens = completion_tokens + if prompt_tokens == 0 and completion_tokens == 0: + return + if not printr._connection_manager: + return + + from api.commands import ConversationTokenUsageCommand + + await printr._connection_manager.broadcast( + ConversationTokenUsageCommand( + wingman_name=self.name, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + is_local=is_local, + ) + ) + + # ───────────────── Build tools ───────────────── # + + def build_tools(self) -> list[dict]: + def _command_has_effective_actions(command: CommandConfig) -> bool: + if command.is_system_command: + return True + if not command.actions: + return False + for action in command.actions: + if not action: + continue + if ( + action.keyboard is not None + or action.mouse is not None + or action.joystick is not None + or action.audio is not None + or action.write is not None + or action.wait is not None + ): + return True + return False + + commands = [ + command.name + for command in self.config.commands + if (not command.force_instant_activation) + and _command_has_effective_actions(command) + ] + tools: list[dict] = [] + if commands: + tools.append( + { + "type": "function", + "function": { + "name": "execute_command", + "description": "Executes a command", + "parameters": { + "type": "object", + "properties": { + "command_name": { + "type": "string", + "description": "The name of the command to execute", + "enum": commands, + }, + }, + "required": ["command_name"], + }, + }, + } + ) + + for _, tool in self.capability_registry.get_meta_tools(): + tools.append(tool) + + for _, tool in self.skill_registry.get_active_tools(): + tools.append(tool) + + for _, tool in self.mcp_registry.get_active_tools(): + tools.append(tool) + + if self.persistent_memory_service: + tools.append({ + "type": "function", + "function": { + "name": "memory_remember", + "description": "Store an important fact or detail for future reference. Use when the user explicitly asks you to remember something.", + "parameters": { + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "The fact or detail to remember.", + }, + }, + "required": ["text"], + }, + }, + }) + tools.append({ + "type": "function", + "function": { + "name": "memory_recall", + "description": "Search your memory for relevant information. Use when the user asks what you remember or know about a topic.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "What to search for in memory.", + }, + }, + "required": ["query"], + }, + }, + }) + tools.append({ + "type": "function", + "function": { + "name": "memory_forget", + "description": "Remove a specific memory. Use when the user explicitly asks you to forget something.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Description of the memory to forget.", + }, + }, + "required": ["query"], + }, + }, + }) + + return tools + + # ───────────────── Commands ─────────────────────────────── # + + def get_command(self, command_name: str) -> CommandConfig | None: if self.config.commands is None: return None - command = next( (item for item in self.config.commands if item.name == command_name), None, @@ -662,34 +1720,15 @@ def get_command(self, command_name: str) -> CommandConfig | None: return command def _select_instant_command_response(self, command: CommandConfig) -> str | None: - """Returns one of the configured responses of the command. This base implementation returns a random one. - - Args: - command (dict): The command object from the config - - Returns: - str: A random response from the command's responses list in the config. - """ command_responses = command.responses if (command_responses is None) or (len(command_responses) == 0): return None - return random.choice(command_responses) async def _execute_instant_activation_command( self, transcript: str ) -> list[CommandConfig] | None: - """Uses a fuzzy string matching algorithm to match the transcript to a configured instant_activation command and executes it immediately. - - Args: - transcript (text): What the user said, transcripted to text. Needs to be similar to one of the defined instant_activation phrases to work. - - Returns: - {} | None: The executed instant_activation command. - """ - try: - # create list with phrases pointing to commands commands_by_instant_activation = {} for command in self.config.commands: if command.instant_activation: @@ -701,7 +1740,6 @@ async def _execute_instant_activation_command( else: commands_by_instant_activation[phrase.lower()] = [command] - # find best matching phrase phrase = difflib.get_close_matches( transcript.lower(), commands_by_instant_activation.keys(), @@ -709,16 +1747,13 @@ async def _execute_instant_activation_command( cutoff=1, ) - # if no phrase found, return None if not phrase: return None - # execute all commands for the phrase commands = commands_by_instant_activation[phrase[0]] for command in commands: await self._execute_command(command, True) - # return the executed command return commands except Exception as e: await printr.print_async( @@ -731,18 +1766,6 @@ async def _execute_instant_activation_command( async def _execute_command( self, command: CommandConfig, is_instant=False ) -> tuple[str | None, str]: - """Triggers the execution of a command. This base implementation executes the keypresses defined in the command. - - Args: - command (dict): The command object from the config to execute - - Returns: - tuple[str | None, str]: A 2-tuple of: - - Instant response (str) to play immediately, or None if there is no instant response. - - Function/tool response (str) to feed back to the LLM (uses command's additional_context, - falls back to "OK", or an error string on failure). - """ - if not command: return None, "Command not found" @@ -759,7 +1782,6 @@ async def _execute_command( color=LogType.COMMAND, ) - # handle the global special commands: if command.name == "ResetConversationHistory": await self.reset_conversation_history() await printr.print_async( @@ -779,23 +1801,10 @@ async def _execute_command( return None, "ERROR DURING PROCESSING" async def execute_action(self, command: CommandConfig): - """Executes the actions defined in the command (in order). - - Args: - command (dict): The command object from the config to execute - """ if not command or not command.actions: return def contains_numpad_key(hotkey: str) -> bool: - """Check if the hotkey string contains a numpad key anywhere in the chord. - - Args: - hotkey: The hotkey string (e.g., 'num 1', 'ctrl+num 1', 'alt+num 2') - - Returns: - True if any token in the chord is a numpad key (num 0 - num 9) - """ if not hotkey: return False tokens = hotkey.lower().split("+") @@ -812,7 +1821,6 @@ def contains_numpad_key(hotkey: str) -> bool: code = action.keyboard.hotkey if action.keyboard.press == action.keyboard.release: - # compressed key events hold = action.keyboard.hold or 0.1 if ( action.keyboard.hotkey_codes @@ -833,7 +1841,6 @@ def contains_numpad_key(hotkey: str) -> bool: time.sleep(hold) keyboard.release(code) else: - # single key events if ( action.keyboard.hotkey_codes and len(action.keyboard.hotkey_codes) == 1 @@ -888,8 +1895,9 @@ def contains_numpad_key(hotkey: str) -> bool: ) printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + # ───────────────── Threading ─────────────────────────────── # + def threaded_execution(self, function, *args) -> threading.Thread | None: - """Execute a function in a separate thread.""" try: def start_thread(function, *args): @@ -903,7 +1911,7 @@ def start_thread(function, *args): thread = threading.Thread(target=start_thread, args=(function, *args)) thread.name = function.__name__ - thread.daemon = True # Mark as daemon so it dies when main process exits + thread.daemon = True thread.start() return thread except Exception as e: @@ -913,27 +1921,17 @@ def start_thread(function, *args): printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) return None + # ───────────────── Config management ─────────────────────── # + async def update_config( self, config: WingmanConfig, skip_config_validation: bool = True ) -> bool: - """Update the config of the Wingman. - - This method should always be called if the config of the Wingman has changed. - - Args: - config: The new wingman configuration - skip_config_validation: If False, validate the config and rollback on error - - Returns: - True if config was updated successfully, False otherwise - """ try: if not skip_config_validation: old_config = deepcopy(self.config) self.config = config - # Propagate skill config changes to loaded skills await self._update_skill_configs(config) if not skip_config_validation: @@ -957,15 +1955,9 @@ async def update_config( return False async def _update_skill_configs(self, wingman_config: WingmanConfig) -> None: - """Propagate skill config changes to loaded skills. - - When the wingman config changes (e.g., user updates custom_properties for a skill), - we need to update the SkillConfig on each loaded skill instance so they see the new values. - """ if not self.skills or not wingman_config.skills: return - # Build lookup of new skill configs by folder name new_skill_configs: dict[str, "SkillConfig"] = {} for skill_config in wingman_config.skills: try: @@ -979,9 +1971,7 @@ async def _update_skill_configs(self, wingman_config: WingmanConfig) -> None: continue new_skill_configs[folder_name] = skill_config - # Update each loaded skill if its config changed for skill in self.skills: - # Get the folder name for this skill try: skill_folder = _get_skill_folder_from_module(skill.config.module) except Exception: @@ -997,51 +1987,49 @@ async def _update_skill_configs(self, wingman_config: WingmanConfig) -> None: fields_set = getattr(user_override, "model_fields_set", None) if fields_set is None: - # Pydantic v1 fallback fields_set = getattr(user_override, "__fields_set__", set()) - # Create updated config by copying current and applying overrides - # This preserves all default values while applying user overrides updated_config = deepcopy(skill.config) - # Apply overrides even if they're explicitly empty. - # This allows users to clear custom properties/prompt in the UI. if "custom_properties" in fields_set: updated_config.custom_properties = user_override.custom_properties if "prompt" in fields_set: updated_config.prompt = user_override.prompt - # Let the skill handle the config update (will compare old vs new) await skill.update_config(updated_config) + async def update_settings(self, settings: SettingsConfig): + try: + self.settings = settings + + for skill in self.skills: + skill.settings = settings + + # Re-create Wingman Pro provider when settings change + # (subscription settings might have been updated) + uses_wingman_pro = any([ + self.config.features.conversation_provider == ConversationProvider.WINGMAN_PRO, + self.config.features.tts_provider == TtsProvider.WINGMAN_PRO, + self.config.features.stt_provider == SttProvider.WINGMAN_PRO, + self.config.features.image_generation_provider == ImageGenerationProvider.WINGMAN_PRO, + ]) + if uses_wingman_pro: + await self.validate() + printr.print( + f"Wingman {self.name}: reinitialized providers with new settings", + server_only=True, + ) + + printr.print(f"Wingman {self.name}'s settings changed", server_only=True) + except Exception as e: + await printr.print_async( + f"Error while updating settings for wingman '{self.name}': {str(e)}", + color=LogType.ERROR, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + async def save_config(self): - """Save the config of the Wingman.""" self.tower.save_wingman(self.name) async def save_commands(self): - """Save only the commands section of this wingman's config. - - This performs a partial YAML update - only the commands field is modified - in the config file, avoiding full config serialization. This is much safer - than save_config() for command-only changes as it won't accidentally - overwrite other fields. - - Use this instead of save_config() when you only changed command definitions, - instant_activation phrases, or other command-related fields. - - Example use cases: - - QuickCommands learning instant activation phrases - - Skills dynamically adding/modifying commands - - Skills updating command responses or actions - """ self.tower.save_wingman_commands(self.name) - - async def update_settings(self, settings: SettingsConfig): - """Update the settings of the Wingman. This method should always be called when the user Settings have changed.""" - self.settings = settings - - # Propagate settings changes to already-loaded skills - for skill in self.skills: - skill.settings = settings - - printr.print(f"Wingman {self.name}'s settings changed", server_only=True) From d1131548c30fb48e07b1225fd014813219d6d3d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Thu, 9 Apr 2026 11:56:39 +0200 Subject: [PATCH 12/34] =?UTF-8?q?milestone:=20M5=20complete=20=E2=80=94=20?= =?UTF-8?q?Wingman=20merged,=20factory=20+=20services=20wired,=20backward-?= =?UTF-8?q?compat=20shim=20in=20place?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit From 8ab5166580ae1ba63312367c90456c2304e21c4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Thu, 9 Apr 2026 11:58:25 +0200 Subject: [PATCH 13/34] feat: add WingmanContext facade for skill API surface control Co-Authored-By: Claude Sonnet 4.6 --- wingmen/wingman_context.py | 140 +++++++++++++++++++++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 wingmen/wingman_context.py diff --git a/wingmen/wingman_context.py b/wingmen/wingman_context.py new file mode 100644 index 00000000..5d080290 --- /dev/null +++ b/wingmen/wingman_context.py @@ -0,0 +1,140 @@ +"""Controlled interface for skills. Limits what plugins can access.""" + +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from openai.types.chat import ChatCompletion + from api.enums import TtsProvider + from api.interface import SoundConfig, WingmanConfig, SettingsConfig + from services.audio_player import AudioPlayer + from wingmen.wingman import Wingman + + +class WingmanContext: + """What skills see. Controlled API surface — no internal leaking.""" + + def __init__(self, wingman: "Wingman"): + self._wingman = wingman + + # --- Properties --- + + @property + def name(self) -> str: + return self._wingman.name + + @property + def config(self) -> "WingmanConfig": + return self._wingman.config + + @property + def settings(self) -> "SettingsConfig": + return self._wingman.settings + + @property + def audio_player(self) -> "AudioPlayer": + return self._wingman.audio_player + + @property + def tower(self): + return self._wingman.tower + + @property + def secret_keeper(self): + return self._wingman.secret_keeper + + # --- Conversation --- + + async def llm_call(self, messages: list[dict], tools: list[dict] | None = None) -> "ChatCompletion | None": + """Make an LLM call. Replaces actual_llm_call().""" + return await self._wingman.actual_llm_call(messages, tools) + + def get_conversation_history(self) -> list[dict]: + """Get a read-only copy of the conversation history.""" + return self._wingman.conversation.get_history() + + async def add_user_message(self, content: str): + await self._wingman.conversation.add_user_message(content, skills=self._wingman.skills) + + async def add_assistant_message(self, content: str): + await self._wingman.conversation.add_assistant_message(content, skills=self._wingman.skills) + + async def reset_conversation_history(self): + await self._wingman.reset_conversation_history() + + # --- Audio --- + + async def play_to_user(self, text: str, no_interrupt: bool = False, + sound_config: "Optional[SoundConfig]" = None): + await self._wingman.play_to_user(text, no_interrupt, sound_config) + + # --- Image generation --- + + async def generate_image(self, text: str) -> str: + return await self._wingman.generate_image(text) + + # --- Secrets --- + + async def retrieve_secret(self, requester: str, errors: list = None) -> str | None: + return await self._wingman.retrieve_secret(requester, errors or []) + + # --- Utilities --- + + def threaded_execution(self, func, *args): + self._wingman.threaded_execution(func, *args) + + async def get_context(self) -> str: + return await self._wingman.context_builder.build( + skills=self._wingman.skills, + skill_registry=self._wingman.skill_registry, + conversation_summary=self._wingman.condenser.summary, + persistent_memory_service=self._wingman.persistent_memory_service, + messages=self._wingman.conversation.messages, + config_dir_name=self._wingman.tower.config_dir.name if self._wingman.tower else None, + ) + + # --- Provider switching (for voice_changer and similar) --- + + async def switch_tts_provider(self, provider: "TtsProvider", + errors: list = None) -> bool: + """Hot-swap the TTS provider at runtime. + + Updates config.features.tts_provider and creates a new TTS instance. + Used by voice_changer skill. + """ + from services.provider_factory import ProviderFactory + self._wingman.config.features.tts_provider = provider + factory = ProviderFactory( + config=self._wingman.config, + settings=self._wingman.settings, + secret_keeper=self._wingman.secret_keeper, + shared_providers=self._wingman._shared_providers, + wingman_name=self._wingman.name, + ) + _errors = errors or [] + new_tts = await factory.create_tts(_errors) + if new_tts: + self._wingman.tts = new_tts + return True + return False + + # --- Backward compatibility (temporary) --- + # These provide access to registries that some skills currently use. + # They should be replaced with proper facade methods in a future iteration. + + @property + def tool_skills(self) -> dict: + # tool_skills lives directly on the Wingman instance, not on tool_executor + return self._wingman.tool_skills + + @property + def mcp_registry(self): + return self._wingman.mcp_registry + + @property + def skill_registry(self): + return self._wingman.skill_registry + + # Expose messages property for backward compat (quick_commands reads it) + @property + def messages(self) -> list: + return self._wingman.conversation.messages From bf30c0009953c581d43a54cf7d1b058a85790d54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Thu, 9 Apr 2026 12:08:52 +0200 Subject: [PATCH 14/34] refactor: skills receive WingmanContext facade instead of raw Wingman MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit All skills now receive WingmanContext (controlled API surface) instead of the raw Wingman instance. voice_changer and radio_chatter provider pre-initialization removed — ProviderFactory handles it at switch time via switch_tts_provider(). Co-Authored-By: Claude Opus 4.6 --- services/skill_local_ai.py | 4 +- skills/api_request/main.py | 4 +- skills/ats_telemetry/main.py | 4 +- skills/audio_device_changer/main.py | 4 +- skills/auto_screenshot/main.py | 4 +- skills/control_windows/main.py | 4 +- skills/file_manager/main.py | 4 +- skills/hud/main.py | 4 +- skills/image_generation/main.py | 4 +- skills/msfs2020_control/main.py | 4 +- skills/quick_commands/main.py | 4 +- skills/radio_chatter/main.py | 44 ++--------------- skills/skill_base.py | 4 +- skills/spotify/main.py | 4 +- skills/thinking_sound/main.py | 4 +- skills/timer/main.py | 4 +- skills/typing_assistant/main.py | 4 +- skills/uexcorp/main.py | 4 +- .../uexcorp/uexcorp/handler/config_handler.py | 8 ++-- skills/uexcorp/uexcorp/helper.py | 6 +-- skills/vision_ai/main.py | 4 +- skills/voice_changer/main.py | 47 ++----------------- wingmen/wingman.py | 7 ++- wingmen/wingman_context.py | 9 ++++ 24 files changed, 67 insertions(+), 126 deletions(-) diff --git a/services/skill_local_ai.py b/services/skill_local_ai.py index e03fe8c5..cfaaa611 100644 --- a/services/skill_local_ai.py +++ b/services/skill_local_ai.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from services.local_ai_service import TokenBudget - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext printr = Printr() @@ -109,7 +109,7 @@ class SkillLocalAI: 4. On success: convert internal types to facade types and return """ - def __init__(self, wingman: "OpenAiWingman"): + def __init__(self, wingman: "WingmanContext"): self._wingman = wingman # ── Availability ────────────────────────────────────────────── diff --git a/skills/api_request/main.py b/skills/api_request/main.py index 5a8a68d8..2bd98a9d 100644 --- a/skills/api_request/main.py +++ b/skills/api_request/main.py @@ -12,7 +12,7 @@ from skills.skill_base import Skill, tool if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext DEFAULT_HEADERS = { "Strict-Transport-Security": "max-age=31536000; includeSubDomains", @@ -61,7 +61,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: self.default_headers = DEFAULT_HEADERS diff --git a/skills/ats_telemetry/main.py b/skills/ats_telemetry/main.py index 45b04093..625da298 100644 --- a/skills/ats_telemetry/main.py +++ b/skills/ats_telemetry/main.py @@ -21,13 +21,13 @@ if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class ATSTelemetry(Skill): def __init__( - self, config: SkillConfig, settings: SettingsConfig, wingman: "OpenAiWingman" + self, config: SkillConfig, settings: SettingsConfig, wingman: "WingmanContext" ) -> None: self.loaded = False self.already_initialized_telemetry = False diff --git a/skills/audio_device_changer/main.py b/skills/audio_device_changer/main.py index 02fc0dc5..a9885fb9 100644 --- a/skills/audio_device_changer/main.py +++ b/skills/audio_device_changer/main.py @@ -15,7 +15,7 @@ from skills.skill_base import Skill if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class AudioDeviceChanger(Skill): @@ -25,7 +25,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) self.original_audio_device = settings.audio.output diff --git a/skills/auto_screenshot/main.py b/skills/auto_screenshot/main.py index 6b7328a9..053f75c8 100644 --- a/skills/auto_screenshot/main.py +++ b/skills/auto_screenshot/main.py @@ -14,7 +14,7 @@ from skills.skill_base import Skill, tool if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class AutoScreenshot(Skill): @@ -22,7 +22,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) diff --git a/skills/control_windows/main.py b/skills/control_windows/main.py index 2b09aa32..9a5ae1be 100644 --- a/skills/control_windows/main.py +++ b/skills/control_windows/main.py @@ -11,7 +11,7 @@ import mouse.mouse as mouse if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class ControlWindows(Skill): @@ -28,7 +28,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) diff --git a/skills/file_manager/main.py b/skills/file_manager/main.py index bb91ca21..8ebcc022 100644 --- a/skills/file_manager/main.py +++ b/skills/file_manager/main.py @@ -10,7 +10,7 @@ from pdfminer.high_level import extract_text if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext DEFAULT_MAX_TEXT_SIZE = 24000 SUPPORTED_FILE_EXTENSIONS = [ @@ -101,7 +101,7 @@ class FileManager(Skill): def __init__( - self, config: SkillConfig, settings: SettingsConfig, wingman: "OpenAiWingman" + self, config: SkillConfig, settings: SettingsConfig, wingman: "WingmanContext" ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) self.allowed_file_extensions = SUPPORTED_FILE_EXTENSIONS diff --git a/skills/hud/main.py b/skills/hud/main.py index 15df3583..6b59f9e4 100644 --- a/skills/hud/main.py +++ b/skills/hud/main.py @@ -31,7 +31,7 @@ from hud_server.validation import validate_hud_settings if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext printr = Printr() @@ -50,7 +50,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman" + wingman: "WingmanContext" ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) diff --git a/skills/image_generation/main.py b/skills/image_generation/main.py index 9eb4ffe3..8f9d5c69 100644 --- a/skills/image_generation/main.py +++ b/skills/image_generation/main.py @@ -7,7 +7,7 @@ from skills.skill_base import Skill, tool if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class ImageGeneration(Skill): @@ -16,7 +16,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) self.image_path = self.get_generated_files_dir() diff --git a/skills/msfs2020_control/main.py b/skills/msfs2020_control/main.py index 8024e032..d94ff5c5 100644 --- a/skills/msfs2020_control/main.py +++ b/skills/msfs2020_control/main.py @@ -22,13 +22,13 @@ from skills.msfs2020_control.command_matcher.command_matcher import CommandMatcher if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class Msfs2020Control(Skill): def __init__( - self, config: SkillConfig, settings: SettingsConfig, wingman: "OpenAiWingman" + self, config: SkillConfig, settings: SettingsConfig, wingman: "WingmanContext" ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) self.already_initialized_simconnect = False diff --git a/skills/quick_commands/main.py b/skills/quick_commands/main.py index d3d14609..f6083b86 100644 --- a/skills/quick_commands/main.py +++ b/skills/quick_commands/main.py @@ -7,7 +7,7 @@ from skills.skill_base import Skill if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class QuickCommands(Skill): @@ -16,7 +16,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) diff --git a/skills/radio_chatter/main.py b/skills/radio_chatter/main.py index ae933e70..3b777987 100644 --- a/skills/radio_chatter/main.py +++ b/skills/radio_chatter/main.py @@ -22,7 +22,7 @@ from skills.skill_base import Skill, tool if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class RadioChatter(Skill): @@ -31,7 +31,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) @@ -155,42 +155,8 @@ async def validate(self) -> list[WingmanInitializationError]: ) ) - # Initialize all providers - initiated_providers = [] - for voice in voices: - voice_provider = voice.provider - if voice_provider not in initiated_providers: - initiated_providers.append(voice_provider) - - if voice_provider == TtsProvider.OPENAI and not self.wingman.openai: - await self.wingman.validate_and_set_openai(errors) - elif ( - voice_provider == TtsProvider.AZURE - and not self.wingman.openai_azure - ): - await self.wingman.validate_and_set_azure(errors) - elif ( - voice_provider == TtsProvider.ELEVENLABS - and not self.wingman.elevenlabs - ): - await self.wingman.validate_and_set_elevenlabs(errors) - elif ( - voice_provider == TtsProvider.WINGMAN_PRO - and not self.wingman.wingman_pro - ): - await self.wingman.validate_and_set_wingman_pro() - elif ( - voice_provider == TtsProvider.INWORLD - and not self.wingman.inworld - ): - await self.wingman.validate_and_set_inworld(errors) - elif ( - voice_provider == TtsProvider.OPENAI_COMPATIBLE - and not self.wingman.openai_compatible_tts - ): - await self.wingman.validate_and_set_openai_compatible_tts( - errors - ) + # Provider initialization is handled by ProviderFactory via + # switch_tts_provider() at voice-switch time. No pre-init needed. return errors @@ -605,7 +571,7 @@ async def _switch_voice( f"Switching voice to {voice_name} ({voice_provider.value})" ) - self.wingman.config.features.tts_provider = voice_provider + await self.wingman.switch_tts_provider(voice_provider) async def _get_original_voice_setting(self) -> VoiceSelection: voice_provider = self.wingman.config.features.tts_provider diff --git a/skills/skill_base.py b/skills/skill_base.py index 235a4afa..e3240119 100644 --- a/skills/skill_base.py +++ b/skills/skill_base.py @@ -26,7 +26,7 @@ if TYPE_CHECKING: from services.skill_local_ai import SkillLocalAI - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext # Type mapping from Python types to JSON Schema types @@ -296,7 +296,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: self.config = config self.settings = settings diff --git a/skills/spotify/main.py b/skills/spotify/main.py index 750ee692..ea18e1ab 100644 --- a/skills/spotify/main.py +++ b/skills/spotify/main.py @@ -8,7 +8,7 @@ from services.file import get_generated_files_dir if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class Spotify(Skill): @@ -17,7 +17,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) diff --git a/skills/thinking_sound/main.py b/skills/thinking_sound/main.py index 99ca957c..38ccbd4f 100644 --- a/skills/thinking_sound/main.py +++ b/skills/thinking_sound/main.py @@ -10,7 +10,7 @@ from skills.skill_base import Skill if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class ThinkingSound(Skill): @@ -20,7 +20,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) diff --git a/skills/timer/main.py b/skills/timer/main.py index 92da197f..0bcd57f0 100644 --- a/skills/timer/main.py +++ b/skills/timer/main.py @@ -13,7 +13,7 @@ from skills.skill_base import Skill, tool if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class ActualTimer: @@ -123,7 +123,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) diff --git a/skills/typing_assistant/main.py b/skills/typing_assistant/main.py index 78f11593..baf7a7d4 100644 --- a/skills/typing_assistant/main.py +++ b/skills/typing_assistant/main.py @@ -6,7 +6,7 @@ import keyboard.keyboard as keyboard if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class TypingAssistant(Skill): @@ -21,7 +21,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) diff --git a/skills/uexcorp/main.py b/skills/uexcorp/main.py index 64e5ee16..9d0a910a 100644 --- a/skills/uexcorp/main.py +++ b/skills/uexcorp/main.py @@ -14,7 +14,7 @@ from skills.uexcorp.uexcorp.helper import Helper if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class UEXCorp(Skill): @@ -24,7 +24,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: self.random_seed = uuid.uuid4() super().__init__(config=config, settings=settings, wingman=wingman) diff --git a/skills/uexcorp/uexcorp/handler/config_handler.py b/skills/uexcorp/uexcorp/handler/config_handler.py index 0ca06714..0beb347e 100644 --- a/skills/uexcorp/uexcorp/handler/config_handler.py +++ b/skills/uexcorp/uexcorp/handler/config_handler.py @@ -6,7 +6,7 @@ from services.file import get_writable_dir if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext from skills.uexcorp.uexcorp.helper import Helper @@ -64,7 +64,7 @@ def __init__( helper: "Helper" ): self.__helper = helper - self.__wingman: "OpenAiWingman | None" = None + self.__wingman: "WingmanContext | None" = None self.__fine_config_path: str = get_writable_dir(os.path.join(self.__helper.get_data_path(), "config")) self.__api_url: str = "https://api.uexcorp.space/2.0" self.__api_use_key: bool = False @@ -324,8 +324,8 @@ def get_behavior_use_fasterwhisper_hotwords(self) -> bool: def set_behavior_update_fasterwhisper_hotwords(self, update: bool): self.__behavior_use_fasterwhisper_hotwords = update - def set_wingman(self, wingman: "OpenAiWingman"): + def set_wingman(self, wingman: "WingmanContext"): self.__wingman = wingman - def get_wingman(self) -> "OpenAiWingman": + def get_wingman(self) -> "WingmanContext": return self.__wingman \ No newline at end of file diff --git a/skills/uexcorp/uexcorp/helper.py b/skills/uexcorp/uexcorp/helper.py index 16a745c5..d9a11f6d 100644 --- a/skills/uexcorp/uexcorp/helper.py +++ b/skills/uexcorp/uexcorp/helper.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from skills.uexcorp.uexcorp.handler.tool_handler import ToolHandler - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext printr = Printr() @@ -66,7 +66,7 @@ def __init__(self): self.__request_while_not_ready = False self.__wingman = None - def prepare(self, threaded_execution: callable, wingman: "OpenAiWingman"): + def prepare(self, threaded_execution: callable, wingman: "WingmanContext"): from skills.uexcorp.uexcorp.handler.tool_handler import ToolHandler self.__wingman = wingman @@ -325,7 +325,7 @@ def get_llm(self) -> Llm: def get_default_thread_ident(self) -> int: return self.__default_thread - def get_wingmen(self) -> "OpenAiWingman": + def get_wingmen(self) -> "WingmanContext": return self.__wingman def toast(self, message: str): diff --git a/skills/vision_ai/main.py b/skills/vision_ai/main.py index adc67b6f..7cc348c7 100644 --- a/skills/vision_ai/main.py +++ b/skills/vision_ai/main.py @@ -8,7 +8,7 @@ from skills.skill_base import Skill, tool if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class VisionAI(Skill): @@ -17,7 +17,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) diff --git a/skills/voice_changer/main.py b/skills/voice_changer/main.py index 5c7742f2..b172f361 100644 --- a/skills/voice_changer/main.py +++ b/skills/voice_changer/main.py @@ -15,7 +15,7 @@ from skills.skill_base import Skill if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman_context import WingmanContext class VoiceChanger(Skill): @@ -24,7 +24,7 @@ def __init__( self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) @@ -46,45 +46,8 @@ async def validate(self) -> list[WingmanInitializationError]: voices: list[VoiceSelection] = self.retrieve_custom_property_value( "voice_changer_voices", errors ) - if voices and len(voices) > 0: - # Initialize all providers - initiated_providers = [] - - for voice in voices: - voice_provider = voice.provider - if voice_provider not in initiated_providers: - initiated_providers.append(voice_provider) - - # initiate provider - if voice_provider == TtsProvider.OPENAI and not self.wingman.openai: - await self.wingman.validate_and_set_openai(errors) - elif ( - voice_provider == TtsProvider.AZURE - and not self.wingman.openai_azure - ): - await self.wingman.validate_and_set_azure(errors) - elif ( - voice_provider == TtsProvider.ELEVENLABS - and not self.wingman.elevenlabs - ): - await self.wingman.validate_and_set_elevenlabs(errors) - elif ( - voice_provider == TtsProvider.WINGMAN_PRO - and not self.wingman.wingman_pro - ): - await self.wingman.validate_and_set_wingman_pro() - elif ( - voice_provider == TtsProvider.INWORLD - and not self.wingman.inworld - ): - await self.wingman.validate_and_set_inworld(errors) - elif ( - voice_provider == TtsProvider.OPENAI_COMPATIBLE - and not self.wingman.openai_compatible_tts - ): - await self.wingman.validate_and_set_openai_compatible_tts( - errors - ) + # Provider initialization is handled by ProviderFactory via + # switch_tts_provider() at voice-switch time. No pre-init needed. return errors @@ -256,7 +219,7 @@ async def _switch_voice(self, voices: list[VoiceSelection]) -> str: ) return f"Voice switching failed due to an unknown voice provider/subprovider. Provider: {voice_provider.value}" - self.wingman.config.features.tts_provider = voice_provider + await self.wingman.switch_tts_provider(voice_provider) if not provider_name: provider_name = voice_provider.value diff --git a/wingmen/wingman.py b/wingmen/wingman.py index de3f603d..fdcb74a5 100644 --- a/wingmen/wingman.py +++ b/wingmen/wingman.py @@ -62,6 +62,7 @@ from services.tool_response_cache import ToolResponseCompressor from services.token_utils import count_tokens from skills.skill_base import Skill +from wingmen.wingman_context import WingmanContext if TYPE_CHECKING: from services.tower import Tower @@ -647,10 +648,11 @@ async def init_skills(self) -> list[WingmanInitializationError]: ) continue + context = WingmanContext(self) skill = ModuleManager.load_skill( config=skill_config, settings=self.settings, - wingman=self, + wingman=context, ) if skill: skill.threaded_execution = self.threaded_execution @@ -780,10 +782,11 @@ async def enable_skill(self, skill_name: str) -> tuple[bool, str]: f"Skill '{skill_name}' is not supported on {normalized_platform}.", ) + context = WingmanContext(self) skill = ModuleManager.load_skill( config=skill_config, settings=self.settings, - wingman=self, + wingman=context, ) if skill: skill.threaded_execution = self.threaded_execution diff --git a/wingmen/wingman_context.py b/wingmen/wingman_context.py index 5d080290..7298a544 100644 --- a/wingmen/wingman_context.py +++ b/wingmen/wingman_context.py @@ -138,3 +138,12 @@ def skill_registry(self): @property def messages(self) -> list: return self._wingman.conversation.messages + + # Expose local AI services for SkillLocalAI facade + @property + def local_ai_service(self): + return self._wingman.local_ai_service + + @property + def persistent_memory_service(self): + return self._wingman.persistent_memory_service From 06902bdea1aeb214dcd5cc2209c3b7265b8bfa62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Thu, 9 Apr 2026 12:10:54 +0200 Subject: [PATCH 15/34] chore: remove custom wingman class support --- services/module_manager.py | 66 +------------------------------------- services/tower.py | 43 ++++++++----------------- 2 files changed, 14 insertions(+), 95 deletions(-) diff --git a/services/module_manager.py b/services/module_manager.py index 9771f280..82031a56 100644 --- a/services/module_manager.py +++ b/services/module_manager.py @@ -12,21 +12,13 @@ SkillBase, SkillConfig, SkillToolInfo, - WingmanConfig, ) -from providers.faster_whisper import FasterWhisper -from providers.parakeet import Parakeet -from providers.whispercpp import Whispercpp -from providers.xvasynth import XVASynth -from services.audio_library import AudioLibrary -from services.audio_player import AudioPlayer -from services.file import get_writable_dir, get_custom_skills_dir +from services.file import get_custom_skills_dir from services.printr import Printr from skills.skill_base import Skill if TYPE_CHECKING: from wingmen.wingman import Wingman - from services.tower import Tower SKILLS_DIR = "skills" @@ -64,62 +56,6 @@ def get_module_name_and_path(module_string: str) -> tuple[str, str]: # module_path = path.join(module_path, module_name + ".py") return module_name, module_path - @staticmethod - def create_wingman_dynamically( - name: str, - config: WingmanConfig, - settings: SettingsConfig, - audio_player: AudioPlayer, - audio_library: AudioLibrary, - whispercpp: Whispercpp, - fasterwhisper: FasterWhisper, - parakeet: Parakeet, - xvasynth: XVASynth, - tower: "Tower", - ): - """Dynamically creates a Wingman instance from a module path and class name - - Args: - name (str): The name of the wingman. This is the key you gave it in the config, e.g. "atc" - config (WingmanConfig): All "general" config entries merged with the specific Wingman config settings. The Wingman takes precedence and overrides the general config. You can just add new keys to the config and they will be available here. - settings (SettingsConfig): The general user settings. - audio_player (AudioPlayer): The audio player handling the playback of audio files. - audio_library (AudioLibrary): The audio library handling the storage and retrieval of audio files. - whispercpp (Whispercpp): The Whispercpp provider for speech-to-text. - fasterwhisper (FasterWhisper): The FasterWhisper provider for speech-to-text. - parakeet (Parakeet): The Parakeet provider for speech-to-text via ONNX Runtime. - xvasynth (XVASynth): The XVASynth provider for text-to-speech. - tower (Tower): The Tower instance, that manages loaded Wingmen. - """ - - try: - # try to load from app dir first - module = import_module(config.custom_class.module) - except ModuleNotFoundError: - # split module into name and path - module_name, module_path = ModuleManager.get_module_name_and_path( - config.custom_class.module - ) - module_path = path.join(get_writable_dir(module_path), module_name + ".py") - # load from alternative absolute file path - spec = util.spec_from_file_location(module_name, module_path) - module = util.module_from_spec(spec) - spec.loader.exec_module(module) - DerivedWingmanClass = getattr(module, config.custom_class.name) - instance = DerivedWingmanClass( - name=name, - config=config, - settings=settings, - audio_player=audio_player, - audio_library=audio_library, - whispercpp=whispercpp, - fasterwhisper=fasterwhisper, - parakeet=parakeet, - xvasynth=xvasynth, - tower=tower, - ) - return instance - @staticmethod def load_skill( config: SkillConfig, settings: SettingsConfig, wingman: "Wingman" diff --git a/services/tower.py b/services/tower.py index fc8b3eaa..8baea4a0 100644 --- a/services/tower.py +++ b/services/tower.py @@ -16,7 +16,6 @@ from services.audio_player import AudioPlayer from services.audio_library import AudioLibrary from services.config_manager import ConfigManager -from services.module_manager import ModuleManager from services.printr import Printr from wingmen.open_ai_wingman import OpenAiWingman from wingmen.wingman import Wingman @@ -109,35 +108,19 @@ async def __instantiate_wingman( ): wingman = None try: - # it's a custom Wingman - if wingman_config.custom_class: - wingman = ModuleManager.create_wingman_dynamically( - name=wingman_name, - config=wingman_config, - settings=settings, - audio_player=self.audio_player, - audio_library=self.audio_library, - whispercpp=self.whispercpp, - fasterwhisper=self.fasterwhisper, - parakeet=self.parakeet, - xvasynth=self.xvasynth, - pocket_tts=self.pocket_tts, - tower=self, - ) - else: - wingman = OpenAiWingman( - name=wingman_name, - config=wingman_config, - settings=settings, - audio_player=self.audio_player, - audio_library=self.audio_library, - whispercpp=self.whispercpp, - fasterwhisper=self.fasterwhisper, - parakeet=self.parakeet, - xvasynth=self.xvasynth, - pocket_tts=self.pocket_tts, - tower=self, - ) + wingman = OpenAiWingman( + name=wingman_name, + config=wingman_config, + settings=settings, + audio_player=self.audio_player, + audio_library=self.audio_library, + whispercpp=self.whispercpp, + fasterwhisper=self.fasterwhisper, + parakeet=self.parakeet, + xvasynth=self.xvasynth, + pocket_tts=self.pocket_tts, + tower=self, + ) except FileNotFoundError as e: # pylint: disable=broad-except wingman_config.disabled = True self.disabled_wingmen.append(wingman_config) From 4e0f09909e4534093cc36996a6b8be228f15754a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Thu, 9 Apr 2026 12:11:16 +0200 Subject: [PATCH 16/34] docs: update skill documentation for unified Wingman + WingmanContext --- skills/AGENTS.md | 4 ++-- skills/README.md | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/skills/AGENTS.md b/skills/AGENTS.md index 45fd3790..5b0551b1 100644 --- a/skills/AGENTS.md +++ b/skills/AGENTS.md @@ -118,10 +118,10 @@ from api.interface import SettingsConfig, SkillConfig, WingmanInitializationErro from skills.skill_base import Skill, tool if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman import Wingman class YourSkillName(Skill): - def __init__(self, config: SkillConfig, settings: SettingsConfig, wingman: "OpenAiWingman") -> None: + def __init__(self, config: SkillConfig, settings: SettingsConfig, wingman: "Wingman") -> None: super().__init__(config=config, settings=settings, wingman=wingman) async def validate(self) -> list[WingmanInitializationError]: diff --git a/skills/README.md b/skills/README.md index c9fcf6cd..bbd0ba3d 100644 --- a/skills/README.md +++ b/skills/README.md @@ -630,7 +630,7 @@ from api.interface import SettingsConfig, SkillConfig, WingmanInitializationErro from skills.skill_base import Skill, tool if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman import Wingman class YourSkillName(Skill): @@ -640,7 +640,7 @@ class YourSkillName(Skill): self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "Wingman", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) # Initialize your skill here @@ -1766,7 +1766,7 @@ from skills.skill_base import Skill, tool from services.skill_local_ai import MemoryType if TYPE_CHECKING: - from wingmen.open_ai_wingman import OpenAiWingman + from wingmen.wingman import Wingman class GameStatsTracker(Skill): @@ -1776,7 +1776,7 @@ class GameStatsTracker(Skill): self, config: SkillConfig, settings: SettingsConfig, - wingman: "OpenAiWingman", + wingman: "Wingman", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) From 9dfac2f9d790d22fba33750bce693ed0fa20959a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Thu, 9 Apr 2026 12:11:49 +0200 Subject: [PATCH 17/34] =?UTF-8?q?milestone:=20M6=20complete=20=E2=80=94=20?= =?UTF-8?q?skill=20facade,=20custom=20wingman=20removal,=20full=20refactor?= =?UTF-8?q?ing=20done?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit From be33a26c4666b86f80fa31fe383362359b66e7fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Fri, 10 Apr 2026 10:09:49 +0200 Subject: [PATCH 18/34] fix: unify Parakeet and FasterWhisper CUDA auto-detect Parakeet's runtime auto-detect re-flipped manual CPU choices to CUDA on every startup and persisted "cuda" even when ONNX Runtime fell back to CPU. Move Parakeet into the same once-per-version hardware scan as FasterWhisper so manual choices stick after the first run. Also fix the misleading "Parakeet initialized with providers: [...]" log that printed the requested provider list instead of what was actually loadable, making CPU fallback look like CUDA success. Co-Authored-By: Claude Opus 4.6 --- providers/parakeet.py | 46 +++++++++++++++++++------------- services/config_manager.py | 7 +++-- services/stt_provider_manager.py | 39 --------------------------- 3 files changed, 33 insertions(+), 59 deletions(-) diff --git a/providers/parakeet.py b/providers/parakeet.py index 58850193..daef9aa6 100644 --- a/providers/parakeet.py +++ b/providers/parakeet.py @@ -78,30 +78,40 @@ def _load_model_inner(self, model_path: Optional[str] = None): if not providers: providers = ["CPUExecutionProvider"] - load_kwargs = {"providers": providers} + # Filter requested providers against what ONNX Runtime actually has + # available, so we know up front whether CUDA will really be used. + try: + import onnxruntime as ort + + available = set(ort.get_available_providers()) + except Exception: + available = None + + effective_providers = providers + if available is not None: + effective_providers = [p for p in providers if p in available] + if not effective_providers: + effective_providers = ["CPUExecutionProvider"] + + if ( + self.settings.execution_provider == "cuda" + and "CUDAExecutionProvider" not in available + ): + self.printr.print( + "Parakeet: CUDA requested but not available in this ONNX Runtime build. " + "Using CPU fallback. For CUDA support, install onnxruntime-gpu.", + server_only=True, + color=LogType.WARNING, + ) + + load_kwargs = {"providers": effective_providers} if model_path: load_kwargs["path"] = model_path self.model = onnx_asr.load_model(model_name, **load_kwargs) - # Check if requested CUDA provider was actually available - if self.settings.execution_provider == "cuda": - try: - import onnxruntime as ort - - available = ort.get_available_providers() - if "CUDAExecutionProvider" not in available: - self.printr.print( - "Parakeet: CUDA requested but not available in this ONNX Runtime build. " - "Using CPU fallback. For CUDA support, install onnxruntime-gpu.", - server_only=True, - color=LogType.WARNING, - ) - except Exception: - pass - self.printr.print( - f"Parakeet initialized with model '{model_name}' (providers: {providers}).", + f"Parakeet initialized with model '{model_name}' (providers: {effective_providers}).", server_only=True, color=LogType.POSITIVE, ) diff --git a/services/config_manager.py b/services/config_manager.py index dc2e80a3..fd186d8b 100644 --- a/services/config_manager.py +++ b/services/config_manager.py @@ -1452,6 +1452,7 @@ def perform_hardware_scan(self, system_manager): if system_manager.is_cuda_available(): self.settings_config.voice_activation.fasterwhisper.device = "cuda" self.settings_config.voice_activation.fasterwhisper.compute_type = "auto" + self.settings_config.voice_activation.parakeet.execution_provider = "cuda" self.printr.print( f"- GPU detected: {system_manager.get_gpu_name()}", color=LogType.STARTUP, @@ -1460,7 +1461,7 @@ def perform_hardware_scan(self, system_manager): source_name=self.log_source_name, ) self.printr.print( - "- Auto-configured FasterWhisper to use CUDA", + "- Auto-configured FasterWhisper and Parakeet to use CUDA", color=LogType.STARTUP, server_only=True, source=LogSource.SYSTEM, @@ -1468,8 +1469,10 @@ def perform_hardware_scan(self, system_manager): ) changes = True else: + self.settings_config.voice_activation.fasterwhisper.device = "cpu" + self.settings_config.voice_activation.parakeet.execution_provider = "cpu" self.printr.print( - "- No NVIDIA GPU detected, keeping current STT settings", + "- No NVIDIA GPU detected, STT providers will use CPU", color=LogType.STARTUP, server_only=True, source=LogSource.SYSTEM, diff --git a/services/stt_provider_manager.py b/services/stt_provider_manager.py index ed26e76f..5b72b7b0 100644 --- a/services/stt_provider_manager.py +++ b/services/stt_provider_manager.py @@ -1,6 +1,5 @@ import asyncio import os -import platform from typing import Awaitable, Callable, Optional from api.enums import LogType, VoiceActivationSttProvider @@ -74,9 +73,6 @@ async def _initialize_parakeet( """Download and initialize Parakeet.""" pk_settings = self.settings_service.settings.voice_activation.parakeet - # Auto-detect CUDA and update execution_provider in settings - self._auto_detect_execution_provider(pk_settings) - # Download model variant = pk_settings.model_variant repo_id = PARAKEET_REPO_MAP.get(variant) @@ -142,41 +138,6 @@ async def _initialize_fasterwhisper( await self._health_check_fasterwhisper() - def _auto_detect_execution_provider(self, pk_settings): - """Auto-detect CUDA and set execution_provider if still on default (cpu).""" - if pk_settings.execution_provider != "cpu": - # User has manually set a non-default provider — respect it - self.printr.print( - f"Parakeet execution_provider already set to '{pk_settings.execution_provider}', skipping auto-detection.", - server_only=True, - color=LogType.INFO, - ) - return - - if platform.system() == "Darwin": - # macOS — always CPU (CoreML excluded for TDT models) - pk_settings.execution_provider = "cpu" - return - - if self.system_manager.is_cuda_available(): - pk_settings.execution_provider = "cuda" - gpu_name = self.system_manager.get_gpu_name() or "Unknown GPU" - self.printr.print( - f"CUDA detected ({gpu_name}). Setting Parakeet to CUDA execution provider.", - server_only=True, - color=LogType.POSITIVE, - ) - else: - pk_settings.execution_provider = "cpu" - self.printr.print( - "No CUDA available. Parakeet will use CPU execution provider.", - server_only=True, - color=LogType.INFO, - ) - - # Persist to settings - self.settings_service.save_settings_to_disk() - async def switch_provider(self, new_provider: VoiceActivationSttProvider): """Switch active STT provider. Unloads old, downloads + loads new.""" old_provider = self.active_provider From c09e3e9313d3f2665554e494945dc2fb3c4be6e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Fri, 10 Apr 2026 10:10:14 +0200 Subject: [PATCH 19/34] cleanup --- docs/parakeet-issues.md | 62 ----------------------------------------- 1 file changed, 62 deletions(-) delete mode 100644 docs/parakeet-issues.md diff --git a/docs/parakeet-issues.md b/docs/parakeet-issues.md deleted file mode 100644 index 5e257003..00000000 --- a/docs/parakeet-issues.md +++ /dev/null @@ -1,62 +0,0 @@ -# Parakeet STT Issues (2026-04-07) - -## Issue 1: macOS CoreML crash - -**Symptom**: Switching execution provider to CoreML gives `ONNXRuntimeError: model_path must not be empty` from `onnxruntime/core/optimizer/initializer.cc:45`. - -**Root cause**: The `NemoConformerAED` model class explicitly excludes CoreML in `onnx_asr/models/nemo.py:183`: -```python -def _get_excluded_providers() -> list[str]: - return [*TensorRtOptions.get_provider_names(), "CoreMLExecutionProvider"] -``` -But `NemoConformerTdt` (the model Parakeet uses — `nemo-parakeet-tdt-0.6b-v2/v3`) inherits from `NemoConformerRnnt` which has **no CoreML exclusion**. So CoreML is passed to the ONNX session, but the TDT model uses external data files (`onnx?data` pattern in `loader.py:134`) that CoreML can't handle. - -**Fix options**: -1. Exclude CoreML from TDT models (like AED does) — simplest, but removes the option -2. Catch the error in `parakeet.py:__load_model_inner()` and fall back to CPU with a toast message -3. Hide the CoreML option in the Client UI when the model variant doesn't support it (requires Core→Client communication of supported providers) - -## Issue 2: Windows CUDA doesn't reinitialize properly - -**Symptom**: Switching execution provider to CUDA downloads the model but STT doesn't work until full app restart. - -**Root cause (suspected)**: `del self.model` in `parakeet.py:101` drops the Python reference to the old ONNX session, but CUDA GPU memory isn't released immediately (Python GC is non-deterministic). When the new CUDA session tries to allocate GPU memory, the old one may still be holding it. - -**Fix options**: -1. Force `gc.collect()` after unloading the model before loading the new one -2. Add explicit ONNX session cleanup (call `self.model` internals to release sessions) -3. Add a brief delay between unload and reload for CUDA specifically - -## Broader architecture issues - -- `__load_model_inner()` has a generic `except Exception` that toasts an error but leaves `self.model = None`. No retry, no CPU fallback. -- No way for the user to recover without restarting the app if reload fails. -- The `_loading` flag prevents transcription during reload, but there's no timeout or progress feedback. -- The Preprocessor (`onnx_asr`) explicitly excludes CUDA — it always runs on CPU regardless of the selected provider. This is by design but may confuse users expecting full CUDA acceleration. - -## Key files - -### Core (wingman-ai) -- `providers/parakeet.py` — Parakeet provider, model loading, settings update -- `services/settings_service.py:144-148` — calls `parakeet.update_settings_async()` -- `api/interface.py:183-192` — `ParakeetSettings` dataclass - -### onnx_asr library (3rd party, in venv) -- `onnx_asr/loader.py:187-357` — `load_model()`, downloads files, creates ONNX sessions -- `onnx_asr/models/nemo.py:70-85` — `NemoConformerRnnt.__init__()` creates encoder/decoder sessions -- `onnx_asr/models/nemo.py:181-183` — `NemoConformerAED._get_excluded_providers()` excludes CoreML -- `onnx_asr/preprocessors/preprocessor.py` — Preprocessor, excludes CUDA -- `onnx_asr/onnx.py:81-101` — `update_onnx_providers()` filters provider list - -### Client (wingman-client) -- The execution provider dropdown is in the STT settings UI — may need to filter options per model/platform - -## Provider/model compatibility matrix - -| Provider | NemoConformerTdt (Parakeet) | NemoConformerAED | Preprocessor | -|----------|---------------------------|------------------|-------------| -| CPU | Yes | Yes | Yes | -| DirectML | Yes | Yes | Yes | -| CUDA | Yes | Yes | **Excluded** | -| CoreML | **Crashes** (should exclude) | **Excluded** | ? | -| TensorRT | Excluded | Excluded | Excluded | From af2985bcb2516c7143cc6f454602e8321ac71599 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Fri, 10 Apr 2026 10:18:24 +0200 Subject: [PATCH 20/34] refactor(wingman): extract CommandExecutor service Move command/action execution out of the unified Wingman class into a focused service. Covers get_command, instant activation matching, command execution, and the keyboard/mouse/joystick action dispatcher. Part of the post-merge cleanup to make wingman.py more navigable. Co-Authored-By: Claude Opus 4.6 --- services/command_executor.py | 253 +++++++++++++++++++++++++++++++++++ wingmen/wingman.py | 231 +++----------------------------- wingmen/wingman_context.py | 6 + 3 files changed, 279 insertions(+), 211 deletions(-) create mode 100644 services/command_executor.py diff --git a/services/command_executor.py b/services/command_executor.py new file mode 100644 index 00000000..79570783 --- /dev/null +++ b/services/command_executor.py @@ -0,0 +1,253 @@ +"""Command and action execution service. + +Handles instant-activation matching, command dispatch, and the +keyboard/mouse/joystick/audio action dispatcher extracted from Wingman. +""" + +import difflib +import random +import time +import traceback + +import keyboard.keyboard as keyboard +import mouse.mouse as mouse + +from api.enums import LogType +from api.interface import CommandConfig, WingmanConfig +from services.audio_library import AudioLibrary +from services.printr import Printr + +printr = Printr() + + +class CommandExecutor: + """Focused service for command lookup, instant activation, and action dispatch.""" + + def __init__( + self, + config: WingmanConfig, + audio_library: AudioLibrary, + wingman_name: str, + on_reset_history, # async callable: reset_conversation_history + on_add_forced_commands=None, # async callable: add_forced_assistant_command_calls + ): + self.config = config + self.audio_library = audio_library + self.wingman_name = wingman_name + self.on_reset_history = on_reset_history + self.on_add_forced_commands = on_add_forced_commands + + # ───────────────── Command lookup ─────────────────────────── # + + def get_command(self, command_name: str) -> CommandConfig | None: + if self.config.commands is None: + return None + command = next( + (item for item in self.config.commands if item.name == command_name), + None, + ) + return command + + def select_instant_command_response(self, command: CommandConfig) -> str | None: + command_responses = command.responses + if (command_responses is None) or (len(command_responses) == 0): + return None + return random.choice(command_responses) + + # ───────────────── Instant activation ─────────────────────── # + + async def try_instant_activation(self, transcript: str) -> tuple[str, bool]: + commands = await self._execute_instant_activation_command(transcript) + if commands: + if self.on_add_forced_commands is not None: + await self.on_add_forced_commands(commands) + responses = [] + for command in commands: + if command.responses: + responses.append(self.select_instant_command_response(command)) + + if len(responses) == len(commands): + responses = list(dict.fromkeys(responses)) + responses = [ + response + "." if not response.endswith(".") else response + for response in responses + ] + return " ".join(responses), True + + return None, True + + return None, False + + async def _execute_instant_activation_command( + self, transcript: str + ) -> list[CommandConfig] | None: + try: + commands_by_instant_activation = {} + for command in self.config.commands: + if command.instant_activation: + for phrase in command.instant_activation: + if phrase.lower() in commands_by_instant_activation: + commands_by_instant_activation[phrase.lower()].append( + command + ) + else: + commands_by_instant_activation[phrase.lower()] = [command] + + phrase = difflib.get_close_matches( + transcript.lower(), + commands_by_instant_activation.keys(), + n=1, + cutoff=1, + ) + + if not phrase: + return None + + commands = commands_by_instant_activation[phrase[0]] + for command in commands: + await self.execute_command(command, True) + + return commands + except Exception as e: + await printr.print_async( + f"Error during instant activation in Wingman '{self.wingman_name}': {str(e)}", + color=LogType.ERROR, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + return None + + # ───────────────── Command execution ──────────────────────── # + + async def execute_command( + self, command: CommandConfig, is_instant=False + ) -> tuple[str | None, str]: + if not command: + return None, "Command not found" + + try: + if len(command.actions or []) == 0: + await printr.print_async( + f"No actions found for command: {command.name}", + color=LogType.WARNING, + ) + else: + await self.execute_action(command) + await printr.print_async( + f"Executed {'instant' if is_instant else 'AI'} command: {command.name}", + color=LogType.COMMAND, + ) + + if command.name == "ResetConversationHistory": + await self.on_reset_history() + await printr.print_async( + f"Executed command: {command.name}", color=LogType.COMMAND + ) + + return ( + self.select_instant_command_response(command), + command.additional_context or "OK", + ) + except Exception as e: + await printr.print_async( + f"Error executing command '{command.name}' for Wingman '{self.wingman_name}': {str(e)}", + color=LogType.ERROR, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + return None, "ERROR DURING PROCESSING" + + # ───────────────── Action dispatch ────────────────────────── # + + async def execute_action(self, command: CommandConfig): + if not command or not command.actions: + return + + def contains_numpad_key(hotkey: str) -> bool: + if not hotkey: + return False + tokens = hotkey.lower().split("+") + return any(token.startswith("num ") for token in tokens) + + try: + for action in command.actions: + if action.keyboard: + if action.keyboard.hotkey_codes and not contains_numpad_key( + action.keyboard.hotkey + ): + code = action.keyboard.hotkey_codes + else: + code = action.keyboard.hotkey + + if action.keyboard.press == action.keyboard.release: + hold = action.keyboard.hold or 0.1 + if ( + action.keyboard.hotkey_codes + and len(action.keyboard.hotkey_codes) == 1 + and not contains_numpad_key(action.keyboard.hotkey) + ): + keyboard.direct_event( + action.keyboard.hotkey_codes[0], + 0 + (1 if action.keyboard.hotkey_extended else 0), + ) + time.sleep(hold) + keyboard.direct_event( + action.keyboard.hotkey_codes[0], + 2 + (1 if action.keyboard.hotkey_extended else 0), + ) + else: + keyboard.press(code) + time.sleep(hold) + keyboard.release(code) + else: + if ( + action.keyboard.hotkey_codes + and len(action.keyboard.hotkey_codes) == 1 + and not contains_numpad_key(action.keyboard.hotkey) + ): + keyboard.direct_event( + action.keyboard.hotkey_codes[0], + (0 if action.keyboard.press else 2) + + (1 if action.keyboard.hotkey_extended else 0), + ) + else: + keyboard.send( + code, + action.keyboard.press, + action.keyboard.release, + ) + + if action.mouse: + if action.mouse.move_to: + x, y = action.mouse.move_to + mouse.move(x, y) + + if action.mouse.move: + x, y = action.mouse.move + mouse.move(x, y, absolute=False, duration=0.5) + + if action.mouse.scroll: + mouse.wheel(action.mouse.scroll) + + if action.mouse.button: + if action.mouse.hold: + mouse.press(button=action.mouse.button) + time.sleep(action.mouse.hold) + mouse.release(button=action.mouse.button) + else: + mouse.click(button=action.mouse.button) + + if action.write: + keyboard.write(action.write) + + if action.wait: + time.sleep(action.wait) + + if action.audio: + await self.audio_library.handle_action( + action.audio, self.config.sound.volume + ) + except Exception as e: + await printr.print_async( + f"Error executing actions of command '{command.name}' for wingman '{self.wingman_name}': {str(e)}", + color=LogType.ERROR, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) diff --git a/wingmen/wingman.py b/wingmen/wingman.py index fdcb74a5..4ecf5d7a 100644 --- a/wingmen/wingman.py +++ b/wingmen/wingman.py @@ -53,6 +53,7 @@ from services.conversation_manager import ConversationManager from services.conversation_condenser import ConversationCondenser from services.context_builder import ContextBuilder +from services.command_executor import CommandExecutor from services.tool_executor import ToolExecutor from services.provider_factory import ProviderFactory from services.skill_registry import SkillRegistry @@ -145,6 +146,13 @@ def __init__( self.condenser = ConversationCondenser(self.conversation, config, name) self.context_builder = ContextBuilder(config, settings, name) self.tool_executor = ToolExecutor(config, settings, name) + self.command_executor = CommandExecutor( + config=config, + audio_library=audio_library, + wingman_name=name, + on_reset_history=self.reset_conversation_history, + on_add_forced_commands=self.add_forced_assistant_command_calls, + ) # --- Token tracking --- self.last_turn_prompt_tokens: int = 0 @@ -934,7 +942,7 @@ async def _get_response_for_transcript( await self.add_user_message(transcript, images=images) benchmark.start_snapshot("Instant activation commands") - instant_response, instant_command_executed = await self._try_instant_activation( + instant_response, instant_command_executed = await self.command_executor.try_instant_activation( transcript=transcript ) if instant_response: @@ -1151,7 +1159,7 @@ async def _process_completion( if response_message.tool_calls: response_message.tool_calls = await self.tool_executor.fix_tool_calls( - response_message.tool_calls, self.get_command + response_message.tool_calls, self.command_executor.get_command ) prompt_tokens = 0 @@ -1176,8 +1184,8 @@ async def _handle_tool_calls(self, tool_calls): mcp_registry=self.mcp_registry, capability_registry=self.capability_registry, persistent_memory_service=self.persistent_memory_service, - get_command_fn=self.get_command, - execute_command_fn=self._execute_command, + get_command_fn=self.command_executor.get_command, + execute_command_fn=self.command_executor.execute_command, play_to_user_fn=self.play_to_user, local_ai_service=self.local_ai_service, update_tool_response_fn=self.conversation.update_tool_response, @@ -1197,8 +1205,8 @@ async def execute_command_by_function_call( mcp_registry=self.mcp_registry, capability_registry=self.capability_registry, persistent_memory_service=self.persistent_memory_service, - get_command_fn=self.get_command, - execute_command_fn=self._execute_command, + get_command_fn=self.command_executor.get_command, + execute_command_fn=self.command_executor.execute_command, play_to_user_fn=self.play_to_user, ) @@ -1400,29 +1408,6 @@ async def generate_image(self, text: str) -> str: ) return "" - # ───────────────── Instant activation ───────────────── # - - async def _try_instant_activation(self, transcript: str) -> tuple[str, bool]: - commands = await self._execute_instant_activation_command(transcript) - if commands: - await self.add_forced_assistant_command_calls(commands) - responses = [] - for command in commands: - if command.responses: - responses.append(self._select_instant_command_response(command)) - - if len(responses) == len(commands): - responses = list(dict.fromkeys(responses)) - responses = [ - response + "." if not response.endswith(".") else response - for response in responses - ] - return " ".join(responses), True - - return None, True - - return None, False - # ───────────────── Instant responses ───────────────── # async def _generate_instant_responses(self) -> None: @@ -1711,192 +1696,15 @@ def _command_has_effective_actions(command: CommandConfig) -> bool: return tools - # ───────────────── Commands ─────────────────────────────── # + # ───────────────── Backward-compat delegation ────────────── # def get_command(self, command_name: str) -> CommandConfig | None: - if self.config.commands is None: - return None - command = next( - (item for item in self.config.commands if item.name == command_name), - None, - ) - return command - - def _select_instant_command_response(self, command: CommandConfig) -> str | None: - command_responses = command.responses - if (command_responses is None) or (len(command_responses) == 0): - return None - return random.choice(command_responses) - - async def _execute_instant_activation_command( - self, transcript: str - ) -> list[CommandConfig] | None: - try: - commands_by_instant_activation = {} - for command in self.config.commands: - if command.instant_activation: - for phrase in command.instant_activation: - if phrase.lower() in commands_by_instant_activation: - commands_by_instant_activation[phrase.lower()].append( - command - ) - else: - commands_by_instant_activation[phrase.lower()] = [command] - - phrase = difflib.get_close_matches( - transcript.lower(), - commands_by_instant_activation.keys(), - n=1, - cutoff=1, - ) - - if not phrase: - return None - - commands = commands_by_instant_activation[phrase[0]] - for command in commands: - await self._execute_command(command, True) - - return commands - except Exception as e: - await printr.print_async( - f"Error during instant activation in Wingman '{self.name}': {str(e)}", - color=LogType.ERROR, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - return None - - async def _execute_command( - self, command: CommandConfig, is_instant=False - ) -> tuple[str | None, str]: - if not command: - return None, "Command not found" - - try: - if len(command.actions or []) == 0: - await printr.print_async( - f"No actions found for command: {command.name}", - color=LogType.WARNING, - ) - else: - await self.execute_action(command) - await printr.print_async( - f"Executed {'instant' if is_instant else 'AI'} command: {command.name}", - color=LogType.COMMAND, - ) - - if command.name == "ResetConversationHistory": - await self.reset_conversation_history() - await printr.print_async( - f"Executed command: {command.name}", color=LogType.COMMAND - ) - - return ( - self._select_instant_command_response(command), - command.additional_context or "OK", - ) - except Exception as e: - await printr.print_async( - f"Error executing command '{command.name}' for Wingman '{self.name}': {str(e)}", - color=LogType.ERROR, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - return None, "ERROR DURING PROCESSING" + """Backward-compat: delegate to command_executor.""" + return self.command_executor.get_command(command_name) async def execute_action(self, command: CommandConfig): - if not command or not command.actions: - return - - def contains_numpad_key(hotkey: str) -> bool: - if not hotkey: - return False - tokens = hotkey.lower().split("+") - return any(token.startswith("num ") for token in tokens) - - try: - for action in command.actions: - if action.keyboard: - if action.keyboard.hotkey_codes and not contains_numpad_key( - action.keyboard.hotkey - ): - code = action.keyboard.hotkey_codes - else: - code = action.keyboard.hotkey - - if action.keyboard.press == action.keyboard.release: - hold = action.keyboard.hold or 0.1 - if ( - action.keyboard.hotkey_codes - and len(action.keyboard.hotkey_codes) == 1 - and not contains_numpad_key(action.keyboard.hotkey) - ): - keyboard.direct_event( - action.keyboard.hotkey_codes[0], - 0 + (1 if action.keyboard.hotkey_extended else 0), - ) - time.sleep(hold) - keyboard.direct_event( - action.keyboard.hotkey_codes[0], - 2 + (1 if action.keyboard.hotkey_extended else 0), - ) - else: - keyboard.press(code) - time.sleep(hold) - keyboard.release(code) - else: - if ( - action.keyboard.hotkey_codes - and len(action.keyboard.hotkey_codes) == 1 - and not contains_numpad_key(action.keyboard.hotkey) - ): - keyboard.direct_event( - action.keyboard.hotkey_codes[0], - (0 if action.keyboard.press else 2) - + (1 if action.keyboard.hotkey_extended else 0), - ) - else: - keyboard.send( - code, - action.keyboard.press, - action.keyboard.release, - ) - - if action.mouse: - if action.mouse.move_to: - x, y = action.mouse.move_to - mouse.move(x, y) - - if action.mouse.move: - x, y = action.mouse.move - mouse.move(x, y, absolute=False, duration=0.5) - - if action.mouse.scroll: - mouse.wheel(action.mouse.scroll) - - if action.mouse.button: - if action.mouse.hold: - mouse.press(button=action.mouse.button) - time.sleep(action.mouse.hold) - mouse.release(button=action.mouse.button) - else: - mouse.click(button=action.mouse.button) - - if action.write: - keyboard.write(action.write) - - if action.wait: - time.sleep(action.wait) - - if action.audio: - await self.audio_library.handle_action( - action.audio, self.config.sound.volume - ) - except Exception as e: - await printr.print_async( - f"Error executing actions of command '{command.name}' for wingman '{self.name}': {str(e)}", - color=LogType.ERROR, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + """Backward-compat: delegate to command_executor.""" + await self.command_executor.execute_action(command) # ───────────────── Threading ─────────────────────────────── # @@ -1934,6 +1742,7 @@ async def update_config( old_config = deepcopy(self.config) self.config = config + self.command_executor.config = config await self._update_skill_configs(config) diff --git a/wingmen/wingman_context.py b/wingmen/wingman_context.py index 7298a544..06f94608 100644 --- a/wingmen/wingman_context.py +++ b/wingmen/wingman_context.py @@ -117,6 +117,12 @@ async def switch_tts_provider(self, provider: "TtsProvider", return True return False + # --- Commands --- + + def get_command(self, command_name: str): + """Delegate to command_executor for skills that need direct command lookup.""" + return self._wingman.command_executor.get_command(command_name) + # --- Backward compatibility (temporary) --- # These provide access to registries that some skills currently use. # They should be replaced with proper facade methods in a future iteration. From a1f98de6ff68103f24f531923432510289a611eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Fri, 10 Apr 2026 10:23:20 +0200 Subject: [PATCH 21/34] refactor(wingman): extract WingmanMcpManager service Move MCP discovery, enable/disable, and connection orchestration out of the unified Wingman class. The manager now owns the McpRegistry and the secret/header/timeout handling for connecting to servers. Co-Authored-By: Claude Opus 4.6 --- services/wingman_mcp_manager.py | 284 ++++++++++++++++++++++++++++++++ wingmen/wingman.py | 249 +++------------------------- 2 files changed, 303 insertions(+), 230 deletions(-) create mode 100644 services/wingman_mcp_manager.py diff --git a/services/wingman_mcp_manager.py b/services/wingman_mcp_manager.py new file mode 100644 index 00000000..9db10ac6 --- /dev/null +++ b/services/wingman_mcp_manager.py @@ -0,0 +1,284 @@ +"""WingmanMcpManager — owns all MCP discovery, connection, and lifecycle. + +Extracted from ``wingmen/wingman.py`` so that MCP concerns (registry creation, +secret injection, timeout handling, enable/disable, parallel init) live in one +focused service. Wingman delegates to this manager and exposes 1-line +forwarding methods for backward compatibility with external callers. +""" + +import asyncio +import traceback +from typing import Callable + +from api.commands import McpStateChangedCommand +from api.enums import LogSource, LogType, WingmanInitializationErrorType +from api.interface import SettingsConfig, WingmanConfig, WingmanInitializationError +from services.mcp_client import McpClient +from services.mcp_registry import McpRegistry +from services.printr import Printr +from services.secret_keeper import SecretKeeper + +printr = Printr() + + +class WingmanMcpManager: + """Manages MCP server discovery, connection, enable/disable, and teardown.""" + + def __init__( + self, + wingman_name: str, + mcp_client: McpClient, + secret_keeper: SecretKeeper, + get_mcp_config: Callable, # callable returning the central mcp_config (from tower) + settings: SettingsConfig, + config: WingmanConfig, + ): + self.wingman_name = wingman_name + self.mcp_client = mcp_client + self.secret_keeper = secret_keeper + self._get_mcp_config = get_mcp_config + self.settings = settings + self.config = config + + # Manager owns its registry and the broadcast callback + self.mcp_registry = McpRegistry( + mcp_client, + wingman_name=wingman_name, + on_state_changed=self._broadcast_mcp_state_changed, + ) + + # ─────────────────────────── Private ────────────────────────────────────── # + + def _broadcast_mcp_state_changed(self): + if printr._connection_manager: + printr.ensure_async( + printr._connection_manager.broadcast( + McpStateChangedCommand(wingman_name=self.wingman_name) + ) + ) + + # ─────────────────────────── Public API ─────────────────────────────────── # + + async def enable_mcp(self, mcp_name: str) -> tuple[bool, str]: + if not self.mcp_client.is_available: + return False, "MCP SDK not installed." + + if mcp_name in self.mcp_registry.get_connected_server_names(): + return True, f"MCP server '{mcp_name}' is already connected." + + central_mcp_config = self._get_mcp_config() + mcp_configs = central_mcp_config.servers if central_mcp_config else [] + + mcp_config = None + for cfg in mcp_configs: + if cfg.name == mcp_name: + mcp_config = cfg + break + + if not mcp_config: + return False, f"MCP server '{mcp_name}' not found in mcp.yaml." + + try: + headers = {} + if mcp_config.headers: + headers.update(mcp_config.headers) + + secret_key = f"mcp_{mcp_config.name}" + api_key = await self.secret_keeper.retrieve( + requester=self.wingman_name, + key=secret_key, + prompt_if_missing=False, + ) + if api_key: + if not any( + k.lower() in ["authorization", "api-key", "x-api-key"] + for k in headers.keys() + ): + headers["Authorization"] = f"Bearer {api_key}" + + default_timeout = 60.0 if mcp_config.type.value == "stdio" else 30.0 + timeout = ( + float(mcp_config.timeout) if mcp_config.timeout else default_timeout + ) + + connection = await asyncio.wait_for( + self.mcp_registry.register_server( + config=mcp_config, + headers=headers if headers else None, + ), + timeout=timeout, + ) + + if connection.is_connected: + tool_count = len(connection.tools) + return True, f"MCP server '{mcp_name}' enabled with {tool_count} tools." + else: + error = connection.error or "Connection failed." + return False, f"MCP server '{mcp_name}' failed to connect: {error}" + + except asyncio.TimeoutError: + error_msg = f"Connection timed out ({int(timeout)}s)." + self.mcp_registry.set_server_error(mcp_name, error_msg) + return False, f"MCP server '{mcp_name}': {error_msg}" + + except Exception as e: + error_msg = f"Error enabling MCP '{mcp_name}': {str(e)}" + await printr.print_async(error_msg, color=LogType.ERROR) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + return False, error_msg + + async def disable_mcp(self, mcp_name: str) -> tuple[bool, str]: + if mcp_name not in self.mcp_registry.get_connected_server_names(): + return True, f"MCP server '{mcp_name}' is already disconnected." + + try: + await self.mcp_registry.unregister_server(mcp_name) + return True, f"MCP server '{mcp_name}' disabled." + + except Exception as e: + error_msg = f"Error disabling MCP '{mcp_name}': {str(e)}" + await printr.print_async(error_msg, color=LogType.ERROR) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + return False, error_msg + + async def init_mcps(self) -> list[WingmanInitializationError]: + errors = [] + + if not self.mcp_client.is_available: + printr.print( + f"[{self.wingman_name}] MCP SDK not installed, skipping MCP initialization.", + color=LogType.WARNING, + server_only=True, + ) + return errors + + await self.unload_mcps() + + central_mcp_config = self._get_mcp_config() + mcp_configs = central_mcp_config.servers if central_mcp_config else [] + if not mcp_configs: + return errors + + discoverable_mcps = self.config.discoverable_mcps + mcps_to_connect = [mcp for mcp in mcp_configs if mcp.name in discoverable_mcps] + + if not mcps_to_connect: + return errors + + async def connect_mcp(mcp_config): + local_errors = [] + try: + headers = {} + if mcp_config.headers: + headers.update(mcp_config.headers) + + secret_key = f"mcp_{mcp_config.name}" + api_key = await self.secret_keeper.retrieve( + requester=self.wingman_name, + key=secret_key, + prompt_if_missing=False, + ) + if api_key: + printr.print( + f"MCP secret '{secret_key}' found ({len(api_key)} chars)", + color=LogType.INFO, + source_name=self.wingman_name, + server_only=True, + ) + if not any( + k.lower() in ["authorization", "api-key", "x-api-key"] + for k in headers.keys() + ): + headers["Authorization"] = f"Bearer {api_key}" + + default_timeout = 60.0 if mcp_config.type.value == "stdio" else 30.0 + timeout = ( + float(mcp_config.timeout) if mcp_config.timeout else default_timeout + ) + + try: + connection = await asyncio.wait_for( + self.mcp_registry.register_server( + config=mcp_config, + headers=headers if headers else None, + ), + timeout=timeout, + ) + except asyncio.TimeoutError: + error_msg = f"MCP '{mcp_config.display_name}' connection timed out ({int(timeout)}s)." + printr.print( + error_msg, + color=LogType.WARNING, + source_name=self.wingman_name, + server_only=True, + ) + local_errors.append( + WingmanInitializationError( + wingman_name=self.wingman_name, + message=error_msg, + error_type=WingmanInitializationErrorType.MCP_CONNECTION_FAILED, + ) + ) + return (False, None, local_errors) + + if connection.is_connected: + return ( + True, + f"{mcp_config.display_name} ({len(connection.tools)} tools)", + local_errors, + ) + else: + error_msg = f"MCP '{mcp_config.display_name}' failed to connect: {connection.error}" + local_errors.append( + WingmanInitializationError( + wingman_name=self.wingman_name, + message=error_msg, + error_type=WingmanInitializationErrorType.MCP_CONNECTION_FAILED, + ) + ) + return (False, None, local_errors) + + except Exception as e: + error_msg = f"MCP '{mcp_config.name}' initialization error: {str(e)}" + printr.print( + error_msg, + color=LogType.ERROR, + source_name=self.wingman_name, + server_only=True, + ) + printr.print( + traceback.format_exc(), color=LogType.ERROR, server_only=True + ) + local_errors.append( + WingmanInitializationError( + wingman_name=self.wingman_name, + message=error_msg, + error_type=WingmanInitializationErrorType.MCP_CONNECTION_FAILED, + ) + ) + return (False, None, local_errors) + + connection_tasks = [connect_mcp(mcp) for mcp in mcps_to_connect] + results = await asyncio.gather(*connection_tasks) + + connected_count = 0 + connected_names = [] + for success, connection_info, mcp_errors in results: + if success: + connected_count += 1 + connected_names.append(connection_info) + errors.extend(mcp_errors) + + if connected_count > 0: + await printr.print_async( + f"Discoverable MCP servers connected ({connected_count}): {', '.join(connected_names)}", + color=LogType.WINGMAN, + source=LogSource.WINGMAN, + source_name=self.wingman_name, + server_only=not self.settings.debug_mode, + ) + + return errors + + async def unload_mcps(self): + await self.mcp_registry.clear() diff --git a/wingmen/wingman.py b/wingmen/wingman.py index 4ecf5d7a..16bb934b 100644 --- a/wingmen/wingman.py +++ b/wingmen/wingman.py @@ -41,7 +41,6 @@ TtsProvider, WingmanInitializationErrorType, ) -from api.commands import McpStateChangedCommand from providers.interfaces import LlmInterface, SttInterface, TtsInterface from services.audio_player import AudioPlayer from services.benchmark import Benchmark @@ -58,8 +57,8 @@ from services.provider_factory import ProviderFactory from services.skill_registry import SkillRegistry from services.mcp_client import McpClient -from services.mcp_registry import McpRegistry from services.capability_registry import CapabilityRegistry +from services.wingman_mcp_manager import WingmanMcpManager from services.tool_response_cache import ToolResponseCompressor from services.token_utils import count_tokens from skills.skill_base import Skill @@ -168,10 +167,13 @@ def __init__( # --- MCP --- self.mcp_client = McpClient(wingman_name=self.name) - self.mcp_registry = McpRegistry( - self.mcp_client, + self.mcp_manager = WingmanMcpManager( wingman_name=self.name, - on_state_changed=self._broadcast_mcp_state_changed, + mcp_client=self.mcp_client, + secret_keeper=self.secret_keeper, + get_mcp_config=lambda: self.tower.config_manager.mcp_config if self.tower else None, + settings=self.settings, + config=self.config, ) # --- Unified capability registry --- @@ -191,6 +193,13 @@ def __init__( self.instant_responses = [] self.last_used_instant_responses = [] + # ──────────────────────────────── Backward-compat properties ──────────────── # + + @property + def mcp_registry(self): + """Backward-compat: many callers access wingman.mcp_registry directly.""" + return self.mcp_manager.mcp_registry + # ──────────────────────────────── Record keys ─────────────────────────────── # def get_record_key(self) -> str | int: @@ -343,7 +352,7 @@ async def unload_skills(self): self.skill_registry.clear() async def unload_mcps(self): - await self.mcp_registry.clear() + return await self.mcp_manager.unload_mcps() # ──────────────────────────────── Memory ──────────────────────────────────── # @@ -365,236 +374,16 @@ def ensure_memory_initialized(self) -> bool: return True return False - # ──────────────────────────────── MCP ──────────────────────────────────────── # - - def _broadcast_mcp_state_changed(self): - if printr._connection_manager: - printr.ensure_async( - printr._connection_manager.broadcast( - McpStateChangedCommand(wingman_name=self.name) - ) - ) + # ──────────────────────────────── MCP (forwarding) ─────────────────────────── # async def enable_mcp(self, mcp_name: str) -> tuple[bool, str]: - if not self.mcp_client.is_available: - return False, "MCP SDK not installed." - - if mcp_name in self.mcp_registry.get_connected_server_names(): - return True, f"MCP server '{mcp_name}' is already connected." - - central_mcp_config = self.tower.config_manager.mcp_config - mcp_configs = central_mcp_config.servers if central_mcp_config else [] - - mcp_config = None - for cfg in mcp_configs: - if cfg.name == mcp_name: - mcp_config = cfg - break - - if not mcp_config: - return False, f"MCP server '{mcp_name}' not found in mcp.yaml." - - try: - headers = {} - if mcp_config.headers: - headers.update(mcp_config.headers) - - secret_key = f"mcp_{mcp_config.name}" - api_key = await self.secret_keeper.retrieve( - requester=self.name, - key=secret_key, - prompt_if_missing=False, - ) - if api_key: - if not any( - k.lower() in ["authorization", "api-key", "x-api-key"] - for k in headers.keys() - ): - headers["Authorization"] = f"Bearer {api_key}" - - default_timeout = 60.0 if mcp_config.type.value == "stdio" else 30.0 - timeout = ( - float(mcp_config.timeout) if mcp_config.timeout else default_timeout - ) - - connection = await asyncio.wait_for( - self.mcp_registry.register_server( - config=mcp_config, - headers=headers if headers else None, - ), - timeout=timeout, - ) - - if connection.is_connected: - tool_count = len(connection.tools) - return True, f"MCP server '{mcp_name}' enabled with {tool_count} tools." - else: - error = connection.error or "Connection failed." - return False, f"MCP server '{mcp_name}' failed to connect: {error}" - - except asyncio.TimeoutError: - error_msg = f"Connection timed out ({int(timeout)}s)." - self.mcp_registry.set_server_error(mcp_name, error_msg) - return False, f"MCP server '{mcp_name}': {error_msg}" - - except Exception as e: - error_msg = f"Error enabling MCP '{mcp_name}': {str(e)}" - await printr.print_async(error_msg, color=LogType.ERROR) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - return False, error_msg + return await self.mcp_manager.enable_mcp(mcp_name) async def disable_mcp(self, mcp_name: str) -> tuple[bool, str]: - if mcp_name not in self.mcp_registry.get_connected_server_names(): - return True, f"MCP server '{mcp_name}' is already disconnected." - - try: - await self.mcp_registry.unregister_server(mcp_name) - return True, f"MCP server '{mcp_name}' disabled." - - except Exception as e: - error_msg = f"Error disabling MCP '{mcp_name}': {str(e)}" - await printr.print_async(error_msg, color=LogType.ERROR) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - return False, error_msg + return await self.mcp_manager.disable_mcp(mcp_name) async def init_mcps(self) -> list[WingmanInitializationError]: - errors = [] - - if not self.mcp_client.is_available: - printr.print( - f"[{self.name}] MCP SDK not installed, skipping MCP initialization.", - color=LogType.WARNING, - server_only=True, - ) - return errors - - await self.unload_mcps() - - central_mcp_config = self.tower.config_manager.mcp_config - mcp_configs = central_mcp_config.servers if central_mcp_config else [] - if not mcp_configs: - return errors - - discoverable_mcps = self.config.discoverable_mcps - mcps_to_connect = [mcp for mcp in mcp_configs if mcp.name in discoverable_mcps] - - if not mcps_to_connect: - return errors - - async def connect_mcp(mcp_config): - local_errors = [] - try: - headers = {} - if mcp_config.headers: - headers.update(mcp_config.headers) - - secret_key = f"mcp_{mcp_config.name}" - api_key = await self.secret_keeper.retrieve( - requester=self.name, - key=secret_key, - prompt_if_missing=False, - ) - if api_key: - printr.print( - f"MCP secret '{secret_key}' found ({len(api_key)} chars)", - color=LogType.INFO, - source_name=self.name, - server_only=True, - ) - if not any( - k.lower() in ["authorization", "api-key", "x-api-key"] - for k in headers.keys() - ): - headers["Authorization"] = f"Bearer {api_key}" - - default_timeout = 60.0 if mcp_config.type.value == "stdio" else 30.0 - timeout = ( - float(mcp_config.timeout) if mcp_config.timeout else default_timeout - ) - - try: - connection = await asyncio.wait_for( - self.mcp_registry.register_server( - config=mcp_config, - headers=headers if headers else None, - ), - timeout=timeout, - ) - except asyncio.TimeoutError: - error_msg = f"MCP '{mcp_config.display_name}' connection timed out ({int(timeout)}s)." - printr.print( - error_msg, - color=LogType.WARNING, - source_name=self.name, - server_only=True, - ) - local_errors.append( - WingmanInitializationError( - wingman_name=self.name, - message=error_msg, - error_type=WingmanInitializationErrorType.MCP_CONNECTION_FAILED, - ) - ) - return (False, None, local_errors) - - if connection.is_connected: - return ( - True, - f"{mcp_config.display_name} ({len(connection.tools)} tools)", - local_errors, - ) - else: - error_msg = f"MCP '{mcp_config.display_name}' failed to connect: {connection.error}" - local_errors.append( - WingmanInitializationError( - wingman_name=self.name, - message=error_msg, - error_type=WingmanInitializationErrorType.MCP_CONNECTION_FAILED, - ) - ) - return (False, None, local_errors) - - except Exception as e: - error_msg = f"MCP '{mcp_config.name}' initialization error: {str(e)}" - printr.print( - error_msg, - color=LogType.ERROR, - source_name=self.name, - server_only=True, - ) - printr.print( - traceback.format_exc(), color=LogType.ERROR, server_only=True - ) - local_errors.append( - WingmanInitializationError( - wingman_name=self.name, - message=error_msg, - error_type=WingmanInitializationErrorType.MCP_CONNECTION_FAILED, - ) - ) - return (False, None, local_errors) - - connection_tasks = [connect_mcp(mcp) for mcp in mcps_to_connect] - results = await asyncio.gather(*connection_tasks) - - connected_count = 0 - connected_names = [] - for success, connection_info, mcp_errors in results: - if success: - connected_count += 1 - connected_names.append(connection_info) - errors.extend(mcp_errors) - - if connected_count > 0: - await printr.print_async( - f"Discoverable MCP servers connected ({connected_count}): {', '.join(connected_names)}", - color=LogType.WINGMAN, - source=LogSource.WINGMAN, - source_name=self.name, - server_only=not self.settings.debug_mode, - ) - - return errors + return await self.mcp_manager.init_mcps() # ──────────────────────────────── Skills ───────────────────────────────────── # From f6ef6ef4898442638b0c0e10e4cd9a97bdc7c446 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Fri, 10 Apr 2026 10:28:58 +0200 Subject: [PATCH 22/34] refactor(wingman): extract WingmanSkillManager service Move skill discovery, init, enable/disable, and prepare/unprepare logic out of the unified Wingman class. Dedupe the ~50 lines of yaml-loading + override-merging + platform-checking shared between init_skills and enable_skill into private helpers. Co-Authored-By: Claude Opus 4.6 --- services/wingman_skill_manager.py | 337 ++++++++++++++++++++++++++++++ wingmen/wingman.py | 293 +++----------------------- 2 files changed, 361 insertions(+), 269 deletions(-) create mode 100644 services/wingman_skill_manager.py diff --git a/services/wingman_skill_manager.py b/services/wingman_skill_manager.py new file mode 100644 index 00000000..71365db7 --- /dev/null +++ b/services/wingman_skill_manager.py @@ -0,0 +1,337 @@ +"""WingmanSkillManager — skill discovery, lifecycle, and state. + +Extracted from ``wingmen/wingman.py``. Owns ``skills``, ``tool_skills``, and +``skill_tools``; the parent ``Wingman`` exposes them as read-through properties. +""" + +import sys +import traceback +from typing import TYPE_CHECKING + +from api.interface import ( + SkillConfig, + SettingsConfig, + WingmanConfig, + WingmanInitializationError, +) +from api.enums import ( + LogSource, + LogType, + WingmanInitializationErrorType, +) +from services.module_manager import ModuleManager +from services.printr import Printr +from services.skill_registry import SkillRegistry +from skills.skill_base import Skill +from wingmen.wingman_context import WingmanContext + +if TYPE_CHECKING: + from wingmen.wingman import Wingman + +printr = Printr() + + +def _get_skill_folder_from_module(module: str) -> str: + """Extract folder name from module path like 'skills.star_head.main' -> 'star_head'""" + return module.replace(".main", "").replace(".", "/").split("/")[1] + + +_PLATFORM_MAP = {"win32": "windows", "darwin": "darwin", "linux": "linux"} + + +class WingmanSkillManager: + """Manages skill discovery, loading, preparation, enable/disable, and teardown.""" + + def __init__( + self, + wingman: "Wingman", + config: WingmanConfig, + settings: SettingsConfig, + skill_registry: SkillRegistry, + ): + self._wingman = wingman + self.config = config + self.settings = settings + self.skill_registry = skill_registry + + # Manager owns skill state + self.skills: list[Skill] = [] + self.tool_skills: dict[str, Skill] = {} + self.skill_tools: list[dict] = [] + + # ──────────────────────────── Private helpers ─────────────────────────────── # + + def _build_user_skill_configs(self) -> dict[str, SkillConfig]: + """Map folder name → user SkillConfig for each entry in wingman config.""" + result: dict[str, SkillConfig] = {} + if self.config.skills: + for skill_config in self.config.skills: + folder_name = _get_skill_folder_from_module(skill_config.module) + result[folder_name] = skill_config + return result + + def _load_skill_config( + self, + skill_folder_name: str, + skill_config_path: str, + user_skill_configs: dict[str, SkillConfig], + ) -> SkillConfig | None: + """Read yaml config + merge user overrides; return SkillConfig or None.""" + skill_config_dict = ModuleManager.read_config(skill_config_path) + if not skill_config_dict: + return None + + if skill_folder_name in user_skill_configs: + user_config = user_skill_configs[skill_folder_name] + if user_config.custom_properties: + skill_config_dict["custom_properties"] = [ + prop.model_dump() for prop in user_config.custom_properties + ] + if user_config.prompt: + skill_config_dict["prompt"] = user_config.prompt + + return SkillConfig(**skill_config_dict) + + def _check_platform_supported(self, skill_config: SkillConfig) -> bool: + """Return True if the skill supports the current platform (or has no restriction).""" + if not skill_config.platforms: + return True + normalized = _PLATFORM_MAP.get(sys.platform, sys.platform) + if normalized not in skill_config.platforms: + printr.print( + f"Skipping skill '{skill_config.name}' - not supported on {normalized}", + color=LogType.WARNING, + server_only=True, + ) + return False + return True + + def _check_platform_supported_with_message( + self, skill_config: SkillConfig + ) -> tuple[bool, str]: + """Like _check_platform_supported but returns (ok, reason) for enable_skill.""" + if not skill_config.platforms: + return True, "" + normalized = _PLATFORM_MAP.get(sys.platform, sys.platform) + if normalized not in skill_config.platforms: + return ( + False, + f"Skill '{skill_config.name}' is not supported on {normalized}.", + ) + return True, "" + + def _instantiate_skill(self, skill_config: SkillConfig) -> Skill | None: + """Create WingmanContext, call ModuleManager.load_skill, bind helpers.""" + context = WingmanContext(self._wingman) + skill = ModuleManager.load_skill( + config=skill_config, + settings=self.settings, + wingman=context, + ) + if skill: + skill.threaded_execution = self._wingman.threaded_execution + return skill + + # ──────────────────────────── Public API ─────────────────────────────────── # + + async def init_skills(self) -> list[WingmanInitializationError]: + if self.skills: + await self.unload_skills() + + errors = [] + self.skills = [] + + user_skill_configs = self._build_user_skill_configs() + available_skills = ModuleManager.read_available_skill_configs() + discoverable_skills = self.config.discoverable_skills + + for ( + skill_folder_name, + skill_config_path, + _is_custom, + _is_local, + ) in available_skills: + try: + skill_config = self._load_skill_config( + skill_folder_name, skill_config_path, user_skill_configs + ) + if not skill_config: + continue + + if skill_config.name not in discoverable_skills: + continue + + if not self._check_platform_supported(skill_config): + continue + + skill = self._instantiate_skill(skill_config) + if skill: + self.skills.append(skill) + await self.prepare_skill(skill) + + except Exception as e: + skill_name = skill_folder_name + error_msg = f"Error loading skill '{skill_name}': {str(e)}" + await printr.print_async( + error_msg, + color=LogType.ERROR, + ) + printr.print( + traceback.format_exc(), color=LogType.ERROR, server_only=True + ) + errors.append( + WingmanInitializationError( + wingman_name=self._wingman.name, + message=error_msg, + error_type=WingmanInitializationErrorType.SKILL_INITIALIZATION_FAILED, + ) + ) + + if self.skills: + skill_names = [s.config.name for s in self.skills] + await printr.print_async( + f"Discoverable skills ({len(skill_names)}): {', '.join(skill_names)}", + color=LogType.WINGMAN, + source=LogSource.WINGMAN, + source_name=self._wingman.name, + server_only=not self.settings.debug_mode, + ) + + return errors + + async def prepare_skill(self, skill: Skill): + try: + for tool_name, tool in skill.get_tools(): + self.tool_skills[tool_name] = skill + self.skill_tools.append(tool) + + self.skill_registry.register_skill(skill) + + if skill.config.auto_activate: + success, message = await skill.ensure_activated() + if not success: + await printr.print_async( + f"Auto-activated skill '{skill.config.display_name}' failed to activate: {message}", + color=LogType.ERROR, + ) + except Exception as e: + await printr.print_async( + f"Error while preparing skill '{skill.name}': {str(e)}", + color=LogType.ERROR, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + + skill.llm_call = self._wingman.actual_llm_call + + async def unprepare_skill(self, skill: Skill): + try: + for tool_name, _ in skill.get_tools(): + self.tool_skills.pop(tool_name, None) + self.skill_tools = [ + t + for t in self.skill_tools + if t.get("function", {}).get("name") != tool_name + ] + self.skill_registry.unregister_skill(skill.name) + except Exception as e: + await printr.print_async( + f"Error while unpreparing skill '{skill.name}': {str(e)}", + color=LogType.ERROR, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + + async def enable_skill(self, skill_name: str) -> tuple[bool, str]: + for existing_skill in self.skills: + if existing_skill.config.name == skill_name: + return True, f"Skill '{skill_name}' is already enabled." + + available_skills = ModuleManager.read_available_skill_configs() + user_skill_configs = self._build_user_skill_configs() + + for ( + skill_folder_name, + skill_config_path, + _is_custom, + _is_local, + ) in available_skills: + try: + skill_config = self._load_skill_config( + skill_folder_name, skill_config_path, user_skill_configs + ) + if not skill_config: + continue + + if skill_config.name != skill_name: + continue + + ok, reason = self._check_platform_supported_with_message(skill_config) + if not ok: + return False, reason + + skill = self._instantiate_skill(skill_config) + if skill: + self.skills.append(skill) + await self.prepare_skill(skill) + + printr.print( + f"Skill '{skill_name}' activated (loaded and made discoverable).", + color=LogType.POSITIVE, + server_only=True, + ) + return True, f"Skill '{skill_name}' activated successfully." + + except Exception as e: + error_msg = f"Error activating skill '{skill_name}': {str(e)}" + await printr.print_async(error_msg, color=LogType.ERROR) + printr.print( + traceback.format_exc(), color=LogType.ERROR, server_only=True + ) + return False, error_msg + + return False, f"Skill '{skill_name}' not found." + + async def disable_skill(self, skill_name: str) -> tuple[bool, str]: + skill_to_remove = None + for skill in self.skills: + if skill.config.name == skill_name: + skill_to_remove = skill + break + + if not skill_to_remove: + return True, f"Skill '{skill_name}' is already deactivated." + + try: + await skill_to_remove.unload() + self.skills.remove(skill_to_remove) + await self.unprepare_skill(skill_to_remove) + + printr.print( + f"Skill '{skill_name}' deactivated (unloaded and removed from discoverable skills).", + color=LogType.WARNING, + server_only=True, + ) + return True, f"Skill '{skill_name}' deactivated successfully." + + except Exception as e: + error_msg = f"Error deactivating skill '{skill_name}': {str(e)}" + await printr.print_async(error_msg, color=LogType.ERROR) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + return False, error_msg + + async def unload_skills(self): + for skill in self.skills: + if not skill.is_prepared: + continue + try: + await skill.unload() + except Exception as e: + await printr.print_async( + f"Error unloading skill '{skill.name}': {str(e)}", + color=LogType.ERROR, + ) + printr.print( + traceback.format_exc(), color=LogType.ERROR, server_only=True + ) + self.tool_skills = {} + self.skill_tools = [] + self.skill_registry.clear() diff --git a/wingmen/wingman.py b/wingmen/wingman.py index 16bb934b..0a69b5b2 100644 --- a/wingmen/wingman.py +++ b/wingmen/wingman.py @@ -45,7 +45,6 @@ from services.audio_player import AudioPlayer from services.benchmark import Benchmark from services.markdown import cleanup_text -from services.module_manager import ModuleManager from services.secret_keeper import SecretKeeper from services.printr import Printr from services.audio_library import AudioLibrary @@ -59,10 +58,10 @@ from services.mcp_client import McpClient from services.capability_registry import CapabilityRegistry from services.wingman_mcp_manager import WingmanMcpManager +from services.wingman_skill_manager import WingmanSkillManager, _get_skill_folder_from_module from services.tool_response_cache import ToolResponseCompressor from services.token_utils import count_tokens from skills.skill_base import Skill -from wingmen.wingman_context import WingmanContext if TYPE_CHECKING: from services.tower import Tower @@ -70,11 +69,6 @@ printr = Printr() -def _get_skill_folder_from_module(module: str) -> str: - """Extract folder name from module path like 'skills.star_head.main' -> 'star_head'""" - return module.replace(".main", "").replace(".", "/").split("/")[1] - - class Wingman: """Unified Wingman class. @@ -160,10 +154,13 @@ def __init__( self._last_prompt_tokens: int = 0 # --- Skills --- - self.skills: list[Skill] = [] - self.tool_skills: dict[str, Skill] = {} - self.skill_tools: list[dict] = [] self.skill_registry = SkillRegistry() + self.skill_manager = WingmanSkillManager( + wingman=self, + config=config, + settings=settings, + skill_registry=self.skill_registry, + ) # --- MCP --- self.mcp_client = McpClient(wingman_name=self.name) @@ -200,6 +197,18 @@ def mcp_registry(self): """Backward-compat: many callers access wingman.mcp_registry directly.""" return self.mcp_manager.mcp_registry + @property + def skills(self): + return self.skill_manager.skills + + @property + def tool_skills(self): + return self.skill_manager.tool_skills + + @property + def skill_tools(self): + return self.skill_manager.skill_tools + # ──────────────────────────────── Record keys ─────────────────────────────── # def get_record_key(self) -> str | int: @@ -334,22 +343,7 @@ async def unload(self): await self.unload_skills() async def unload_skills(self): - for skill in self.skills: - if not skill.is_prepared: - continue - try: - await skill.unload() - except Exception as e: - await printr.print_async( - f"Error unloading skill '{skill.name}': {str(e)}", - color=LogType.ERROR, - ) - printr.print( - traceback.format_exc(), color=LogType.ERROR, server_only=True - ) - self.tool_skills = {} - self.skill_tools = [] - self.skill_registry.clear() + return await self.skill_manager.unload_skills() async def unload_mcps(self): return await self.mcp_manager.unload_mcps() @@ -385,255 +379,16 @@ async def disable_mcp(self, mcp_name: str) -> tuple[bool, str]: async def init_mcps(self) -> list[WingmanInitializationError]: return await self.mcp_manager.init_mcps() - # ──────────────────────────────── Skills ───────────────────────────────────── # + # ──────────────────────────────── Skills (forwarding) ─────────────────────── # async def init_skills(self) -> list[WingmanInitializationError]: - import sys - - current_platform = sys.platform - platform_map = {"win32": "windows", "darwin": "darwin", "linux": "linux"} - normalized_platform = platform_map.get(current_platform, current_platform) - - if self.skills: - await self.unload_skills() - - errors = [] - self.skills = [] - - user_skill_configs: dict[str, "SkillConfig"] = {} - if self.config.skills: - for skill_config in self.config.skills: - folder_name = _get_skill_folder_from_module(skill_config.module) - user_skill_configs[folder_name] = skill_config - - available_skills = ModuleManager.read_available_skill_configs() - discoverable_skills = self.config.discoverable_skills - - for ( - skill_folder_name, - skill_config_path, - _is_custom, - _is_local, - ) in available_skills: - try: - skill_config_dict = ModuleManager.read_config(skill_config_path) - if not skill_config_dict: - continue - - from api.interface import SkillConfig - - if skill_folder_name in user_skill_configs: - user_config = user_skill_configs[skill_folder_name] - if user_config.custom_properties: - skill_config_dict["custom_properties"] = [ - prop.model_dump() for prop in user_config.custom_properties - ] - if user_config.prompt: - skill_config_dict["prompt"] = user_config.prompt - - skill_config = SkillConfig(**skill_config_dict) - - if skill_config.name not in discoverable_skills: - continue - - if skill_config.platforms: - if normalized_platform not in skill_config.platforms: - printr.print( - f"Skipping skill '{skill_config.name}' - not supported on {normalized_platform}", - color=LogType.WARNING, - server_only=True, - ) - continue - - context = WingmanContext(self) - skill = ModuleManager.load_skill( - config=skill_config, - settings=self.settings, - wingman=context, - ) - if skill: - skill.threaded_execution = self.threaded_execution - self.skills.append(skill) - await self.prepare_skill(skill) - - except Exception as e: - skill_name = skill_folder_name - error_msg = f"Error loading skill '{skill_name}': {str(e)}" - await printr.print_async( - error_msg, - color=LogType.ERROR, - ) - printr.print( - traceback.format_exc(), color=LogType.ERROR, server_only=True - ) - errors.append( - WingmanInitializationError( - wingman_name=self.name, - message=error_msg, - error_type=WingmanInitializationErrorType.SKILL_INITIALIZATION_FAILED, - ) - ) - - if self.skills: - skill_names = [s.config.name for s in self.skills] - await printr.print_async( - f"Discoverable skills ({len(skill_names)}): {', '.join(skill_names)}", - color=LogType.WINGMAN, - source=LogSource.WINGMAN, - source_name=self.name, - server_only=not self.settings.debug_mode, - ) - - return errors - - async def prepare_skill(self, skill: Skill): - try: - for tool_name, tool in skill.get_tools(): - self.tool_skills[tool_name] = skill - self.skill_tools.append(tool) - - self.skill_registry.register_skill(skill) - - if skill.config.auto_activate: - success, message = await skill.ensure_activated() - if not success: - await printr.print_async( - f"Auto-activated skill '{skill.config.display_name}' failed to activate: {message}", - color=LogType.ERROR, - ) - except Exception as e: - await printr.print_async( - f"Error while preparing skill '{skill.name}': {str(e)}", - color=LogType.ERROR, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - - skill.llm_call = self.actual_llm_call - - async def unprepare_skill(self, skill: Skill): - try: - for tool_name, _ in skill.get_tools(): - self.tool_skills.pop(tool_name, None) - self.skill_tools = [ - t - for t in self.skill_tools - if t.get("function", {}).get("name") != tool_name - ] - self.skill_registry.unregister_skill(skill.name) - except Exception as e: - await printr.print_async( - f"Error while unpreparing skill '{skill.name}': {str(e)}", - color=LogType.ERROR, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + return await self.skill_manager.init_skills() async def enable_skill(self, skill_name: str) -> tuple[bool, str]: - import sys - - current_platform = sys.platform - platform_map = {"win32": "windows", "darwin": "darwin", "linux": "linux"} - normalized_platform = platform_map.get(current_platform, current_platform) - - for existing_skill in self.skills: - if existing_skill.config.name == skill_name: - return True, f"Skill '{skill_name}' is already enabled." - - available_skills = ModuleManager.read_available_skill_configs() - user_skill_configs: dict[str, "SkillConfig"] = {} - if self.config.skills: - for skill_config in self.config.skills: - folder_name = _get_skill_folder_from_module(skill_config.module) - user_skill_configs[folder_name] = skill_config - - for ( - skill_folder_name, - skill_config_path, - _is_custom, - _is_local, - ) in available_skills: - try: - skill_config_dict = ModuleManager.read_config(skill_config_path) - if not skill_config_dict: - continue - - from api.interface import SkillConfig - - if skill_folder_name in user_skill_configs: - user_config = user_skill_configs[skill_folder_name] - if user_config.custom_properties: - skill_config_dict["custom_properties"] = [ - prop.model_dump() for prop in user_config.custom_properties - ] - if user_config.prompt: - skill_config_dict["prompt"] = user_config.prompt - - skill_config = SkillConfig(**skill_config_dict) - - if skill_config.name != skill_name: - continue - - if skill_config.platforms: - if normalized_platform not in skill_config.platforms: - return ( - False, - f"Skill '{skill_name}' is not supported on {normalized_platform}.", - ) - - context = WingmanContext(self) - skill = ModuleManager.load_skill( - config=skill_config, - settings=self.settings, - wingman=context, - ) - if skill: - skill.threaded_execution = self.threaded_execution - self.skills.append(skill) - await self.prepare_skill(skill) - - printr.print( - f"Skill '{skill_name}' activated (loaded and made discoverable).", - color=LogType.POSITIVE, - server_only=True, - ) - return True, f"Skill '{skill_name}' activated successfully." - - except Exception as e: - error_msg = f"Error activating skill '{skill_name}': {str(e)}" - await printr.print_async(error_msg, color=LogType.ERROR) - printr.print( - traceback.format_exc(), color=LogType.ERROR, server_only=True - ) - return False, error_msg - - return False, f"Skill '{skill_name}' not found." + return await self.skill_manager.enable_skill(skill_name) async def disable_skill(self, skill_name: str) -> tuple[bool, str]: - skill_to_remove = None - for skill in self.skills: - if skill.config.name == skill_name: - skill_to_remove = skill - break - - if not skill_to_remove: - return True, f"Skill '{skill_name}' is already deactivated." - - try: - await skill_to_remove.unload() - self.skills.remove(skill_to_remove) - await self.unprepare_skill(skill_to_remove) - - printr.print( - f"Skill '{skill_name}' deactivated (unloaded and removed from discoverable skills).", - color=LogType.WARNING, - server_only=True, - ) - return True, f"Skill '{skill_name}' deactivated successfully." - - except Exception as e: - error_msg = f"Error deactivating skill '{skill_name}': {str(e)}" - await printr.print_async(error_msg, color=LogType.ERROR) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - return False, error_msg + return await self.skill_manager.disable_skill(skill_name) # ──────────────────────────── The main processing loop ──────────────────────── # From 46e5987b2d6e82751dc4ebbd24d8629494097f66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Fri, 10 Apr 2026 10:42:24 +0200 Subject: [PATCH 23/34] refactor(wingman): wire skill context into ConversationManager once Drop the per-call skills/skill_registry/tool_skills kwargs from ConversationManager methods. WingmanSkillManager now pushes the current skill state into the conversation via set_skill_context() after every init/enable/disable, eliminating the wrapper cluster in Wingman that was just forwarding skill state on every call. Co-Authored-By: Claude Opus 4.6 --- services/conversation_manager.py | 52 +++++++++++++++++-------------- services/wingman_skill_manager.py | 14 +++++++++ wingmen/wingman.py | 46 ++++++--------------------- wingmen/wingman_context.py | 4 +-- 4 files changed, 53 insertions(+), 63 deletions(-) diff --git a/services/conversation_manager.py b/services/conversation_manager.py index e4293d6b..eb8e0ac8 100644 --- a/services/conversation_manager.py +++ b/services/conversation_manager.py @@ -17,7 +17,6 @@ if TYPE_CHECKING: from api.interface import CommandConfig, SettingsConfig, WingmanConfig - from skills.skill_base import Skill printr = Printr() @@ -43,6 +42,26 @@ def __init__( self.pending_tool_calls: list[str] = [] self.conversation_summary: str = "" + # Skill state — set once via set_skill_context(); avoids per-call kwargs + self._skills: list = [] + self._skill_registry = None + self._tool_skills: dict = {} + + def set_skill_context( + self, + skills: list, + skill_registry, + tool_skills: dict, + ) -> None: + """Wire current skill state into the manager. + + Called by WingmanSkillManager after init/enable/disable so per-call + methods don't need to receive it as arguments. + """ + self._skills = skills + self._skill_registry = skill_registry + self._tool_skills = tool_skills + # ------------------------------------------------------------------ # GPT / assistant response helpers # ------------------------------------------------------------------ @@ -51,9 +70,6 @@ async def add_gpt_response( self, message, tool_calls, - skills: list["Skill"] = None, - skill_registry=None, - tool_skills: dict = None, ) -> tuple[bool, bool]: """Adds a message from GPT to the conversation history as well as adding dummy tool responses for any tool calls. @@ -61,12 +77,9 @@ async def add_gpt_response( Args: message (dict | ChatCompletionMessage): The message to add. tool_calls (list): The tool calls associated with the message. - skills: List of active skills for hook invocation. - skill_registry: The skill registry (used for ``is_meta_tool``). - tool_skills: Mapping of function-name -> Skill instance. """ # call skill hooks (only for prepared/activated skills) - for skill in (skills or []): + for skill in self._skills: if skill.is_prepared: await skill.on_add_assistant_message( message.content, message.tool_calls @@ -92,10 +105,10 @@ async def add_gpt_response( # Meta-tools (search_skills, activate_skill, etc.) always need a follow-up # LLM call so it can use the newly activated tools - if skill_registry and skill_registry.is_meta_tool(function_name): + if self._skill_registry and self._skill_registry.is_meta_tool(function_name): is_summarize_needed = True - elif tool_skills and function_name in tool_skills: - skill = tool_skills[function_name] + elif function_name in self._tool_skills: + skill = self._tool_skills[function_name] if await skill.is_waiting_response_needed(function_name): is_waiting_response_needed = True if await skill.is_summarize_needed(function_name): @@ -240,7 +253,6 @@ async def add_user_message( self, content: str, images: list[tuple[str, str]] = None, - skills: list["Skill"] = None, condense_fn: Optional[Callable] = None, ): """Shortens the conversation history if needed and adds a user message to it. @@ -248,11 +260,10 @@ async def add_user_message( Args: content (str): The message content to add. images (list[tuple[str, str]]): Optional list of (base64_data, mime_type) tuples to attach. - skills: List of active skills for hook invocation. condense_fn: Optional async callable invoked after cleanup (``_maybe_condense_history``). """ # call skill hooks (only for prepared/activated skills) - for skill in (skills or []): + for skill in self._skills: if skill.is_prepared: await skill.on_add_user_message(content) @@ -276,16 +287,15 @@ async def add_user_message( self.messages.append(msg) async def add_assistant_message( - self, content: str, skills: list["Skill"] = None + self, content: str ): """Adds an assistant message to the conversation history. Args: content (str): The message content to add. - skills: List of active skills for hook invocation. """ # call skill hooks (only for prepared/activated skills) - for skill in (skills or []): + for skill in self._skills: if skill.is_prepared: await skill.on_add_assistant_message(content, []) @@ -295,17 +305,11 @@ async def add_assistant_message( async def add_forced_assistant_command_calls( self, commands: list["CommandConfig"], - skills: list["Skill"] = None, - skill_registry=None, - tool_skills: dict = None, ): """Adds forced assistant command calls to the conversation history. Args: commands (list[CommandConfig]): The commands to add. - skills: List of active skills for hook invocation. - skill_registry: The skill registry (used for ``is_meta_tool``). - tool_skills: Mapping of function-name -> Skill instance. """ if not commands: @@ -368,7 +372,7 @@ async def add_forced_assistant_command_calls( tool_id_to_command[tool_id] = command await self.add_gpt_response( - message, message.tool_calls, skills, skill_registry, tool_skills + message, message.tool_calls ) for tool_call in message.tool_calls: command = tool_id_to_command[tool_call.id] diff --git a/services/wingman_skill_manager.py b/services/wingman_skill_manager.py index 71365db7..1e66dc7f 100644 --- a/services/wingman_skill_manager.py +++ b/services/wingman_skill_manager.py @@ -61,6 +61,16 @@ def __init__( # ──────────────────────────── Private helpers ─────────────────────────────── # + def _sync_conversation_skill_context(self) -> None: + """Push current skill state into the conversation manager. + + Called after every mutation (init/enable/disable/unload) so that + ConversationManager does not need per-call skill kwargs. + """ + self._wingman.conversation.set_skill_context( + self.skills, self.skill_registry, self.tool_skills + ) + def _build_user_skill_configs(self) -> dict[str, SkillConfig]: """Map folder name → user SkillConfig for each entry in wingman config.""" result: dict[str, SkillConfig] = {} @@ -197,6 +207,7 @@ async def init_skills(self) -> list[WingmanInitializationError]: server_only=not self.settings.debug_mode, ) + self._sync_conversation_skill_context() return errors async def prepare_skill(self, skill: Skill): @@ -239,6 +250,7 @@ async def unprepare_skill(self, skill: Skill): color=LogType.ERROR, ) printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + self._sync_conversation_skill_context() async def enable_skill(self, skill_name: str) -> tuple[bool, str]: for existing_skill in self.skills: @@ -272,6 +284,7 @@ async def enable_skill(self, skill_name: str) -> tuple[bool, str]: if skill: self.skills.append(skill) await self.prepare_skill(skill) + self._sync_conversation_skill_context() printr.print( f"Skill '{skill_name}' activated (loaded and made discoverable).", @@ -335,3 +348,4 @@ async def unload_skills(self): self.tool_skills = {} self.skill_tools = [] self.skill_registry.clear() + self._sync_conversation_skill_context() diff --git a/wingmen/wingman.py b/wingmen/wingman.py index 0a69b5b2..2eafaff7 100644 --- a/wingmen/wingman.py +++ b/wingmen/wingman.py @@ -144,7 +144,7 @@ def __init__( audio_library=audio_library, wingman_name=name, on_reset_history=self.reset_conversation_history, - on_add_forced_commands=self.add_forced_assistant_command_calls, + on_add_forced_commands=self.conversation.add_forced_assistant_command_calls, ) # --- Token tracking --- @@ -490,7 +490,7 @@ async def _get_response_for_transcript( transcript=transcript ) if instant_response: - await self.add_assistant_message(instant_response) + await self.conversation.add_assistant_message(instant_response) benchmark.finish_snapshot() if instant_response == ".": instant_response = None @@ -519,7 +519,7 @@ async def _get_response_for_transcript( turn_completion_tokens = usage[1] self._last_prompt_tokens = turn_prompt_tokens - is_waiting_response_needed, is_summarize_needed = await self._add_gpt_response( + is_waiting_response_needed, is_summarize_needed = await self.conversation.add_gpt_response( response_message, tool_calls ) interrupt = True @@ -555,7 +555,7 @@ async def _get_response_for_transcript( tool_timings.extend(iteration_timings) if instant_response: - await self._trim_tool_responses() + await self.conversation.trim_tool_responses(max_tokens=500, is_condensing=self.condenser.is_condensing) self._add_benchmark_snapshot( benchmark, "LLM Processing", llm_processing_time_ms ) @@ -574,7 +574,7 @@ async def _get_response_for_transcript( llm_processing_time_ms += (time.perf_counter() - llm_start) * 1000 if completion is None: - await self._trim_tool_responses() + await self.conversation.trim_tool_responses(max_tokens=500, is_condensing=self.condenser.is_condensing) self._add_benchmark_snapshot( benchmark, "LLM Processing", llm_processing_time_ms ) @@ -595,12 +595,12 @@ async def _get_response_for_transcript( self._last_prompt_tokens = turn_prompt_tokens is_waiting_response_needed, is_summarize_needed = ( - await self._add_gpt_response(response_message, tool_calls) + await self.conversation.add_gpt_response(response_message, tool_calls) ) if tool_calls: interrupt = False elif is_waiting_response_needed: - await self._trim_tool_responses() + await self.conversation.trim_tool_responses(max_tokens=500, is_condensing=self.condenser.is_condensing) self._add_benchmark_snapshot( benchmark, "LLM Processing", llm_processing_time_ms ) @@ -613,7 +613,7 @@ async def _get_response_for_transcript( ) return None, None, None, interrupt - await self._trim_tool_responses() + await self.conversation.trim_tool_responses(max_tokens=500, is_condensing=self.condenser.is_condensing) self._add_benchmark_snapshot( benchmark, "LLM Processing", llm_processing_time_ms @@ -756,44 +756,16 @@ async def execute_command_by_function_call( # ───────────────── Conversation delegation ───────────────── # - async def _add_gpt_response(self, message, tool_calls) -> tuple[bool, bool]: - return await self.conversation.add_gpt_response( - message, - tool_calls, - skills=self.skills, - skill_registry=self.skill_registry, - tool_skills=self.tool_skills, - ) - - async def _trim_tool_responses(self, max_tokens: int = 500): - await self.conversation.trim_tool_responses( - max_tokens=max_tokens, - is_condensing=self.condenser.is_condensing, - ) - async def add_user_message(self, content: str, images: list[tuple[str, str]] = None): + """Thin wrapper: resets memory-recall state then delegates to ConversationManager.""" self._memory_recall_notified = False self.context_builder.reset_memory_notification() await self.conversation.add_user_message( content, images=images, - skills=self.skills, condense_fn=lambda: self.condenser.maybe_condense(self.local_ai_service), ) - async def add_assistant_message(self, content: str): - await self.conversation.add_assistant_message( - content, skills=self.skills - ) - - async def add_forced_assistant_command_calls(self, commands: list[CommandConfig]): - await self.conversation.add_forced_assistant_command_calls( - commands, - skills=self.skills, - skill_registry=self.skill_registry, - tool_skills=self.tool_skills, - ) - async def reset_conversation_history(self): if self.persistent_memory_service and len(self.conversation.messages) >= 4: try: diff --git a/wingmen/wingman_context.py b/wingmen/wingman_context.py index 06f94608..742073cd 100644 --- a/wingmen/wingman_context.py +++ b/wingmen/wingman_context.py @@ -53,10 +53,10 @@ def get_conversation_history(self) -> list[dict]: return self._wingman.conversation.get_history() async def add_user_message(self, content: str): - await self._wingman.conversation.add_user_message(content, skills=self._wingman.skills) + await self._wingman.conversation.add_user_message(content) async def add_assistant_message(self, content: str): - await self._wingman.conversation.add_assistant_message(content, skills=self._wingman.skills) + await self._wingman.conversation.add_assistant_message(content) async def reset_conversation_history(self): await self._wingman.reset_conversation_history() From c6646d25132a8cde0928e1ba3fea2e3afe2022ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Fri, 10 Apr 2026 10:45:38 +0200 Subject: [PATCH 24/34] refactor(wingman): extract TurnMetrics service Move per-turn benchmark snapshot building and token-usage broadcast into a focused service. Local-LLM token-count fallback and the formatting helpers move with them. Co-Authored-By: Claude Opus 4.6 --- services/turn_metrics.py | 128 +++++++++++++++++++++++++++++++++++ wingmen/wingman.py | 142 +++++++-------------------------------- 2 files changed, 153 insertions(+), 117 deletions(-) create mode 100644 services/turn_metrics.py diff --git a/services/turn_metrics.py b/services/turn_metrics.py new file mode 100644 index 00000000..ab7e6701 --- /dev/null +++ b/services/turn_metrics.py @@ -0,0 +1,128 @@ +"""Per-turn benchmark snapshot building and token-usage broadcasting.""" + +from api.enums import ConversationProvider +from api.interface import BenchmarkResult, WingmanConfig +from services.benchmark import Benchmark +from services.printr import Printr +from services.token_utils import count_tokens + +printr = Printr() + + +class TurnMetrics: + """Focused service for per-turn benchmark snapshots and token-usage broadcast. + + Owns the two token counters that were previously tracked directly on + :class:`~wingmen.wingman.Wingman` and provides the formatting helpers + that used to live there as private methods. + """ + + def __init__( + self, + wingman_name: str, + config: WingmanConfig, + conversation, + ): + self.wingman_name = wingman_name + self.config = config + self.conversation = conversation + self.last_turn_prompt_tokens: int = 0 + self.last_turn_completion_tokens: int = 0 + + # ──────────────────────────── public API ─────────────────────────── # + + def add_benchmark_snapshot( + self, benchmark: Benchmark, label: str, execution_time_ms: float + ) -> None: + if execution_time_ms >= 1000: + formatted_time = f"{execution_time_ms/1000:.1f}s" + else: + formatted_time = f"{int(execution_time_ms)}ms" + + benchmark.snapshots.append( + BenchmarkResult( + label=label, + execution_time_ms=execution_time_ms, + formatted_execution_time=formatted_time, + ) + ) + + def add_tool_execution_snapshot( + self, + benchmark: Benchmark, + total_time_ms: float, + tool_timings: list[tuple[str, float]], + ) -> None: + if total_time_ms >= 1000: + formatted_time = f"{total_time_ms/1000:.1f}s" + else: + formatted_time = f"{int(total_time_ms)}ms" + + nested_snapshots = [] + for label, time_ms in tool_timings: + if time_ms >= 1000: + fmt = f"{time_ms/1000:.1f}s" + else: + fmt = f"{int(time_ms)}ms" + nested_snapshots.append( + BenchmarkResult( + label=label, + execution_time_ms=time_ms, + formatted_execution_time=fmt, + ) + ) + + benchmark.snapshots.append( + BenchmarkResult( + label="Tool Execution", + execution_time_ms=total_time_ms, + formatted_execution_time=formatted_time, + snapshots=nested_snapshots if nested_snapshots else None, + ) + ) + + async def broadcast_token_usage( + self, prompt_tokens: int, completion_tokens: int + ) -> None: + is_local = ( + self.config.features.conversation_provider == ConversationProvider.LOCAL_LLM + ) + + if is_local and prompt_tokens == 0: + prompt_tokens = sum( + count_tokens( + msg["content"] + if isinstance(msg.get("content"), str) + else str(msg.get("content", "")) + ) + for msg in self.conversation.messages + ) + if is_local and completion_tokens == 0 and self.conversation.messages: + last = self.conversation.messages[-1] + if last.get("role") == "assistant": + content = last.get("content", "") + completion_tokens = count_tokens( + content if isinstance(content, str) else str(content) + ) + + self.last_turn_prompt_tokens = prompt_tokens + self.last_turn_completion_tokens = completion_tokens + if prompt_tokens == 0 and completion_tokens == 0: + return + if not printr._connection_manager: + return + + from api.commands import ConversationTokenUsageCommand + + await printr._connection_manager.broadcast( + ConversationTokenUsageCommand( + wingman_name=self.wingman_name, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + is_local=is_local, + ) + ) + + def reset_token_counters(self) -> None: + self.last_turn_prompt_tokens = 0 + self.last_turn_completion_tokens = 0 diff --git a/wingmen/wingman.py b/wingmen/wingman.py index 2eafaff7..eb9d5e22 100644 --- a/wingmen/wingman.py +++ b/wingmen/wingman.py @@ -60,7 +60,7 @@ from services.wingman_mcp_manager import WingmanMcpManager from services.wingman_skill_manager import WingmanSkillManager, _get_skill_folder_from_module from services.tool_response_cache import ToolResponseCompressor -from services.token_utils import count_tokens +from services.turn_metrics import TurnMetrics from skills.skill_base import Skill if TYPE_CHECKING: @@ -147,9 +147,13 @@ def __init__( on_add_forced_commands=self.conversation.add_forced_assistant_command_calls, ) - # --- Token tracking --- - self.last_turn_prompt_tokens: int = 0 - self.last_turn_completion_tokens: int = 0 + # --- Metrics service --- + self.metrics = TurnMetrics( + wingman_name=name, + config=config, + conversation=self.conversation, + ) + self.execution_start: None | float = None self._last_prompt_tokens: int = 0 @@ -428,13 +432,12 @@ async def process(self, audio_input_wav: str = None, transcript: str = None, ima if actual_response: token_usage = None - if self.last_turn_prompt_tokens or self.last_turn_completion_tokens: + if self.metrics.last_turn_prompt_tokens or self.metrics.last_turn_completion_tokens: token_usage = ( - self.last_turn_prompt_tokens, - self.last_turn_completion_tokens, + self.metrics.last_turn_prompt_tokens, + self.metrics.last_turn_completion_tokens, ) - self.last_turn_prompt_tokens = 0 - self.last_turn_completion_tokens = 0 + self.metrics.reset_token_counters() await printr.print_async( f"{actual_response}", color=LogType.POSITIVE, @@ -506,7 +509,7 @@ async def _get_response_for_transcript( llm_processing_time_ms += (time.perf_counter() - llm_start) * 1000 if completion is None: - self._add_benchmark_snapshot( + self.metrics.add_benchmark_snapshot( benchmark, "LLM Processing", llm_processing_time_ms ) return None, None, None, True @@ -556,14 +559,14 @@ async def _get_response_for_transcript( if instant_response: await self.conversation.trim_tool_responses(max_tokens=500, is_condensing=self.condenser.is_condensing) - self._add_benchmark_snapshot( + self.metrics.add_benchmark_snapshot( benchmark, "LLM Processing", llm_processing_time_ms ) if tool_execution_time_ms > 0: - self._add_tool_execution_snapshot( + self.metrics.add_tool_execution_snapshot( benchmark, tool_execution_time_ms, tool_timings ) - await self._broadcast_token_usage( + await self.metrics.broadcast_token_usage( turn_prompt_tokens, turn_completion_tokens ) return None, instant_response, None, interrupt @@ -575,14 +578,14 @@ async def _get_response_for_transcript( if completion is None: await self.conversation.trim_tool_responses(max_tokens=500, is_condensing=self.condenser.is_condensing) - self._add_benchmark_snapshot( + self.metrics.add_benchmark_snapshot( benchmark, "LLM Processing", llm_processing_time_ms ) if tool_execution_time_ms > 0: - self._add_tool_execution_snapshot( + self.metrics.add_tool_execution_snapshot( benchmark, tool_execution_time_ms, tool_timings ) - await self._broadcast_token_usage( + await self.metrics.broadcast_token_usage( turn_prompt_tokens, turn_completion_tokens ) return None, None, None, True @@ -601,28 +604,28 @@ async def _get_response_for_transcript( interrupt = False elif is_waiting_response_needed: await self.conversation.trim_tool_responses(max_tokens=500, is_condensing=self.condenser.is_condensing) - self._add_benchmark_snapshot( + self.metrics.add_benchmark_snapshot( benchmark, "LLM Processing", llm_processing_time_ms ) if tool_execution_time_ms > 0: - self._add_tool_execution_snapshot( + self.metrics.add_tool_execution_snapshot( benchmark, tool_execution_time_ms, tool_timings ) - await self._broadcast_token_usage( + await self.metrics.broadcast_token_usage( turn_prompt_tokens, turn_completion_tokens ) return None, None, None, interrupt await self.conversation.trim_tool_responses(max_tokens=500, is_condensing=self.condenser.is_condensing) - self._add_benchmark_snapshot( + self.metrics.add_benchmark_snapshot( benchmark, "LLM Processing", llm_processing_time_ms ) if tool_execution_time_ms > 0: - self._add_tool_execution_snapshot( + self.metrics.add_tool_execution_snapshot( benchmark, tool_execution_time_ms, tool_timings ) - await self._broadcast_token_usage(turn_prompt_tokens, turn_completion_tokens) + await self.metrics.broadcast_token_usage(turn_prompt_tokens, turn_completion_tokens) return response_message.content, response_message.content, None, interrupt # ───────────────── LLM call ───────────────── # @@ -1001,101 +1004,6 @@ def _get_random_filler(self): self.last_used_instant_responses.append(random_index) return self.instant_responses[random_index] - # ───────────────── Benchmarks / token usage ───────────────── # - - def _add_benchmark_snapshot( - self, benchmark: Benchmark, label: str, execution_time_ms: float - ): - if execution_time_ms >= 1000: - formatted_time = f"{execution_time_ms/1000:.1f}s" - else: - formatted_time = f"{int(execution_time_ms)}ms" - - from api.interface import BenchmarkResult - - benchmark.snapshots.append( - BenchmarkResult( - label=label, - execution_time_ms=execution_time_ms, - formatted_execution_time=formatted_time, - ) - ) - - def _add_tool_execution_snapshot( - self, - benchmark: Benchmark, - total_time_ms: float, - tool_timings: list[tuple[str, float]], - ): - from api.interface import BenchmarkResult - - if total_time_ms >= 1000: - formatted_time = f"{total_time_ms/1000:.1f}s" - else: - formatted_time = f"{int(total_time_ms)}ms" - - nested_snapshots = [] - for label, time_ms in tool_timings: - if time_ms >= 1000: - fmt = f"{time_ms/1000:.1f}s" - else: - fmt = f"{int(time_ms)}ms" - nested_snapshots.append( - BenchmarkResult( - label=label, - execution_time_ms=time_ms, - formatted_execution_time=fmt, - ) - ) - - benchmark.snapshots.append( - BenchmarkResult( - label="Tool Execution", - execution_time_ms=total_time_ms, - formatted_execution_time=formatted_time, - snapshots=nested_snapshots if nested_snapshots else None, - ) - ) - - async def _broadcast_token_usage(self, prompt_tokens: int, completion_tokens: int): - is_local = ( - self.config.features.conversation_provider == ConversationProvider.LOCAL_LLM - ) - - if is_local and prompt_tokens == 0: - prompt_tokens = sum( - count_tokens( - msg["content"] - if isinstance(msg.get("content"), str) - else str(msg.get("content", "")) - ) - for msg in self.conversation.messages - ) - if is_local and completion_tokens == 0 and self.conversation.messages: - last = self.conversation.messages[-1] - if last.get("role") == "assistant": - content = last.get("content", "") - completion_tokens = count_tokens( - content if isinstance(content, str) else str(content) - ) - - self.last_turn_prompt_tokens = prompt_tokens - self.last_turn_completion_tokens = completion_tokens - if prompt_tokens == 0 and completion_tokens == 0: - return - if not printr._connection_manager: - return - - from api.commands import ConversationTokenUsageCommand - - await printr._connection_manager.broadcast( - ConversationTokenUsageCommand( - wingman_name=self.name, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - is_local=is_local, - ) - ) # ───────────────── Build tools ───────────────── # From 6dc98c6448d11fad3d90df2b831990099bf98d0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Fri, 10 Apr 2026 10:48:39 +0200 Subject: [PATCH 25/34] refactor(wingman): extract InstantResponseGenerator service Move generic instant-filler generation (the LLM-driven phrase list used during long tool calls) into a focused service. The JSON-retry loop and the random-non-repeating selection logic move with it. Co-Authored-By: Claude Opus 4.6 --- services/instant_response_generator.py | 130 +++++++++++++++++++++++++ wingmen/wingman.py | 98 +++---------------- 2 files changed, 144 insertions(+), 84 deletions(-) create mode 100644 services/instant_response_generator.py diff --git a/services/instant_response_generator.py b/services/instant_response_generator.py new file mode 100644 index 00000000..2a9856a3 --- /dev/null +++ b/services/instant_response_generator.py @@ -0,0 +1,130 @@ +"""InstantResponseGenerator service. + +Generates a list of short generic filler phrases via an LLM call (used during +long tool-call turns) and provides a random non-repeating selection from that +list. +""" + +import json +import random +import traceback + +from api.enums import LogType +from services.printr import Printr + +printr = Printr() + + +class InstantResponseGenerator: + """Generates and serves generic instant-filler phrases. + + Construct once per Wingman, then call :meth:`generate` from + ``prepare()`` (wrapped in ``threaded_execution``) and + :meth:`get_random_filler` from the turn loop. + """ + + def __init__( + self, + wingman_name: str, + llm_call_fn, + get_context_fn, + ): + """ + Args: + wingman_name: Used for log messages. + llm_call_fn: Async callable matching ``actual_llm_call(messages, tools=None)`` + returning a ``ChatCompletion | None``. + get_context_fn: Async callable matching ``get_context() -> str``. + """ + self.wingman_name = wingman_name + self._llm_call = llm_call_fn + self._get_context = get_context_fn + self.instant_responses: list[str] = [] + self.last_used_instant_responses: list[int] = [] + + async def generate(self) -> None: + """Populate ``self.instant_responses`` via an LLM call. + + Called by ``Wingman.prepare()`` when + ``config.features.use_generic_instant_responses`` is enabled. + """ + context = await self._get_context() + messages = [ + { + "role": "system", + "content": """ + Generate a list in JSON format of at least 20 short direct text responses. + Make sure the response only contains the JSON, no additional text. + They must fit the described character in the given context by the user. + Every generated response must be generally usable in every situation. + Responses must show its still in progress and not in a finished state. + The user request this response is used on is unknown. Therefore it must be generic. + Good examples: + - "Processing..." + - "Stand by..." + + Bad examples: + - "Generating route..." (too specific) + - "I'm sorry, I can't do that." (too negative) + + Response example: + [ + "OK", + "Generating results...", + "Roger that!", + "Stand by..." + ] + """, + }, + {"role": "user", "content": context}, + ] + try: + completion = await self._llm_call(messages) + if completion is None: + return + if completion.choices[0].message.content: + retry_limit = 3 + retry_count = 1 + valid = False + while not valid and retry_count <= retry_limit: + try: + responses = json.loads(completion.choices[0].message.content) + valid = True + for response in responses: + if response not in self.instant_responses: + self.instant_responses.append(str(response)) + except json.JSONDecodeError: + messages.append(completion.choices[0].message) + messages.append( + { + "role": "user", + "content": "It was tried to handle the response in its entirety as a JSON string. Fix response to be a pure, valid JSON, it was not convertable.", + } + ) + if retry_count <= retry_limit: + completion = await self._llm_call(messages) + retry_count += 1 + except Exception as e: + await printr.print_async( + f"Error while generating instant responses: {str(e)}", + color=LogType.ERROR, + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + + def get_random_filler(self) -> str | None: + """Return a random non-recently-used filler phrase. + + Returns ``None`` when no responses have been generated yet. + """ + if not self.instant_responses: + return None + + if len(self.last_used_instant_responses) > 2: + self.last_used_instant_responses = self.last_used_instant_responses[-2:] + + random_index = random.randint(0, len(self.instant_responses) - 1) + while random_index in self.last_used_instant_responses: + random_index = random.randint(0, len(self.instant_responses) - 1) + + self.last_used_instant_responses.append(random_index) + return self.instant_responses[random_index] diff --git a/wingmen/wingman.py b/wingmen/wingman.py index eb9d5e22..d7459f95 100644 --- a/wingmen/wingman.py +++ b/wingmen/wingman.py @@ -61,6 +61,7 @@ from services.wingman_skill_manager import WingmanSkillManager, _get_skill_folder_from_module from services.tool_response_cache import ToolResponseCompressor from services.turn_metrics import TurnMetrics +from services.instant_response_generator import InstantResponseGenerator from skills.skill_base import Skill if TYPE_CHECKING: @@ -189,10 +190,15 @@ def __init__( self._background_tasks: set[asyncio.Task] = set() self._tool_response_compressor = ToolResponseCompressor() + # --- Instant response generator --- + self.instant_response_generator = InstantResponseGenerator( + wingman_name=name, + llm_call_fn=self.actual_llm_call, + get_context_fn=self.get_context, + ) + # --- Conversation state --- self.last_gpt_call = None - self.instant_responses = [] - self.last_used_instant_responses = [] # ──────────────────────────────── Backward-compat properties ──────────────── # @@ -314,7 +320,7 @@ async def prepare(self): color=LogType.WARNING, server_only=True, ) - self.threaded_execution(self._generate_instant_responses) + self.threaded_execution(self.instant_response_generator.generate) except Exception as e: await printr.print_async( f"Error while preparing wingman '{self.name}': {str(e)}", @@ -532,9 +538,11 @@ async def _get_response_for_transcript( message = None if response_message.content: message = response_message.content - elif self.instant_responses: - message = self._get_random_filler() - is_summarize_needed = True + else: + filler = self.instant_response_generator.get_random_filler() + if filler: + message = filler + is_summarize_needed = True if message: self.threaded_execution(self.play_to_user, message, interrupt) await printr.print_async( @@ -927,84 +935,6 @@ async def generate_image(self, text: str) -> str: ) return "" - # ───────────────── Instant responses ───────────────── # - - async def _generate_instant_responses(self) -> None: - context = await self.get_context() - messages = [ - { - "role": "system", - "content": """ - Generate a list in JSON format of at least 20 short direct text responses. - Make sure the response only contains the JSON, no additional text. - They must fit the described character in the given context by the user. - Every generated response must be generally usable in every situation. - Responses must show its still in progress and not in a finished state. - The user request this response is used on is unknown. Therefore it must be generic. - Good examples: - - "Processing..." - - "Stand by..." - - Bad examples: - - "Generating route..." (too specific) - - "I'm sorry, I can't do that." (too negative) - - Response example: - [ - "OK", - "Generating results...", - "Roger that!", - "Stand by..." - ] - """, - }, - {"role": "user", "content": context}, - ] - try: - completion = await self.actual_llm_call(messages) - if completion is None: - return - if completion.choices[0].message.content: - retry_limit = 3 - retry_count = 1 - valid = False - while not valid and retry_count <= retry_limit: - try: - responses = json.loads(completion.choices[0].message.content) - valid = True - for response in responses: - if response not in self.instant_responses: - self.instant_responses.append(str(response)) - except json.JSONDecodeError: - messages.append(completion.choices[0].message) - messages.append( - { - "role": "user", - "content": "It was tried to handle the response in its entirety as a JSON string. Fix response to be a pure, valid JSON, it was not convertable.", - } - ) - if retry_count <= retry_limit: - completion = await self.actual_llm_call(messages) - retry_count += 1 - except Exception as e: - await printr.print_async( - f"Error while generating instant responses: {str(e)}", - color=LogType.ERROR, - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - - def _get_random_filler(self): - if len(self.last_used_instant_responses) > 2: - self.last_used_instant_responses = self.last_used_instant_responses[-2:] - - random_index = random.randint(0, len(self.instant_responses) - 1) - while random_index in self.last_used_instant_responses: - random_index = random.randint(0, len(self.instant_responses) - 1) - - self.last_used_instant_responses.append(random_index) - return self.instant_responses[random_index] - - # ───────────────── Build tools ───────────────── # def build_tools(self) -> list[dict]: From 93927c3059801c0501e930ea5d03ef57e3d28fb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Fri, 10 Apr 2026 10:50:45 +0200 Subject: [PATCH 26/34] refactor(wingman): move tool definitions to their owners build_tools() becomes a 17-line orchestrator. The execute_command tool definition lives in CommandExecutor; the persistent-memory tool definitions live in PersistentMemoryService. Wingman just asks each source for its tools. Co-Authored-By: Claude Opus 4.6 --- services/command_executor.py | 51 +++++++++++++++++ services/persistent_memory.py | 57 +++++++++++++++++++ wingmen/wingman.py | 103 ++-------------------------------- 3 files changed, 114 insertions(+), 97 deletions(-) diff --git a/services/command_executor.py b/services/command_executor.py index 79570783..ff95ff66 100644 --- a/services/command_executor.py +++ b/services/command_executor.py @@ -155,6 +155,57 @@ async def execute_command( printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) return None, "ERROR DURING PROCESSING" + # ───────────────── Tool definition ───────────────────────── # + + def get_tool_definition(self) -> dict | None: + """Return the OpenAI-style execute_command tool definition, or None if no + eligible commands are configured.""" + def _command_has_effective_actions(command: CommandConfig) -> bool: + if command.is_system_command: + return True + if not command.actions: + return False + for action in command.actions: + if not action: + continue + if ( + action.keyboard is not None + or action.mouse is not None + or action.joystick is not None + or action.audio is not None + or action.write is not None + or action.wait is not None + ): + return True + return False + + commands = [ + command.name + for command in self.config.commands + if (not command.force_instant_activation) + and _command_has_effective_actions(command) + ] + if not commands: + return None + return { + "type": "function", + "function": { + "name": "execute_command", + "description": "Executes a command", + "parameters": { + "type": "object", + "properties": { + "command_name": { + "type": "string", + "description": "The name of the command to execute", + "enum": commands, + }, + }, + "required": ["command_name"], + }, + }, + } + # ───────────────── Action dispatch ────────────────────────── # async def execute_action(self, command: CommandConfig): diff --git a/services/persistent_memory.py b/services/persistent_memory.py index 9a5762f1..e73ef367 100644 --- a/services/persistent_memory.py +++ b/services/persistent_memory.py @@ -586,6 +586,63 @@ def close(self) -> None: self._db.close() self._db = None + def get_tool_definitions(self) -> list[dict]: + """Return the OpenAI-style tool definitions for persistent memory operations. + These are exposed to the LLM whenever a PersistentMemoryService is active.""" + return [ + { + "type": "function", + "function": { + "name": "memory_remember", + "description": "Store an important fact or detail for future reference. Use when the user explicitly asks you to remember something.", + "parameters": { + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "The fact or detail to remember.", + }, + }, + "required": ["text"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "memory_recall", + "description": "Search your memory for relevant information. Use when the user asks what you remember or know about a topic.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "What to search for in memory.", + }, + }, + "required": ["query"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "memory_forget", + "description": "Remove a specific memory. Use when the user explicitly asks you to forget something.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Description of the memory to forget.", + }, + }, + "required": ["query"], + }, + }, + }, + ] + # --- Private helpers --- def _find_duplicate(self, embedding: list[float]) -> MemoryEntry | None: diff --git a/wingmen/wingman.py b/wingmen/wingman.py index d7459f95..2bf90f8e 100644 --- a/wingmen/wingman.py +++ b/wingmen/wingman.py @@ -938,53 +938,12 @@ async def generate_image(self, text: str) -> str: # ───────────────── Build tools ───────────────── # def build_tools(self) -> list[dict]: - def _command_has_effective_actions(command: CommandConfig) -> bool: - if command.is_system_command: - return True - if not command.actions: - return False - for action in command.actions: - if not action: - continue - if ( - action.keyboard is not None - or action.mouse is not None - or action.joystick is not None - or action.audio is not None - or action.write is not None - or action.wait is not None - ): - return True - return False - - commands = [ - command.name - for command in self.config.commands - if (not command.force_instant_activation) - and _command_has_effective_actions(command) - ] + """Assemble the full tool list for LLM calls.""" tools: list[dict] = [] - if commands: - tools.append( - { - "type": "function", - "function": { - "name": "execute_command", - "description": "Executes a command", - "parameters": { - "type": "object", - "properties": { - "command_name": { - "type": "string", - "description": "The name of the command to execute", - "enum": commands, - }, - }, - "required": ["command_name"], - }, - }, - } - ) + + command_tool = self.command_executor.get_tool_definition() + if command_tool: + tools.append(command_tool) for _, tool in self.capability_registry.get_meta_tools(): tools.append(tool) @@ -996,57 +955,7 @@ def _command_has_effective_actions(command: CommandConfig) -> bool: tools.append(tool) if self.persistent_memory_service: - tools.append({ - "type": "function", - "function": { - "name": "memory_remember", - "description": "Store an important fact or detail for future reference. Use when the user explicitly asks you to remember something.", - "parameters": { - "type": "object", - "properties": { - "text": { - "type": "string", - "description": "The fact or detail to remember.", - }, - }, - "required": ["text"], - }, - }, - }) - tools.append({ - "type": "function", - "function": { - "name": "memory_recall", - "description": "Search your memory for relevant information. Use when the user asks what you remember or know about a topic.", - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "What to search for in memory.", - }, - }, - "required": ["query"], - }, - }, - }) - tools.append({ - "type": "function", - "function": { - "name": "memory_forget", - "description": "Remove a specific memory. Use when the user explicitly asks you to forget something.", - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "Description of the memory to forget.", - }, - }, - "required": ["query"], - }, - }, - }) + tools.extend(self.persistent_memory_service.get_tool_definitions()) return tools From cd28a8da9acc81b2a7450d0729e823c77a6803f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Fri, 10 Apr 2026 10:51:49 +0200 Subject: [PATCH 27/34] refactor(wingman): cache image-gen subscription instance Stop instantiating a fresh WingmanSubscription on every generate_image call. The subscription is created lazily on first use and reused. Co-Authored-By: Claude Opus 4.6 --- wingmen/wingman.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/wingmen/wingman.py b/wingmen/wingman.py index 2bf90f8e..b144582a 100644 --- a/wingmen/wingman.py +++ b/wingmen/wingman.py @@ -190,6 +190,9 @@ def __init__( self._background_tasks: set[asyncio.Task] = set() self._tool_response_compressor = ToolResponseCompressor() + # --- Image generation (lazy) --- + self._image_subscription = None + # --- Instant response generator --- self.instant_response_generator = InstantResponseGenerator( wingman_name=name, @@ -917,22 +920,22 @@ async def play_to_user( async def generate_image(self, text: str) -> str: if ( self.config.features.image_generation_provider - == ImageGenerationProvider.WINGMAN_PRO + != ImageGenerationProvider.WINGMAN_PRO ): - try: + return "" + try: + if self._image_subscription is None: from providers.wingman_subscription import WingmanSubscription - wingman_pro = WingmanSubscription( + self._image_subscription = WingmanSubscription( wingman_name=self.name, settings=self.settings.wingman_pro ) - return await wingman_pro.generate_image(text) - except Exception as e: - await printr.print_async( - f"Error during image generation: {str(e)}", color=LogType.ERROR - ) - printr.print( - traceback.format_exc(), color=LogType.ERROR, server_only=True - ) + return await self._image_subscription.generate_image(text) + except Exception as e: + await printr.print_async( + f"Error during image generation: {str(e)}", color=LogType.ERROR + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) return "" # ───────────────── Build tools ───────────────── # From 8e2487727c81a8ba4af6e69ce12f1b718e670102 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Fri, 10 Apr 2026 10:52:41 +0200 Subject: [PATCH 28/34] refactor(wingman): move threaded_execution to threading_utils Wingman.threaded_execution is now a one-line wrapper around the free function in services/threading_utils.py. The wrapper is kept because WingmanSkillManager binds it onto every skill instance. Co-Authored-By: Claude Opus 4.6 --- services/threading_utils.py | 42 +++++++++++++++++++++++++++++++++++++ wingmen/wingman.py | 24 ++------------------- 2 files changed, 44 insertions(+), 22 deletions(-) create mode 100644 services/threading_utils.py diff --git a/services/threading_utils.py b/services/threading_utils.py new file mode 100644 index 00000000..a24dd674 --- /dev/null +++ b/services/threading_utils.py @@ -0,0 +1,42 @@ +"""Threading helpers used across wingmen and skills.""" + +import asyncio +import threading +import traceback + +from api.enums import LogType +from services.printr import Printr + +printr = Printr() + + +def threaded_execution(function, *args) -> threading.Thread | None: + """Run ``function`` in a fresh daemon thread. + + If ``function`` is a coroutine function, a new event loop is created + inside the thread to run it. Otherwise it is called directly. + + Returns the started thread, or ``None`` on failure. + """ + try: + + def start_thread(function, *args): + if asyncio.iscoroutinefunction(function): + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + new_loop.run_until_complete(function(*args)) + new_loop.close() + else: + function(*args) + + thread = threading.Thread(target=start_thread, args=(function, *args)) + thread.name = function.__name__ + thread.daemon = True + thread.start() + return thread + except Exception as e: + printr.print( + f"Error starting threaded execution: {str(e)}", color=LogType.ERROR + ) + printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) + return None diff --git a/wingmen/wingman.py b/wingmen/wingman.py index b144582a..adb88ccb 100644 --- a/wingmen/wingman.py +++ b/wingmen/wingman.py @@ -55,6 +55,7 @@ from services.tool_executor import ToolExecutor from services.provider_factory import ProviderFactory from services.skill_registry import SkillRegistry +from services.threading_utils import threaded_execution from services.mcp_client import McpClient from services.capability_registry import CapabilityRegistry from services.wingman_mcp_manager import WingmanMcpManager @@ -975,28 +976,7 @@ async def execute_action(self, command: CommandConfig): # ───────────────── Threading ─────────────────────────────── # def threaded_execution(self, function, *args) -> threading.Thread | None: - try: - - def start_thread(function, *args): - if asyncio.iscoroutinefunction(function): - new_loop = asyncio.new_event_loop() - asyncio.set_event_loop(new_loop) - new_loop.run_until_complete(function(*args)) - new_loop.close() - else: - function(*args) - - thread = threading.Thread(target=start_thread, args=(function, *args)) - thread.name = function.__name__ - thread.daemon = True - thread.start() - return thread - except Exception as e: - printr.print( - f"Error starting threaded execution: {str(e)}", color=LogType.ERROR - ) - printr.print(traceback.format_exc(), color=LogType.ERROR, server_only=True) - return None + return threaded_execution(function, *args) # ───────────────── Config management ─────────────────────── # From da8a5e90be402f9d117f5d6979cf8085abbb4975 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Fri, 10 Apr 2026 13:50:11 +0200 Subject: [PATCH 29/34] simpliy code --- services/benchmark.py | 14 ++-- services/command_executor.py | 44 +++++++------ services/config_service.py | 7 +- services/conversation_manager.py | 5 +- services/platform_utils.py | 14 ++++ services/turn_metrics.py | 44 ++++--------- services/wingman_mcp_manager.py | 102 +++++++++++++++--------------- services/wingman_skill_manager.py | 39 ++++-------- wingmen/wingman.py | 16 ----- wingmen/wingman_context.py | 2 +- 10 files changed, 127 insertions(+), 160 deletions(-) create mode 100644 services/platform_utils.py diff --git a/services/benchmark.py b/services/benchmark.py index 3312dcbf..685de8fe 100644 --- a/services/benchmark.py +++ b/services/benchmark.py @@ -4,6 +4,13 @@ from services.printr import Printr +def format_ms(execution_time_ms: float) -> str: + """Format milliseconds as ``"1.2s"`` or ``"340ms"``.""" + if execution_time_ms >= 1000: + return f"{execution_time_ms/1000:.1f}s" + return f"{int(execution_time_ms)}ms" + + class Benchmark: def __init__(self, label: str): self.label = label @@ -53,13 +60,8 @@ def finish_snapshot(self) -> BenchmarkResult | None: def _create_benchmark_result(self, label: str, start_time: float): end_time = time.perf_counter() execution_time = (end_time - start_time) * 1000 # Convert to milliseconds - if execution_time >= 1000: - formatted_execution_time = f"{execution_time/1000:.1f}s" - else: - formatted_execution_time = f"{int(execution_time)}ms" - return BenchmarkResult( label=label, execution_time_ms=execution_time, - formatted_execution_time=formatted_execution_time, + formatted_execution_time=format_ms(execution_time), ) diff --git a/services/command_executor.py b/services/command_executor.py index ff95ff66..4d11054c 100644 --- a/services/command_executor.py +++ b/services/command_executor.py @@ -20,6 +20,27 @@ printr = Printr() +def _command_has_effective_actions(command: CommandConfig) -> bool: + """True if the command has at least one action the LLM can meaningfully trigger.""" + if command.is_system_command: + return True + if not command.actions: + return False + for action in command.actions: + if not action: + continue + if ( + action.keyboard is not None + or action.mouse is not None + or action.joystick is not None + or action.audio is not None + or action.write is not None + or action.wait is not None + ): + return True + return False + + class CommandExecutor: """Focused service for command lookup, instant activation, and action dispatch.""" @@ -81,6 +102,8 @@ async def try_instant_activation(self, transcript: str) -> tuple[str, bool]: async def _execute_instant_activation_command( self, transcript: str ) -> list[CommandConfig] | None: + if not self.config.commands: + return None try: commands_by_instant_activation = {} for command in self.config.commands: @@ -160,25 +183,8 @@ async def execute_command( def get_tool_definition(self) -> dict | None: """Return the OpenAI-style execute_command tool definition, or None if no eligible commands are configured.""" - def _command_has_effective_actions(command: CommandConfig) -> bool: - if command.is_system_command: - return True - if not command.actions: - return False - for action in command.actions: - if not action: - continue - if ( - action.keyboard is not None - or action.mouse is not None - or action.joystick is not None - or action.audio is not None - or action.write is not None - or action.wait is not None - ): - return True - return False - + if not self.config.commands: + return None commands = [ command.name for command in self.config.commands diff --git a/services/config_service.py b/services/config_service.py index c0f4c03d..e57f807e 100644 --- a/services/config_service.py +++ b/services/config_service.py @@ -282,7 +282,7 @@ async def get_wingman_skills( wingman_name: str, ) -> list[WingmanSkillState]: """Get all skills with their enabled/disabled state for a specific wingman.""" - import sys + from services.platform_utils import normalize_platform try: # Get all available skills @@ -307,10 +307,7 @@ async def get_wingman_skills( discoverable_skills = wingman_config.discoverable_skills - # Get current platform for filtering - current_platform = sys.platform - platform_map = {"win32": "windows", "darwin": "darwin", "linux": "linux"} - normalized_platform = platform_map.get(current_platform, current_platform) + normalized_platform = normalize_platform() # Build response with enabled state result = [] diff --git a/services/conversation_manager.py b/services/conversation_manager.py index eb8e0ac8..f999b5df 100644 --- a/services/conversation_manager.py +++ b/services/conversation_manager.py @@ -11,7 +11,7 @@ ParsedFunction, ) -from api.enums import ConversationProvider, LogType, LogSource +from api.enums import ConversationProvider, LogType from services.printr import Printr from services.token_utils import count_tokens, truncate_to_tokens @@ -24,9 +24,6 @@ class ConversationManager: """Owns the conversation message list, tool-response bookkeeping, and history cleanup / trimming logic. - - Extracted from ``OpenAiWingman`` — all behaviour is identical, only the - home module changed. """ def __init__( diff --git a/services/platform_utils.py b/services/platform_utils.py new file mode 100644 index 00000000..8f859a15 --- /dev/null +++ b/services/platform_utils.py @@ -0,0 +1,14 @@ +"""Shared platform helpers.""" + +import sys + +_PLATFORM_MAP = {"win32": "windows", "darwin": "darwin", "linux": "linux"} + + +def normalize_platform(platform: str | None = None) -> str: + """Return a normalized platform name (``windows``/``darwin``/``linux``). + + Falls back to the raw ``sys.platform`` string if no mapping exists. + """ + raw = platform if platform is not None else sys.platform + return _PLATFORM_MAP.get(raw, raw) diff --git a/services/turn_metrics.py b/services/turn_metrics.py index ab7e6701..6aca93cc 100644 --- a/services/turn_metrics.py +++ b/services/turn_metrics.py @@ -2,7 +2,7 @@ from api.enums import ConversationProvider from api.interface import BenchmarkResult, WingmanConfig -from services.benchmark import Benchmark +from services.benchmark import Benchmark, format_ms from services.printr import Printr from services.token_utils import count_tokens @@ -10,12 +10,7 @@ class TurnMetrics: - """Focused service for per-turn benchmark snapshots and token-usage broadcast. - - Owns the two token counters that were previously tracked directly on - :class:`~wingmen.wingman.Wingman` and provides the formatting helpers - that used to live there as private methods. - """ + """Focused service for per-turn benchmark snapshots and token-usage broadcast.""" def __init__( self, @@ -34,16 +29,11 @@ def __init__( def add_benchmark_snapshot( self, benchmark: Benchmark, label: str, execution_time_ms: float ) -> None: - if execution_time_ms >= 1000: - formatted_time = f"{execution_time_ms/1000:.1f}s" - else: - formatted_time = f"{int(execution_time_ms)}ms" - benchmark.snapshots.append( BenchmarkResult( label=label, execution_time_ms=execution_time_ms, - formatted_execution_time=formatted_time, + formatted_execution_time=format_ms(execution_time_ms), ) ) @@ -53,31 +43,21 @@ def add_tool_execution_snapshot( total_time_ms: float, tool_timings: list[tuple[str, float]], ) -> None: - if total_time_ms >= 1000: - formatted_time = f"{total_time_ms/1000:.1f}s" - else: - formatted_time = f"{int(total_time_ms)}ms" - - nested_snapshots = [] - for label, time_ms in tool_timings: - if time_ms >= 1000: - fmt = f"{time_ms/1000:.1f}s" - else: - fmt = f"{int(time_ms)}ms" - nested_snapshots.append( - BenchmarkResult( - label=label, - execution_time_ms=time_ms, - formatted_execution_time=fmt, - ) + nested_snapshots = [ + BenchmarkResult( + label=label, + execution_time_ms=time_ms, + formatted_execution_time=format_ms(time_ms), ) + for label, time_ms in tool_timings + ] benchmark.snapshots.append( BenchmarkResult( label="Tool Execution", execution_time_ms=total_time_ms, - formatted_execution_time=formatted_time, - snapshots=nested_snapshots if nested_snapshots else None, + formatted_execution_time=format_ms(total_time_ms), + snapshots=nested_snapshots or None, ) ) diff --git a/services/wingman_mcp_manager.py b/services/wingman_mcp_manager.py index 9db10ac6..5600bb4c 100644 --- a/services/wingman_mcp_manager.py +++ b/services/wingman_mcp_manager.py @@ -1,9 +1,7 @@ """WingmanMcpManager — owns all MCP discovery, connection, and lifecycle. -Extracted from ``wingmen/wingman.py`` so that MCP concerns (registry creation, -secret injection, timeout handling, enable/disable, parallel init) live in one -focused service. Wingman delegates to this manager and exposes 1-line -forwarding methods for backward compatibility with external callers. +Centralises MCP concerns (registry creation, secret injection, timeout +handling, enable/disable, parallel init) in one focused service. """ import asyncio @@ -11,7 +9,12 @@ from typing import Callable from api.commands import McpStateChangedCommand -from api.enums import LogSource, LogType, WingmanInitializationErrorType +from api.enums import ( + LogSource, + LogType, + McpTransportType, + WingmanInitializationErrorType, +) from api.interface import SettingsConfig, WingmanConfig, WingmanInitializationError from services.mcp_client import McpClient from services.mcp_registry import McpRegistry @@ -20,6 +23,10 @@ printr = Printr() +_AUTH_HEADER_KEYS = {"authorization", "api-key", "x-api-key"} +_STDIO_DEFAULT_TIMEOUT = 60.0 +_HTTP_DEFAULT_TIMEOUT = 30.0 + class WingmanMcpManager: """Manages MCP server discovery, connection, enable/disable, and teardown.""" @@ -49,6 +56,41 @@ def __init__( # ─────────────────────────── Private ────────────────────────────────────── # + async def _prepare_connection_params( + self, mcp_config, log_secret_found: bool = False + ) -> tuple[dict, float]: + """Build request headers (with secret-injected auth) and resolve the timeout.""" + headers: dict = {} + if mcp_config.headers: + headers.update(mcp_config.headers) + + secret_key = f"mcp_{mcp_config.name}" + api_key = await self.secret_keeper.retrieve( + requester=self.wingman_name, + key=secret_key, + prompt_if_missing=False, + ) + if api_key: + if log_secret_found: + printr.print( + f"MCP secret '{secret_key}' found ({len(api_key)} chars)", + color=LogType.INFO, + source_name=self.wingman_name, + server_only=True, + ) + if not any(k.lower() in _AUTH_HEADER_KEYS for k in headers.keys()): + headers["Authorization"] = f"Bearer {api_key}" + + default_timeout = ( + _STDIO_DEFAULT_TIMEOUT + if mcp_config.type == McpTransportType.STDIO + else _HTTP_DEFAULT_TIMEOUT + ) + timeout = ( + float(mcp_config.timeout) if mcp_config.timeout else default_timeout + ) + return headers, timeout + def _broadcast_mcp_state_changed(self): if printr._connection_manager: printr.ensure_async( @@ -79,27 +121,7 @@ async def enable_mcp(self, mcp_name: str) -> tuple[bool, str]: return False, f"MCP server '{mcp_name}' not found in mcp.yaml." try: - headers = {} - if mcp_config.headers: - headers.update(mcp_config.headers) - - secret_key = f"mcp_{mcp_config.name}" - api_key = await self.secret_keeper.retrieve( - requester=self.wingman_name, - key=secret_key, - prompt_if_missing=False, - ) - if api_key: - if not any( - k.lower() in ["authorization", "api-key", "x-api-key"] - for k in headers.keys() - ): - headers["Authorization"] = f"Bearer {api_key}" - - default_timeout = 60.0 if mcp_config.type.value == "stdio" else 30.0 - timeout = ( - float(mcp_config.timeout) if mcp_config.timeout else default_timeout - ) + headers, timeout = await self._prepare_connection_params(mcp_config) connection = await asyncio.wait_for( self.mcp_registry.register_server( @@ -168,32 +190,8 @@ async def init_mcps(self) -> list[WingmanInitializationError]: async def connect_mcp(mcp_config): local_errors = [] try: - headers = {} - if mcp_config.headers: - headers.update(mcp_config.headers) - - secret_key = f"mcp_{mcp_config.name}" - api_key = await self.secret_keeper.retrieve( - requester=self.wingman_name, - key=secret_key, - prompt_if_missing=False, - ) - if api_key: - printr.print( - f"MCP secret '{secret_key}' found ({len(api_key)} chars)", - color=LogType.INFO, - source_name=self.wingman_name, - server_only=True, - ) - if not any( - k.lower() in ["authorization", "api-key", "x-api-key"] - for k in headers.keys() - ): - headers["Authorization"] = f"Bearer {api_key}" - - default_timeout = 60.0 if mcp_config.type.value == "stdio" else 30.0 - timeout = ( - float(mcp_config.timeout) if mcp_config.timeout else default_timeout + headers, timeout = await self._prepare_connection_params( + mcp_config, log_secret_found=True ) try: diff --git a/services/wingman_skill_manager.py b/services/wingman_skill_manager.py index 1e66dc7f..98df1753 100644 --- a/services/wingman_skill_manager.py +++ b/services/wingman_skill_manager.py @@ -1,10 +1,9 @@ """WingmanSkillManager — skill discovery, lifecycle, and state. -Extracted from ``wingmen/wingman.py``. Owns ``skills``, ``tool_skills``, and -``skill_tools``; the parent ``Wingman`` exposes them as read-through properties. +Owns ``skills``, ``tool_skills``, and ``skill_tools``; the parent ``Wingman`` +exposes them as read-through properties. """ -import sys import traceback from typing import TYPE_CHECKING @@ -20,6 +19,7 @@ WingmanInitializationErrorType, ) from services.module_manager import ModuleManager +from services.platform_utils import normalize_platform from services.printr import Printr from services.skill_registry import SkillRegistry from skills.skill_base import Skill @@ -36,9 +36,6 @@ def _get_skill_folder_from_module(module: str) -> str: return module.replace(".main", "").replace(".", "/").split("/")[1] -_PLATFORM_MAP = {"win32": "windows", "darwin": "darwin", "linux": "linux"} - - class WingmanSkillManager: """Manages skill discovery, loading, preparation, enable/disable, and teardown.""" @@ -102,27 +99,13 @@ def _load_skill_config( return SkillConfig(**skill_config_dict) - def _check_platform_supported(self, skill_config: SkillConfig) -> bool: - """Return True if the skill supports the current platform (or has no restriction).""" - if not skill_config.platforms: - return True - normalized = _PLATFORM_MAP.get(sys.platform, sys.platform) - if normalized not in skill_config.platforms: - printr.print( - f"Skipping skill '{skill_config.name}' - not supported on {normalized}", - color=LogType.WARNING, - server_only=True, - ) - return False - return True - - def _check_platform_supported_with_message( + def _check_platform_supported( self, skill_config: SkillConfig ) -> tuple[bool, str]: - """Like _check_platform_supported but returns (ok, reason) for enable_skill.""" + """Return (ok, reason) for whether the skill supports the current platform.""" if not skill_config.platforms: return True, "" - normalized = _PLATFORM_MAP.get(sys.platform, sys.platform) + normalized = normalize_platform() if normalized not in skill_config.platforms: return ( False, @@ -171,7 +154,13 @@ async def init_skills(self) -> list[WingmanInitializationError]: if skill_config.name not in discoverable_skills: continue - if not self._check_platform_supported(skill_config): + ok, reason = self._check_platform_supported(skill_config) + if not ok: + printr.print( + f"Skipping skill - {reason}", + color=LogType.WARNING, + server_only=True, + ) continue skill = self._instantiate_skill(skill_config) @@ -276,7 +265,7 @@ async def enable_skill(self, skill_name: str) -> tuple[bool, str]: if skill_config.name != skill_name: continue - ok, reason = self._check_platform_supported_with_message(skill_config) + ok, reason = self._check_platform_supported(skill_config) if not ok: return False, reason diff --git a/wingmen/wingman.py b/wingmen/wingman.py index adb88ccb..685c19a1 100644 --- a/wingmen/wingman.py +++ b/wingmen/wingman.py @@ -5,14 +5,11 @@ interfaces for STT, TTS, and LLM. """ -import json import time import asyncio -import random import traceback import threading from copy import deepcopy -import difflib from typing import ( Any, Dict, @@ -21,8 +18,6 @@ ) from openai import APIConnectionError from openai.types.chat import ChatCompletion -import keyboard.keyboard as keyboard -import mouse.mouse as mouse from api.interface import ( CommandConfig, SettingsConfig, @@ -129,13 +124,6 @@ def __init__( self.tts: TtsInterface | None = None self.llm: LlmInterface | None = None - # --- Backward-compat: keep old attributes for custom wingmen / skills --- - self.whispercpp = whispercpp - self.fasterwhisper = fasterwhisper - self.parakeet = parakeet - self.xvasynth = xvasynth - self.pocket_tts = pocket_tts - # --- Extracted services --- self.conversation = ConversationManager(config, settings, name) self.condenser = ConversationCondenser(self.conversation, config, name) @@ -157,7 +145,6 @@ def __init__( ) self.execution_start: None | float = None - self._last_prompt_tokens: int = 0 # --- Skills --- self.skill_registry = SkillRegistry() @@ -530,7 +517,6 @@ async def _get_response_for_transcript( turn_prompt_tokens = usage[0] turn_completion_tokens = usage[1] - self._last_prompt_tokens = turn_prompt_tokens is_waiting_response_needed, is_summarize_needed = await self.conversation.add_gpt_response( response_message, tool_calls @@ -607,7 +593,6 @@ async def _get_response_for_transcript( ) turn_prompt_tokens = usage[0] turn_completion_tokens += usage[1] - self._last_prompt_tokens = turn_prompt_tokens is_waiting_response_needed, is_summarize_needed = ( await self.conversation.add_gpt_response(response_message, tool_calls) @@ -791,7 +776,6 @@ async def reset_conversation_history(self): pass await self.conversation.reset() - self._last_prompt_tokens = 0 self.skill_registry.reset_activations() self.mcp_registry.reset_activations() diff --git a/wingmen/wingman_context.py b/wingmen/wingman_context.py index 742073cd..b518c588 100644 --- a/wingmen/wingman_context.py +++ b/wingmen/wingman_context.py @@ -50,7 +50,7 @@ async def llm_call(self, messages: list[dict], tools: list[dict] | None = None) def get_conversation_history(self) -> list[dict]: """Get a read-only copy of the conversation history.""" - return self._wingman.conversation.get_history() + return list(self._wingman.conversation.messages) async def add_user_message(self, content: str): await self._wingman.conversation.add_user_message(content) From af30289b8e942cb1ea1c4c2041696a769b6bfe33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Mon, 13 Apr 2026 02:04:00 +0200 Subject: [PATCH 30/34] shorten message --- services/context_builder.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/services/context_builder.py b/services/context_builder.py index 8b4a97a8..79afc2c6 100644 --- a/services/context_builder.py +++ b/services/context_builder.py @@ -139,9 +139,7 @@ async def build( else getattr(msg, "content", "") ) # Extract plain text from multimodal content (images etc.) - content = ( - self._extract_text_content(raw_content) if raw_content else "" - ) + content = self._extract_text_content(raw_content) if raw_content else "" if role == "user" and content: last_user_msg = content break @@ -152,10 +150,7 @@ async def build( last_user_msg ) ) - if ( - persistent_memory_context - and not self._memory_recall_notified - ): + if persistent_memory_context and not self._memory_recall_notified: self._memory_recall_notified = True # Count restored fact lines (lines starting with "- ") fact_count = sum( @@ -165,7 +160,7 @@ async def build( ) if fact_count > 0: await printr.print_async( - f"Memory recalled: {fact_count} relevant {'memory' if fact_count == 1 else 'memories'} loaded.", + f"Memory: {fact_count} {'memory' if fact_count == 1 else 'memories'} recalled", color=LogType.MEMORY, source_name=self._wingman_name, ) From 4c7abec0f1ffcad08e4b090d1d6de1c031012b65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Mon, 13 Apr 2026 11:55:00 +0200 Subject: [PATCH 31/34] add endpoint to manually trigger STT --- wingman_core.py | 58 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/wingman_core.py b/wingman_core.py index e5d0e366..362417c3 100644 --- a/wingman_core.py +++ b/wingman_core.py @@ -87,6 +87,7 @@ from services.image_processing import process_image, validate_image_mime from services.model_metadata import ModelMetadataService from services.audio_recorder import RECORDING_PATH, AudioRecorder +from services.threading_utils import threaded_execution from services.config_manager import ConfigManager from services.printr import Printr from services.secret_keeper import SecretKeeper @@ -146,6 +147,18 @@ def __init__( endpoint=self.send_text_to_wingman, tags=tags, ) + self.router.add_api_route( + methods=["POST"], + path="/start-recording-for-wingman", + endpoint=self.start_recording_for_wingman, + tags=tags, + ) + self.router.add_api_route( + methods=["POST"], + path="/stop-recording-for-wingman", + endpoint=self.stop_recording_for_wingman, + tags=tags, + ) self.router.add_api_route( methods=["POST"], path="/generate-greeting", @@ -1719,6 +1732,51 @@ def get_startup_errors(self): async def stop_playback(self): await self.audio_player.stop_playback() + # POST /start-recording-for-wingman + async def start_recording_for_wingman(self, wingman_name: str): + """Start audio recording for a wingman (GUI mic toggle).""" + if not self.tower or self.active_recording["key"] != "": + return + + wingman = self.tower.get_wingman_by_name(wingman_name) + if not wingman: + return + + self.active_recording = dict(key="__gui__", wingman=wingman) + self.was_listening_before_ptt = self.is_listening + if ( + self.settings_service.settings.voice_activation.enabled + and self.is_listening + ): + self.start_voice_recognition(mute=True) + + self.audio_recorder.start_recording(wingman_name=wingman.name) + + # POST /stop-recording-for-wingman + async def stop_recording_for_wingman(self, wingman_name: str): + """Stop audio recording and process the result (GUI mic toggle).""" + if ( + not self.tower + or self.active_recording["key"] != "__gui__" + ): + return + + wingman = self.active_recording["wingman"] + recorded_audio_wav = self.audio_recorder.stop_recording( + wingman_name=wingman.name + ) + self.active_recording = {"key": "", "wingman": None} + + if ( + self.settings_service.settings.voice_activation.enabled + and not self.is_listening + and self.was_listening_before_ptt + ): + self.start_voice_recognition() + + if recorded_audio_wav and isinstance(wingman, Wingman): + threaded_execution(wingman.process, str(recorded_audio_wav)) + # POST /ask-wingman-conversation-provider async def ask_wingman_conversation_provider( self, wingman_name: str, text: str = Body(...) From 3fefb8c1ca690738e1c97f0e7f0a745cf34f78bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Mon, 13 Apr 2026 16:07:43 +0200 Subject: [PATCH 32/34] fix: address Copilot PR review feedback (#387) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix inverted no_interrupt boolean in tool-call loop play_to_user call - Propagate config/settings updates to all extracted services (not just command_executor/skills) - Roll back tts_provider config on failed provider creation - Guard config_dir access in WingmanContext.get_context() matching Wingman's own guard - Fix persistent memory tool names in context instructions (remember→memory_remember, etc.) - Update skill docs to use WingmanContext type instead of Wingman Co-Authored-By: Claude Opus 4.6 --- services/context_builder.py | 4 ++-- skills/AGENTS.md | 4 ++-- skills/README.md | 8 ++++---- wingmen/wingman.py | 24 +++++++++++++++++++++++- wingmen/wingman_context.py | 5 ++++- 5 files changed, 35 insertions(+), 10 deletions(-) diff --git a/services/context_builder.py b/services/context_builder.py index 79afc2c6..93721363 100644 --- a/services/context_builder.py +++ b/services/context_builder.py @@ -193,9 +193,9 @@ async def build( "\n\n# PERSISTENT MEMORY\n" "You have persistent memory. Important facts and past conversation summaries " "are provided in the [Memory] sections above (if any). " - "You can use the `remember`, `recall`, and `forget` tools when the user " + "You can use the `memory_remember`, `memory_recall`, and `memory_forget` tools when the user " "explicitly asks you to remember, recall, or forget something. " - "You don't need to use `remember` for routine information — that is handled automatically." + "You don't need to use `memory_remember` for routine information — that is handled automatically." ) self._last_compiled_context = context diff --git a/skills/AGENTS.md b/skills/AGENTS.md index 5b0551b1..50cf7436 100644 --- a/skills/AGENTS.md +++ b/skills/AGENTS.md @@ -118,10 +118,10 @@ from api.interface import SettingsConfig, SkillConfig, WingmanInitializationErro from skills.skill_base import Skill, tool if TYPE_CHECKING: - from wingmen.wingman import Wingman + from wingmen.wingman_context import WingmanContext class YourSkillName(Skill): - def __init__(self, config: SkillConfig, settings: SettingsConfig, wingman: "Wingman") -> None: + def __init__(self, config: SkillConfig, settings: SettingsConfig, wingman: "WingmanContext") -> None: super().__init__(config=config, settings=settings, wingman=wingman) async def validate(self) -> list[WingmanInitializationError]: diff --git a/skills/README.md b/skills/README.md index bbd0ba3d..fe4184f8 100644 --- a/skills/README.md +++ b/skills/README.md @@ -630,7 +630,7 @@ from api.interface import SettingsConfig, SkillConfig, WingmanInitializationErro from skills.skill_base import Skill, tool if TYPE_CHECKING: - from wingmen.wingman import Wingman + from wingmen.wingman_context import WingmanContext class YourSkillName(Skill): @@ -640,7 +640,7 @@ class YourSkillName(Skill): self, config: SkillConfig, settings: SettingsConfig, - wingman: "Wingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) # Initialize your skill here @@ -1766,7 +1766,7 @@ from skills.skill_base import Skill, tool from services.skill_local_ai import MemoryType if TYPE_CHECKING: - from wingmen.wingman import Wingman + from wingmen.wingman_context import WingmanContext class GameStatsTracker(Skill): @@ -1776,7 +1776,7 @@ class GameStatsTracker(Skill): self, config: SkillConfig, settings: SettingsConfig, - wingman: "Wingman", + wingman: "WingmanContext", ) -> None: super().__init__(config=config, settings=settings, wingman=wingman) diff --git a/wingmen/wingman.py b/wingmen/wingman.py index 685c19a1..5d21c1da 100644 --- a/wingmen/wingman.py +++ b/wingmen/wingman.py @@ -534,7 +534,7 @@ async def _get_response_for_transcript( message = filler is_summarize_needed = True if message: - self.threaded_execution(self.play_to_user, message, interrupt) + self.threaded_execution(self.play_to_user, message, not interrupt) await printr.print_async( f"{message}", color=LogType.POSITIVE, @@ -972,7 +972,15 @@ async def update_config( old_config = deepcopy(self.config) self.config = config + + # Propagate to all services that hold a config reference self.command_executor.config = config + self.conversation._config = config + self.condenser._config = config + self.context_builder._config = config + self.tool_executor._config = config + self.metrics.config = config + self.mcp_manager.config = config await self._update_skill_configs(config) @@ -984,7 +992,15 @@ async def update_config( error.error_type != WingmanInitializationErrorType.MISSING_SECRET ): + # Roll back config on all services self.config = old_config + self.command_executor.config = old_config + self.conversation._config = old_config + self.condenser._config = old_config + self.context_builder._config = old_config + self.tool_executor._config = old_config + self.metrics.config = old_config + self.mcp_manager.config = old_config return False return True @@ -1044,6 +1060,12 @@ async def update_settings(self, settings: SettingsConfig): try: self.settings = settings + # Propagate to all services that hold a settings reference + self.conversation._settings = settings + self.context_builder._settings = settings + self.tool_executor._settings = settings + self.mcp_manager.settings = settings + for skill in self.skills: skill.settings = settings diff --git a/wingmen/wingman_context.py b/wingmen/wingman_context.py index b518c588..4292976e 100644 --- a/wingmen/wingman_context.py +++ b/wingmen/wingman_context.py @@ -89,7 +89,7 @@ async def get_context(self) -> str: conversation_summary=self._wingman.condenser.summary, persistent_memory_service=self._wingman.persistent_memory_service, messages=self._wingman.conversation.messages, - config_dir_name=self._wingman.tower.config_dir.name if self._wingman.tower else None, + config_dir_name=self._wingman.tower.config_dir.name if self._wingman.tower and self._wingman.tower.config_dir and self._wingman.tower.config_dir.name else None, ) # --- Provider switching (for voice_changer and similar) --- @@ -102,6 +102,7 @@ async def switch_tts_provider(self, provider: "TtsProvider", Used by voice_changer skill. """ from services.provider_factory import ProviderFactory + old_provider = self._wingman.config.features.tts_provider self._wingman.config.features.tts_provider = provider factory = ProviderFactory( config=self._wingman.config, @@ -115,6 +116,8 @@ async def switch_tts_provider(self, provider: "TtsProvider", if new_tts: self._wingman.tts = new_tts return True + # Roll back config on failure + self._wingman.config.features.tts_provider = old_provider return False # --- Commands --- From a5049ce3f11caf9fbeb1087d7b93e1e8c5a4fead Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Mon, 13 Apr 2026 16:19:23 +0200 Subject: [PATCH 33/34] fix: address second round of Copilot PR review feedback (#387) - Fix duplicate tool response in exception path (update existing instead of adding new) - Propagate config/settings to WingmanSkillManager in update_config/update_settings - Delegate WingmanContext.add_user_message to Wingman wrapper (preserves memory reset + condensation) - Clarify get_conversation_history docstring (shallow copy, not truly read-only) - Remove "include all secrets" from condensation prompt to prevent leaking sensitive data - Fix typo and improve JSON retry prompt in instant response generator Co-Authored-By: Claude Opus 4.6 --- services/conversation_condenser.py | 2 +- services/instant_response_generator.py | 2 +- services/tool_executor.py | 5 ++++- wingmen/wingman.py | 3 +++ wingmen/wingman_context.py | 8 ++++++-- 5 files changed, 15 insertions(+), 5 deletions(-) diff --git a/services/conversation_condenser.py b/services/conversation_condenser.py index e7fb1888..c25a3c3b 100644 --- a/services/conversation_condenser.py +++ b/services/conversation_condenser.py @@ -288,7 +288,7 @@ async def condense( user_prompt_suffix = ( "\n\n---\n" "Now list every fact from the conversation above as bullet points.\n" - "Start from the FIRST message, end at the LAST. Include all secrets, names, preferences, and creative content:" + "Start from the FIRST message, end at the LAST. Include all names, preferences, and creative content. Never include secrets, API keys, credentials, passwords, or tokens:" ) prefix_suffix_tokens = count_tokens(user_prompt_prefix) + count_tokens( user_prompt_suffix diff --git a/services/instant_response_generator.py b/services/instant_response_generator.py index 2a9856a3..b2f364fb 100644 --- a/services/instant_response_generator.py +++ b/services/instant_response_generator.py @@ -98,7 +98,7 @@ async def generate(self) -> None: messages.append( { "role": "user", - "content": "It was tried to handle the response in its entirety as a JSON string. Fix response to be a pure, valid JSON, it was not convertable.", + "content": "The response could not be parsed as JSON. Return only valid JSON with no additional text.", } ) if retry_count <= retry_limit: diff --git a/services/tool_executor.py b/services/tool_executor.py index bebdadb0..00f00edb 100644 --- a/services/tool_executor.py +++ b/services/tool_executor.py @@ -204,7 +204,10 @@ async def handle_tool_calls( # adding a new tool response add_tool_response_fn(tool_call, function_response) except Exception as e: - add_tool_response_fn(tool_call, "Error") + if tool_call.id: + await update_tool_response_fn(tool_call.id, "Error") + else: + add_tool_response_fn(tool_call, "Error") await printr.print_async( f"Error while processing tool call: {str(e)}", color=LogType.ERROR ) diff --git a/wingmen/wingman.py b/wingmen/wingman.py index 5d21c1da..ab428708 100644 --- a/wingmen/wingman.py +++ b/wingmen/wingman.py @@ -981,6 +981,7 @@ async def update_config( self.tool_executor._config = config self.metrics.config = config self.mcp_manager.config = config + self.skill_manager.config = config await self._update_skill_configs(config) @@ -1001,6 +1002,7 @@ async def update_config( self.tool_executor._config = old_config self.metrics.config = old_config self.mcp_manager.config = old_config + self.skill_manager.config = old_config return False return True @@ -1065,6 +1067,7 @@ async def update_settings(self, settings: SettingsConfig): self.context_builder._settings = settings self.tool_executor._settings = settings self.mcp_manager.settings = settings + self.skill_manager.settings = settings for skill in self.skills: skill.settings = settings diff --git a/wingmen/wingman_context.py b/wingmen/wingman_context.py index 4292976e..7dcff181 100644 --- a/wingmen/wingman_context.py +++ b/wingmen/wingman_context.py @@ -49,11 +49,15 @@ async def llm_call(self, messages: list[dict], tools: list[dict] | None = None) return await self._wingman.actual_llm_call(messages, tools) def get_conversation_history(self) -> list[dict]: - """Get a read-only copy of the conversation history.""" + """Get a shallow copy of the conversation history. + + Note: message objects are shared with the live conversation state. + Do not mutate individual messages. + """ return list(self._wingman.conversation.messages) async def add_user_message(self, content: str): - await self._wingman.conversation.add_user_message(content) + await self._wingman.add_user_message(content) async def add_assistant_message(self, content: str): await self._wingman.conversation.add_assistant_message(content) From 5e50d6b958433d4c07fa78144a43474dfe77560f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Hopst=C3=A4tter?= Date: Mon, 13 Apr 2026 16:27:19 +0200 Subject: [PATCH 34/34] fix: address third round of Copilot PR review feedback (#387) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix chunked condensation prompt to also exclude secrets (missed in round 2) - Fix dict[str, any] → dict[str, Any] type annotations in tool_executor and wingman - Rename WingmanContext.retrieve_secret param from requester to secret_name to match API - Move blocking requests.get to asyncio.to_thread in provider_factory OpenRouter check Co-Authored-By: Claude Opus 4.6 --- services/conversation_condenser.py | 4 ++-- services/provider_factory.py | 16 +++++++++++----- services/tool_executor.py | 4 ++-- wingmen/wingman.py | 2 +- wingmen/wingman_context.py | 4 ++-- 5 files changed, 18 insertions(+), 12 deletions(-) diff --git a/services/conversation_condenser.py b/services/conversation_condenser.py index c25a3c3b..8197a4c3 100644 --- a/services/conversation_condenser.py +++ b/services/conversation_condenser.py @@ -491,7 +491,7 @@ async def _chunked_support( user_prompt = ( f"{existing_summary_section if i == 0 else ''}" f"CONVERSATION TO SUMMARIZE (part {i + 1}/{len(chunks)}):\n{chunk}\n\n" - "---\nList every fact from the above as bullet points. Include all secrets, names, preferences, and creative content:" + "---\nList every fact from the above as bullet points. Include all names, preferences, and creative content. Never include secrets, API keys, credentials, passwords, or tokens:" ) # Safety: if chunk input exceeds budget, truncate chunk text @@ -504,7 +504,7 @@ async def _chunked_support( user_prompt = ( f"{existing_summary_section if i == 0 else ''}" f"CONVERSATION TO SUMMARIZE (part {i + 1}/{len(chunks)}):\n{chunk}\n\n" - "---\nList every fact from the above as bullet points. Include all secrets, names, preferences, and creative content:" + "---\nList every fact from the above as bullet points. Include all names, preferences, and creative content. Never include secrets, API keys, credentials, passwords, or tokens:" ) await printr.print_async( f"Chunk {i + 1}/{len(chunks)} exceeded context budget, truncated to fit.", diff --git a/services/provider_factory.py b/services/provider_factory.py index 404de76f..2e70ad35 100644 --- a/services/provider_factory.py +++ b/services/provider_factory.py @@ -310,13 +310,19 @@ async def _check_openrouter_tool_support(self, api_key: str) -> bool: Replicates the logic from OpenAiWingman.validate_and_set_openrouter(). """ try: + import asyncio import requests + model = self._config.openrouter.conversation_model - response = requests.get( - f"https://openrouter.ai/api/v1/models/{model}", - headers={"Authorization": f"Bearer {api_key}"}, - timeout=10, - ) + + def _fetch(): + return requests.get( + f"https://openrouter.ai/api/v1/models/{model}", + headers={"Authorization": f"Bearer {api_key}"}, + timeout=10, + ) + + response = await asyncio.to_thread(_fetch) if response.status_code == 200: result = response.json() supported_params = result.get("data", {}).get( diff --git a/services/tool_executor.py b/services/tool_executor.py index 00f00edb..b4cedcb6 100644 --- a/services/tool_executor.py +++ b/services/tool_executor.py @@ -3,7 +3,7 @@ import json import time import traceback -from typing import TYPE_CHECKING, Callable, Awaitable +from typing import TYPE_CHECKING, Any, Callable, Awaitable from api.enums import LogType from services.benchmark import Benchmark @@ -223,7 +223,7 @@ async def handle_tool_calls( async def execute_by_function_call( self, function_name: str, - function_args: dict[str, any], + function_args: dict[str, Any], *, tool_skills: dict, skill_registry: "SkillRegistry", diff --git a/wingmen/wingman.py b/wingmen/wingman.py index ab428708..ee9c41d9 100644 --- a/wingmen/wingman.py +++ b/wingmen/wingman.py @@ -738,7 +738,7 @@ async def _handle_tool_calls(self, tool_calls): ) async def execute_command_by_function_call( - self, function_name: str, function_args: dict[str, any] + self, function_name: str, function_args: dict[str, Any] ) -> tuple[str, str | None, Skill | None, str | None]: """Public API kept for backward compatibility with skills.""" return await self.tool_executor.execute_by_function_call( diff --git a/wingmen/wingman_context.py b/wingmen/wingman_context.py index 7dcff181..37a5e2e6 100644 --- a/wingmen/wingman_context.py +++ b/wingmen/wingman_context.py @@ -78,8 +78,8 @@ async def generate_image(self, text: str) -> str: # --- Secrets --- - async def retrieve_secret(self, requester: str, errors: list = None) -> str | None: - return await self._wingman.retrieve_secret(requester, errors or []) + async def retrieve_secret(self, secret_name: str, errors: list = None) -> str | None: + return await self._wingman.retrieve_secret(secret_name, errors or []) # --- Utilities ---