Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor temporary table creation #29

Merged
merged 1 commit into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ extension-pkg-allow-list = [
markers = [
"mysql_live",
"postgres_live",
"sqlite_live",
]

[tool.mypy]
Expand Down
1 change: 1 addition & 0 deletions subsetter/config_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class MultiplicityConfig(ForbidBaseModel):
filters: Dict[str, List[FilterConfig]] = {} # type: ignore
multiplicity: MultiplicityConfig = MultiplicityConfig()
infer_foreign_keys: Literal["none", "schema", "all"] = "none"
compact_keys: bool = False


class SubsetterConfig(ForbidBaseModel):
Expand Down
168 changes: 87 additions & 81 deletions subsetter/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
import re
import uuid
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple

import sqlalchemy as sa
Expand Down Expand Up @@ -38,54 +39,70 @@ def tqdm(x, **_):
DESTINATION_BUFFER_SIZE = 1024


# pylint: disable=too-many-ancestors,abstract-method
class TemporaryTable(Executable, ClauseElement):
inherit_cache = True

TEMP_ID = "_tmp_subsetter_"

def __init__(
self, schema: str, name: str, select: sa.Select, index: int = 0
) -> None:
self.schema = schema
self.name = name
self.select = select
self.index = index
self.table_obj: sa.Table

def _temporary_table_compile_generic(self, compiler: SQLCompiler, **_) -> str:
name = self.name + self.TEMP_ID + str(self.index)
schema_enc = compiler.dialect.identifier_preparer.quote(self.schema)
name_enc = compiler.dialect.identifier_preparer.quote(name)
select_stmt = compiler.process(self.select)
self.table_obj = sa.Table(
name,
sa.MetaData(),
*(sa.Column(col.name, col.type) for col in self.select.selected_columns),
schema=self.schema,
)
return f"CREATE TEMPORARY TABLE {schema_enc}.{name_enc} AS {select_stmt}"

def _temporary_table_compile_no_schema(self, compiler: SQLCompiler, **_) -> str:
"""
Postgres creates temporary tables in a special schema. We make the table
name incorporate the schema name to compensate and avoid collisions.
"""
name = self.schema + self.TEMP_ID + f"{self.index}_" + self.name
name_enc = compiler.dialect.identifier_preparer.quote(name)
select_stmt = compiler.process(self.select)
self.table_obj = sa.Table(
name,
sa.MetaData(),
*(sa.Column(col.name, col.type) for col in self.select.selected_columns),
)
return f"CREATE TEMPORARY TABLE {name_enc} AS {select_stmt}"
def create_temporary_table(
conn,
schema: str,
select: sa.Select,
*,
primary_key: Tuple[str, ...] = (),
) -> Tuple[sa.Table, int]:
"""
Create a temporary table on the passed connection generated by the passed
Select object. This method will return a

Parameters
conn: The connection to create the temporary table within. Temporary tables
are private to the connection that created them and are cleaned up
after the connection is closed.
schema: The schema to create the temporary table within. For some dialects
temporary tables always exist in their own schema and this parameter
will be ignored.
primary_key: If set will mark the set of columns passed as primary keys in
the temporary table. This tuple should match a subset of the
column names in the select query.

Returns a tuple containing the generated table object and the number of rows that
were inserted in the table.
"""
dialect = conn.engine.dialect

# Some dialects can only create temporary tables in an implicit schema
temp_schema: Optional[str] = schema
if dialect.name in ("postgresql", "sqlite"):
temp_schema = None

temp_name = f"_tmp_subsetter_{str(uuid.uuid4()).replace('-', '_')}"

# Create the temporary table from the select statement. Mark the requested
# columns as part of the primary key.
metadata = sa.MetaData()
table_obj = sa.Table(
temp_name,
metadata,
schema=temp_schema,
prefixes=["TEMPORARY"],
*(
sa.Column(col.name, col.type, primary_key=col.name in primary_key)
for col in select.selected_columns
),
)
try:
metadata.create_all(conn)
except Exception as exc: # pylint: disable=broad-exception-caught
# TODO: Is this still needed?
#
# Some client/server combinations report a read-only error even though the temporary
# table creation actually succeeded. We'll just swallow the error here and if there
# was a real issue it'll get flagged again when we query against it.
if "--read-only" not in str(exc):
raise

# Copy data into the temporary table
result = conn.execute(
table_obj.insert().from_select(list(table_obj.columns), select)
)

compiles(TemporaryTable, "postgresql", "sqlite")(
TemporaryTable._temporary_table_compile_no_schema
)
compiles(TemporaryTable)(TemporaryTable._temporary_table_compile_generic)
return table_obj, result.rowcount


# pylint: disable=too-many-ancestors,abstract-method
Expand Down Expand Up @@ -718,59 +735,48 @@ def _materialize_tables(
) -> None:
materialization_order = self._materialization_order(meta, plan)
for schema, table_name, ref_count in materialization_order:
table = meta.tables[(schema, table_name)]
query = plan.queries[f"{schema}.{table_name}"]
ttbl = TemporaryTable(
schema, table_name, query.build(meta.sql_build_context())
)

LOGGER.info(
"Materializing sample for %s.%s",
schema,
table_name,
)
LOGGER.debug(
" Using statement %s",
str(ttbl.compile(dialect=conn.engine.dialect)).replace("\n", " "),
)

try:
result = conn.execute(ttbl)
except Exception as exc: # pylint: disable=broad-exception-caught
# Some client/server combinations report a read-only error even though the temporary
# table creation actually succeeded. We'll just swallow the error here and if there
# was a real issue it'll get flagged again when we query against it.
if "--read-only" not in str(exc):
raise
else:
meta.temp_tables[(schema, table_name, 0)] = ttbl.table_obj
LOGGER.info(
"Materialized %d rows for %s.%s in temporary table",
result.rowcount,
meta.temp_tables[(schema, table_name, 0)], rowcount = (
create_temporary_table(
conn,
schema,
table_name,
query.build(meta.sql_build_context()),
primary_key=table.primary_key,
)
)
LOGGER.info(
"Materialized %d rows for %s.%s in temporary table",
rowcount,
schema,
table_name,
)

if meta.supports_temp_reopen:
continue

# Create additional copies of the temporary table if needed. This is
# to work around an issue on mysql with reopening temporary tables.
for index in range(1, ref_count):
ttbl_copy = TemporaryTable(
schema, table_name, sa.select(ttbl.table_obj), index=index
)
try:
result = conn.execute(ttbl_copy)
except Exception as exc: # pylint: disable=broad-exception-caught
if "--read-only" not in str(exc):
raise
else:
meta.temp_tables[(schema, table_name, index)] = ttbl_copy.table_obj
LOGGER.info(
"Copied materialization of %s.%s",
meta.temp_tables[(schema, table_name, index)], _ = (
create_temporary_table(
conn,
schema,
table_name,
meta.temp_tables[(schema, table_name, 0)].select(),
primary_key=table.primary_key,
)
)
LOGGER.info(
"Copied materialization of %s.%s",
schema,
table_name,
)

def _copy_results(
self,
Expand Down
Loading