Skip to content

Commit

Permalink
Exploring ways to keep audios in chat history
Browse files Browse the repository at this point in the history
  • Loading branch information
paulovcmedeiros committed Feb 23, 2024
1 parent b4f3981 commit 54fdd96
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 9 deletions.
18 changes: 12 additions & 6 deletions pyrobbot/app/app_page_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,16 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)
self.tts_conversion_watcher_thread.start()

def speak(self, tts: TextToSpeech):
def speak(self, tts: TextToSpeech, autoplay: bool = True):
"""Autoplay an audio segment in the streamlit app."""
# Adaped from: <https://discuss.streamlit.io/t/
# how-to-play-an-audio-file-automatically-generated-using-text-to-speech-
# in-streamlit/33201/2>
autoplay = "true" if autoplay else "false"
data = tts.speech.export(format="mp3").read()
b64 = base64.b64encode(data).decode()
md = f"""
<audio controls autoplay="true">
<audio controls autoplay="{autoplay}">
<source src="data:audio/mp3;base64,{b64}" type="audio/mp3">
</audio>
"""
Expand Down Expand Up @@ -201,6 +202,8 @@ def render_chat_history(self):
with contextlib.suppress(KeyError):
st.caption(message["timestamp"])
st.markdown(message["content"])
with contextlib.suppress(KeyError):
st.audio(message["assistant_reply_audio_file"], format="audio/mp3")

def render_cost_estimate_page(self):
"""Render the estimated costs information in the chat."""
Expand Down Expand Up @@ -262,15 +265,18 @@ def _render_chatbot_page(self):
with st.empty():
st.markdown("▌")
full_response = ""
# When the chat object answers the user's question, it will
# put the response in the tts queue, then in the play speech
# queue, assynchronously
for chunk in self.chat_obj.answer_question(prompt):
full_response += chunk.content
st.markdown(full_response + "▌")
st.caption(datetime.datetime.now().replace(microsecond=0))
st.markdown(full_response)
if mic_input:
while not self.chat_obj.play_speech_queue.empty():
self.chat_obj.speak(self.chat_obj.play_speech_queue.get())
self.chat_obj.play_speech_queue.task_done()
while not self.chat_obj.play_speech_queue.empty():
self.chat_obj.speak(self.chat_obj.play_speech_queue.get())
self.chat_obj.play_speech_queue.task_done()

prompt = None

self.chat_history.append(
Expand Down
6 changes: 4 additions & 2 deletions pyrobbot/chat_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def add_to_history(self, msg_list: list[dict]):

def load_history(self) -> list[dict]:
"""Load the chat history."""
selected_columns = ["timestamp", "message_exchange"]
selected_columns = ["timestamp", "message_exchange", "assistant_reply_audio_file"]
messages_df = self.database.get_messages_dataframe()[selected_columns]

# Convert unix timestamps to datetime objs at the local timezone
Expand All @@ -60,9 +60,11 @@ def load_history(self) -> list[dict]:
)

msg_exchanges = messages_df["message_exchange"].apply(ast.literal_eval).tolist()
# Add timestamps to messages
# Add timestamps and path to eventual audio files to messages
for i_msg_exchange, timestamp in enumerate(messages_df["timestamp"]):
msg_exchanges[i_msg_exchange][0]["timestamp"] = timestamp
path = messages_df["assistant_reply_audio_file"].iloc[i_msg_exchange]
msg_exchanges[i_msg_exchange][1]["assistant_reply_audio_file"] = path

return list(itertools.chain.from_iterable(msg_exchanges))

