From f69fdb98dc4bab78316d8e1e457979ec3ccdcb95 Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Thu, 26 Sep 2024 11:47:52 +0200 Subject: [PATCH] Support DuckDB spatial and shapefile uploads (#715) --- lumen/ai/agents.py | 4 ++-- lumen/ai/assistant.py | 4 ++-- lumen/ai/controls.py | 38 +++++++++++++++++++++++++++++++++++--- lumen/sources/duckdb.py | 11 ++++++++++- 4 files changed, 49 insertions(+), 8 deletions(-) diff --git a/lumen/ai/agents.py b/lumen/ai/agents.py index 186663108..acbadaa2d 100644 --- a/lumen/ai/agents.py +++ b/lumen/ai/agents.py @@ -99,7 +99,7 @@ def _exception_handler(exception): state.config.raise_with_notifications = True @classmethod - def applies(cls) -> bool: + async def applies(cls) -> bool: """ Additional checks to determine if the agent should be used. """ @@ -464,7 +464,7 @@ class TableListAgent(LumenBaseAgent): requires = param.List(default=["current_source"], readonly=True) @classmethod - def applies(cls) -> bool: + async def applies(cls) -> bool: source = memory.get("current_source") if not source: return True # source not loaded yet; always apply diff --git a/lumen/ai/assistant.py b/lumen/ai/assistant.py index f04f19011..38398d444 100644 --- a/lumen/ai/assistant.py +++ b/lumen/ai/assistant.py @@ -149,7 +149,7 @@ def download_notebook(): if "current_source" in memory and "available_sources" not in memory: memory["available_sources"] = [memory["current_source"]] - elif "current_source" not in memory and "available_sources" in memory: + elif "current_source" not in memory and memory.get("available_sources"): memory["current_source"] = memory["available_sources"][0] elif "available_sources" not in memory: memory["available_sources"] = [] @@ -343,7 +343,7 @@ async def _create_valid_agent(self, messages, system, agent_model, errors=None): return out async def _choose_agent(self, messages: list | str, agents: list[Agent]): - agents = [agent for agent in agents if agent.applies()] + agents = [agent for agent in agents if await agent.applies()] agent_names = tuple(sagent.name[:-5] for sagent in agents) if len(agent_names) == 0: raise ValueError("No agents available to choose from.") diff --git a/lumen/ai/controls.py b/lumen/ai/controls.py index d36703c37..db28a113c 100644 --- a/lumen/ai/controls.py +++ b/lumen/ai/controls.py @@ -1,4 +1,5 @@ import io +import zipfile import pandas as pd import panel as pn @@ -147,8 +148,12 @@ def _add_table( file: io.BytesIO | io.StringIO, table_controls: TableControls, ): + conn = duckdb_source._connection extension = table_controls.extension table = table_controls.table + sql_expr = f"SELECT * FROM {table}" + params = {} + conversion = None if extension.endswith("csv"): df = pd.read_csv(file, parse_dates=True) elif extension.endswith(("parq", "parquet")): @@ -158,14 +163,41 @@ def _add_table( elif extension.endswith("xlsx"): sheet = table_controls.sheet df = pd.read_excel(file, sheet_name=sheet) + elif extension.endswith(('geojson', 'wkt', 'zip')): + if extension.endswith('zip'): + zf = zipfile.ZipFile(file) + if not any(f.filename.endswith('shp') for f in zf.filelist): + raise ValueError("Could not interpret zip file contents") + file.seek(0) + import geopandas as gpd + geo_df = gpd.read_file(file) + df = pd.DataFrame(geo_df) + df['geometry'] = geo_df['geometry'].to_wkb() + params['initializers'] = init = [""" + INSTALL spatial; + LOAD spatial; + """] + conn.execute(init[0]) + cols = ', '.join(f'"{c}"' for c in df.columns if c != 'geometry') + conversion = f'CREATE TEMP TABLE {table} AS SELECT {cols}, ST_GeomFromWKB(geometry) as geometry FROM {table}_temp' else: raise ValueError(f"Unsupported file extension: {extension}") - duckdb_source._connection.from_df(df).to_view(table) - duckdb_source.tables[table] = f"SELECT * FROM {table}" + duckdb_source.param.update(params) + df_rel = conn.from_df(df) + if conversion: + conn.register(f'{table}_temp', df_rel) + conn.execute(conversion) + conn.unregister(f'{table}_temp') + else: + df_rel.to_view(table) + duckdb_source.tables[table] = sql_expr memory["current_source"] = duckdb_source memory["current_table"] = table - memory["available_sources"].append(duckdb_source) + if "available_sources" in memory: + memory["available_sources"].append(duckdb_source) + else: + memory["available_sources"] = [duckdb_source] self._last_table = table @param.depends("add", watch=True) diff --git a/lumen/sources/duckdb.py b/lumen/sources/duckdb.py index 47ae50579..c025225fb 100644 --- a/lumen/sources/duckdb.py +++ b/lumen/sources/duckdb.py @@ -258,7 +258,16 @@ def get(self, table, **query): sql_transforms = [SQLFilter(conditions=conditions)] + sql_transforms for st in sql_transforms: sql_expr = st.apply(sql_expr) - df = self._connection.execute(sql_expr).fetch_df(date_as_object=True) + rel = self._connection.execute(sql_expr) + has_geom = any(d[0] == 'geometry' and d[1] == 'BINARY' for d in rel.description) + df = rel.fetch_df(date_as_object=True) + if has_geom: + import geopandas as gpd + geom = self._connection.execute( + f'SELECT ST_AsWKB(geometry::GEOMETRY) as geometry FROM ({sql_expr})' + ).fetch_df() + df['geometry'] = gpd.GeoSeries.from_wkb(geom.geometry.apply(bytes)) + df = gpd.GeoDataFrame(df) if not self.filter_in_sql: df = Filter.apply_to(df, conditions=conditions) return df