diff --git a/playground/streaming/synthesizer/synthesize.py b/playground/streaming/synthesizer/synthesize.py index b27153c4b0..f1101d8df3 100644 --- a/playground/streaming/synthesizer/synthesize.py +++ b/playground/streaming/synthesizer/synthesize.py @@ -2,8 +2,8 @@ from vocode.streaming.models.message import BaseMessage from vocode.streaming.models.synthesizer import AzureSynthesizerConfig -from vocode.streaming.output_device.base_output_device import BaseOutputDevice -from vocode.streaming.output_device.speaker_output import SpeakerOutput +from vocode.streaming.output_device.abstract_output_device import AbstractOutputDevice +from vocode.streaming.output_device.blocking_speaker_output import BlockingSpeakerOutput from vocode.streaming.synthesizer.azure_synthesizer import AzureSynthesizer from vocode.streaming.synthesizer.base_synthesizer import BaseSynthesizer from vocode.streaming.utils import get_chunk_size_per_second @@ -19,7 +19,7 @@ async def speak( synthesizer: BaseSynthesizer, - output_device: BaseOutputDevice, + output_device: AbstractOutputDevice, message: BaseMessage, ): message_sent = message.text @@ -58,7 +58,7 @@ async def speak( return message_sent, cut_off async def main(): - speaker_output = SpeakerOutput.from_default_device() + speaker_output = BlockingSpeakerOutput.from_default_device() synthesizer = AzureSynthesizer(AzureSynthesizerConfig.from_output_device(speaker_output)) try: while True: diff --git a/quickstarts/streaming_conversation.py b/quickstarts/streaming_conversation.py index 9a482299fc..54d774f665 100644 --- a/quickstarts/streaming_conversation.py +++ b/quickstarts/streaming_conversation.py @@ -49,7 +49,6 @@ async def main(): speaker_output, ) = create_streaming_microphone_input_and_speaker_output( use_default_devices=False, - use_blocking_speaker_output=True, # this moves the playback to a separate thread, set to False to use the main thread ) conversation = StreamingConversation( diff --git a/tests/fakedata/conversation.py b/tests/fakedata/conversation.py index c538b45aef..2df0a8afaa 100644 --- a/tests/fakedata/conversation.py +++ b/tests/fakedata/conversation.py @@ -8,7 +8,7 @@ from vocode.streaming.models.message import BaseMessage from vocode.streaming.models.synthesizer import PlayHtSynthesizerConfig, SynthesizerConfig from vocode.streaming.models.transcriber import DeepgramTranscriberConfig, TranscriberConfig -from vocode.streaming.output_device.base_output_device import BaseOutputDevice +from vocode.streaming.output_device.abstract_output_device import AbstractOutputDevice from vocode.streaming.streaming_conversation import StreamingConversation from vocode.streaming.synthesizer.base_synthesizer import BaseSynthesizer from vocode.streaming.telephony.constants import DEFAULT_CHUNK_SIZE, DEFAULT_SAMPLING_RATE @@ -35,7 +35,7 @@ ) -class DummyOutputDevice(BaseOutputDevice): +class DummyOutputDevice(AbstractOutputDevice): def consume_nonblocking(self, chunk: bytes): pass diff --git a/tests/streaming/test_streaming_conversation.py b/tests/streaming/test_streaming_conversation.py index e38b6efadc..2df0f6eb4b 100644 --- a/tests/streaming/test_streaming_conversation.py +++ b/tests/streaming/test_streaming_conversation.py @@ -16,7 +16,7 @@ from vocode.streaming.models.events import Sender from vocode.streaming.models.transcriber import Transcription from vocode.streaming.models.transcript import ActionStart, Message, Transcript -from vocode.streaming.utils.worker import AsyncWorker +from vocode.streaming.utils.worker import AbstractAsyncWorker class ShouldIgnoreUtteranceTestCase(BaseModel): @@ -25,7 +25,7 @@ class ShouldIgnoreUtteranceTestCase(BaseModel): expected: bool -async def _consume_worker_output(worker: AsyncWorker, timeout: float = 0.1): +async def _consume_worker_output(worker: AbstractAsyncWorker, timeout: float = 0.1): try: return await asyncio.wait_for(worker.output_queue.get(), timeout=timeout) except asyncio.TimeoutError: diff --git a/vocode/helpers.py b/vocode/helpers.py index c7088f64d4..448b67da38 100644 --- a/vocode/helpers.py +++ b/vocode/helpers.py @@ -10,7 +10,6 @@ from vocode.streaming.output_device.blocking_speaker_output import ( BlockingSpeakerOutput as BlockingStreamingSpeakerOutput, ) -from vocode.streaming.output_device.speaker_output import SpeakerOutput as StreamingSpeakerOutput from vocode.turn_based.input_device.microphone_input import ( MicrophoneInput as TurnBasedMicrophoneInput, ) @@ -31,15 +30,10 @@ def create_streaming_microphone_input_and_speaker_output( output_device_name: Optional[str] = None, mic_sampling_rate=None, speaker_sampling_rate=None, - use_blocking_speaker_output=False, ): return _create_microphone_input_and_speaker_output( microphone_class=StreamingMicrophoneInput, - speaker_class=( - BlockingStreamingSpeakerOutput - if use_blocking_speaker_output - else StreamingSpeakerOutput - ), + speaker_class=BlockingStreamingSpeakerOutput, use_default_devices=use_default_devices, input_device_name=input_device_name, output_device_name=output_device_name, @@ -70,7 +64,6 @@ def _create_microphone_input_and_speaker_output( microphone_class: typing.Type[Union[StreamingMicrophoneInput, TurnBasedMicrophoneInput]], speaker_class: typing.Type[ Union[ - StreamingSpeakerOutput, BlockingStreamingSpeakerOutput, TurnBasedSpeakerOutput, ] @@ -83,7 +76,7 @@ def _create_microphone_input_and_speaker_output( ) -> Union[ Tuple[ StreamingMicrophoneInput, - Union[StreamingSpeakerOutput, BlockingStreamingSpeakerOutput], + Union[BlockingStreamingSpeakerOutput], ], Tuple[TurnBasedMicrophoneInput, TurnBasedSpeakerOutput], ]: diff --git a/vocode/streaming/client_backend/conversation.py b/vocode/streaming/client_backend/conversation.py index 5970c475bc..31b34c82b5 100644 --- a/vocode/streaming/client_backend/conversation.py +++ b/vocode/streaming/client_backend/conversation.py @@ -116,7 +116,7 @@ def __init__( async def handle_event(self, event: Event): if event.type == EventType.TRANSCRIPT: transcript_event = typing.cast(TranscriptEvent, event) - self.output_device.consume_transcript(transcript_event) + await self.output_device.send_transcript(transcript_event) # logger.debug(event.dict()) def restart(self, output_device: WebsocketOutputDevice): diff --git a/vocode/streaming/models/synthesizer.py b/vocode/streaming/models/synthesizer.py index 6820543c2d..01cf1936b8 100644 --- a/vocode/streaming/models/synthesizer.py +++ b/vocode/streaming/models/synthesizer.py @@ -6,7 +6,7 @@ from .audio import AudioEncoding, SamplingRate from .model import BaseModel, TypedModel from vocode.streaming.models.client_backend import OutputAudioConfig -from vocode.streaming.output_device.base_output_device import BaseOutputDevice +from vocode.streaming.output_device.abstract_output_device import AbstractOutputDevice from vocode.streaming.telephony.constants import DEFAULT_AUDIO_ENCODING, DEFAULT_SAMPLING_RATE @@ -46,7 +46,7 @@ class Config: arbitrary_types_allowed = True @classmethod - def from_output_device(cls, output_device: BaseOutputDevice, **kwargs): + def from_output_device(cls, output_device: AbstractOutputDevice, **kwargs): return cls( sampling_rate=output_device.sampling_rate, audio_encoding=output_device.audio_encoding, diff --git a/vocode/streaming/output_device/abstract_output_device.py b/vocode/streaming/output_device/abstract_output_device.py new file mode 100644 index 0000000000..6626384535 --- /dev/null +++ b/vocode/streaming/output_device/abstract_output_device.py @@ -0,0 +1,20 @@ +from abc import abstractmethod +import asyncio +from vocode.streaming.output_device.audio_chunk import AudioChunk +from vocode.streaming.utils.worker import AbstractAsyncWorker, InterruptibleEvent + + +class AbstractOutputDevice(AbstractAsyncWorker[InterruptibleEvent[AudioChunk]]): + + def __init__(self, sampling_rate: int, audio_encoding): + super().__init__(input_queue=asyncio.Queue()) + self.sampling_rate = sampling_rate + self.audio_encoding = audio_encoding + + @abstractmethod + async def play(self, chunk: bytes): + pass + + @abstractmethod + def interrupt(self): + pass diff --git a/vocode/streaming/output_device/audio_chunk.py b/vocode/streaming/output_device/audio_chunk.py new file mode 100644 index 0000000000..d58e081adf --- /dev/null +++ b/vocode/streaming/output_device/audio_chunk.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass, field +from enum import Enum +from uuid import UUID +import uuid + + +class ChunkState(int, Enum): + UNPLAYED = 0 + PLAYED = 1 + INTERRUPTED = 2 + + +@dataclass +class AudioChunk: + data: bytes + state: ChunkState = ChunkState.UNPLAYED + chunk_id: UUID = field(default_factory=uuid.uuid4) + + @staticmethod + def on_play(): + pass + + @staticmethod + def on_interrupt(): + pass + + def __hash__(self) -> int: + return hash(self.chunk_id) diff --git a/vocode/streaming/output_device/base_output_device.py b/vocode/streaming/output_device/base_output_device.py deleted file mode 100644 index 2ce90d5c2f..0000000000 --- a/vocode/streaming/output_device/base_output_device.py +++ /dev/null @@ -1,16 +0,0 @@ -from vocode.streaming.models.audio import AudioEncoding - - -class BaseOutputDevice: - def __init__(self, sampling_rate: int, audio_encoding: AudioEncoding): - self.sampling_rate = sampling_rate - self.audio_encoding = audio_encoding - - def start(self): - pass - - def consume_nonblocking(self, chunk: bytes): - raise NotImplemented - - def terminate(self): - pass diff --git a/vocode/streaming/output_device/blocking_speaker_output.py b/vocode/streaming/output_device/blocking_speaker_output.py index ba1515d9c1..da0c8d04a0 100644 --- a/vocode/streaming/output_device/blocking_speaker_output.py +++ b/vocode/streaming/output_device/blocking_speaker_output.py @@ -6,26 +6,21 @@ import sounddevice as sd from vocode.streaming.models.audio import AudioEncoding -from vocode.streaming.output_device.base_output_device import BaseOutputDevice +from vocode.streaming.output_device.rate_limit_interruptions_output_device import ( + RateLimitInterruptionsOutputDevice, +) from vocode.streaming.utils.worker import ThreadAsyncWorker +DEFAULT_SAMPLING_RATE = 44100 -class BlockingSpeakerOutput(BaseOutputDevice, ThreadAsyncWorker): - DEFAULT_SAMPLING_RATE = 44100 - def __init__( - self, - device_info: dict, - sampling_rate: Optional[int] = None, - audio_encoding: AudioEncoding = AudioEncoding.LINEAR16, - ): +class _PlaybackWorker(ThreadAsyncWorker[bytes]): + + def __init__(self, *, device_info: dict, sampling_rate: int): + super().__init__(input_queue=asyncio.Queue()) + self.sampling_rate = sampling_rate self.device_info = device_info - sampling_rate = sampling_rate or int( - self.device_info.get("default_samplerate", self.DEFAULT_SAMPLING_RATE) - ) - self.input_queue: asyncio.Queue[bytes] = asyncio.Queue() - BaseOutputDevice.__init__(self, sampling_rate, audio_encoding) - ThreadAsyncWorker.__init__(self, self.input_queue) + self.input_queue.put_nowait(self.sampling_rate * b"\x00") self.stream = sd.OutputStream( channels=1, samplerate=self.sampling_rate, @@ -33,12 +28,8 @@ def __init__( device=int(self.device_info["index"]), ) self._ended = False - self.input_queue.put_nowait(self.sampling_rate * b"\x00") self.stream.start() - def start(self): - ThreadAsyncWorker.start(self) - def _run_loop(self): while not self._ended: try: @@ -47,10 +38,42 @@ def _run_loop(self): except queue.Empty: continue - def consume_nonblocking(self, chunk): - ThreadAsyncWorker.consume_nonblocking(self, chunk) - def terminate(self): self._ended = True - ThreadAsyncWorker.terminate(self) + super().terminate() self.stream.close() + + +class BlockingSpeakerOutput(RateLimitInterruptionsOutputDevice): + DEFAULT_SAMPLING_RATE = 44100 + + def __init__( + self, + device_info: dict, + sampling_rate: Optional[int] = None, + audio_encoding: AudioEncoding = AudioEncoding.LINEAR16, + ): + sampling_rate = sampling_rate or int( + device_info.get("default_samplerate", DEFAULT_SAMPLING_RATE) + ) + super().__init__(sampling_rate=sampling_rate, audio_encoding=audio_encoding) + self.playback_worker = _PlaybackWorker(device_info=device_info, sampling_rate=sampling_rate) + self.input_queue: asyncio.Queue[bytes] = asyncio.Queue() + + async def play(self, chunk): + self.playback_worker.consume_nonblocking(chunk) + + def start(self) -> asyncio.Task: + self.playback_worker.start() + return super().start() + + def terminate(self): + self.playback_worker.terminate() + super().terminate() + + @classmethod + def from_default_device( + cls, + **kwargs, + ): + return cls(sd.query_devices(kind="output"), **kwargs) diff --git a/vocode/streaming/output_device/file_output_device.py b/vocode/streaming/output_device/file_output_device.py index d9ab22d8dd..59830da83e 100644 --- a/vocode/streaming/output_device/file_output_device.py +++ b/vocode/streaming/output_device/file_output_device.py @@ -4,7 +4,10 @@ import numpy as np -from .base_output_device import BaseOutputDevice +from vocode.streaming.output_device.rate_limit_interruptions_output_device import ( + RateLimitInterruptionsOutputDevice, +) + from vocode.streaming.models.audio import AudioEncoding from vocode.streaming.utils.worker import ThreadAsyncWorker @@ -27,7 +30,7 @@ def terminate(self): self.wav.close() -class FileOutputDevice(BaseOutputDevice): +class FileOutputDevice(RateLimitInterruptionsOutputDevice): DEFAULT_SAMPLING_RATE = 44100 def __init__( @@ -47,9 +50,13 @@ def __init__( self.wav = wav self.thread_worker = FileWriterWorker(self.queue, wav) + + def start(self) -> asyncio.Task: self.thread_worker.start() + return super().start() - def consume_nonblocking(self, chunk): + async def play(self, chunk: bytes): + # TODO (output device refactor): just dispatch out into a thread to write to the file per block, doesn't need a worker chunk_arr = np.frombuffer(chunk, dtype=np.int16) for i in range(0, chunk_arr.shape[0], self.blocksize): block = np.zeros(self.blocksize, dtype=np.int16) @@ -59,3 +66,4 @@ def consume_nonblocking(self, chunk): def terminate(self): self.thread_worker.terminate() + super().terminate() diff --git a/vocode/streaming/output_device/rate_limit_interruptions_output_device.py b/vocode/streaming/output_device/rate_limit_interruptions_output_device.py new file mode 100644 index 0000000000..76f060945c --- /dev/null +++ b/vocode/streaming/output_device/rate_limit_interruptions_output_device.py @@ -0,0 +1,60 @@ +import asyncio +import time + +from vocode.streaming.constants import PER_CHUNK_ALLOWANCE_SECONDS +from vocode.streaming.models.audio import AudioEncoding +from vocode.streaming.output_device.abstract_output_device import AbstractOutputDevice +from vocode.streaming.output_device.audio_chunk import ChunkState +from vocode.streaming.utils import get_chunk_size_per_second + + +class RateLimitInterruptionsOutputDevice(AbstractOutputDevice): + def __init__( + self, + sampling_rate: int, + audio_encoding: AudioEncoding, + per_chunk_allowance_seconds: float = PER_CHUNK_ALLOWANCE_SECONDS, + ): + super().__init__(sampling_rate, audio_encoding) + self.per_chunk_allowance_seconds = per_chunk_allowance_seconds + + async def _run_loop(self): + while True: + start_time = time.time() + try: + item = await self.input_queue.get() + except asyncio.CancelledError: + return + + self.interruptible_event = item + audio_chunk = item.payload + + if item.is_interrupted(): + audio_chunk.on_interrupt() + audio_chunk.state = ChunkState.INTERRUPTED + continue + + speech_length_seconds = (len(audio_chunk.data)) / get_chunk_size_per_second( + self.audio_encoding, + self.sampling_rate, + ) + await self.play(audio_chunk.data) + audio_chunk.on_play() + audio_chunk.state = ChunkState.PLAYED + end_time = time.time() + await asyncio.sleep( + max( + speech_length_seconds + - (end_time - start_time) + - self.per_chunk_allowance_seconds, + 0, + ), + ) + self.interruptible_event.is_interruptible = False + + def interrupt(self): + """ + For conversations that use rate-limiting playback as above, + no custom logic is needed on interrupt, because to end synthesis, all we need to do is stop sending chunks. + """ + pass diff --git a/vocode/streaming/output_device/speaker_output.py b/vocode/streaming/output_device/speaker_output.py index 543dbbfb0c..b7861dbade 100644 --- a/vocode/streaming/output_device/speaker_output.py +++ b/vocode/streaming/output_device/speaker_output.py @@ -4,11 +4,13 @@ import numpy as np import sounddevice as sd -from .base_output_device import BaseOutputDevice +from .abstract_output_device import AbstractOutputDevice from vocode.streaming.models.audio import AudioEncoding +raise DeprecationWarning("Use BlockingSpeakerOutput instead") -class SpeakerOutput(BaseOutputDevice): + +class SpeakerOutput(AbstractOutputDevice): DEFAULT_SAMPLING_RATE = 44100 def __init__( diff --git a/vocode/streaming/output_device/twilio_output_device.py b/vocode/streaming/output_device/twilio_output_device.py index 0563f08cff..d3c34df017 100644 --- a/vocode/streaming/output_device/twilio_output_device.py +++ b/vocode/streaming/output_device/twilio_output_device.py @@ -3,63 +3,118 @@ import asyncio import base64 import json -from typing import Optional +from typing import Optional, Union +import uuid from fastapi import WebSocket +from pydantic import BaseModel -from vocode.streaming.output_device.base_output_device import BaseOutputDevice +from vocode.streaming.output_device.audio_chunk import AudioChunk, ChunkState +from vocode.streaming.output_device.abstract_output_device import AbstractOutputDevice from vocode.streaming.telephony.constants import DEFAULT_AUDIO_ENCODING, DEFAULT_SAMPLING_RATE from vocode.streaming.utils.create_task import asyncio_create_task_with_done_error_log +from vocode.streaming.utils.worker import InterruptibleEvent -class TwilioOutputDevice(BaseOutputDevice): +class ChunkFinishedMarkMessage(BaseModel): + chunk_id: str + + +MarkMessage = Union[ChunkFinishedMarkMessage] # space for more mark messages + + +class TwilioOutputDevice(AbstractOutputDevice): def __init__(self, ws: Optional[WebSocket] = None, stream_sid: Optional[str] = None): super().__init__(sampling_rate=DEFAULT_SAMPLING_RATE, audio_encoding=DEFAULT_AUDIO_ENCODING) self.ws = ws self.stream_sid = stream_sid self.active = True - self.queue: asyncio.Queue[str] = asyncio.Queue() - self.process_task = asyncio_create_task_with_done_error_log(self.process()) - async def process(self): - while self.active: - message = await self.queue.get() - await self.ws.send_text(message) + self.twilio_events_queue: asyncio.Queue[str] = asyncio.Queue() + self.mark_message_queue: asyncio.Queue[MarkMessage] = asyncio.Queue() + self.unprocessed_audio_chunks_queue: asyncio.Queue[InterruptibleEvent[AudioChunk]] = ( + asyncio.Queue() + ) + + def consume_nonblocking(self, item: InterruptibleEvent[AudioChunk]): + # TODO (output device refactor): think about when interrupted messages enter the queue + synchronicity with the clear message + if not item.is_interrupted(): + self._send_audio_chunk_and_mark(item.payload.data) + self.unprocessed_audio_chunks_queue.put_nowait(item) + + async def play(self, chunk: bytes): + """ + For Twilio, we send all of the audio chunks to be played at once, + and then consume the mark messages to know when to send the on_play / on_interrupt callbacks + """ + pass + + def interrupt(self): + self._send_clear_message() + + def enqueue_mark_message(self, mark_message: MarkMessage): + self.mark_message_queue.put_nowait(mark_message) + + async def _send_twilio_messages(self): + while True: + try: + twilio_event = await self.twilio_events_queue.get() + except asyncio.CancelledError: + return + + await self.ws.send_text(twilio_event) + + async def _process_mark_messages(self): + while True: + try: + mark_message = await self.mark_message_queue.get() + item = await self.unprocessed_audio_chunks_queue.get() + # TODO (output device refactor): cross reference chunk IDs between mark message and audio chunks? + except asyncio.CancelledError: + return + + self.interruptible_event = item + audio_chunk = item.payload + + if item.is_interrupted(): + audio_chunk.on_interrupt() + audio_chunk.state = ChunkState.INTERRUPTED + continue - def consume_nonblocking(self, chunk: bytes): - twilio_message = { + await self.play(audio_chunk.data) + audio_chunk.on_play() + audio_chunk.state = ChunkState.PLAYED + + self.interruptible_event.is_interruptible = False + + async def _run_loop(self): + send_twilio_messages_task = asyncio_create_task_with_done_error_log( + self._send_twilio_messages() + ) + process_mark_messages_task = asyncio_create_task_with_done_error_log( + self._process_mark_messages() + ) + await asyncio.gather(send_twilio_messages_task, process_mark_messages_task) + + def _send_audio_chunk_and_mark(self, chunk: bytes): + media_message = { "event": "media", "streamSid": self.stream_sid, "media": {"payload": base64.b64encode(chunk).decode("utf-8")}, } - self.queue.put_nowait(json.dumps(twilio_message)) - - def send_chunk_finished_mark(self, utterance_id, chunk_idx): - mark_message = { - "event": "mark", - "streamSid": self.stream_sid, - "mark": { - "name": f"chunk-{utterance_id}-{chunk_idx}", - }, - } - self.queue.put_nowait(json.dumps(mark_message)) - - def send_utterance_finished_mark(self, utterance_id): + self.twilio_events_queue.put_nowait(json.dumps(media_message)) mark_message = { "event": "mark", "streamSid": self.stream_sid, "mark": { - "name": f"utterance-{utterance_id}", + "name": str(uuid.uuid4()), }, } - self.queue.put_nowait(json.dumps(mark_message)) + self.twilio_events_queue.put_nowait(json.dumps(mark_message)) - def send_clear_message(self): + def _send_clear_message(self): clear_message = { "event": "clear", "streamSid": self.stream_sid, } - self.queue.put_nowait(json.dumps(clear_message)) - - def terminate(self): - self.process_task.cancel() + self.twilio_events_queue.put_nowait(json.dumps(clear_message)) diff --git a/vocode/streaming/output_device/vonage_output_device.py b/vocode/streaming/output_device/vonage_output_device.py index 6466303226..de32e13c9f 100644 --- a/vocode/streaming/output_device/vonage_output_device.py +++ b/vocode/streaming/output_device/vonage_output_device.py @@ -3,18 +3,19 @@ from fastapi import WebSocket -from vocode.streaming.output_device.base_output_device import BaseOutputDevice -from vocode.streaming.output_device.speaker_output import SpeakerOutput +from vocode.streaming.output_device.blocking_speaker_output import BlockingSpeakerOutput +from vocode.streaming.output_device.rate_limit_interruptions_output_device import ( + RateLimitInterruptionsOutputDevice, +) from vocode.streaming.telephony.constants import ( PCM_SILENCE_BYTE, VONAGE_AUDIO_ENCODING, VONAGE_CHUNK_SIZE, VONAGE_SAMPLING_RATE, ) -from vocode.streaming.utils.create_task import asyncio_create_task_with_done_error_log -class VonageOutputDevice(BaseOutputDevice): +class VonageOutputDevice(RateLimitInterruptionsOutputDevice): def __init__( self, ws: Optional[WebSocket] = None, @@ -22,28 +23,17 @@ def __init__( ): super().__init__(sampling_rate=VONAGE_SAMPLING_RATE, audio_encoding=VONAGE_AUDIO_ENCODING) self.ws = ws - self.active = True - self.queue: asyncio.Queue[bytes] = asyncio.Queue() - self.process_task = asyncio_create_task_with_done_error_log(self.process()) self.output_to_speaker = output_to_speaker if output_to_speaker: - self.output_speaker = SpeakerOutput.from_default_device( + self.output_speaker = BlockingSpeakerOutput.from_default_device( sampling_rate=VONAGE_SAMPLING_RATE, blocksize=VONAGE_CHUNK_SIZE // 2 ) - async def process(self): - while self.active: - chunk = await self.queue.get() - if self.output_to_speaker: - self.output_speaker.consume_nonblocking(chunk) - for i in range(0, len(chunk), VONAGE_CHUNK_SIZE): - subchunk = chunk[i : i + VONAGE_CHUNK_SIZE] - if len(subchunk) % 2 == 1: - subchunk += PCM_SILENCE_BYTE # pad with silence, Vonage goes crazy otherwise - await self.ws.send_bytes(subchunk) - - def consume_nonblocking(self, chunk: bytes): - self.queue.put_nowait(chunk) - - def terminate(self): - self.process_task.cancel() + async def play(self, chunk: bytes): + if self.output_to_speaker: + self.output_speaker.consume_nonblocking(chunk) + for i in range(0, len(chunk), VONAGE_CHUNK_SIZE): + subchunk = chunk[i : i + VONAGE_CHUNK_SIZE] + if len(subchunk) % 2 == 1: + subchunk += PCM_SILENCE_BYTE # pad with silence, Vonage goes crazy otherwise + await self.ws.send_bytes(subchunk) diff --git a/vocode/streaming/output_device/websocket_output_device.py b/vocode/streaming/output_device/websocket_output_device.py index ca0133c145..9b5ce1de4b 100644 --- a/vocode/streaming/output_device/websocket_output_device.py +++ b/vocode/streaming/output_device/websocket_output_device.py @@ -7,11 +7,12 @@ from vocode.streaming.models.audio import AudioEncoding from vocode.streaming.models.transcript import TranscriptEvent from vocode.streaming.models.websocket import AudioMessage, TranscriptMessage -from vocode.streaming.output_device.base_output_device import BaseOutputDevice -from vocode.streaming.utils.create_task import asyncio_create_task_with_done_error_log +from vocode.streaming.output_device.rate_limit_interruptions_output_device import ( + RateLimitInterruptionsOutputDevice, +) -class WebsocketOutputDevice(BaseOutputDevice): +class WebsocketOutputDevice(RateLimitInterruptionsOutputDevice): def __init__(self, ws: WebSocket, sampling_rate: int, audio_encoding: AudioEncoding): super().__init__(sampling_rate, audio_encoding) self.ws = ws @@ -20,25 +21,15 @@ def __init__(self, ws: WebSocket, sampling_rate: int, audio_encoding: AudioEncod def start(self): self.active = True - self.process_task = asyncio_create_task_with_done_error_log(self.process()) + return super().start() def mark_closed(self): self.active = False - async def process(self): - while self.active: - message = await self.queue.get() - await self.ws.send_text(message) + async def play(self, chunk: bytes): + await self.ws.send_text(AudioMessage.from_bytes(chunk).json()) - def consume_nonblocking(self, chunk: bytes): - if self.active: - audio_message = AudioMessage.from_bytes(chunk) - self.queue.put_nowait(audio_message.json()) - - def consume_transcript(self, event: TranscriptEvent): + async def send_transcript(self, event: TranscriptEvent): if self.active: transcript_message = TranscriptMessage.from_event(event) - self.queue.put_nowait(transcript_message.json()) - - def terminate(self): - self.process_task.cancel() + await self.ws.send_text(transcript_message.json()) diff --git a/vocode/streaming/streaming_conversation.py b/vocode/streaming/streaming_conversation.py index 8e47980166..ff8e89e0ac 100644 --- a/vocode/streaming/streaming_conversation.py +++ b/vocode/streaming/streaming_conversation.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +from functools import partial import queue import random import re @@ -39,7 +40,6 @@ from vocode.streaming.constants import ( ALLOWED_IDLE_TIME, CHECK_HUMAN_PRESENT_MESSAGE_CHOICES, - PER_CHUNK_ALLOWANCE_SECONDS, TEXT_TO_SPEECH_CHUNK_SIZE_SECONDS, ) from vocode.streaming.models.actions import EndOfTurn @@ -48,7 +48,8 @@ from vocode.streaming.models.message import BaseMessage, BotBackchannel, LLMToken, SilenceMessage from vocode.streaming.models.transcriber import TranscriberConfig, Transcription from vocode.streaming.models.transcript import Message, Transcript, TranscriptCompleteEvent -from vocode.streaming.output_device.base_output_device import BaseOutputDevice +from vocode.streaming.output_device.audio_chunk import AudioChunk, ChunkState +from vocode.streaming.output_device.abstract_output_device import AbstractOutputDevice from vocode.streaming.synthesizer.base_synthesizer import ( BaseSynthesizer, FillerAudio, @@ -57,7 +58,11 @@ from vocode.streaming.synthesizer.input_streaming_synthesizer import InputStreamingSynthesizer from vocode.streaming.transcriber.base_transcriber import BaseTranscriber from vocode.streaming.transcriber.deepgram_transcriber import DeepgramTranscriber -from vocode.streaming.utils import create_conversation_id, get_chunk_size_per_second +from vocode.streaming.utils import ( + create_conversation_id, + enumerate_async_iter, + get_chunk_size_per_second, +) from vocode.streaming.utils.create_task import asyncio_create_task_with_done_error_log from vocode.streaming.utils.events_manager import EventsManager from vocode.streaming.utils.speed_manager import SpeedManager @@ -102,7 +107,7 @@ LOW_INTERRUPT_SENSITIVITY_BACKCHANNEL_UTTERANCE_LENGTH_THRESHOLD = 3 -OutputDeviceType = TypeVar("OutputDeviceType", bound=BaseOutputDevice) +OutputDeviceType = TypeVar("OutputDeviceType", bound=AbstractOutputDevice) class StreamingConversation(Generic[OutputDeviceType]): @@ -590,7 +595,6 @@ def __init__( synthesizer: BaseSynthesizer, speed_coefficient: float = 1.0, conversation_id: Optional[str] = None, - per_chunk_allowance_seconds: float = PER_CHUNK_ALLOWANCE_SECONDS, events_manager: Optional[EventsManager] = None, ): self.id = conversation_id or create_conversation_id() @@ -655,7 +659,6 @@ def __init__( self.events_manager = events_manager or EventsManager() self.events_task: Optional[asyncio.Task] = None - self.per_chunk_allowance_seconds = per_chunk_allowance_seconds self.transcript = Transcript() self.transcript.attach_events_manager(self.events_manager) @@ -823,6 +826,7 @@ def broadcast_interrupt(self): num_interrupts += 1 except queue.Empty: break + self.output_device.interrupt() self.agent.cancel_current_task() self.agent_responses_worker.cancel_current_task() if self.actions_worker: @@ -886,77 +890,92 @@ async def send_speech_to_output( Returns the message that was sent up to, and a flag if the message was cut off """ + seconds_spoken = 0.0 - async def get_chunks( - output_queue: asyncio.Queue[Optional[SynthesisResult.ChunkResult]], - chunk_generator: AsyncGenerator[SynthesisResult.ChunkResult, None], + def create_on_play_callback( + chunk_idx: int, + processed_event: asyncio.Event, ): - try: - async for chunk_result in chunk_generator: - await output_queue.put(chunk_result) - except asyncio.CancelledError: - pass - finally: - await output_queue.put(None) # sentinel + def _on_play(): + if chunk_idx == 0: + if started_event: + started_event.set() + if first_chunk_span: + self._track_first_chunk(first_chunk_span, synthesis_result) + + nonlocal seconds_spoken + + self.mark_last_action_timestamp() + + seconds_spoken += seconds_per_chunk + if transcript_message: + transcript_message.text = synthesis_result.get_message_up_to(seconds_spoken) + + processed_event.set() + + return _on_play + + def create_on_interrupt_callback( + chunk_idx: int, + processed_event: asyncio.Event, + ): + logged = False + + def _on_interrupt(): + nonlocal logged + if not logged: + logger.debug( + "Interrupted, stopping text to speech after {} chunks".format(chunk_idx), + ) + logged = True + processed_event.set() + + return _on_interrupt if self.transcriber.get_transcriber_config().mute_during_speech: logger.debug("Muting transcriber") self.transcriber.mute() - message_sent = message - cut_off = False - chunk_size = self._get_synthesizer_chunk_size(seconds_per_chunk) - chunk_idx = 0 - seconds_spoken = 0.0 logger.debug(f"Start sending speech {message} to output") first_chunk_span = self._maybe_create_first_chunk_span(synthesis_result, message) - chunk_queue: asyncio.Queue[Optional[SynthesisResult.ChunkResult]] = asyncio.Queue() - get_chunks_task = asyncio_create_task_with_done_error_log( - get_chunks(chunk_queue, synthesis_result.chunk_generator), - ) - first = True - while True: - chunk_result = await chunk_queue.get() - if chunk_result is None: - break - if first and first_chunk_span: - self._track_first_chunk(first_chunk_span, synthesis_result) - first = False - start_time = time.time() - speech_length_seconds = seconds_per_chunk * (len(chunk_result.chunk) / chunk_size) - seconds_spoken = chunk_idx * seconds_per_chunk - if stop_event.is_set(): - logger.debug( - "Interrupted, stopping text to speech after {} chunks".format(chunk_idx), - ) - message_sent = synthesis_result.get_message_up_to(seconds_spoken) - cut_off = True - break - if chunk_idx == 0: - if started_event: - started_event.set() - self.output_device.consume_nonblocking(chunk_result.chunk) - end_time = time.time() - await asyncio.sleep( - max( - speech_length_seconds - - (end_time - start_time) - - self.per_chunk_allowance_seconds, - 0, + audio_chunks: List[AudioChunk] = [] + processed_events: List[asyncio.Event] = [] + async for chunk_idx, chunk_result in enumerate_async_iter(synthesis_result.chunk_generator): + processed_event = asyncio.Event() + audio_chunk = AudioChunk( + data=chunk_result.chunk, + ) + audio_chunk.on_play = create_on_play_callback(chunk_idx, processed_event) + audio_chunk.on_interrupt = create_on_interrupt_callback(chunk_idx, processed_event) + self.output_device.consume_nonblocking( + InterruptibleEvent( + payload=audio_chunk, + is_interruptible=True, + interruption_event=stop_event, ), ) - self.mark_last_action_timestamp() - chunk_idx += 1 - seconds_spoken += seconds_per_chunk - if transcript_message: - transcript_message.text = synthesis_result.get_message_up_to(seconds_spoken) - get_chunks_task.cancel() + audio_chunks.append(audio_chunk) + processed_events.append(processed_event) + + # TODO (output device refactor): consider ramifications of asyncio.gather + await asyncio.gather(*(processed_event.wait() for processed_event in processed_events)) + + maybe_first_interrupted_audio_chunk = next( + ( + audio_chunk + for audio_chunk in audio_chunks + if audio_chunk.state == ChunkState.INTERRUPTED + ), + None, + ) + cut_off = maybe_first_interrupted_audio_chunk is not None + if self.transcriber.get_transcriber_config().mute_during_speech: logger.debug("Unmuting transcriber") self.transcriber.unmute() if transcript_message: - transcript_message.text = message_sent transcript_message.is_final = not cut_off + message_sent = transcript_message.text if transcript_message and cut_off else message if synthesis_result.synthesis_total_span: synthesis_result.synthesis_total_span.finish() return message_sent, cut_off diff --git a/vocode/streaming/telephony/conversation/abstract_phone_conversation.py b/vocode/streaming/telephony/conversation/abstract_phone_conversation.py index 308919a403..1e3cf4c0ff 100644 --- a/vocode/streaming/telephony/conversation/abstract_phone_conversation.py +++ b/vocode/streaming/telephony/conversation/abstract_phone_conversation.py @@ -49,7 +49,6 @@ def __init__( conversation_id: Optional[str] = None, events_manager: Optional[EventsManager] = None, speed_coefficient: float = 1.0, - per_chunk_allowance_seconds: float = 0.01, ): conversation_id = conversation_id or create_conversation_id() ctx_conversation_id.set(conversation_id) @@ -64,7 +63,6 @@ def __init__( agent_factory.create_agent(agent_config), synthesizer_factory.create_synthesizer(synthesizer_config), conversation_id=conversation_id, - per_chunk_allowance_seconds=per_chunk_allowance_seconds, events_manager=events_manager, speed_coefficient=speed_coefficient, ) diff --git a/vocode/streaming/telephony/conversation/mark_message_queue.py b/vocode/streaming/telephony/conversation/mark_message_queue.py deleted file mode 100644 index c4b17b9319..0000000000 --- a/vocode/streaming/telephony/conversation/mark_message_queue.py +++ /dev/null @@ -1,46 +0,0 @@ -import asyncio -from typing import Dict, Union - -from pydantic.v1 import BaseModel - - -class ChunkFinishedMarkMessage(BaseModel): - chunk_idx: int - - -class UtteranceFinishedMarkMessage(BaseModel): - pass - - -MarkMessage = Union[ChunkFinishedMarkMessage, UtteranceFinishedMarkMessage] - - -class MarkMessageQueue: - """A keyed asyncio.Queue for MarkMessage objects""" - - def __init__(self): - self.utterance_queues: Dict[str, asyncio.Queue[MarkMessage]] = {} - - def create_utterance_queue(self, utterance_id: str): - if utterance_id in self.utterance_queues: - raise ValueError(f"utterance_id {utterance_id} already exists") - self.utterance_queues[utterance_id] = asyncio.Queue() - - def put_nowait( - self, - utterance_id: str, - mark_message: MarkMessage, - ): - if utterance_id in self.utterance_queues: - self.utterance_queues[utterance_id].put_nowait(mark_message) - - async def get( - self, - utterance_id: str, - ) -> MarkMessage: - if utterance_id not in self.utterance_queues: - raise ValueError(f"utterance_id {utterance_id} not found") - return await self.utterance_queues[utterance_id].get() - - def delete_utterance_queue(self, utterance_id: str): - del self.utterance_queues[utterance_id] diff --git a/vocode/streaming/telephony/conversation/twilio_phone_conversation.py b/vocode/streaming/telephony/conversation/twilio_phone_conversation.py index 6145d53b05..f3c833726d 100644 --- a/vocode/streaming/telephony/conversation/twilio_phone_conversation.py +++ b/vocode/streaming/telephony/conversation/twilio_phone_conversation.py @@ -1,10 +1,8 @@ -import asyncio import base64 import json import os -import threading from enum import Enum -from typing import AsyncGenerator, Optional +from typing import Optional from fastapi import WebSocket from loguru import logger @@ -15,25 +13,17 @@ from vocode.streaming.models.synthesizer import SynthesizerConfig from vocode.streaming.models.telephony import PhoneCallDirection, TwilioConfig from vocode.streaming.models.transcriber import TranscriberConfig -from vocode.streaming.models.transcript import Message -from vocode.streaming.output_device.twilio_output_device import TwilioOutputDevice +from vocode.streaming.output_device.twilio_output_device import ( + ChunkFinishedMarkMessage, + TwilioOutputDevice, +) from vocode.streaming.synthesizer.abstract_factory import AbstractSynthesizerFactory -from vocode.streaming.synthesizer.base_synthesizer import SynthesisResult -from vocode.streaming.synthesizer.input_streaming_synthesizer import InputStreamingSynthesizer from vocode.streaming.telephony.client.twilio_client import TwilioClient from vocode.streaming.telephony.config_manager.base_config_manager import BaseConfigManager from vocode.streaming.telephony.conversation.abstract_phone_conversation import ( AbstractPhoneConversation, ) -from vocode.streaming.telephony.conversation.mark_message_queue import ( - ChunkFinishedMarkMessage, - MarkMessage, - MarkMessageQueue, - UtteranceFinishedMarkMessage, -) from vocode.streaming.transcriber.abstract_factory import AbstractTranscriberFactory -from vocode.streaming.utils import create_utterance_id -from vocode.streaming.utils.create_task import asyncio_create_task_with_done_error_log from vocode.streaming.utils.events_manager import EventsManager from vocode.streaming.utils.state_manager import TwilioPhoneConversationStateManager @@ -83,7 +73,6 @@ def __init__( synthesizer_factory=synthesizer_factory, speed_coefficient=speed_coefficient, ) - self.mark_message_queue: MarkMessageQueue = MarkMessageQueue() self.config_manager = config_manager self.twilio_config = twilio_config or TwilioConfig( account_sid=os.environ["TWILIO_ACCOUNT_SID"], @@ -140,135 +129,10 @@ async def _handle_ws_message(self, message) -> Optional[TwilioPhoneConversationW chunk = base64.b64decode(media["payload"]) self.receive_audio(chunk) if data["event"] == "mark": - mark_name = data["mark"]["name"] - if mark_name.startswith("chunk-"): - utterance_id, chunk_idx = mark_name.split("-")[1:] - self.mark_message_queue.put_nowait( - utterance_id=utterance_id, - mark_message=ChunkFinishedMarkMessage(chunk_idx=int(chunk_idx)), - ) - elif mark_name.startswith("utterance"): - utterance_id = mark_name.split("-")[1] - self.mark_message_queue.put_nowait( - utterance_id=utterance_id, - mark_message=UtteranceFinishedMarkMessage(), - ) + chunk_id = data["mark"]["name"] + self.output_device.enqueue_mark_message(ChunkFinishedMarkMessage(chunk_id=chunk_id)) elif data["event"] == "stop": logger.debug(f"Media WS: Received event 'stop': {message}") logger.debug("Stopping...") return TwilioPhoneConversationWebsocketAction.CLOSE_WEBSOCKET return None - - async def _send_chunks( - self, - utterance_id: str, - chunk_generator: AsyncGenerator[SynthesisResult.ChunkResult, None], - clear_message_lock: asyncio.Lock, - stop_event: threading.Event, - ): - chunk_idx = 0 - try: - async for chunk_result in chunk_generator: - async with clear_message_lock: - if stop_event.is_set(): - break - self.output_device.consume_nonblocking(chunk_result.chunk) - self.output_device.send_chunk_finished_mark(utterance_id, chunk_idx) - chunk_idx += 1 - except asyncio.CancelledError: - pass - finally: - logger.debug("Finished sending all chunks to Twilio") - self.output_device.send_utterance_finished_mark(utterance_id) - - async def send_speech_to_output( - self, - message: str, - synthesis_result: SynthesisResult, - stop_event: threading.Event, - seconds_per_chunk: float, - transcript_message: Optional[Message] = None, - started_event: Optional[threading.Event] = None, - ): - """In contrast with send_speech_to_output in the base class, this function uses mark messages - to support interruption - we send all chunks to the output device, and then wait for mark messages[0] - that indicate that each chunk has been played. This means that we don't need to depends on asyncio.sleep - to support interruptions. - - Once we receive an interruption signal: - - we send a clear message to Twilio to stop playing all queued audio - - based on the number of mark messages we've received back, we know how many chunks were played and can indicate on the transcript - - [0] https://www.twilio.com/docs/voice/twiml/stream#websocket-messages-to-twilio - """ - - if self.transcriber.get_transcriber_config().mute_during_speech: - logger.debug("Muting transcriber") - self.transcriber.mute() - message_sent = message - cut_off = False - chunk_idx = 0 - seconds_spoken = 0.0 - logger.debug(f"Start sending speech {message} to output") - - utterance_id = create_utterance_id() - self.mark_message_queue.create_utterance_queue(utterance_id) - - first_chunk_span = self._maybe_create_first_chunk_span(synthesis_result, message) - - clear_message_lock = asyncio.Lock() - - asyncio_create_task_with_done_error_log( - self._send_chunks( - utterance_id, - synthesis_result.chunk_generator, - clear_message_lock, - stop_event, - ), - ) - mark_event: MarkMessage - first = True - while True: - mark_event = await self.mark_message_queue.get(utterance_id) - if isinstance(mark_event, UtteranceFinishedMarkMessage): - break - if first and first_chunk_span: - self._track_first_chunk(first_chunk_span, synthesis_result) - first = False - seconds_spoken = mark_event.chunk_idx * seconds_per_chunk - # Lock here so that we check the stop event and send the clear message atomically - # w.r.t. the _send_chunks task which also checks the stop event - # Otherwise, we could send the clear message while _send_chunks is in the middle of sending a chunk - # and the synthesis wouldn't be cleared - async with clear_message_lock: - if stop_event.is_set(): - self.output_device.send_clear_message() - logger.debug( - "Interrupted, stopping text to speech after {} chunks".format(chunk_idx) - ) - message_sent = synthesis_result.get_message_up_to(seconds_spoken) - cut_off = True - break - if chunk_idx == 0: - if started_event: - started_event.set() - self.mark_last_action_timestamp() - chunk_idx += 1 - seconds_spoken += seconds_per_chunk - if transcript_message: - transcript_message.text = synthesis_result.get_message_up_to(seconds_spoken) - self.mark_message_queue.delete_utterance_queue(utterance_id) - if self.transcriber.get_transcriber_config().mute_during_speech: - logger.debug("Unmuting transcriber") - self.transcriber.unmute() - if transcript_message: - # For input streaming synthesizers, we have to buffer the message as it is streamed in - # What is said is federated fully by synthesis_result.get_message_up_to - if isinstance(self.synthesizer, InputStreamingSynthesizer): - message_sent = transcript_message.text - else: - transcript_message.text = message_sent - transcript_message.is_final = not cut_off - if synthesis_result.synthesis_total_span: - synthesis_result.synthesis_total_span.finish() - return message_sent, cut_off diff --git a/vocode/streaming/telephony/conversation/vonage_phone_conversation.py b/vocode/streaming/telephony/conversation/vonage_phone_conversation.py index 65adc73f4a..6460bee619 100644 --- a/vocode/streaming/telephony/conversation/vonage_phone_conversation.py +++ b/vocode/streaming/telephony/conversation/vonage_phone_conversation.py @@ -48,7 +48,6 @@ def __init__( events_manager: Optional[EventsManager] = None, output_to_speaker: bool = False, speed_coefficient: float = 1.0, - per_chunk_allowance_seconds: float = 0.01, noise_suppression: bool = False, ): self.speed_coefficient = speed_coefficient @@ -68,7 +67,6 @@ def __init__( transcriber_factory=transcriber_factory, agent_factory=agent_factory, synthesizer_factory=synthesizer_factory, - per_chunk_allowance_seconds=per_chunk_allowance_seconds, ) self.vonage_config = vonage_config self.telephony_client = VonageClient( diff --git a/vocode/streaming/transcriber/base_transcriber.py b/vocode/streaming/transcriber/base_transcriber.py index 4745b79a65..faec6314fd 100644 --- a/vocode/streaming/transcriber/base_transcriber.py +++ b/vocode/streaming/transcriber/base_transcriber.py @@ -8,7 +8,7 @@ from vocode.streaming.models.audio import AudioEncoding from vocode.streaming.models.transcriber import TranscriberConfig, Transcription from vocode.streaming.utils.speed_manager import SpeedManager -from vocode.streaming.utils.worker import AsyncWorker, ThreadAsyncWorker +from vocode.streaming.utils.worker import AbstractAsyncWorker, ThreadAsyncWorker TranscriberConfigType = TypeVar("TranscriberConfigType", bound=TranscriberConfig) @@ -58,16 +58,16 @@ def terminate(self): pass -class BaseAsyncTranscriber(AbstractTranscriber[TranscriberConfigType], AsyncWorker): +class BaseAsyncTranscriber(AbstractTranscriber[TranscriberConfigType], AbstractAsyncWorker): def __init__(self, transcriber_config: TranscriberConfigType): AbstractTranscriber.__init__(self, transcriber_config) - AsyncWorker.__init__(self, self.input_queue, self.output_queue) + AbstractAsyncWorker.__init__(self, self.input_queue, self.output_queue) async def _run_loop(self): raise NotImplementedError def terminate(self): - AsyncWorker.terminate(self) + AbstractAsyncWorker.terminate(self) class BaseThreadAsyncTranscriber(AbstractTranscriber[TranscriberConfigType], ThreadAsyncWorker): diff --git a/vocode/streaming/utils/__init__.py b/vocode/streaming/utils/__init__.py index 6f2585e090..31dd24998c 100644 --- a/vocode/streaming/utils/__init__.py +++ b/vocode/streaming/utils/__init__.py @@ -135,3 +135,12 @@ async def generate_from_async_iter_with_lookahead( if buffer and stream_length <= lookahead: yield buffer return + + +async def enumerate_async_iter( + async_iter: AsyncIterator[AsyncIteratorGenericType], +) -> AsyncGenerator[Tuple[int, AsyncIteratorGenericType], None]: + i = 0 + async for item in async_iter: + yield i, item + i += 1 diff --git a/vocode/streaming/utils/worker.py b/vocode/streaming/utils/worker.py index 80021b580f..da98658e7c 100644 --- a/vocode/streaming/utils/worker.py +++ b/vocode/streaming/utils/worker.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import ABC, abstractmethod import asyncio import threading from typing import Any, Generic, Optional, TypeVar @@ -12,10 +13,24 @@ WorkerInputType = TypeVar("WorkerInputType") -class AsyncWorker(Generic[WorkerInputType]): +class AbstractWorker(Generic[WorkerInputType], ABC): + @abstractmethod + def consume_nonblocking(self, item: WorkerInputType): + pass + + @abstractmethod + def produce_nonblocking(self, item): + pass + + @abstractmethod + async def _run_loop(self): + pass + + +class AbstractAsyncWorker(AbstractWorker[WorkerInputType]): def __init__( self, - input_queue: asyncio.Queue, + input_queue: asyncio.Queue[WorkerInputType], output_queue: asyncio.Queue = asyncio.Queue(), ) -> None: self.worker_task: Optional[asyncio.Task] = None @@ -36,9 +51,6 @@ def consume_nonblocking(self, item: WorkerInputType): def produce_nonblocking(self, item): self.output_queue.put_nowait(item) - async def _run_loop(self): - raise NotImplementedError - def terminate(self): if self.worker_task: return self.worker_task.cancel() @@ -46,7 +58,7 @@ def terminate(self): return False -class ThreadAsyncWorker(AsyncWorker[WorkerInputType]): +class ThreadAsyncWorker(AbstractAsyncWorker[WorkerInputType]): def __init__( self, input_queue: asyncio.Queue[WorkerInputType], @@ -93,7 +105,7 @@ def terminate(self): return super().terminate() -class AsyncQueueWorker(AsyncWorker): +class AsyncQueueWorker(AbstractAsyncWorker): async def _run_loop(self): while True: try: @@ -177,7 +189,7 @@ def create_interruptible_agent_response_event( InterruptibleEventType = TypeVar("InterruptibleEventType", bound=InterruptibleEvent) -class InterruptibleWorker(AsyncWorker[InterruptibleEventType]): +class InterruptibleWorker(AbstractAsyncWorker[InterruptibleEventType]): def __init__( self, input_queue: asyncio.Queue[InterruptibleEventType], diff --git a/vocode/turn_based/output_device/base_output_device.py b/vocode/turn_based/output_device/abstract_output_device.py similarity index 55% rename from vocode/turn_based/output_device/base_output_device.py rename to vocode/turn_based/output_device/abstract_output_device.py index d54c0c7fd6..d111dd67a2 100644 --- a/vocode/turn_based/output_device/base_output_device.py +++ b/vocode/turn_based/output_device/abstract_output_device.py @@ -1,9 +1,12 @@ +from abc import ABC, abstractmethod from pydub import AudioSegment -class BaseOutputDevice: +class AbstractOutputDevice(ABC): + + @abstractmethod def send_audio(self, audio: AudioSegment) -> None: - raise NotImplementedError + pass def terminate(self): pass diff --git a/vocode/turn_based/output_device/speaker_output.py b/vocode/turn_based/output_device/speaker_output.py index a0b35dc4da..a3f748f9ed 100644 --- a/vocode/turn_based/output_device/speaker_output.py +++ b/vocode/turn_based/output_device/speaker_output.py @@ -4,10 +4,10 @@ import sounddevice as sd from pydub import AudioSegment -from vocode.turn_based.output_device.base_output_device import BaseOutputDevice +from vocode.turn_based.output_device.abstract_output_device import AbstractOutputDevice -class SpeakerOutput(BaseOutputDevice): +class SpeakerOutput(AbstractOutputDevice): DEFAULT_SAMPLING_RATE = 44100 def __init__( diff --git a/vocode/turn_based/turn_based_conversation.py b/vocode/turn_based/turn_based_conversation.py index faddb613b0..0eb9e506a8 100644 --- a/vocode/turn_based/turn_based_conversation.py +++ b/vocode/turn_based/turn_based_conversation.py @@ -2,7 +2,7 @@ from vocode.turn_based.agent.base_agent import BaseAgent from vocode.turn_based.input_device.base_input_device import BaseInputDevice -from vocode.turn_based.output_device.base_output_device import BaseOutputDevice +from vocode.turn_based.output_device.abstract_output_device import AbstractOutputDevice from vocode.turn_based.synthesizer.base_synthesizer import BaseSynthesizer from vocode.turn_based.transcriber.base_transcriber import BaseTranscriber @@ -14,7 +14,7 @@ def __init__( transcriber: BaseTranscriber, agent: BaseAgent, synthesizer: BaseSynthesizer, - output_device: BaseOutputDevice, + output_device: AbstractOutputDevice, ): self.input_device = input_device self.transcriber = transcriber