Skip to content

Commit

Permalink
Refactor for more reusability
Browse files Browse the repository at this point in the history
  • Loading branch information
paulovcmedeiros committed Feb 23, 2024
1 parent 02b0a7e commit b4f3981
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 89 deletions.
61 changes: 37 additions & 24 deletions pyrobbot/app/app_page_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
import base64
import contextlib
import datetime
import time
import uuid
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

import streamlit as st
from audiorecorder import audiorecorder
from PIL import Image
from pydub import AudioSegment

from pyrobbot import GeneralDefinitions
from pyrobbot.chat import Chat
from pyrobbot.chat_configs import ChatOptions
from pyrobbot.chat_configs import VoiceChatConfigs
from pyrobbot.sst_and_tts import TextToSpeech
from pyrobbot.voice_chat import VoiceChat

if TYPE_CHECKING:
from pyrobbot.app.multipage import MultipageChatbotApp
Expand All @@ -30,19 +31,28 @@
_RecoveredChat = object()


def autoplay_audio(audio: AudioSegment):
"""Autoplay an audio segment in the streamlit app."""
# Adaped from: <https://discuss.streamlit.io/t/
# how-to-play-an-audio-file-automatically-generated-using-text-to-speech-
# in-streamlit/33201/2>
data = audio.export(format="mp3").read()
b64 = base64.b64encode(data).decode()
md = f"""
class WebAppChat(VoiceChat):
"""A chat object for web apps."""

def __init__(self, **kwargs):
"""Initialize a new instance of the WebAppChat class."""
super().__init__(**kwargs)
self.tts_conversion_watcher_thread.start()

def speak(self, tts: TextToSpeech):
"""Autoplay an audio segment in the streamlit app."""
# Adaped from: <https://discuss.streamlit.io/t/
# how-to-play-an-audio-file-automatically-generated-using-text-to-speech-
# in-streamlit/33201/2>
data = tts.speech.export(format="mp3").read()
b64 = base64.b64encode(data).decode()
md = f"""
<audio controls autoplay="true">
<source src="data:audio/mp3;base64,{b64}" type="audio/mp3">
</audio>
"""
st.markdown(md, unsafe_allow_html=True)
st.markdown(md, unsafe_allow_html=True)
time.sleep(tts.speech.duration_seconds)


class AppPage(ABC):
Expand Down Expand Up @@ -115,15 +125,15 @@ class ChatBotPage(AppPage):
def __init__(
self,
parent: "MultipageChatbotApp",
chat_obj: Chat = None,
chat_obj: WebAppChat = None,
sidebar_title: str = "",
page_title: str = "",
):
"""Initialize new instance of the ChatBotPage class with an optional Chat object.
"""Initialize new instance of the ChatBotPage class with an opt WebAppChat object.
Args:
parent (MultipageChatbotApp): The parent app of the page.
chat_obj (Chat): The chat object. Defaults to None.
chat_obj (WebAppChat): The chat object. Defaults to None.
sidebar_title (str): The sidebar title for the chatbot page.
Defaults to an empty string.
page_title (str): The title for the chatbot page.
Expand All @@ -139,29 +149,29 @@ def __init__(
self.avatars = {"assistant": _ASSISTANT_AVATAR_IMAGE, "user": _USER_AVATAR_IMAGE}

@property
def chat_configs(self) -> ChatOptions:
def chat_configs(self) -> VoiceChatConfigs:
"""Return the configs used for the page's chat object."""
if "chat_configs" not in self.state:
self.state["chat_configs"] = self.parent.state["chat_configs"]
return self.state["chat_configs"]

@chat_configs.setter
def chat_configs(self, value: ChatOptions):
self.state["chat_configs"] = ChatOptions.model_validate(value)
def chat_configs(self, value: VoiceChatConfigs):
self.state["chat_configs"] = VoiceChatConfigs.model_validate(value)
if "chat_obj" in self.state:
del self.state["chat_obj"]

@property
def chat_obj(self) -> Chat:
def chat_obj(self) -> WebAppChat:
"""Return the chat object responsible for the queries on this page."""
if "chat_obj" not in self.state:
self.chat_obj = Chat(
self.chat_obj = WebAppChat(
configs=self.chat_configs, openai_client=self.parent.openai_client
)
return self.state["chat_obj"]

@chat_obj.setter
def chat_obj(self, new_chat_obj: Chat):
def chat_obj(self, new_chat_obj: WebAppChat):
current_chat = self.state.get("chat_obj")
if current_chat:
current_chat.save_cache()
Expand Down Expand Up @@ -221,6 +231,7 @@ def _render_chatbot_page(self):
)

