Skip to content

Commit

Permalink
Fix failure due to latest wave rc provider release (#1522)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W authored May 6, 2024
1 parent 72df85b commit 17b0fe1
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 20 deletions.
20 changes: 11 additions & 9 deletions .circleci/integration-tests/Dockerfile.astro_cloud
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,25 @@ ENV JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64/

RUN apt-get update -y \
&& apt-get install -y \
git \
unzip \
git \
unzip \
&& apt-get install -y --no-install-recommends \
build-essential \
libsasl2-2 \
libsasl2-dev \
libsasl2-modules \
build-essential \
libsasl2-2 \
libsasl2-dev \
libsasl2-modules \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get install -y curl gnupg \
&& echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list \
&& curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - \
&& apt-get update -y \
&& apt-get install -y \
google-cloud-sdk \
google-cloud-sdk-gke-gcloud-auth-plugin \
jq
google-cloud-sdk \
google-cloud-sdk-gke-gcloud-auth-plugin \
jq \
pkg-config libxml2-dev libxmlsec1-dev libxmlsec1-openssl


# Set Hive and Hadoop versions.
ENV HIVE_LIBRARY_VERSION=hive-2.3.9
Expand Down
2 changes: 1 addition & 1 deletion astronomer/providers/amazon/aws/example_dags/example_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
S3KeysUnchangedSensorAsync,
)

S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME", "test-bucket-astronomer-providers")
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME", "test-astronomer-providers-bucket")
S3_BUCKET_KEY = os.getenv("S3_BUCKET_KEY", "test")
S3_BUCKET_KEY_LIST = os.getenv("S3_BUCKET_KEY_LIST", "test2")
S3_BUCKET_WILDCARD_KEY = os.getenv("S3_BUCKET_WILDCARD_KEY", "test*")
Expand Down
14 changes: 11 additions & 3 deletions astronomer/providers/snowflake/hooks/snowflake_sql_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import uuid
import warnings
from datetime import timedelta
from functools import cached_property
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -76,6 +77,13 @@ def __init__(
super().__init__(snowflake_conn_id, *args, **kwargs)
self.private_key: Any = None

@cached_property
def _get_conn_params(self) -> dict[str, str | None]:
# for apache-airflow-providers-snowflake<5.5.0
if callable(super()._get_conn_params):
return super()._get_conn_params() # type: ignore[no-any-return,operator]
return super()._get_conn_params

def get_private_key(self) -> None:
"""Gets the private key from snowflake connection"""
conn = self.get_connection(self.snowflake_conn_id)
Expand Down Expand Up @@ -127,7 +135,7 @@ def execute_query(
When executing the statement, Snowflake replaces placeholders (? and :name) in
the statement with these specified values.
"""
conn_config = self._get_conn_params()
conn_config = self._get_conn_params

req_id = uuid.uuid4()
url = "https://{}.snowflakecomputing.com/api/v2/statements".format(conn_config["account"])
Expand Down Expand Up @@ -171,7 +179,7 @@ def get_headers(self) -> dict[str, Any]:
"""Based on the private key, and with connection details JWT Token is generated and header is formed"""
if not self.private_key:
self.get_private_key()
conn_config = self._get_conn_params()
conn_config = self._get_conn_params

# Get the JWT token from the connection details and the private key
token = JWTGenerator(
Expand All @@ -197,7 +205,7 @@ def get_request_url_header_params(self, query_id: str) -> tuple[dict[str, Any],
:param query_id: statement handles query ids for the individual statements.
"""
conn_config = self._get_conn_params()
conn_config = self._get_conn_params
req_id = uuid.uuid4()
header = self.get_headers()
params = {"requestId": str(req_id), "page": 2, "pageSize": 10}
Expand Down
11 changes: 6 additions & 5 deletions dev/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ FROM ${IMAGE_NAME}
USER root
RUN apt-get update -y && apt-get install -y git
RUN apt-get install -y --no-install-recommends \
build-essential \
libsasl2-2 \
libsasl2-dev \
libsasl2-modules \
jq
build-essential \
libsasl2-2 \
libsasl2-dev \
libsasl2-modules \
jq \
pkg-config libxml2-dev libxmlsec1-dev libxmlsec1-openssl

COPY setup.cfg ${AIRFLOW_HOME}/astronomer_providers/setup.cfg
COPY pyproject.toml ${AIRFLOW_HOME}/astronomer_providers/pyproject.toml
Expand Down
23 changes: 21 additions & 2 deletions tests/snowflake/hooks/test_snowflake_sql_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ def test_check_query_output_exception(self, mock_geturl_header_params, mock_requ
assert airflow_exception

@mock.patch(
"astronomer.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHookAsync._get_conn_params"
"astronomer.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHookAsync._get_conn_params",
new_callable=mock.PropertyMock,
)
@mock.patch("astronomer.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHookAsync.get_headers")
def test_get_request_url_header_params(self, mock_get_header, mock_conn_param):
Expand All @@ -250,7 +251,8 @@ def test_get_request_url_header_params(self, mock_get_header, mock_conn_param):
"astronomer.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHookAsync.get_private_key"
)
@mock.patch(
"astronomer.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHookAsync._get_conn_params"
"astronomer.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHookAsync._get_conn_params",
new_callable=mock.PropertyMock,
)
@mock.patch("astronomer.providers.snowflake.hooks.sql_api_generate_jwt.JWTGenerator.get_token")
def test_get_headers(self, mock_get_token, mock_conn_param, mock_private_key):
Expand All @@ -261,6 +263,23 @@ def test_get_headers(self, mock_get_token, mock_conn_param, mock_private_key):
result = hook.get_headers()
assert result == HEADERS

@mock.patch(
"airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params",
new_callable=mock.PropertyMock,
)
def test__get_conn_params__with_property_upstream(self, mock_conn_param):
mock_conn_param.return_value = CONN_PARAMS

hook = SnowflakeSqlApiHookAsync(snowflake_conn_id="mock_conn_id")
assert hook._get_conn_params == CONN_PARAMS

@mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params")
def test__get_conn_params__with_callable_upstream(self, mock_conn_param):
mock_conn_param.return_value = CONN_PARAMS

hook = SnowflakeSqlApiHookAsync(snowflake_conn_id="mock_conn_id")
assert hook._get_conn_params == CONN_PARAMS

@pytest.fixture()
def non_encrypted_temporary_private_key(self, tmp_path: Path) -> Path:
"""Encrypt the pem file from the path"""
Expand Down

0 comments on commit 17b0fe1

Please sign in to comment.