Skip to content

Commit

Permalink
Various tweaks for assistant, agents and utilities (#728)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr authored Oct 23, 2024
1 parent 17763d0 commit 08915e2
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 65 deletions.
22 changes: 16 additions & 6 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,9 @@ async def invoke(self, messages: list | str):

class SourceAgent(Agent):
"""
The SourceAgent allows a user to upload datasets.
The SourceAgent allows a user to upload new datasets.
Use this if the user is requesting to add a dataset or you think
Only use this if the user is requesting to add a dataset or you think
additional information is required to solve the user query.
"""

Expand Down Expand Up @@ -352,7 +352,6 @@ def _render_lumen(self, component: Component, message: pn.chat.ChatMessage = Non
self.interface.stream(message=message, **message_kwargs, replace=True, max_width=self._max_width)



class TableAgent(LumenBaseAgent):
"""
Displays a single table / dataset. Does not discuss.
Expand Down Expand Up @@ -479,8 +478,9 @@ def _use_table(self, event):
self.interface.send(f"Show the table: {table!r}")

async def answer(self, messages: list | str):
source = memory["current_source"]
tables = source.get_tables()
tables = []
for source in memory['available_sources']:
tables += source.get_tables()
if not tables:
return

Expand Down Expand Up @@ -720,7 +720,17 @@ async def answer(self, messages: list | str):
table_schema = schema
else:
table_schema = await get_schema(source, source_table, include_min_max=False)
table_schemas[source_table] = {

# Look up underlying table name
table_name = source_table
if (
'tables' in source.param and
isinstance(source.tables, dict) and
'select ' not in source.tables[table_name].lower()
):
table_name = source.tables[table_name]

table_schemas[table_name] = {
"schema": yaml.dump(table_schema),
"sql": source.get_sql_expr(source_table)
}
Expand Down
109 changes: 67 additions & 42 deletions lumen/ai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re

from io import StringIO
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

import param
import yaml
Expand Down Expand Up @@ -402,7 +402,7 @@ async def _get_agent(self, messages: list | str):
for subagent, deps, instruction in agent_chain[:-1]:
agent_name = type(subagent).name.replace('Agent', '')
with self.interface.add_step(title=f"Querying {agent_name} agent...") as step:
step.stream(f"Assistant decided the {agent_name!r} will provide {', '.join(deps)}.")
step.stream(f"`{agent_name}` agent is working on the following task:\n\n{instruction}")
self._current_agent.object = f"## **Current Agent**: {agent_name}"
custom_messages = messages.copy()
if isinstance(subagent, SQLAgent):
Expand All @@ -417,7 +417,8 @@ async def _get_agent(self, messages: list | str):
if instruction:
custom_messages.append({"role": "user", "content": instruction})
await subagent.answer(custom_messages)
step.success_title = f"{agent_name} agent responded"
step.stream(f"`{agent_name}` agent successfully completed the following task:\n\n- {instruction}", replace=True)
step.success_title = f"{agent_name} agent successfully responded"
return selected

def _serialize(self, obj, exclude_passwords=True):
Expand Down Expand Up @@ -490,53 +491,89 @@ class PlanningAssistant(Assistant):
instead of simply resolving the dependencies step-by-step.
"""

@classmethod
async def _lookup_schemas(
cls,
tables: dict[str, Source],
requested: list[str],
provided: list[str],
cache: dict[str, dict] | None = None
) -> str:
cache = cache or {}
to_query, queries = [], []
for table in requested:
if table in provided or table in cache:
continue
to_query.append(table)
queries.append(get_schema(tables[table], table, limit=3))
for table, schema in zip(to_query, await asyncio.gather(*queries)):
cache[table] = schema
schema_info = ''
for table in requested:
if table in provided:
continue
provided.append(table)
schema_info += f'- {table}: {cache[table]}\n\n'
return schema_info

async def _make_plan(
self,
user_msg: dict[str, Any],
messages: list,
agents: dict[str, Agent],
tables: dict[str, Source],
unmet_dependencies: set[str],
reason_model: type[BaseModel],
plan_model: type[BaseModel],
step: ChatStep
):
step: ChatStep,
schemas: dict[str, dict] | None = None
) -> BaseModel:
user_msg = messages[-1]
info = ''
reasoning = None
requested_tables, provided_tables = [], []
requested, provided = [], []
if 'current_table' in memory:
requested_tables.append(memory['current_table'])
requested.append(memory['current_table'])
elif len(tables) == 1:
requested_tables.append(next(iter(tables)))
while reasoning is None or requested_tables:
# Add context of table schemas
schemas = []
requested = getattr(reasoning, 'tables', requested_tables)
for table in requested:
if table in provided_tables:
continue
provided_tables.append(table)
schemas.append(get_schema(tables[table], table, limit=3))
for table, schema in zip(requested, await asyncio.gather(*schemas)):
info += f'- {table}: {schema}\n\n'
requested.append(next(iter(tables)))
while reasoning is None or requested:
info += await self._lookup_schemas(tables, requested, provided, cache=schemas)
available = [t for t in tables if t not in provided]
system = render_template(
'plan_agent.jinja2', agents=list(agents.values()), current_agent=self._current_agent.object,
unmet_dependencies=unmet_dependencies, memory=memory, table_info=info, tables=list(tables)
unmet_dependencies=unmet_dependencies, memory=memory, table_info=info, tables=available
)
async for reasoning in self.llm.stream(
messages=messages,
system=system,
response_model=reason_model,
):
step.stream(reasoning.chain_of_thought, replace=True)
requested_tables = [t for t in reasoning.tables if t and t not in provided_tables]
if requested_tables:
continue
requested = [
t for t in getattr(reasoning, 'tables', [])
if t and t not in provided
]
new_msg = dict(role=user_msg['role'], content=f"<user query>{user_msg['content']}</user query> {reasoning.chain_of_thought}")
messages = messages[:-1] + [new_msg]
plan = await self._fill_model(messages, system, plan_model)
return plan

async def _resolve_plan(self, plan, agents, messages):
step = plan.steps[-1]
subagent = agents[step.expert]
unmet_dependencies = {
r for r in await subagent.requirements(messages) if r not in memory
}
agent_chain = [(subagent, unmet_dependencies, step.instruction)]
for step in plan.steps[:-1][::-1]:
subagent = agents[step.expert]
requires = set(await subagent.requirements(messages))
unmet_dependencies = {
dep for dep in (unmet_dependencies | requires)
if dep not in subagent.provides and dep not in memory
}
agent_chain.append((subagent, subagent.provides, step.instruction))
return agent_chain, unmet_dependencies

async def _resolve_dependencies(self, messages: list, agents: dict[str, Agent]) -> list[tuple(Agent, any)]:
agent_names = tuple(sagent.name[:-5] for sagent in agents.values())
tables = {}
Expand All @@ -547,29 +584,17 @@ async def _resolve_dependencies(self, messages: list, agents: dict[str, Agent])
reason_model, plan_model = make_plan_models(agent_names, list(tables))
planned = False
unmet_dependencies = set()
user_msg = messages[-1]
schemas = {}
with self.interface.add_step(title="Planning how to solve user query...", user="Assistant") as istep:
while not planned or unmet_dependencies:
while not planned:
plan = await self._make_plan(
user_msg, messages, agents, tables, unmet_dependencies, reason_model, plan_model, istep
messages, agents, tables, unmet_dependencies, reason_model, plan_model, istep, schemas
)
step = plan.steps[-1]
subagent = agents[step.expert]
unmet_dependencies = {
r for r in await subagent.requirements(messages) if r not in memory
}
agent_chain = [(subagent, unmet_dependencies, step.instruction)]
for step in plan.steps[:-1][::-1]:
subagent = agents[step.expert]
requires = set(await subagent.requirements(messages))
unmet_dependencies = {
dep for dep in (unmet_dependencies | requires)
if dep not in subagent.provides and dep not in memory
}
agent_chain.append((subagent, subagent.provides, step.instruction))
agent_chain, unmet_dependencies = await self._resolve_plan(plan, agents, messages)
if unmet_dependencies:
istep.stream(f"The plan didn't account for {unmet_dependencies!r}", replace=True)
planned = True
else:
planned = True
istep.stream('\n\nHere are the steps:\n\n')
for i, step in enumerate(plan.steps):
istep.stream(f"{i+1}. {step.expert}: {step.instruction}\n")
Expand Down
19 changes: 8 additions & 11 deletions lumen/ai/prompts/plan_agent.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@ Ensure that you provide each expert some context to ensure they do not repeat pr
Currently you have the following information available to you:
{% for item in memory.keys() %}
- {{ item }}
{% endfor %}
And have access to the following data tables:
{% for table in tables %}
- {{ table }}
{% endfor %}
{%- endfor %}
{% if table_info %}
In order to make an informed decision here are schemas for the most relevant tables:
In order to make an informed decision here are schemas for the most relevant tables (note that these schemas are computed on a subset of data):
{{ table_info }}
Do not request any additional tables.
{% endif %}
{%- if tables %}
Additionally the following tables are available and you may request to look at them before revising your plan:
{% for table in tables %}
- {{ table }}
{% endfor %}
{%- endif -%}
Here's the choice of experts and their uses:
{% for agent in agents %}
- `{{ agent.name[:-5] }}`
Expand All @@ -31,10 +32,6 @@ Here's the choice of experts and their uses:
Description: {{ agent.__doc__.strip().split() | join(' ') }}
{% endfor %}

{% if not table_info %}
If you do not think you can solve the problem given the current information provide a list of requested tables.
{% endif %}

{% if unmet_dependencies %}
Note that a previous plan was unsuccessful because it did not satisfy the following required pieces of information: {unmet_dependencies!r}
{% endif %}
19 changes: 13 additions & 6 deletions lumen/ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import inspect
import math
import time

from functools import wraps
Expand Down Expand Up @@ -140,10 +141,13 @@ async def get_schema(
if "max" in spec:
spec.pop("max")

if not include_enum:
for field, spec in schema.items():
if "enum" in spec:
spec.pop("enum")
for field, spec in schema.items():
if "enum" not in spec:
continue
elif not include_enum:
spec.pop("enum")
elif "limit" in get_kwargs:
spec["enum"].append("...")

if count and include_count:
spec["count"] = count
Expand Down Expand Up @@ -174,7 +178,7 @@ def get_data_sync():

async def describe_data(df: pd.DataFrame) -> str:
def format_float(num):
if pd.isna(num):
if pd.isna(num) or math.isinf(num):
return num
# if is integer, round to 0 decimals
if num == int(num):
Expand Down Expand Up @@ -209,7 +213,10 @@ def describe_data_sync(df):
for col in df.select_dtypes(include=["object"]).columns:
if col not in df_describe_dict:
df_describe_dict[col] = {}
df_describe_dict[col]["nunique"] = df[col].nunique()
try:
df_describe_dict[col]["nunique"] = df[col].nunique()
except Exception:
df_describe_dict[col]["nunique"] = 'unknown'
try:
df_describe_dict[col]["lengths"] = {
"max": df[col].str.len().max(),
Expand Down

0 comments on commit 08915e2

Please sign in to comment.