Skip to content

Commit

Permalink
Do text-to-speech over completion chunks
Browse files Browse the repository at this point in the history
So that the assistant can start talking while data is still bein
streamed.
  • Loading branch information
paulovcmedeiros committed Nov 16, 2023
2 parents 109c7ae + 7046c74 commit aa7ef16
Show file tree
Hide file tree
Showing 10 changed files with 149 additions and 75 deletions.
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
license = "MIT"
name = "pyrobbot"
readme = "README.md"
version = "0.3.0"
version = "0.3.1"

[build-system]
build-backend = "poetry.core.masonry.api"
Expand All @@ -31,6 +31,7 @@
streamlit = "^1.28.0"
tiktoken = "^0.5.1"
# Text to speech
chime = "^0.7.0"
gtts = "^2.4.0"
pydub = "^0.25.1"
pygame = "^2.5.2"
Expand Down Expand Up @@ -60,6 +61,7 @@
##################
# Linter configs #
##################
pytest-xdist = "^3.4.0"

[tool.black]
line-length = 90
Expand Down Expand Up @@ -130,7 +132,7 @@
##################

[tool.pytest.ini_options]
addopts = "-v --cache-clear --failed-first --cov-report=term-missing --cov-report=term:skip-covered --cov-report=xml:.coverage.xml --cov=./"
addopts = "-v --cache-clear -n auto --failed-first --cov-report=term-missing --cov-report=term:skip-covered --cov-report=xml:.coverage.xml --cov=./"
log_cli_level = "INFO"
testpaths = ["tests/smoke", "tests/unit"]

Expand Down
9 changes: 9 additions & 0 deletions pyrobbot/argparse_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ def _populate_parser_from_pydantic_model(parser, model):
for key in _argarse2pydantic
if _argarse2pydantic[key](field_name) is not None
}

if args_opts.get("type") == bool:
if args_opts.get("default") is True:
args_opts["action"] = "store_false"
else:
args_opts["action"] = "store_true"
args_opts.pop("default", None)
args_opts.pop("type", None)

args_opts["required"] = field.is_required()
if "help" in args_opts:
args_opts["help"] = f"{args_opts['help']} (default: %(default)s)"
Expand Down
32 changes: 14 additions & 18 deletions pyrobbot/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import json
import shutil
import uuid
from collections import defaultdict

from loguru import logger

Expand All @@ -23,9 +22,10 @@ class Chat(AlternativeConstructors):
responses.
"""

_initial_greeting_translations = defaultdict(lambda: defaultdict(str))
_initial_greeting_translations = {} # map language:translation
default_configs = ChatOptions()

def __init__(self, configs: ChatOptions = None):
def __init__(self, configs: ChatOptions = default_configs):
"""Initializes a chat instance.
Args:
Expand All @@ -37,9 +37,6 @@ def __init__(self, configs: ChatOptions = None):
self.id = str(uuid.uuid4())
self.initial_openai_key_hash = GeneralConstants.openai_key_hash()

if configs is None:
configs = ChatOptions()

self._passed_configs = configs
for field in self._passed_configs.model_fields:
setattr(self, field, self._passed_configs[field])
Expand Down Expand Up @@ -157,26 +154,24 @@ def load_history(self):
@property
def initial_greeting(self):
"""Return the initial greeting for the chat."""
default_greeting = f"Hi! I'm {self.assistant_name}. How can I assist you today?"
try:
passed_greeting = self._initial_greeting.strip()
except AttributeError:
passed_greeting = ""

if not passed_greeting:
self._initial_greeting = (
f"Hello! I'm {self.assistant_name}. How can I assist you today?"
)
self._initial_greeting = default_greeting

