diff --git a/.changes/unreleased/Features-20230627-234808.yaml b/.changes/unreleased/Features-20230627-234808.yaml new file mode 100644 index 0000000000..63b488095f --- /dev/null +++ b/.changes/unreleased/Features-20230627-234808.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Enable Snowflake queries in dbt <-> MetricFlow CLI integration +time: 2023-06-27T23:48:08.287586-07:00 +custom: + Author: tlento + Issue: "579" diff --git a/.github/workflows/cd-sql-engine-tests.yaml b/.github/workflows/cd-sql-engine-tests.yaml index 480f21b866..a7f25a2bdc 100644 --- a/.github/workflows/cd-sql-engine-tests.yaml +++ b/.github/workflows/cd-sql-engine-tests.yaml @@ -21,7 +21,6 @@ jobs: name: Snowflake Tests runs-on: ubuntu-latest steps: - - name: Check-out the repo uses: actions/checkout@v3 @@ -33,6 +32,7 @@ jobs: mf_sql_engine_password: ${{ secrets.MF_SNOWFLAKE_PWD }} parallelism: ${{ env.EXTERNAL_ENGINE_TEST_PARALLELISM }} additional-pytest-options: ${{ env.ADDITIONAL_PYTEST_OPTIONS }} + make-target: "test-snowflake" redshift-tests: environment: DW_INTEGRATION_TESTS @@ -40,7 +40,6 @@ jobs: if: ${{ github.event.action != 'labeled' || github.event.label.name == 'run_mf_sql_engine_tests' }} runs-on: ubuntu-latest steps: - - name: Check-out the repo uses: actions/checkout@v3 @@ -59,7 +58,6 @@ jobs: if: ${{ github.event.action != 'labeled' || github.event.label.name == 'run_mf_sql_engine_tests' }} runs-on: ubuntu-latest steps: - - name: Check-out the repo uses: actions/checkout@v3 @@ -78,7 +76,6 @@ jobs: if: ${{ github.event.action != 'labeled' || github.event.label.name == 'run_mf_sql_engine_tests' }} runs-on: ubuntu-latest steps: - - name: Check-out the repo uses: actions/checkout@v3 @@ -97,7 +94,6 @@ jobs: if: ${{ github.event.action != 'labeled' || github.event.label.name == 'run_mf_sql_engine_tests' }} runs-on: ubuntu-latest steps: - - name: Check-out the repo uses: actions/checkout@v3 @@ -116,7 +112,6 @@ jobs: if: ${{ github.event_name != 'pull_request' && failure() }} runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - name: Slack Failure diff --git a/Makefile b/Makefile index 2f894ebea4..235ecb0630 100644 --- a/Makefile +++ b/Makefile @@ -30,6 +30,10 @@ test: test-postgresql: hatch -v run postgres-env:pytest -vv -n $(PARALLELISM) $(ADDITIONAL_PYTEST_OPTIONS) metricflow/test/ +.PHONY: test-snowflake +test-snowflake: + hatch -v run snowflake-env:pytest -vv -n $(PARALLELISM) $(ADDITIONAL_PYTEST_OPTIONS) metricflow/test/ + .PHONY: lint lint: hatch -v run dev-env:pre-commit run --all-files diff --git a/metricflow/cli/dbt_connectors/adapter_backed_client.py b/metricflow/cli/dbt_connectors/adapter_backed_client.py index b74d0b9d71..0676745920 100644 --- a/metricflow/cli/dbt_connectors/adapter_backed_client.py +++ b/metricflow/cli/dbt_connectors/adapter_backed_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +import enum import logging import textwrap import time @@ -7,6 +8,7 @@ import pandas as pd from dbt.adapters.base.impl import BaseAdapter +from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from dbt_semantic_interfaces.pretty_print import pformat_big_objects from metricflow.dataflow.sql_table import SqlTable @@ -16,6 +18,7 @@ from metricflow.protocols.sql_request import SqlJsonTag, SqlRequestId, SqlRequestTagSet from metricflow.random_id import random_id from metricflow.sql.render.postgres import PostgresSQLSqlQueryPlanRenderer +from metricflow.sql.render.snowflake import SnowflakeSqlQueryPlanRenderer from metricflow.sql.render.sql_plan_renderer import SqlQueryPlanRenderer from metricflow.sql.sql_bind_parameters import SqlBindParameters from metricflow.sql_clients.sql_statement_metadata import CombinedSqlTags, SqlStatementCommentMetadata @@ -23,6 +26,33 @@ logger = logging.getLogger(__name__) +class SupportedAdapterTypes(enum.Enum): + """Enumeration of supported dbt adapter types.""" + + POSTGRES = "postgres" + SNOWFLAKE = "snowflake" + + @property + def sql_engine_type(self) -> SqlEngine: + """Return the SqlEngine corresponding to the supported adapter type.""" + if self is SupportedAdapterTypes.POSTGRES: + return SqlEngine.POSTGRES + elif self is SupportedAdapterTypes.SNOWFLAKE: + return SqlEngine.SNOWFLAKE + else: + assert_values_exhausted(self) + + @property + def sql_query_plan_renderer(self) -> SqlQueryPlanRenderer: + """Return the SqlQueryPlanRenderer corresponding to the supported adapter type.""" + if self is SupportedAdapterTypes.POSTGRES: + return PostgresSQLSqlQueryPlanRenderer() + elif self is SupportedAdapterTypes.SNOWFLAKE: + return SnowflakeSqlQueryPlanRenderer() + else: + assert_values_exhausted(self) + + class AdapterBackedSqlClient: """SqlClient implementation which delegates database operations to a dbt BaseAdapter instance. @@ -38,15 +68,18 @@ def __init__(self, adapter: BaseAdapter): The dbt BaseAdapter should already be fully initialized, including all credential verification, and ready for use for establishing connections and issuing queries. """ - if adapter.type() != "postgres": - raise ValueError( - f"Received dbt adapter with unsupported type {adapter.type()}, but we only support postgres!" - ) self._adapter = adapter - # TODO: normalize from adapter.type() - self._sql_engine_type = SqlEngine.POSTGRES - # TODO: create factory based on SqlEngine type - self._sql_query_plan_renderer = PostgresSQLSqlQueryPlanRenderer() + try: + adapter_type = SupportedAdapterTypes(self._adapter.type()) + except ValueError as e: + raise ValueError( + f"Adapter type {self._adapter.type()} is not supported. Must be one " + f"of {[item.value for item in SupportedAdapterTypes]}." + ) from e + + self._sql_engine_type = adapter_type.sql_engine_type + self._sql_query_plan_renderer = adapter_type.sql_query_plan_renderer + logger.info(f"Initialized AdapterBackedSqlClient with dbt adapter type `{adapter_type.value}`") @property def sql_engine_type(self) -> SqlEngine: @@ -246,6 +279,11 @@ def _get_type_from_pandas_dtype(self, dtype: str) -> str: def list_tables(self, schema_name: str) -> Sequence[str]: """Get a list of the table names in a given schema. Only used in tutorials and tests.""" # TODO: Short term, make this work with as many engines as possible. Medium term, remove this altogether. + if self.sql_engine_type is SqlEngine.SNOWFLAKE: + # Snowflake likes capitalizing things, except when it doesn't. We can get away with this due to its + # limited scope of usage. + schema_name = schema_name.upper() + df = self.query( textwrap.dedent( f"""\ @@ -257,7 +295,9 @@ def list_tables(self, schema_name: str) -> Sequence[str]: if df.empty: return [] - # Lower casing table names for consistency between Snowflake and other clients. + # Lower casing table names and data frame names for consistency between Snowflake and other clients. + # As above, we can do this because it isn't used in any consequential situations. + df.columns = df.columns.str.lower() return [t.lower() for t in df["table_name"]] def table_exists(self, sql_table: SqlTable) -> bool: diff --git a/metricflow/engine/metricflow_engine.py b/metricflow/engine/metricflow_engine.py index 5a6e39a367..2fd65a0e4c 100644 --- a/metricflow/engine/metricflow_engine.py +++ b/metricflow/engine/metricflow_engine.py @@ -574,6 +574,12 @@ def get_dimension_values( # noqa: D result_dataframe = query_result.result_df if result_dataframe is None: return [] + if get_group_by_values not in result_dataframe.columns: + # Snowflake likes upper-casing things in result output, so we lower-case all names and + # see if we can get the value from there. + get_group_by_values = get_group_by_values.lower() + result_dataframe.columns = result_dataframe.columns.str.lower() + return [str(val) for val in result_dataframe[get_group_by_values]] @log_call(module_name=__name__, telemetry_reporter=_telemetry_reporter) diff --git a/metricflow/test/cli/test_cli.py b/metricflow/test/cli/test_cli.py index 0a9540a6d4..54f49880f6 100644 --- a/metricflow/test/cli/test_cli.py +++ b/metricflow/test/cli/test_cli.py @@ -33,12 +33,15 @@ TIME_SPINE_TABLE, TRANSACTIONS_TABLE, ) +from metricflow.protocols.sql_client import SqlEngine from metricflow.test.fixtures.cli_fixtures import MetricFlowCliRunner def test_query(cli_runner: MetricFlowCliRunner) -> None: # noqa: D resp = cli_runner.run(query, args=["--metrics", "bookings", "--group-bys", "ds"]) - assert "bookings" in resp.output + # case insensitive matches are needed for snowflake due to the capitalization thing + engine_is_snowflake = cli_runner.cli_context.sql_client.sql_engine_type is SqlEngine.SNOWFLAKE + assert "bookings" in resp.output or ("bookings" in resp.output.lower() and engine_is_snowflake) assert resp.exit_code == 0 diff --git a/metricflow/test/compare_df.py b/metricflow/test/compare_df.py index e401430583..ce34fbd419 100644 --- a/metricflow/test/compare_df.py +++ b/metricflow/test/compare_df.py @@ -43,12 +43,24 @@ def _dataframes_contain_same_data( def assert_dataframes_equal( - actual: pd.DataFrame, expected: pd.DataFrame, sort_columns: bool = True, allow_empty: bool = False + actual: pd.DataFrame, + expected: pd.DataFrame, + sort_columns: bool = True, + allow_empty: bool = False, + compare_names_using_lowercase: bool = False, ) -> None: """Check that contents of DataFrames are the same. If sort_columns is set to false, value and column order needs to be the same. + If compare_names_using_lowercase is set to True, we copy the dataframes and lower-case their names. + This is useful for Snowflake query output comparisons. """ + if compare_names_using_lowercase: + actual = actual.copy() + expected = expected.copy() + actual.columns = actual.columns.str.lower() + expected.columns = expected.columns.str.lower() + if set(actual.columns) != set(expected.columns): raise ValueError( f"DataFrames do not contain the same columns. actual: {set(actual.columns)}, " diff --git a/metricflow/test/execution/test_tasks.py b/metricflow/test/execution/test_tasks.py index dda868d7f6..6c90330c15 100644 --- a/metricflow/test/execution/test_tasks.py +++ b/metricflow/test/execution/test_tasks.py @@ -9,7 +9,7 @@ SelectSqlQueryToTableTask, ) from metricflow.execution.executor import SequentialPlanExecutor -from metricflow.protocols.sql_client import SqlClient +from metricflow.protocols.sql_client import SqlClient, SqlEngine from metricflow.random_id import random_id from metricflow.sql.sql_bind_parameters import SqlBindParameters from metricflow.test.compare_df import assert_dataframes_equal @@ -32,6 +32,7 @@ def test_read_sql_task(sql_client: SqlClient) -> None: # noqa: D columns=["foo"], data=[(1,)], ), + compare_names_using_lowercase=sql_client.sql_engine_type is SqlEngine.SNOWFLAKE, ) @@ -56,5 +57,6 @@ def test_write_table_task(mf_test_session_state: MetricFlowTestSessionState, sql columns=["foo"], data=[(1,)], ), + compare_names_using_lowercase=sql_client.sql_engine_type is SqlEngine.SNOWFLAKE, ) sql_client.drop_table(output_table) diff --git a/metricflow/test/fixtures/dbt_projects/metricflow_testing/profiles.yml b/metricflow/test/fixtures/dbt_projects/metricflow_testing/profiles.yml index 6e44e30c4e..90223cf818 100644 --- a/metricflow/test/fixtures/dbt_projects/metricflow_testing/profiles.yml +++ b/metricflow/test/fixtures/dbt_projects/metricflow_testing/profiles.yml @@ -9,3 +9,15 @@ postgres: pass: "{{ env_var('MF_SQL_ENGINE_PASSWORD') }}" dbname: "{{ env_var('MF_SQL_ENGINE_DATABASE') }}" schema: "{{ env_var('MF_SQL_ENGINE_SCHEMA') }}" +snowflake: + target: dev + outputs: + dev: + type: snowflake + # The snowflake account is equivalent to the host value in the SqlAlchemy parsed URL + account: "{{ env_var('MF_SQL_ENGINE_HOST') }}" + user: "{{ env_var('MF_SQL_ENGINE_USER') }}" + password: "{{ env_var('MF_SQL_ENGINE_PASSWORD') }}" + warehouse: "{{ env_var('MF_SQL_ENGINE_WAREHOUSE') }}" + database: "{{ env_var('MF_SQL_ENGINE_DATABASE') }}" + schema: "{{ env_var('MF_SQL_ENGINE_SCHEMA') }}" diff --git a/metricflow/test/fixtures/sql_client_fixtures.py b/metricflow/test/fixtures/sql_client_fixtures.py index 12443b720b..59ba13421c 100644 --- a/metricflow/test/fixtures/sql_client_fixtures.py +++ b/metricflow/test/fixtures/sql_client_fixtures.py @@ -16,7 +16,6 @@ from metricflow.sql_clients.databricks import DatabricksSqlClient from metricflow.sql_clients.duckdb import DuckDbSqlClient from metricflow.sql_clients.redshift import RedshiftSqlClient -from metricflow.sql_clients.snowflake import SnowflakeSqlClient from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState, dialect_from_url logger = logging.getLogger(__name__) @@ -27,6 +26,7 @@ MF_SQL_ENGINE_PORT = "MF_SQL_ENGINE_PORT" MF_SQL_ENGINE_USER = "MF_SQL_ENGINE_USER" MF_SQL_ENGINE_SCHEMA = "MF_SQL_ENGINE_SCHEMA" +MF_SQL_ENGINE_WAREHOUSE = "MF_SQL_ENGINE_WAREHOUSE" def configure_test_env_from_url(url: str, schema: str) -> sqlalchemy.engine.URL: @@ -74,7 +74,13 @@ def make_test_sql_client(url: str, password: str, schema: str) -> SqlClient: if dialect == SqlDialect.REDSHIFT: return RedshiftSqlClient.from_connection_details(url, password) elif dialect == SqlDialect.SNOWFLAKE: - return SnowflakeSqlClient.from_connection_details(url, password) + parsed_url = configure_test_env_from_url(url, schema) + assert "warehouse" in parsed_url.normalized_query, "Sql engine URL params did not include Snowflake warehouse!" + warehouses = parsed_url.normalized_query["warehouse"] + assert len(warehouses) == 1, f"Found more than 1 warehouse in Snowflake URL: `{warehouses}`" + os.environ[MF_SQL_ENGINE_WAREHOUSE] = warehouses[0] + __initialize_dbt() + return AdapterBackedSqlClient(adapter=get_adapter_by_type("snowflake")) elif dialect == SqlDialect.BIGQUERY: return BigQuerySqlClient.from_connection_details(url, password) elif dialect == SqlDialect.POSTGRESQL: diff --git a/metricflow/test/integration/test_write_to_table.py b/metricflow/test/integration/test_write_to_table.py index 4497c7c363..42a9258476 100644 --- a/metricflow/test/integration/test_write_to_table.py +++ b/metricflow/test/integration/test_write_to_table.py @@ -4,6 +4,7 @@ from metricflow.dataflow.sql_table import SqlTable from metricflow.engine.metricflow_engine import MetricFlowQueryRequest +from metricflow.protocols.sql_client import SqlEngine from metricflow.random_id import random_id from metricflow.test.compare_df import assert_dataframes_equal from metricflow.test.integration.conftest import IntegrationTestHelpers @@ -38,6 +39,10 @@ def test_write_to_table(it_helpers: IntegrationTestHelpers) -> None: # noqa: D """ ) ) - assert_dataframes_equal(actual=actual, expected=expected) + assert_dataframes_equal( + actual=actual, + expected=expected, + compare_names_using_lowercase=it_helpers.sql_client.sql_engine_type is SqlEngine.SNOWFLAKE, + ) finally: it_helpers.sql_client.drop_table(output_table) diff --git a/metricflow/test/sql_clients/test_sql_client.py b/metricflow/test/sql_clients/test_sql_client.py index 11fd6309b6..7b2cf1817a 100644 --- a/metricflow/test/sql_clients/test_sql_client.py +++ b/metricflow/test/sql_clients/test_sql_client.py @@ -26,7 +26,13 @@ def _select_x_as_y(x: int = 1, y: str = "y") -> str: # noqa: D return f"SELECT {x} AS {y}" -def _check_1col(df: pd.DataFrame, col: str = "y", vals: Set[Union[int, str]] = {1}) -> None: # noqa: D +def _check_1col(df: pd.DataFrame, col: str = "y", vals: Set[Union[int, str]] = {1}) -> None: + """Helper to check that 1 column has the same value and a case-insensitive matching name. + + We lower-case the names due to snowflake's tendency to capitalize things. This isn't ideal but it'll do for now. + """ + df.columns = df.columns.str.lower() + col = col.lower() assert isinstance(df, pd.DataFrame) assert df.shape == (len(vals), 1) assert df.columns.tolist() == [col] @@ -116,7 +122,11 @@ def test_create_table_from_dataframe( # noqa: D sql_client.create_table_from_dataframe(sql_table=sql_table, df=expected_df) actual_df = sql_client.query(f"SELECT * FROM {sql_table.sql}") - assert_dataframes_equal(actual=actual_df, expected=expected_df) + assert_dataframes_equal( + actual=actual_df, + expected=expected_df, + compare_names_using_lowercase=sql_client.sql_engine_type is SqlEngine.SNOWFLAKE, + ) def test_table_exists(mf_test_session_state: MetricFlowTestSessionState, sql_client: SqlClient) -> None: # noqa: D @@ -159,7 +169,7 @@ def test_dry_run_of_bad_query_raises_exception(sql_client: SqlClient) -> None: bad_stmt = "SELECT bad_col" # Tests that a bad query raises an exception. Different engines may raise different exceptions e.g. # ProgrammingError, OperationalError, google.api_core.exceptions.BadRequest, etc. - with pytest.raises(Exception, match=r"bad_col"): + with pytest.raises(Exception, match=r"(?i)bad_col"): sql_client.dry_run(bad_stmt) diff --git a/metricflow/test/table_snapshot/test_source_schema.py b/metricflow/test/table_snapshot/test_source_schema.py index aa6073f6eb..933b1f2ef9 100644 --- a/metricflow/test/table_snapshot/test_source_schema.py +++ b/metricflow/test/table_snapshot/test_source_schema.py @@ -6,7 +6,7 @@ import pytest from metricflow.dataflow.sql_table import SqlTable -from metricflow.protocols.sql_client import SqlClient +from metricflow.protocols.sql_client import SqlClient, SqlEngine from metricflow.test.compare_df import assert_dataframes_equal from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState from metricflow.test.fixtures.table_fixtures import CONFIGURED_SOURCE_TABLE_SNAPSHOT_REPOSITORY @@ -61,6 +61,7 @@ def test_validate_data_in_source_schema( assert_dataframes_equal( actual=actual_table_df, expected=expected_table_df, + compare_names_using_lowercase=sql_client.sql_engine_type is SqlEngine.SNOWFLAKE, ) except Exception as e: error_message = ( diff --git a/metricflow/test/table_snapshot/test_table_snapshots.py b/metricflow/test/table_snapshot/test_table_snapshots.py index 9e84be9da9..d28e48a961 100644 --- a/metricflow/test/table_snapshot/test_table_snapshots.py +++ b/metricflow/test/table_snapshot/test_table_snapshots.py @@ -7,7 +7,7 @@ import pytest from dbt_semantic_interfaces.test_utils import as_datetime -from metricflow.protocols.sql_client import SqlClient +from metricflow.protocols.sql_client import SqlClient, SqlEngine from metricflow.random_id import random_id from metricflow.test.compare_df import assert_dataframes_equal from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState @@ -68,7 +68,11 @@ def test_restore( snapshot_restorer.restore(table_snapshot) actual = sql_client.query(f"SELECT * FROM {schema_name}.{table_snapshot.table_name}") - assert_dataframes_equal(actual=actual, expected=table_snapshot.as_df) + assert_dataframes_equal( + actual=actual, + expected=table_snapshot.as_df, + compare_names_using_lowercase=sql_client.sql_engine_type is SqlEngine.SNOWFLAKE, + ) finally: sql_client.drop_schema(schema_name, cascade=True) diff --git a/pyproject.toml b/pyproject.toml index 0b38d6a439..13e853ad7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,10 @@ dev-packages = [ "types-python-dateutil", ] +dbt-snowflake = [ + "dbt-snowflake>=1.6.0b3", +] + [tool.hatch.build.targets.sdist] exclude = [ "/.github", @@ -119,6 +123,20 @@ features = [ "dev-packages", ] +[tool.hatch.envs.snowflake-env.env-vars] + MF_TEST_ADAPTER_TYPE="snowflake" + # Note - the snowflake URL and password should be set via environment secrets + # in the calling process + +[tool.hatch.envs.snowflake-env] +description = "Dev environment for working with Snowflake adapter" +# Install the dbt snowflake package as a pre-install extra, just as with postgres +features = [ + "dev-packages", + "dbt-snowflake", +] + + # Many deprecation warnings come from 3rd-party libraries and make the # output of pytest noisy. Since no action is going to be taken, hide those # warnings.