Skip to content

Commit

Permalink
Refactor temporary table creation
Browse files Browse the repository at this point in the history
  • Loading branch information
msg555 committed Oct 7, 2024
1 parent bb994b4 commit ed92913
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 81 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ extension-pkg-allow-list = [
markers = [
"mysql_live",
"postgres_live",
"sqlite_live",
]

[tool.mypy]
Expand Down
1 change: 1 addition & 0 deletions subsetter/config_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class MultiplicityConfig(ForbidBaseModel):
filters: Dict[str, List[FilterConfig]] = {} # type: ignore
multiplicity: MultiplicityConfig = MultiplicityConfig()
infer_foreign_keys: Literal["none", "schema", "all"] = "none"
compact_keys: bool = False


class SubsetterConfig(ForbidBaseModel):
Expand Down
168 changes: 87 additions & 81 deletions subsetter/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
import re
import uuid
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple

import sqlalchemy as sa
Expand Down Expand Up @@ -38,54 +39,70 @@ def tqdm(x, **_):
DESTINATION_BUFFER_SIZE = 1024


# pylint: disable=too-many-ancestors,abstract-method
class TemporaryTable(Executable, ClauseElement):
inherit_cache = True

TEMP_ID = "_tmp_subsetter_"

def __init__(
self, schema: str, name: str, select: sa.Select, index: int = 0
) -> None:
self.schema = schema
self.name = name
self.select = select
self.index = index
self.table_obj: sa.Table

def _temporary_table_compile_generic(self, compiler: SQLCompiler, **_) -> str:
name = self.name + self.TEMP_ID + str(self.index)
schema_enc = compiler.dialect.identifier_preparer.quote(self.schema)
name_enc = compiler.dialect.identifier_preparer.quote(name)
select_stmt = compiler.process(self.select)
self.table_obj = sa.Table(
name,
sa.MetaData(),
*(sa.Column(col.name, col.type) for col in self.select.selected_columns),
schema=self.schema,
)
return f"CREATE TEMPORARY TABLE {schema_enc}.{name_enc} AS {select_stmt}"

def _temporary_table_compile_no_schema(self, compiler: SQLCompiler, **_) -> str:
"""
Postgres creates temporary tables in a special schema. We make the table
name incorporate the schema name to compensate and avoid collisions.
"""
name = self.schema + self.TEMP_ID + f"{self.index}_" + self.name
name_enc = compiler.dialect.identifier_preparer.quote(name)
select_stmt = compiler.process(self.select)
self.table_obj = sa.Table(
name,
sa.MetaData(),
*(sa.Column(col.name, col.type) for col in self.select.selected_columns),
)
return f"CREATE TEMPORARY TABLE {name_enc} AS {select_stmt}"
def create_temporary_table(
conn,
schema: str,
select: sa.Select,
*,
primary_key: Tuple[str, ...] = (),
) -> Tuple[sa.Table, int]:
"""
Create a temporary table on the passed connection generated by the passed
Select object. This method will return a
Parameters
conn: The connection to create the temporary table within. Temporary tables
are private to the connection that created them and are cleaned up
after the connection is closed.
schema: The schema to create the temporary table within. For some dialects
temporary tables always exist in their own schema and this parameter
will be ignored.
primary_key: If set will mark the set of columns passed as primary keys in
the temporary table. This tuple should match a subset of the
column names in the select query.
Returns a tuple containing the generated table object and the number of rows that
were inserted in the table.
"""
dialect = conn.engine.dialect

# Some dialects can only create temporary tables in an implicit schema
temp_schema: Optional[str] = schema
if dialect.name in ("postgresql", "sqlite"):
temp_schema = None

temp_name = f"_tmp_subsetter_{str(uuid.uuid4()).replace('-', '_')}"

# Create the temporary table from the select statement. Mark the requested
# columns as part of the primary key.
metadata = sa.MetaData()
table_obj = sa.Table(
temp_name,
metadata,
schema=temp_schema,
prefixes=["TEMPORARY"],
*(
sa.Column(col.name, col.type, primary_key=col.name in primary_key)
for col in select.selected_columns
),
)
try:
metadata.create_all(conn)
except Exception as exc: # pylint: disable=broad-exception-caught
# TODO: Is this still needed?
#
# Some client/server combinations report a read-only error even though the temporary
# table creation actually succeeded. We'll just swallow the error here and if there
# was a real issue it'll get flagged again when we query against it.
if "--read-only" not in str(exc):
raise

# Copy data into the temporary table
result = conn.execute(
table_obj.insert().from_select(list(table_obj.columns), select)
)

compiles(TemporaryTable, "postgresql", "sqlite")(
TemporaryTable._temporary_table_compile_no_schema
)
compiles(TemporaryTable)(TemporaryTable._temporary_table_compile_generic)
return table_obj, result.rowcount


# pylint: disable=too-many-ancestors,abstract-method
Expand Down Expand Up @@ -718,59 +735,48 @@ def _materialize_tables(
) -> None:
materialization_order = self._materialization_order(meta, plan)
for schema, table_name, ref_count in materialization_order:
table = meta.tables[(schema, table_name)]
query = plan.queries[f"{schema}.{table_name}"]
ttbl = TemporaryTable(
schema, table_name, query.build(meta.sql_build_context())
)

LOGGER.info(
"Materializing sample for %s.%s",
schema,
table_name,
)
LOGGER.debug(
" Using statement %s",
str(ttbl.compile(dialect=conn.engine.dialect)).replace("\n", " "),
)

try:
result = conn.execute(ttbl)
except Exception as exc: # pylint: disable=broad-exception-caught
# Some client/server combinations report a read-only error even though the temporary
# table creation actually succeeded. We'll just swallow the error here and if there
# was a real issue it'll get flagged again when we query against it.
if "--read-only" not in str(exc):
raise
else:
meta.temp_tables[(schema, table_name, 0)] = ttbl.table_obj
LOGGER.info(
"Materialized %d rows for %s.%s in temporary table",
result.rowcount,
meta.temp_tables[(schema, table_name, 0)], rowcount = (
create_temporary_table(
conn,
schema,
table_name,
query.build(meta.sql_build_context()),
primary_key=table.primary_key,
)
)
LOGGER.info(
"Materialized %d rows for %s.%s in temporary table",
rowcount,
schema,
table_name,
)

if meta.supports_temp_reopen:
continue

# Create additional copies of the temporary table if needed. This is
# to work around an issue on mysql with reopening temporary tables.
for index in range(1, ref_count):
ttbl_copy = TemporaryTable(
schema, table_name, sa.select(ttbl.table_obj), index=index
)
try:
result = conn.execute(ttbl_copy)
except Exception as exc: # pylint: disable=broad-exception-caught
if "--read-only" not in str(exc):
raise
else:
meta.temp_tables[(schema, table_name, index)] = ttbl_copy.table_obj
LOGGER.info(
"Copied materialization of %s.%s",
meta.temp_tables[(schema, table_name, index)], _ = (
create_temporary_table(
conn,
schema,
table_name,
meta.temp_tables[(schema, table_name, 0)].select(),
primary_key=table.primary_key,
)
)
LOGGER.info(
"Copied materialization of %s.%s",
schema,
table_name,
)

def _copy_results(
self,
Expand Down

0 comments on commit ed92913

Please sign in to comment.