Skip to content

Commit

Permalink
Move Snowflake queries onto the dbt AdapterBackedSqlClient
Browse files Browse the repository at this point in the history
It's time to enable Snowflake queries for our first class integration
with dbt projects.

This commit generalizes the initialization of the AdapterBackedSqlClient
to allow for easy expansion to other supported adapter types, and
demonstrates the capability by enabling support for Snowflake.

Due to Snowflake's insistence on storing database objects with fully
uppercase identifiers by default, and making queries against them
case sensitive..... sometimes? Maybe? .... we need to update a lot
of test assertion checks and a few of our accessors.

The only production impact is on the list_dimensions method, which
renders a column in a result dataframe that it selects by name. In
the longer term we likely need a better story around identifier
escaping for case sensitivity in Snowflake, but this appears to
work.

This commit was also tested via CLI queries against our Snowflake
instance, and everything works as expected.
  • Loading branch information
tlento committed Jun 28, 2023
1 parent ee92470 commit 8b63d85
Show file tree
Hide file tree
Showing 15 changed files with 151 additions and 27 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230627-234808.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Enable Snowflake queries in dbt <-> MetricFlow CLI integration
time: 2023-06-27T23:48:08.287586-07:00
custom:
Author: tlento
Issue: "579"
7 changes: 1 addition & 6 deletions .github/workflows/cd-sql-engine-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ jobs:
name: Snowflake Tests
runs-on: ubuntu-latest
steps:

- name: Check-out the repo
uses: actions/checkout@v3

Expand All @@ -33,14 +32,14 @@ jobs:
mf_sql_engine_password: ${{ secrets.MF_SNOWFLAKE_PWD }}
parallelism: ${{ env.EXTERNAL_ENGINE_TEST_PARALLELISM }}
additional-pytest-options: ${{ env.ADDITIONAL_PYTEST_OPTIONS }}
make-target: "test-snowflake"

redshift-tests:
environment: DW_INTEGRATION_TESTS
name: Redshift Tests
if: ${{ github.event.action != 'labeled' || github.event.label.name == 'run_mf_sql_engine_tests' }}
runs-on: ubuntu-latest
steps:

- name: Check-out the repo
uses: actions/checkout@v3

Expand All @@ -59,7 +58,6 @@ jobs:
if: ${{ github.event.action != 'labeled' || github.event.label.name == 'run_mf_sql_engine_tests' }}
runs-on: ubuntu-latest
steps:

- name: Check-out the repo
uses: actions/checkout@v3

Expand All @@ -78,7 +76,6 @@ jobs:
if: ${{ github.event.action != 'labeled' || github.event.label.name == 'run_mf_sql_engine_tests' }}
runs-on: ubuntu-latest
steps:

- name: Check-out the repo
uses: actions/checkout@v3

Expand All @@ -97,7 +94,6 @@ jobs:
if: ${{ github.event.action != 'labeled' || github.event.label.name == 'run_mf_sql_engine_tests' }}
runs-on: ubuntu-latest
steps:

- name: Check-out the repo
uses: actions/checkout@v3

Expand All @@ -116,7 +112,6 @@ jobs:
if: ${{ github.event_name != 'pull_request' && failure() }}
runs-on: ubuntu-latest
steps:

- uses: actions/checkout@v3

- name: Slack Failure
Expand Down
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ test:
test-postgresql:
hatch -v run postgres-env:pytest -vv -n $(PARALLELISM) $(ADDITIONAL_PYTEST_OPTIONS) metricflow/test/

.PHONY: test-snowflake
test-snowflake:
hatch -v run snowflake-env:pytest -vv -n $(PARALLELISM) $(ADDITIONAL_PYTEST_OPTIONS) metricflow/test/

.PHONY: lint
lint:
hatch -v run dev-env:pre-commit run --all-files
Expand Down
58 changes: 49 additions & 9 deletions metricflow/cli/dbt_connectors/adapter_backed_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import enum
import logging
import textwrap
import time
from typing import Optional, Sequence

import pandas as pd
from dbt.adapters.base.impl import BaseAdapter
from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.pretty_print import pformat_big_objects

from metricflow.dataflow.sql_table import SqlTable
Expand All @@ -16,13 +18,41 @@
from metricflow.protocols.sql_request import SqlJsonTag, SqlRequestId, SqlRequestTagSet
from metricflow.random_id import random_id
from metricflow.sql.render.postgres import PostgresSQLSqlQueryPlanRenderer
from metricflow.sql.render.snowflake import SnowflakeSqlQueryPlanRenderer
from metricflow.sql.render.sql_plan_renderer import SqlQueryPlanRenderer
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql_clients.sql_statement_metadata import CombinedSqlTags, SqlStatementCommentMetadata

