Skip to content

Commit

Permalink
Merge pull request #176 from mobidata-bw/primary-key
Browse files Browse the repository at this point in the history
Create primary key for new tables
  • Loading branch information
hbruch authored Nov 25, 2024
2 parents 8eba1bb + 0e0022f commit 1c265ae
Showing 1 changed file with 90 additions and 38 deletions.
128 changes: 90 additions & 38 deletions pipeline/resources/postgis_geopandas_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.

import csv
from contextlib import contextmanager
import re
from contextlib import closing, contextmanager
from io import StringIO
from typing import Any, Dict, Iterator, Optional, Sequence, cast

Expand All @@ -28,8 +29,6 @@
from sqlalchemy import create_engine
from sqlalchemy.engine import URL, Connection

# TODO figure out appropriate SQL injection mechanism while allowing dynamic table/column provision


@contextmanager
def connect_postgresql(config, schema='public') -> Iterator[Connection]:
Expand Down Expand Up @@ -65,28 +64,44 @@ class PostgreSQLPandasIOManager(ConfigurableIOManager): # type: ignore
def _config(self) -> Dict[str, Any]:
return self.dict()

def _assert_sql_safety(self, *expressions: str) -> None:
r"""
Raises a ValueError in case an expression contains characters other than
word or number chars or is of lenght 0 (^[\w\d]+$).
"""
for expression in expressions:
if not re.match(r'^[\w\d]+$', expression):
raise ValueError(f'Unexpected sql identifier {expression}')

def handle_output(self, context: OutputContext, obj: pandas.DataFrame):
schema, table = self._get_schema_table(context.asset_key)

if isinstance(obj, pandas.DataFrame):
with connect_postgresql(config=self._config) as con:
self._create_schema_if_not_exists(schema, con)
# just recreate table with empty frame (obj[:0]) and load later via copy_from
obj[:0].to_sql(
con=con,
name=table,
index=True,
schema=schema,
if_exists='replace',
)
table_exists = self._has_table(con, schema, table)
if table_exists:
self._truncate_table(schema, table, con)
else:
# create table with empty frame (obj[:0]) and load later via copy_from
obj[:0].to_sql(
con=con,
name=table,
index=True,
schema=schema,
if_exists='replace',
)
# table was just created, create primary key (to_sql doesn't create these,
# though index=True suggests this)
self._create_primary_key(schema, table, obj.index.names, con)
obj.reset_index()
sio = StringIO()
obj.to_csv(sio, sep='\t', na_rep='', header=False)
sio.seek(0)
c = con.connection.cursor()
# ignore mypy attribute check, as postgres cursor has custom extension to DBAPICursor: copy_expert
c.copy_expert(f"COPY {schema}.{table} FROM STDIN WITH (FORMAT csv, DELIMITER '\t')", sio) # type: ignore[attr-defined]
con.connection.commit()
with closing(con.connection.cursor()) as c:
# ignore mypy attribute check, as postgres cursor has custom extension to DBAPICursor: copy_expert
c.copy_expert(f"COPY {schema}.{table} FROM STDIN WITH (FORMAT csv, DELIMITER '\t')", sio) # type: ignore[attr-defined]
con.connection.commit()
context.add_output_metadata({'num_rows': len(obj), 'table_name': f'{schema}.{table}'})
elif obj is None:
self.delete_asset(context)
Expand Down Expand Up @@ -125,15 +140,29 @@ def delete_asset(self, context: OutputContext):
else:
raise Exception('Deletion of not-partitioned assets not yet supported.')

def _delete_partition(self, schema, table, partition_col_name, partition_key, con):
def _delete_partition(self, schema: str, table: str, partition_col_name: str, partition_key: str, con: Connection):
try:
self._assert_sql_safety(schema, table, partition_col_name, partition_key)
with closing(con.connection.cursor()) as c:
c.execute(f"DELETE FROM {schema}.{table} WHERE {partition_col_name}='{partition_key}'")
except UndefinedTable:
# TODO log debug info, asset did not exist, so nothing to
pass

def _truncate_table(self, schema: str, table: str, con: Connection):
try:
c = con.connection.cursor()
c.execute(f"DELETE FROM {schema}.{table} WHERE {partition_col_name}='{partition_key}'")
self._assert_sql_safety(schema, table)
with closing(con.connection.cursor()) as c:
c.execute(f'TRUNCATE TABLE {schema}.{table}')
except UndefinedTable:
# TODO log debug info, asset did not exist, so nothing to
con.connection.rollback()
pass