Expand Down
19 changes: 18 additions & 1 deletion pyrobbot/embeddings_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def create(self):
timestamp INTEGER NOT NULL,
chat_model TEXT NOT NULL,
message_exchange TEXT NOT NULL,
assistant_reply_audio_file TEXT,
embedding TEXT
)
"""
Expand All @@ -65,7 +66,10 @@ def create(self):
conn.execute(
"""
CREATE TRIGGER IF NOT EXISTS prevent_messages_modification
BEFORE UPDATE ON messages
BEFORE UPDATE OF
timestamp, chat_model, message_exchange, embedding
ON
messages
BEGIN
SELECT RAISE(FAIL, 'modification not allowed');
END;
Expand Down Expand Up @@ -126,6 +130,19 @@ def insert_message_exchange(self, chat_model, message_exchange, embedding):
conn.execute(sql, (timestamp, chat_model, message_exchange, embedding))
conn.close()

def update_last_message_exchange_with_audio(self, assistant_reply_audio_file: Path):
"""Update the last message exchange in the database's 'messages' table.
Args:
assistant_reply_audio_file (Path): Path to the assistant's reply audio file.
"""
conn = sqlite3.connect(self.db_path)
sql = "UPDATE messages SET assistant_reply_audio_file = ? WHERE "
sql += "rowid = (SELECT MAX(rowid) FROM messages);"
with conn:
conn.execute(sql, (assistant_reply_audio_file.as_posix(),))
conn.close()

def get_messages_dataframe(self):
"""Retrieve msg exchanges from the `messages` table. Return them as a DataFrame.
Expand Down
30 changes: 30 additions & 0 deletions pyrobbot/voice_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def __init__(self, configs: VoiceChatConfigs = default_configs, **kwargs):
self.interrupt_reply = threading.Event()
self.exit_chat = threading.Event()

self.current_answer_audios_queue = queue.Queue()

def start(self):
"""Start the chat."""
# ruff: noqa: T201
Expand Down Expand Up @@ -158,6 +160,9 @@ def answer_question(self, question: str):
"""Answer a question."""
logger.debug("{}> Getting response to '{}'...", self.assistant_name, question)
sentence_for_tts = ""
with self.current_answer_audios_queue.mutex:
self.current_answer_audios_queue.queue.clear()

for answer_chunk in self.respond_user_prompt(prompt=question):
if self.interrupt_reply.is_set() or self.exit_chat.is_set():
raise StopIteration
Expand All @@ -181,6 +186,22 @@ def answer_question(self, question: str):
if sentence_for_tts and not self.reply_only_as_text:
self.tts_conversion_queue.put(sentence_for_tts)

# Merge all AudioSegments in self.current_answer_audios_queue into a single one
self.tts_conversion_queue.join()
merged_audio = AudioSegment.empty()
while not self.current_answer_audios_queue.empty():
merged_audio += self.current_answer_audios_queue.get()
self.current_answer_audios_queue.task_done()

# Save the combined audio as an mp3 file in the cache directory
audio_file_path = self.audio_cache_dir() / f"{datetime.now().isoformat()}.mp3"
merged_audio.export(audio_file_path, format="mp3")

# Update the chat history with the audio file path
self.context_handler.database.update_last_message_exchange_with_audio(
assistant_reply_audio_file=audio_file_path
)

def speak(self, tts: TextToSpeech):
"""Reproduce audio from a pygame Sound object."""
tts.set_sample_rate(self.sample_rate)
Expand Down Expand Up @@ -371,6 +392,9 @@ def handle_tts_queue(self, text_queue: queue.Queue):

# Dispatch the audio to be played
self.play_speech_queue.put(tts)

# Keep track of audios played for the current answer
self.current_answer_audios_queue.put(tts.speech)
except Exception as error: # noqa: PERF203, BLE001
logger.opt(exception=True).debug(error)
logger.error(error)
Expand All @@ -388,6 +412,12 @@ def get_sound_file(self, wav_buffer: io.BytesIO, mode: str = "r"):
subtype="PCM_16",
)

def audio_cache_dir(self):
"""Return the audio cache directory."""
directory = self.cache_dir / "audio_files"
directory.mkdir(parents=True, exist_ok=True)
return directory

def _assistant_still_replying(self):
"""Check if the assistant is still talking."""
return (
Expand Down

0 comments on commit 54fdd96

Please sign in to comment.