From 381c8ae138178426f1d7b15937458f6a59fd2919 Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Wed, 20 Nov 2024 11:10:02 +0100 Subject: [PATCH] Refactor handling of outputs (#773) --- lumen/ai/agents.py | 15 ++- lumen/ai/coordinator.py | 2 +- lumen/ai/memory.py | 21 +++-- lumen/ai/models.py | 2 +- lumen/ai/prompts/SQLAgent/main.jinja2 | 1 + lumen/ai/ui.py | 130 ++++++++++++++++---------- lumen/ai/views.py | 45 +++++---- 7 files changed, 135 insertions(+), 81 deletions(-) diff --git a/lumen/ai/agents.py b/lumen/ai/agents.py index 2a1b84ae3..0cc67e368 100644 --- a/lumen/ai/agents.py +++ b/lumen/ai/agents.py @@ -358,13 +358,17 @@ def _render_lumen( component: Component, message: pn.chat.ChatMessage = None, render_output: bool = False, + title: str | None = None, **kwargs ): out = self._output_type( - component=component, render_output=render_output, **kwargs + component=component, render_output=render_output, title=title, **kwargs ) if 'outputs' in self._memory: - self._memory['outputs'].append(out) + # We have to create a new list to trigger an event + # since inplace updates will not trigger updates + # and won't allow diffing between old and new values + self._memory['outputs'] = self._memory['outputs']+[out] message_kwargs = dict(value=out, user=self.user) self.interface.stream(message=message, **message_kwargs, replace=True, max_width=self._max_width) @@ -724,7 +728,7 @@ async def respond( messages[-1]["content"] = re.sub(r"//[^/]+//", "", messages[-1]["content"]) sql_query = await self._create_valid_sql(messages, system_prompt, tables_to_source, step_title) pipeline = self._memory['current_pipeline'] - self._render_lumen(pipeline, spec=sql_query, render_output=render_output) + self._render_lumen(pipeline, spec=sql_query, render_output=render_output, title=step_title) return pipeline @@ -788,7 +792,7 @@ async def respond( print(f"{self.name} settled on spec: {spec!r}.") self._memory["current_view"] = dict(spec, type=self.view_type) view = self.view_type(pipeline=pipeline, **spec) - self._render_lumen(view, render_output=render_output) + self._render_lumen(view, render_output=render_output, title=step_title) return view @@ -979,6 +983,7 @@ async def respond( view, analysis=analysis, pipeline=pipeline, - render_output=render_output + render_output=render_output, + title=step_title ) return view diff --git a/lumen/ai/coordinator.py b/lumen/ai/coordinator.py index 4b94f5073..a4c411558 100644 --- a/lumen/ai/coordinator.py +++ b/lumen/ai/coordinator.py @@ -675,7 +675,7 @@ async def _compute_execution_graph(self, messages: list[Message], agents: dict[s istep.stream(f"The plan didn't account for {unmet_dependencies!r}", replace=True) else: planned = True - self._memory['current_plan'] = plan.title + self._memory['plan'] = plan istep.stream('\n\nHere are the steps:\n\n') for i, step in enumerate(plan.steps): istep.stream(f"{i+1}. {step.expert}: {step.instruction}\n") diff --git a/lumen/ai/memory.py b/lumen/ai/memory.py index 1d0c54adc..d9284ac48 100644 --- a/lumen/ai/memory.py +++ b/lumen/ai/memory.py @@ -14,13 +14,20 @@ def __init__(self, *args, **kwargs): self._callbacks = defaultdict(list) self._rx = {} - def __setitem__(self, key, value): - super().__setitem__(key, value) - self._trigger_update(key, value) + def __setitem__(self, key, new): + if key in self: + old = self[key] + else: + old = None + super().__setitem__(key, new) + self._trigger_update(key, old, new) def on_change(self, key, callback): self._callbacks[key].append(callback) + def remove_on_change(self, key, callback): + self._callbacks[key].remove(callback) + def rx(self, key): if key in self._rx: return self._rx[key] @@ -28,13 +35,13 @@ def rx(self, key): return rxp def trigger(self, key): - self._trigger_update(key, self[key]) + self._trigger_update(key, self[key], self[key]) - def _trigger_update(self, key, value): + def _trigger_update(self, key, old, new): for cb in self._callbacks[key]: - cb(value) + cb(key, old, new) if key in self._rx: - self._rx[key].rx.value = value + self._rx[key].rx.value = new memory = _Memory() diff --git a/lumen/ai/models.py b/lumen/ai/models.py index 046ef21e9..d24e12929 100644 --- a/lumen/ai/models.py +++ b/lumen/ai/models.py @@ -94,7 +94,7 @@ def make_plan_models(agent_names: list[str], tables: list[str]): "Step", expert=(Literal[agent_names], FieldInfo(description="The name of the expert to assign a task to.")), instruction=(str, FieldInfo(description="Instructions to the expert to assist in the task, and whether rendering is required.")), - title=(str, FieldInfo(description="Short title of the task to be performed; up to six words.")), + title=(str, FieldInfo(description="Short title of the task to be performed; up to three words.")), render_output=(bool, FieldInfo(description="Whether the output of the expert should be rendered. If the user wants to see the table, and the expert is SQL, then this should be `True`.")), ) extras = {} diff --git a/lumen/ai/prompts/SQLAgent/main.jinja2 b/lumen/ai/prompts/SQLAgent/main.jinja2 index 6d860f984..d6623bf71 100644 --- a/lumen/ai/prompts/SQLAgent/main.jinja2 +++ b/lumen/ai/prompts/SQLAgent/main.jinja2 @@ -27,6 +27,7 @@ Checklist - If it's a date column (excluding individual year/month/day integers) date, cast to date using appropriate syntax, e.g. CAST or TO_DATE - Use only `{{ dialect }}` syntax +- Try to pretty print the SQL output with newlines and indentation. {% if dialect == 'duckdb' %} - If the table name originally did not have `read_*` prefix, use the original table name - Use table names verbatim; e.g. if table is read_csv('table.csv') then use read_csv('table.csv') and not 'table' diff --git a/lumen/ai/ui.py b/lumen/ai/ui.py index 946313da9..1aac4d426 100644 --- a/lumen/ai/ui.py +++ b/lumen/ai/ui.py @@ -1,6 +1,7 @@ from __future__ import annotations from io import StringIO +from typing import TYPE_CHECKING import param @@ -30,7 +31,10 @@ from .coordinator import Coordinator, Planner from .export import export_notebook from .llm import Llm, OpenAI -from .memory import memory +from .memory import _Memory, memory + +if TYPE_CHECKING: + from .views import LumenOutput DataT = str | Source | Pipeline @@ -313,7 +317,7 @@ def _table_explorer(self): ) source_map = {} - def update_source_map(sources, init=False): + def update_source_map(_, __, sources, init=False): selected = list(table_select.value) deduplicate = len(sources) > 1 new = {} @@ -329,7 +333,7 @@ def update_source_map(sources, init=False): source_map.update(new) table_select.param.update(options=list(source_map), value=selected) memory.on_change('available_sources', update_source_map) - update_source_map(memory['available_sources'], init=True) + update_source_map(None, None, memory['available_sources'], init=True) controls = SourceControls(select_existing=False, name='Upload') tabs = Tabs(controls, sizing_mode='stretch_both', design=Material) @@ -362,59 +366,87 @@ def get_explorers(tables, load): sizing_mode='stretch_both', ) + def _add_exploration(self, title: str, memory: _Memory): + self._titles.append(title) + self._contexts.append(memory) + self._conversations.append(self._coordinator.interface.objects) + self._explorations.append((title, Column(name=title, sizing_mode='stretch_both', loading=True))) + self._notebook_export.filename = f"{title.replace(' ', '_')}.ipynb" + self._explorations.active = len(self._explorations)-1 + self._output.active = 1 + + def _add_outputs(self, exploration: Column, outputs: list[LumenOutput], memory: _Memory): + from panel_gwalker import GraphicWalker + if 'current_sql' in memory: + sql = memory["current_sql"] + sql_pane = Markdown( + f'```sql\n{sql}\n```', + margin=0, sizing_mode='stretch_width' + ) + if sql.count('\n') > 10: + sql_pane = Column(sql_pane, max_height=250, scroll='y-auto') + if len(exploration) and isinstance(exploration[0], Markdown): + exploration[0] = sql_pane + else: + exploration.insert(0, sql_pane) + + content = [] + if exploration.loading: + pipeline = memory['current_pipeline'] + content.append( + ('Overview', GraphicWalker( + pipeline.param.data, + kernel_computation=True, + tab='data', + sizing_mode='stretch_both' + )) + ) + content.extend([ + (out.title or type(out).__name__.replace('Output', ''), ParamMethod( + out.render, inplace=True, + sizing_mode='stretch_both' + )) for out in outputs + ]) + if exploration.loading: + tabs = Tabs(*content, active=len(outputs), dynamic=True) + exploration.append(tabs) + else: + tabs = exploration[-1] + tabs.extend(content) + tabs.active = len(tabs)-1 + def _wrap_callback(self, callback): async def wrapper(contents: list | str, user: str, instance: ChatInterface): if not self._explorations: prev_memory = memory else: prev_memory = self._contexts[self._explorations.active] + index = self._explorations.active if len(self._explorations) else -1 local_memory = prev_memory.clone() - prev_outputs = local_memory.get('outputs', []) - prev_pipeline = local_memory.get('current_pipeline') local_memory['outputs'] = outputs = [] - with self._coordinator.param.update(memory=local_memory): - await callback(contents, user, instance) - if not outputs: - prev_memory.update(local_memory) - return - from panel_gwalker import GraphicWalker - - title = local_memory['current_plan'] - pipeline = local_memory['current_pipeline'] - new = prev_pipeline is not pipeline - content = [] - if 'current_sql' in local_memory: - content.append(Markdown( - f'```sql\n{local_memory["current_sql"]}\n```', - margin=0 - )) - if not new: - outputs = prev_outputs + outputs - content.append( - Tabs( - ('Overview', GraphicWalker( - pipeline.param.data, - kernel_computation=True, - tab='data', - sizing_mode='stretch_both' - )), - *((type(out).__name__.replace('Output', ''), ParamMethod( - out.render, inplace=True, - sizing_mode='stretch_both' - )) for out in outputs), active=len(outputs), dynamic=True - ) - ) - if new: - self._conversations.append(self._coordinator.interface.objects) - self._explorations.append((title, Column(*content, name=title))) - self._contexts.append(local_memory) - self._titles.append(title) - self._notebook_export.filename = f"{title.replace(' ', '_')}.ipynb" - self._explorations.active = len(self._explorations)-1 - self._output.active = 1 - else: - tab = self._explorations.active - title = self._titles[tab] - self._explorations[tab] = Column(*content, name=title) + def render_plan(_, old, new): + nonlocal index + plan = local_memory['plan'] + if any(step.expert == 'SQLAgent' for step in plan.steps): + self._add_exploration(plan.title, local_memory) + index += 1 + local_memory.on_change('plan', render_plan) + + def render_output(_, old, new): + added = [out for out in new if out not in old] + exploration = self._explorations[index] + self._add_outputs(exploration, added, local_memory) + exploration.loading = False + outputs[:] = new + local_memory.on_change('outputs', render_output) + + try: + with self._coordinator.param.update(memory=local_memory): + await callback(contents, user, instance) + finally: + local_memory.remove_on_change('plan', render_plan) + local_memory.remove_on_change('outputs', render_output) + if not outputs: + prev_memory.update(local_memory) return wrapper diff --git a/lumen/ai/views.py b/lumen/ai/views.py index 4fb81a892..7d8bd614a 100644 --- a/lumen/ai/views.py +++ b/lumen/ai/views.py @@ -1,11 +1,17 @@ import asyncio +import traceback import panel as pn import param import yaml +from panel.layout import Column, Row, Tabs +from panel.pane import Alert from panel.param import ParamMethod from panel.viewable import Viewer +from panel.widgets import ( + Button, ButtonIcon, Checkbox, CodeEditor, LoadingSpinner, +) from param.parameterized import discard_events from lumen.ai.utils import get_data @@ -28,6 +34,8 @@ class LumenOutput(Viewer): spec = param.String(allow_None=True) + title = param.String(allow_None=True) + language = "yaml" def __init__(self, **params): @@ -35,19 +43,19 @@ def __init__(self, **params): component_spec = params['component'].to_spec() params['spec'] = yaml.dump(component_spec) super().__init__(**params) - code_editor = pn.widgets.CodeEditor( + code_editor = CodeEditor( value=self.param.spec, language=self.language, theme='tomorrow_night_bright', sizing_mode="stretch_both", on_keyup=False ) code_editor.link(self, bidirectional=True, value='spec') - copy_icon = pn.widgets.ButtonIcon( + copy_icon = ButtonIcon( icon="copy", active_icon="check", toggle_duration=1000 ) copy_icon.js_on_click( args={"code_editor": code_editor}, code="navigator.clipboard.writeText(code_editor.code);", ) - download_icon = pn.widgets.ButtonIcon( + download_icon = ButtonIcon( icon="download", active_icon="check", toggle_duration=1000 ) download_icon.js_on_click( @@ -64,14 +72,14 @@ def __init__(self, **params): a.parentNode.removeChild(a); //afterwards we remove the element again """, ) - icons = pn.Row(copy_icon, download_icon) - code_col = pn.Column(code_editor, icons, sizing_mode="stretch_both") + icons = Row(copy_icon, download_icon) + code_col = Column(code_editor, icons, sizing_mode="stretch_both") if self.render_output: - placeholder = pn.Column( + placeholder = Column( ParamMethod(self.render, inplace=True), sizing_mode="stretch_width" ) - self._main = pn.Tabs( + self._main = Tabs( ("Code", code_col), ("Output", placeholder), styles={'min-width': '100%', 'height': 'fit-content', 'min-height': '300px'}, @@ -103,7 +111,7 @@ async def _render_pipeline(self, pipeline): ) download_pane = download.__panel__() download_pane.sizing_mode = 'fixed' - controls = pn.Row( + controls = Row( download_pane, styles={'position': 'absolute', 'right': '40px', 'top': '-35px'} ) @@ -118,16 +126,16 @@ async def _render_pipeline(self, pipeline): if limited: def unlimit(e): sql_limit.limit = None if e.new else 1_000_000 - full_data = pn.widgets.Checkbox( + full_data = Checkbox( name='Full data', width=100, visible=limited ) full_data.param.watch(unlimit, 'value') controls.insert(0, full_data) - return pn.Column(controls, table) + return Column(controls, table) @param.depends('spec', 'active') async def render(self): - yield pn.indicators.LoadingSpinner( + yield LoadingSpinner( value=True, name="Rendering component...", height=50, width=50 ) @@ -138,7 +146,7 @@ async def render(self): yield self._last_output[self.spec] return elif self.component is None: - yield pn.pane.Alert( + yield Alert( "No component to render. Please complete the Config tab.", alert_type="warning", ) @@ -157,9 +165,8 @@ async def render(self): self._last_output[self.spec] = output yield output except Exception as e: - import traceback traceback.print_exc() - yield pn.pane.Alert( + yield Alert( f"```\n{e}\n```\nPlease press undo, edit the YAML, or continue chatting.", alert_type="danger", ) @@ -191,7 +198,7 @@ def __init__(self, **params): run_button.param.watch(self._rerun, 'clicks') self._main.insert(1, ('Config', controls)) else: - run_button = pn.widgets.Button( + run_button = Button( icon='player-play', name='Run', on_click=self._rerun, button_type='success', margin=(10, 0, 0 , 10) ) @@ -221,7 +228,7 @@ class SQLOutput(LumenOutput): @param.depends('spec', 'active') async def render(self): - yield pn.indicators.LoadingSpinner( + yield LoadingSpinner( value=True, name="Executing SQL query...", height=50, width=50 ) if self.active != 1: @@ -234,7 +241,9 @@ async def render(self): try: if self._rendered: - pipeline.source = pipeline.source.create_sql_expr_source(tables={pipeline.table: self.spec}) + pipeline.source = pipeline.source.create_sql_expr_source( + tables={pipeline.table: self.spec} + ) output = await self._render_pipeline(pipeline) self._rendered = True self._last_output.clear() @@ -243,7 +252,7 @@ async def render(self): except Exception as e: import traceback traceback.print_exc() - yield pn.pane.Alert( + yield Alert( f"```\n{e}\n```\nPlease press undo, edit the YAML, or continue chatting.", alert_type="danger", )