From ec7a8d6bcc2bca420e2d5af5b26162c497915500 Mon Sep 17 00:00:00 2001 From: Nitzan Raz Date: Wed, 11 Sep 2024 16:47:31 +0300 Subject: [PATCH] (feat): Support for text iterators in `AsyncElevenLabs` Why: Allowing the async client to utilize incoming text streams when generating voice. Very useful when feeding the realtime output of an LLM into the TTS. Closes #344 What: 1. Copied `RealtimeTextToSpeechClient` and `text_chunker` into `AsyncRealtimeTextToSpeechClient` and `async_text_chunker` Most of the logic is intact, aside from async stuff 2. Added `AsyncRealtimeTextToSpeechClient` into `AsyncElevenLabs` just like `RealtimeTextToSpeechClient` is in `ElevenLabs` 3. Added rudimentary testing The code is basically a copy-paste of what I found in the repo. We can rewrite it to be more elegant, but I figured parity with the sync code is more important. --- src/elevenlabs/client.py | 53 +++++++++++--- src/elevenlabs/realtime_tts.py | 122 ++++++++++++++++++++++++++++++++- tests/test_async_generation.py | 17 +++++ 3 files changed, 180 insertions(+), 12 deletions(-) diff --git a/src/elevenlabs/client.py b/src/elevenlabs/client.py index a0dac65..3216326 100644 --- a/src/elevenlabs/client.py +++ b/src/elevenlabs/client.py @@ -13,7 +13,7 @@ from .types import Voice, VoiceSettings, \ PronunciationDictionaryVersionLocator, Model from .environment import ElevenLabsEnvironment -from .realtime_tts import RealtimeTextToSpeechClient +from .realtime_tts import RealtimeTextToSpeechClient, AsyncRealtimeTextToSpeechClient from .types import OutputFormat @@ -257,6 +257,25 @@ class AsyncElevenLabs(AsyncBaseElevenLabs): api_key="YOUR_API_KEY", ) """ + def __init__( + self, + *, + base_url: typing.Optional[str] = None, + environment: ElevenLabsEnvironment = ElevenLabsEnvironment.PRODUCTION, + api_key: typing.Optional[str] = os.getenv("ELEVEN_API_KEY"), + timeout: typing.Optional[float] = None, + follow_redirects: typing.Optional[bool] = True, + httpx_client: typing.Optional[httpx.AsyncClient] = None + ): + super().__init__( + base_url=base_url, + environment=environment, + api_key=api_key, + timeout=timeout, + follow_redirects=follow_redirects, + httpx_client=httpx_client, + ) + self.text_to_speech = AsyncRealtimeTextToSpeechClient(client_wrapper=self._client_wrapper) async def clone( self, @@ -383,16 +402,28 @@ async def generate( model_id = model.model_id if stream: - return self.text_to_speech.convert_as_stream( - voice_id=voice_id, - model_id=model_id, - voice_settings=voice_settings, - optimize_streaming_latency=optimize_streaming_latency, - output_format=output_format, - text=text, - request_options=request_options, - pronunciation_dictionary_locators=pronunciation_dictionary_locators - ) + if isinstance(text, str): + return self.text_to_speech.convert_as_stream( + voice_id=voice_id, + model_id=model_id, + voice_settings=voice_settings, + optimize_streaming_latency=optimize_streaming_latency, + output_format=output_format, + text=text, + request_options=request_options, + pronunciation_dictionary_locators=pronunciation_dictionary_locators + ) + elif isinstance(text, AsyncIterator): + return self.text_to_speech.convert_realtime( # type: ignore + voice_id=voice_id, + voice_settings=voice_settings, + output_format=output_format, + text=text, + request_options=request_options, + model_id=model_id + ) + else: + raise ApiError(body="Text is neither a string nor an iterator.") else: if not isinstance(text, str): raise ApiError(body="Text must be a string when stream is False.") diff --git a/src/elevenlabs/realtime_tts.py b/src/elevenlabs/realtime_tts.py index 146431d..ea2b242 100644 --- a/src/elevenlabs/realtime_tts.py +++ b/src/elevenlabs/realtime_tts.py @@ -5,15 +5,17 @@ import json import base64 import websockets +import asyncio from websockets.sync.client import connect +from websockets.client import connect as async_connect from .core.api_error import ApiError from .core.jsonable_encoder import jsonable_encoder from .core.remove_none_from_dict import remove_none_from_dict from .core.request_options import RequestOptions from .types.voice_settings import VoiceSettings -from .text_to_speech.client import TextToSpeechClient +from .text_to_speech.client import TextToSpeechClient, AsyncTextToSpeechClient from .types import OutputFormat # this is used as the default value for optional parameters @@ -37,6 +39,22 @@ def text_chunker(chunks: typing.Iterator[str]) -> typing.Iterator[str]: if buffer != "": yield buffer + " " +async def async_text_chunker(chunks: typing.AsyncIterator[str]) -> typing.AsyncIterator[str]: + """Used during input streaming to chunk text blocks and set last char to space""" + splitters = (".", ",", "?", "!", ";", ":", "—", "-", "(", ")", "[", "]", "}", " ") + buffer = "" + async for text in chunks: + if buffer.endswith(splitters): + yield buffer if buffer.endswith(" ") else buffer + " " + buffer = text + elif text.startswith(splitters): + output = buffer + text[0] + yield output if output.endswith(" ") else output + " " + buffer = text[1:] + else: + buffer += text + if buffer != "": + yield buffer + " " class RealtimeTextToSpeechClient(TextToSpeechClient): @@ -137,3 +155,105 @@ def get_text() -> typing.Iterator[str]: raise ApiError(body=data, status_code=ce.code) elif ce.code != 1000: raise ApiError(body=ce.reason, status_code=ce.code) + + +class AsyncRealtimeTextToSpeechClient(AsyncTextToSpeechClient): + + async def convert_realtime( + self, + voice_id: str, + *, + text: typing.AsyncIterator[str], + model_id: typing.Optional[str] = OMIT, + output_format: typing.Optional[OutputFormat] = "mp3_44100_128", + voice_settings: typing.Optional[VoiceSettings] = OMIT, + request_options: typing.Optional[RequestOptions] = None, + ) -> typing.AsyncIterator[bytes]: + """ + Converts text into speech using a voice of your choice and returns audio. + + Parameters: + - voice_id: str. Voice ID to be used, you can use https://api.elevenlabs.io/v1/voices to list all the available voices. + + - text: typing.Iterator[str]. The text that will get converted into speech. + + - model_id: typing.Optional[str]. Identifier of the model that will be used, you can query them using GET /v1/models. The model needs to have support for text to speech, you can check this using the can_do_text_to_speech property. + + - voice_settings: typing.Optional[VoiceSettings]. Voice settings overriding stored setttings for the given voice. They are applied only on the given request. + + - request_options: typing.Optional[RequestOptions]. Request-specific configuration. + --- + from elevenlabs import PronunciationDictionaryVersionLocator, VoiceSettings + from elevenlabs.client import ElevenLabs + + def get_text() -> typing.Iterator[str]: + yield "Hello, how are you?" + yield "I am fine, thank you." + + client = ElevenLabs( + api_key="YOUR_API_KEY", + ) + client.text_to_speech.convert_realtime( + voice_id="string", + text=get_text(), + model_id="string", + voice_settings=VoiceSettings( + stability=1.1, + similarity_boost=1.1, + style=1.1, + use_speaker_boost=True, + ), + ) + """ + async with async_connect( + urllib.parse.urljoin( + "wss://api.elevenlabs.io/", + f"v1/text-to-speech/{jsonable_encoder(voice_id)}/stream-input?model_id={model_id}&output_format={output_format}" + ), + extra_headers=jsonable_encoder( + remove_none_from_dict( + { + **self._client_wrapper.get_headers(), + **(request_options.get("additional_headers", {}) if request_options is not None else {}), + } + ) + ) + ) as socket: + try: + await socket.send(json.dumps( + dict( + text=" ", + try_trigger_generation=True, + voice_settings=voice_settings.dict() if voice_settings else None, + generation_config=dict( + chunk_length_schedule=[50], + ), + ) + )) + except websockets.exceptions.ConnectionClosedError as ce: + raise ApiError(body=ce.reason, status_code=ce.code) + + try: + async for text_chunk in async_text_chunker(text): + data = dict(text=text_chunk, try_trigger_generation=True) + await socket.send(json.dumps(data)) + try: + async with asyncio.timeout(1e-4): + data = json.loads(await socket.recv()) + if "audio" in data and data["audio"]: + yield base64.b64decode(data["audio"]) # type: ignore + except TimeoutError: + pass + + await socket.send(json.dumps(dict(text=""))) + + while True: + + data = json.loads(await socket.recv()) + if "audio" in data and data["audio"]: + yield base64.b64decode(data["audio"]) # type: ignore + except websockets.exceptions.ConnectionClosed as ce: + if "message" in data: + raise ApiError(body=data, status_code=ce.code) + elif ce.code != 1000: + raise ApiError(body=ce.reason, status_code=ce.code) diff --git a/tests/test_async_generation.py b/tests/test_async_generation.py index 1ed11ab..9202033 100644 --- a/tests/test_async_generation.py +++ b/tests/test_async_generation.py @@ -29,3 +29,20 @@ async def main(): if not IN_GITHUB: play(out) asyncio.run(main()) + +def test_generate_stream() -> None: + async def main(): + async def text_stream(): + yield "Hi there, I'm Eleven " + yield "I'm a text to speech API " + + audio_stream = await async_client.generate( + text=text_stream(), + voice="Nicole", + model="eleven_monolingual_v1", + stream=True + ) + + if not IN_GITHUB: + stream(audio_stream) # type: ignore + asyncio.run(main())