Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat): Support for text iterators in AsyncElevenLabs #358

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 42 additions & 11 deletions src/elevenlabs/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .types import Voice, VoiceSettings, \
PronunciationDictionaryVersionLocator, Model
from .environment import ElevenLabsEnvironment
from .realtime_tts import RealtimeTextToSpeechClient
from .realtime_tts import RealtimeTextToSpeechClient, AsyncRealtimeTextToSpeechClient
from .types import OutputFormat


Expand Down Expand Up @@ -257,6 +257,25 @@ class AsyncElevenLabs(AsyncBaseElevenLabs):
api_key="YOUR_API_KEY",
)
"""
def __init__(
self,
*,
base_url: typing.Optional[str] = None,
environment: ElevenLabsEnvironment = ElevenLabsEnvironment.PRODUCTION,
api_key: typing.Optional[str] = os.getenv("ELEVEN_API_KEY"),
timeout: typing.Optional[float] = None,
follow_redirects: typing.Optional[bool] = True,
httpx_client: typing.Optional[httpx.AsyncClient] = None
):
super().__init__(
base_url=base_url,
environment=environment,
api_key=api_key,
timeout=timeout,
follow_redirects=follow_redirects,
httpx_client=httpx_client,
)
self.text_to_speech = AsyncRealtimeTextToSpeechClient(client_wrapper=self._client_wrapper)

async def clone(
self,
Expand Down Expand Up @@ -383,16 +402,28 @@ async def generate(
model_id = model.model_id

if stream:
return self.text_to_speech.convert_as_stream(
voice_id=voice_id,
model_id=model_id,
voice_settings=voice_settings,
optimize_streaming_latency=optimize_streaming_latency,
output_format=output_format,
text=text,
request_options=request_options,
pronunciation_dictionary_locators=pronunciation_dictionary_locators
)
if isinstance(text, str):
return self.text_to_speech.convert_as_stream(
voice_id=voice_id,
model_id=model_id,
voice_settings=voice_settings,
optimize_streaming_latency=optimize_streaming_latency,
output_format=output_format,
text=text,
request_options=request_options,
pronunciation_dictionary_locators=pronunciation_dictionary_locators
)
elif isinstance(text, AsyncIterator):
return self.text_to_speech.convert_realtime( # type: ignore
voice_id=voice_id,
voice_settings=voice_settings,
output_format=output_format,
text=text,
request_options=request_options,
model_id=model_id
)
else:
raise ApiError(body="Text is neither a string nor an iterator.")
else:
if not isinstance(text, str):
raise ApiError(body="Text must be a string when stream is False.")
Expand Down
122 changes: 121 additions & 1 deletion src/elevenlabs/realtime_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
import json
import base64
import websockets
import asyncio

from websockets.sync.client import connect
from websockets.client import connect as async_connect

from .core.api_error import ApiError
from .core.jsonable_encoder import jsonable_encoder
from .core.remove_none_from_dict import remove_none_from_dict
from .core.request_options import RequestOptions
from .types.voice_settings import VoiceSettings
from .text_to_speech.client import TextToSpeechClient
from .text_to_speech.client import TextToSpeechClient, AsyncTextToSpeechClient
from .types import OutputFormat

# this is used as the default value for optional parameters
Expand All @@ -37,6 +39,22 @@ def text_chunker(chunks: typing.Iterator[str]) -> typing.Iterator[str]:
if buffer != "":
yield buffer + " "

async def async_text_chunker(chunks: typing.AsyncIterator[str]) -> typing.AsyncIterator[str]:
"""Used during input streaming to chunk text blocks and set last char to space"""
splitters = (".", ",", "?", "!", ";", ":", "—", "-", "(", ")", "[", "]", "}", " ")
buffer = ""
async for text in chunks:
if buffer.endswith(splitters):
yield buffer if buffer.endswith(" ") else buffer + " "
buffer = text
elif text.startswith(splitters):
output = buffer + text[0]
yield output if output.endswith(" ") else output + " "
buffer = text[1:]
else:
buffer += text
if buffer != "":
yield buffer + " "

class RealtimeTextToSpeechClient(TextToSpeechClient):

Expand Down Expand Up @@ -137,3 +155,105 @@ def get_text() -> typing.Iterator[str]:
raise ApiError(body=data, status_code=ce.code)
elif ce.code != 1000:
raise ApiError(body=ce.reason, status_code=ce.code)


class AsyncRealtimeTextToSpeechClient(AsyncTextToSpeechClient):

async def convert_realtime(
self,
voice_id: str,
*,
text: typing.AsyncIterator[str],
model_id: typing.Optional[str] = OMIT,
output_format: typing.Optional[OutputFormat] = "mp3_44100_128",
voice_settings: typing.Optional[VoiceSettings] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> typing.AsyncIterator[bytes]:
"""
Converts text into speech using a voice of your choice and returns audio.

