Skip to content

Commit

Permalink
Fix bug in solver, add tests (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
msg555 authored Apr 22, 2024
1 parent 5cea461 commit 64bc850
Show file tree
Hide file tree
Showing 16 changed files with 934 additions and 309 deletions.
10 changes: 9 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
# v0.2.1
# v0.3.0

- Added support for remapping schema and table names in sampling phase
- Patch support for JSON data types
- Improve live database testing
- Replaced a lot of custom dialect handling with sqlalchemy core api
- Plan output no longer outputs if tables should be materialized. Sampler now
calculates if it needs to materialize a table itself.
- Fixed significant bug in solver that prevented finding a valid plan in many
scenarios. Now planning will only fail if there is a forward cycle of foreign
keys.
- Added support in mysql dialects for sampling across multiple foreign keys to
the sape table. Previously this would result in an error due to attempting to
reopen a temporary table which is not supported in mysql.

# v0.2.0

Expand Down
2 changes: 1 addition & 1 deletion subsetter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _main_plan(args):
stream=fplan,
default_flow_style=False,
width=2**20,
sort_keys=False,
sort_keys=True,
)
except IOError as exc:
LOGGER.error(
Expand Down
32 changes: 27 additions & 5 deletions subsetter/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sqlalchemy as sa

from subsetter.common import parse_table_name
from subsetter.plan_model import SQLTableIdentifier
from subsetter.solver import reverse_graph

LOGGER = logging.getLogger(__name__)
Expand All @@ -15,7 +16,7 @@
# pylint: disable=modified-iterating-list


@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, order=True)
class ForeignKey:
columns: Tuple[str, ...]
dst_schema: str
Expand Down Expand Up @@ -59,10 +60,13 @@ def __init__(
self,
metadata_obj: sa.MetaData,
tables: Dict[Tuple[str, str], TableMetadata],
*,
supports_temp_reopen: bool = True,
) -> None:
self.metadata_obj = metadata_obj
self.tables = tables
self.temp_tables: Dict[Tuple[str, str], sa.Table] = {}
self.supports_temp_reopen = supports_temp_reopen
self.temp_tables: Dict[Tuple[str, str, int], sa.Table] = {}

@classmethod
def from_engine(
Expand Down Expand Up @@ -121,6 +125,7 @@ def from_engine(
)
for schema, table in table_queue
},
supports_temp_reopen=engine.dialect.name != "mysql",
),
table_queue[num_selected_tables:],
)
Expand Down Expand Up @@ -222,20 +227,37 @@ def compute_reverse_keys(self) -> None:

def as_graph(
self, *, ignore_tables: Optional[Set[Tuple[str, str]]] = None
) -> Dict[str, List[str]]:
) -> Dict[str, Set[str]]:
if ignore_tables is None:
ignore_tables = set()
return {
f"{table.schema}.{table.name}": [
f"{table.schema}.{table.name}": {
f"{fk.dst_schema}.{fk.dst_table}"
for fk in table.foreign_keys
if (fk.dst_schema, fk.dst_table) not in ignore_tables
and (fk.dst_schema, fk.dst_table) in self.tables
]
}
for table in self.tables.values()
if (table.schema, table.name) not in ignore_tables
}

def sql_build_context(self):
reference_count = {}

def _context(ident: SQLTableIdentifier) -> sa.Table:
if ident.sampled:
if self.supports_temp_reopen:
index = 0
else:
index = reference_count.get(
(ident.table_schema, ident.table_name), 0
)
reference_count[(ident.table_schema, ident.table_name)] = index + 1
return self.temp_tables[(ident.table_schema, ident.table_name, index)]
return self.tables[(ident.table_schema, ident.table_name)].table_obj

return _context

def output_graphviz(self, fout: TextIO) -> None:
def _dot_label(lbl: TableMetadata) -> str:
return f'"{str(lbl)}"'
Expand Down
67 changes: 41 additions & 26 deletions subsetter/plan_model.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
from typing import Any, Dict, List, Literal, Optional, Union
import itertools
from typing import Any, Callable, Dict, List, Literal, Optional, Union

import sqlalchemy as sa
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated

from subsetter.metadata import DatabaseMetadata

# pylint: disable=unused-argument


SQLBuildContext = Callable[["SQLTableIdentifier"], sa.Table]


class SQLTableIdentifier(BaseModel):
table_schema: str = Field(..., alias="schema")
table_name: str = Field(..., alias="table")
sampled: bool = False

model_config = ConfigDict(populate_by_name=True)

def build(self, meta: DatabaseMetadata) -> sa.Table:
if self.sampled:
return meta.temp_tables[(self.table_schema, self.table_name)]
return meta.tables[(self.table_schema, self.table_name)].table_obj
def build(self, context: SQLBuildContext) -> sa.Table:
return context(self)


SQLKnownOperator = Literal["<", ">", "=", "<>", "!=", "like", "not like"]
Expand All @@ -31,7 +31,7 @@ class SQLWhereClauseFalse(BaseModel):

model_config = ConfigDict(populate_by_name=True)

def build(self, meta: DatabaseMetadata, table_obj: sa.Table):
def build(self, context: SQLBuildContext, table_obj: sa.Table):
return sa.false()

def simplify(self) -> "SQLWhereClause":
Expand All @@ -43,7 +43,7 @@ class SQLWhereClauseTrue(BaseModel):

model_config = ConfigDict(populate_by_name=True)

def build(self, meta: DatabaseMetadata, table_obj: sa.Table):
def build(self, context: SQLBuildContext, table_obj: sa.Table):
return sa.true()

def simplify(self) -> "SQLWhereClause":
Expand All @@ -58,7 +58,7 @@ class SQLWhereClauseOperator(BaseModel):

