Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_dump_config(self) -> tuple[bool, str]:
2. on_start() → calls connect_to_avatar() → starts audio loop
3. Audio arrives → sample rate checked → queued → calls send_audio_to_avatar()
4. flush command → calls interrupt_avatar()
5. tts_audio_end event → calls send_eof_to_avatar() (if reason=1)
5. tts_audio_end event → calls send_eof_to_avatar()
6. on_stop() → calls disconnect_from_avatar() → cleanup

You don't need to override on_init/on_start/on_stop!
Expand Down Expand Up @@ -227,7 +227,7 @@ async def send_eof_to_avatar(self) -> None:

Called automatically in two scenarios:
1. When drain command is received (manual trigger)
2. When tts_audio_end event arrives with reason=1 (TTS completion)
2. When tts_audio_end event arrives

Example:
async def send_eof_to_avatar(self) -> None:
Expand Down Expand Up @@ -384,22 +384,22 @@ async def on_data(self, ten_env: AsyncTenEnv, data: Data) -> None:

if data_name == "tts_audio_end":
json_str, _ = data.get_property_to_json(None)
reason = None
request_id = "unknown"
if json_str:
payload = json.loads(json_str)
reason = payload.get("reason")
request_id = payload.get("request_id", "unknown")
ten_env.log_info(
f"{self.LOG_PREFIX} tts_audio_end: "
f"reason={reason}, request_id={request_id}"
)
ten_env.log_info(
f"{self.LOG_PREFIX} tts_audio_end: "
f"reason={reason}, request_id={request_id}"
)

# reason=1 means TTS generation complete
if reason == 1:
ten_env.log_info(
f"{self.LOG_PREFIX} TTS complete "
f"(request_id={request_id}), sending EOF"
)
await self._on_tts_audio_end(ten_env)
ten_env.log_info(
f"{self.LOG_PREFIX} TTS audio ended "
f"(request_id={request_id}), sending EOF"
)
await self._on_tts_audio_end(ten_env)

# ========================================================================
# AUDIO HANDLING - Managed by base class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
#
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from time import time
from typing import TypedDict

from agora_token_builder import RtcTokenBuilder
from ten_runtime import AsyncTenEnv
from ten_ai_base.config import BaseConfig
from spatius import new_avatar_session, AgoraEgressConfig
Expand All @@ -23,6 +25,7 @@ class SpatiusParams(TypedDict, total=False):
agora_uid: str
agora_token: str
agora_appid: str
agora_appcert: str
agora_channel: str
region: str
sample_rate: int | str
Expand All @@ -40,12 +43,14 @@ class SpatiusConfig(BaseConfig):
agora_uid: str = ""
agora_token: str = ""
agora_appid: str = ""
agora_appcert: str = ""
agora_channel: str = ""

region: str = ""
sample_rate: int = 24000
session_expire_minutes: int = 30

channel: str = ""
params: SpatiusParams = field(default_factory=dict)

dump: bool = False
Expand All @@ -71,9 +76,15 @@ def update_params(self) -> None:
if "agora_appid" in self.params:
self.agora_appid = self.params["agora_appid"]

if "agora_appcert" in self.params:
self.agora_appcert = self.params["agora_appcert"]

if "agora_channel" in self.params:
self.agora_channel = self.params["agora_channel"]

if self._has_value(self.channel):
self.agora_channel = self.channel

if "region" in self.params:
self.region = self.params["region"]

Expand All @@ -92,7 +103,6 @@ def validate_params(self) -> None:
"params.spatius_app_id": self.spatius_app_id,
"params.spatius_avatar_id": self.spatius_avatar_id,
"params.agora_uid": self.agora_uid,
"params.agora_token": self.agora_token,
"params.agora_appid": self.agora_appid,
"params.agora_channel": self.agora_channel,
}
Expand All @@ -107,6 +117,14 @@ def validate_params(self) -> None:
f"Missing required fields: {', '.join(missing_fields)}"
)

if not self._has_value(self.agora_token) and not self._has_value(
self.agora_appcert
):
raise ValueError(
"Either params.agora_token or params.agora_appcert "
"must be provided"
)

