Skip to content

Commit

Permalink
Auto generate output schema
Browse files Browse the repository at this point in the history
  • Loading branch information
msg555 committed May 6, 2024
1 parent ef1873b commit 3d4dc8c
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 5 deletions.
4 changes: 2 additions & 2 deletions sampler_config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ output:
# Will attempt to apply each replacement in the order they are listed. Example
# below can be used to add the suffix "_out" to the schema.
remap:
- search: "^(\w+)\."
replace: "\1_out."
- search: "^(\\w+)\\."
replace: "\\1_out."

# Alternative configuration to output JSON files within a directory
# output:
Expand Down
13 changes: 12 additions & 1 deletion subsetter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ def _add_sample_args(parser, *, subset_action: bool = False):
default=False,
help="Truncate existing output before sampling",
)
parser.add_argument(
"--create",
action="store_const",
const=True,
default=False,
help="Create tables in destination from source if missing",
)
output_parsers = parser.add_subparsers(
dest="output",
required=False,
Expand Down Expand Up @@ -269,7 +276,11 @@ def _main_sample(args):
)
sys.exit(1)

Sampler(_get_sample_config(args)).sample(plan, truncate=args.truncate)
Sampler(_get_sample_config(args)).sample(
plan,
truncate=args.truncate,
create=args.create,
)


def _main_subset(args):
Expand Down
5 changes: 5 additions & 0 deletions subsetter/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ def from_engine(
table_queue[num_selected_tables:],
)

def track_new_table(self, table_obj: sa.Table) -> None:
if table_obj.schema is None:
raise ValueError("Table schema must be set")
self.tables[(table_obj.schema, table_obj.name)] = TableMetadata(table_obj)

def infer_missing_foreign_keys(self) -> None:
pk_map: Dict[Tuple[str, Tuple[str, ...]], Optional[TableMetadata]] = {}
for table in self.tables.values():
Expand Down
83 changes: 81 additions & 2 deletions subsetter/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import os
import re
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union
from typing import Any, Dict, Iterable, List, Literal, Optional, Set, Tuple, Union

import sqlalchemy as sa
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -141,6 +141,9 @@ def output_result_set(
def truncate(self) -> None:
"""Delete any existing data that could interfere with output destination"""

def create(self, source_meta: DatabaseMetadata) -> None:
"""Create any missing tables in destination from the source schema"""

@abc.abstractmethod
def insert_order(self) -> List[str]:
"""Return the order to insert data that respects foreign key relationships"""
Expand Down Expand Up @@ -215,6 +218,74 @@ def __init__(self, config: DatabaseOutputConfig, tables: List[str]) -> None:
close_backward=True,
)

def create(self, source_meta: DatabaseMetadata) -> None:
"""Create any missing tables in destination from the source schema"""
metadata_obj = sa.MetaData()

table_obj_map = {}
tables_created = set()
for remapped_table, table in self.table_remap.items():
remap_schema, remap_table = parse_table_name(remapped_table)

if (remap_schema, remap_table) in self.meta.tables:
table_obj_map[table] = self.meta.tables[
(remap_schema, remap_table)
].table_obj
continue

table_obj = source_meta.tables[parse_table_name(table)].table_obj
table_obj_map[table] = sa.Table(
remap_table,
metadata_obj,
*(
sa.Column(
col.name,
col.type,
nullable=col.nullable,
primary_key=col.primary_key,
)
for col in table_obj.columns
),
schema=remap_schema,
)
tables_created.add(table_obj_map[table])

def _remap_cols(cols: Iterable[sa.Column]) -> List[sa.Column]:
return [
table_obj_map[f"{col.table.schema}.{col.table.name}"].columns[col.name]
for col in cols
]

# Copy table constraints including foreign key constraints.
for table, remapped_table_obj in table_obj_map.items():
if remapped_table_obj not in tables_created:
continue

table_obj = source_meta.tables[parse_table_name(table)].table_obj
for constraint in table_obj.constraints:
if isinstance(constraint, sa.UniqueConstraint):
remapped_table_obj.append_constraint(
sa.UniqueConstraint(*_remap_cols(constraint.columns))
)
if isinstance(constraint, sa.CheckConstraint):
remapped_table_obj.append_constraint(
sa.CheckConstraint(constraint.sqltext)
)
if isinstance(constraint, sa.ForeignKeyConstraint):
remapped_table_obj.append_constraint(
sa.ForeignKeyConstraint(
_remap_cols(constraint.columns),
_remap_cols(elem.column for elem in constraint.elements),
name=constraint.name,
)
)

if tables_created:
LOGGER.info("Creating %d tables in destination", len(tables_created))
metadata_obj.create_all(bind=self.engine)
for remapped_table_obj in tables_created:
self.meta.track_new_table(remapped_table_obj)

def truncate(self) -> None:
for schema, table_name in self.additional_tables:
LOGGER.info(
Expand Down Expand Up @@ -325,7 +396,13 @@ def __init__(self, config: SamplerConfig) -> None:
env_prefix="SUBSET_SOURCE_"
)

def sample(self, plan: SubsetPlan, *, truncate: bool = False) -> None:
def sample(
self,
plan: SubsetPlan,
*,
truncate: bool = False,
create: bool = False,
) -> None:
meta, _ = DatabaseMetadata.from_engine(self.source_engine, list(plan.queries))
if self.config.multiplicity.infer_foreign_keys:
meta.infer_missing_foreign_keys()
Expand All @@ -336,6 +413,8 @@ def sample(self, plan: SubsetPlan, *, truncate: bool = False) -> None:
)

output = SamplerOutput.from_config(self.config.output, list(plan.queries))
if create:
output.create(meta)
insert_order = output.insert_order()
if truncate:
output.truncate()
Expand Down

0 comments on commit 3d4dc8c

Please sign in to comment.