Skip to content

Commit

Permalink
Deprecate BatchSensorAsync (#1391)
Browse files Browse the repository at this point in the history
* feat(amazon): deprecate BatchSensorAsync
* feat(amazon): remove BatchSensorTrigger
  • Loading branch information
Lee-W authored Dec 21, 2023
1 parent d6cbd8e commit 37f2d35
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 273 deletions.
72 changes: 13 additions & 59 deletions astronomer/providers/amazon/aws/sensors/batch.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,22 @@
import warnings
from datetime import timedelta
from typing import Any, Dict

from airflow.providers.amazon.aws.sensors.batch import BatchSensor

from astronomer.providers.amazon.aws.triggers.batch import BatchSensorTrigger
from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception
from astronomer.providers.utils.typing_compat import Context


class BatchSensorAsync(BatchSensor):
"""
Given a job ID of a Batch Job, poll for the job status asynchronously until it
reaches a failure or a success state.
If the job fails, the task will fail.
.. see also::
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:BatchSensor`
:param job_id: Batch job_id to check the state for
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
:param region_name: region name to use in AWS Hook
Override the region_name in connection (if provided)
:param poll_interval: polling period in seconds to check for the status of the job
This class is deprecated.
Please use :class: `~airflow.providers.amazon.aws.sensors.batch.BatchSensor`.
"""

def __init__(
self,
*,
poll_interval: float = 5,
**kwargs: Any,
):
# TODO: Remove once deprecated
if poll_interval:
self.poke_interval = poll_interval
warnings.warn(
"Argument `poll_interval` is deprecated and will be removed "
"in a future release. Please use `poke_interval` instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(**kwargs)

def execute(self, context: Context) -> None:
"""Defers trigger class to poll for state of the job run until it reaches a failure or a success state"""
if not poke(self, context):
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=BatchSensorTrigger(
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
poke_interval=self.poke_interval,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: Dict[str, Any]) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if "status" in event and event["status"] == "error":
raise_error_or_skip_exception(self.soft_fail, event["message"])
self.log.info(event["message"])
def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
warnings.warn(
(
"This module is deprecated. "
"Please use `airflow.providers.amazon.aws.sensors.batch.BatchSensor` "
"and set deferrable to True instead."
),
DeprecationWarning,
stacklevel=2,
)
return super().__init__(*args, deferrable=True, **kwargs)
62 changes: 0 additions & 62 deletions astronomer/providers/amazon/aws/triggers/batch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
from typing import Any, AsyncIterator, Dict, Optional, Tuple

from airflow.triggers.base import BaseTrigger, TriggerEvent
Expand Down Expand Up @@ -71,64 +70,3 @@ async def run(self) -> AsyncIterator["TriggerEvent"]:
yield TriggerEvent({"status": "error", "message": error_message})
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})


class BatchSensorTrigger(BaseTrigger):
"""
Checks for the status of a submitted job_id to AWS Batch until it reaches a failure or a success state.
BatchSensorTrigger is fired as deferred class with params to poll the job state in Triggerer
:param job_id: the job ID, to poll for job completion or not
:param aws_conn_id: connection id of AWS credentials / region name. If None,
credential boto3 strategy will be used
:param region_name: AWS region name to use
Override the region_name in connection (if provided)
:param poke_interval: polling period in seconds to check for the status of the job
"""

def __init__(
self,
job_id: str,
region_name: Optional[str],
aws_conn_id: Optional[str] = "aws_default",
poke_interval: float = 5,
):
super().__init__()
self.job_id = job_id
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.poke_interval = poke_interval

