Skip to content

Commit

Permalink
Add pytest command to populate persistent source schema.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Jun 23, 2023
1 parent af275f8 commit d441f2a
Showing 1 changed file with 38 additions and 1 deletion.
39 changes: 38 additions & 1 deletion metricflow/test/source_schema_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dbt_semantic_interfaces.pretty_print import pformat_big_objects

from metricflow.protocols.sql_client import SqlClient
from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState
from metricflow.test.table_snapshot.table_snapshots import (
SqlTableSnapshotRepository,
SqlTableSnapshotRestorer,
Expand All @@ -27,7 +28,7 @@ def create_tables_listed_in_table_snapshot_repository(
[table_snapshot.table_name for table_snapshot in table_snapshot_repository.table_snapshots]
)
logger.info(
f"The following tables will be created if they don't exist in {schema_name}:\n"
f"The following tables will be created if they don't exist in {schema_name}:\n"
f"{pformat_big_objects(expected_table_names)}"
)
source_schema_table_names = sorted(sql_client.list_tables(schema_name=schema_name))
Expand All @@ -43,3 +44,39 @@ def create_tables_listed_in_table_snapshot_repository(
if table_snapshot.table_name in missing_table_names:
logger.info(f"Restoring: {table_snapshot.table_name}")
snapshot_restorer.restore(table_snapshot)


POPULATE_SOURCE_SCHEMA_SHELL_COMMAND = (
f"hatch -v run dev-env:pytest "
f"-vv "
f"--log-cli-level info "
f"--use-persistent-source-schema "
f"{__file__}::populate_source_schema"
)


def populate_source_schema(
mf_test_session_state: MetricFlowTestSessionState,
sql_client: SqlClient,
source_table_snapshot_repository: SqlTableSnapshotRepository,
) -> None:
"""Populate the source schema with the tables listed in table_snapshots.
This can be run via pytest when this file is specified because this function was whitelisted as a "test" in
pyproject.toml. However, because the filename does not begin with "test_", it's not normally collected and run. As
such, all parameters to this function are defined in fixtures.
"""
if not mf_test_session_state.use_persistent_source_schema:
raise ValueError("This should be run with the flag to enable use of the persistent source schema")

schema_name = mf_test_session_state.mf_source_schema

logger.info(f"Dropping schema {schema_name}")
sql_client.drop_schema(schema_name=schema_name, cascade=True)
logger.info(f"Creating schema {schema_name}")
sql_client.create_schema(schema_name=schema_name)
create_tables_listed_in_table_snapshot_repository(
sql_client=sql_client,
schema_name=schema_name,
table_snapshot_repository=source_table_snapshot_repository,
)

0 comments on commit d441f2a

Please sign in to comment.