Skip to content

Commit

Permalink
Merge pull request #631 from dbt-labs/cut-snowflake-to-dbt-adapters
Browse files Browse the repository at this point in the history
  • Loading branch information
tlento authored Jun 28, 2023
2 parents 8acdd22 + 30ec98a commit 9e3412e
Show file tree
Hide file tree
Showing 16 changed files with 155 additions and 284 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230627-234808.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Enable Snowflake queries in dbt <-> MetricFlow CLI integration
time: 2023-06-27T23:48:08.287586-07:00
custom:
Author: tlento
Issue: "579"
7 changes: 1 addition & 6 deletions .github/workflows/cd-sql-engine-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ jobs:
name: Snowflake Tests
runs-on: ubuntu-latest
steps:

- name: Check-out the repo
uses: actions/checkout@v3

Expand All @@ -33,14 +32,14 @@ jobs:
mf_sql_engine_password: ${{ secrets.MF_SNOWFLAKE_PWD }}
parallelism: ${{ env.EXTERNAL_ENGINE_TEST_PARALLELISM }}
additional-pytest-options: ${{ env.ADDITIONAL_PYTEST_OPTIONS }}
make-target: "test-snowflake"

redshift-tests:
environment: DW_INTEGRATION_TESTS
name: Redshift Tests
if: ${{ github.event.action != 'labeled' || github.event.label.name == 'run_mf_sql_engine_tests' }}
runs-on: ubuntu-latest
steps:

- name: Check-out the repo
uses: actions/checkout@v3

Expand All @@ -59,7 +58,6 @@ jobs:
if: ${{ github.event.action != 'labeled' || github.event.label.name == 'run_mf_sql_engine_tests' }}
runs-on: ubuntu-latest
steps:

- name: Check-out the repo
uses: actions/checkout@v3

Expand All @@ -78,7 +76,6 @@ jobs:
if: ${{ github.event.action != 'labeled' || github.event.label.name == 'run_mf_sql_engine_tests' }}
runs-on: ubuntu-latest
steps:

- name: Check-out the repo
uses: actions/checkout@v3

Expand All @@ -97,7 +94,6 @@ jobs:
if: ${{ github.event.action != 'labeled' || github.event.label.name == 'run_mf_sql_engine_tests' }}
runs-on: ubuntu-latest
steps:

- name: Check-out the repo
uses: actions/checkout@v3

Expand All @@ -116,7 +112,6 @@ jobs:
if: ${{ github.event_name != 'pull_request' && failure() }}
runs-on: ubuntu-latest
steps:

- uses: actions/checkout@v3

- name: Slack Failure
Expand Down
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ test:
test-postgresql:
hatch -v run postgres-env:pytest -vv -n $(PARALLELISM) $(ADDITIONAL_PYTEST_OPTIONS) metricflow/test/

.PHONY: test-snowflake
test-snowflake:
hatch -v run snowflake-env:pytest -vv -n $(PARALLELISM) $(ADDITIONAL_PYTEST_OPTIONS) metricflow/test/

.PHONY: lint
lint:
hatch -v run dev-env:pre-commit run --all-files
Expand Down
58 changes: 49 additions & 9 deletions metricflow/cli/dbt_connectors/adapter_backed_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import enum
import logging
import textwrap
import time
from typing import Optional, Sequence

import pandas as pd
from dbt.adapters.base.impl import BaseAdapter
from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.pretty_print import pformat_big_objects

from metricflow.dataflow.sql_table import SqlTable
Expand All @@ -16,13 +18,41 @@
from metricflow.protocols.sql_request import SqlJsonTag, SqlRequestId, SqlRequestTagSet
from metricflow.random_id import random_id
from metricflow.sql.render.postgres import PostgresSQLSqlQueryPlanRenderer
from metricflow.sql.render.snowflake import SnowflakeSqlQueryPlanRenderer
from metricflow.sql.render.sql_plan_renderer import SqlQueryPlanRenderer
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql_clients.sql_statement_metadata import CombinedSqlTags, SqlStatementCommentMetadata

logger = logging.getLogger(__name__)


