From b777a2933ebd66f4b8fdfa037728c6eebee73eb2 Mon Sep 17 00:00:00 2001 From: Andrew <15331990+ahuang11@users.noreply.github.com> Date: Wed, 25 Sep 2024 09:46:47 -0700 Subject: [PATCH] Async data calls (#714) --- lumen/ai/agents.py | 60 +++++++++-------- lumen/ai/analysis.py | 12 ++-- lumen/ai/assistant.py | 10 +-- lumen/ai/utils.py | 153 ++++++++++++++++++++++++------------------ lumen/ai/views.py | 11 +-- 5 files changed, 140 insertions(+), 106 deletions(-) diff --git a/lumen/ai/agents.py b/lumen/ai/agents.py index de7cc3f43..186663108 100644 --- a/lumen/ai/agents.py +++ b/lumen/ai/agents.py @@ -39,7 +39,8 @@ ) from .translate import param_to_pydantic from .utils import ( - clean_sql, describe_data, get_schema, render_template, retry_llm_output, + clean_sql, describe_data, get_data, get_pipeline, get_schema, + render_template, retry_llm_output, ) from .views import AnalysisOutput, LumenOutput, SQLOutput @@ -274,7 +275,7 @@ async def _system_prompt_with_context( context = f"Available tables: {', '.join(closest_tables)}" else: memory["current_table"] = table = memory.get("current_table", tables[0]) - schema = get_schema(memory["current_source"], table) + schema = await get_schema(memory["current_source"], table) if schema: context = f"{table} with schema: {schema}" @@ -389,7 +390,7 @@ async def answer(self, messages: list | str): for table in source.get_tables(): tables_to_source[table] = source if isinstance(source, DuckDBSource) and source.ephemeral: - schema = get_schema(source, table, include_min_max=False, include_enum=True, limit=1) + schema = await get_schema(source, table, include_min_max=False, include_enum=True, limit=1) tables_schema_str += f"### {table}\nSchema:\n```yaml\n{yaml.dump(schema)}```\n" else: tables_schema_str += f"### {table}\n" @@ -435,12 +436,12 @@ async def answer(self, messages: list | str): get_kwargs['sql_transforms'] = [SQLLimit(limit=1_000_000)] memory["current_source"] = source memory["current_table"] = table - memory["current_pipeline"] = pipeline = Pipeline( + memory["current_pipeline"] = pipeline = await get_pipeline( source=source, table=table, **get_kwargs ) - df = pipeline.data + df = await get_data(pipeline) if len(df) > 0: - memory["current_data"] = describe_data(df) + memory["current_data"] = await describe_data(df) if self.debug: print(f"{self.name} thinks that the user is talking about {table=!r}.") return pipeline @@ -581,7 +582,7 @@ async def _create_valid_sql(self, messages, system, tables_to_source, errors=Non # Get validated query sql_query = sql_expr_source.tables[expr_slug] sql_transforms = [SQLLimit(limit=1_000_000)] - pipeline = Pipeline( + pipeline = await get_pipeline( source=sql_expr_source, table=expr_slug, sql_transforms=sql_transforms ) except InstructorRetryException as e: @@ -605,9 +606,9 @@ async def _create_valid_sql(self, messages, system, tables_to_source, errors=Non step.status = "failed" raise e - df = pipeline.data + df = await get_data(pipeline) if len(df) > 0: - memory["current_data"] = describe_data(df) + memory["current_data"] = await describe_data(df) memory["available_sources"].append(sql_expr_source) memory["current_source"] = sql_expr_source @@ -701,7 +702,7 @@ async def answer(self, messages: list | str): if not hasattr(source, "get_sql_expr"): return None - schema = get_schema(source, table, include_min_max=False) + schema = await get_schema(source, table, include_min_max=False) join_required = await self.check_join_required(messages, schema, table) if join_required: tables_to_source = await self.find_join_tables(messages) @@ -713,7 +714,7 @@ async def answer(self, messages: list | str): if source_table == table: table_schema = schema else: - table_schema = get_schema(source, source_table, include_min_max=False) + table_schema = await get_schema(source, source_table, include_min_max=False) table_schemas[source_table] = { "schema": yaml.dump(table_schema), "sql": source.get_sql_expr(source_table) @@ -754,13 +755,14 @@ async def answer(self, messages: list | str) -> Transform: if "current_pipeline" in memory: pipeline = memory["current_pipeline"] else: - pipeline = Pipeline( + pipeline = await get_pipeline( source=memory["current_source"], table=memory["current_table"], ) memory["current_pipeline"] = pipeline - pipeline._update_data(force=True) - memory["current_data"] = describe_data(pipeline.data) + await asyncio.to_thread(pipeline._update_data, force=True) + data = await get_data(pipeline) + memory["current_data"] = await describe_data(data) return pipeline async def invoke(self, messages: list | str): @@ -867,7 +869,7 @@ async def _construct_transform( self, messages: list | str, transform: type[Transform], system_prompt: str ) -> Transform: excluded = transform._internal_params + ["controls", "type"] - schema = get_schema(memory["current_pipeline"]) + schema = await get_schema(memory["current_pipeline"]) table = memory["current_table"] model = param_to_pydantic(transform, excluded=excluded, schema=schema)[ transform.__name__ @@ -912,8 +914,9 @@ async def answer(self, messages: list | str) -> Transform: else: pipeline.add_transform(transform) - pipeline._update_data(force=True) - memory["current_data"] = describe_data(pipeline.data) + await asyncio.to_thread(pipeline._update_data, force=True) + data = await get_data(pipeline) + memory["current_data"] = await describe_data(data) return pipeline async def invoke(self, messages: list | str): @@ -927,7 +930,7 @@ class BaseViewAgent(LumenBaseAgent): provides = param.List(default=["current_plot"], readonly=True) - def _extract_spec(self, model: BaseModel): + async def _extract_spec(self, model: BaseModel): return dict(model) async def answer(self, messages: list | str) -> hvPlotUIView: @@ -935,7 +938,7 @@ async def answer(self, messages: list | str) -> hvPlotUIView: # Write prompts system_prompt = await self._system_prompt_with_context(messages) - schema = get_schema(pipeline, include_min_max=False) + schema = await get_schema(pipeline, include_min_max=False) view_prompt = render_template( "plot_agent.jinja2", schema=yaml.dump(schema), @@ -951,7 +954,7 @@ async def answer(self, messages: list | str) -> hvPlotUIView: system=system_prompt + view_prompt, response_model=self._get_model(schema), ) - spec = self._extract_spec(output) + spec = await self._extract_spec(output) chain_of_thought = spec.pop("chain_of_thought") with self.interface.add_step(title="Generating view...") as step: step.stream(chain_of_thought) @@ -1002,7 +1005,7 @@ def _get_model(cls, schema): }) return model[cls.view_type.__name__] - def _extract_spec(self, model): + async def _extract_spec(self, model): pipeline = memory["current_pipeline"] spec = { key: val for key, val in dict(model).items() @@ -1014,7 +1017,8 @@ def _extract_spec(self, model): # Add defaults spec["responsive"] = True - if len(pipeline.data) > 20000 and spec["kind"] in ("line", "scatter", "points"): + data = await get_data(pipeline) + if len(data) > 20000 and spec["kind"] in ("line", "scatter", "points"): spec["rasterize"] = True spec["cnorm"] = "log" return spec @@ -1039,7 +1043,7 @@ class VegaLiteAgent(BaseViewAgent): def _get_model(cls, schema): return VegaLiteSpec - def _extract_spec(self, model): + async def _extract_spec(self, model): vega_spec = json.loads(model.json_spec) if "$schema" not in vega_spec: vega_spec["$schema"] = "https://vega.github.io/schema/vega-lite/v5.json" @@ -1092,7 +1096,7 @@ async def _system_prompt_with_context( async def answer(self, messages: list | str, agents: list[Agent] | None = None): pipeline = memory['current_pipeline'] - analyses = {a.name: a for a in self.analyses if a.applies(pipeline)} + analyses = {a.name: a for a in self.analyses if await a.applies(pipeline)} if not analyses: print("NONE found...") return None @@ -1125,8 +1129,10 @@ async def answer(self, messages: list | str, agents: list[Agent] | None = None): with self.interface.add_step(title="Creating view...", user="Assistant") as step: await asyncio.sleep(0.1) # necessary to give it time to render before calling sync function... analysis_callable = analyses[analysis_name].instance(agents=agents) + + data = await get_data(pipeline) for field in analysis_callable._field_params: - analysis_callable.param[field].objects = list(pipeline.data.columns) + analysis_callable.param[field].objects = list(data.columns) memory["current_analysis"] = analysis_callable if analysis_callable.autorun: @@ -1143,8 +1149,8 @@ async def answer(self, messages: list | str, agents: list[Agent] | None = None): # Ensure current_data reflects processed pipeline if pipeline is not memory['current_pipeline']: pipeline = memory['current_pipeline'] - if len(pipeline.data) > 0: - memory["current_data"] = describe_data(pipeline.data) + if len(data) > 0: + memory["current_data"] = await describe_data(data) yaml_spec = yaml.dump(spec) step.stream(f"Generated view\n```yaml\n{yaml_spec}\n```") step.success_title = "Generated view" diff --git a/lumen/ai/analysis.py b/lumen/ai/analysis.py index 45c07c9b0..5be18b2bb 100644 --- a/lumen/ai/analysis.py +++ b/lumen/ai/analysis.py @@ -1,10 +1,11 @@ import panel as pn import param +from lumen.ai.utils import get_data + from ..base import Component from .controls import SourceControls from .memory import memory -from .utils import get_schema class Analysis(param.ParameterizedFunction): @@ -34,13 +35,14 @@ class Analysis(param.ParameterizedFunction): _field_params = [] @classmethod - def applies(cls, pipeline) -> bool: + async def applies(cls, pipeline) -> bool: applies = True + data = await get_data(pipeline) for col in cls.columns: if isinstance(col, tuple): - applies &= any(c in pipeline.data.columns for c in col) + applies &= any(c in data.columns for c in col) else: - applies &= col in pipeline.data.columns + applies &= col in data.columns return applies def controls(self): @@ -80,7 +82,7 @@ def controls(self): table = memory.get("current_table") self._previous_source = source self._previous_table = table - columns = list(get_schema(source, table=table).keys()) + columns = list(source.get_schema(table).keys()) index_col = pn.widgets.AutocompleteInput.from_param( self.param.index_col, options=columns, name="Join on", placeholder="Start typing column name", search_strategy="includes", diff --git a/lumen/ai/assistant.py b/lumen/ai/assistant.py index 501a3c6d3..f04f19011 100644 --- a/lumen/ai/assistant.py +++ b/lumen/ai/assistant.py @@ -186,7 +186,7 @@ async def use_suggestion(event): else: return await agent.invoke([{'role': 'user', 'content': contents}], agents=self.agents) - self._add_analysis_suggestions() + await self._add_analysis_suggestions() else: self.interface.send(contents) @@ -233,13 +233,13 @@ async def run_demo(event): self.interface.param.watch(hide_suggestions, "objects") return message - def _add_analysis_suggestions(self): + async def _add_analysis_suggestions(self): pipeline = memory['current_pipeline'] current_analysis = memory.get("current_analysis") allow_consecutive = getattr(current_analysis, '_consecutive_calls', True) applicable_analyses = [] for analysis in self._analyses: - if analysis.applies(pipeline) and (allow_consecutive or analysis is not type(current_analysis)): + if await analysis.applies(pipeline) and (allow_consecutive or analysis is not type(current_analysis)): applicable_analyses.append(analysis) self._add_suggestions_to_footer( [f"Apply {analysis.__name__}" for analysis in applicable_analyses], @@ -263,7 +263,7 @@ async def _invalidate_memory(self, messages): raise KeyError(f'Table {table} could not be found in available sources.') try: - spec = get_schema(source, table=table, include_count=True) + spec = await get_schema(source, table=table, include_count=True) except Exception: # If the selected table cannot be fetched we should invalidate it spec = None @@ -482,7 +482,7 @@ async def invoke(self, messages: list | str) -> str: await agent.invoke(messages[-context_length:], **kwargs) self._current_agent.object = "## No agent active" if "current_pipeline" in agent.provides: - self._add_analysis_suggestions() + await self._add_analysis_suggestions() print("\033[92mDONE\033[0m", "\n\n") def controls(self): diff --git a/lumen/ai/utils.py b/lumen/ai/utils.py index 30d2e29e4..cb9877e3a 100644 --- a/lumen/ai/utils.py +++ b/lumen/ai/utils.py @@ -97,7 +97,7 @@ def format_schema(schema): return formatted -def get_schema( +async def get_schema( source: Source | Pipeline, table: str | None = None, include_min_max: bool = True, @@ -106,11 +106,11 @@ def get_schema( **get_kwargs ): if isinstance(source, Pipeline): - schema = source.get_schema() + schema = await asyncio.to_thread(source.get_schema) else: if "limit" not in get_kwargs: get_kwargs["limit"] = 100 - schema = source.get_schema(table, **get_kwargs) + schema = await asyncio.to_thread(source.get_schema, table, **get_kwargs) schema = dict(schema) # first pop regardless to prevent @@ -146,7 +146,27 @@ def get_schema( return schema -def describe_data(df: pd.DataFrame) -> str: +async def get_pipeline(**kwargs): + """ + A wrapper be able to use asyncio.to_thread and not + block the main thread when calling Pipeline + """ + def get_pipeline_sync(): + return Pipeline(**kwargs) + return await asyncio.to_thread(get_pipeline_sync) + + +async def get_data(pipeline): + """ + A wrapper be able to use asyncio.to_thread and not + block the main thread when calling pipeline.data + """ + def get_data_sync(): + return pipeline.data + return await asyncio.to_thread(get_data_sync) + + +async def describe_data(df: pd.DataFrame) -> str: def format_float(num): if pd.isna(num): return num @@ -158,67 +178,70 @@ def format_float(num): else: return f"{num:.1e}" # Exponential notation with two decimals - size = df.size - shape = df.shape - if size < 250: - return df - - is_summarized = False - if shape[0] > 5000: - is_summarized = True - df = df.sample(5000) - - df = df.sort_index() - - for col in df.columns: - if isinstance(df[col].iloc[0], pd.Timestamp): - df[col] = pd.to_datetime(df[col]) - - describe_df = df.describe(percentiles=[]) - columns_to_drop = ["min", "max"] # present if any numeric - columns_to_drop = [col for col in columns_to_drop if col in describe_df.columns] - df_describe_dict = describe_df.drop(columns=columns_to_drop).to_dict() - - for col in df.select_dtypes(include=["object"]).columns: - if col not in df_describe_dict: - df_describe_dict[col] = {} - df_describe_dict[col]["nunique"] = df[col].nunique() - try: - df_describe_dict[col]["lengths"] = { - "max": df[col].str.len().max(), - "min": df[col].str.len().min(), - "mean": float(df[col].str.len().mean()), - } - except AttributeError: - pass - - for col in df.columns: - if col not in df_describe_dict: - df_describe_dict[col] = {} - df_describe_dict[col]["nulls"] = int(df[col].isnull().sum()) - - # select datetime64 columns - for col in df.select_dtypes(include=["datetime64"]).columns: - for key in df_describe_dict[col]: - df_describe_dict[col][key] = str(df_describe_dict[col][key]) - df[col] = df[col].astype(str) # shorten output - - # select all numeric columns and round - for col in df.select_dtypes(include=["int64", "float64"]).columns: - for key in df_describe_dict[col]: - df_describe_dict[col][key] = format_float(df_describe_dict[col][key]) - - for col in df.select_dtypes(include=["float64"]).columns: - df[col] = df[col].apply(format_float) - - data = { - "summary": { - "total_table_cells": size, - "total_shape": shape, - "is_summarized": is_summarized, - }, - "stats": df_describe_dict, - } + def describe_data_sync(df): + size = df.size + shape = df.shape + if size < 250: + return df + + is_summarized = False + if shape[0] > 5000: + is_summarized = True + df = df.sample(5000) + + df = df.sort_index() + + for col in df.columns: + if isinstance(df[col].iloc[0], pd.Timestamp): + df[col] = pd.to_datetime(df[col]) + + describe_df = df.describe(percentiles=[]) + columns_to_drop = ["min", "max"] # present if any numeric + columns_to_drop = [col for col in columns_to_drop if col in describe_df.columns] + df_describe_dict = describe_df.drop(columns=columns_to_drop).to_dict() + + for col in df.select_dtypes(include=["object"]).columns: + if col not in df_describe_dict: + df_describe_dict[col] = {} + df_describe_dict[col]["nunique"] = df[col].nunique() + try: + df_describe_dict[col]["lengths"] = { + "max": df[col].str.len().max(), + "min": df[col].str.len().min(), + "mean": float(df[col].str.len().mean()), + } + except AttributeError: + pass + + for col in df.columns: + if col not in df_describe_dict: + df_describe_dict[col] = {} + df_describe_dict[col]["nulls"] = int(df[col].isnull().sum()) + + # select datetime64 columns + for col in df.select_dtypes(include=["datetime64"]).columns: + for key in df_describe_dict[col]: + df_describe_dict[col][key] = str(df_describe_dict[col][key]) + df[col] = df[col].astype(str) # shorten output + + # select all numeric columns and round + for col in df.select_dtypes(include=["int64", "float64"]).columns: + for key in df_describe_dict[col]: + df_describe_dict[col][key] = format_float(df_describe_dict[col][key]) + + for col in df.select_dtypes(include=["float64"]).columns: + df[col] = df[col].apply(format_float) + + return { + "summary": { + "total_table_cells": size, + "total_shape": shape, + "is_summarized": is_summarized, + }, + "stats": df_describe_dict, + } + + data = asyncio.to_thread(describe_data_sync, df) return data diff --git a/lumen/ai/views.py b/lumen/ai/views.py index 7b7bb179d..046715107 100644 --- a/lumen/ai/views.py +++ b/lumen/ai/views.py @@ -7,6 +7,8 @@ from panel.viewable import Viewer from param.parameterized import discard_events +from lumen.ai.utils import get_data + from ..base import Component from ..dashboard import load_yaml from ..downloads import Download @@ -75,7 +77,7 @@ def __init__(self, **params): ] self._last_output = {} - def _render_pipeline(self, pipeline): + async def _render_pipeline(self, pipeline): table = Table( pipeline=pipeline, pagination='remote', min_height=500, sizing_mode="stretch_both", stylesheets=[ @@ -104,7 +106,8 @@ def _render_pipeline(self, pipeline): else: sql_limit = None if sql_limit: - limited = len(pipeline.data) == sql_limit.limit + data = await get_data(pipeline) + limited = len(data) == sql_limit.limit if limited: def unlimit(e): sql_limit.limit = None if e.new else 1_000_000 @@ -139,7 +142,7 @@ async def _render_component(self): yaml_spec = load_yaml(self.spec) self.component = type(self.component).from_spec(yaml_spec) if isinstance(self.component, Pipeline): - output = self._render_pipeline(self.component) + output = await self._render_pipeline(self.component) else: output = self.component.__panel__() self._rendered = True @@ -225,7 +228,7 @@ async def _render_component(self): try: if self._rendered: pipeline.source = pipeline.source.create_sql_expr_source(tables={pipeline.table: self.spec}) - output = self._render_pipeline(pipeline) + output = await self._render_pipeline(pipeline) self._rendered = True self._last_output.clear() self._last_output[self.spec] = output