def serialize(self) -> Tuple[str, Dict[str, Any]]:
"""Serializes BatchSensorTrigger arguments and classpath."""
return (
"astronomer.providers.amazon.aws.triggers.batch.BatchSensorTrigger",
{
"job_id": self.job_id,
"aws_conn_id": self.aws_conn_id,
"region_name": self.region_name,
"poke_interval": self.poke_interval,
},
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
"""
Make async connection using aiobotocore library to AWS Batch,
periodically poll for the Batch job status
The status that indicates job completion are: 'SUCCEEDED'|'FAILED'.
"""
hook = BatchClientHookAsync(job_id=self.job_id, aws_conn_id=self.aws_conn_id)
try:
while True:
response = await hook.get_job_description(self.job_id)
state = response["status"]
if state == BatchClientHookAsync.SUCCESS_STATE:
success_message = f"{self.job_id} was completed successfully"
yield TriggerEvent({"status": "success", "message": success_message})
if state == BatchClientHookAsync.FAILURE_STATE:
error_message = f"{self.job_id} failed"
yield TriggerEvent({"status": "error", "message": error_message})
await asyncio.sleep(self.poke_interval)
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})
76 changes: 10 additions & 66 deletions tests/amazon/aws/sensors/test_batch_sensors.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,17 @@
from unittest import mock

import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.sensors.batch import BatchSensor

from astronomer.providers.amazon.aws.sensors.batch import BatchSensorAsync
from astronomer.providers.amazon.aws.triggers.batch import BatchSensorTrigger

MODULE = "astronomer.providers.amazon.aws.sensors.batch"


class TestBatchSensorAsync:
JOB_ID = "8ba9d676-4108-4474-9dca-8bbac1da9b19"
AWS_CONN_ID = "airflow_test"
REGION_NAME = "eu-west-1"
TASK = BatchSensorAsync(
task_id="task",
job_id=JOB_ID,
aws_conn_id=AWS_CONN_ID,
region_name=REGION_NAME,
)

@mock.patch(f"{MODULE}.BatchSensorAsync.defer")
@mock.patch(f"{MODULE}.BatchSensorAsync.poke", return_value=True)
def test_batch_sensor_async_finish_before_deferred(self, mock_poke, mock_defer, context):
"""Assert task is not deferred when it receives a finish status before deferring"""
self.TASK.execute(context)
assert not mock_defer.called

@mock.patch(f"{MODULE}.BatchSensorAsync.poke", return_value=False)
def test_batch_sensor_async(self, context):
"""
Asserts that a task is deferred and a BatchSensorTrigger will be fired
when the BatchSensorAsync is executed.
"""

with pytest.raises(TaskDeferred) as exc:
self.TASK.execute(context)
assert isinstance(exc.value.trigger, BatchSensorTrigger), "Trigger is not a BatchSensorTrigger"

def test_batch_sensor_async_execute_failure(self, context):
"""Tests that an AirflowException is raised in case of error event"""

with pytest.raises(AirflowException) as exc_info:
self.TASK.execute_complete(
context=None, event={"status": "error", "message": "test failure message"}
)

assert str(exc_info.value) == "test failure message"

@pytest.mark.parametrize(
"event",
[{"status": "success", "message": f"AWS Batch job ({JOB_ID}) succeeded"}],
)
def test_batch_sensor_async_execute_complete(self, caplog, event):
"""Tests that execute_complete method returns None and that it prints expected log"""

with mock.patch.object(self.TASK.log, "info") as mock_log_info:
assert self.TASK.execute_complete(context=None, event=event) is None

mock_log_info.assert_called_with(event["message"])

def test_poll_interval_deprecation_warning(self):
"""Test DeprecationWarning for BatchSensorAsync by setting param poll_interval"""
# TODO: Remove once deprecated
with pytest.warns(expected_warning=DeprecationWarning):
BatchSensorAsync(
task_id="task",
job_id=self.JOB_ID,
aws_conn_id=self.AWS_CONN_ID,
region_name=self.REGION_NAME,
poll_interval=5.0,
)
def test_init(self):
task = BatchSensorAsync(
task_id="task",
job_id="8ba9d676-4108-4474-9dca-8bbac1da9b19",
aws_conn_id="airflow_test",
region_name="eu-west-1",
)
assert isinstance(task, BatchSensor)
assert task.deferrable is True
87 changes: 1 addition & 86 deletions tests/amazon/aws/triggers/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
import pytest
from airflow.triggers.base import TriggerEvent

