Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support AWS plugin for TTS, STT and LLM #1302

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
23efbf4
base files for AWS plugins
dasxran Jul 14, 2024
70fcfc3
aws workflow
dasxran Jul 14, 2024
fa9a9f5
rename aws folder
dasxran Jul 14, 2024
7282f51
Polly TTS
dasxran Jul 14, 2024
7d07511
Update test_tts with aws
dasxran Jul 15, 2024
485e410
Merge branch 'livekit:main' into dex/aws
dasxran Jul 20, 2024
c1676ed
Merge branch 'livekit:main' into dex/aws
dasxran Jul 28, 2024
b19b2c8
aws transcribe
dasxran Jul 28, 2024
32e8f32
livekit-agents v0.8.0
dasxran Jul 30, 2024
0d40699
Merge remote-tracking branch 'dasxran/dex/aws' into aws-tts-stt
jayeshp19 Dec 9, 2024
8d7210f
ruff check & yml changes
jayeshp19 Dec 9, 2024
0d84256
setup changes
jayeshp19 Dec 9, 2024
9c0c373
Merge branch 'main' of https://github.com/livekit/agents into aws-tts…
jayeshp19 Dec 11, 2024
b2b5614
Merge branch 'main' of https://github.com/livekit/agents into aws-tts…
jayeshp19 Dec 26, 2024
c2d26fc
updates
jayeshp19 Dec 26, 2024
ca7d609
updates
jayeshp19 Dec 26, 2024
84741d9
updates
jayeshp19 Dec 26, 2024
660d783
updates
jayeshp19 Dec 26, 2024
81e2ebd
updates
jayeshp19 Jan 13, 2025
93f8b19
updates
jayeshp19 Jan 13, 2025
69b286b
updates
jayeshp19 Jan 13, 2025
9d52b78
updates
jayeshp19 Jan 13, 2025
71f450d
updates
jayeshp19 Jan 13, 2025
dbf090d
updates
jayeshp19 Jan 13, 2025
0f8b2a3
updates
jayeshp19 Jan 13, 2025
121cc02
updates
jayeshp19 Jan 13, 2025
a7dbac9
updates
jayeshp19 Jan 13, 2025
98d9882
updates
jayeshp19 Jan 13, 2025
856ba87
updates
jayeshp19 Jan 13, 2025
599e197
updates
jayeshp19 Jan 13, 2025
651390a
updates
jayeshp19 Jan 13, 2025
c1127f3
debug
jayeshp19 Jan 15, 2025
941ef3f
debug
jayeshp19 Jan 17, 2025
55768ae
Merge branch 'main' of https://github.com/livekit/agents into aws-tts…
jayeshp19 Jan 18, 2025
5af4053
changeset
jayeshp19 Jan 20, 2025
33c787d
Merge branch 'main' of https://github.com/livekit/agents into aws-tts…
jayeshp19 Jan 20, 2025
23de909
updates
jayeshp19 Jan 20, 2025
6b8e706
updates
jayeshp19 Jan 20, 2025
1b56f02
Merge branch 'main' of https://github.com/livekit/agents into aws-tts…
jayeshp19 Jan 23, 2025
827a09c
updates
jayeshp19 Jan 25, 2025
fa2efdd
supprot bedrock llms
jayeshp19 Jan 27, 2025
c735e47
ruff
jayeshp19 Jan 27, 2025
e8557e1
updates
jayeshp19 Jan 27, 2025
3e4869b
updates
jayeshp19 Jan 27, 2025
c46554c
updates
jayeshp19 Jan 27, 2025
1a07bb6
updates
jayeshp19 Jan 27, 2025
c6687d6
updates
jayeshp19 Jan 27, 2025
c9dabff
updates
jayeshp19 Jan 27, 2025
fe986ab
update test llm
jayeshp19 Feb 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/slow-keys-invite.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-plugins-aws": minor
---

initial release
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,5 @@ jobs:
-p livekit.plugins.fal \
-p livekit.plugins.playai \
-p livekit.plugins.assemblyai \
-p livekit.plugins.rime
-p livekit.plugins.rime \
-p livekit.plugins.aws
2 changes: 2 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ jobs:
RIME_API_KEY: ${{ secrets.RIME_API_KEY }}
GOOGLE_APPLICATION_CREDENTIALS: google.json
PYTEST_ADDOPTS: "--color=yes"
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
working-directory: tests
run: |
echo "$GOOGLE_CREDENTIALS_JSON" > google.json
Expand Down
3 changes: 2 additions & 1 deletion livekit-plugins/install_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ pip install \
"${SCRIPT_DIR}/livekit-plugins-playai" \
"${SCRIPT_DIR}/livekit-plugins-silero" \
"${SCRIPT_DIR}/livekit-plugins-turn-detector" \
jayeshp19 marked this conversation as resolved.
Show resolved Hide resolved
"${SCRIPT_DIR}/livekit-plugins-rime"
"${SCRIPT_DIR}/livekit-plugins-rime" \
"${SCRIPT_DIR}/livekit-plugins-aws"
1 change: 1 addition & 0 deletions livekit-plugins/livekit-plugins-aws/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# livekit-plugins-aws
13 changes: 13 additions & 0 deletions livekit-plugins/livekit-plugins-aws/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# LiveKit Plugins AWS

