-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(amazon): deprecate RedshiftDataOperatorAsync, RedshiftSQLOperato…
…rAsync
- Loading branch information
Showing
6 changed files
with
81 additions
and
319 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
94 changes: 17 additions & 77 deletions
94
astronomer/providers/amazon/aws/operators/redshift_sql.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,90 +1,30 @@ | ||
from typing import Any, cast | ||
import warnings | ||
from typing import Any | ||
|
||
from airflow.exceptions import AirflowException | ||
from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator | ||
|
||
try: | ||
from airflow.providers.amazon.aws.operators.redshift_sql import RedshiftSQLOperator | ||
except ImportError: # pragma: no cover | ||
# For apache-airflow-providers-amazon > 6.0.0 | ||
# currently added type: ignore[no-redef, attr-defined] and pragma: no cover because this import | ||
# path won't be available in current setup | ||
from airflow.providers.common.sql.operators.sql import ( | ||
SQLExecuteQueryOperator as RedshiftSQLOperator, | ||
) | ||
|
||
from astronomer.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook | ||
from astronomer.providers.amazon.aws.triggers.redshift_sql import RedshiftSQLTrigger | ||
from astronomer.providers.utils.typing_compat import Context | ||
|
||
|
||
class RedshiftSQLOperatorAsync(RedshiftSQLOperator): | ||
class RedshiftSQLOperatorAsync(RedshiftDataOperator): | ||
""" | ||
Executes SQL Statements against an Amazon Redshift cluster" | ||
:param sql: the SQL code to be executed as a single string, or | ||
a list of str (sql statements), or a reference to a template file. | ||
Template references are recognized by str ending in '.sql' | ||
:param redshift_conn_id: reference to Amazon Redshift connection id | ||
:param parameters: (optional) the parameters to render the SQL query with. | ||
:param autocommit: if True, each command is automatically committed. | ||
(default value: False) | ||
This class is deprecated. | ||
Please use :class: `~airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator` | ||
and set `deferrable` param to `True` instead. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
redshift_conn_id: str = "redshift_default", | ||
poll_interval: float = 5, | ||
poll_interval: int = 5, | ||
**kwargs: Any, | ||
) -> None: | ||
self.redshift_conn_id = redshift_conn_id | ||
self.poll_interval = poll_interval | ||
if self.__class__.__base__.__name__ == "RedshiftSQLOperator": # type: ignore[union-attr] | ||
# It's better to do str check of the parent class name because currently RedshiftSQLOperator | ||
# is deprecated and in future OSS RedshiftSQLOperator may be removed | ||
super().__init__(**kwargs) | ||
else: | ||
super().__init__(conn_id=redshift_conn_id, **kwargs) # pragma: no cover | ||
|
||
def execute(self, context: Context) -> None: | ||
""" | ||
Makes a sync call to RedshiftDataHook and execute the query and gets back the query_ids list and | ||
defers trigger to poll for the status for the query executed | ||
""" | ||
redshift_data_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id) | ||
query_ids, response = redshift_data_hook.execute_query(sql=cast(str, self.sql), params=self.params) | ||
if response.get("status") == "error": | ||
self.execute_complete(cast(Context, {}), response) | ||
return | ||
context["ti"].xcom_push(key="return_value", value=query_ids) | ||
|
||
if redshift_data_hook.queries_are_completed(query_ids, context): | ||
self.log.info("%s completed successfully.", self.task_id) | ||
return | ||
|
||
self.defer( | ||
timeout=self.execution_timeout, | ||
trigger=RedshiftSQLTrigger( | ||
task_id=self.task_id, | ||
polling_period_seconds=self.poll_interval, | ||
aws_conn_id=self.redshift_conn_id, | ||
query_ids=query_ids, | ||
warnings.warn( | ||
( | ||
"This module is deprecated and will be removed in 2.0.0." | ||
"Please use `airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator`" | ||
"and set `deferrable` param to `True` instead." | ||
), | ||
method_name="execute_complete", | ||
DeprecationWarning, | ||
stacklevel=2, | ||
) | ||
|
||
def execute_complete(self, context: Context, event: Any = None) -> None: | ||
""" | ||
Callback for when the trigger fires - returns immediately. | ||
Relies on trigger to throw an exception, otherwise it assumes execution was | ||
successful. | ||
""" | ||
if event: | ||
if "status" in event and event["status"] == "error": | ||
raise AirflowException(event["message"]) | ||
elif "status" in event and event["status"] == "success": | ||
self.log.info("%s completed successfully.", self.task_id) | ||
return | ||
else: | ||
self.log.info("%s completed successfully.", self.task_id) | ||
return | ||
kwargs["poll_interval"] = poll_interval | ||
super().__init__(deferrable=True, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,101 +1,16 @@ | ||
from unittest import mock | ||
|
||
import pytest | ||
from airflow.exceptions import AirflowException, TaskDeferred | ||
from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator | ||
|
||
from astronomer.providers.amazon.aws.operators.redshift_data import ( | ||
RedshiftDataOperatorAsync, | ||
) | ||
from astronomer.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger | ||
from tests.utils.airflow_util import create_context | ||
|
||
|
||
class TestRedshiftDataOperatorAsync: | ||
DATABASE_NAME = "TEST_DATABASE" | ||
TASK_ID = "fetch_data" | ||
SQL_QUERY = "select * from any" | ||
TASK = RedshiftDataOperatorAsync( | ||
task_id=TASK_ID, | ||
sql=SQL_QUERY, | ||
database=DATABASE_NAME, | ||
) | ||
|
||
@mock.patch("astronomer.providers.amazon.aws.operators.redshift_data.RedshiftDataOperatorAsync.defer") | ||
@mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") | ||
@mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") | ||
def test_redshift_data_op_async_finished_before_deferred(self, mock_execute, mock_conn, mock_defer): | ||
mock_execute.return_value = ["test_query_id"], {} | ||
mock_conn.describe_statement.return_value = { | ||
"Status": "FINISHED", | ||
} | ||
self.TASK.execute(create_context(self.TASK)) | ||
assert not mock_defer.called | ||
|
||
@mock.patch("astronomer.providers.amazon.aws.operators.redshift_data.RedshiftDataOperatorAsync.defer") | ||
@mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") | ||
@mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") | ||
def test_redshift_data_op_async_aborted_before_deferred(self, mock_execute, mock_conn, mock_defer): | ||
mock_execute.return_value = ["test_query_id"], {} | ||
mock_conn.describe_statement.return_value = {"Status": "ABORTED"} | ||
|
||
with pytest.raises(AirflowException): | ||
self.TASK.execute(create_context(self.TASK)) | ||
|
||
assert not mock_defer.called | ||
|
||
@mock.patch("astronomer.providers.amazon.aws.operators.redshift_data.RedshiftDataOperatorAsync.defer") | ||
@mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") | ||
@mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") | ||
def test_redshift_data_op_async_failed_before_deferred(self, mock_execute, mock_conn, mock_defer): | ||
mock_execute.return_value = ["test_query_id"], {} | ||
mock_conn.describe_statement.return_value = { | ||
"Status": "FAILED", | ||
"QueryString": "test query", | ||
"Error": "test error", | ||
} | ||
|
||
with pytest.raises(AirflowException): | ||
self.TASK.execute(create_context(self.TASK)) | ||
|
||
assert not mock_defer.called | ||
|
||
@pytest.mark.parametrize("status", ("SUBMITTED", "PICKED", "STARTED")) | ||
@mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") | ||
@mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") | ||
def test_redshift_data_op_async(self, mock_execute, mock_conn, status): | ||
mock_execute.return_value = ["test_query_id"], {} | ||
mock_conn.describe_statement.return_value = {"Status": status} | ||
|
||
with pytest.raises(TaskDeferred) as exc: | ||
self.TASK.execute(create_context(self.TASK)) | ||
assert isinstance(exc.value.trigger, RedshiftDataTrigger), "Trigger is not a RedshiftDataTrigger" | ||
|
||
@mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") | ||
def test_redshift_data_op_async_execute_query_error(self, mock_execute, context): | ||
mock_execute.return_value = [], {"status": "error", "message": "Test exception"} | ||
with pytest.raises(AirflowException): | ||
self.TASK.execute(context) | ||
|
||
def test_redshift_data_op_async_execute_failure(self, context): | ||
"""Tests that an AirflowException is raised in case of error event""" | ||
|
||
with pytest.raises(AirflowException): | ||
self.TASK.execute_complete( | ||
context=None, event={"status": "error", "message": "test failure message"} | ||
) | ||
|
||
@pytest.mark.parametrize( | ||
"event", | ||
[None, {"status": "success", "message": "Job completed"}], | ||
) | ||
def test_redshift_data_op_async_execute_complete(self, event): | ||
"""Asserts that logging occurs as expected""" | ||
|
||
if not event: | ||
with pytest.raises(AirflowException) as exception_info: | ||
self.TASK.execute_complete(context=None, event=None) | ||
assert exception_info.value.args[0] == "Did not receive valid event from the trigerrer" | ||
else: | ||
with mock.patch.object(self.TASK.log, "info") as mock_log_info: | ||
self.TASK.execute_complete(context=None, event=event) | ||
mock_log_info.assert_called_with("%s completed successfully.", self.TASK_ID) | ||
def test_init(self): | ||
task = RedshiftDataOperatorAsync( | ||
task_id="fetch_data", | ||
sql="select * from any", | ||
database="TEST_DATABASE", | ||
) | ||
assert isinstance(task, RedshiftDataOperator) | ||
assert task.deferrable is True |
Oops, something went wrong.