Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create primary key for new tables #176

Merged
merged 4 commits into from
Nov 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 90 additions & 38 deletions pipeline/resources/postgis_geopandas_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.

import csv
from contextlib import contextmanager
import re
from contextlib import closing, contextmanager
from io import StringIO
from typing import Any, Dict, Iterator, Optional, Sequence, cast

Expand All @@ -28,8 +29,6 @@
from sqlalchemy import create_engine
from sqlalchemy.engine import URL, Connection

# TODO figure out appropriate SQL injection mechanism while allowing dynamic table/column provision


@contextmanager
def connect_postgresql(config, schema='public') -> Iterator[Connection]:
Expand Down Expand Up @@ -65,28 +64,44 @@ class PostgreSQLPandasIOManager(ConfigurableIOManager): # type: ignore
def _config(self) -> Dict[str, Any]:
return self.dict()

def _assert_sql_safety(self, *expressions: str) -> None:
r"""
Raises a ValueError in case an expression contains characters other than
word or number chars or is of lenght 0 (^[\w\d]+$).
"""
for expression in expressions:
if not re.match(r'^[\w\d]+$', expression):
raise ValueError(f'Unexpected sql identifier {expression}')

def handle_output(self, context: OutputContext, obj: pandas.DataFrame):
schema, table = self._get_schema_table(context.asset_key)

if isinstance(obj, pandas.DataFrame):
with connect_postgresql(config=self._config) as con:
self._create_schema_if_not_exists(schema, con)
# just recreate table with empty frame (obj[:0]) and load later via copy_from
obj[:0].to_sql(
con=con,
name=table,
index=True,
schema=schema,
if_exists='replace',
)
table_exists = self._has_table(con, schema, table)
derhuerst marked this conversation as resolved.
Show resolved Hide resolved
if table_exists:
self._truncate_table(schema, table, con)
else:
# create table with empty frame (obj[:0]) and load later via copy_from
obj[:0].to_sql(
con=con,
name=table,
index=True,
schema=schema,
if_exists='replace',
)
# table was just created, create primary key (to_sql doesn't create these,
# though index=True suggests this)
self._create_primary_key(schema, table, obj.index.names, con)
obj.reset_index()
sio = StringIO()
obj.to_csv(sio, sep='\t', na_rep='', header=False)
sio.seek(0)
c = con.connection.cursor()
# ignore mypy attribute check, as postgres cursor has custom extension to DBAPICursor: copy_expert
c.copy_expert(f"COPY {schema}.{table} FROM STDIN WITH (FORMAT csv, DELIMITER '\t')", sio) # type: ignore[attr-defined]
con.connection.commit()
with closing(con.connection.cursor()) as c:
# ignore mypy attribute check, as postgres cursor has custom extension to DBAPICursor: copy_expert
c.copy_expert(f"COPY {schema}.{table} FROM STDIN WITH (FORMAT csv, DELIMITER '\t')", sio) # type: ignore[attr-defined]
con.connection.commit()
context.add_output_metadata({'num_rows': len(obj), 'table_name': f'{schema}.{table}'})
elif obj is None:
self.delete_asset(context)
Expand Down Expand Up @@ -125,15 +140,29 @@ def delete_asset(self, context: OutputContext):
else:
raise Exception('Deletion of not-partitioned assets not yet supported.')

def _delete_partition(self, schema, table, partition_col_name, partition_key, con):
def _delete_partition(self, schema: str, table: str, partition_col_name: str, partition_key: str, con: Connection):
try:
self._assert_sql_safety(schema, table, partition_col_name, partition_key)
with closing(con.connection.cursor()) as c:
c.execute(f"DELETE FROM {schema}.{table} WHERE {partition_col_name}='{partition_key}'")
except UndefinedTable:
# TODO log debug info, asset did not exist, so nothing to
pass

def _truncate_table(self, schema: str, table: str, con: Connection):
try:
c = con.connection.cursor()
c.execute(f"DELETE FROM {schema}.{table} WHERE {partition_col_name}='{partition_key}'")
self._assert_sql_safety(schema, table)
with closing(con.connection.cursor()) as c:
c.execute(f'TRUNCATE TABLE {schema}.{table}')
except UndefinedTable:
derhuerst marked this conversation as resolved.
Show resolved Hide resolved
# TODO log debug info, asset did not exist, so nothing to
con.connection.rollback()
pass

