From f4fefeb64956b1e3c42d791dab22b3ba612d8f3f Mon Sep 17 00:00:00 2001
From: Nitzan Raz <nitz.raz@gmail.com>
Date: Tue, 10 Dec 2024 12:38:33 +0200
Subject: [PATCH] (feat): Support for text iterators in `AsyncElevenLabs`

Why:
Allowing the async client to utilize incoming text streams when generating voice.
Very useful when feeding the realtime output of an LLM into the TTS.

Closes #344

What:
1. Copied `RealtimeTextToSpeechClient` and `text_chunker` into `AsyncRealtimeTextToSpeechClient` and `async_text_chunker`
   Most of the logic is intact, aside from async stuff
2. Added `AsyncRealtimeTextToSpeechClient` into `AsyncElevenLabs` just like `RealtimeTextToSpeechClient` is in `ElevenLabs`
3. Added rudimentary testing

The code is basically a copy-paste of what I found in the repo. We can rewrite it to be more elegant, but I figured parity with the sync code is more important.
---
 src/elevenlabs/client.py       |  53 +++++++++++---
 src/elevenlabs/realtime_tts.py | 122 ++++++++++++++++++++++++++++++++-
 tests/test_async_generation.py |  17 +++++
 3 files changed, 180 insertions(+), 12 deletions(-)

diff --git a/src/elevenlabs/client.py b/src/elevenlabs/client.py
index 69c28a4d..46699905 100644
--- a/src/elevenlabs/client.py
+++ b/src/elevenlabs/client.py
@@ -13,7 +13,7 @@
 from .types import Voice, VoiceSettings, \
   PronunciationDictionaryVersionLocator, Model
 from .environment import ElevenLabsEnvironment
-from .realtime_tts import RealtimeTextToSpeechClient
+from .realtime_tts import RealtimeTextToSpeechClient, AsyncRealtimeTextToSpeechClient
 from .types import OutputFormat
 from .text_to_speech.types.text_to_speech_stream_with_timestamps_response import TextToSpeechStreamWithTimestampsResponse
 