logger = logging.getLogger(__name__)


class SupportedAdapterTypes(enum.Enum):
"""Enumeration of supported dbt adapter types."""

POSTGRES = "postgres"
SNOWFLAKE = "snowflake"

@property
def sql_engine_type(self) -> SqlEngine:
"""Return the SqlEngine corresponding to the supported adapter type."""
if self is SupportedAdapterTypes.POSTGRES:
return SqlEngine.POSTGRES
elif self is SupportedAdapterTypes.SNOWFLAKE:
return SqlEngine.SNOWFLAKE
else:
assert_values_exhausted(self)

@property
def sql_query_plan_renderer(self) -> SqlQueryPlanRenderer:
"""Return the SqlQueryPlanRenderer corresponding to the supported adapter type."""
if self is SupportedAdapterTypes.POSTGRES:
return PostgresSQLSqlQueryPlanRenderer()
elif self is SupportedAdapterTypes.SNOWFLAKE:
return SnowflakeSqlQueryPlanRenderer()
else:
assert_values_exhausted(self)


class AdapterBackedSqlClient:
"""SqlClient implementation which delegates database operations to a dbt BaseAdapter instance.
Expand All @@ -38,15 +68,18 @@ def __init__(self, adapter: BaseAdapter):
The dbt BaseAdapter should already be fully initialized, including all credential verification, and
ready for use for establishing connections and issuing queries.
"""
if adapter.type() != "postgres":
raise ValueError(
f"Received dbt adapter with unsupported type {adapter.type()}, but we only support postgres!"
)
self._adapter = adapter
# TODO: normalize from adapter.type()
self._sql_engine_type = SqlEngine.POSTGRES
# TODO: create factory based on SqlEngine type
self._sql_query_plan_renderer = PostgresSQLSqlQueryPlanRenderer()
try:
adapter_type = SupportedAdapterTypes(self._adapter.type())
except ValueError as e:
raise ValueError(
f"Adapter type {self._adapter.type()} is not supported. Must be one "
f"of {[item.value for item in SupportedAdapterTypes]}."
) from e

self._sql_engine_type = adapter_type.sql_engine_type
self._sql_query_plan_renderer = adapter_type.sql_query_plan_renderer
logger.info(f"Initialized AdapterBackedSqlClient with dbt adapter type `{adapter_type.value}`")

@property
def sql_engine_type(self) -> SqlEngine:
Expand Down Expand Up @@ -246,6 +279,11 @@ def _get_type_from_pandas_dtype(self, dtype: str) -> str:
def list_tables(self, schema_name: str) -> Sequence[str]:
"""Get a list of the table names in a given schema. Only used in tutorials and tests."""
# TODO: Short term, make this work with as many engines as possible. Medium term, remove this altogether.
if self.sql_engine_type is SqlEngine.SNOWFLAKE:
# Snowflake likes capitalizing things, except when it doesn't. We can get away with this due to its
# limited scope of usage.
schema_name = schema_name.upper()

