From 37bf01f90f39bb33026629c1c708d00e892d662e Mon Sep 17 00:00:00 2001 From: Daniel McKnight <34697904+NeonDaniel@users.noreply.github.com> Date: Wed, 8 Nov 2023 09:11:20 -0800 Subject: [PATCH] Threaded input handling and multi-session support (#31) * Prevent sending an input until the previous response has been handled This would ideally use a queue but that will require using a different UI since the gradio ChatBot expects each input to return a value synchronously Relates to #26 * Implement gradio State to track a session ID Update handling so TTS responses are attached to a specific browser session * Implement session-specific profile settings * Add remaining user profile params to UI --------- Co-authored-by: Daniel McKnight --- neon_iris/client.py | 27 ++++++--- neon_iris/web_client.py | 128 ++++++++++++++++++++++++++++++---------- 2 files changed, 115 insertions(+), 40 deletions(-) diff --git a/neon_iris/client.py b/neon_iris/client.py index 3c9bdc5..cb637f7 100644 --- a/neon_iris/client.py +++ b/neon_iris/client.py @@ -38,6 +38,7 @@ from typing import Optional from uuid import uuid4 from ovos_bus_client.message import Message +from ovos_utils.json_helper import merge_dict from pika.exceptions import StreamLostError from neon_utils.configuration_utils import get_neon_user_config from neon_utils.mq_utils import NeonMQHandler @@ -228,27 +229,31 @@ def _clear_audio_cache(): def send_utterance(self, utterance: str, lang: str = "en-us", username: Optional[str] = None, - user_profiles: Optional[list] = None): + user_profiles: Optional[list] = None, + context: Optional[dict] = None): """ Optionally override this to queue text inputs or do any pre-parsing :param utterance: utterance to submit to skills module :param lang: language code associated with request :param username: username associated with request :param user_profiles: user profiles expecting a response + :param context: Optional dict context to add to emitted message """ - self._send_utterance(utterance, lang, username, user_profiles) + self._send_utterance(utterance, lang, username, user_profiles, context) def send_audio(self, audio_file: str, lang: str = "en-us", username: Optional[str] = None, - user_profiles: Optional[list] = None): + user_profiles: Optional[list] = None, + context: Optional[dict] = None): """ Optionally override this to queue audio inputs or do any pre-parsing :param audio_file: path to audio file to send to speech module :param lang: language code associated with request :param username: username associated with request :param user_profiles: user profiles expecting a response + :param context: Optional dict context to add to emitted message """ - self._send_audio(audio_file, lang, username, user_profiles) + self._send_audio(audio_file, lang, username, user_profiles, context) def _build_message(self, msg_type: str, data: dict, username: Optional[str] = None, @@ -267,7 +272,9 @@ def _build_message(self, msg_type: str, data: dict, }) def _send_utterance(self, utterance: str, lang: str, - username: str, user_profiles: list): + username: str, user_profiles: list, + context: Optional[dict] = None): + context = context or dict() username = username or self.default_username user_profiles = user_profiles or [self.user_config] message = self._build_message("recognizer_loop:utterance", @@ -275,11 +282,14 @@ def _send_utterance(self, utterance: str, lang: str, "lang": lang}, username, user_profiles) serialized = {"msg_type": message.msg_type, "data": message.data, - "context": message.context} + "context": merge_dict(message.context, context, + new_only=True)} self._send_serialized_message(serialized) def _send_audio(self, audio_file: str, lang: str, - username: str, user_profiles: list): + username: str, user_profiles: list, + context: Optional[dict] = None): + context = context or dict() audio_data = encode_file_to_base64_string(audio_file) message = self._build_message("neon.audio_input", {"lang": lang, @@ -289,7 +299,8 @@ def _send_audio(self, audio_file: str, lang: str, username, user_profiles) serialized = {"msg_type": message.msg_type, "data": message.data, - "context": message.context} + "context": merge_dict(message.context, context, + new_only=True)} self._send_serialized_message(serialized) def _send_serialized_message(self, serialized: dict): diff --git a/neon_iris/web_client.py b/neon_iris/web_client.py index 7e6f92c..c8ae560 100644 --- a/neon_iris/web_client.py +++ b/neon_iris/web_client.py @@ -27,7 +27,8 @@ from os import makedirs from os.path import isfile, join, isdir from time import time -from typing import List, Optional +from typing import List, Optional, Dict +from uuid import uuid4 import gradio @@ -35,6 +36,8 @@ from ovos_bus_client import Message from ovos_config import Configuration from ovos_utils import LOG +from ovos_utils.json_helper import merge_dict + from neon_utils.file_utils import decode_base64_string_to_file from ovos_utils.xdg_utils import xdg_data_home @@ -50,12 +53,15 @@ def __init__(self, lang: str = None): NeonAIClient.__init__(self, config.get("MQ")) self._await_response = Event() self._response = None - self._current_tts = None + self._current_tts = dict() + self._profiles: Dict[str, dict] = dict() self._audio_path = join(xdg_data_home(), "iris", "stt") if not isdir(self._audio_path): makedirs(self._audio_path) self.default_lang = lang or self.config.get('default_lang') self.chat_ui = gradio.Blocks() + LOG.name = "iris" + LOG.init(self.config.get("logs")) @property def lang(self): @@ -69,24 +75,52 @@ def supported_languages(self) -> List[str]: """ return self.config.get('languages') or [self.default_lang] - def update_profile(self, stt_lang: str, tts_lang: str, tts_lang_2: str): + def _start_session(self): + sid = uuid4().hex + self._current_tts[sid] = None + self._profiles[sid] = self.user_config + self._profiles[sid]['user']['username'] = sid + return sid + + def update_profile(self, stt_lang: str, tts_lang: str, tts_lang_2: str, + time: int, date: str, uom: str, city: str, state: str, + country: str, first: str, middle: str, last: str, + pref_name: str, email: str, session_id: str): """ Callback to handle user settings changes from the web UI """ - # TODO: Per-client config. The current method of referencing - # `self._user_config` means every user shares one configuration which - # does not scale. This client should probably override the - # `self.user_config` property and implement a method for storing user - # configuration in cookies or similar. + location_dict = dict() + if any((city, state, country)): + from neon_utils.location_utils import get_coordinates, get_timezone + try: + location_dict = {"city": city, "state": state, + "country": country} + lat, lon = get_coordinates(location_dict) + location_dict["lat"] = lat + location_dict["lng"] = lon + location_dict["tz"], location_dict["utc"] = get_timezone(lat, + lon) + LOG.debug(f"Got location update: {location_dict}") + except Exception as e: + LOG.exception(e) + profile_update = {"speech": {"stt_language": stt_lang, "tts_language": tts_lang, - "secondary_tts_language": tts_lang_2}} - from neon_utils.user_utils import apply_local_user_profile_updates - apply_local_user_profile_updates(profile_update, self._user_config) + "secondary_tts_language": tts_lang_2}, + "units": {"time": time, "date": date, "measure": uom}, + "location": location_dict, + "user": {"first_name": first, "middle_name": middle, + "last_name": last, + "preferred_name": pref_name, "email": email}} + old_profile = self._profiles.get(session_id) or self.user_config + self._profiles[session_id] = merge_dict(old_profile, profile_update) + LOG.info(f"Updated profile for: {session_id}") + return session_id def send_audio(self, audio_file: str, lang: str = "en-us", username: Optional[str] = None, - user_profiles: Optional[list] = None): + user_profiles: Optional[list] = None, + context: Optional[dict] = None): """ @param audio_file: path to wav audio file to send to speech module @param lang: language code associated with request @@ -95,7 +129,7 @@ def send_audio(self, audio_file: str, lang: str = "en-us", """ # TODO: Audio conversion is really slow here. check ovos-stt-http-server audio_file = self.convert_audio(audio_file) - self._send_audio(audio_file, lang, username, user_profiles) + self._send_audio(audio_file, lang, username, user_profiles, context) def convert_audio(self, audio_file: str, target_sr=16000, target_channels=1, dtype='int16') -> str: @@ -128,29 +162,37 @@ def on_user_input(self, utterance: str, *args, **kwargs) -> str: @param utterance: String utterance submitted by the user @returns: String response from Neon (or "ERROR") """ - # TODO: This should probably queue with a separate iterator thread + LOG.debug(f"Input received") + if not self._await_response.wait(30): + LOG.error("Previous response not completed after 30 seconds") LOG.debug(f"args={args}|kwargs={kwargs}") self._await_response.clear() self._response = None + gradio_id = args[2] if utterance: LOG.info(f"Sending utterance: {utterance} with lang: {self.lang}") - self.send_utterance(utterance, self.lang) + self.send_utterance(utterance, self.lang, username=gradio_id, + user_profiles=[self._profiles[gradio_id]], + context={"gradio": {"session": gradio_id}}) else: LOG.info(f"Sending audio: {args[1]} with lang: {self.lang}") - self.send_audio(args[1], self.lang) + self.send_audio(args[1], self.lang, username=gradio_id, + user_profiles=[self._profiles[gradio_id]], + context={"gradio": {"session": gradio_id}}) self._await_response.wait(30) self._response = self._response or "ERROR" LOG.info(f"Got response={self._response}") return self._response - def play_tts(self): + def play_tts(self, session_id: str): LOG.info(f"Playing most recent TTS file {self._current_tts}") - return self._current_tts + return self._current_tts.get(session_id), session_id def run(self): """ Blocking method to start the web server """ + self._await_response.set() title = self.config.get("webui_title", "Neon AI") description = self.config.get("webui_description", "Chat With Neon") chatbot = self.config.get("webui_chatbot_label") or description @@ -164,6 +206,8 @@ def run(self): textbox = gradio.Textbox(placeholder=placeholder) with self.chat_ui as blocks: + client_session = gradio.State(self._start_session()) + client_session.attach_load_event(self._start_session, None) # Define primary UI audio_input = gradio.Audio(source="microphone", type="filepath", @@ -171,7 +215,7 @@ def run(self): gradio.ChatInterface(self.on_user_input, chatbot=chatbot, textbox=textbox, - additional_inputs=[audio_input], + additional_inputs=[audio_input, client_session], title=title, retry_btn=None, undo_btn=None, ) @@ -179,7 +223,8 @@ def run(self): label="Neon's Response") tts_button = gradio.Button("Play TTS") tts_button.click(self.play_tts, - outputs=[tts_audio]) + inputs=[client_session], + outputs=[tts_audio, client_session]) # Define settings UI with gradio.Row(): with gradio.Column(): @@ -193,18 +238,36 @@ def run(self): choices=[None] + self.supported_languages, value=None) - submit = gradio.Button("Update User Settings") with gradio.Column(): - # TODO: Unit settings - pass + time_format = gradio.Radio(label="Time Format", + choices=[12, 24], + value=12) + date_format = gradio.Radio(label="Date Format", + choices=["MDY", "YMD", "DMY", + "YDM"], + value="MDY") + unit_of_measure = gradio.Radio(label="Units of Measure", + choices=["imperial", + "metric"], + value="imperial") with gradio.Column(): - # TODO: Location settings - pass + city = gradio.Textbox(label="City") + state = gradio.Textbox(label="State") + country = gradio.Textbox(label="Country") with gradio.Column(): - # TODO Name settings - pass + first_name = gradio.Textbox(label="First Name") + middle_name = gradio.Textbox(label="Middle Name") + last_name = gradio.Textbox(label="Last Name") + pref_name = gradio.Textbox(label="Preferred Name") + email_addr = gradio.Textbox(label="Email Address") + # TODO: DoB, pic, about, phone? + submit = gradio.Button("Update User Settings") submit.click(self.update_profile, - inputs=[stt_lang, tts_lang, tts_lang_2]) + inputs=[stt_lang, tts_lang, tts_lang_2, time_format, + date_format, unit_of_measure, city, state, + country, first_name, middle_name, last_name, + pref_name, email_addr, client_session], + outputs=[client_session]) blocks.launch(server_name=address, server_port=port) def handle_klat_response(self, message: Message): @@ -213,19 +276,20 @@ def handle_klat_response(self, message: Message): audio in all requested languages. @param message: Neon response message """ - LOG.debug(f"Response_data={message.data}") + LOG.debug(f"gradio context={message.context['gradio']}") resp_data = message.data["responses"] files = [] sentences = [] + session = message.context['gradio']['session'] for lang, response in resp_data.items(): sentences.append(response.get("sentence")) if response.get("audio"): for gender, data in response["audio"].items(): filepath = "/".join([self.audio_cache_dir] + response[gender].split('/')[-4:]) - # TODO: This only plays the most recent, so it doesn't support - # multiple languages - self._current_tts = filepath + # TODO: This only plays the most recent, so it doesn't + # support multiple languages or multi-utterance responses + self._current_tts[session] = filepath files.append(filepath) if not isfile(filepath): decode_base64_string_to_file(data, filepath)