@@ -342,6 +342,25 @@ class AsyncElevenLabs(AsyncBaseElevenLabs):
         api_key="YOUR_API_KEY",
     )
     """
+    def __init__(
+        self,
+        *,
+        base_url: typing.Optional[str] = None,
+        environment: ElevenLabsEnvironment = ElevenLabsEnvironment.PRODUCTION,
+        api_key: typing.Optional[str] = os.getenv("ELEVEN_API_KEY"),
+        timeout: typing.Optional[float] = None,
+        follow_redirects: typing.Optional[bool] = True,
+        httpx_client: typing.Optional[httpx.AsyncClient] = None
+    ):
+        super().__init__(
+            base_url=base_url,
+            environment=environment,
+            api_key=api_key,
+            timeout=timeout,
+            follow_redirects=follow_redirects,
+            httpx_client=httpx_client,
+        )
+        self.text_to_speech = AsyncRealtimeTextToSpeechClient(client_wrapper=self._client_wrapper)
 
     async def clone(
       self,
@@ -468,16 +487,28 @@ async def generate(
             model_id = model.model_id
         
         if stream:
-            return self.text_to_speech.convert_as_stream(
-                voice_id=voice_id,
-                model_id=model_id,
-                voice_settings=voice_settings,
-                optimize_streaming_latency=optimize_streaming_latency,
-                output_format=output_format,
-                text=text,
-                request_options=request_options,
-                pronunciation_dictionary_locators=pronunciation_dictionary_locators
-            )
+            if isinstance(text, str):
+                return self.text_to_speech.convert_as_stream(
+                    voice_id=voice_id,
+                    model_id=model_id,
+                    voice_settings=voice_settings,
+                    optimize_streaming_latency=optimize_streaming_latency,
+                    output_format=output_format,
+                    text=text,
+                    request_options=request_options,
+                    pronunciation_dictionary_locators=pronunciation_dictionary_locators
+                )
+            elif isinstance(text, AsyncIterator):
+                return self.text_to_speech.convert_realtime(  # type: ignore
+                    voice_id=voice_id,
+                    voice_settings=voice_settings,
+                    output_format=output_format,
+                    text=text,
+                    request_options=request_options,
+                    model_id=model_id
+                )
+            else: 
+                raise ApiError(body="Text is neither a string nor an iterator.")
         else:
             if not isinstance(text, str):
                 raise ApiError(body="Text must be a string when stream is False.")
diff --git a/src/elevenlabs/realtime_tts.py b/src/elevenlabs/realtime_tts.py
index f8680de1..920c0ac4 100644
--- a/src/elevenlabs/realtime_tts.py
+++ b/src/elevenlabs/realtime_tts.py
@@ -5,15 +5,17 @@
 import json
 import base64
 import websockets
+import asyncio
 
 from websockets.sync.client import connect
+from websockets.client import connect as async_connect
 
 from .core.api_error import ApiError
 from .core.jsonable_encoder import jsonable_encoder
 from .core.remove_none_from_dict import remove_none_from_dict
 from .core.request_options import RequestOptions
 from .types.voice_settings import VoiceSettings
-from .text_to_speech.client import TextToSpeechClient
+from .text_to_speech.client import TextToSpeechClient, AsyncTextToSpeechClient
 from .types import OutputFormat
 from .text_to_speech.types.text_to_speech_stream_with_timestamps_response import TextToSpeechStreamWithTimestampsResponse
 
@@ -38,6 +40,22 @@ def text_chunker(chunks: typing.Iterator[str]) -> typing.Iterator[str]:
     if buffer != "":
         yield buffer + " "
 
+async def async_text_chunker(chunks: typing.AsyncIterator[str]) -> typing.AsyncIterator[str]:
+    """Used during input streaming to chunk text blocks and set last char to space"""
+    splitters = (".", ",", "?", "!", ";", ":", "—", "-", "(", ")", "[", "]", "}", " ")
+    buffer = ""
+    async for text in chunks:
+        if buffer.endswith(splitters):
+            yield buffer if buffer.endswith(" ") else buffer + " "
+            buffer = text
+        elif text.startswith(splitters):
+            output = buffer + text[0]
+            yield output if output.endswith(" ") else output + " "
+            buffer = text[1:]
+        else:
+            buffer += text
+    if buffer != "":
+        yield buffer + " "
 
 class RealtimeTextToSpeechClient(TextToSpeechClient):
 
@@ -139,6 +157,7 @@ def get_text() -> typing.Iterator[str]:
                 elif ce.code != 1000:
                     raise ApiError(body=ce.reason, status_code=ce.code)
 
+
     def convert_realtime_full(
         self,
         voice_id: str,
@@ -235,3 +254,104 @@ def get_text() -> typing.Iterator[str]:
                 elif ce.code != 1000:
                     raise ApiError(body=ce.reason, status_code=ce.code)
 
+
+class AsyncRealtimeTextToSpeechClient(AsyncTextToSpeechClient):
+
+    async def convert_realtime(
+        self,
+        voice_id: str,
+        *,
+        text: typing.AsyncIterator[str],
+        model_id: typing.Optional[str] = OMIT,
+        output_format: typing.Optional[OutputFormat] = "mp3_44100_128",
+        voice_settings: typing.Optional[VoiceSettings] = OMIT,
+        request_options: typing.Optional[RequestOptions] = None,
+    ) -> typing.AsyncIterator[bytes]:
+        """
+        Converts text into speech using a voice of your choice and returns audio.
+
+        Parameters:
+            - voice_id: str. Voice ID to be used, you can use https://api.elevenlabs.io/v1/voices to list all the available voices.
+            
+            - text: typing.Iterator[str]. The text that will get converted into speech.
+
+            - model_id: typing.Optional[str]. Identifier of the model that will be used, you can query them using GET /v1/models. The model needs to have support for text to speech, you can check this using the can_do_text_to_speech property.
+
+            - voice_settings: typing.Optional[VoiceSettings]. Voice settings overriding stored setttings for the given voice. They are applied only on the given request.
+
+            - request_options: typing.Optional[RequestOptions]. Request-specific configuration.
+        ---
+        from elevenlabs import PronunciationDictionaryVersionLocator, VoiceSettings
+        from elevenlabs.client import ElevenLabs
+
+        def get_text() -> typing.Iterator[str]:
+            yield "Hello, how are you?"
+            yield "I am fine, thank you."
+
+        client = ElevenLabs(
+            api_key="YOUR_API_KEY",
+        )
+        client.text_to_speech.convert_realtime(
+            voice_id="string",
+            text=get_text(),
+            model_id="string",
+            voice_settings=VoiceSettings(
+                stability=1.1,
+                similarity_boost=1.1,
+                style=1.1,
+                use_speaker_boost=True,
+            ),
+        )
+        """
+        async with async_connect(
+            urllib.parse.urljoin(
+              "wss://api.elevenlabs.io/", 
+              f"v1/text-to-speech/{jsonable_encoder(voice_id)}/stream-input?model_id={model_id}&output_format={output_format}"
+            ),
+            extra_headers=jsonable_encoder(
+                remove_none_from_dict(
+                    {
+                        **self._client_wrapper.get_headers(),
+                        **(request_options.get("additional_headers", {}) if request_options is not None else {}),
+                    }
+                )
+            )
+        ) as socket:
+            try:
+                await socket.send(json.dumps(
+                    dict(
+                        text=" ",
+                        try_trigger_generation=True,
+                        voice_settings=voice_settings.dict() if voice_settings else None,
+                        generation_config=dict(
+                            chunk_length_schedule=[50],
+                        ),
+                    )
+                ))
+            except websockets.exceptions.ConnectionClosedError as ce:
+                raise ApiError(body=ce.reason, status_code=ce.code)
+
+            try:
+                async for text_chunk in async_text_chunker(text):
+                    data = dict(text=text_chunk, try_trigger_generation=True)
+                    await socket.send(json.dumps(data))
+                    try:
+                        async with asyncio.timeout(1e-4):
+                            data = json.loads(await socket.recv())
+                        if "audio" in data and data["audio"]:
+                            yield base64.b64decode(data["audio"])  # type: ignore
+                    except TimeoutError:
+                        pass
+
+                await socket.send(json.dumps(dict(text="")))
+
+                while True:
+
+                    data = json.loads(await socket.recv())
+                    if "audio" in data and data["audio"]:
+                        yield base64.b64decode(data["audio"])  # type: ignore
+            except websockets.exceptions.ConnectionClosed as ce:
+                if "message" in data:
+                    raise ApiError(body=data, status_code=ce.code)
+                elif ce.code != 1000:
+                    raise ApiError(body=ce.reason, status_code=ce.code)
diff --git a/tests/test_async_generation.py b/tests/test_async_generation.py
index 1ed11ab6..92020336 100644
--- a/tests/test_async_generation.py
+++ b/tests/test_async_generation.py
@@ -29,3 +29,20 @@ async def main():
         if not IN_GITHUB:
             play(out)
     asyncio.run(main())
+
+def test_generate_stream() -> None:
+    async def main():
+        async def text_stream():
+            yield "Hi there, I'm Eleven "
+            yield "I'm a text to speech API "
+
+        audio_stream = await async_client.generate(
+            text=text_stream(),
+            voice="Nicole",
+            model="eleven_monolingual_v1",
+            stream=True
+        )
+
+        if not IN_GITHUB:
+            stream(audio_stream)  # type: ignore
+    asyncio.run(main())