From 3b49823a388a9a4fd898125da0a37b8d35f3e688 Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Fri, 8 Nov 2024 16:22:50 +0100 Subject: [PATCH] Simplify Agent.respond signature (#742) Co-authored-by: Andrew Huang --- lumen/ai/agents.py | 209 ++++++++++++++++++++++-------------------- lumen/ai/assistant.py | 35 ++++--- 2 files changed, 135 insertions(+), 109 deletions(-) diff --git a/lumen/ai/agents.py b/lumen/ai/agents.py index b2276ad74..8227a1c44 100644 --- a/lumen/ai/agents.py +++ b/lumen/ai/agents.py @@ -4,7 +4,7 @@ import re import textwrap -from typing import Literal +from typing import Any, Literal import pandas as pd import panel as pn @@ -13,6 +13,7 @@ from instructor.retry import InstructorRetryException from panel.chat import ChatInterface +from panel.layout import Column from panel.viewable import Viewer from pydantic import BaseModel, create_model from pydantic.fields import FieldInfo @@ -50,28 +51,40 @@ class Agent(Viewer): embeddings. """ - debug = param.Boolean(default=False) + debug = param.Boolean(default=False, doc=""" + Whether to enable verbose error reporting.""") - embeddings = param.ClassSelector(class_=Embeddings) + embeddings = param.ClassSelector(class_=Embeddings, doc=""" + Embeddings object which is queried to provide additional context + before asking the LLM to respond.""") - interface = param.ClassSelector(class_=ChatInterface) + interface = param.ClassSelector(class_=ChatInterface, doc=""" + The ChatInterface to report progress to.""") - llm = param.ClassSelector(class_=Llm) + llm = param.ClassSelector(class_=Llm, doc=""" + The LLM implementation to query.""") - system_prompt = param.String() + steps_layout = param.ClassSelector(default=None, class_=Column, allow_None=True, doc=""" + The layout progress updates will be streamed to.""") - response_model = param.ClassSelector(class_=BaseModel, is_instance=False) + provides = param.List(default=[], readonly=True, doc=""" + List of context values this Agent provides to current working memory.""") - user = param.String(default="Agent") + requires = param.List(default=[], readonly=True, doc=""" + List of context that this Agent requires to be in memory.""") - requires = param.List(default=[], readonly=True) + response_model = param.ClassSelector(class_=BaseModel, is_instance=False, doc=""" + A Pydantic model determining the schema of the response.""") - provides = param.List(default=[], readonly=True) + system_prompt = param.String(doc="The system prompt.") - _steps_layout = param.ClassSelector(default=None, class_=pn.Column) + user = param.String(default="Agent", doc=""" + The name of the user that will be respond to the user query.""") + # Panel extensions this agent requires to be loaded _extensions = () + # Maximum width of the output _max_width = 1200 __abstract = True @@ -100,13 +113,6 @@ def _exception_handler(exception): else: state.config.raise_with_notifications = True - @classmethod - async def applies(cls) -> bool: - """ - Additional checks to determine if the agent should be used. - """ - return True - async def _interface_callback(self, contents: list | str, user: str, instance: ChatInterface): await self.respond(contents) self._retries_left = 1 @@ -180,30 +186,51 @@ async def _select_table(self, tables): self.interface.pop(-1) return tables + # Public API + + @classmethod + async def applies(cls) -> bool: + """ + Additional checks to determine if the agent should be used. + """ + return True + async def requirements(self, messages: list[Message]) -> list[str]: return self.requires async def respond( - self, messages: list[Message], - title: str = "", - render_output: bool = True, - steps_layout: pn.Column | None = None - ) -> None: - self._steps_layout = steps_layout - + self, + messages: list[Message], + render_output: bool = False, + step_title: str | None = None + ) -> Any: + """ + Provides a response to the user query. + + The type of the response may be a simple string or an object. + + Arguments + --------- + messages: list[Message] + The list of messages corresponding to the user query and any other + system messages to be included. + render_output: bool + Whether to render the output to the chat interface. + step_title: str | None + If the Agent response is part of a longer query this describes + the step currently being processed. + """ system_prompt = await self._system_prompt_with_context(messages) response = self.llm.stream( messages, system=system_prompt, response_model=self.response_model, field="output" ) - - if not render_output: - return - - message = None - async for output_chunk in response: - message = self.interface.stream( - output_chunk, replace=True, message=message, user=self.user, max_width=self._max_width - ) + if render_output: + message = None + async for output_chunk in response: + message = self.interface.stream( + output_chunk, replace=True, message=message, user=self.user, max_width=self._max_width + ) + return response class SourceAgent(Agent): @@ -223,18 +250,17 @@ class SourceAgent(Agent): _extensions = ('filedropper',) async def respond( - self, messages: list, - title: str = "", - render_output: bool = True, - steps_layout: pn.Column | None = None - ) -> None: - self._steps_layout = steps_layout - if not render_output: - return - + self, + messages: list[Message], + render_output: bool = False, + step_title: str | None = None + ) -> Any: source_controls = SourceControls(multiple=True, replace_controls=True, select_existing=False) - while not source_controls._add_button.clicks > 0: - await asyncio.sleep(0.05) + if render_output: + self.interface.send(source_controls, respond=False, user="SourceAgent") + while not source_controls._add_button.clicks > 0: + await asyncio.sleep(0.05) + return source_controls class ChatAgent(Agent): @@ -266,7 +292,7 @@ async def requirements(self, messages: list[Message], errors=None): available_sources = memory["available_sources"] _, tables_schema_str = await gather_table_sources(available_sources) - with self.interface.add_step(title="Checking if data is required", steps_layout=self._steps_layout) as step: + with self.interface.add_step(title="Checking if data is required", steps_layout=self.steps_layout) as step: response = self.llm.stream( messages, system=( @@ -371,12 +397,8 @@ def _render_lumen( self, component: Component, message: pn.chat.ChatMessage = None, - render_output: bool = True, **kwargs ): - if not render_output: - return - out = self._output_type(component=component, **kwargs) message_kwargs = dict(value=out, user=self.user) self.interface.stream(message=message, **message_kwargs, replace=True, max_width=self._max_width) @@ -411,20 +433,12 @@ def _use_table(self, event): async def respond( self, messages: list[Message], - title: str = "", - render_output: bool = True, - steps_layout: pn.Column | None = None - ) -> None: - self._steps_layout = steps_layout + render_output: bool = False, + step_title: str | None = None + ) -> Any: tables = [] for source in memory['available_sources']: tables += source.get_tables() - if not tables: - return - - if not render_output: - return - self._df = pd.DataFrame({"Table": tables}) table_list = pn.widgets.Tabulator( self._df, @@ -440,7 +454,9 @@ async def respond( } ) table_list.on_click(self._use_table) - self.interface.stream(table_list, user="Lumen") + if render_output: + self.interface.stream(table_list, user="Lumen") + return table_list class SQLAgent(LumenBaseAgent): @@ -480,7 +496,7 @@ async def _select_relevant_table(self, messages: list[Message]) -> tuple[str, Ba elif len(tables) == 1: table = tables[0] else: - with self.interface.add_step(title="Choosing the most relevant table...", steps_layout=self._steps_layout) as step: + with self.interface.add_step(title="Choosing the most relevant table...", steps_layout=self.steps_layout) as step: closest_tables = memory.pop("closest_tables", []) if closest_tables: tables = closest_tables @@ -534,7 +550,7 @@ async def _create_valid_sql( } ] - with self.interface.add_step(title=title or "SQL query", steps_layout=self._steps_layout) as step: + with self.interface.add_step(title=title or "SQL query", steps_layout=self.steps_layout) as step: response = self.llm.stream(messages, system=system, response_model=Sql) sql_query = None async for output in response: @@ -615,7 +631,7 @@ async def _check_join_required( schema, table: str ): - with self.interface.add_step(title="Checking if join is required", steps_layout=self._steps_layout) as step: + with self.interface.add_step(title="Checking if join is required", steps_layout=self.steps_layout) as step: join_prompt = render_template( "join_required.jinja2", schema=yaml.dump(schema), @@ -646,7 +662,7 @@ async def find_join_tables(self, messages: list): "find_joins.jinja2", available_tables=available_tables ) - with self.interface.add_step(title="Determining tables required for join", steps_layout=self._steps_layout) as step: + with self.interface.add_step(title="Determining tables required for join", steps_layout=self.steps_layout) as step: output = await self.llm.invoke( messages, system=find_joins_prompt, @@ -683,10 +699,9 @@ async def find_join_tables(self, messages: list): async def respond( self, messages: list[Message], - title: str = "", - render_output: bool = True, - steps_layout: pn.Column | None = None - ) -> None: + render_output: bool = False, + step_title: str | None = None + ) -> Any: """ Steps: 1. Retrieve the current source and table from memory. @@ -699,9 +714,7 @@ async def respond( 8. If a join is required, remove source/table prefixes from the last message. 9. Construct the SQL query with `_create_valid_sql`. """ - self._steps_layout = steps_layout table, source = await self._select_relevant_table(messages) - if not hasattr(source, "get_sql_expr"): return None @@ -745,9 +758,11 @@ async def respond( if join_required: # Remove source prefixes message, e.g. //// messages[-1]["content"] = re.sub(r"//[^/]+//", "", messages[-1]["content"]) - sql_query = await self._create_valid_sql(messages, system, tables_to_source, title) + sql_query = await self._create_valid_sql(messages, system, tables_to_source, step_title) pipeline = memory['current_pipeline'] - self._render_lumen(pipeline, spec=sql_query, render_output=render_output) + if render_output: + self._render_lumen(pipeline, spec=sql_query) + return pipeline class BaseViewAgent(LumenBaseAgent): @@ -762,11 +777,9 @@ async def _extract_spec(self, model: BaseModel): async def respond( self, messages: list[Message], - title: str = "", - render_output: bool = True, - steps_layout: pn.Column | None = None - ) -> None: - self._steps_layout = steps_layout + render_output: bool = False, + step_title: str | None = None + ) -> Any: pipeline = memory["current_pipeline"] # Write prompts @@ -790,12 +803,17 @@ async def respond( spec = await self._extract_spec(output) chain_of_thought = spec.pop("chain_of_thought", None) if chain_of_thought: - with self.interface.add_step(title=title or "Generating view...", steps_layout=self._steps_layout) as step: + with self.interface.add_step( + title=step_title or "Generating view...", + steps_layout=self.steps_layout + ) as step: step.stream(chain_of_thought) print(f"{self.name} settled on {spec=!r}.") memory["current_view"] = dict(spec, type=self.view_type) view = self.view_type(pipeline=pipeline, **spec) - self._render_lumen(view, render_output=render_output) + if render_output: + self._render_lumen(view) + return view class hvPlotAgent(BaseViewAgent): @@ -862,11 +880,8 @@ class VegaLiteAgent(BaseViewAgent): If the user asks to plot, visualize or render the data this is your best best. """ - system_prompt = param.String( - default=""" - Generate the plot the user requested as a vega-lite specification. - """ - ) + system_prompt = param.String(default=""" + Generate the plot the user requested as a vega-lite specification.""") view_type = VegaLiteView @@ -933,12 +948,10 @@ async def _system_prompt_with_context( async def respond( self, messages: list[Message], - title: str = "", - render_output: bool = True, - steps_layout: pn.Column | None = None, + render_output: bool = False, + step_title: str | None = None, agents: list[Agent] | None = None - ) -> None: - self._steps_layout = steps_layout + ) -> Any: pipeline = memory['current_pipeline'] analyses = {a.name: a for a in self.analyses if await a.applies(pipeline)} if not analyses: @@ -952,7 +965,7 @@ async def respond( analyses = {analysis: analyses[analysis]} if len(analyses) > 1: - with self.interface.add_step(title="Choosing the most relevant analysis...", steps_layout=self._steps_layout) as step: + with self.interface.add_step(title="Choosing the most relevant analysis...", steps_layout=self.steps_layout) as step: type_ = Literal[tuple(analyses)] analysis_model = create_model( "Analysis", @@ -970,7 +983,7 @@ async def respond( else: analysis_name = next(iter(analyses)) - with self.interface.add_step(title=title or "Creating view...", steps_layout=self._steps_layout) as step: + with self.interface.add_step(title=step_title or "Creating view...", steps_layout=self.steps_layout) as step: await asyncio.sleep(0.1) # necessary to give it time to render before calling sync function... analysis_callable = analyses[analysis_name].instance(agents=agents) @@ -1003,7 +1016,9 @@ async def respond( view = None analysis = memory["current_analysis"] - if view is None and analysis.autorun: - self.interface.stream('Failed to find an analysis that applies to this data') - else: - self._render_lumen(view, analysis=analysis, pipeline=memory['current_pipeline'], render_output=render_output) + if render_output: + if view is None and analysis.autorun: + self.interface.stream('Failed to find an analysis that applies to this data') + else: + self._render_lumen(view, analysis=analysis, pipeline=memory['current_pipeline']) + return view diff --git a/lumen/ai/assistant.py b/lumen/ai/assistant.py index 13ca339f5..7ee45123f 100644 --- a/lumen/ai/assistant.py +++ b/lumen/ai/assistant.py @@ -213,7 +213,10 @@ async def use_suggestion(event): else: print("No analysis agent found.") return - await agent.respond([{'role': 'user', 'content': contents}], agents=self.agents) + messages = [{'role': 'user', 'content': contents}] + await agent.respond( + messages, render_output=True, agents=self.agents + ) await self._add_analysis_suggestions() else: self.interface.send(contents) @@ -449,13 +452,15 @@ async def _get_agent_chain_link(self, messages: list[Message]) -> AgentChainLink if instruction: custom_messages.append({"role": "user", "content": instruction}) - respond_kwargs = {} # attach the new steps to the existing steps--used when there is intermediate Lumen output - last_steps_message = self.interface.objects[-2] - if last_steps_message.user == "Assistant" and isinstance(last_steps_message.object, Card): - respond_kwargs["steps_layout"] = last_steps_message.object - - await subagent.respond(custom_messages, title=title, render_output=render_output, **respond_kwargs) + steps_layout = None + for step_message in reversed(self.interface.objects[-5:]): + if step_message.user == "Assistant" and isinstance(step_message.object, Card): + steps_layout = step_message.object + break + + with subagent.param.update(steps_layout=steps_layout): + await subagent.respond(custom_messages, step_title=title, render_output=render_output) step.stream(f"`{agent_name}` agent successfully completed the following task:\n\n- {instruction}", replace=True) step.success_title = f"{agent_name} agent successfully responded" @@ -515,14 +520,20 @@ async def invoke(self, messages: list[Message]) -> str: print("\n\033[95mAGENT:\033[0m", agent, messages[-context_length:]) - last_steps_message = self.interface.objects[-2] - respond_kwargs = {"title": title, "render_output": True} # attach the new steps to the existing steps--used when there is intermediate Lumen output - if last_steps_message.user == "Assistant" and isinstance(last_steps_message.object, Card): - respond_kwargs["steps_layout"] = last_steps_message.object + steps_layout = None + for step_message in reversed(self.interface.objects[-5:]): + if step_message.user == "Assistant" and isinstance(step_message.object, Card): + steps_layout = step_message.object + break + + respond_kwargs = {} if isinstance(agent, AnalysisAgent): respond_kwargs["agents"] = self.agents - await agent.respond(messages[-context_length:], **respond_kwargs) + with agent.param.update(steps_layout=steps_layout): + await agent.respond( + messages[-context_length:], step_title=title, render_output=True, **respond_kwargs + ) self._current_agent.object = "## No agent active" if "current_pipeline" in agent.provides: await self._add_analysis_suggestions()