translated_greeting = type(self)._initial_greeting_translations[ # noqa: SLF001
self._initial_greeting
][self.language]
if not translated_greeting:
translated_greeting = self._translate(self._initial_greeting)
type(self)._initial_greeting_translations[ # noqa: SLF001
self._initial_greeting
][self.language] = translated_greeting
if passed_greeting or self.language != "en":
translation_cache = type(self)._initial_greeting_translations # noqa: SLF001
translated_greeting = translation_cache.get(self.language)
if not translated_greeting:
translated_greeting = self._translate(self._initial_greeting)
translation_cache[self.language] = translated_greeting
self._initial_greeting = translated_greeting

return translated_greeting
return self._initial_greeting

@initial_greeting.setter
def initial_greeting(self, value: str):
Expand Down Expand Up @@ -280,6 +275,7 @@ def _respond_prompt(self, prompt: str, role: str, **kwargs):

def _translate(self, text):
lang = self.language
logger.debug("Processing translation of '{}' to '{}'...", text, lang)
translation_prompt = f"Translate the text between triple quotes to {lang}. "
translation_prompt += "DO NOT WRITE ANYTHING ELSE. Only the translation. "
translation_prompt += f"If the text is already in {lang}, then just repeat "
Expand Down
13 changes: 10 additions & 3 deletions pyrobbot/chat_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class ChatOptions(OpenAiApiCallOptions):
)
private_mode: Optional[bool] = Field(
default=None,
description="Toggle private mode. If set to `True`, the chat will not "
description="Toggle private mode. If this flag is used, the chat will not "
+ "be logged and the chat history will not be saved.",
)
api_connection_max_n_attempts: int = Field(
Expand All @@ -174,15 +174,19 @@ class VoiceAssistantConfigs(BaseConfigModel):
openai_tts_voice: Literal[
"alloy", "echo", "fable", "onyx", "nova", "shimmer"
] = Field(default="onyx", description="Voice to use for OpenAI's TTS")
exit_expressions: list[str] = Field(
default=["bye-bye", "ok bye-bye", "okay bye-bye"],
description="Expression(s) to use in order to exit the chat",
)

inactivity_timeout_seconds: int = Field(
default=2,
default=1,
gt=0,
description="How much time user should be inactive "
"for the assistant to stop listening",
)
speech_likelihood_threshold: float = Field(
default=0.85,
default=0.5,
ge=0.0,
le=1.0,
description="Accept audio as speech if the likelihood is above this threshold",
Expand All @@ -195,6 +199,9 @@ class VoiceAssistantConfigs(BaseConfigModel):
frame_duration: Literal[10, 20, 30] = Field(
default=30, description="Frame duration for audio recording, in milliseconds."
)
skip_initial_greeting: Optional[bool] = Field(
default=None, description="Skip initial greeting."
)


class VoiceChatConfigs(ChatOptions, VoiceAssistantConfigs):
Expand Down
2 changes: 1 addition & 1 deletion pyrobbot/command_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .chat import Chat
from .chat_configs import ChatOptions
from .openai_utils import CannotConnectToApiError
from .text_to_speech import VoiceChat
from .voice_chat import VoiceChat


def voice_chat(args):
Expand Down
6 changes: 2 additions & 4 deletions pyrobbot/general_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ def from_dict(cls, configs: dict):
Returns:
cls: An instance of Chat initialized with the given configurations.
"""
dummy = cls()
return cls(configs=dummy.configs.model_validate(configs))
return cls(configs=cls.default_configs.model_validate(configs))

@classmethod
def from_cli_args(cls, cli_args):
Expand All @@ -37,11 +36,10 @@ def from_cli_args(cls, cli_args):
Returns:
cls: An instance of the class initialized with CLI-specified configurations.
"""
dummy = cls()
chat_opts = {
k: v
for k, v in vars(cli_args).items()
if k in dummy.configs.model_fields and v is not None
if k in cls.default_configs.model_fields and v is not None
}
return cls.from_dict(chat_opts)

Expand Down
Loading

0 comments on commit aa7ef16

Please sign in to comment.