df = self.query(
textwrap.dedent(
f"""\
Expand All @@ -257,7 +295,9 @@ def list_tables(self, schema_name: str) -> Sequence[str]:
if df.empty:
return []

# Lower casing table names for consistency between Snowflake and other clients.
# Lower casing table names and data frame names for consistency between Snowflake and other clients.
# As above, we can do this because it isn't used in any consequential situations.
df.columns = df.columns.str.lower()
return [t.lower() for t in df["table_name"]]

def table_exists(self, sql_table: SqlTable) -> bool:
Expand Down
6 changes: 6 additions & 0 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,12 @@ def get_dimension_values( # noqa: D
result_dataframe = query_result.result_df
if result_dataframe is None:
return []
if get_group_by_values not in result_dataframe.columns:
# Snowflake likes upper-casing things in result output, so we lower-case all names and
# see if we can get the value from there.
get_group_by_values = get_group_by_values.lower()
result_dataframe.columns = result_dataframe.columns.str.lower()

return [str(val) for val in result_dataframe[get_group_by_values]]

@log_call(module_name=__name__, telemetry_reporter=_telemetry_reporter)
Expand Down
5 changes: 4 additions & 1 deletion metricflow/test/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,15 @@
TIME_SPINE_TABLE,
TRANSACTIONS_TABLE,
)
from metricflow.protocols.sql_client import SqlEngine
from metricflow.test.fixtures.cli_fixtures import MetricFlowCliRunner


def test_query(cli_runner: MetricFlowCliRunner) -> None: # noqa: D
resp = cli_runner.run(query, args=["--metrics", "bookings", "--group-bys", "ds"])
assert "bookings" in resp.output
# case insensitive matches are needed for snowflake due to the capitalization thing
engine_is_snowflake = cli_runner.cli_context.sql_client.sql_engine_type is SqlEngine.SNOWFLAKE
assert "bookings" in resp.output or ("bookings" in resp.output.lower() and engine_is_snowflake)
assert resp.exit_code == 0


Expand Down
14 changes: 13 additions & 1 deletion metricflow/test/compare_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,24 @@ def _dataframes_contain_same_data(


def assert_dataframes_equal(
actual: pd.DataFrame, expected: pd.DataFrame, sort_columns: bool = True, allow_empty: bool = False
actual: pd.DataFrame,
expected: pd.DataFrame,
sort_columns: bool = True,
allow_empty: bool = False,
compare_names_using_lowercase: bool = False,
) -> None:
"""Check that contents of DataFrames are the same.
If sort_columns is set to false, value and column order needs to be the same.
If compare_names_using_lowercase is set to True, we copy the dataframes and lower-case their names.
This is useful for Snowflake query output comparisons.
"""
if compare_names_using_lowercase:
actual = actual.copy()
expected = expected.copy()
actual.columns = actual.columns.str.lower()
expected.columns = expected.columns.str.lower()

if set(actual.columns) != set(expected.columns):
raise ValueError(
f"DataFrames do not contain the same columns. actual: {set(actual.columns)}, "
Expand Down
4 changes: 3 additions & 1 deletion metricflow/test/execution/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
SelectSqlQueryToTableTask,
)
from metricflow.execution.executor import SequentialPlanExecutor
from metricflow.protocols.sql_client import SqlClient
from metricflow.protocols.sql_client import SqlClient, SqlEngine
from metricflow.random_id import random_id
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.test.compare_df import assert_dataframes_equal
Expand All @@ -32,6 +32,7 @@ def test_read_sql_task(sql_client: SqlClient) -> None: # noqa: D
columns=["foo"],
data=[(1,)],
),
compare_names_using_lowercase=sql_client.sql_engine_type is SqlEngine.SNOWFLAKE,
)


Expand All @@ -56,5 +57,6 @@ def test_write_table_task(mf_test_session_state: MetricFlowTestSessionState, sql
columns=["foo"],
data=[(1,)],
),
compare_names_using_lowercase=sql_client.sql_engine_type is SqlEngine.SNOWFLAKE,
)
sql_client.drop_table(output_table)
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,15 @@ postgres:
pass: "{{ env_var('MF_SQL_ENGINE_PASSWORD') }}"
dbname: "{{ env_var('MF_SQL_ENGINE_DATABASE') }}"
schema: "{{ env_var('MF_SQL_ENGINE_SCHEMA') }}"
snowflake:
target: dev
outputs:
dev:
type: snowflake
# The snowflake account is equivalent to the host value in the SqlAlchemy parsed URL
account: "{{ env_var('MF_SQL_ENGINE_HOST') }}"
user: "{{ env_var('MF_SQL_ENGINE_USER') }}"
password: "{{ env_var('MF_SQL_ENGINE_PASSWORD') }}"
warehouse: "{{ env_var('MF_SQL_ENGINE_WAREHOUSE') }}"
database: "{{ env_var('MF_SQL_ENGINE_DATABASE') }}"
schema: "{{ env_var('MF_SQL_ENGINE_SCHEMA') }}"
10 changes: 8 additions & 2 deletions metricflow/test/fixtures/sql_client_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from metricflow.sql_clients.databricks import DatabricksSqlClient
from metricflow.sql_clients.duckdb import DuckDbSqlClient
from metricflow.sql_clients.redshift import RedshiftSqlClient
from metricflow.sql_clients.snowflake import SnowflakeSqlClient
from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState, dialect_from_url

logger = logging.getLogger(__name__)
Expand All @@ -27,6 +26,7 @@
MF_SQL_ENGINE_PORT = "MF_SQL_ENGINE_PORT"
MF_SQL_ENGINE_USER = "MF_SQL_ENGINE_USER"
MF_SQL_ENGINE_SCHEMA = "MF_SQL_ENGINE_SCHEMA"
MF_SQL_ENGINE_WAREHOUSE = "MF_SQL_ENGINE_WAREHOUSE"


def configure_test_env_from_url(url: str, schema: str) -> sqlalchemy.engine.URL:
Expand Down Expand Up @@ -74,7 +74,13 @@ def make_test_sql_client(url: str, password: str, schema: str) -> SqlClient:
if dialect == SqlDialect.REDSHIFT:
return RedshiftSqlClient.from_connection_details(url, password)
elif dialect == SqlDialect.SNOWFLAKE:
return SnowflakeSqlClient.from_connection_details(url, password)
parsed_url = configure_test_env_from_url(url, schema)
assert "warehouse" in parsed_url.normalized_query, "Sql engine URL params did not include Snowflake warehouse!"
warehouses = parsed_url.normalized_query["warehouse"]
assert len(warehouses) == 1, f"Found more than 1 warehouse in Snowflake URL: `{warehouses}`"
os.environ[MF_SQL_ENGINE_WAREHOUSE] = warehouses[0]
__initialize_dbt()
return AdapterBackedSqlClient(adapter=get_adapter_by_type("snowflake"))
elif dialect == SqlDialect.BIGQUERY:
return BigQuerySqlClient.from_connection_details(url, password)
elif dialect == SqlDialect.POSTGRESQL:
Expand Down
7 changes: 6 additions & 1 deletion metricflow/test/integration/test_write_to_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from metricflow.dataflow.sql_table import SqlTable
from metricflow.engine.metricflow_engine import MetricFlowQueryRequest
from metricflow.protocols.sql_client import SqlEngine
from metricflow.random_id import random_id
from metricflow.test.compare_df import assert_dataframes_equal
from metricflow.test.integration.conftest import IntegrationTestHelpers
Expand Down Expand Up @@ -38,6 +39,10 @@ def test_write_to_table(it_helpers: IntegrationTestHelpers) -> None: # noqa: D
"""
)
)
assert_dataframes_equal(actual=actual, expected=expected)
assert_dataframes_equal(
actual=actual,
expected=expected,
compare_names_using_lowercase=it_helpers.sql_client.sql_engine_type is SqlEngine.SNOWFLAKE,
)
finally:
it_helpers.sql_client.drop_table(output_table)
16 changes: 13 additions & 3 deletions metricflow/test/sql_clients/test_sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@ def _select_x_as_y(x: int = 1, y: str = "y") -> str: # noqa: D
return f"SELECT {x} AS {y}"


