Skip to content

Commit

Permalink
Support DuckDB spatial and shapefile uploads (#715)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr authored Sep 26, 2024
1 parent b777a29 commit f69fdb9
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 8 deletions.
4 changes: 2 additions & 2 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _exception_handler(exception):
state.config.raise_with_notifications = True

@classmethod
def applies(cls) -> bool:
async def applies(cls) -> bool:
"""
Additional checks to determine if the agent should be used.
"""
Expand Down Expand Up @@ -464,7 +464,7 @@ class TableListAgent(LumenBaseAgent):
requires = param.List(default=["current_source"], readonly=True)

@classmethod
def applies(cls) -> bool:
async def applies(cls) -> bool:
source = memory.get("current_source")
if not source:
return True # source not loaded yet; always apply
Expand Down
4 changes: 2 additions & 2 deletions lumen/ai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def download_notebook():

if "current_source" in memory and "available_sources" not in memory:
memory["available_sources"] = [memory["current_source"]]
elif "current_source" not in memory and "available_sources" in memory:
elif "current_source" not in memory and memory.get("available_sources"):
memory["current_source"] = memory["available_sources"][0]
elif "available_sources" not in memory:
memory["available_sources"] = []
Expand Down Expand Up @@ -343,7 +343,7 @@ async def _create_valid_agent(self, messages, system, agent_model, errors=None):
return out

async def _choose_agent(self, messages: list | str, agents: list[Agent]):
agents = [agent for agent in agents if agent.applies()]
agents = [agent for agent in agents if await agent.applies()]
agent_names = tuple(sagent.name[:-5] for sagent in agents)
if len(agent_names) == 0:
raise ValueError("No agents available to choose from.")
Expand Down
38 changes: 35 additions & 3 deletions lumen/ai/controls.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import io
import zipfile

import pandas as pd
import panel as pn
Expand Down Expand Up @@ -147,8 +148,12 @@ def _add_table(
file: io.BytesIO | io.StringIO,
table_controls: TableControls,
):
conn = duckdb_source._connection
extension = table_controls.extension
table = table_controls.table
sql_expr = f"SELECT * FROM {table}"
params = {}
conversion = None
if extension.endswith("csv"):
df = pd.read_csv(file, parse_dates=True)
elif extension.endswith(("parq", "parquet")):
Expand All @@ -158,14 +163,41 @@ def _add_table(
elif extension.endswith("xlsx"):
sheet = table_controls.sheet
df = pd.read_excel(file, sheet_name=sheet)
elif extension.endswith(('geojson', 'wkt', 'zip')):
if extension.endswith('zip'):
zf = zipfile.ZipFile(file)
if not any(f.filename.endswith('shp') for f in zf.filelist):
raise ValueError("Could not interpret zip file contents")
file.seek(0)
import geopandas as gpd
geo_df = gpd.read_file(file)
df = pd.DataFrame(geo_df)
df['geometry'] = geo_df['geometry'].to_wkb()
params['initializers'] = init = ["""
INSTALL spatial;
LOAD spatial;
"""]
conn.execute(init[0])
cols = ', '.join(f'"{c}"' for c in df.columns if c != 'geometry')
conversion = f'CREATE TEMP TABLE {table} AS SELECT {cols}, ST_GeomFromWKB(geometry) as geometry FROM {table}_temp'
else:
raise ValueError(f"Unsupported file extension: {extension}")

duckdb_source._connection.from_df(df).to_view(table)
duckdb_source.tables[table] = f"SELECT * FROM {table}"
duckdb_source.param.update(params)
df_rel = conn.from_df(df)
if conversion:
conn.register(f'{table}_temp', df_rel)
conn.execute(conversion)
conn.unregister(f'{table}_temp')
else:
df_rel.to_view(table)
duckdb_source.tables[table] = sql_expr
memory["current_source"] = duckdb_source
memory["current_table"] = table
memory["available_sources"].append(duckdb_source)
if "available_sources" in memory:
memory["available_sources"].append(duckdb_source)
else:
memory["available_sources"] = [duckdb_source]
self._last_table = table

@param.depends("add", watch=True)
Expand Down
11 changes: 10 additions & 1 deletion lumen/sources/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,16 @@ def get(self, table, **query):
sql_transforms = [SQLFilter(conditions=conditions)] + sql_transforms
for st in sql_transforms:
sql_expr = st.apply(sql_expr)
df = self._connection.execute(sql_expr).fetch_df(date_as_object=True)
rel = self._connection.execute(sql_expr)
has_geom = any(d[0] == 'geometry' and d[1] == 'BINARY' for d in rel.description)
df = rel.fetch_df(date_as_object=True)
if has_geom:
import geopandas as gpd
geom = self._connection.execute(
f'SELECT ST_AsWKB(geometry::GEOMETRY) as geometry FROM ({sql_expr})'
).fetch_df()
df['geometry'] = gpd.GeoSeries.from_wkb(geom.geometry.apply(bytes))
df = gpd.GeoDataFrame(df)
if not self.filter_in_sql:
df = Filter.apply_to(df, conditions=conditions)
return df
Expand Down

0 comments on commit f69fdb9

Please sign in to comment.