From 4c7d9f10b6b05efc1bfc37028111ec9fd10c2f3d Mon Sep 17 00:00:00 2001 From: Holger Bruch Date: Mon, 12 Feb 2024 16:52:49 +0100 Subject: [PATCH 1/4] create primary key for new tables --- .../resources/postgis_geopandas_io_manager.py | 86 +++++++++++++------ 1 file changed, 61 insertions(+), 25 deletions(-) diff --git a/pipeline/resources/postgis_geopandas_io_manager.py b/pipeline/resources/postgis_geopandas_io_manager.py index 9fe011a..dd5a30b 100644 --- a/pipeline/resources/postgis_geopandas_io_manager.py +++ b/pipeline/resources/postgis_geopandas_io_manager.py @@ -71,14 +71,21 @@ def handle_output(self, context: OutputContext, obj: pandas.DataFrame): if isinstance(obj, pandas.DataFrame): with connect_postgresql(config=self._config) as con: self._create_schema_if_not_exists(schema, con) - # just recreate table with empty frame (obj[:0]) and load later via copy_from - obj[:0].to_sql( - con=con, - name=table, - index=True, - schema=schema, - if_exists='replace', - ) + table_exists = self._has_table(con, schema, table) + if table_exists: + self._truncate_table(schema, table, con) + else: + # create table with empty frame (obj[:0]) and load later via copy_from + obj[:0].to_sql( + con=con, + name=table, + index=True, + schema=schema, + if_exists='replace', + ) + # table was just created, create primary key (to_postgis doesn't create these, + # though index=True suggests this) + self._create_primary_key(schema, table, obj.index.names, con) obj.reset_index() sio = StringIO() obj.to_csv(sio, sep='\t', na_rep='', header=False) @@ -134,6 +141,19 @@ def _delete_partition(self, schema, table, partition_col_name, partition_key, co con.connection.rollback() pass + def _truncate_table(self, schema, table, con): + try: + c = con.connection.cursor() + c.execute(f'TRUNCATE TABLE {schema}.{table}') + except UndefinedTable: + # TODO log debug info, asset did not exist, so nothing to + con.connection.rollback() + pass + + def _create_primary_key(self, schema, table, keys, con): + with con.connection.cursor() as c: + c.execute(f'ALTER TABLE {schema}.{table} ADD PRIMARY KEY ({",".join(keys)})') + def _load_input( self, con: Connection, table: str, schema: str, columns: Optional[Sequence[str]], context: InputContext ) -> pandas.DataFrame: @@ -162,6 +182,15 @@ def _get_select_statement( col_str = ', '.join(columns) if columns else '*' return f'SELECT {col_str} FROM {schema}.{table}' + def _has_table(self, con, schema, table): + with con.connection.cursor() as c: + c.execute( + f"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = '{schema}' AND table_name = '{table}'" + ) + if c.fetchone()[0] == 1: + return True + return False + # need mypy to ignore following line due to https://github.com/dagster-io/dagster/issues/17443 class PostGISGeoPandasIOManager(PostgreSQLPandasIOManager): # type: ignore @@ -173,25 +202,32 @@ def handle_output(self, context: OutputContext, obj: geopandas.GeoDataFrame): if isinstance(obj, geopandas.GeoDataFrame): with connect_postgresql(config=self._config) as con: self._create_schema_if_not_exists(schema, con) - if context.has_partition_key: - # add additional column (name? for now just partition) - # to the frame and initialize with partition_name - # (name could become part of metadata, ob may be contained already) - partition_col_name = self._get_partition_expr(context) - partition_key = context.partition_key - obj[partition_col_name] = partition_key - - # We leave other partions untouched, but need to delete data from this - # partition before we append again. - if_exists_action = 'append' - self._delete_partition(schema, table, partition_col_name, partition_key, con) - else: - # All data can be replaced (e.g. deleted before insertion). - # geopandas will take care of this. - if_exists_action = 'replace' + table_exists = self._has_table(con, schema, table) + if table_exists: + if context.has_partition_key: + # add additional column (name? for now just partition) + # to the frame and initialize with partition_name + # (name could become part of metadata, ob may be contained already) + partition_col_name = self._get_partition_expr(context) + partition_key = context.partition_key + obj[partition_col_name] = partition_key + + # We leave other partitions untouched, but need to delete data from this + # partition before we append again. + self._delete_partition(schema, table, partition_col_name, partition_key, con) + else: + # All data can be replaced (i.e. truncated before insertion). + # geopandas will take care of this. + self._truncate_table(schema, table, con) + obj.to_postgis( - con=con, name=table, index=True, schema=schema, if_exists=if_exists_action, chunksize=self.chunksize + con=con, name=table, index=True, schema=schema, if_exists='append', chunksize=self.chunksize ) + if not table_exists: + # table was just created, create primary key (to_postgis doesn't create these, + # though index=True suggests this) + self._create_primary_key(schema, table, obj.index.names, con) + con.connection.commit() context.add_output_metadata({'num_rows': len(obj), 'table_name': f'{schema}.{table}'}) else: super().handle_output(context, obj) From 489cd1799354cd3761ed8126936bb2e0149bd6dc Mon Sep 17 00:00:00 2001 From: Holger Bruch Date: Wed, 13 Nov 2024 09:57:43 +0100 Subject: [PATCH 2/4] Add additional typing information cursor handling explicit --- .../resources/postgis_geopandas_io_manager.py | 39 +++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/pipeline/resources/postgis_geopandas_io_manager.py b/pipeline/resources/postgis_geopandas_io_manager.py index dd5a30b..9ebfd89 100644 --- a/pipeline/resources/postgis_geopandas_io_manager.py +++ b/pipeline/resources/postgis_geopandas_io_manager.py @@ -13,7 +13,7 @@ # limitations under the License. import csv -from contextlib import contextmanager +from contextlib import closing, contextmanager from io import StringIO from typing import Any, Dict, Iterator, Optional, Sequence, cast @@ -90,10 +90,10 @@ def handle_output(self, context: OutputContext, obj: pandas.DataFrame): sio = StringIO() obj.to_csv(sio, sep='\t', na_rep='', header=False) sio.seek(0) - c = con.connection.cursor() - # ignore mypy attribute check, as postgres cursor has custom extension to DBAPICursor: copy_expert - c.copy_expert(f"COPY {schema}.{table} FROM STDIN WITH (FORMAT csv, DELIMITER '\t')", sio) # type: ignore[attr-defined] - con.connection.commit() + with closing(con.connection.cursor()) as c: + # ignore mypy attribute check, as postgres cursor has custom extension to DBAPICursor: copy_expert + c.copy_expert(f"COPY {schema}.{table} FROM STDIN WITH (FORMAT csv, DELIMITER '\t')", sio) # type: ignore[attr-defined] + con.connection.commit() context.add_output_metadata({'num_rows': len(obj), 'table_name': f'{schema}.{table}'}) elif obj is None: self.delete_asset(context) @@ -132,26 +132,24 @@ def delete_asset(self, context: OutputContext): else: raise Exception('Deletion of not-partitioned assets not yet supported.') - def _delete_partition(self, schema, table, partition_col_name, partition_key, con): + def _delete_partition(self, schema: str, table: str, partition_col_name: str, partition_key: str, con: Connection): try: - c = con.connection.cursor() - c.execute(f"DELETE FROM {schema}.{table} WHERE {partition_col_name}='{partition_key}'") + with closing(con.connection.cursor()) as c: + c.execute(f"DELETE FROM {schema}.{table} WHERE {partition_col_name}='{partition_key}'") except UndefinedTable: # TODO log debug info, asset did not exist, so nothing to - con.connection.rollback() pass - def _truncate_table(self, schema, table, con): + def _truncate_table(self, schema: str, table: str, con: Connection): try: - c = con.connection.cursor() - c.execute(f'TRUNCATE TABLE {schema}.{table}') + with closing(con.connection.cursor()) as c: + c.execute(f'TRUNCATE TABLE {schema}.{table}') except UndefinedTable: # TODO log debug info, asset did not exist, so nothing to - con.connection.rollback() pass - def _create_primary_key(self, schema, table, keys, con): - with con.connection.cursor() as c: + def _create_primary_key(self, schema: str, table: str, keys: list[str], con: Connection): + with closing(con.connection.cursor()) as c: c.execute(f'ALTER TABLE {schema}.{table} ADD PRIMARY KEY ({",".join(keys)})') def _load_input( @@ -166,8 +164,8 @@ def _load_input( con=con, ) - def _create_schema_if_not_exists(self, schema, con): - with con.connection.cursor() as c: + def _create_schema_if_not_exists(self, schema: str, con: Connection): + with closing(con.connection.cursor()) as c: c.execute(f'CREATE SCHEMA IF NOT EXISTS {schema}') def _get_schema_table(self, asset_key): @@ -182,12 +180,13 @@ def _get_select_statement( col_str = ', '.join(columns) if columns else '*' return f'SELECT {col_str} FROM {schema}.{table}' - def _has_table(self, con, schema, table): - with con.connection.cursor() as c: + def _has_table(self, con: Connection, schema: str, table: str): + with closing(con.connection.cursor()) as c: c.execute( f"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = '{schema}' AND table_name = '{table}'" ) - if c.fetchone()[0] == 1: + fetch_result = c.fetchone() + if fetch_result and fetch_result[0] == 1: return True return False From 8f47b18d71461e80a4764ef96f18479d03960536 Mon Sep 17 00:00:00 2001 From: Holger Bruch Date: Wed, 13 Nov 2024 09:59:24 +0100 Subject: [PATCH 3/4] Check strings used in sql statements for expected chars --- .../resources/postgis_geopandas_io_manager.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/pipeline/resources/postgis_geopandas_io_manager.py b/pipeline/resources/postgis_geopandas_io_manager.py index 9ebfd89..d3c62fb 100644 --- a/pipeline/resources/postgis_geopandas_io_manager.py +++ b/pipeline/resources/postgis_geopandas_io_manager.py @@ -13,6 +13,7 @@ # limitations under the License. import csv +import re from contextlib import closing, contextmanager from io import StringIO from typing import Any, Dict, Iterator, Optional, Sequence, cast @@ -28,8 +29,6 @@ from sqlalchemy import create_engine from sqlalchemy.engine import URL, Connection -# TODO figure out appropriate SQL injection mechanism while allowing dynamic table/column provision - @contextmanager def connect_postgresql(config, schema='public') -> Iterator[Connection]: @@ -65,6 +64,15 @@ class PostgreSQLPandasIOManager(ConfigurableIOManager): # type: ignore def _config(self) -> Dict[str, Any]: return self.dict() + def _assert_sql_safety(self, *expressions: str) -> None: + r""" + Raises a ValueError in case an expression contains characters other than + word or number chars or is of lenght 0 (^[\w\d]+$). + """ + for expression in expressions: + if not re.match(r'^[\w\d]+$', expression): + raise ValueError(f'Unexpected sql identifier {expression}') + def handle_output(self, context: OutputContext, obj: pandas.DataFrame): schema, table = self._get_schema_table(context.asset_key) @@ -83,7 +91,7 @@ def handle_output(self, context: OutputContext, obj: pandas.DataFrame): schema=schema, if_exists='replace', ) - # table was just created, create primary key (to_postgis doesn't create these, + # table was just created, create primary key (to_sql doesn't create these, # though index=True suggests this) self._create_primary_key(schema, table, obj.index.names, con) obj.reset_index() @@ -134,6 +142,7 @@ def delete_asset(self, context: OutputContext): def _delete_partition(self, schema: str, table: str, partition_col_name: str, partition_key: str, con: Connection): try: + self._assert_sql_safety(schema, table, partition_col_name, partition_key) with closing(con.connection.cursor()) as c: c.execute(f"DELETE FROM {schema}.{table} WHERE {partition_col_name}='{partition_key}'") except UndefinedTable: @@ -142,6 +151,7 @@ def _delete_partition(self, schema: str, table: str, partition_col_name: str, pa def _truncate_table(self, schema: str, table: str, con: Connection): try: + self._assert_sql_safety(schema, table) with closing(con.connection.cursor()) as c: c.execute(f'TRUNCATE TABLE {schema}.{table}') except UndefinedTable: @@ -150,6 +160,7 @@ def _truncate_table(self, schema: str, table: str, con: Connection): def _create_primary_key(self, schema: str, table: str, keys: list[str], con: Connection): with closing(con.connection.cursor()) as c: + self._assert_sql_safety(schema, table, **keys) c.execute(f'ALTER TABLE {schema}.{table} ADD PRIMARY KEY ({",".join(keys)})') def _load_input( @@ -166,6 +177,7 @@ def _load_input( def _create_schema_if_not_exists(self, schema: str, con: Connection): with closing(con.connection.cursor()) as c: + self._assert_sql_safety(schema) c.execute(f'CREATE SCHEMA IF NOT EXISTS {schema}') def _get_schema_table(self, asset_key): @@ -177,11 +189,13 @@ def _get_select_statement( schema: str, columns: Optional[Sequence[str]], ): + self._assert_sql_safety(schema, table, **columns) col_str = ', '.join(columns) if columns else '*' return f'SELECT {col_str} FROM {schema}.{table}' def _has_table(self, con: Connection, schema: str, table: str): with closing(con.connection.cursor()) as c: + self._assert_sql_safety(schema, table) c.execute( f"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = '{schema}' AND table_name = '{table}'" ) From 0e0022f2742410cf8b086ee8ab1ecfc7ce4ae90a Mon Sep 17 00:00:00 2001 From: Holger Bruch Date: Wed, 13 Nov 2024 09:59:51 +0100 Subject: [PATCH 4/4] Add comment explaining to_sql/to_postgis --- pipeline/resources/postgis_geopandas_io_manager.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pipeline/resources/postgis_geopandas_io_manager.py b/pipeline/resources/postgis_geopandas_io_manager.py index d3c62fb..3591d82 100644 --- a/pipeline/resources/postgis_geopandas_io_manager.py +++ b/pipeline/resources/postgis_geopandas_io_manager.py @@ -232,7 +232,10 @@ def handle_output(self, context: OutputContext, obj: geopandas.GeoDataFrame): # All data can be replaced (i.e. truncated before insertion). # geopandas will take care of this. self._truncate_table(schema, table, con) - + # while writing standard pandas.DataFrames to sql tables is database agnostic and performed + # internally via SQLAlchemy, writing a GeoDataFrame requires explicitly using to_postgis. See + # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_sql.html and + # https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.to_postgis.html obj.to_postgis( con=con, name=table, index=True, schema=schema, if_exists='append', chunksize=self.chunksize )