-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat(amazon): deprecate BatchSensorAsync * feat(amazon): remove BatchSensorTrigger
- Loading branch information
Showing
4 changed files
with
24 additions
and
273 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,68 +1,22 @@ | ||
import warnings | ||
from datetime import timedelta | ||
from typing import Any, Dict | ||
|
||
from airflow.providers.amazon.aws.sensors.batch import BatchSensor | ||
|
||
from astronomer.providers.amazon.aws.triggers.batch import BatchSensorTrigger | ||
from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception | ||
from astronomer.providers.utils.typing_compat import Context | ||
|
||
|
||
class BatchSensorAsync(BatchSensor): | ||
""" | ||
Given a job ID of a Batch Job, poll for the job status asynchronously until it | ||
reaches a failure or a success state. | ||
If the job fails, the task will fail. | ||
.. see also:: | ||
For more information on how to use this sensor, take a look at the guide: | ||
:ref:`howto/sensor:BatchSensor` | ||
:param job_id: Batch job_id to check the state for | ||
:param aws_conn_id: aws connection to use, defaults to 'aws_default' | ||
:param region_name: region name to use in AWS Hook | ||
Override the region_name in connection (if provided) | ||
:param poll_interval: polling period in seconds to check for the status of the job | ||
This class is deprecated. | ||
Please use :class: `~airflow.providers.amazon.aws.sensors.batch.BatchSensor`. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
poll_interval: float = 5, | ||
**kwargs: Any, | ||
): | ||
# TODO: Remove once deprecated | ||
if poll_interval: | ||
self.poke_interval = poll_interval | ||
warnings.warn( | ||
"Argument `poll_interval` is deprecated and will be removed " | ||
"in a future release. Please use `poke_interval` instead.", | ||
DeprecationWarning, | ||
stacklevel=2, | ||
) | ||
super().__init__(**kwargs) | ||
|
||
def execute(self, context: Context) -> None: | ||
"""Defers trigger class to poll for state of the job run until it reaches a failure or a success state""" | ||
if not poke(self, context): | ||
self.defer( | ||
timeout=timedelta(seconds=self.timeout), | ||
trigger=BatchSensorTrigger( | ||
job_id=self.job_id, | ||
aws_conn_id=self.aws_conn_id, | ||
region_name=self.region_name, | ||
poke_interval=self.poke_interval, | ||
), | ||
method_name="execute_complete", | ||
) | ||
|
||
def execute_complete(self, context: Context, event: Dict[str, Any]) -> None: | ||
""" | ||
Callback for when the trigger fires - returns immediately. | ||
Relies on trigger to throw an exception, otherwise it assumes execution was | ||
successful. | ||
""" | ||
if "status" in event and event["status"] == "error": | ||
raise_error_or_skip_exception(self.soft_fail, event["message"]) | ||
self.log.info(event["message"]) | ||
def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] | ||
warnings.warn( | ||
( | ||
"This module is deprecated. " | ||
"Please use `airflow.providers.amazon.aws.sensors.batch.BatchSensor` " | ||
"and set deferrable to True instead." | ||
), | ||
DeprecationWarning, | ||
stacklevel=2, | ||
) | ||
return super().__init__(*args, deferrable=True, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,73 +1,17 @@ | ||
from unittest import mock | ||
|
||
import pytest | ||
from airflow.exceptions import AirflowException, TaskDeferred | ||
from airflow.providers.amazon.aws.sensors.batch import BatchSensor | ||
|
||
from astronomer.providers.amazon.aws.sensors.batch import BatchSensorAsync | ||
from astronomer.providers.amazon.aws.triggers.batch import BatchSensorTrigger | ||
|
||
MODULE = "astronomer.providers.amazon.aws.sensors.batch" | ||
|
||
|
||
class TestBatchSensorAsync: | ||
JOB_ID = "8ba9d676-4108-4474-9dca-8bbac1da9b19" | ||
AWS_CONN_ID = "airflow_test" | ||
REGION_NAME = "eu-west-1" | ||
TASK = BatchSensorAsync( | ||
task_id="task", | ||
job_id=JOB_ID, | ||
aws_conn_id=AWS_CONN_ID, | ||
region_name=REGION_NAME, | ||
) | ||
|
||
@mock.patch(f"{MODULE}.BatchSensorAsync.defer") | ||
@mock.patch(f"{MODULE}.BatchSensorAsync.poke", return_value=True) | ||
def test_batch_sensor_async_finish_before_deferred(self, mock_poke, mock_defer, context): | ||
"""Assert task is not deferred when it receives a finish status before deferring""" | ||
self.TASK.execute(context) | ||
assert not mock_defer.called | ||
|
||
@mock.patch(f"{MODULE}.BatchSensorAsync.poke", return_value=False) | ||
def test_batch_sensor_async(self, context): | ||
""" | ||
Asserts that a task is deferred and a BatchSensorTrigger will be fired | ||
when the BatchSensorAsync is executed. | ||
""" | ||
|
||
with pytest.raises(TaskDeferred) as exc: | ||
self.TASK.execute(context) | ||
assert isinstance(exc.value.trigger, BatchSensorTrigger), "Trigger is not a BatchSensorTrigger" | ||
|
||
def test_batch_sensor_async_execute_failure(self, context): | ||
"""Tests that an AirflowException is raised in case of error event""" | ||
|
||
with pytest.raises(AirflowException) as exc_info: | ||
self.TASK.execute_complete( | ||
context=None, event={"status": "error", "message": "test failure message"} | ||
) | ||
|
||
assert str(exc_info.value) == "test failure message" | ||
|
||
@pytest.mark.parametrize( | ||
"event", | ||
[{"status": "success", "message": f"AWS Batch job ({JOB_ID}) succeeded"}], | ||
) | ||
def test_batch_sensor_async_execute_complete(self, caplog, event): | ||
"""Tests that execute_complete method returns None and that it prints expected log""" | ||
|
||
with mock.patch.object(self.TASK.log, "info") as mock_log_info: | ||
assert self.TASK.execute_complete(context=None, event=event) is None | ||
|
||
mock_log_info.assert_called_with(event["message"]) | ||
|
||
def test_poll_interval_deprecation_warning(self): | ||
"""Test DeprecationWarning for BatchSensorAsync by setting param poll_interval""" | ||
# TODO: Remove once deprecated | ||
with pytest.warns(expected_warning=DeprecationWarning): | ||
BatchSensorAsync( | ||
task_id="task", | ||
job_id=self.JOB_ID, | ||
aws_conn_id=self.AWS_CONN_ID, | ||
region_name=self.REGION_NAME, | ||
poll_interval=5.0, | ||
) | ||
def test_init(self): | ||
task = BatchSensorAsync( | ||
task_id="task", | ||
job_id="8ba9d676-4108-4474-9dca-8bbac1da9b19", | ||
aws_conn_id="airflow_test", | ||
region_name="eu-west-1", | ||
) | ||
assert isinstance(task, BatchSensor) | ||
assert task.deferrable is True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters