Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 71 additions & 17 deletions livekit-agents/livekit/agents/voice/transcription/synchronizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@

STANDARD_SPEECH_RATE = 3.83 # hyphens (syllables) per second

# max time aclose() waits for the forwarding/speaking-rate tasks to drain before
# cancelling them, so a stalled downstream output can't deadlock segment rotation
_SEGMENT_ACLOSE_TIMEOUT = 5.0


@dataclass
class _TextSyncOptions:
Expand Down Expand Up @@ -132,6 +136,7 @@ class _SegmentSynchronizerImpl:

def __init__(self, options: _TextSyncOptions, *, next_in_chain: io.TextOutput | None) -> None:
self._opts = options
self._id = utils.shortuuid("SSI_") # to correlate warnings to a specific impl
self._text_data = _TextData(word_stream=self._opts.word_tokenizer.stream())
self._audio_data = _AudioData(sr_stream=self._opts.speaking_rate_detector.stream())

Expand Down Expand Up @@ -159,6 +164,10 @@ def __init__(self, options: _TextSyncOptions, *, next_in_chain: io.TextOutput |
self._playback_completed = False
self._interrupted = False

@property
def id(self) -> str:
return self._id

@property
def closed(self) -> bool:
return self._close_future.done()
Expand All @@ -173,12 +182,16 @@ def text_input_ended(self) -> bool:

def on_playback_started(self, start_time: float) -> None:
if self.closed:
logger.warning("_SegmentSynchronizerImpl.on_playback_started called after close")
logger.warning(
"_SegmentSynchronizerImpl.on_playback_started called after close",
extra={"impl_id": self._id},
)
return

if self._start_fut.is_set():
logger.warning(
"_SegmentSynchronizerImpl.on_playback_started called after start_fut is set"
"_SegmentSynchronizerImpl.on_playback_started called after start_fut is set",
extra={"impl_id": self._id},
)
return

Expand All @@ -187,15 +200,21 @@ def on_playback_started(self, start_time: float) -> None:

def push_audio(self, frame: rtc.AudioFrame) -> None:
if self.closed:
logger.warning("_SegmentSynchronizerImpl.push_audio called after close")
logger.warning(
"_SegmentSynchronizerImpl.push_audio called after close",
extra={"impl_id": self._id},
)
return

self._audio_data.sr_stream.push_frame(frame)
self._audio_data.pushed_duration += frame.duration

def end_audio_input(self) -> None:
if self.closed:
logger.warning("_SegmentSynchronizerImpl.end_audio_input called after close")
logger.warning(
"_SegmentSynchronizerImpl.end_audio_input called after close",
extra={"impl_id": self._id},
)
return

self._audio_data.done = True
Expand All @@ -204,7 +223,9 @@ def end_audio_input(self) -> None:

def push_text(self, text: str) -> None:
if self.closed:
logger.warning("_SegmentSynchronizerImpl.push_text called after close")
logger.warning(
"_SegmentSynchronizerImpl.push_text called after close", extra={"impl_id": self._id}
)
return

start_time, end_time = None, None
Expand All @@ -224,7 +245,10 @@ def push_text(self, text: str) -> None:

def end_text_input(self) -> None:
if self.closed:
logger.warning("_SegmentSynchronizerImpl.end_text_input called after close")
logger.warning(
"_SegmentSynchronizerImpl.end_text_input called after close",
extra={"impl_id": self._id},
)
return

self._text_data.done = True
Expand All @@ -234,7 +258,9 @@ def end_text_input(self) -> None:

def pause(self) -> None:
if self.closed:
logger.warning("_SegmentSynchronizerImpl.pause called after close")
logger.warning(
"_SegmentSynchronizerImpl.pause called after close", extra={"impl_id": self._id}
)
return

if self._paused_wall_time is None:
Expand All @@ -243,7 +269,9 @@ def pause(self) -> None:

def resume(self) -> None:
if self.closed:
logger.warning("_SegmentSynchronizerImpl.resume called after close")
logger.warning(
"_SegmentSynchronizerImpl.resume called after close", extra={"impl_id": self._id}
)
return

if self._paused_wall_time is not None:
Expand Down Expand Up @@ -275,14 +303,21 @@ def _reestimate_speed(self) -> None:

def mark_playback_finished(self, *, playback_position: float, interrupted: bool) -> None:
if self.closed:
logger.warning("_SegmentSynchronizerImpl.playback_finished called after close")
logger.warning(
"_SegmentSynchronizerImpl.playback_finished called after close",
extra={"impl_id": self._id},
)
return

self._interrupted = interrupted
if not self._text_data.done or not self._audio_data.done:
logger.warning(
"_SegmentSynchronizerImpl.playback_finished called before text/audio input is done",
extra={"text_done": self._text_data.done, "audio_done": self._audio_data.done},
extra={
"impl_id": self._id,
"text_done": self._text_data.done,
"audio_done": self._audio_data.done,
},
)
return

Expand Down Expand Up @@ -403,8 +438,18 @@ async def aclose(self) -> None:
self._output_enabled_ev.set()
await self._text_data.word_stream.aclose()
await self._audio_data.sr_stream.aclose()
await self._capture_atask
await self._speaking_rate_atask

# bound the drain of the forwarding/speaking-rate tasks
_, pending = await asyncio.wait(
[self._capture_atask, self._speaking_rate_atask],
timeout=_SEGMENT_ACLOSE_TIMEOUT,
)
if pending:
logger.warning(
"_SegmentSynchronizerImpl.aclose timed out draining tasks, cancelling them",
extra={"impl_id": self._id},
)
await utils.aio.cancel_and_wait(*pending)


class TranscriptSynchronizer:
Expand Down Expand Up @@ -502,10 +547,14 @@ async def _rotate_segment_task(self, old_task: asyncio.Task[None] | None) -> Non
with contextlib.suppress(Exception):
await old_task

old_impl = self._impl
try:
await self._impl.aclose()
await old_impl.aclose()
except Exception:
logger.exception("failed to close segment synchronizer impl during rotation")
logger.exception(
"failed to close segment synchronizer impl during rotation",
extra={"impl_id": old_impl.id},
)

# always create a new impl even if aclose() failed, to avoid leaving
# self._impl pointing to a closed impl which causes the agent to get stuck
Expand All @@ -522,7 +571,10 @@ def rotate_segment(self) -> None:
return

if self._rotate_segment_atask and not self._rotate_segment_atask.done():
logger.warning("rotate_segment called while previous segment is still being rotated")
logger.warning(
"rotate_segment called while previous segment is still being rotated",
extra={"impl_id": self._impl.id},
)

self._rotate_segment_atask = asyncio.create_task(
self._rotate_segment_task(self._rotate_segment_atask)
Expand Down Expand Up @@ -578,7 +630,8 @@ async def capture_frame(self, frame: rtc.AudioFrame) -> None:
if self._synchronizer._impl.audio_input_ended:
# this should not happen if `on_playback_finished` is called after each flush
logger.warning(
"_SegmentSynchronizerImpl audio marked as ended in capture audio, rotating segment"
"_SegmentSynchronizerImpl audio marked as ended in capture audio, rotating segment",
extra={"impl_id": self._synchronizer._impl.id},
)
self._synchronizer.rotate_segment()
await self._synchronizer.barrier()
Expand Down Expand Up @@ -690,7 +743,8 @@ async def capture_text(self, text: str) -> None:
if self._synchronizer._impl.text_input_ended:
# this should not happen if `on_playback_finished` is called after each flush
logger.warning(
"_SegmentSynchronizerImpl text marked as ended in capture text, rotating segment"
"_SegmentSynchronizerImpl text marked as ended in capture text, rotating segment",
extra={"impl_id": self._synchronizer._impl.id},
)
self._synchronizer.rotate_segment()
await self._synchronizer.barrier()
Expand Down
Loading