Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions livekit-agents/livekit/agents/voice/avatar/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,6 @@ async def start(self, agent_session: AgentSession, room: rtc.Room) -> None:
"release resources when the job shuts down"
)

if agent_session._started and (audio_output := agent_session.output.audio) is not None:
logger.warning(
(
"AvatarSession.start() was called after AgentSession.start(); "
"the existing audio output may be replaced by the avatar. "
"Please start the avatar session before AgentSession.start() to avoid this."
),
extra={"audio_output": audio_output.label},
)

self._room = room
self._agent_session = agent_session
self._agent_session.on("conversation_item_added", self._on_conversation_item_added)
Expand All @@ -120,6 +110,7 @@ async def wait_for_join(self, *, timeout: float | None = 30.0) -> None:
``timeout`` seconds. Pass ``timeout=None`` to wait indefinitely.
"""
if self._wait_avatar_join_task is None:
# TODO(long): fix when this called before the room is connected
return
if timeout is None:
await self._wait_avatar_join_task
Expand Down
142 changes: 128 additions & 14 deletions livekit-agents/livekit/agents/voice/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def __init__(
sample_rate: The sample rate required by the audio sink, if None, any sample rate is accepted
""" # noqa: E501
super().__init__()
self.__next_in_chain = next_in_chain
self._sample_rate = sample_rate
self.__label = label
self.__capturing = False
Expand All @@ -155,26 +154,39 @@ def __init__(
playback_position=0, interrupted=False
)

if self.next_in_chain:
self.next_in_chain.on(
"playback_finished",
lambda ev: self.on_playback_finished(
interrupted=ev.interrupted,
playback_position=ev.playback_position,
synchronized_transcript=ev.synchronized_transcript,
),
)
self.next_in_chain.on(
"playback_started", lambda ev: self.on_playback_started(created_at=ev.created_at)
)
# auto-wrap a bare leaf with a _AudioSinkProxy so the leaf can be
# hot-swapped later without disturbing wrappers above. wrappers that
# cache next_in_chain (e.g. _SyncedAudioOutput) cache the proxy, so
# their references stay valid across swaps
if (
next_in_chain is not None
and next_in_chain.next_in_chain is None
and not isinstance(next_in_chain, _AudioSinkProxy)
):
next_in_chain = _AudioSinkProxy(next_in_chain)

self._next_in_chain: AudioOutput | None = next_in_chain
if next_in_chain is not None:
next_in_chain.on("playback_finished", self._forward_next_playback_finished)
next_in_chain.on("playback_started", self._forward_next_playback_started)

def _forward_next_playback_finished(self, ev: PlaybackFinishedEvent) -> None:
self.on_playback_finished(
interrupted=ev.interrupted,
playback_position=ev.playback_position,
synchronized_transcript=ev.synchronized_transcript,
)

def _forward_next_playback_started(self, ev: PlaybackStartedEvent) -> None:
self.on_playback_started(created_at=ev.created_at)

@property
def label(self) -> str:
return self.__label

@property
def next_in_chain(self) -> AudioOutput | None:
return self.__next_in_chain
return self._next_in_chain

def on_playback_started(self, *, created_at: float) -> None:
self.emit("playback_started", PlaybackStartedEvent(created_at=created_at))
Expand Down Expand Up @@ -275,6 +287,87 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}(label={self.label!r}, next={self.next_in_chain!r})"


class _AudioSinkProxy(AudioOutput):
"""Stable swap point at the bottom of an audio wrapper chain.

Wrappers above hold a reference to the proxy; the actual sink lives in
``next_in_chain`` and can be replaced via :meth:`set_next_in_chain` without
disturbing them. When detached (``next_in_chain`` is None), the proxy acts
as a no-op sink that still cooperates with the playback-finished protocol
so upstream wrappers don't hang.

Only the proxy has a mutable ``next_in_chain`` — regular ``AudioOutput``
subclasses store theirs immutably at construction.
"""

def __init__(self, next_in_chain: AudioOutput | None = None) -> None:
super().__init__(
label="AudioSinkProxy",
capabilities=AudioOutputCapabilities(pause=True),
next_in_chain=None,
)
# whether the wrapper above us has attached the proxy; set_next_in_chain
# uses this to decide if a new/old downstream should be notified
self._attached = False
if next_in_chain is not None:
self.set_next_in_chain(next_in_chain)

def on_attached(self) -> None:
self._attached = True
super().on_attached()

def on_detached(self) -> None:
self._attached = False
super().on_detached()

def set_next_in_chain(self, new: AudioOutput | None) -> None:
"""Replace the downstream sink, transferring playback listeners
and on_attached/on_detached state.
"""
if new is self._next_in_chain:
return

old = self._next_in_chain
if old is not None:
old.off("playback_finished", self._forward_next_playback_finished)
old.off("playback_started", self._forward_next_playback_started)
if self._attached:
old.on_detached()

