Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[draft] Support AWS plugin for TTS and STT #1302

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,5 @@ jobs:
-p livekit.plugins.anthropic \
-p livekit.plugins.fal \
-p livekit.plugins.playai \
-p livekit.plugins.assemblyai
-p livekit.plugins.assemblyai \
-p livekit.plugins.aws
2 changes: 2 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ jobs:
PLAYHT_USER_ID: ${{ secrets.PLAYHT_USER_ID }}
GOOGLE_APPLICATION_CREDENTIALS: google.json
PYTEST_ADDOPTS: "--color=yes"
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
working-directory: tests
run: |
echo "$GOOGLE_CREDENTIALS_JSON" > google.json
Expand Down
3 changes: 2 additions & 1 deletion livekit-plugins/install_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ pip install \
"${SCRIPT_DIR}/livekit-plugins-rag" \
"${SCRIPT_DIR}/livekit-plugins-playai" \
"${SCRIPT_DIR}/livekit-plugins-silero" \
"${SCRIPT_DIR}/livekit-plugins-turn-detector"
"${SCRIPT_DIR}/livekit-plugins-turn-detector" \
"${SCRIPT_DIR}/livekit-plugins-aws"
13 changes: 13 additions & 0 deletions livekit-plugins/livekit-plugins-aws/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# LiveKit Plugins AWS

Agent Framework plugin for services from AWS Services. Currently supports STT and TTS.

## Installation

```bash
pip install livekit-plugins-aws
```

## Pre-requisites

You'll need to specify an AWS Access Key and a Deployment Region. They can be set as environment variables: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY` and `AWS_DEFAULT_REGION`, respectively.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .stt import STT, SpeechStream
from .tts import TTS, ChunkedStream
from .version import __version__

__all__ = ["STT", "SpeechStream", "TTS", "ChunkedStream", "__version__"]

from livekit.agents import Plugin


class AWSPlugin(Plugin):
def __init__(self) -> None:
super().__init__(__name__, __version__, __package__)


Plugin.register_plugin(AWSPlugin())
77 changes: 77 additions & 0 deletions livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os
from typing import Literal

import boto3


def _get_aws_credentials(
api_key: str | None, api_secret: str | None, region: str | None
):
region = region or os.environ.get("AWS_DEFAULT_REGION")
if not region:
raise ValueError(
"AWS_DEFAULT_REGION must be set using the argument or by setting the AWS_DEFAULT_REGION environment variable."
)

# If API key and secret are provided, create a session with them
if api_key and api_secret:
session = boto3.Session(
aws_access_key_id=api_key,
aws_secret_access_key=api_secret,
region_name=region,
)
else:
session = boto3.Session(region_name=region)

credentials = session.get_credentials()
if not credentials or not credentials.access_key or not credentials.secret_key:
raise ValueError("No valid AWS credentials found.")
return credentials.access_key, credentials.secret_key


TTS_SPEECH_ENGINE = Literal["standard", "neural", "long-form", "generative"]
TTS_LANGUAGE = Literal[
"arb",
"cmn-CN",
"cy-GB",
"da-DK",
"de-DE",
"en-AU",
"en-GB",
"en-GB-WLS",
"en-IN",
"en-US",
"es-ES",
"es-MX",
"es-US",
"fr-CA",
"fr-FR",
"is-IS",
"it-IT",
"ja-JP",
"hi-IN",
"ko-KR",
"nb-NO",
"nl-NL",
"pl-PL",
"pt-BR",
"pt-PT",
"ro-RO",
"ru-RU",
"sv-SE",
"tr-TR",
"en-NZ",
"en-ZA",
"ca-ES",
"de-AT",
"yue-CN",
"ar-AE",
"fi-FI",
"en-IE",
"nl-BE",
"fr-BE",
"cs-CZ",
"de-CH",
]