from astronomer.providers.amazon.aws.triggers.batch import (
BatchOperatorTrigger,
BatchSensorTrigger,
)
from astronomer.providers.amazon.aws.triggers.batch import BatchOperatorTrigger

JOB_NAME = "51455483-c62c-48ac-9b88-53a6a725baa3"
JOB_ID = "8ba9d676-4108-4474-9dca-8bbac1da9b19"
Expand Down Expand Up @@ -94,85 +91,3 @@ async def test_batch_trigger_exception(self, mock_response):
task = [i async for i in self.TRIGGER.run()]
assert len(task) == 1
assert TriggerEvent({"status": "error", "message": "Test exception"}) in task


class TestBatchSensorTrigger:
TRIGGER = BatchSensorTrigger(
job_id=JOB_ID,
region_name=REGION_NAME,
aws_conn_id=AWS_CONN_ID,
poke_interval=POKE_INTERVAL,
)

def test_batch_sensor_trigger_serialization(self):
"""
Asserts that the BatchSensorTrigger correctly serializes its arguments
and classpath.
"""

classpath, kwargs = self.TRIGGER.serialize()
assert classpath == "astronomer.providers.amazon.aws.triggers.batch.BatchSensorTrigger"
assert kwargs == {
"job_id": JOB_ID,
"region_name": "eu-west-1",
"aws_conn_id": "airflow_test",
"poke_interval": POKE_INTERVAL,
}

@pytest.mark.asyncio
@mock.patch("astronomer.providers.amazon.aws.hooks.batch_client.BatchClientHookAsync.get_job_description")
async def test_batch_sensor_trigger_run(self, mock_response):
"""Trigger the BatchSensorTrigger and check if the task is in running state."""
mock_response.return_value = {"status": "RUNNABLE"}

task = asyncio.create_task(self.TRIGGER.run().__anext__())
await asyncio.sleep(0.5)
# TriggerEvent was not returned
assert task.done() is False
asyncio.get_event_loop().stop()

@pytest.mark.asyncio
@mock.patch("astronomer.providers.amazon.aws.hooks.batch_client.BatchClientHookAsync.get_job_description")
async def test_batch_sensor_trigger_completed(self, mock_response):
"""Test if the success event is returned from trigger."""
mock_response.return_value = {"status": "SUCCEEDED"}
trigger = BatchSensorTrigger(
job_id=JOB_ID,
region_name=REGION_NAME,
aws_conn_id=AWS_CONN_ID,
)
generator = trigger.run()
actual_response = await generator.asend(None)
assert (
TriggerEvent({"status": "success", "message": f"{JOB_ID} was completed successfully"})
== actual_response
)

@pytest.mark.asyncio
@mock.patch("astronomer.providers.amazon.aws.hooks.batch_client.BatchClientHookAsync.get_job_description")
async def test_batch_sensor_trigger_failure(self, mock_response):
"""Test if the failure event is returned from trigger."""
mock_response.return_value = {"status": "FAILED"}
trigger = BatchSensorTrigger(
job_id=JOB_ID,
region_name=REGION_NAME,
aws_conn_id=AWS_CONN_ID,
)
generator = trigger.run()
actual_response = await generator.asend(None)
assert TriggerEvent({"status": "error", "message": f"{JOB_ID} failed"}) == actual_response

@pytest.mark.asyncio
@mock.patch("astronomer.providers.amazon.aws.hooks.batch_client.BatchClientHookAsync.get_job_description")
async def test_batch_sensor_trigger_exception(self, mock_response):
"""Test if the exception is raised from trigger."""
mock_response.side_effect = Exception("Test exception")
trigger = BatchSensorTrigger(
job_id=JOB_ID,
region_name=REGION_NAME,
aws_conn_id=AWS_CONN_ID,
)
task = [i async for i in trigger.run()]
assert len(task) == 1

assert TriggerEvent({"status": "error", "message": "Test exception"}) in task

0 comments on commit 37f2d35

Please sign in to comment.