diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1b7bf73..55ff605 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,7 +22,7 @@ jobs: # sudo apt update # sudo apt install -y pkg-config mysql-server - name: Install package deps - run: python -m pip install -e .[mysql,postgres] + run: python -m pip install -e .[all] - name: Install dev deps run: python -m pip install -r requirements-dev.txt - name: Lint diff --git a/CHANGELOG.md b/CHANGELOG.md index d413dea..1f4d43a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,11 @@ - Improve error messaging - Update output formatting - Add support for creating output schema if it does not exist +- Removed `normalize_foreign_keys` option that is no longer needed +- Like constraints in target constraints are now 'and'ed together as with all + other constraints instead of 'or'ed +- Changed optional dependency names to match sqlalchemy +- Added sqlite support (mostly to help with tests) # v0.3.0 diff --git a/Dockerfile b/Dockerfile index 9fc69a8..4bf8752 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,7 +5,7 @@ RUN pip install -U pip tqdm WORKDIR /subsetter COPY . ./ -RUN python3 -m pip install -e .[mysql,postgres] +RUN python3 -m pip install -e .[all] RUN adduser -S ctruser diff --git a/planner_config.example.yaml b/planner_config.example.yaml index ba97222..d5927c7 100644 --- a/planner_config.example.yaml +++ b/planner_config.example.yaml @@ -20,8 +20,8 @@ targets: amount: 100 # Additional possible filters shown below. Multiple filters can be provided # and the results will be intersected together (except all which overrides - # everything). Note that additional rows may be included beyond what is - # specified here if needed when following foreign keys. + # everything). Additional rows will only be sampled if a row from another + # targetted table has a dependence on them. # all: true # percent: 5.0 @@ -43,14 +43,11 @@ select: - db2.gadgets - db2.gizmos-* -# Add additional constraints for some tables. The planner does not attempt to -# verify that these constraints will not break foreign key relationships. In -# general it's always safe to apply constraints to tables that have no incoming -# foreign key constraints among selected tables. This config file does not -# accept arbitrary SQL; however you can manually modify the SQL in the generated -# plan with arbitrary SQL. +# Add additional constraints for some tables. Constraints can only be applied +# for tables where filtering rows would not cause foreign key constraints to be +# violated. table_constraints: - db1.user-data: + db1.user_data: - column: action_date operator: '>' value: '2023-07-01' @@ -84,12 +81,3 @@ extra_fks: # matches the name of a primary key column (that is unique within the database) # should function as a foreign key to that table. infer_foreign_keys: false - -# If set to true the subsetter will automatically attempt to normalize some -# foreign key relationships. In particular if there are foreign key -# relationships A->B, A->C, B->C then the subsetter will assume that the -# relationship A->C is redundant and ignore it. Without this assumption this -# sort of relationship triangle cannot be sampled with a "one table, one query" -# strategy. Note that this does not currently attempt to normalize chains that -# involve more than three tables. -normalize_foreign_keys: false diff --git a/sampler_config.example.yaml b/sampler_config.example.yaml index be36337..a92328f 100644 --- a/sampler_config.example.yaml +++ b/sampler_config.example.yaml @@ -7,6 +7,15 @@ source: username: my_user # overridden by SUBSET_SOURCE_USERNAME password: my_s3cret # overridden by SUBSET_SOURCE_PASSWORD # database: my_dbname # overridden by SUBSET_SOURCE_DATABASE (if needed) + + # For sqlite the the file named by 'database' will be mounted as the 'main' + # schema. You can mount additional databases using the 'sqlite_databases' + # mapping: + # + # sqlite_databases: + # foo: /path/to/foo.db + # bar: /path/to/bar.db + session_sqls: # Set any additional session variables; e.g. - SET @@session.max_statement_time=0 - SET @@session.net_read_timeout=3600 @@ -14,6 +23,10 @@ source: - SET @@session.wait_timeout=28800 - SET @@session.innodb_lock_wait_timeout=3600 + # Set the transaction isolation level. Defaults to REPEATABLE READ for all + # engines other than sqlite which defaults to SERIALIZABLE. + isolation_level: "REPEATABLE READ" + # Optionally specify the source database. This can also be passed on the command # line or through environment variables. output: diff --git a/setup.cfg b/setup.cfg index 5430994..e666be6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,11 +31,18 @@ install_requires = typing-extensions [options.extras_require] -mysql = - pymysql ~= 1.0 -postgres = - psycopg2-binary ~= 2.0 +all = + sqlalchemy[pymysql,postgresql_psycopg2binary] + +pymysql = + sqlalchemy[pymysql] + +postgresql = + sqlalchemy[postgresql] + +postgresql_psycopg2binary = + sqlalchemy[postgresql_psycopg2binary] [options.package_data] subsetter = diff --git a/subsetter/common.py b/subsetter/common.py index b9ce23a..c932c87 100644 --- a/subsetter/common.py +++ b/subsetter/common.py @@ -8,6 +8,7 @@ DatabaseDialect = Literal[ "mysql", "postgres", + "sqlite", ] LOGGER = logging.getLogger(__name__) @@ -37,9 +38,9 @@ def database_url( dialect = dialect or os.getenv(f"{env_prefix}DIALECT", None) # type: ignore host = host or os.getenv(f"{env_prefix}HOST", "localhost") port = port or int(os.getenv(f"{env_prefix}PORT", "0")) - database = database or os.getenv(f"{env_prefix}DATABASE", "") - username = username or os.environ[f"{env_prefix}USERNAME"] - password = os.environ[f"{env_prefix}PASSWORD"] if password is None else password + database = database or os.getenv(f"{env_prefix}DATABASE", None) + username = username or os.getenv(f"{env_prefix}USERNAME", None) + password = password or os.getenv(f"{env_prefix}PASSWORD", None) if dialect is None: dialect = DEFAULT_DIALECT @@ -57,13 +58,15 @@ def database_url( port = 5432 if database: extra_kwargs["database"] = database + elif dialect == "sqlite": + return sa.engine.URL.create(drivername="sqlite", database=database) else: raise ValueError(f"Unsupported SQL dialect {dialect!r}") return sa.engine.URL.create( drivername=drivername, host=host, - port=port or 3306, + port=port, username=username, password=password, **extra_kwargs, @@ -86,7 +89,8 @@ class DatabaseConfig(BaseModel): username: Optional[str] = None password: Optional[str] = None session_sqls: List[str] = [] - isolation_level: IsolationLevel = "READ COMMITTED" + sqlite_databases: Optional[Dict[str, str]] = {} + isolation_level: Optional[IsolationLevel] = None def database_url( self, @@ -106,16 +110,32 @@ def database_engine( self, env_prefix: Optional[str] = None, ) -> sa.engine.Engine: + if self.isolation_level: + isolation_level = self.isolation_level + elif self.dialect == "sqlite": + isolation_level = "SERIALIZABLE" + else: + isolation_level = "READ COMMITTED" engine = sa.create_engine( self.database_url(env_prefix=env_prefix), - isolation_level=self.isolation_level, + isolation_level=isolation_level, pool_pre_ping=True, ) @sa.event.listens_for(engine, "connect") def _set_session_sqls(dbapi_connection, _): - with dbapi_connection.cursor() as cursor: + cursor = dbapi_connection.cursor() + try: + if self.dialect == "sqlite": + for db_alias, db_file in self.sqlite_databases.items(): + escaped_db_file = db_file.replace("'", "''") + cursor.execute( + f"ATTACH DATABASE '{escaped_db_file}' as {db_alias}" + ) + for session_sql in self.session_sqls: cursor.execute(session_sql) + finally: + cursor.close() return engine diff --git a/subsetter/metadata.py b/subsetter/metadata.py index 27722be..ff6bf98 100644 --- a/subsetter/metadata.py +++ b/subsetter/metadata.py @@ -2,7 +2,7 @@ import dataclasses import logging from fnmatch import fnmatch -from typing import Dict, List, Optional, Set, TextIO, Tuple +from typing import Dict, List, Optional, Set, Tuple import sqlalchemy as sa @@ -167,37 +167,6 @@ def infer_missing_foreign_keys(self) -> None: ) table.foreign_keys.append(fk) - def normalize_foreign_keys(self) -> None: - """ - If table A has a foreign key to table B and they both share a foreign - key on the same column in table C, remove the foreign key from table A - assuming it is redundant. - """ - fk_sets = { - table_key: { - (fk.dst_schema, fk.dst_table, fk.dst_columns) - for fk in table.foreign_keys - } - for table_key, table in self.tables.items() - } - for table in self.tables.values(): - child_fk_sets = set() - for fk in table.foreign_keys: - child_fk_sets |= fk_sets[(fk.dst_schema, fk.dst_table)] - fk_out = [] - for fk in table.foreign_keys: - if (fk.dst_schema, fk.dst_table, fk.dst_columns) not in child_fk_sets: - fk_out.append(fk) - else: - LOGGER.info( - "Normalizing foreign key, removed %s->%s.%s on %r", - table, - fk.dst_schema, - fk.dst_table, - fk.columns, - ) - table.foreign_keys = fk_out - def toposort(self) -> List[TableMetadata]: return [ # type: ignore self.tables[parse_table_name(u)] for u in toposort(self.as_graph()) @@ -247,20 +216,3 @@ def _context(ident: SQLTableIdentifier) -> sa.Table: return self.tables[(ident.table_schema, ident.table_name)].table_obj return _context - - def output_graphviz(self, fout: TextIO) -> None: - def _dot_label(lbl: TableMetadata) -> str: - return f'"{str(lbl)}"' - - fout.write("digraph {\n") - for table in self.tables.values(): - fout.write(" ") - fout.write(_dot_label(table)) - fout.write(" -> {") - - deps = { - self.tables[(fk.dst_schema, fk.dst_table)] for fk in table.foreign_keys - } - fout.write(", ".join(_dot_label(dep) for dep in deps)) - fout.write("}\n") - fout.write("}\n") diff --git a/subsetter/planner.py b/subsetter/planner.py index 9e0d67e..3856578 100644 --- a/subsetter/planner.py +++ b/subsetter/planner.py @@ -59,10 +59,14 @@ class ColumnConstraint(BaseModel): ignore_fks: List[IgnoreFKConfig] = [] extra_fks: List[ExtraFKConfig] = [] infer_foreign_keys: bool = False - normalize_foreign_keys: bool = False class Planner: + """ + Class responsible for taking in a plan configuration and a source database + schema and producing a subsetting strategy. + """ + def __init__(self, config: PlannerConfig) -> None: self.config = config self.engine = self.config.source.database_engine(env_prefix="SUBSET_SOURCE_") @@ -92,10 +96,11 @@ def plan(self) -> SubsetPlan: extra_table[1], ) + return self._plan_internal() + + def _plan_internal(self) -> SubsetPlan: if self.config.infer_foreign_keys: self.meta.infer_missing_foreign_keys() - if self.config.normalize_foreign_keys: - self.meta.normalize_foreign_keys() self._remove_ignore_fks() self._add_extra_fks() self._check_ignore_tables() @@ -280,16 +285,23 @@ def _plan_table( assert not foreign_keys if target.all_: rev_foreign_keys.clear() - LOGGER.debug("Targetting %s and sampling from %s", table, rev_foreign_keys) + if rev_foreign_keys: + LOGGER.debug( + "Sampling %s as union of target parameters and references from %s", + table, + [f"{fk.dst_schema}.{fk.dst_table}" for fk in rev_foreign_keys], + ) + else: + LOGGER.debug("Targetting %s", table) elif foreign_keys: LOGGER.debug( - "Reverse sampling %s from %s", + "Sampling %s as intersection of references from %s", table, [f"{fk.dst_schema}.{fk.dst_table}" for fk in foreign_keys], ) else: LOGGER.debug( - "Sampling %s from %s", + "Sampling %s as union of references from %s", table, [f"{fk.dst_schema}.{fk.dst_table}" for fk in rev_foreign_keys], ) @@ -327,17 +339,26 @@ def _plan_table( f"{table.schema}.{table.name}", [] ) conf_constraints_sql: List[SQLWhereClause] = [] - all_columns = {column.name for column in table.table_obj.columns} + if conf_constraints and rev_foreign_keys: + raise ValueError( + f"Cannot apply table constraints to {table} without violating " + "foreign key constraints of previously sampled tables", + ) + for conf_constraint in conf_constraints: - if conf_constraint.column in all_columns: - conf_constraints_sql.append( - SQLWhereClauseOperator( - type_="operator", - operator=conf_constraint.operator, - column=conf_constraint.column, - value=conf_constraint.value, - ) + if conf_constraint.column not in table.table_obj.columns: + raise ValueError( + "Table {table} has no column {conf_constraint.column!r} for table constraint", + ) + + conf_constraints_sql.append( + SQLWhereClauseOperator( + type_="operator", + operator=conf_constraint.operator, + column=conf_constraint.column, + value=conf_constraint.value, ) + ) # Calculate initial foreign-key / config constraint statement statements: List[SQLStatementSelect] = [ @@ -376,19 +397,14 @@ def _plan_table( ) for column, patterns in target.like.items(): - target_constraints.append( - SQLWhereClauseOr( - type_="or", - conditions=[ - SQLWhereClauseOperator( - type_="operator", - operator="like", - column=column, - value=pattern, - ) - for pattern in patterns - ], + target_constraints.extend( + SQLWhereClauseOperator( + type_="operator", + operator="like", + column=column, + value=pattern, ) + for pattern in patterns ) for column, in_list in target.in_.items(): diff --git a/subsetter/sampler.py b/subsetter/sampler.py index 84c135f..5eecf39 100644 --- a/subsetter/sampler.py +++ b/subsetter/sampler.py @@ -60,7 +60,7 @@ def _temporary_table_compile_generic(self, compiler, **_) -> str: ) return f"CREATE TEMPORARY TABLE {schema_enc}.{name_enc} AS {select_stmt}" - def _temporary_table_compile_postgres(self, compiler, **_) -> str: + def _temporary_table_compile_no_schema(self, compiler, **_) -> str: """ Postgres creates temporary tables in a special schema. We make the table name incorporate the schema name to compensate and avoid collisions. @@ -76,7 +76,9 @@ def _temporary_table_compile_postgres(self, compiler, **_) -> str: return f"CREATE TEMPORARY TABLE {name_enc} AS {select_stmt}" -compiles(TemporaryTable, "postgresql")(TemporaryTable._temporary_table_compile_postgres) +compiles(TemporaryTable, "postgresql", "sqlite")( + TemporaryTable._temporary_table_compile_no_schema +) compiles(TemporaryTable)(TemporaryTable._temporary_table_compile_generic) diff --git a/tests/dataset_manager.py b/tests/dataset_manager.py index 5477801..9a37ffe 100644 --- a/tests/dataset_manager.py +++ b/tests/dataset_manager.py @@ -94,7 +94,7 @@ def _col_spec(col: Union[str, ColumnDescriptor]) -> sa.Column: def apply_dataset(db_config: DatabaseConfig, dataset: TestDataset) -> None: - engine = sa.create_engine(db_config.database_url()) + engine = db_config.database_engine() schemas = set() metadata = sa.MetaData() @@ -105,14 +105,17 @@ def apply_dataset(db_config: DatabaseConfig, dataset: TestDataset) -> None: schemas.add(schema) schemas.add(schema + "_out") - with engine.connect() as conn: - for schema in schemas: - try: - conn.execute(sa.schema.DropSchema(schema, cascade=True, if_exists=True)) - except sa.exc.ProgrammingError: - conn.execute(sa.schema.DropSchema(schema, if_exists=True)) - conn.execute(sa.schema.CreateSchema(schema)) - conn.commit() + if db_config.dialect != "sqlite": + with engine.connect() as conn: + for schema in schemas: + try: + conn.execute( + sa.schema.DropSchema(schema, cascade=True, if_exists=True) + ) + except sa.exc.ProgrammingError: + conn.execute(sa.schema.DropSchema(schema, if_exists=True)) + conn.execute(sa.schema.CreateSchema(schema)) + conn.commit() metadata.create_all(engine) @@ -124,7 +127,7 @@ def apply_dataset(db_config: DatabaseConfig, dataset: TestDataset) -> None: def get_rows(db_config, schema: str, table: str) -> List[Dict[str, Any]]: - engine = sa.create_engine(db_config.database_url()) + engine = db_config.database_engine() with engine.connect() as conn: metadata_obj = sa.MetaData() table_obj = sa.Table(table, metadata_obj, schema=schema, autoload_with=conn) diff --git a/tests/test_live.py b/tests/test_live.py index 5f4742d..846c503 100644 --- a/tests/test_live.py +++ b/tests/test_live.py @@ -1,4 +1,5 @@ import os +import tempfile import pytest @@ -31,6 +32,24 @@ def db_config_postgres(request): ) +@pytest.fixture +def sqlite_init_db(): + with tempfile.NamedTemporaryFile(suffix=".db") as tf1: + with tempfile.NamedTemporaryFile(suffix=".db") as tf2: + yield tf1.name, tf2.name + + +def db_config_sqlite(request): + db1, db2 = request.getfixturevalue("sqlite_init_db") + return DatabaseConfig( + dialect="sqlite", + sqlite_databases={ + "test": db1, + "test_out": db2, + }, + ) + + DATABASE_CONFIGURATIONS = [ pytest.param( db_config_mysql, @@ -48,6 +67,14 @@ def db_config_postgres(request): ], id="postgres", ), + pytest.param( + db_config_sqlite, + marks=[ + pytest.mark.usefixtures("sqlite_init_db"), + pytest.mark.sqlite_live, + ], + id="sqlite", + ), ]