Skip to content

Commit

Permalink
Refactoring in webchat app code
Browse files Browse the repository at this point in the history
  • Loading branch information
paulovcmedeiros committed Mar 2, 2024
1 parent 14224c8 commit 6c36fe4
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 228 deletions.
247 changes: 29 additions & 218 deletions pyrobbot/app/app_page_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import contextlib
import datetime
import queue
import threading
import time
import uuid
from abc import ABC, abstractmethod
Expand All @@ -14,34 +13,27 @@
import streamlit as st
from audio_recorder_streamlit import audio_recorder
from loguru import logger
from PIL import Image
from pydub import AudioSegment
from pydub.exceptions import CouldntDecodeError
from streamlit.runtime.scriptrunner import add_script_run_ctx
from streamlit_mic_recorder import mic_recorder

from pyrobbot import GeneralDefinitions
from pyrobbot.chat_configs import VoiceChatConfigs
from pyrobbot.voice_chat import VoiceChat

if TYPE_CHECKING:
from .app_utils import (
AsyncReplier,
WebAppChat,
filter_page_info_from_queue,
get_avatar_images,
load_chime,
)

from pyrobbot.app.multipage import MultipageChatbotApp
if TYPE_CHECKING:
from .multipage import MultipageChatbotApp

# Sentinel object for when a chat is recovered from cache
_RecoveredChat = object()


class WebAppChat(VoiceChat):
"""A chat object for web apps."""

def __init__(self, **kwargs):
"""Initialize a new instance of the WebAppChat class."""
super().__init__(**kwargs)
self.tts_conversion_watcher_thread.start()
self.handle_update_audio_history_thread.start()


class AppPage(ABC):
"""Abstract base class for a page within a streamlit application."""

Expand Down Expand Up @@ -127,9 +119,6 @@ def manual_switch_mic_recorder(self):
stop_prompt=red_square,
just_once=True,
use_container_width=True,
callback=None,
args=(),
kwargs={},
)

if recording is None:
Expand Down Expand Up @@ -297,7 +286,7 @@ def voice_output(self) -> bool:

def play_chime(self, chime_type: str = "correct-answer-tone", parent_element=None):
"""Sound a chime to send notificatons to the user."""
chime = _load_chime(chime_type)
chime = load_chime(chime_type)
self.render_custom_audio_player(
chime, hidden=True, autoplay=True, parent_element=parent_element
)
Expand All @@ -318,18 +307,15 @@ def direct_text_prompt(self):
placeholder = (
f"Send a message to {self.chat_obj.assistant_name} ({self.chat_obj.model})"
)
text_from_manual_audio_recorder = ""
with st.container():
left, right = st.columns([0.95, 0.05])
with left:
text_from_chat_input_widget = st.chat_input(placeholder=placeholder)
with right:
text_from_manual_audio_recorder = ""
if st.session_state.get("toggle_continuous_voice_input"):
st.empty()
else:
text_from_manual_audio_recorder = self.chat_obj.stt(
self.manual_switch_mic_recorder()
).text
if not st.session_state.get("toggle_continuous_voice_input"):
audio = self.manual_switch_mic_recorder()
text_from_manual_audio_recorder = self.chat_obj.stt(audio).text
return text_from_chat_input_widget or text_from_manual_audio_recorder