Agent Framework plugin for services from AWS Services. Currently supports STT and TTS.

## Installation

```bash
pip install livekit-plugins-aws
```

## Pre-requisites

You'll need to specify an AWS Access Key and a Deployment Region. They can be set as environment variables: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY` and `AWS_DEFAULT_REGION`, respectively.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .stt import STT, SpeechStream
from .tts import TTS, ChunkedStream
from .version import __version__

__all__ = ["STT", "SpeechStream", "TTS", "ChunkedStream", "__version__"]

from livekit.agents import Plugin


class AWSPlugin(Plugin):
def __init__(self) -> None:
super().__init__(__name__, __version__, __package__)


Plugin.register_plugin(AWSPlugin())
29 changes: 29 additions & 0 deletions livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os
from typing import Optional

import boto3 # type: ignore


def _get_aws_credentials(
api_key: Optional[str], api_secret: Optional[str], region: Optional[str]
):
region = region or os.environ.get("AWS_DEFAULT_REGION")
if not region:
raise ValueError(
"AWS_DEFAULT_REGION must be set using the argument or by setting the AWS_DEFAULT_REGION environment variable."
)

# If API key and secret are provided, create a session with them
if api_key and api_secret:
session = boto3.Session(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this making network calls?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

boto3.Session() doesn’t make network calls, but session.get_credentials() does if API keys and secrets aren’t cached., but we’re calling it during initialization, it’s a one-time operation.

aws_access_key_id=api_key,
aws_secret_access_key=api_secret,
region_name=region,
)
else:
session = boto3.Session(region_name=region)

credentials = session.get_credentials()
if not credentials or not credentials.access_key or not credentials.secret_key:
raise ValueError("No valid AWS credentials found.")
return credentials.access_key, credentials.secret_key
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import logging

logger = logging.getLogger("livekit.plugins.aws")
48 changes: 48 additions & 0 deletions livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Literal

TTS_SPEECH_ENGINE = Literal["standard", "neural", "long-form", "generative"]
TTS_LANGUAGE = Literal[
"arb",
"cmn-CN",
"cy-GB",
"da-DK",
"de-DE",
"en-AU",
"en-GB",
"en-GB-WLS",
"en-IN",
"en-US",
"es-ES",
"es-MX",
"es-US",
"fr-CA",
"fr-FR",
"is-IS",
"it-IT",
"ja-JP",
"hi-IN",
"ko-KR",
"nb-NO",
"nl-NL",
"pl-PL",
"pt-BR",
"pt-PT",
"ro-RO",
"ru-RU",
"sv-SE",
"tr-TR",
"en-NZ",
"en-ZA",
"ca-ES",
"de-AT",
"yue-CN",
"ar-AE",
"fi-FI",
"en-IE",
"nl-BE",
"fr-BE",
"cs-CZ",
"de-CH",
]

TTS_OUTPUT_FORMAT = Literal["pcm", "mp3"]
Empty file.
218 changes: 218 additions & 0 deletions livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/stt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import asyncio
from dataclasses import dataclass
from typing import Optional

from amazon_transcribe.client import TranscribeStreamingClient
from amazon_transcribe.model import Result, TranscriptEvent
from livekit import rtc
from livekit.agents import (
DEFAULT_API_CONNECT_OPTIONS,
APIConnectOptions,
stt,
utils,
)

from ._utils import _get_aws_credentials
from .log import logger


@dataclass
class STTOptions:
speech_region: str
sample_rate: int
language: str
encoding: str
vocabulary_name: Optional[str]
session_id: Optional[str]
vocab_filter_method: Optional[str]
vocab_filter_name: Optional[str]
show_speaker_label: Optional[bool]
enable_channel_identification: Optional[bool]
number_of_channels: Optional[int]
enable_partial_results_stabilization: Optional[bool]
partial_results_stability: Optional[str]
language_model_name: Optional[str]


