Skip to content

Commit

Permalink
Refactor handling of outputs (#773)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr authored Nov 20, 2024
1 parent 9b7e759 commit 381c8ae
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 81 deletions.
15 changes: 10 additions & 5 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,13 +358,17 @@ def _render_lumen(
component: Component,
message: pn.chat.ChatMessage = None,
render_output: bool = False,
title: str | None = None,
**kwargs
):
out = self._output_type(
component=component, render_output=render_output, **kwargs
component=component, render_output=render_output, title=title, **kwargs
)
if 'outputs' in self._memory:
self._memory['outputs'].append(out)
# We have to create a new list to trigger an event
# since inplace updates will not trigger updates
# and won't allow diffing between old and new values
self._memory['outputs'] = self._memory['outputs']+[out]
message_kwargs = dict(value=out, user=self.user)
self.interface.stream(message=message, **message_kwargs, replace=True, max_width=self._max_width)

Expand Down Expand Up @@ -724,7 +728,7 @@ async def respond(
messages[-1]["content"] = re.sub(r"//[^/]+//", "", messages[-1]["content"])
sql_query = await self._create_valid_sql(messages, system_prompt, tables_to_source, step_title)
pipeline = self._memory['current_pipeline']
self._render_lumen(pipeline, spec=sql_query, render_output=render_output)
self._render_lumen(pipeline, spec=sql_query, render_output=render_output, title=step_title)
return pipeline


Expand Down Expand Up @@ -788,7 +792,7 @@ async def respond(
print(f"{self.name} settled on spec: {spec!r}.")
self._memory["current_view"] = dict(spec, type=self.view_type)
view = self.view_type(pipeline=pipeline, **spec)
self._render_lumen(view, render_output=render_output)
self._render_lumen(view, render_output=render_output, title=step_title)
return view


Expand Down Expand Up @@ -979,6 +983,7 @@ async def respond(
view,
analysis=analysis,
pipeline=pipeline,
render_output=render_output
render_output=render_output,
title=step_title
)
return view
2 changes: 1 addition & 1 deletion lumen/ai/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ async def _compute_execution_graph(self, messages: list[Message], agents: dict[s
istep.stream(f"The plan didn't account for {unmet_dependencies!r}", replace=True)
else:
planned = True
self._memory['current_plan'] = plan.title
self._memory['plan'] = plan
istep.stream('\n\nHere are the steps:\n\n')
for i, step in enumerate(plan.steps):
istep.stream(f"{i+1}. {step.expert}: {step.instruction}\n")
Expand Down
21 changes: 14 additions & 7 deletions lumen/ai/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,34 @@ def __init__(self, *args, **kwargs):
self._callbacks = defaultdict(list)
self._rx = {}

def __setitem__(self, key, value):
super().__setitem__(key, value)
self._trigger_update(key, value)
def __setitem__(self, key, new):
if key in self:
old = self[key]
else:
old = None
super().__setitem__(key, new)
self._trigger_update(key, old, new)

def on_change(self, key, callback):
self._callbacks[key].append(callback)

def remove_on_change(self, key, callback):
self._callbacks[key].remove(callback)

def rx(self, key):
if key in self._rx:
return self._rx[key]
self._rx[key] = rxp = param.rx(self[key])
return rxp

def trigger(self, key):
self._trigger_update(key, self[key])
self._trigger_update(key, self[key], self[key])

def _trigger_update(self, key, value):
def _trigger_update(self, key, old, new):
for cb in self._callbacks[key]:
cb(value)
cb(key, old, new)
if key in self._rx:
self._rx[key].rx.value = value
self._rx[key].rx.value = new


memory = _Memory()
2 changes: 1 addition & 1 deletion lumen/ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def make_plan_models(agent_names: list[str], tables: list[str]):
"Step",
expert=(Literal[agent_names], FieldInfo(description="The name of the expert to assign a task to.")),
instruction=(str, FieldInfo(description="Instructions to the expert to assist in the task, and whether rendering is required.")),
title=(str, FieldInfo(description="Short title of the task to be performed; up to six words.")),
title=(str, FieldInfo(description="Short title of the task to be performed; up to three words.")),
render_output=(bool, FieldInfo(description="Whether the output of the expert should be rendered. If the user wants to see the table, and the expert is SQL, then this should be `True`.")),
)
extras = {}
Expand Down
1 change: 1 addition & 0 deletions lumen/ai/prompts/SQLAgent/main.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Checklist
- If it's a date column (excluding individual year/month/day integers) date, cast to date using appropriate syntax, e.g.
CAST or TO_DATE
- Use only `{{ dialect }}` syntax
- Try to pretty print the SQL output with newlines and indentation.
{% if dialect == 'duckdb' %}
- If the table name originally did not have `read_*` prefix, use the original table name
- Use table names verbatim; e.g. if table is read_csv('table.csv') then use read_csv('table.csv') and not 'table'
Expand Down
130 changes: 81 additions & 49 deletions lumen/ai/ui.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from io import StringIO
from typing import TYPE_CHECKING

import param

Expand Down Expand Up @@ -30,7 +31,10 @@
from .coordinator import Coordinator, Planner
from .export import export_notebook
from .llm import Llm, OpenAI
from .memory import memory
from .memory import _Memory, memory

if TYPE_CHECKING:
from .views import LumenOutput

DataT = str | Source | Pipeline

Expand Down Expand Up @@ -313,7 +317,7 @@ def _table_explorer(self):
)

source_map = {}
def update_source_map(sources, init=False):
def update_source_map(_, __, sources, init=False):
selected = list(table_select.value)
deduplicate = len(sources) > 1
new = {}
Expand All @@ -329,7 +333,7 @@ def update_source_map(sources, init=False):
source_map.update(new)
table_select.param.update(options=list(source_map), value=selected)
memory.on_change('available_sources', update_source_map)
update_source_map(memory['available_sources'], init=True)
update_source_map(None, None, memory['available_sources'], init=True)

controls = SourceControls(select_existing=False, name='Upload')
tabs = Tabs(controls, sizing_mode='stretch_both', design=Material)
Expand Down Expand Up @@ -362,59 +366,87 @@ def get_explorers(tables, load):
sizing_mode='stretch_both',
)

def _add_exploration(self, title: str, memory: _Memory):
self._titles.append(title)
self._contexts.append(memory)
self._conversations.append(self._coordinator.interface.objects)
self._explorations.append((title, Column(name=title, sizing_mode='stretch_both', loading=True)))
self._notebook_export.filename = f"{title.replace(' ', '_')}.ipynb"
self._explorations.active = len(self._explorations)-1
self._output.active = 1

def _add_outputs(self, exploration: Column, outputs: list[LumenOutput], memory: _Memory):
from panel_gwalker import GraphicWalker
if 'current_sql' in memory:
sql = memory["current_sql"]
sql_pane = Markdown(
f'```sql\n{sql}\n```',
margin=0, sizing_mode='stretch_width'
)
if sql.count('\n') > 10:
sql_pane = Column(sql_pane, max_height=250, scroll='y-auto')
if len(exploration) and isinstance(exploration[0], Markdown):
exploration[0] = sql_pane
else:
exploration.insert(0, sql_pane)

content = []
if exploration.loading:
pipeline = memory['current_pipeline']
content.append(
('Overview', GraphicWalker(
pipeline.param.data,
kernel_computation=True,
tab='data',
sizing_mode='stretch_both'
))
)
content.extend([
(out.title or type(out).__name__.replace('Output', ''), ParamMethod(
out.render, inplace=True,
sizing_mode='stretch_both'
)) for out in outputs
])
if exploration.loading:
tabs = Tabs(*content, active=len(outputs), dynamic=True)
exploration.append(tabs)
else:
tabs = exploration[-1]
tabs.extend(content)
tabs.active = len(tabs)-1

def _wrap_callback(self, callback):
async def wrapper(contents: list | str, user: str, instance: ChatInterface):
if not self._explorations:
prev_memory = memory
else:
prev_memory = self._contexts[self._explorations.active]
index = self._explorations.active if len(self._explorations) else -1
local_memory = prev_memory.clone()
prev_outputs = local_memory.get('outputs', [])
prev_pipeline = local_memory.get('current_pipeline')
local_memory['outputs'] = outputs = []
with self._coordinator.param.update(memory=local_memory):
await callback(contents, user, instance)
if not outputs:
prev_memory.update(local_memory)
return

from panel_gwalker import GraphicWalker

title = local_memory['current_plan']
pipeline = local_memory['current_pipeline']
new = prev_pipeline is not pipeline
content = []
if 'current_sql' in local_memory:
content.append(Markdown(
f'```sql\n{local_memory["current_sql"]}\n```',
margin=0
))
if not new:
outputs = prev_outputs + outputs
content.append(
Tabs(
('Overview', GraphicWalker(
pipeline.param.data,
kernel_computation=True,
tab='data',
sizing_mode='stretch_both'
)),
*((type(out).__name__.replace('Output', ''), ParamMethod(
out.render, inplace=True,
sizing_mode='stretch_both'
)) for out in outputs), active=len(outputs), dynamic=True
)
)
if new:
self._conversations.append(self._coordinator.interface.objects)
self._explorations.append((title, Column(*content, name=title)))
self._contexts.append(local_memory)
self._titles.append(title)
self._notebook_export.filename = f"{title.replace(' ', '_')}.ipynb"
self._explorations.active = len(self._explorations)-1
self._output.active = 1
else:
tab = self._explorations.active
title = self._titles[tab]
self._explorations[tab] = Column(*content, name=title)
def render_plan(_, old, new):
nonlocal index
plan = local_memory['plan']
if any(step.expert == 'SQLAgent' for step in plan.steps):
self._add_exploration(plan.title, local_memory)
index += 1
local_memory.on_change('plan', render_plan)

def render_output(_, old, new):
added = [out for out in new if out not in old]
exploration = self._explorations[index]
self._add_outputs(exploration, added, local_memory)
exploration.loading = False
outputs[:] = new
local_memory.on_change('outputs', render_output)

try:
with self._coordinator.param.update(memory=local_memory):
await callback(contents, user, instance)
finally:
local_memory.remove_on_change('plan', render_plan)
local_memory.remove_on_change('outputs', render_output)
if not outputs:
prev_memory.update(local_memory)
return wrapper
Loading

0 comments on commit 381c8ae

Please sign in to comment.