Skip to content

Commit

Permalink
Use "--persistent-source-schema" in the snapshot generation script.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Jun 27, 2023
1 parent 0d22032 commit 3bf54fe
Showing 1 changed file with 57 additions and 11 deletions.
68 changes: 57 additions & 11 deletions metricflow/test/generate_snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@

import logging
import os
from dataclasses import dataclass
from typing import Optional, Sequence

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.implementations.base import FrozenBaseModel
from dbt_semantic_interfaces.pretty_print import pformat_big_objects

from metricflow.configuration.env_var import EnvironmentVariable
from metricflow.protocols.sql_client import SqlEngine

logger = logging.getLogger(__name__)

Expand All @@ -47,6 +50,12 @@ class MetricFlowTestCredentialSet(FrozenBaseModel): # noqa: D
engine_password: Optional[str]


@dataclass(frozen=True)
class MetricFlowTestConfiguration: # noqa: D
engine: SqlEngine
credential_set: MetricFlowTestCredentialSet


class MetricFlowTestCredentialSetForAllEngines(FrozenBaseModel): # noqa: D
duck_db: MetricFlowTestCredentialSet
redshift: MetricFlowTestCredentialSet
Expand All @@ -55,8 +64,29 @@ class MetricFlowTestCredentialSetForAllEngines(FrozenBaseModel): # noqa: D
databricks: MetricFlowTestCredentialSet

@property
def as_sequence(self) -> Sequence[MetricFlowTestCredentialSet]: # noqa: D
return (self.duck_db, self.redshift, self.snowflake, self.big_query, self.databricks)
def as_configurations(self) -> Sequence[MetricFlowTestConfiguration]: # noqa: D
return (
MetricFlowTestConfiguration(
engine=SqlEngine.DUCKDB,
credential_set=self.duck_db,
),
MetricFlowTestConfiguration(
engine=SqlEngine.REDSHIFT,
credential_set=self.redshift,
),
MetricFlowTestConfiguration(
engine=SqlEngine.SNOWFLAKE,
credential_set=self.snowflake,
),
MetricFlowTestConfiguration(
engine=SqlEngine.BIGQUERY,
credential_set=self.big_query,
),
MetricFlowTestConfiguration(
engine=SqlEngine.DATABRICKS,
credential_set=self.databricks,
),
)


SNAPSHOT_GENERATING_TEST_FILES = (
Expand Down Expand Up @@ -88,20 +118,34 @@ def run_command(command: str) -> None: # noqa: D
raise RuntimeError(f"Error running command: {command}")


def run_tests(engine_credential_set: MetricFlowTestCredentialSet, test_file_paths: Sequence[str]) -> None: # noqa: D
def run_tests(test_configuration: MetricFlowTestConfiguration, test_file_paths: Sequence[str]) -> None: # noqa: D
combined_paths = " ".join(test_file_paths)
if engine_credential_set.engine_url is None:
if test_configuration.credential_set.engine_url is None:
if "MF_SQL_ENGINE_URL" in os.environ:
del os.environ["MF_SQL_ENGINE_URL"]
else:
os.environ["MF_SQL_ENGINE_URL"] = engine_credential_set.engine_url
os.environ["MF_SQL_ENGINE_URL"] = test_configuration.credential_set.engine_url

if engine_credential_set.engine_password is None:
if test_configuration.credential_set.engine_password is None:
if "MF_SQL_ENGINE_PASSWORD" in os.environ:
del os.environ["MF_SQL_ENGINE_PASSWORD"]
else:
os.environ["MF_SQL_ENGINE_PASSWORD"] = engine_credential_set.engine_password
run_command(f"pytest -x -vv -n 4 --overwrite-snapshots {combined_paths}")
os.environ["MF_SQL_ENGINE_PASSWORD"] = test_configuration.credential_set.engine_password

if test_configuration.engine is SqlEngine.DUCKDB:
# Can't use --use-persistent-source-schema with duckdb since it's in memory.
run_command(f"pytest -x -vv -n 4 --overwrite-snapshots {combined_paths}")
elif (
test_configuration.engine is SqlEngine.REDSHIFT
or test_configuration.engine is SqlEngine.SNOWFLAKE
or test_configuration.engine is SqlEngine.BIGQUERY
or test_configuration.engine is SqlEngine.DATABRICKS
):
run_command(f"pytest -x -vv -n 4 --overwrite-snapshots --use-persistent-source-schema {combined_paths}")
elif test_configuration.engine is SqlEngine.POSTGRES:
raise NotImplementedError(f"{test_configuration.engine} is not yet supported in this script.")
else:
assert_values_exhausted(test_configuration.engine)


def run_cli() -> None: # noqa: D
Expand All @@ -123,9 +167,11 @@ def run_cli() -> None: # noqa: D
f"Running the following tests to generate snapshots:\n{pformat_big_objects(SNAPSHOT_GENERATING_TEST_FILES)}"
)

for credential_set in credential_sets.as_sequence:
logger.info(f"Running test for {credential_set.engine_url}")
run_tests(credential_set, SNAPSHOT_GENERATING_TEST_FILES)
for test_configuration in credential_sets.as_configurations:
logger.info(
f"Running tests for {test_configuration.engine} with URL: {test_configuration.credential_set.engine_url}"
)
run_tests(test_configuration, SNAPSHOT_GENERATING_TEST_FILES)


if __name__ == "__main__":
Expand Down

0 comments on commit 3bf54fe

Please sign in to comment.