model_config = ConfigDict(populate_by_name=True)

def build(self, meta: DatabaseMetadata, table_obj: sa.Table):
def build(self, context: SQLBuildContext, table_obj: sa.Table):
op = self.operator
column = table_obj.columns[self.column]
if op == "<":
Expand Down Expand Up @@ -87,12 +87,12 @@ class SQLWhereClauseIn(BaseModel):

model_config = ConfigDict(populate_by_name=True)

def build(self, meta: DatabaseMetadata, table_obj: sa.Table):
def build(self, context: SQLBuildContext, table_obj: sa.Table):
columns = sa.tuple_(*(table_obj.columns[col_name] for col_name in self.columns))
if isinstance(self.values, list):
clause = columns.in_(self.values)
else:
clause = columns.in_(self.values.build(meta))
clause = columns.in_(self.values.build(context))
if self.negated:
clause = sa.not_(clause)
return clause
Expand All @@ -114,10 +114,10 @@ class SQLWhereClauseAnd(BaseModel):

model_config = ConfigDict(populate_by_name=True)

def build(self, meta: DatabaseMetadata, table_obj: sa.Table):
def build(self, context: SQLBuildContext, table_obj: sa.Table):
if not self.conditions:
return sa.true()
return sa.and_(*(cond.build(meta, table_obj) for cond in self.conditions))
return sa.and_(*(cond.build(context, table_obj) for cond in self.conditions))

def simplify(self) -> "SQLWhereClause":
simp_conditions: List["SQLWhereClause"] = [
Expand All @@ -127,6 +127,14 @@ def simplify(self) -> "SQLWhereClause":
simp_condition := condition.simplify(), SQLWhereClauseTrue
)
]
simp_conditions = list(
itertools.chain(
*(
cond.conditions if isinstance(cond, SQLWhereClauseAnd) else (cond,)
for cond in simp_conditions
)
)
)
if any(
isinstance(condition, SQLWhereClauseFalse) for condition in simp_conditions
):
Expand All @@ -144,10 +152,10 @@ class SQLWhereClauseOr(BaseModel):

model_config = ConfigDict(populate_by_name=True)

def build(self, meta: DatabaseMetadata, table_obj: sa.Table):
def build(self, context: SQLBuildContext, table_obj: sa.Table):
if not self.conditions:
return sa.false()
return sa.or_(*(cond.build(meta, table_obj) for cond in self.conditions))
return sa.or_(*(cond.build(context, table_obj) for cond in self.conditions))

def simplify(self) -> "SQLWhereClause":
simp_conditions: List["SQLWhereClause"] = [
Expand All @@ -157,6 +165,14 @@ def simplify(self) -> "SQLWhereClause":
simp_condition := condition.simplify(), SQLWhereClauseFalse
)
]
simp_conditions = list(
itertools.chain(
*(
cond.conditions if isinstance(cond, SQLWhereClauseOr) else (cond,)
for cond in simp_conditions
)
)
)
if any(
isinstance(condition, SQLWhereClauseTrue) for condition in simp_conditions
):
Expand All @@ -174,7 +190,7 @@ class SQLWhereClauseRandom(BaseModel):

model_config = ConfigDict(populate_by_name=True)

def build(self, meta: DatabaseMetadata, table_obj: sa.Table):
def build(self, context: SQLBuildContext, table_obj: sa.Table):
# pylint: disable=not-callable
return sa.func.random() < self.threshold

Expand All @@ -193,7 +209,7 @@ class SQLWhereClauseSQL(BaseModel):

model_config = ConfigDict(populate_by_name=True)

def build(self, meta: DatabaseMetadata, table_obj: sa.Table):
def build(self, context: SQLBuildContext, table_obj: sa.Table):
clause = sa.text(self.sql)
if self.sql_params is not None:
clause = clause.bindparams(**self.sql_params)
Expand Down Expand Up @@ -227,8 +243,8 @@ class SQLStatementSelect(BaseModel):

model_config = ConfigDict(populate_by_name=True)

def build(self, meta: DatabaseMetadata):
table_obj = self.from_.build(meta)
def build(self, context: SQLBuildContext):
table_obj = self.from_.build(context)

if self.columns:
# pylint: disable=not-an-iterable
Expand All @@ -237,7 +253,7 @@ def build(self, meta: DatabaseMetadata):
stmt = sa.select(table_obj)

if self.where:
stmt = stmt.where(self.where.build(meta, table_obj))
stmt = stmt.where(self.where.build(context, table_obj))

if self.limit is not None:
# pylint: disable=not-callable
Expand Down Expand Up @@ -269,8 +285,8 @@ class SQLStatementUnion(BaseModel):
model_config = ConfigDict(populate_by_name=True)
# TODO: Assert statements is not empty

def build(self, meta: DatabaseMetadata):
return sa.union(*(statement.build(meta) for statement in self.statements))
def build(self, context: SQLBuildContext):
return sa.union(*(statement.build(context) for statement in self.statements))

def simplify(self) -> "SQLStatement":
simp_statements = [
Expand Down Expand Up @@ -326,16 +342,15 @@ class SQLTableQuery(BaseModel):
sql: Optional[str] = None
sql_params: Optional[Dict[str, Any]] = None
statement: Optional[SQLStatement] = None
materialize: bool = False

def build(self, meta: DatabaseMetadata):
def build(self, context: SQLBuildContext):
if self.sql is not None:
stmt = sa.text(self.sql)
if self.sql_params is not None:
stmt = stmt.bindparams(**self.sql_params)
return stmt
if self.statement is not None:
return self.statement.build(meta)
return self.statement.build(context)
raise ValueError("One of 'sql' or 'select' must be set")


Expand Down
Loading

0 comments on commit 64bc850

Please sign in to comment.