def _create_primary_key(self, schema: str, table: str, keys: list[str], con: Connection):
with closing(con.connection.cursor()) as c:
self._assert_sql_safety(schema, table, **keys)
c.execute(f'ALTER TABLE {schema}.{table} ADD PRIMARY KEY ({",".join(keys)})')

def _load_input(
self, con: Connection, table: str, schema: str, columns: Optional[Sequence[str]], context: InputContext
) -> pandas.DataFrame:
Expand All @@ -146,8 +175,9 @@ def _load_input(
con=con,
)

def _create_schema_if_not_exists(self, schema, con):
with con.connection.cursor() as c:
def _create_schema_if_not_exists(self, schema: str, con: Connection):
with closing(con.connection.cursor()) as c:
self._assert_sql_safety(schema)
c.execute(f'CREATE SCHEMA IF NOT EXISTS {schema}')

def _get_schema_table(self, asset_key):
Expand All @@ -159,9 +189,21 @@ def _get_select_statement(
schema: str,
columns: Optional[Sequence[str]],
):
self._assert_sql_safety(schema, table, **columns)
col_str = ', '.join(columns) if columns else '*'
return f'SELECT {col_str} FROM {schema}.{table}'

def _has_table(self, con: Connection, schema: str, table: str):
with closing(con.connection.cursor()) as c:
self._assert_sql_safety(schema, table)
c.execute(
f"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = '{schema}' AND table_name = '{table}'"
)
fetch_result = c.fetchone()
if fetch_result and fetch_result[0] == 1:
return True
return False


# need mypy to ignore following line due to https://github.com/dagster-io/dagster/issues/17443
class PostGISGeoPandasIOManager(PostgreSQLPandasIOManager): # type: ignore
Expand All @@ -173,25 +215,35 @@ def handle_output(self, context: OutputContext, obj: geopandas.GeoDataFrame):
if isinstance(obj, geopandas.GeoDataFrame):
with connect_postgresql(config=self._config) as con:
self._create_schema_if_not_exists(schema, con)
if context.has_partition_key:
# add additional column (name? for now just partition)
# to the frame and initialize with partition_name
# (name could become part of metadata, ob may be contained already)
partition_col_name = self._get_partition_expr(context)
partition_key = context.partition_key
obj[partition_col_name] = partition_key

# We leave other partions untouched, but need to delete data from this
# partition before we append again.
if_exists_action = 'append'
self._delete_partition(schema, table, partition_col_name, partition_key, con)
else:
# All data can be replaced (e.g. deleted before insertion).
# geopandas will take care of this.
if_exists_action = 'replace'
table_exists = self._has_table(con, schema, table)
if table_exists:
if context.has_partition_key:
# add additional column (name? for now just partition)
# to the frame and initialize with partition_name
# (name could become part of metadata, ob may be contained already)
partition_col_name = self._get_partition_expr(context)
partition_key = context.partition_key
obj[partition_col_name] = partition_key

# We leave other partitions untouched, but need to delete data from this
# partition before we append again.
self._delete_partition(schema, table, partition_col_name, partition_key, con)
else:
# All data can be replaced (i.e. truncated before insertion).
# geopandas will take care of this.
self._truncate_table(schema, table, con)
# while writing standard pandas.DataFrames to sql tables is database agnostic and performed
# internally via SQLAlchemy, writing a GeoDataFrame requires explicitly using to_postgis. See
# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_sql.html and
# https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.to_postgis.html
obj.to_postgis(
con=con, name=table, index=True, schema=schema, if_exists=if_exists_action, chunksize=self.chunksize
con=con, name=table, index=True, schema=schema, if_exists='append', chunksize=self.chunksize
)
if not table_exists:
# table was just created, create primary key (to_postgis doesn't create these,
# though index=True suggests this)
self._create_primary_key(schema, table, obj.index.names, con)
con.connection.commit()
context.add_output_metadata({'num_rows': len(obj), 'table_name': f'{schema}.{table}'})
else:
super().handle_output(context, obj)
Expand Down

0 comments on commit 1c265ae

Please sign in to comment.