From 8b63d85b1ceba87779bc464cebc0b9f81287c58d Mon Sep 17 00:00:00 2001 From: tlento Date: Tue, 27 Jun 2023 23:30:00 -0700 Subject: [PATCH 1/3] Move Snowflake queries onto the dbt AdapterBackedSqlClient It's time to enable Snowflake queries for our first class integration with dbt projects. This commit generalizes the initialization of the AdapterBackedSqlClient to allow for easy expansion to other supported adapter types, and demonstrates the capability by enabling support for Snowflake. Due to Snowflake's insistence on storing database objects with fully uppercase identifiers by default, and making queries against them case sensitive..... sometimes? Maybe? .... we need to update a lot of test assertion checks and a few of our accessors. The only production impact is on the list_dimensions method, which renders a column in a result dataframe that it selects by name. In the longer term we likely need a better story around identifier escaping for case sensitivity in Snowflake, but this appears to work. This commit was also tested via CLI queries against our Snowflake instance, and everything works as expected. --- .../unreleased/Features-20230627-234808.yaml | 6 ++ .github/workflows/cd-sql-engine-tests.yaml | 7 +-- Makefile | 4 ++ .../dbt_connectors/adapter_backed_client.py | 58 ++++++++++++++++--- metricflow/engine/metricflow_engine.py | 6 ++ metricflow/test/cli/test_cli.py | 5 +- metricflow/test/compare_df.py | 14 ++++- metricflow/test/execution/test_tasks.py | 4 +- .../metricflow_testing/profiles.yml | 12 ++++ .../test/fixtures/sql_client_fixtures.py | 10 +++- .../test/integration/test_write_to_table.py | 7 ++- .../test/sql_clients/test_sql_client.py | 16 ++++- .../test/table_snapshot/test_source_schema.py | 3 +- .../table_snapshot/test_table_snapshots.py | 8 ++- pyproject.toml | 18 ++++++ 15 files changed, 151 insertions(+), 27 deletions(-) create mode 100644 .changes/unreleased/Features-20230627-234808.yaml diff --git a/.changes/unreleased/Features-20230627-234808.yaml b/.changes/unreleased/Features-20230627-234808.yaml new file mode 100644 index 0000000000..63b488095f --- /dev/null +++ b/.changes/unreleased/Features-20230627-234808.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Enable Snowflake queries in dbt <-> MetricFlow CLI integration +time: 2023-06-27T23:48:08.287586-07:00 +custom: + Author: tlento + Issue: "579" diff --git a/.github/workflows/cd-sql-engine-tests.yaml b/.github/workflows/cd-sql-engine-tests.yaml index 480f21b866..a7f25a2bdc 100644 --- a/.github/workflows/cd-sql-engine-tests.yaml +++ b/.github/workflows/cd-sql-engine-tests.yaml @@ -21,7 +21,6 @@ jobs: name: Snowflake Tests runs-on: ubuntu-latest steps: - - name: Check-out the repo uses: actions/checkout@v3 @@ -33,6 +32,7 @@ jobs: mf_sql_engine_password: ${{ secrets.MF_SNOWFLAKE_PWD }} parallelism: ${{ env.EXTERNAL_ENGINE_TEST_PARALLELISM }} additional-pytest-options: ${{ env.ADDITIONAL_PYTEST_OPTIONS }} + make-target: "test-snowflake" redshift-tests: environment: DW_INTEGRATION_TESTS @@ -40,7 +40,6 @@ jobs: if: ${{ github.event.action != 'labeled' || github.event.label.name == 'run_mf_sql_engine_tests' }} runs-on: ubuntu-latest steps: - - name: Check-out the repo uses: actions/checkout@v3 @@ -59,7 +58,6 @@ jobs: if: ${{ github.event.action != 'labeled' || github.event.label.name == 'run_mf_sql_engine_tests' }} runs-on: ubuntu-latest steps: - - name: Check-out the repo uses: actions/checkout@v3 @@ -78,7 +76,6 @@ jobs: if: ${{ github.event.action != 'labeled' || github.event.label.name == 'run_mf_sql_engine_tests' }} runs-on: ubuntu-latest steps: - - name: Check-out the repo uses: actions/checkout@v3 @@ -97,7 +94,6 @@ jobs: if: ${{ github.event.action != 'labeled' || github.event.label.name == 'run_mf_sql_engine_tests' }} runs-on: ubuntu-latest steps: - - name: Check-out the repo uses: actions/checkout@v3 @@ -116,7 +112,6 @@ jobs: if: ${{ github.event_name != 'pull_request' && failure() }} runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - name: Slack Failure diff --git a/Makefile b/Makefile index 2f894ebea4..235ecb0630 100644 --- a/Makefile +++ b/Makefile @@ -30,6 +30,10 @@ test: test-postgresql: hatch -v run postgres-env:pytest -vv -n $(PARALLELISM) $(ADDITIONAL_PYTEST_OPTIONS) metricflow/test/ +.PHONY: test-snowflake +test-snowflake: + hatch -v run snowflake-env:pytest -vv -n $(PARALLELISM) $(ADDITIONAL_PYTEST_OPTIONS) metricflow/test/ + .PHONY: lint lint: hatch -v run dev-env:pre-commit run --all-files diff --git a/metricflow/cli/dbt_connectors/adapter_backed_client.py b/metricflow/cli/dbt_connectors/adapter_backed_client.py index b74d0b9d71..0676745920 100644 --- a/metricflow/cli/dbt_connectors/adapter_backed_client.py +++ b/metricflow/cli/dbt_connectors/adapter_backed_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +import enum import logging import textwrap import time @@ -7,6 +8,7 @@ import pandas as pd from dbt.adapters.base.impl import BaseAdapter +from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from dbt_semantic_interfaces.pretty_print import pformat_big_objects from metricflow.dataflow.sql_table import SqlTable @@ -16,6 +18,7 @@ from metricflow.protocols.sql_request import SqlJsonTag, SqlRequestId, SqlRequestTagSet from metricflow.random_id import random_id from metricflow.sql.render.postgres import PostgresSQLSqlQueryPlanRenderer +from metricflow.sql.render.snowflake import SnowflakeSqlQueryPlanRenderer from metricflow.sql.render.sql_plan_renderer import SqlQueryPlanRenderer from metricflow.sql.sql_bind_parameters import SqlBindParameters from metricflow.sql_clients.sql_statement_metadata import CombinedSqlTags, SqlStatementCommentMetadata @@ -23,6 +26,33 @@ logger = logging.getLogger(__name__) +class SupportedAdapterTypes(enum.Enum): + """Enumeration of supported dbt adapter types.""" + + POSTGRES = "postgres" + SNOWFLAKE = "snowflake" + + @property + def sql_engine_type(self) -> SqlEngine: + """Return the SqlEngine corresponding to the supported adapter type.""" + if self is SupportedAdapterTypes.POSTGRES: + return SqlEngine.POSTGRES + elif self is SupportedAdapterTypes.SNOWFLAKE: + return SqlEngine.SNOWFLAKE + else: + assert_values_exhausted(self) + + @property + def sql_query_plan_renderer(self) -> SqlQueryPlanRenderer: + """Return the SqlQueryPlanRenderer corresponding to the supported adapter type.""" + if self is SupportedAdapterTypes.POSTGRES: + return PostgresSQLSqlQueryPlanRenderer() + elif self is SupportedAdapterTypes.SNOWFLAKE: + return SnowflakeSqlQueryPlanRenderer() + else: + assert_values_exhausted(self) + + class AdapterBackedSqlClient: """SqlClient implementation which delegates database operations to a dbt BaseAdapter instance. @@ -38,15 +68,18 @@ def __init__(self, adapter: BaseAdapter): The dbt BaseAdapter should already be fully initialized, including all credential verification, and ready for use for establishing connections and issuing queries. """ - if adapter.type() != "postgres": - raise ValueError( - f"Received dbt adapter with unsupported type {adapter.type()}, but we only support postgres!" - ) self._adapter = adapter - # TODO: normalize from adapter.type() - self._sql_engine_type = SqlEngine.POSTGRES - # TODO: create factory based on SqlEngine type - self._sql_query_plan_renderer = PostgresSQLSqlQueryPlanRenderer() + try: + adapter_type = SupportedAdapterTypes(self._adapter.type()) + except ValueError as e: + raise ValueError( + f"Adapter type {self._adapter.type()} is not supported. Must be one " + f"of {[item.value for item in SupportedAdapterTypes]}." + ) from e + + self._sql_engine_type = adapter_type.sql_engine_type + self._sql_query_plan_renderer = adapter_type.sql_query_plan_renderer + logger.info(f"Initialized AdapterBackedSqlClient with dbt adapter type `{adapter_type.value}`") @property def sql_engine_type(self) -> SqlEngine: @@ -246,6 +279,11 @@ def _get_type_from_pandas_dtype(self, dtype: str) -> str: def list_tables(self, schema_name: str) -> Sequence[str]: """Get a list of the table names in a given schema. Only used in tutorials and tests.""" # TODO: Short term, make this work with as many engines as possible. Medium term, remove this altogether. + if self.sql_engine_type is SqlEngine.SNOWFLAKE: + # Snowflake likes capitalizing things, except when it doesn't. We can get away with this due to its + # limited scope of usage. + schema_name = schema_name.upper() + df = self.query( textwrap.dedent( f"""\ @@ -257,7 +295,9 @@ def list_tables(self, schema_name: str) -> Sequence[str]: if df.empty: return [] - # Lower casing table names for consistency between Snowflake and other clients. + # Lower casing table names and data frame names for consistency between Snowflake and other clients. + # As above, we can do this because it isn't used in any consequential situations. + df.columns = df.columns.str.lower() return [t.lower() for t in df["table_name"]] def table_exists(self, sql_table: SqlTable) -> bool: diff --git a/metricflow/engine/metricflow_engine.py b/metricflow/engine/metricflow_engine.py index 5a6e39a367..2fd65a0e4c 100644 --- a/metricflow/engine/metricflow_engine.py +++ b/metricflow/engine/metricflow_engine.py @@ -574,6 +574,12 @@ def get_dimension_values( # noqa: D result_dataframe = query_result.result_df if result_dataframe is None: return [] + if get_group_by_values not in result_dataframe.columns: + # Snowflake likes upper-casing things in result output, so we lower-case all names and + # see if we can get the value from there. + get_group_by_values = get_group_by_values.lower() + result_dataframe.columns = result_dataframe.columns.str.lower() + return [str(val) for val in result_dataframe[get_group_by_values]] @log_call(module_name=__name__, telemetry_reporter=_telemetry_reporter) diff --git a/metricflow/test/cli/test_cli.py b/metricflow/test/cli/test_cli.py index 0a9540a6d4..54f49880f6 100644 --- a/metricflow/test/cli/test_cli.py +++ b/metricflow/test/cli/test_cli.py @@ -33,12 +33,15 @@ TIME_SPINE_TABLE, TRANSACTIONS_TABLE, ) +from metricflow.protocols.sql_client import SqlEngine from metricflow.test.fixtures.cli_fixtures import MetricFlowCliRunner def test_query(cli_runner: MetricFlowCliRunner) -> None: # noqa: D resp = cli_runner.run(query, args=["--metrics", "bookings", "--group-bys", "ds"]) - assert "bookings" in resp.output + # case insensitive matches are needed for snowflake due to the capitalization thing + engine_is_snowflake = cli_runner.cli_context.sql_client.sql_engine_type is SqlEngine.SNOWFLAKE + assert "bookings" in resp.output or ("bookings" in resp.output.lower() and engine_is_snowflake) assert resp.exit_code == 0 diff --git a/metricflow/test/compare_df.py b/metricflow/test/compare_df.py index e401430583..ce34fbd419 100644 --- a/metricflow/test/compare_df.py +++ b/metricflow/test/compare_df.py @@ -43,12 +43,24 @@ def _dataframes_contain_same_data( def assert_dataframes_equal( - actual: pd.DataFrame, expected: pd.DataFrame, sort_columns: bool = True, allow_empty: bool = False + actual: pd.DataFrame, + expected: pd.DataFrame, + sort_columns: bool = True, + allow_empty: bool = False, + compare_names_using_lowercase: bool = False, ) -> None: """Check that contents of DataFrames are the same. If sort_columns is set to false, value and column order needs to be the same. + If compare_names_using_lowercase is set to True, we copy the dataframes and lower-case their names. + This is useful for Snowflake query output comparisons. """ + if compare_names_using_lowercase: + actual = actual.copy() + expected = expected.copy() + actual.columns = actual.columns.str.lower() + expected.columns = expected.columns.str.lower() + if set(actual.columns) != set(expected.columns): raise ValueError( f"DataFrames do not contain the same columns. actual: {set(actual.columns)}, " diff --git a/metricflow/test/execution/test_tasks.py b/metricflow/test/execution/test_tasks.py index dda868d7f6..6c90330c15 100644 --- a/metricflow/test/execution/test_tasks.py +++ b/metricflow/test/execution/test_tasks.py @@ -9,7 +9,7 @@ SelectSqlQueryToTableTask, ) from metricflow.execution.executor import SequentialPlanExecutor -from metricflow.protocols.sql_client import SqlClient +from metricflow.protocols.sql_client import SqlClient, SqlEngine from metricflow.random_id import random_id from metricflow.sql.sql_bind_parameters import SqlBindParameters from metricflow.test.compare_df import assert_dataframes_equal @@ -32,6 +32,7 @@ def test_read_sql_task(sql_client: SqlClient) -> None: # noqa: D columns=["foo"], data=[(1,)], ), + compare_names_using_lowercase=sql_client.sql_engine_type is SqlEngine.SNOWFLAKE, ) @@ -56,5 +57,6 @@ def test_write_table_task(mf_test_session_state: MetricFlowTestSessionState, sql columns=["foo"], data=[(1,)], ), + compare_names_using_lowercase=sql_client.sql_engine_type is SqlEngine.SNOWFLAKE, ) sql_client.drop_table(output_table) diff --git a/metricflow/test/fixtures/dbt_projects/metricflow_testing/profiles.yml b/metricflow/test/fixtures/dbt_projects/metricflow_testing/profiles.yml index 6e44e30c4e..90223cf818 100644 --- a/metricflow/test/fixtures/dbt_projects/metricflow_testing/profiles.yml +++ b/metricflow/test/fixtures/dbt_projects/metricflow_testing/profiles.yml @@ -9,3 +9,15 @@ postgres: pass: "{{ env_var('MF_SQL_ENGINE_PASSWORD') }}" dbname: "{{ env_var('MF_SQL_ENGINE_DATABASE') }}" schema: "{{ env_var('MF_SQL_ENGINE_SCHEMA') }}" +snowflake: + target: dev + outputs: + dev: + type: snowflake + # The snowflake account is equivalent to the host value in the SqlAlchemy parsed URL + account: "{{ env_var('MF_SQL_ENGINE_HOST') }}" + user: "{{ env_var('MF_SQL_ENGINE_USER') }}" + password: "{{ env_var('MF_SQL_ENGINE_PASSWORD') }}" + warehouse: "{{ env_var('MF_SQL_ENGINE_WAREHOUSE') }}" + database: "{{ env_var('MF_SQL_ENGINE_DATABASE') }}" + schema: "{{ env_var('MF_SQL_ENGINE_SCHEMA') }}" diff --git a/metricflow/test/fixtures/sql_client_fixtures.py b/metricflow/test/fixtures/sql_client_fixtures.py index 12443b720b..59ba13421c 100644 --- a/metricflow/test/fixtures/sql_client_fixtures.py +++ b/metricflow/test/fixtures/sql_client_fixtures.py @@ -16,7 +16,6 @@ from metricflow.sql_clients.databricks import DatabricksSqlClient from metricflow.sql_clients.duckdb import DuckDbSqlClient from metricflow.sql_clients.redshift import RedshiftSqlClient -from metricflow.sql_clients.snowflake import SnowflakeSqlClient from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState, dialect_from_url logger = logging.getLogger(__name__) @@ -27,6 +26,7 @@ MF_SQL_ENGINE_PORT = "MF_SQL_ENGINE_PORT" MF_SQL_ENGINE_USER = "MF_SQL_ENGINE_USER" MF_SQL_ENGINE_SCHEMA = "MF_SQL_ENGINE_SCHEMA" +MF_SQL_ENGINE_WAREHOUSE = "MF_SQL_ENGINE_WAREHOUSE" def configure_test_env_from_url(url: str, schema: str) -> sqlalchemy.engine.URL: @@ -74,7 +74,13 @@ def make_test_sql_client(url: str, password: str, schema: str) -> SqlClient: if dialect == SqlDialect.REDSHIFT: return RedshiftSqlClient.from_connection_details(url, password) elif dialect == SqlDialect.SNOWFLAKE: - return SnowflakeSqlClient.from_connection_details(url, password) + parsed_url = configure_test_env_from_url(url, schema) + assert "warehouse" in parsed_url.normalized_query, "Sql engine URL params did not include Snowflake warehouse!" + warehouses = parsed_url.normalized_query["warehouse"] + assert len(warehouses) == 1, f"Found more than 1 warehouse in Snowflake URL: `{warehouses}`" + os.environ[MF_SQL_ENGINE_WAREHOUSE] = warehouses[0] + __initialize_dbt() + return AdapterBackedSqlClient(adapter=get_adapter_by_type("snowflake")) elif dialect == SqlDialect.BIGQUERY: return BigQuerySqlClient.from_connection_details(url, password) elif dialect == SqlDialect.POSTGRESQL: diff --git a/metricflow/test/integration/test_write_to_table.py b/metricflow/test/integration/test_write_to_table.py index 4497c7c363..42a9258476 100644 --- a/metricflow/test/integration/test_write_to_table.py +++ b/metricflow/test/integration/test_write_to_table.py @@ -4,6 +4,7 @@ from metricflow.dataflow.sql_table import SqlTable from metricflow.engine.metricflow_engine import MetricFlowQueryRequest +from metricflow.protocols.sql_client import SqlEngine from metricflow.random_id import random_id from metricflow.test.compare_df import assert_dataframes_equal from metricflow.test.integration.conftest import IntegrationTestHelpers @@ -38,6 +39,10 @@ def test_write_to_table(it_helpers: IntegrationTestHelpers) -> None: # noqa: D """ ) ) - assert_dataframes_equal(actual=actual, expected=expected) + assert_dataframes_equal( + actual=actual, + expected=expected, + compare_names_using_lowercase=it_helpers.sql_client.sql_engine_type is SqlEngine.SNOWFLAKE, + ) finally: it_helpers.sql_client.drop_table(output_table) diff --git a/metricflow/test/sql_clients/test_sql_client.py b/metricflow/test/sql_clients/test_sql_client.py index 11fd6309b6..7b2cf1817a 100644 --- a/metricflow/test/sql_clients/test_sql_client.py +++ b/metricflow/test/sql_clients/test_sql_client.py @@ -26,7 +26,13 @@ def _select_x_as_y(x: int = 1, y: str = "y") -> str: # noqa: D return f"SELECT {x} AS {y}" -def _check_1col(df: pd.DataFrame, col: str = "y", vals: Set[Union[int, str]] = {1}) -> None: # noqa: D +def _check_1col(df: pd.DataFrame, col: str = "y", vals: Set[Union[int, str]] = {1}) -> None: + """Helper to check that 1 column has the same value and a case-insensitive matching name. + + We lower-case the names due to snowflake's tendency to capitalize things. This isn't ideal but it'll do for now. + """ + df.columns = df.columns.str.lower() + col = col.lower() assert isinstance(df, pd.DataFrame) assert df.shape == (len(vals), 1) assert df.columns.tolist() == [col] @@ -116,7 +122,11 @@ def test_create_table_from_dataframe( # noqa: D sql_client.create_table_from_dataframe(sql_table=sql_table, df=expected_df) actual_df = sql_client.query(f"SELECT * FROM {sql_table.sql}") - assert_dataframes_equal(actual=actual_df, expected=expected_df) + assert_dataframes_equal( + actual=actual_df, + expected=expected_df, + compare_names_using_lowercase=sql_client.sql_engine_type is SqlEngine.SNOWFLAKE, + ) def test_table_exists(mf_test_session_state: MetricFlowTestSessionState, sql_client: SqlClient) -> None: # noqa: D @@ -159,7 +169,7 @@ def test_dry_run_of_bad_query_raises_exception(sql_client: SqlClient) -> None: bad_stmt = "SELECT bad_col" # Tests that a bad query raises an exception. Different engines may raise different exceptions e.g. # ProgrammingError, OperationalError, google.api_core.exceptions.BadRequest, etc. - with pytest.raises(Exception, match=r"bad_col"): + with pytest.raises(Exception, match=r"(?i)bad_col"): sql_client.dry_run(bad_stmt) diff --git a/metricflow/test/table_snapshot/test_source_schema.py b/metricflow/test/table_snapshot/test_source_schema.py index aa6073f6eb..933b1f2ef9 100644 --- a/metricflow/test/table_snapshot/test_source_schema.py +++ b/metricflow/test/table_snapshot/test_source_schema.py @@ -6,7 +6,7 @@ import pytest from metricflow.dataflow.sql_table import SqlTable -from metricflow.protocols.sql_client import SqlClient +from metricflow.protocols.sql_client import SqlClient, SqlEngine from metricflow.test.compare_df import assert_dataframes_equal from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState from metricflow.test.fixtures.table_fixtures import CONFIGURED_SOURCE_TABLE_SNAPSHOT_REPOSITORY @@ -61,6 +61,7 @@ def test_validate_data_in_source_schema( assert_dataframes_equal( actual=actual_table_df, expected=expected_table_df, + compare_names_using_lowercase=sql_client.sql_engine_type is SqlEngine.SNOWFLAKE, ) except Exception as e: error_message = ( diff --git a/metricflow/test/table_snapshot/test_table_snapshots.py b/metricflow/test/table_snapshot/test_table_snapshots.py index 9e84be9da9..d28e48a961 100644 --- a/metricflow/test/table_snapshot/test_table_snapshots.py +++ b/metricflow/test/table_snapshot/test_table_snapshots.py @@ -7,7 +7,7 @@ import pytest from dbt_semantic_interfaces.test_utils import as_datetime -from metricflow.protocols.sql_client import SqlClient +from metricflow.protocols.sql_client import SqlClient, SqlEngine from metricflow.random_id import random_id from metricflow.test.compare_df import assert_dataframes_equal from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState @@ -68,7 +68,11 @@ def test_restore( snapshot_restorer.restore(table_snapshot) actual = sql_client.query(f"SELECT * FROM {schema_name}.{table_snapshot.table_name}") - assert_dataframes_equal(actual=actual, expected=table_snapshot.as_df) + assert_dataframes_equal( + actual=actual, + expected=table_snapshot.as_df, + compare_names_using_lowercase=sql_client.sql_engine_type is SqlEngine.SNOWFLAKE, + ) finally: sql_client.drop_schema(schema_name, cascade=True) diff --git a/pyproject.toml b/pyproject.toml index 0b38d6a439..13e853ad7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,10 @@ dev-packages = [ "types-python-dateutil", ] +dbt-snowflake = [ + "dbt-snowflake>=1.6.0b3", +] + [tool.hatch.build.targets.sdist] exclude = [ "/.github", @@ -119,6 +123,20 @@ features = [ "dev-packages", ] +[tool.hatch.envs.snowflake-env.env-vars] + MF_TEST_ADAPTER_TYPE="snowflake" + # Note - the snowflake URL and password should be set via environment secrets + # in the calling process + +[tool.hatch.envs.snowflake-env] +description = "Dev environment for working with Snowflake adapter" +# Install the dbt snowflake package as a pre-install extra, just as with postgres +features = [ + "dev-packages", + "dbt-snowflake", +] + + # Many deprecation warnings come from 3rd-party libraries and make the # output of pytest noisy. Since no action is going to be taken, hide those # warnings. From 0e17bbca86779b1ad19007fc58a5f778b253aed3 Mon Sep 17 00:00:00 2001 From: tlento Date: Tue, 27 Jun 2023 23:51:34 -0700 Subject: [PATCH 2/3] Align postgres hatch environment config with Snowflake Due to an incompatibility between dbt-snowflake 1.5 and 1.6, we had to specify the version. Given that versioning was necessary, the benefit of using pre-install was limited. As such, we make the relevant adapters environment-specific dependencies, and bring postgres in line. This will hopefully get streamlined into the bundle package in the future, but for now we can put up with manual version management on these adapter environments. --- pyproject.toml | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 13e853ad7d..25e4a42082 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,10 @@ dev-packages = [ "types-python-dateutil", ] +dbt-postgres = [ + "dbt-postgres>=1.6.0.b6", +] + dbt-snowflake = [ "dbt-snowflake>=1.6.0b3", ] @@ -110,17 +114,9 @@ features = [ [tool.hatch.envs.postgres-env] description = "Dev environment for working with Postgres adapter" -# Install the dbt postgres package as a pre-install extra -# This helps us avoid having to do version pegs, although it relies on pip's -# current "pretend version conflicts aren't a problem" behavior for existing -# installations. In future, this will be updated to use an editable dependency -# on the dbt-metricflow bundle, which will allow for local version management -# for testing. -pre-install-commands = [ - "pip install dbt-postgres", -] features = [ "dev-packages", + "dbt-postgres", ] [tool.hatch.envs.snowflake-env.env-vars] From 30ec98a745f5c81e6f9639cb9d8636b5f8236262 Mon Sep 17 00:00:00 2001 From: tlento Date: Tue, 27 Jun 2023 23:55:03 -0700 Subject: [PATCH 3/3] Remove unused SQLAlchemy-backed Snowflake client This allows us to remove the direct dependencies on Snowflake's python connector and SQLAlchemy libraries. --- metricflow/sql_clients/snowflake.py | 247 ---------------------------- pyproject.toml | 2 - 2 files changed, 249 deletions(-) delete mode 100644 metricflow/sql_clients/snowflake.py diff --git a/metricflow/sql_clients/snowflake.py b/metricflow/sql_clients/snowflake.py deleted file mode 100644 index a4dff5a466..0000000000 --- a/metricflow/sql_clients/snowflake.py +++ /dev/null @@ -1,247 +0,0 @@ -from __future__ import annotations - -import json -import logging -import textwrap -import threading -import urllib.parse -from collections import OrderedDict -from contextlib import contextmanager -from typing import Dict, Iterator, Optional, Sequence, Set - -import pandas as pd -import sqlalchemy -from sqlalchemy.exc import ProgrammingError -from typing_extensions import override - -from metricflow.protocols.sql_client import SqlEngine -from metricflow.protocols.sql_request import ( - MF_EXTRA_TAGS_KEY, - MF_SYSTEM_TAGS_KEY, - JsonDict, - SqlJsonTag, - SqlRequestTagSet, -) -from metricflow.sql.render.snowflake import SnowflakeSqlQueryPlanRenderer -from metricflow.sql.render.sql_plan_renderer import SqlQueryPlanRenderer -from metricflow.sql.sql_bind_parameters import SqlBindParameters -from metricflow.sql_clients.common_client import SqlDialect, not_empty -from metricflow.sql_clients.sqlalchemy_dialect import SqlAlchemySqlClient - -logger = logging.getLogger(__name__) - - -class SnowflakeSqlClient(SqlAlchemySqlClient): - """Client for Snowflake. - - Note: By default, Snowflake uses uppercase for schema, table, and column - names. To create or access them as lowercase, you must use double quotes. - - It's also tricky trying to get tests / queries on Snowflake working with - https://docs.snowflake.com/en/sql-reference/parameters.html#quoted-identifiers-ignore-case enabled. - For example, when listing table names, all tables would be upper case with that setting (causing an issue where - semantic models would constantly be primed because the table names didn't match). - """ - - DEFAULT_LOGIN_TIMEOUT = 60 - DEFAULT_CLIENT_SESSION_KEEP_ALIVE = True - - @staticmethod - def _parse_url_query_params(url: str) -> Dict[str, str]: - """Gets the warehouse from the query parameters in the URL, throwing an exception if not set properly.""" - url_query_params: Dict[str, str] = {} - - parsed_url = urllib.parse.urlparse(url) - query_dict = urllib.parse.parse_qs(parsed_url.query) - - if "warehouse" not in query_dict: - raise ValueError(f"Missing warehouse in URL query: {url}") - - if len(query_dict["warehouse"]) > 1: - raise ValueError(f"Multiple warehouses in URL query: {url}") - - url_query_params["warehouse"] = query_dict["warehouse"][0] - - # optionally, role - if "role" not in query_dict: - return url_query_params - - if len(query_dict["role"]) > 1: - raise ValueError(f"Multiple roles in URL query: {url}") - - url_query_params["role"] = query_dict["role"][0] - return url_query_params - - @staticmethod - def from_connection_details(url: str, password: Optional[str]) -> SnowflakeSqlClient: # noqa: D - parsed_url = sqlalchemy.engine.make_url(url) - if parsed_url.drivername != SqlDialect.SNOWFLAKE.value: - raise ValueError(f"Invalid dialect in URL for Snowflake: {url}") - - if parsed_url.port: - raise ValueError(f"Snowflake URL should not have a port set: {url}") - - if not password: - raise ValueError(f"Password not supplied for {url}") - - SqlAlchemySqlClient.validate_query_params( - url=parsed_url, required_parameters={"warehouse"}, optional_parameters={"role"} - ) - - return SnowflakeSqlClient( - host=not_empty(parsed_url.host, "host", url), - username=not_empty(parsed_url.username, "username", url), - password=password, - database=not_empty(parsed_url.database, "database", url), - url_query_params=SnowflakeSqlClient._parse_url_query_params(url), - ) - - def __init__( # noqa: D - self, - database: str, - username: str, - password: str, - host: str, - url_query_params: Dict[str, str], - login_timeout: int = DEFAULT_LOGIN_TIMEOUT, - client_session_keep_alive: bool = DEFAULT_CLIENT_SESSION_KEEP_ALIVE, - ) -> None: - self._connection_url = SqlAlchemySqlClient.build_engine_url( - dialect=SqlDialect.SNOWFLAKE.value, - username=username, - password=password, - host=host, - database=database, - query=url_query_params, - ) - self._engine_lock = threading.Lock() - self._known_sessions_ids_lock = threading.Lock() - self._known_session_ids: Set[int] = set() - super().__init__( - engine=self._create_engine(login_timeout=login_timeout, client_session_keep_alive=client_session_keep_alive) - ) - - def _create_engine( - self, - login_timeout: int = DEFAULT_LOGIN_TIMEOUT, - client_session_keep_alive: bool = DEFAULT_CLIENT_SESSION_KEEP_ALIVE, - ) -> sqlalchemy.engine.Engine: # noqa: D - return sqlalchemy.create_engine( - self._connection_url, - pool_size=10, - max_overflow=10, - pool_pre_ping=False, - connect_args={"client_session_keep_alive": client_session_keep_alive, "login_timeout": login_timeout}, - ) - - @property - @override - def sql_engine_type(self) -> SqlEngine: - return SqlEngine.SNOWFLAKE - - @property - @override - def sql_query_plan_renderer(self) -> SqlQueryPlanRenderer: - return SnowflakeSqlQueryPlanRenderer() - - @contextmanager - def _engine_connection( - self, - engine: sqlalchemy.engine.Engine, - system_tags: SqlRequestTagSet = SqlRequestTagSet(), - extra_tags: SqlJsonTag = SqlJsonTag(), - ) -> Iterator[sqlalchemy.engine.Connection]: - """Context Manager for providing a configured connection. - - Snowflake allows setting a WEEK_START parameter on each session. This forces the value to be - 1, which means Monday. Future updates could parameterize this to read from some kind of - options construct, which the DBClient could read in at initialization and use here (for example). - At this time we hard-code the ISO standard. - """ - with super()._engine_connection(self._engine) as conn: - # WEEK_START 1 means Monday. - conn.execute("ALTER SESSION SET WEEK_START = 1;") - combined_tags: JsonDict = OrderedDict() - if system_tags.tag_dict: - combined_tags[MF_SYSTEM_TAGS_KEY] = system_tags.tag_dict - if extra_tags is not None: - combined_tags[MF_EXTRA_TAGS_KEY] = extra_tags.json_dict - - if combined_tags: - conn.execute( - sqlalchemy.text("ALTER SESSION SET QUERY_TAG = :query_tag"), - query_tag=json.dumps(combined_tags), - ) - results = conn.execute("SELECT CURRENT_SESSION()") - sessions = [] - for row in results: - sessions.append(row[0]) - assert len(sessions) == 1 - session = sessions[0] - with self._known_sessions_ids_lock: - self._known_session_ids.add(session) - yield conn - with self._known_sessions_ids_lock: - self._known_session_ids.remove(session) - - def _query( # noqa: D - self, - stmt: str, - bind_params: SqlBindParameters = SqlBindParameters(), - allow_re_auth: bool = True, - system_tags: SqlRequestTagSet = SqlRequestTagSet(), - extra_tags: SqlJsonTag = SqlJsonTag(), - ) -> pd.DataFrame: - with self._engine_connection(engine=self._engine, system_tags=system_tags, extra_tags=extra_tags) as conn: - try: - return pd.read_sql_query(sqlalchemy.text(stmt), conn, params=bind_params.param_dict) - except ProgrammingError as e: - if "Authentication token has expired" in str(e) and allow_re_auth: - logger.warning( - "Snowflake authentication token expired. Attempting to re-auth, then we'll re-run the query" - ) - with self._engine_lock: - self._engine.dispose() - self._engine = self._create_engine() - # this was our one chance to re-auth - return self._query(stmt, allow_re_auth=False, bind_params=bind_params) - raise e - - def _engine_specific_query_implementation( - self, - stmt: str, - bind_params: SqlBindParameters, - system_tags: SqlRequestTagSet = SqlRequestTagSet(), - extra_tags: SqlJsonTag = SqlJsonTag(), - ) -> pd.DataFrame: - return self._query( - stmt, - bind_params=bind_params, - system_tags=system_tags, - extra_tags=extra_tags, - ) - - def list_tables(self, schema_name: str) -> Sequence[str]: - """List tables using 'information_schema' instead of SHOW TABLES to sidestep 10K row limit. - - TODO: This and the previous implementation could have issues if Snowflake is configured with case-sensitivity. - """ - df = self.query( - textwrap.dedent( - """\ - SELECT table_name FROM information_schema.tables - WHERE table_schema = :schema_name - """ - ), - sql_bind_parameters=SqlBindParameters.create_from_dict({"schema_name": schema_name.upper()}), - ) - if df.empty: - return [] - - # Lower casing table names to be similar to other SQL clients. TBD on the implications of this. - return [t.lower() for t in df["table_name"]] - - def close(self) -> None: - """Snowflake will hang pytest if this is not done.""" - with self._engine_lock: - self._engine.dispose() diff --git a/pyproject.toml b/pyproject.toml index 25e4a42082..0bad587a19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,8 +49,6 @@ dependencies = [ "requests~=2.27.1", "ruamel.yaml~=0.17.21", "rudder-sdk-python~=1.0.3", - "snowflake-connector-python>=2.7.8", - "snowflake-sqlalchemy~=1.4.3", "sqlalchemy-bigquery~=1.6.1", "sqlalchemy-redshift==0.8.1", "sqlalchemy2-stubs~=0.0.2a21",