self._next_in_chain = new

if new is not None:
new.on("playback_finished", self._forward_next_playback_finished)
new.on("playback_started", self._forward_next_playback_started)
if self._attached:
new.on_attached()

@property
def sample_rate(self) -> int | None:
return self.next_in_chain.sample_rate if self.next_in_chain else None

@property
def can_pause(self) -> bool:
return not self.next_in_chain or self.next_in_chain.can_pause

async def capture_frame(self, frame: rtc.AudioFrame) -> None:
await super().capture_frame(frame)
if self.next_in_chain:
await self.next_in_chain.capture_frame(frame)

def flush(self) -> None:
super().flush()
if self.next_in_chain:
self.next_in_chain.flush()
else:
# no real sink; synthesize a playback_finished
self.on_playback_finished(playback_position=0.0, interrupted=True)
Comment thread
devin-ai-integration[bot] marked this conversation as resolved.
Outdated

def clear_buffer(self) -> None:
if self.next_in_chain:
self.next_in_chain.clear_buffer()
Comment thread
longcw marked this conversation as resolved.
Outdated


class TextOutput(ABC):
def __init__(self, *, label: str, next_in_chain: TextOutput | None) -> None:
self.__label = label
Expand Down Expand Up @@ -568,6 +661,27 @@ def audio(self, sink: AudioOutput | None) -> None:
else:
self._audio_sink.on_detached()

def set_audio_sink(self, sink: AudioOutput | None, *, preserve_wrappers: bool = False) -> None:
Comment thread
longcw marked this conversation as resolved.
Outdated
"""Set the audio sink at the bottom of the chain.

When ``preserve_wrappers`` is True, walks the chain looking for a
:class:`_AudioSinkProxy` and swaps its downstream — leaving wrappers
like :class:`TranscriptSynchronizer` and :class:`RecorderAudioOutput`
attached. Falls back to ``self.audio = sink`` if no proxy is present
(no wrappers, or the chain hasn't been set up yet).

With the default ``preserve_wrappers=False``, this is exactly
equivalent to ``self.audio = sink``.
"""
if preserve_wrappers:
cur = self._audio_sink
while cur is not None:
if isinstance(cur, _AudioSinkProxy):
cur.set_next_in_chain(sink)
return
cur = cur.next_in_chain
self.audio = sink

