Skip to content

Commit

Permalink
[DEV-3656]: Pass snowflake connector to feast
Browse files Browse the repository at this point in the history
  • Loading branch information
Chekanin committed Dec 18, 2024
1 parent 08b0b61 commit 3c72945
Show file tree
Hide file tree
Showing 14 changed files with 100 additions and 183 deletions.
69 changes: 22 additions & 47 deletions sdk/python/feast/infra/offline_stores/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,44 +82,10 @@ class SnowflakeOfflineStoreConfig(FeastConfigBaseModel):
type: Literal["snowflake.offline"] = "snowflake.offline"
""" Offline store type selector """

config_path: Optional[str] = os.path.expanduser("~/.snowsql/config")
""" Snowflake snowsql config path -- absolute path required (Cant use ~)"""
snowflake_connection: Any

connection_name: Optional[str] = None
""" Snowflake connector connection name -- typically defined in ~/.snowflake/connections.toml """

account: Optional[str] = None
""" Snowflake deployment identifier -- drop .snowflakecomputing.com """

user: Optional[str] = None
""" Snowflake user name """

password: Optional[str] = None
""" Snowflake password """

role: Optional[str] = None
""" Snowflake role name """

warehouse: Optional[str] = None
""" Snowflake warehouse name """

authenticator: Optional[str] = None
""" Snowflake authenticator name """

private_key: Optional[str] = None
""" Snowflake private key file path"""

private_key_content: Optional[bytes] = None
""" Snowflake private key stored as bytes"""

private_key_passphrase: Optional[str] = None
""" Snowflake private key file passphrase"""

database: StrictStr
""" Snowflake database name """

schema_: Optional[str] = Field("PUBLIC", alias="schema")
""" Snowflake schema name """
temp_intermediate_schema: str
""" Target schema of temporary intermediate tables for historical requests """