class STT(stt.STT):
def __init__(
self,
*,
speech_region: str = "us-east-1",
api_key: str | None = None,
api_secret: str | None = None,
sample_rate: int = 48000,
language: str = "en-US",
encoding: str = "pcm",
vocabulary_name: Optional[str] = None,
session_id: Optional[str] = None,
vocab_filter_method: Optional[str] = None,
vocab_filter_name: Optional[str] = None,
show_speaker_label: Optional[bool] = None,
enable_channel_identification: Optional[bool] = None,
number_of_channels: Optional[int] = None,
enable_partial_results_stabilization: Optional[bool] = None,
partial_results_stability: Optional[str] = None,
language_model_name: Optional[str] = None,
):
super().__init__(
capabilities=stt.STTCapabilities(streaming=True, interim_results=True)
)

self._api_key, self._api_secret = _get_aws_credentials(
api_key, api_secret, speech_region
)
self._config = STTOptions(
speech_region=speech_region,
language=language,
sample_rate=sample_rate,
encoding=encoding,
vocabulary_name=vocabulary_name,
session_id=session_id,
vocab_filter_method=vocab_filter_method,
vocab_filter_name=vocab_filter_name,
show_speaker_label=show_speaker_label,
enable_channel_identification=enable_channel_identification,
number_of_channels=number_of_channels,
enable_partial_results_stabilization=enable_partial_results_stabilization,
partial_results_stability=partial_results_stability,
language_model_name=language_model_name,
)

async def _recognize_impl(
self,
buffer: utils.AudioBuffer,
*,
language: str | None,
conn_options: APIConnectOptions,
) -> stt.SpeechEvent:
raise NotImplementedError(
"Amazon Transcribe does not support single frame recognition"
)

def stream(
self,
*,
language: str | None = None,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> "SpeechStream":
return SpeechStream(
stt=self,
conn_options=conn_options,
opts=self._config,
)


class SpeechStream(stt.SpeechStream):
def __init__(
self,
stt: STT,
opts: STTOptions,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> None:
super().__init__(
stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate
)
self._opts = opts
self._client = TranscribeStreamingClient(region=self._opts.speech_region)

async def _run(self) -> None:
stream = await self._client.start_stream_transcription(
language_code=self._opts.language,
media_sample_rate_hz=self._opts.sample_rate,
media_encoding=self._opts.encoding,
vocabulary_name=self._opts.vocabulary_name,
session_id=self._opts.session_id,
vocab_filter_method=self._opts.vocab_filter_method,
vocab_filter_name=self._opts.vocab_filter_name,
show_speaker_label=self._opts.show_speaker_label,
enable_channel_identification=self._opts.enable_channel_identification,
number_of_channels=self._opts.number_of_channels,
enable_partial_results_stabilization=self._opts.enable_partial_results_stabilization,
partial_results_stability=self._opts.partial_results_stability,
language_model_name=self._opts.language_model_name,
)

@utils.log_exceptions(logger=logger)
async def input_generator():
async for frame in self._input_ch:
if isinstance(frame, rtc.AudioFrame):
await stream.input_stream.send_audio_event(
audio_chunk=frame.data.tobytes()
)
await stream.input_stream.end_stream()

@utils.log_exceptions(logger=logger)
async def handle_transcript_events():
async for event in stream.output_stream:
if isinstance(event, TranscriptEvent):
self._process_transcript_event(event)

tasks = [
asyncio.create_task(input_generator()),
asyncio.create_task(handle_transcript_events()),
]
try:
await asyncio.gather(*tasks)
finally:
await utils.aio.gracefully_cancel(*tasks)

def _process_transcript_event(self, transcript_event: TranscriptEvent):
stream = transcript_event.transcript.results
for resp in stream:
if resp.start_time and resp.start_time == 0.0:
self._event_ch.send_nowait(
stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
)

if resp.end_time and resp.end_time > 0.0:
if resp.is_partial:
self._event_ch.send_nowait(
stt.SpeechEvent(
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
alternatives=[
_streaming_recognize_response_to_speech_data(resp)
],
)
)

else:
self._event_ch.send_nowait(
stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[
_streaming_recognize_response_to_speech_data(resp)
],
)
)

if not resp.is_partial:
self._event_ch.send_nowait(
stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
)


def _streaming_recognize_response_to_speech_data(resp: Result) -> stt.SpeechData:
data = stt.SpeechData(
language="en-US",
start_time=resp.start_time if resp.start_time else 0.0,
end_time=resp.end_time if resp.end_time else 0.0,
confidence=0.0,
text=resp.alternatives[0].transcript if resp.alternatives else "",
)

return data
Loading
Loading