@property
def transcription(self) -> TextOutput | None:
return self._transcription_sink
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,6 @@ def __init__(
super().__init__(
label="RecorderIO",
next_in_chain=audio_output,
sample_rate=audio_output.sample_rate if audio_output else None,
# TODO: support pause
capabilities=io.AudioOutputCapabilities(pause=True), # depends on the next_in_chain
)
Expand All @@ -364,6 +363,12 @@ def __init__(
self.__current_pause_start: float | None = None
self.__pause_wall_times: list[tuple[float, float]] = []

@property
def sample_rate(self) -> int | None:
if self._sample_rate is not None:
return self._sample_rate
return self.next_in_chain.sample_rate if self.next_in_chain else None

@property
def started_wall_time(self) -> float | None:
return self.__started_time
Expand Down
28 changes: 17 additions & 11 deletions livekit-agents/livekit/agents/voice/transcription/synchronizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ async def _rotate_segment_task(self, old_task: asyncio.Task[None] | None) -> Non
# always create a new impl even if aclose() failed, to avoid leaving
# self._impl pointing to a closed impl which causes the agent to get stuck
self._impl = _SegmentSynchronizerImpl(
options=self._opts, next_in_chain=self._text_output._next_in_chain
options=self._opts, next_in_chain=self._text_output.next_in_chain
)

# apply the current pause state to the new impl
Expand Down Expand Up @@ -545,19 +545,24 @@ def __init__(
super().__init__(
label="TranscriptSynchronizer",
next_in_chain=next_in_chain,
sample_rate=next_in_chain.sample_rate,
capabilities=io.AudioOutputCapabilities(pause=True),
)
self._next_in_chain: io.AudioOutput = next_in_chain # redefined for better typing
self._synchronizer = synchronizer
self._pushed_duration: float = 0.0

@property
def sample_rate(self) -> int | None:
if self._sample_rate is not None:
return self._sample_rate
return self.next_in_chain.sample_rate if self.next_in_chain else None
Comment on lines +553 to +557

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚩 sample_rate is now dynamic instead of fixed at construction time

Both _SyncedAudioOutput (synchronizer.py:553-557) and RecorderAudioOutput (recorder_io.py:366-370) changed from setting sample_rate at construction time (passing it to super().__init__) to computing it dynamically via a property that delegates to self.next_in_chain.sample_rate. This is intentional for the hot-swap use case — after swapping the leaf sink, the sample rate should reflect the new sink's requirements. However, in generation.py:418-428, the resampler is created lazily on the first frame and never recreated. If the sample rate changes after the first frame (e.g., from a hot-swap), the resampler won't be updated. This is a pre-existing limitation, not introduced by this PR, but it becomes more relevant now that hot-swapping is supported.

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.


async def capture_frame(self, frame: rtc.AudioFrame) -> None:
# using barrier() on capture should be sufficient, flush() must not be called if
# capture_frame isn't completed
await self._synchronizer.barrier()

await self._next_in_chain.capture_frame(frame) # passthrough audio
if self.next_in_chain:
await self.next_in_chain.capture_frame(frame) # passthrough audio
await super().capture_frame(frame)
self._pushed_duration += frame.duration

Expand Down Expand Up @@ -587,7 +592,8 @@ async def capture_frame(self, frame: rtc.AudioFrame) -> None:

def flush(self) -> None:
super().flush()
self._next_in_chain.flush()
if self.next_in_chain:
self.next_in_chain.flush()

if not self._synchronizer.enabled:
return
Expand All @@ -600,7 +606,8 @@ def flush(self) -> None:
self._synchronizer._impl.end_audio_input()

def clear_buffer(self) -> None:
self._next_in_chain.clear_buffer()
if self.next_in_chain:
self.next_in_chain.clear_buffer()

# this is going to be automatically called by the next_in_chain
def on_playback_started(self, *, created_at: float) -> None:
Expand Down Expand Up @@ -663,7 +670,6 @@ def __init__(
self, synchronizer: TranscriptSynchronizer, *, next_in_chain: io.TextOutput | None
) -> None:
super().__init__(label="TranscriptSynchronizer", next_in_chain=next_in_chain)
self._next_in_chain: io.TextOutput | None = next_in_chain
self._synchronizer = synchronizer
self._capturing = False

Expand All @@ -682,8 +688,8 @@ async def capture_text(self, text: str) -> None:
"still active; transcription sync is disabled. This usually means "
"session.output.audio was replaced after AgentSession.start()."
)
if self._next_in_chain:
await self._next_in_chain.capture_text(text)
if self.next_in_chain:
await self.next_in_chain.capture_text(text)
return

self._capturing = True
Expand All @@ -699,8 +705,8 @@ async def capture_text(self, text: str) -> None:

def flush(self) -> None:
if not self._synchronizer.enabled: # passthrough text if the synchronizer is disabled
if self._next_in_chain:
self._next_in_chain.flush()
if self.next_in_chain:
self.next_in_chain.flush()
return

if not self._capturing:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,12 @@ async def start(
)
self.session_id = session_details.get("sessionId")

agent_session.output.audio = DataStreamAudioOutput(
room=room,
destination_identity=self._avatar_participant_identity,
sample_rate=SAMPLE_RATE,
wait_remote_track=rtc.TrackKind.KIND_VIDEO,
agent_session.output.set_audio_sink(
DataStreamAudioOutput(
room=room,
destination_identity=self._avatar_participant_identity,
sample_rate=SAMPLE_RATE,
wait_remote_track=rtc.TrackKind.KIND_VIDEO,
),
preserve_wrappers=True,
)
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,12 @@ async def start(
},
)

agent_session.output.audio = DataStreamAudioOutput(
room=room,
destination_identity=self._avatar_participant_identity,
sample_rate=SAMPLE_RATE,
wait_remote_track=rtc.TrackKind.KIND_VIDEO,
agent_session.output.set_audio_sink(
DataStreamAudioOutput(
room=room,
destination_identity=self._avatar_participant_identity,
sample_rate=SAMPLE_RATE,
wait_remote_track=rtc.TrackKind.KIND_VIDEO,
),
preserve_wrappers=True,
)
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,14 @@ async def _shutdown_session() -> None:
)
session_task_mapping[room.name] = self.conversation_id

agent_session.output.audio = DataStreamAudioOutput(
room=room,
destination_identity="listener",
sample_rate=SAMPLE_RATE,
# wait_remote_track=rtc.TrackKind.KIND_VIDEO,
agent_session.output.set_audio_sink(
DataStreamAudioOutput(
room=room,
destination_identity="listener",
sample_rate=SAMPLE_RATE,
# wait_remote_track=rtc.TrackKind.KIND_VIDEO,
),
preserve_wrappers=True,
)
except AvatarTalkException as e:
logger.error(e)
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,13 @@ async def start(
logger.debug("starting avatar session")
await self._start_agent(livekit_url, livekit_token)

agent_session.output.audio = DataStreamAudioOutput(
room=room,
destination_identity=self._avatar_participant_identity,
wait_remote_track=rtc.TrackKind.KIND_VIDEO,
agent_session.output.set_audio_sink(
DataStreamAudioOutput(
room=room,
destination_identity=self._avatar_participant_identity,
wait_remote_track=rtc.TrackKind.KIND_VIDEO,
),
preserve_wrappers=True,
)

async def _start_agent(self, livekit_url: str, livekit_token: str) -> None:
Expand Down
Loading
Loading