if self.sample_rate <= 0:
raise ValueError("sample_rate must be greater than 0")

Expand All @@ -118,6 +136,25 @@ def validate_params(self) -> None:
except ValueError as exc:
raise ValueError("params.agora_uid must be an integer") from exc

def resolve_agora_token(self) -> str:
"""Return configured Agora token or generate one from app cert."""
if self._has_value(self.agora_token):
return self.agora_token

privilege_expired_ts = int(time()) + (self.session_expire_minutes * 60)
return RtcTokenBuilder.buildTokenWithUid(
self.agora_appid,
self.agora_appcert,
self.agora_channel,
int(self.agora_uid),
1,
privilege_expired_ts,
)

@staticmethod
def _has_value(value: str) -> bool:
return bool(value and value.strip())


class SpatiusAvatarExtension(AsyncAvatarBaseExtension):
"""
Expand Down Expand Up @@ -174,6 +211,7 @@ async def validate_config(self, ten_env: AsyncTenEnv) -> bool:
f"agora_uid={self.config.agora_uid}, "
f"agora_token={self._masked_agora_token()}, "
f"agora_appid={self.config.agora_appid}, "
f"agora_appcert={self._masked_agora_appcert()}, "
f"agora_channel={self.config.agora_channel}, "
f"sample_rate={self.config.sample_rate}, "
"session_expire_minutes="
Expand All @@ -193,10 +231,20 @@ def _masked_api_key(self) -> str:

def _masked_agora_token(self) -> str:
"""Return a redacted Agora token for logs."""
if not self.config.agora_token:
return "(generated from app cert)"
if len(self.config.agora_token) <= 4:
return "(short)"
return f"***{self.config.agora_token[-4:]}"

def _masked_agora_appcert(self) -> str:
"""Return a redacted Agora app certificate for logs."""
if not self.config.agora_appcert:
return "(empty)"
if len(self.config.agora_appcert) <= 4:
return "(short)"
return f"***{self.config.agora_appcert[-4:]}"

def _region(self) -> str:
"""Return the configured Spatius region."""
return (self.config.region or "").strip()
Expand All @@ -213,9 +261,10 @@ async def connect_to_avatar(self, ten_env: AsyncTenEnv) -> None:

# Create avatar session using spatius with Agora egress.
avatar_uid = int(self.config.agora_uid)
agora_token = self.config.resolve_agora_token()
agora_egress = AgoraEgressConfig(
channel_name=self.config.agora_channel,
token=self.config.agora_token,
token=agora_token,
uid=avatar_uid,
publisher_id=self.config.agora_uid,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"type": "extension",
"name": "spatius_avatar_python",
"version": "0.1.1",
"version": "0.1.2",
"dependencies": [
{
"type": "system",
Expand Down Expand Up @@ -32,6 +32,24 @@
"dump_path": {
"type": "string"
},
"channel": {
"type": "string"
},
"agora_uid": {
"type": "string"
},
"agora_token": {
"type": "string"
},
"agora_appid": {
"type": "string"
},
"agora_appcert": {
"type": "string"
},
"agora_channel": {
"type": "string"
},
"params": {
"type": "object",
"properties": {
Expand All @@ -53,6 +71,9 @@
"agora_appid": {
"type": "string"
},
"agora_appcert": {
"type": "string"
},
"agora_channel": {
"type": "string"
},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
{
"dump": false,
"dump_path": "",
"channel": "",
"agora_uid": "",
"agora_token": "",
"agora_appid": "",
"agora_appcert": "",
"agora_channel": "",
"params": {
"spatius_api_key": "${env:SPATIUS_API_KEY|}",
"spatius_app_id": "${env:SPATIUS_APP_ID|}",
"spatius_avatar_id": "",
"agora_uid": "",
"agora_token": "",
"agora_appid": "${env:AGORA_APP_ID|}",
"agora_appcert": "${env:AGORA_APP_CERTIFICATE|}",
"agora_channel": "",
"region": "",
"sample_rate": 24000,
Expand Down
Loading