@property
Expand All @@ -350,11 +336,11 @@ def continuous_text_prompt(self):
self.play_chime()
with st.spinner(f"{self.chat_obj.assistant_name} is listening..."):
while True:
with self.parent.text_prompt_queue.mutex:
this_page_prompt_queue = filter_page_info_from_queue(
app_page=self, the_queue=self.parent.text_prompt_queue
)
with contextlib.suppress(queue.Empty):
with self.parent.text_prompt_queue.mutex:
this_page_prompt_queue = filter_page_info_from_queue(
app_page=self, the_queue=self.parent.text_prompt_queue
)
if prompt := this_page_prompt_queue.get_nowait()["text"]:
this_page_prompt_queue.task_done()
break
Expand All @@ -372,27 +358,23 @@ def _render_chatbot_page(self): # noqa: PLR0915
"""
self.chat_obj.reply_only_as_text = not self.voice_output
question_answer_chunks_queue = queue.Queue()
partial_audios_queue = queue.Queue()

self.render_title()
chat_msgs_container = st.container(height=600, border=False)
with chat_msgs_container:
self.render_chat_history()

# The inputs should be rendered after the chat history. There is a performance
# penalty otherwise, as rendering the history causes streamlit to rerun the
# entire page
direct_text_prompt = self.direct_text_prompt
continuous_stt_prompt = self.continuous_text_prompt
continuous_stt_prompt = "" if direct_text_prompt else self.continuous_text_prompt
prompt = direct_text_prompt or continuous_stt_prompt

if prompt:
logger.opt(colors=True).debug("<yellow>Recived prompt: {}</yellow>", prompt)
self.parent.reply_ongoing.set()

# Interrupt any ongoing reply in this page
with question_answer_chunks_queue.mutex:
question_answer_chunks_queue.queue.clear()
with partial_audios_queue.mutex:
partial_audios_queue.queue.clear()

if continuous_stt_prompt:
self.play_chime("option-select")
self.status_msg_container.success("Got your message!")
Expand Down Expand Up @@ -425,103 +407,18 @@ def _render_chatbot_page(self): # noqa: PLR0915

# Display (stream) assistant response in chat message container
with st.chat_message("assistant", avatar=self.avatars["assistant"]):
text_reply_container = st.empty()
audio_reply_container = st.empty()

# Create threads to process text and audio replies asynchronously
answer_question_thread = threading.Thread(
target=_put_chat_reply_chunks_in_queue,
args=(self.chat_obj, prompt, question_answer_chunks_queue),
)
play_partial_audios_thread = threading.Thread(
target=_play_queued_audios,
args=(
partial_audios_queue,
self.render_custom_audio_player,
self.status_msg_container,
),
daemon=False,
)
for thread in (
answer_question_thread,
play_partial_audios_thread,
):
add_script_run_ctx(thread)
thread.start()

# Render the reply
chunk = ""
full_response = ""
current_audio = AudioSegment.empty()
text_reply_container.markdown("▌")
self.status_msg_container.empty()
while (chunk is not None) or (current_audio is not None):
logger.trace("Waiting for text or audio chunks...")
# Render text
with contextlib.suppress(queue.Empty):
chunk = question_answer_chunks_queue.get_nowait()
if chunk is not None:
full_response += chunk
text_reply_container.markdown(full_response + "▌")
question_answer_chunks_queue.task_done()

# Render audio (if any)
with contextlib.suppress(queue.Empty):
current_audio = (
self.chat_obj.play_speech_queue.get_nowait()
)
self.chat_obj.play_speech_queue.task_done()
if current_audio is None:
partial_audios_queue.put(None)
else:
partial_audios_queue.put(current_audio.speech)

logger.opt(colors=True).debug(
"<yellow>Replied to user prompt '{}': {}</yellow>",
prompt,
full_response,
)
text_reply_container.caption(
datetime.datetime.now().replace(microsecond=0)
)
text_reply_container.markdown(full_response)

while play_partial_audios_thread.is_alive():
logger.trace(
"Waiting for partial audios to finish playing..."
)
time.sleep(0.1)

logger.debug("Getting path to full audio file...")
try:
full_audio_fpath = (
self.chat_obj.last_answer_full_audio_fpath.get(timeout=2)
)
except queue.Empty:
full_audio_fpath = None
logger.warning("Problem getting path to full audio file")
else:
logger.debug(
"Got path to full audio file: {}", full_audio_fpath
)
self.chat_obj.last_answer_full_audio_fpath.task_done()

# Process text and audio replies asynchronously
replier = AsyncReplier(self, prompt)
reply = replier.stream_text_and_audio_reply()
self.chat_history.append(
{
"role": "assistant",
"name": self.chat_obj.assistant_name,
"content": full_response,
"assistant_reply_audio_file": full_audio_fpath,
"content": reply["text"],
"assistant_reply_audio_file": reply["audio"],
}
)

if full_audio_fpath:
self.render_custom_audio_player(
full_audio_fpath,
parent_element=audio_reply_container,
autoplay=False,
)

# Reset title according to conversation initial contents
min_history_len_for_summary = 3
if (
Expand Down Expand Up @@ -550,6 +447,7 @@ def _render_chatbot_page(self): # noqa: PLR0915
app_page=self, the_queue=self.parent.text_prompt_queue
)

replier.join()
self.parent.reply_ongoing.clear()

if continuous_stt_prompt and not self.parent.reply_ongoing.is_set():
Expand Down Expand Up @@ -580,90 +478,3 @@ def _trim_page_padding():
else:
self._render_chatbot_page()
logger.debug("Reached the end of the chatbot page.")


def filter_page_info_from_queue(app_page: AppPage, the_queue: queue.Queue):
"""Filter `app_page`'s data from `queue` inplace. Return queue of items in `app_page`.
**Use with original_queue.mutex!!**
Args:
app_page: The page whose entries should be removed.
the_queue: The queue to be filtered.
Returns:
queue.Queue: The queue with only the entries from `app_page`.
Example:
```
with the_queue.mutex:
this_page_data = remove_page_info_from_queue(app_page, the_queue)
```
"""
queue_with_only_entries_from_other_pages = queue.Queue()
items_from_page_queue = queue.Queue()
while the_queue.queue:
original_queue_entry = the_queue.queue.popleft()
if original_queue_entry["page"].page_id == app_page.page_id:
items_from_page_queue.put(original_queue_entry)
else:
queue_with_only_entries_from_other_pages.put(original_queue_entry)

the_queue.queue = queue_with_only_entries_from_other_pages.queue
return items_from_page_queue


@st.cache_data
def get_avatar_images():
"""Return the avatar images for the assistant and the user."""
avatar_files_dir = GeneralDefinitions.APP_DIR / "data"
assistant_avatar_file_path = avatar_files_dir / "assistant_avatar.png"
user_avatar_file_path = avatar_files_dir / "user_avatar.png"
assistant_avatar_image = Image.open(assistant_avatar_file_path)
user_avatar_image = Image.open(user_avatar_file_path)

return {"assistant": assistant_avatar_image, "user": user_avatar_image}


@st.cache_data
def _load_chime(chime_type: str) -> AudioSegment:
"""Load a chime sound from the data directory."""
type2filename = {
"correct-answer-tone": "mixkit-correct-answer-tone-2870.wav",
"option-select": "mixkit-interface-option-select-2573.wav",
}

return AudioSegment.from_file(
GeneralDefinitions.APP_DIR / "data" / type2filename[chime_type],
format="wav",
)


def _put_chat_reply_chunks_in_queue(chat_obj, prompt, question_answer_chunks_queue):
"""Get chunks of the reply to the prompt and put them in the queue."""
for chunk in chat_obj.answer_question(prompt):
question_answer_chunks_queue.put(chunk.content)
question_answer_chunks_queue.put(None)


def _play_queued_audios(audios_queue, audio_player_rendering_function, parent_element):
"""Play queued audio segments."""
logger.debug("Playing queued audios...")
while True:
try:
audio = audios_queue.get()
if audio is None:
audios_queue.task_done()
break

logger.debug("Playing partial audio...")
audio_player_rendering_function(
audio, parent_element=parent_element, autoplay=True, hidden=True
)
parent_element.empty()
audios_queue.task_done()
except Exception as error: # noqa: BLE001
logger.opt(exception=True).debug("Error playing partial audio.")
logger.error(error)
break
logger.debug("Done playing queued audios.")
Loading

0 comments on commit 6c36fe4

Please sign in to comment.