Skip to content

Commit

Permalink
Add support to "cancel" expressions
Browse files Browse the repository at this point in the history
I.e., expressions used while the assistant replies that causes the reply
to be interrupted.
  • Loading branch information
paulovcmedeiros committed Feb 2, 2024
1 parent e81d0b3 commit 95a6674
Show file tree
Hide file tree
Showing 7 changed files with 394 additions and 225 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@
_ruff = "ruff check ."
# Test-related tasks
pytest = "pytest"
test = ["pytest"]
# Tasks to be run as pre-push checks
pre-push-checks = ["lint", "pytest"]

Expand Down
2 changes: 1 addition & 1 deletion pyrobbot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class GeneralDefinitions:
# Location info
try:
IPINFO = ipinfo.getHandler().getDetails().all
except requests.exceptions.ReadTimeout:
except (requests.exceptions.ReadTimeout, requests.exceptions.ConnectionError):
IPINFO = defaultdict(lambda: "unknown")

@staticmethod
Expand Down
9 changes: 7 additions & 2 deletions pyrobbot/app/multipage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from abc import ABC, abstractmethod, abstractproperty

import streamlit as st
from openai import OpenAI
from openai import OpenAI, OpenAIError
from pydantic import ValidationError

from pyrobbot import GeneralConstants
Expand Down Expand Up @@ -138,7 +138,12 @@ def init_chat_credentials(self):
+ "Chats created with this key won't be visible to people using other keys.",
)

client = OpenAI()
try:
client = OpenAI()
except OpenAIError:
st.error("Failed to connect to OpenAI API. Please check your API key.")
return

if not client.api_key:
st.write(":red[You need to provide a key to use the chat]")

Expand Down
6 changes: 6 additions & 0 deletions pyrobbot/chat_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,17 @@ class VoiceAssistantConfigs(BaseConfigModel):
openai_tts_voice: Literal[
"alloy", "echo", "fable", "onyx", "nova", "shimmer"
] = Field(default="onyx", description="Voice to use for OpenAI's TTS")

exit_expressions: list[str] = Field(
default=["bye-bye", "ok bye-bye", "okay bye-bye"],
description="Expression(s) to use in order to exit the chat",
)

cancel_expressions: list[str] = Field(
default=["ok", "cancel", "stop"],
description="Word(s) to use in order to cancel the current reply",
)