TTS_OUTPUT_FORMAT = Literal["mp3", "pcm"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import logging

logger = logging.getLogger("livekit.plugins.aws")
Empty file.
234 changes: 234 additions & 0 deletions livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/stt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import asyncio
from dataclasses import dataclass
from typing import Optional

from amazon_transcribe.client import TranscribeStreamingClient
from amazon_transcribe.model import TranscriptEvent, TranscriptResultStream
from livekit import rtc
from livekit.agents import (
DEFAULT_API_CONNECT_OPTIONS,
APIConnectOptions,
stt,
utils,
)

from ._utils import _get_aws_credentials
from .log import logger


@dataclass
class STTOptions:
speech_region: str
sample_rate: int
language: str
encoding: str
vocabulary_name: Optional[str]
session_id: Optional[str]
vocab_filter_method: Optional[str]
vocab_filter_name: Optional[str]
show_speaker_label: Optional[bool]
enable_channel_identification: Optional[bool]
number_of_channels: Optional[int]
enable_partial_results_stabilization: Optional[bool]
partial_results_stability: Optional[str]
language_model_name: Optional[str]


class STT(stt.STT):
def __init__(
self,
*,
speech_region: str = "us-east-1",
api_key: str | None = None,
api_secret: str | None = None,
sample_rate: int = 48000,
language: str = "en-US",
encoding: str = "pcm",
vocabulary_name: Optional[str] = None,
session_id: Optional[str] = None,
vocab_filter_method: Optional[str] = None,
vocab_filter_name: Optional[str] = None,
show_speaker_label: Optional[bool] = None,
enable_channel_identification: Optional[bool] = None,
number_of_channels: Optional[int] = None,
enable_partial_results_stabilization: Optional[bool] = None,
partial_results_stability: Optional[str] = None,
language_model_name: Optional[str] = None,
):
super().__init__(
capabilities=stt.STTCapabilities(streaming=True, interim_results=True)
)

self._api_key, self._api_secret = _get_aws_credentials(
api_key, api_secret, speech_region
)
self._config = STTOptions(
speech_region=speech_region,
language=language,
sample_rate=sample_rate,
encoding=encoding,
vocabulary_name=vocabulary_name,
session_id=session_id,
vocab_filter_method=vocab_filter_method,
vocab_filter_name=vocab_filter_name,
show_speaker_label=show_speaker_label,
enable_channel_identification=enable_channel_identification,
number_of_channels=number_of_channels,
enable_partial_results_stabilization=enable_partial_results_stabilization,
partial_results_stability=partial_results_stability,
language_model_name=language_model_name,
)

async def _recognize_impl(
self,
*,
buffer: utils.AudioBuffer,
language: str | None = None,
) -> stt.SpeechEvent:
raise NotImplementedError(
"Amazon Transcribe does not support single frame recognition"
)

def stream(
self,
*,
language: str | None = None,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> "SpeechStream":
return SpeechStream(
stt=self,
conn_options=conn_options,
opts=self._config,
)


class SpeechStream(stt.SpeechStream):
def __init__(
self,
stt: STT,
opts: STTOptions,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> None:
super().__init__(
stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate
)
self._opts = opts
self._client = TranscribeStreamingClient(region=self._opts.speech_region)

async def _run(self) -> None:
try:
# aws requires a async generator when calling start_stream_transcription
stream = await self._client.start_stream_transcription(
language_code=self._opts.language,
media_sample_rate_hz=self._opts.sample_rate,
media_encoding=self._opts.encoding,
vocabulary_name=self._opts.vocabulary_name,
session_id=self._opts.session_id,
vocab_filter_method=self._opts.vocab_filter_method,
vocab_filter_name=self._opts.vocab_filter_name,
show_speaker_label=self._opts.show_speaker_label,
enable_channel_identification=self._opts.enable_channel_identification,
number_of_channels=self._opts.number_of_channels,
enable_partial_results_stabilization=self._opts.enable_partial_results_stabilization,
partial_results_stability=self._opts.partial_results_stability,
language_model_name=self._opts.language_model_name,
)

# this function basically convert the queue into a async generator
async def input_generator():
try:
async for frame in self._input_ch:
if isinstance(frame, rtc.AudioFrame):
await stream.input_stream.send_audio_event(
audio_chunk=frame.data.tobytes()
)
await stream.input_stream.end_stream()
except Exception as e:
logger.exception(f"an error occurred while streaming inputs: {e}")

# try to connect
handler = TranscriptEventHandler(stream.output_stream, self._event_ch)
await asyncio.gather(input_generator(), handler.handle_events())
except Exception as e:
logger.exception(f"an error occurred while streaming inputs: {e}")


def _streaming_recognize_response_to_speech_data(
resp: None,
) -> stt.SpeechData:
data = stt.SpeechData(
language="en-US",
start_time=resp.start_time,
end_time=resp.end_time,
confidence=0.0,
text=resp.alternatives[0].transcript,
)

return data


class TranscriptEventHandler:
def __init__(
self,
transcript_result_stream: TranscriptResultStream,
event_ch: asyncio.Queue[Optional[stt.SpeechEvent]],
):
self._transcript_result_stream = transcript_result_stream
self._event_ch = event_ch

async def handle_events(self):
"""Process generic incoming events from Amazon Transcribe
and delegate to appropriate sub-handlers.
"""
async for event in self._transcript_result_stream:
if isinstance(event, TranscriptEvent):
await self.handle_transcript_event(event)

async def handle_transcript_event(self, transcript_event: TranscriptEvent):
# This handler can be implemented to handle transcriptions as needed.
stream = transcript_event.transcript.results
for resp in stream:
if resp.start_time == 0.0:
self._event_ch.send_nowait(
stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
)

if resp.end_time > 0.0:
if resp.is_partial:
self._event_ch.send_nowait(
stt.SpeechEvent(
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
alternatives=[
_streaming_recognize_response_to_speech_data(resp)
],
)
)

else:
self._event_ch.send_nowait(
stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[
_streaming_recognize_response_to_speech_data(resp)
],
)
)

if not resp.is_partial:
self._event_ch.send_nowait(
stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
)
Loading
Loading