mic_input = st.session_state.get("toggle_mic_input", False)
self.chat_obj.reply_only_as_text = not mic_input
prompt = (
self.state.pop("recorded_prompt", None)
if mic_input
Expand Down Expand Up @@ -251,13 +262,15 @@ def _render_chatbot_page(self):
with st.empty():
st.markdown("▌")
full_response = ""
for chunk in self.chat_obj.respond_user_prompt(prompt):
full_response += chunk
for chunk in self.chat_obj.answer_question(prompt):
full_response += chunk.content
st.markdown(full_response + "▌")
st.caption(datetime.datetime.now().replace(microsecond=0))
st.markdown(full_response)
if mic_input:
autoplay_audio(self.chat_obj.tts(full_response).speech)
while not self.chat_obj.play_speech_queue.empty():
self.chat_obj.speak(self.chat_obj.play_speech_queue.get())
self.chat_obj.play_speech_queue.task_done()
prompt = None

self.chat_history.append(
Expand Down
31 changes: 17 additions & 14 deletions pyrobbot/app/multipage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
from pydantic import ValidationError

from pyrobbot import GeneralDefinitions
from pyrobbot.app.app_page_templates import (
from pyrobbot.chat_configs import VoiceChatConfigs
from pyrobbot.openai_utils import OpenAiClientWrapper

from .app_page_templates import (
_ASSISTANT_AVATAR_IMAGE,
AppPage,
ChatBotPage,
WebAppChat,
_RecoveredChat,
)
from pyrobbot.chat import Chat
from pyrobbot.chat_configs import ChatOptions
from pyrobbot.openai_utils import OpenAiClientWrapper


class AbstractMultipageApp(ABC):
Expand Down Expand Up @@ -137,15 +138,17 @@ def openai_client(self) -> OpenAiClientWrapper:
return self.state["openai_client"]

@property
def chat_configs(self) -> ChatOptions:
def chat_configs(self) -> VoiceChatConfigs:
"""Return the configs used for the page's chat object."""
if "chat_configs" not in self.state:
try:
chat_options_file_path = sys.argv[-1]
self.state["chat_configs"] = ChatOptions.from_file(chat_options_file_path)
self.state["chat_configs"] = VoiceChatConfigs.from_file(
chat_options_file_path
)
except (FileNotFoundError, JSONDecodeError):
logger.warning("Could not retrieve cli args. Using default chat options.")
self.state["chat_configs"] = ChatOptions()
self.state["chat_configs"] = VoiceChatConfigs()
return self.state["chat_configs"]

def create_api_key_element(self):
Expand Down Expand Up @@ -208,9 +211,9 @@ def handle_ui_page_selection(self):

# Present the user with the model and instructions fields first
field_names = ["model", "ai_instructions", "context_model"]
field_names += list(ChatOptions.model_fields)
field_names += list(VoiceChatConfigs.model_fields)
field_names = list(dict.fromkeys(field_names))
model_fields = {k: ChatOptions.model_fields[k] for k in field_names}
model_fields = {k: VoiceChatConfigs.model_fields[k] for k in field_names}

updates_to_chat_configs = self._handle_chat_configs_value_selection(
current_chat_configs, model_fields
Expand All @@ -219,7 +222,7 @@ def handle_ui_page_selection(self):
if updates_to_chat_configs:
new_chat_configs = current_chat_configs.model_dump()
new_chat_configs.update(updates_to_chat_configs)
new_chat = Chat.from_dict(new_chat_configs)
new_chat = WebAppChat.from_dict(new_chat_configs)
self.selected_page.chat_obj = new_chat

def render(self, **kwargs):
Expand Down Expand Up @@ -264,7 +267,7 @@ def render(self, **kwargs):
self.state["saved_chats_reloaded"] = True
for cache_dir_path in self.openai_client.saved_chat_cache_paths:
try:
chat = Chat.from_cache(
chat = WebAppChat.from_cache(
cache_dir=cache_dir_path, openai_client=self.openai_client
)
except ValidationError:
Expand Down Expand Up @@ -357,9 +360,9 @@ def _handle_chat_configs_value_selection(self, current_chat_configs, model_field
updates_to_chat_configs = {}
for field_name, field in model_fields.items():
title = field_name.replace("_", " ").title()
choices = ChatOptions.get_allowed_values(field=field_name)
description = ChatOptions.get_description(field=field_name)
field_type = ChatOptions.get_type(field=field_name)
choices = VoiceChatConfigs.get_allowed_values(field=field_name)
description = VoiceChatConfigs.get_description(field=field_name)
field_type = VoiceChatConfigs.get_type(field=field_name)

# Check if the field is frozen and disable corresponding UI element if so
chat_started = self.selected_page.state.get("chat_started", False)
Expand Down
41 changes: 33 additions & 8 deletions pyrobbot/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Optional

import openai
from attr import dataclass
from loguru import logger
from pydub import AudioSegment
from tzlocal import get_localzone
Expand All @@ -24,6 +25,14 @@
from .tokens import PRICE_PER_K_TOKENS_EMBEDDINGS, TokenUsageDatabase


@dataclass
class AssistantResponseChunk:
"""A chunk of the assistant's response."""

content: str
chunk_type: str = "text"


class Chat(AlternativeConstructors):
"""Manages conversations with an AI chat model.
Expand Down Expand Up @@ -52,6 +61,8 @@ def __init__(
self.id = str(uuid.uuid4())
logger.debug("Init chat {}", self.id)

self._code_marker = "\uE001" # TEST

self._passed_configs = configs
for field in self._passed_configs.model_fields:
setattr(self, field, self._passed_configs[field])
Expand All @@ -74,13 +85,14 @@ def __init__(
@property
def base_directive(self):
"""Return the base directive for the LLM."""
code_marker = self._code_marker
local_datetime = datetime.now(get_localzone()).isoformat(timespec="seconds")
msg_content = (
f"Your name is {self.assistant_name}. Your model is {self.model}\n"
f"You are a helpful assistant to {self.username}\n"
f"You have internet access\n"
+ "\n".join([f"{instruct.strip(' .')}." for instruct in self.ai_instructions])
+ "\n"
f"You MUST ALWAYS write {code_marker} before AND after code blocks. Example: "
f"```foo ... ``` MUST become {code_marker}```foo ... ```{code_marker}\n"
f"The current city is {GeneralDefinitions.IPINFO['city']} in "
f"{GeneralDefinitions.IPINFO['country_name']}\n"
f"The local datetime is {local_datetime}\n"
Expand All @@ -95,6 +107,7 @@ def base_directive(self):
" > Do *NOT* apologise nor say you are sorry nor give any excuses.\n"
" > Do *NOT* ask for permission to lookup online.\n"
" > STATE CLEARLY that you will look it up online.\n"
"\n".join([f"{instruct.strip(' .')}." for instruct in self.ai_instructions])
)
return {"role": "system", "name": self.system_name, "content": msg_content}

Expand Down Expand Up @@ -223,22 +236,34 @@ def respond_system_prompt(
self, prompt: str, add_to_history=False, skip_check=True, **kwargs
):
"""Respond to a system prompt."""
yield from self._respond_prompt(
for response_chunk in self._respond_prompt(
prompt=prompt,
role="system",
add_to_history=add_to_history,
skip_check=skip_check,
**kwargs,
)
):
yield response_chunk.content

def yield_response_from_msg(
self, prompt_msg: dict, add_to_history: bool = True, **kwargs
):
"""Yield response from a prompt message."""
code_marker = self._code_marker
try:
yield from self._yield_response_from_msg(
inside_code_block = False
for answer_chunk in self._yield_response_from_msg(
prompt_msg=prompt_msg, add_to_history=add_to_history, **kwargs
)
):
code_marker_detected = code_marker in answer_chunk
inside_code_block = (code_marker_detected and not inside_code_block) or (
inside_code_block and not code_marker_detected
)
yield AssistantResponseChunk(
content=answer_chunk.strip(code_marker),
chunk_type="code" if inside_code_block else "text",
)

except (ReachedMaxNumberOfAttemptsError, openai.OpenAIError) as error:
yield self.response_failure_message(error=error)

Expand All @@ -253,7 +278,7 @@ def start(self):
continue
print(f"{self.assistant_name}> ", end="", flush=True)
for chunk in self.respond_user_prompt(prompt=question):
print(chunk, end="", flush=True)
print(chunk.content, end="", flush=True)
print()
print()
except (KeyboardInterrupt, EOFError):
Expand Down Expand Up @@ -288,7 +313,7 @@ def response_failure_message(self, error: Optional[Exception] = None):
msg += f" The reason seems to be: {error} "
msg += "Please check your connection or OpenAI API key."
logger.opt(exception=True).debug(error)
return msg
return AssistantResponseChunk(msg)

def stt(self, speech: AudioSegment):
"""Convert audio to text."""
Expand Down
3 changes: 3 additions & 0 deletions pyrobbot/chat_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ class VoiceAssistantConfigs(BaseConfigModel):
frame_duration: Literal[10, 20, 30] = Field(
default=30, description="Frame duration for audio recording, in milliseconds."
)
reply_only_as_text: Optional[bool] = Field(
default=None, description="Reply only as text. The assistant will not speak."
)
skip_initial_greeting: Optional[bool] = Field(
default=None, description="Skip initial greeting."
)
Expand Down
Loading

0 comments on commit b4f3981

Please sign in to comment.