storage_integration_name: Optional[str] = None
""" Storage integration name in snowflake """
Expand Down Expand Up @@ -279,6 +245,7 @@ def get_historical_features(
assert isinstance(config.offline_store, SnowflakeOfflineStoreConfig)
for fv in feature_views:
assert isinstance(fv.batch_source, SnowflakeSource)
temp_intermediate_schema = config.offline_config.temp_intermediate_schema

with GetSnowflakeConnection(config.offline_store) as conn:
snowflake_conn = conn
Expand All @@ -299,7 +266,12 @@ def get_historical_features(
def query_generator() -> Iterator[str]:
table_name = offline_utils.get_temp_entity_table_name()

_upload_entity_df(entity_df, snowflake_conn, config, table_name)
_upload_entity_df(
entity_df=entity_df,
snowflake_conn=snowflake_conn,
schema=temp_intermediate_schema,
table_name=table_name,
)

expected_join_keys = offline_utils.get_expected_join_keys(
project, feature_views, registry
Expand All @@ -323,7 +295,7 @@ def query_generator() -> Iterator[str]:
# Generate the Snowflake SQL query from the query context
query = offline_utils.build_point_in_time_query(
query_context,
left_table_query_string=table_name,
left_table_query_string=f"{temp_intermediate_schema}.{table_name}",
entity_df_event_timestamp_col=entity_df_event_timestamp_col,
entity_df_columns=entity_schema.keys(),
query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
Expand Down Expand Up @@ -646,24 +618,25 @@ def _get_entity_schema(
def _upload_entity_df(
entity_df: Union[pd.DataFrame, str],
snowflake_conn: SnowflakeConnection,
config: RepoConfig,
schema: str,
table_name: str,
) -> None:
if isinstance(entity_df, pd.DataFrame):
# Write the data from the DataFrame to the table
# Known issues with following entity data types: BINARY
write_pandas(
snowflake_conn,
entity_df,
table_name,
conn=snowflake_conn,
df=entity_df,
schema=schema,
table_name=table_name,
auto_create_table=True,
create_temp_table=True,
)

return None
elif isinstance(entity_df, str):
# If the entity_df is a string (SQL query), create a Snowflake table out of it,
query = f'CREATE TEMPORARY TABLE "{table_name}" AS ({entity_df})'
query = f'CREATE TEMPORARY TABLE {schema}.{table_name} AS ({entity_df})'
execute_snowflake_statement(snowflake_conn, query)

return None
Expand Down Expand Up @@ -714,6 +687,7 @@ def _get_entity_df_event_timestamp_range(
return entity_df_event_timestamp_range



MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN = """
/*
Compute a deterministic hash for the `left_table_query_string` that will be used throughout
Expand All @@ -728,13 +702,14 @@ def _get_entity_df_event_timestamp_range(
{% for entity in featureview.entities %}
CAST("{{entity}}" AS VARCHAR) ||
{% endfor %}
CAST("{{entity_df_event_timestamp_col}}" AS VARCHAR)
CAST("{{entity_df_event_timestamp_col}}" AS VARCHAR) ||
UUID_STRING()
) AS "{{featureview.name}}__entity_row_unique_id"
{% else %}
,CAST("{{entity_df_event_timestamp_col}}" AS VARCHAR) AS "{{featureview.name}}__entity_row_unique_id"
{% endif %}
{% endfor %}
FROM "{{ left_table_query_string }}"
FROM {{ left_table_query_string }}
),
{% for featureview in featureviews %}
Expand Down Expand Up @@ -871,7 +846,7 @@ def _get_entity_df_event_timestamp_range(
SELECT "{{ final_output_feature_names | join('", "')}}"
FROM "entity_dataframe"
{% for featureview in featureviews %}
LEFT JOIN (
INNER JOIN (
SELECT
"{{featureview.name}}__entity_row_unique_id"
{% for feature in featureview.features %}
Expand Down
1 change: 0 additions & 1 deletion sdk/python/feast/infra/registry/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ class S3RegistryConfig(RegistryConfig):
registry_type: StrictStr = "s3"
s3_resource: typing.Any


@field_validator("s3_resource")
@classmethod
def validate_s3_resource(cls, value: typing.Any) -> typing.Any:
Expand Down
48 changes: 12 additions & 36 deletions sdk/python/feast/infra/registry/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from threading import Lock
from typing import Any, Callable, List, Literal, Optional, Union, cast

from pydantic import ConfigDict, Field, StrictStr
from pydantic import ConfigDict, field_validator

import feast
from feast.base_feature_view import BaseFeatureView
Expand Down Expand Up @@ -79,45 +79,21 @@ class SnowflakeRegistryConfig(RegistryConfig):
type: Literal["snowflake.registry"] = "snowflake.registry"
""" Registry type selector """

config_path: Optional[str] = os.path.expanduser("~/.snowsql/config")
""" Snowflake snowsql config path -- absolute path required (Cant use ~)"""
snowflake_connection: Any

connection_name: Optional[str] = None
""" Snowflake connector connection name -- typically defined in ~/.snowflake/connections.toml """
temp_intermediate_schema: str
""" Target schema of temporary intermediate tables for historical requests """

account: Optional[str] = None
""" Snowflake deployment identifier -- drop .snowflakecomputing.com """

user: Optional[str] = None
""" Snowflake user name """

password: Optional[str] = None
""" Snowflake password """

role: Optional[str] = None
""" Snowflake role name """

warehouse: Optional[str] = None
""" Snowflake warehouse name """

authenticator: Optional[str] = None
""" Snowflake authenticator name """

private_key: Optional[str] = None
""" Snowflake private key file path"""

private_key_content: Optional[bytes] = None
""" Snowflake private key stored as bytes"""

private_key_passphrase: Optional[str] = None
""" Snowflake private key file passphrase"""
model_config = ConfigDict(populate_by_name=True)

database: StrictStr
""" Snowflake database name """
@field_validator("snowflake_connection")
@classmethod
def validate_snowflake_connection(cls, value: Any) -> Any:
from snowflake.connector import SnowflakeConnection

schema_: Optional[str] = Field("PUBLIC", alias="schema")
""" Snowflake schema name """
model_config = ConfigDict(populate_by_name=True)
if not isinstance(value, SnowflakeConnection):
raise ValueError("s3_resource must be an instance of boto3.resources.base.ServiceResource")
return value


class SnowflakeRegistry(BaseRegistry):
Expand Down
55 changes: 2 additions & 53 deletions sdk/python/feast/infra/utils/snowflake/snowflake_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,66 +47,15 @@ def __init__(self, config: Any, autocommit=True):
self.config = config
self.autocommit = autocommit

def __enter__(self):
def __enter__(self) -> SnowflakeConnection:
assert self.config.type in [
"snowflake.registry",
"snowflake.offline",
"snowflake.engine",
"snowflake.online",
]

if self.config.type not in _cache:
if self.config.type == "snowflake.registry":
config_header = "connections.feast_registry"
elif self.config.type == "snowflake.offline":
config_header = "connections.feast_offline_store"
if self.config.type == "snowflake.engine":
config_header = "connections.feast_batch_engine"
elif self.config.type == "snowflake.online":
config_header = "connections.feast_online_store"

config_dict = dict(self.config)

# read config file
config_reader = configparser.ConfigParser()
config_reader.read([config_dict["config_path"]])
kwargs: Dict[str, Any] = {}
if config_reader.has_section(config_header):
kwargs = dict(config_reader[config_header])

kwargs.update((k, v) for k, v in config_dict.items() if v is not None)

for k, v in kwargs.items():
if k in ["role", "warehouse", "database", "schema_"]:
kwargs[k] = f'"{v}"'

kwargs["schema"] = kwargs.pop("schema_")

# https://docs.snowflake.com/en/user-guide/python-connector-example.html#using-key-pair-authentication-key-pair-rotation
# https://docs.snowflake.com/en/user-guide/key-pair-auth.html#configuring-key-pair-authentication
if "private_key" in kwargs or "private_key_content" in kwargs:
kwargs["private_key"] = parse_private_key_path(
kwargs.get("private_key_passphrase"),
kwargs.get("private_key"),
kwargs.get("private_key_content"),
)

try:
_cache[self.config.type] = snowflake.connector.connect(
application="feast",
client_session_keep_alive=True,
autocommit=self.autocommit,
**kwargs,
)
_cache[self.config.type].cursor().execute(
"ALTER SESSION SET TIMEZONE = 'UTC'", _is_internal=True
)

except KeyError as e:
raise SnowflakeIncompleteConfig(e)

self.client = _cache[self.config.type]
return self.client
return self.config.snowflake_connection

def __exit__(self, exc_type, exc_val, exc_tb):
pass
Expand Down
18 changes: 9 additions & 9 deletions sdk/python/requirements/py3.10-ci-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file was autogenerated by uv via the following command:
# uv pip compile -p 3.10 --system --no-strip-extras setup.py --extra ci --output-file sdk/python/requirements/py3.10-ci-requirements.txt
aiobotocore==2.15.2
aiobotocore==2.16.0
# via feast (setup.py)
aiohappyeyeballs==2.4.4
# via aiohttp
Expand Down Expand Up @@ -69,19 +69,19 @@ bigtree==0.22.3
# via feast (setup.py)
bleach==6.2.0
# via nbconvert
boto3==1.35.36
boto3==1.35.81
# via
# feast (setup.py)
# moto
boto3-stubs[essential, s3]==1.35.82
boto3-stubs[essential, s3]==1.35.83
# via feast (setup.py)
botocore==1.35.36
botocore==1.35.81
# via
# aiobotocore
# boto3
# moto
# s3transfer
botocore-stubs==1.35.82
botocore-stubs==1.35.83
# via boto3-stubs
build==1.2.2.post1
# via
Expand Down Expand Up @@ -150,11 +150,11 @@ cryptography==42.0.8
# types-redis
cython==3.0.11
# via thriftpy2
dask[dataframe]==2024.12.0
dask[dataframe]==2024.12.1
# via
# feast (setup.py)
# dask-expr
dask-expr==1.1.20
dask-expr==1.1.21
# via dask
db-dtypes==1.3.1
# via google-cloud-bigquery
Expand Down Expand Up @@ -408,7 +408,7 @@ jupyter-core==5.7.2
# nbclient
# nbconvert
# nbformat
jupyter-events==0.10.0
jupyter-events==0.11.0
# via jupyter-server
jupyter-lsp==2.2.5
# via jupyterlab
Expand Down Expand Up @@ -490,7 +490,7 @@ mypy-boto3-lambda==1.35.68
# via boto3-stubs
mypy-boto3-rds==1.35.82
# via boto3-stubs
mypy-boto3-s3==1.35.76.post1
mypy-boto3-s3==1.35.81
# via boto3-stubs
mypy-boto3-sqs==1.35.0
# via boto3-stubs
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/requirements/py3.10-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ cloudpickle==3.1.0
# via dask
colorama==0.4.6
# via feast (setup.py)
dask[dataframe]==2024.12.0
dask[dataframe]==2024.12.1
# via
# feast (setup.py)
# dask-expr
dask-expr==1.1.20
dask-expr==1.1.21
# via dask
dill==0.3.9
# via feast (setup.py)
Expand Down
Loading

0 comments on commit 3c72945

Please sign in to comment.