inactivity_timeout_seconds: int = Field(
default=1,
gt=0,
Expand Down
193 changes: 193 additions & 0 deletions pyrobbot/sst_and_tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
"""Code related to speech-to-text and text-to-speech conversions."""
import io
import socket
from dataclasses import dataclass, field
from typing import Literal

import numpy as np
import speech_recognition as sr
from gtts import gTTS
from loguru import logger
from openai import OpenAI
from pydub import AudioSegment

from .general_utils import retry
from .tokens import TokenUsageDatabase


@dataclass
class SpeechAndTextConfigs:
"""Configs for speech-to-text and text-to-speech."""

general_token_usage_db: TokenUsageDatabase
token_usage_db: TokenUsageDatabase
engine: Literal["openai", "google"] = "google"
language: str = "en-US"
timeout: int = 10


@dataclass
class SpeechToText(SpeechAndTextConfigs):
"""Class for converting speech to text."""

speech: AudioSegment = None
_text: str = field(init=False, default="")

def __post_init__(self):
if not self.speech:
self.speech = AudioSegment.silent(duration=0)
self.recogniser = sr.Recognizer()
self.recogniser.operation_timeout = self.timeout

wav_buffer = io.BytesIO()
self.speech.export(wav_buffer, format="wav")
wav_buffer.seek(0)
with sr.AudioFile(wav_buffer) as source:
self.audio_data = self.recogniser.listen(source)

@property
def text(self) -> str:
"""Return the text from the speech."""
if not self._text:
self._text = self._stt()
return self._text

def _stt(self) -> str:
"""Perform speech-to-text."""
if not self.speech:
logger.debug("No speech detected")
return ""

if self.engine == "openai":
stt_function = self._stt_openai
fallback_stt_function = self._stt_google
fallback_name = "google"
else:
stt_function = self._stt_google
fallback_stt_function = self._stt_openai
fallback_name = "openai"

logger.debug("Converting audio to text ({} STT)...", self.engine)
try:
rtn = stt_function()
except (
ConnectionResetError,
socket.timeout,
sr.exceptions.RequestError,
) as error:
logger.error(error)
logger.error(
"Can't communicate with `{}` speech-to-text API right now",
self.engine,
)
logger.warning("Trying to use `{}` STT instead", fallback_name)
rtn = fallback_stt_function()
except sr.exceptions.UnknownValueError:
logger.opt(colors=True).debug("<yellow>Can't understand audio</yellow>")
rtn = ""

self._text = rtn.strip()

return self._text

@retry()
def _stt_openai(self):
"""Perform speech-to-text using OpenAI's API."""
wav_buffer = io.BytesIO(self.audio_data.get_wav_data())
wav_buffer.name = "audio.wav"
with wav_buffer as audio_file_buffer:
transcript = OpenAI(timeout=self.timeout).audio.transcriptions.create(
model="whisper-1",
file=audio_file_buffer,
language=self.language.split("-")[0], # put in ISO-639-1 format
prompt=f"The language is {self.language}. "
"Do not transcribe if you think the audio is noise.",
)

for db in [
self.general_token_usage_db,
self.token_usage_db,
]:
db.insert_data(
model="whisper-1",
n_input_tokens=int(np.ceil(self.speech.duration_seconds)),
)

return transcript.text

def _stt_google(self):
"""Perform speech-to-text using Google's API."""
return self.recogniser.recognize_google(
audio_data=self.audio_data, language=self.language
)


@dataclass
class TextToSpeech(SpeechAndTextConfigs):
"""Class for converting text to speech."""

text: str = ""
openai_tts_voice: str = ""
_speech: AudioSegment = field(init=False, default=None)

def __post_init__(self):
self.text = self.text.strip()

@property
def speech(self) -> AudioSegment:
"""Return the speech from the text."""
if not self._speech:
self._speech = self._tts()
return self._speech

def set_sample_rate(self, sample_rate: int):
"""Set the sample rate of the speech."""
self._speech = self._speech.set_frame_rate(sample_rate)

def _tts(self):
logger.debug("Running {} TTS on text '{}'", self.engine, self.text)
rtn = self._tts_openai() if self.engine == "openai" else self._tts_google()
logger.debug("Done with TTS for '{}'", self.text)

return rtn

def _tts_openai(self) -> AudioSegment:
"""Convert text to speech using OpenAI's TTS. Return an AudioSegment object."""
client = OpenAI(timeout=self.timeout)

openai_tts_model = "tts-1"

@retry()
def _create_speech(*args, **kwargs):
for db in [
self.general_token_usage_db,
self.token_usage_db,
]:
db.insert_data(model=openai_tts_model, n_input_tokens=len(self.text))
return client.audio.speech.create(*args, **kwargs)

response = _create_speech(
input=self.text,
model=openai_tts_model,
voice=self.openai_tts_voice,
response_format="mp3",
timeout=self.timeout,
)

mp3_buffer = io.BytesIO()
for mp3_stream_chunk in response.iter_bytes(chunk_size=4096):
mp3_buffer.write(mp3_stream_chunk)
mp3_buffer.seek(0)

audio = AudioSegment.from_mp3(mp3_buffer)
audio += 8 # Increase volume a bit
return audio

def _tts_google(self) -> AudioSegment:
"""Convert text to speech using Google's TTS. Return a WAV BytesIO object."""
tts = gTTS(self.text, lang=self.language)
mp3_buffer = io.BytesIO()
tts.write_to_fp(mp3_buffer)
mp3_buffer.seek(0)

return AudioSegment.from_mp3(mp3_buffer)
Loading

0 comments on commit 95a6674

Please sign in to comment.