From f44e62bb419f5c24006f7c26ed022909166880bc Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Tue, 31 Dec 2024 12:50:53 -0800 Subject: [PATCH] Reply gracefully when chat model is not selected (#1183) * add *.chat files to gitignore * gracefully handle new messages without a selected chat model * pre-commit --- .gitignore | 1 + .../jupyter_ai/chat_handlers/base.py | 24 +++++++++++++------ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 0f1f752b4..0a70ee9a6 100644 --- a/.gitignore +++ b/.gitignore @@ -138,3 +138,4 @@ dev.sh # Version files are auto-generated by Hatchling and should not be committed to # the source repo. packages/**/_version.py +*.chat diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 327ff5965..07fad3304 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -171,16 +171,26 @@ async def on_message(self, message: Message): """ Method which receives a human message, calls `self.get_llm_chain()`, and processes the message via `self.process_message()`, calling - `self.handle_exc()` when an exception is raised. This method is called - by RootChatHandler when it routes a human message to this chat handler. + `self.handle_exc()` when an exception is raised. + + This is the method called directly in response to new chat messages. """ - lm_provider_klass = self.config_manager.lm_provider + ChatModelProvider = self.config_manager.lm_provider + + # first, ensure a chat model is configured + if not ChatModelProvider: + # TODO: update this message to be more useful once we improve + # ease-of-access to the Jupyter AI settings. + self.reply( + "To use Jupyter AI, please select a chat model first in the Jupyter AI settings." + ) + return # ensure the current slash command is supported if self.routing_type.routing_method == "slash_command": routing_type = cast(SlashCommandRoutingType, self.routing_type) slash_command = "/" + routing_type.slash_id if routing_type.slash_id else "" - if slash_command in lm_provider_klass.unsupported_slash_commands: + if slash_command in ChatModelProvider.unsupported_slash_commands: self.reply( "Sorry, the selected language model does not support this slash command.", ) @@ -188,10 +198,10 @@ async def on_message(self, message: Message): # check whether the configured LLM can support a request at this time. if self.uses_llm and BaseChatHandler._requests_count > 0: - lm_provider_params = self.config_manager.lm_provider_params - lm_provider = lm_provider_klass(**lm_provider_params) + chat_model_args = self.config_manager.lm_provider_params + chat_model = ChatModelProvider(**chat_model_args) - if not lm_provider.allows_concurrency: + if not chat_model.allows_concurrency: self.reply( "The currently selected language model can process only one request at a time. Please wait for me to reply before sending another question.", message,