def _create_primary_key(self, schema: str, table: str, keys: list[str], con: Connection):
with closing(con.connection.cursor()) as c:
self._assert_sql_safety(schema, table, **keys)
c.execute(f'ALTER TABLE {schema}.{table} ADD PRIMARY KEY ({",".join(keys)})')
derhuerst marked this conversation as resolved.
Show resolved Hide resolved

def _load_input(
self, con: Connection, table: str, schema: str, columns: Optional[Sequence[str]], context: InputContext
) -> pandas.DataFrame:
Expand All @@ -146,8 +175,9 @@ def _load_input(
con=con,
)

def _create_schema_if_not_exists(self, schema, con):
with con.connection.cursor() as c:
def _create_schema_if_not_exists(self, schema: str, con: Connection):
with closing(con.connection.cursor()) as c:
self._assert_sql_safety(schema)
c.execute(f'CREATE SCHEMA IF NOT EXISTS {schema}')

def _get_schema_table(self, asset_key):
Expand All @@ -159,9 +189,21 @@ def _get_select_statement(
schema: str,
columns: Optional[Sequence[str]],
):
self._assert_sql_safety(schema, table, **columns)
col_str = ', '.join(columns) if columns else '*'
return f'SELECT {col_str} FROM {schema}.{table}'

def _has_table(self, con: Connection, schema: str, table: str):
with closing(con.connection.cursor()) as c:
self._assert_sql_safety(schema, table)
c.execute(
f"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = '{schema}' AND table_name = '{table}'"
)
fetch_result = c.fetchone()
if fetch_result and fetch_result[0] == 1:
return True
return False


# need mypy to ignore following line due to https://github.com/dagster-io/dagster/issues/17443
class PostGISGeoPandasIOManager(PostgreSQLPandasIOManager): # type: ignore
Expand All @@ -173,25 +215,35 @@ def handle_output(self, context: OutputContext, obj: geopandas.GeoDataFrame):
if isinstance(obj, geopandas.GeoDataFrame):
with connect_postgresql(config=self._config) as con:
self._create_schema_if_not_exists(schema, con)
if context.has_partition_key:
# add additional column (name? for now just partition)
# to the frame and initialize with partition_name
# (name could become part of metadata, ob may be contained already)
partition_col_name = self._get_partition_expr(context)
partition_key = context.partition_key
obj[partition_col_name] = partition_key

# We leave other partions untouched, but need to delete data from this
# partition before we append again.
if_exists_action = 'append'
self._delete_partition(schema, table, partition_col_name, partition_key, con)
else:
# All data can be replaced (e.g. deleted before insertion).
# geopandas will take care of this.
if_exists_action = 'replace'
table_exists = self._has_table(con, schema, table)
if table_exists:
if context.has_partition_key:
# add additional column (name? for now just partition)
# to the frame and initialize with partition_name
# (name could become part of metadata, ob may be contained already)
partition_col_name = self._get_partition_expr(context)
partition_key = context.partition_key
obj[partition_col_name] = partition_key

# We leave other partitions untouched, but need to delete data from this
# partition before we append again.
self._delete_partition(schema, table, partition_col_name, partition_key, con)
else:
# All data can be replaced (i.e. truncated before insertion).
# geopandas will take care of this.
self._truncate_table(schema, table, con)
# while writing standard pandas.DataFrames to sql tables is database agnostic and performed
# internally via SQLAlchemy, writing a GeoDataFrame requires explicitly using to_postgis. See
# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_sql.html and
# https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.to_postgis.html
obj.to_postgis(
con=con, name=table, index=True, schema=schema, if_exists=if_exists_action, chunksize=self.chunksize
con=con, name=table, index=True, schema=schema, if_exists='append', chunksize=self.chunksize
)
if not table_exists:
# table was just created, create primary key (to_postgis doesn't create these,
# though index=True suggests this)
self._create_primary_key(schema, table, obj.index.names, con)
con.connection.commit()
context.add_output_metadata({'num_rows': len(obj), 'table_name': f'{schema}.{table}'})
else:
super().handle_output(context, obj)
Expand Down