class SupportedAdapterTypes(enum.Enum):
"""Enumeration of supported dbt adapter types."""

POSTGRES = "postgres"
SNOWFLAKE = "snowflake"

@property
def sql_engine_type(self) -> SqlEngine:
"""Return the SqlEngine corresponding to the supported adapter type."""
if self is SupportedAdapterTypes.POSTGRES:
return SqlEngine.POSTGRES
elif self is SupportedAdapterTypes.SNOWFLAKE:
return SqlEngine.SNOWFLAKE
else:
assert_values_exhausted(self)

@property
def sql_query_plan_renderer(self) -> SqlQueryPlanRenderer:
"""Return the SqlQueryPlanRenderer corresponding to the supported adapter type."""
if self is SupportedAdapterTypes.POSTGRES:
return PostgresSQLSqlQueryPlanRenderer()
elif self is SupportedAdapterTypes.SNOWFLAKE:
return SnowflakeSqlQueryPlanRenderer()
else:
assert_values_exhausted(self)


class AdapterBackedSqlClient:
"""SqlClient implementation which delegates database operations to a dbt BaseAdapter instance.
Expand All @@ -38,15 +68,18 @@ def __init__(self, adapter: BaseAdapter):
The dbt BaseAdapter should already be fully initialized, including all credential verification, and
ready for use for establishing connections and issuing queries.
"""
if adapter.type() != "postgres":
raise ValueError(
f"Received dbt adapter with unsupported type {adapter.type()}, but we only support postgres!"
)
self._adapter = adapter
# TODO: normalize from adapter.type()
self._sql_engine_type = SqlEngine.POSTGRES
# TODO: create factory based on SqlEngine type
self._sql_query_plan_renderer = PostgresSQLSqlQueryPlanRenderer()
try:
adapter_type = SupportedAdapterTypes(self._adapter.type())
except ValueError as e:
raise ValueError(
f"Adapter type {self._adapter.type()} is not supported. Must be one "
f"of {[item.value for item in SupportedAdapterTypes]}."
) from e

self._sql_engine_type = adapter_type.sql_engine_type
self._sql_query_plan_renderer = adapter_type.sql_query_plan_renderer
logger.info(f"Initialized AdapterBackedSqlClient with dbt adapter type `{adapter_type.value}`")

@property
def sql_engine_type(self) -> SqlEngine:
Expand Down Expand Up @@ -246,6 +279,11 @@ def _get_type_from_pandas_dtype(self, dtype: str) -> str:
def list_tables(self, schema_name: str) -> Sequence[str]:
"""Get a list of the table names in a given schema. Only used in tutorials and tests."""
# TODO: Short term, make this work with as many engines as possible. Medium term, remove this altogether.
if self.sql_engine_type is SqlEngine.SNOWFLAKE:
# Snowflake likes capitalizing things, except when it doesn't. We can get away with this due to its
# limited scope of usage.
schema_name = schema_name.upper()

df = self.query(
textwrap.dedent(
f"""\
Expand All @@ -257,7 +295,9 @@ def list_tables(self, schema_name: str) -> Sequence[str]:
if df.empty:
return []

# Lower casing table names for consistency between Snowflake and other clients.
# Lower casing table names and data frame names for consistency between Snowflake and other clients.
# As above, we can do this because it isn't used in any consequential situations.
df.columns = df.columns.str.lower()
return [t.lower() for t in df["table_name"]]

def table_exists(self, sql_table: SqlTable) -> bool:
Expand Down
6 changes: 6 additions & 0 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,12 @@ def get_dimension_values( # noqa: D
result_dataframe = query_result.result_df
if result_dataframe is None:
return []
if get_group_by_values not in result_dataframe.columns:
# Snowflake likes upper-casing things in result output, so we lower-case all names and
# see if we can get the value from there.
get_group_by_values = get_group_by_values.lower()
result_dataframe.columns = result_dataframe.columns.str.lower()

return [str(val) for val in result_dataframe[get_group_by_values]]

@log_call(module_name=__name__, telemetry_reporter=_telemetry_reporter)
Expand Down
Loading

0 comments on commit 9e3412e

Please sign in to comment.