Parameters:
- voice_id: str. Voice ID to be used, you can use https://api.elevenlabs.io/v1/voices to list all the available voices.

- text: typing.Iterator[str]. The text that will get converted into speech.

- model_id: typing.Optional[str]. Identifier of the model that will be used, you can query them using GET /v1/models. The model needs to have support for text to speech, you can check this using the can_do_text_to_speech property.

- voice_settings: typing.Optional[VoiceSettings]. Voice settings overriding stored setttings for the given voice. They are applied only on the given request.

- request_options: typing.Optional[RequestOptions]. Request-specific configuration.
---
from elevenlabs import PronunciationDictionaryVersionLocator, VoiceSettings
from elevenlabs.client import ElevenLabs

def get_text() -> typing.Iterator[str]:
yield "Hello, how are you?"
yield "I am fine, thank you."

client = ElevenLabs(
api_key="YOUR_API_KEY",
)
client.text_to_speech.convert_realtime(
voice_id="string",
text=get_text(),
model_id="string",
voice_settings=VoiceSettings(
stability=1.1,
similarity_boost=1.1,
style=1.1,
use_speaker_boost=True,
),
)
"""
async with async_connect(
urllib.parse.urljoin(
"wss://api.elevenlabs.io/",
f"v1/text-to-speech/{jsonable_encoder(voice_id)}/stream-input?model_id={model_id}&output_format={output_format}"
),
extra_headers=jsonable_encoder(
remove_none_from_dict(
{
**self._client_wrapper.get_headers(),
**(request_options.get("additional_headers", {}) if request_options is not None else {}),
}
)
)
) as socket:
try:
await socket.send(json.dumps(
dict(
text=" ",
try_trigger_generation=True,
voice_settings=voice_settings.dict() if voice_settings else None,
generation_config=dict(
chunk_length_schedule=[50],
),
)
))
except websockets.exceptions.ConnectionClosedError as ce:
raise ApiError(body=ce.reason, status_code=ce.code)

try:
async for text_chunk in async_text_chunker(text):
data = dict(text=text_chunk, try_trigger_generation=True)
await socket.send(json.dumps(data))
try:
async with asyncio.timeout(1e-4):
data = json.loads(await socket.recv())
if "audio" in data and data["audio"]:
yield base64.b64decode(data["audio"]) # type: ignore
except TimeoutError:
pass

await socket.send(json.dumps(dict(text="")))

while True:

data = json.loads(await socket.recv())
if "audio" in data and data["audio"]:
yield base64.b64decode(data["audio"]) # type: ignore
except websockets.exceptions.ConnectionClosed as ce:
if "message" in data:
raise ApiError(body=data, status_code=ce.code)
elif ce.code != 1000:
raise ApiError(body=ce.reason, status_code=ce.code)
17 changes: 17 additions & 0 deletions tests/test_async_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,20 @@ async def main():
if not IN_GITHUB:
play(out)
asyncio.run(main())

def test_generate_stream() -> None:
async def main():
async def text_stream():
yield "Hi there, I'm Eleven "
yield "I'm a text to speech API "

audio_stream = await async_client.generate(
text=text_stream(),
voice="Nicole",
model="eleven_monolingual_v1",
stream=True
)

if not IN_GITHUB:
stream(audio_stream) # type: ignore
asyncio.run(main())