def _check_1col(df: pd.DataFrame, col: str = "y", vals: Set[Union[int, str]] = {1}) -> None: # noqa: D
def _check_1col(df: pd.DataFrame, col: str = "y", vals: Set[Union[int, str]] = {1}) -> None:
"""Helper to check that 1 column has the same value and a case-insensitive matching name.
We lower-case the names due to snowflake's tendency to capitalize things. This isn't ideal but it'll do for now.
"""
df.columns = df.columns.str.lower()
col = col.lower()
assert isinstance(df, pd.DataFrame)
assert df.shape == (len(vals), 1)
assert df.columns.tolist() == [col]
Expand Down Expand Up @@ -116,7 +122,11 @@ def test_create_table_from_dataframe( # noqa: D
sql_client.create_table_from_dataframe(sql_table=sql_table, df=expected_df)

actual_df = sql_client.query(f"SELECT * FROM {sql_table.sql}")
assert_dataframes_equal(actual=actual_df, expected=expected_df)
assert_dataframes_equal(
actual=actual_df,
expected=expected_df,
compare_names_using_lowercase=sql_client.sql_engine_type is SqlEngine.SNOWFLAKE,
)


def test_table_exists(mf_test_session_state: MetricFlowTestSessionState, sql_client: SqlClient) -> None: # noqa: D
Expand Down Expand Up @@ -159,7 +169,7 @@ def test_dry_run_of_bad_query_raises_exception(sql_client: SqlClient) -> None:
bad_stmt = "SELECT bad_col"
# Tests that a bad query raises an exception. Different engines may raise different exceptions e.g.
# ProgrammingError, OperationalError, google.api_core.exceptions.BadRequest, etc.
with pytest.raises(Exception, match=r"bad_col"):
with pytest.raises(Exception, match=r"(?i)bad_col"):
sql_client.dry_run(bad_stmt)


Expand Down
3 changes: 2 additions & 1 deletion metricflow/test/table_snapshot/test_source_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

from metricflow.dataflow.sql_table import SqlTable
from metricflow.protocols.sql_client import SqlClient
from metricflow.protocols.sql_client import SqlClient, SqlEngine
from metricflow.test.compare_df import assert_dataframes_equal
from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState
from metricflow.test.fixtures.table_fixtures import CONFIGURED_SOURCE_TABLE_SNAPSHOT_REPOSITORY
Expand Down Expand Up @@ -61,6 +61,7 @@ def test_validate_data_in_source_schema(
assert_dataframes_equal(
actual=actual_table_df,
expected=expected_table_df,
compare_names_using_lowercase=sql_client.sql_engine_type is SqlEngine.SNOWFLAKE,
)
except Exception as e:
error_message = (
Expand Down
Loading

0 comments on commit 8b63d85

Please sign in to comment.