Skip to content

Commit

Permalink
Add test to validate tables in the source schema.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Jun 23, 2023
1 parent d441f2a commit fcaf819
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 2 deletions.
9 changes: 8 additions & 1 deletion metricflow/test/fixtures/table_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@
logger = logging.getLogger(__name__)


# Prefer to use the fixture, but there are some cases where fixtures aren't available. e.g. when defining a
# parameterized test.
CONFIGURED_SOURCE_TABLE_SNAPSHOT_REPOSITORY = SqlTableSnapshotRepository(
Path(os.path.dirname(__file__)).joinpath("source_table_snapshots")
)


@pytest.fixture(scope="session")
def source_table_snapshot_repository() -> SqlTableSnapshotRepository: # noqa: D
return SqlTableSnapshotRepository(Path(os.path.dirname(__file__)).joinpath("source_table_snapshots"))
return CONFIGURED_SOURCE_TABLE_SNAPSHOT_REPOSITORY


@pytest.fixture(scope="session", autouse=True)
Expand Down
6 changes: 5 additions & 1 deletion metricflow/test/table_snapshot/table_snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,11 @@ def restore(self, table_snapshot: SqlTableSnapshot, overwrite: bool = False) ->
)


class TableSnapshotParseException(Exception): # noqa: D
class TableSnapshotException(Exception): # noqa: D
pass


class TableSnapshotParseException(TableSnapshotException): # noqa: D
pass


Expand Down
72 changes: 72 additions & 0 deletions metricflow/test/table_snapshot/test_source_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import annotations

import logging
import warnings

import pytest

from metricflow.dataflow.sql_table import SqlTable
from metricflow.protocols.sql_client import SqlClient
from metricflow.test.compare_df import assert_dataframes_equal
from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState
from metricflow.test.fixtures.table_fixtures import CONFIGURED_SOURCE_TABLE_SNAPSHOT_REPOSITORY
from metricflow.test.source_schema_tools import POPULATE_SOURCE_SCHEMA_SHELL_COMMAND
from metricflow.test.table_snapshot.table_snapshots import (
SqlTableSnapshotRepository,
TableSnapshotException,
)

logger = logging.getLogger(__name__)


@pytest.mark.parametrize(
argnames="table_name",
argvalues=tuple(
table_snapshot.table_name for table_snapshot in CONFIGURED_SOURCE_TABLE_SNAPSHOT_REPOSITORY.table_snapshots
),
ids=lambda table_name: f"table_name={table_name}",
)
def test_validate_data_in_source_schema(
mf_test_session_state: MetricFlowTestSessionState,
sql_client: SqlClient,
source_table_snapshot_repository: SqlTableSnapshotRepository,
table_name: str,
create_source_tables: None,
) -> None:
"""Verifies that the source schema contains the tables as described in the snapshot repository.
This is useful to run when a persisted source schema is used to validate that the tables were properly created by a
call to populate_source_schema().
"""
if not mf_test_session_state.use_persistent_source_schema:
pytest.skip("Skipping as this session is running without the persistent source schema flag.")

schema_name = mf_test_session_state.mf_source_schema

matching_table_snapshots = tuple(
table_snapshot
for table_snapshot in source_table_snapshot_repository.table_snapshots
if table_snapshot.table_name == table_name
)

assert (
len(matching_table_snapshots) == 1
), f"Did not get exactly one matching table snapshot for table name {table_name}. Got {matching_table_snapshots}"

for table_snapshot in matching_table_snapshots:
try:
sql_table = SqlTable(schema_name=schema_name, table_name=table_snapshot.table_name)
expected_table_df = table_snapshot.as_df
actual_table_df = sql_client.query(f"SELECT * FROM {sql_table.sql}")
assert_dataframes_equal(
actual=actual_table_df,
expected=expected_table_df,
)
except Exception as e:
error_message = (
f"Error verifying that a table corresponding to {table_snapshot} exists in the persistent source "
f"schema {schema_name}. Try re-populating with: {POPULATE_SOURCE_SCHEMA_SHELL_COMMAND}"
)
# Add it to the warnings so that it stands out in a sea of test failures.
warnings.warn(error_message)
raise TableSnapshotException(error_message) from e

0 comments on commit fcaf819

Please sign in to comment.