diff --git a/pyproject.toml b/pyproject.toml index 33f2876..a769519 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ extension-pkg-allow-list = [ markers = [ "mysql_live", "postgres_live", + "sqlite_live", ] [tool.mypy] diff --git a/subsetter/config_model.py b/subsetter/config_model.py index 8306de5..ee16c08 100644 --- a/subsetter/config_model.py +++ b/subsetter/config_model.py @@ -82,6 +82,7 @@ class MultiplicityConfig(ForbidBaseModel): filters: Dict[str, List[FilterConfig]] = {} # type: ignore multiplicity: MultiplicityConfig = MultiplicityConfig() infer_foreign_keys: Literal["none", "schema", "all"] = "none" + compact_keys: bool = False class SubsetterConfig(ForbidBaseModel): diff --git a/subsetter/sampler.py b/subsetter/sampler.py index 5a46c7c..8f7b685 100644 --- a/subsetter/sampler.py +++ b/subsetter/sampler.py @@ -4,6 +4,7 @@ import logging import os import re +import uuid from typing import Any, Dict, Iterable, List, Optional, Set, Tuple import sqlalchemy as sa @@ -38,54 +39,70 @@ def tqdm(x, **_): DESTINATION_BUFFER_SIZE = 1024 -# pylint: disable=too-many-ancestors,abstract-method -class TemporaryTable(Executable, ClauseElement): - inherit_cache = True - - TEMP_ID = "_tmp_subsetter_" - - def __init__( - self, schema: str, name: str, select: sa.Select, index: int = 0 - ) -> None: - self.schema = schema - self.name = name - self.select = select - self.index = index - self.table_obj: sa.Table - - def _temporary_table_compile_generic(self, compiler: SQLCompiler, **_) -> str: - name = self.name + self.TEMP_ID + str(self.index) - schema_enc = compiler.dialect.identifier_preparer.quote(self.schema) - name_enc = compiler.dialect.identifier_preparer.quote(name) - select_stmt = compiler.process(self.select) - self.table_obj = sa.Table( - name, - sa.MetaData(), - *(sa.Column(col.name, col.type) for col in self.select.selected_columns), - schema=self.schema, - ) - return f"CREATE TEMPORARY TABLE {schema_enc}.{name_enc} AS {select_stmt}" - - def _temporary_table_compile_no_schema(self, compiler: SQLCompiler, **_) -> str: - """ - Postgres creates temporary tables in a special schema. We make the table - name incorporate the schema name to compensate and avoid collisions. - """ - name = self.schema + self.TEMP_ID + f"{self.index}_" + self.name - name_enc = compiler.dialect.identifier_preparer.quote(name) - select_stmt = compiler.process(self.select) - self.table_obj = sa.Table( - name, - sa.MetaData(), - *(sa.Column(col.name, col.type) for col in self.select.selected_columns), - ) - return f"CREATE TEMPORARY TABLE {name_enc} AS {select_stmt}" +def create_temporary_table( + conn, + schema: str, + select: sa.Select, + *, + primary_key: Tuple[str, ...] = (), +) -> Tuple[sa.Table, int]: + """ + Create a temporary table on the passed connection generated by the passed + Select object. This method will return a + + Parameters + conn: The connection to create the temporary table within. Temporary tables + are private to the connection that created them and are cleaned up + after the connection is closed. + schema: The schema to create the temporary table within. For some dialects + temporary tables always exist in their own schema and this parameter + will be ignored. + primary_key: If set will mark the set of columns passed as primary keys in + the temporary table. This tuple should match a subset of the + column names in the select query. + + Returns a tuple containing the generated table object and the number of rows that + were inserted in the table. + """ + dialect = conn.engine.dialect + + # Some dialects can only create temporary tables in an implicit schema + temp_schema: Optional[str] = schema + if dialect.name in ("postgresql", "sqlite"): + temp_schema = None + + temp_name = f"_tmp_subsetter_{str(uuid.uuid4()).replace('-', '_')}" + + # Create the temporary table from the select statement. Mark the requested + # columns as part of the primary key. + metadata = sa.MetaData() + table_obj = sa.Table( + temp_name, + metadata, + schema=temp_schema, + prefixes=["TEMPORARY"], + *( + sa.Column(col.name, col.type, primary_key=col.name in primary_key) + for col in select.selected_columns + ), + ) + try: + metadata.create_all(conn) + except Exception as exc: # pylint: disable=broad-exception-caught + # TODO: Is this still needed? + # + # Some client/server combinations report a read-only error even though the temporary + # table creation actually succeeded. We'll just swallow the error here and if there + # was a real issue it'll get flagged again when we query against it. + if "--read-only" not in str(exc): + raise + # Copy data into the temporary table + result = conn.execute( + table_obj.insert().from_select(list(table_obj.columns), select) + ) -compiles(TemporaryTable, "postgresql", "sqlite")( - TemporaryTable._temporary_table_compile_no_schema -) -compiles(TemporaryTable)(TemporaryTable._temporary_table_compile_generic) + return table_obj, result.rowcount # pylint: disable=too-many-ancestors,abstract-method @@ -718,37 +735,28 @@ def _materialize_tables( ) -> None: materialization_order = self._materialization_order(meta, plan) for schema, table_name, ref_count in materialization_order: + table = meta.tables[(schema, table_name)] query = plan.queries[f"{schema}.{table_name}"] - ttbl = TemporaryTable( - schema, table_name, query.build(meta.sql_build_context()) - ) LOGGER.info( "Materializing sample for %s.%s", schema, table_name, ) - LOGGER.debug( - " Using statement %s", - str(ttbl.compile(dialect=conn.engine.dialect)).replace("\n", " "), - ) - - try: - result = conn.execute(ttbl) - except Exception as exc: # pylint: disable=broad-exception-caught - # Some client/server combinations report a read-only error even though the temporary - # table creation actually succeeded. We'll just swallow the error here and if there - # was a real issue it'll get flagged again when we query against it. - if "--read-only" not in str(exc): - raise - else: - meta.temp_tables[(schema, table_name, 0)] = ttbl.table_obj - LOGGER.info( - "Materialized %d rows for %s.%s in temporary table", - result.rowcount, + meta.temp_tables[(schema, table_name, 0)], rowcount = ( + create_temporary_table( + conn, schema, - table_name, + query.build(meta.sql_build_context()), + primary_key=table.primary_key, ) + ) + LOGGER.info( + "Materialized %d rows for %s.%s in temporary table", + rowcount, + schema, + table_name, + ) if meta.supports_temp_reopen: continue @@ -756,21 +764,19 @@ def _materialize_tables( # Create additional copies of the temporary table if needed. This is # to work around an issue on mysql with reopening temporary tables. for index in range(1, ref_count): - ttbl_copy = TemporaryTable( - schema, table_name, sa.select(ttbl.table_obj), index=index - ) - try: - result = conn.execute(ttbl_copy) - except Exception as exc: # pylint: disable=broad-exception-caught - if "--read-only" not in str(exc): - raise - else: - meta.temp_tables[(schema, table_name, index)] = ttbl_copy.table_obj - LOGGER.info( - "Copied materialization of %s.%s", + meta.temp_tables[(schema, table_name, index)], _ = ( + create_temporary_table( + conn, schema, - table_name, + meta.temp_tables[(schema, table_name, 0)].select(), + primary_key=table.primary_key, ) + ) + LOGGER.info( + "Copied materialization of %s.%s", + schema, + table_name, + ) def _copy_results( self,