diff --git a/astronomer/providers/microsoft/azure/hooks/data_factory.py b/astronomer/providers/microsoft/azure/hooks/data_factory.py index 268be4dea..64622015b 100644 --- a/astronomer/providers/microsoft/azure/hooks/data_factory.py +++ b/astronomer/providers/microsoft/azure/hooks/data_factory.py @@ -2,6 +2,7 @@ from __future__ import annotations import inspect +import warnings from functools import wraps from typing import Any, TypeVar, Union, cast @@ -68,13 +69,20 @@ async def bind_argument(arg: Any, default_key: str) -> None: class AzureDataFactoryHookAsync(AzureDataFactoryHook): """ - An Async Hook connects to Azure DataFactory to perform pipeline operations. - - :param azure_data_factory_conn_id: The :ref:`Azure Data Factory connection id`. + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook` instead. """ def __init__(self, azure_data_factory_conn_id: str): """Initialize the hook instance.""" + warnings.warn( + ( + "This class is deprecated and will be removed in 2.0.0." + "Use :class: `~airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook` instead." + ), + DeprecationWarning, + stacklevel=2, + ) self._async_conn: DataFactoryManagementClient | None = None self.conn_id = azure_data_factory_conn_id super().__init__(azure_data_factory_conn_id=azure_data_factory_conn_id) diff --git a/astronomer/providers/microsoft/azure/sensors/data_factory.py b/astronomer/providers/microsoft/azure/sensors/data_factory.py index 0d820846b..ac9413ac6 100644 --- a/astronomer/providers/microsoft/azure/sensors/data_factory.py +++ b/astronomer/providers/microsoft/azure/sensors/data_factory.py @@ -1,68 +1,29 @@ import warnings -from datetime import timedelta -from typing import Any, Dict +from typing import Any -from airflow.providers.microsoft.azure.sensors.data_factory import ( - AzureDataFactoryPipelineRunStatusSensor, -) - -from astronomer.providers.microsoft.azure.triggers.data_factory import ( - ADFPipelineRunStatusSensorTrigger, -) -from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception -from astronomer.providers.utils.typing_compat import Context +from airflow.providers.microsoft.azure.sensors.data_factory import AzureDataFactoryPipelineRunStatusSensor class AzureDataFactoryPipelineRunStatusSensorAsync(AzureDataFactoryPipelineRunStatusSensor): """ - Checks the status of a pipeline run. - - :param azure_data_factory_conn_id: The connection identifier for connecting to Azure Data Factory. - :param run_id: The pipeline run identifier. - :param resource_group_name: The resource group name. - :param factory_name: The data factory name. - :param poll_interval: polling period in seconds to check for the status + This class is deprecated. + Use :class: `~airflow.providers.microsoft.azure.sensors.data_factory.AzureDataFactoryPipelineRunStatusSensor` + instead and set `deferrable` param to `True` instead. """ def __init__( self, - *, + *args: Any, poll_interval: float = 5, **kwargs: Any, ): # TODO: Remove once deprecated if poll_interval: - self.poke_interval = poll_interval + kwargs["poke_interval"] = poll_interval warnings.warn( "Argument `poll_interval` is deprecated and will be removed " "in a future release. Please use `poke_interval` instead.", DeprecationWarning, stacklevel=2, ) - super().__init__(**kwargs) - - def execute(self, context: Context) -> None: - """Defers trigger class to poll for state of the job run until it reaches a failure state or success state""" - if not poke(self, context): - self.defer( - timeout=timedelta(seconds=self.timeout), - trigger=ADFPipelineRunStatusSensorTrigger( - run_id=self.run_id, - azure_data_factory_conn_id=self.azure_data_factory_conn_id, - resource_group_name=self.resource_group_name, - factory_name=self.factory_name, - poke_interval=self.poke_interval, - ), - method_name="execute_complete", - ) - - def execute_complete(self, context: Context, event: Dict[str, str]) -> None: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event: - if event["status"] == "error": - raise_error_or_skip_exception(self.soft_fail, event["message"]) - self.log.info(event["message"]) + super().__init__(*args, deferrable=True, **kwargs) diff --git a/astronomer/providers/microsoft/azure/triggers/data_factory.py b/astronomer/providers/microsoft/azure/triggers/data_factory.py index 1628ddd6b..29b727bf3 100644 --- a/astronomer/providers/microsoft/azure/triggers/data_factory.py +++ b/astronomer/providers/microsoft/azure/triggers/data_factory.py @@ -14,14 +14,8 @@ class ADFPipelineRunStatusSensorTrigger(BaseTrigger): """ - ADFPipelineRunStatusSensorTrigger is fired as deferred class with params to run the task in trigger worker, when - ADF Pipeline is running - - :param run_id: The pipeline run identifier. - :param azure_data_factory_conn_id: The connection identifier for connecting to Azure Data Factory. - :param poke_interval: polling period in seconds to check for the status - :param resource_group_name: The resource group name. - :param factory_name: The data factory name. + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.microsoft.azure.triggers.data_factory.ADFPipelineRunStatusSensorTrigger` instead. """ def __init__( diff --git a/tests/microsoft/azure/sensors/test_data_factory.py b/tests/microsoft/azure/sensors/test_data_factory.py index da1810614..1ca574bc0 100644 --- a/tests/microsoft/azure/sensors/test_data_factory.py +++ b/tests/microsoft/azure/sensors/test_data_factory.py @@ -1,62 +1,24 @@ -from unittest import mock - import pytest -from airflow.exceptions import AirflowException, TaskDeferred +from airflow.providers.microsoft.azure.sensors.data_factory import AzureDataFactoryPipelineRunStatusSensor from astronomer.providers.microsoft.azure.sensors.data_factory import ( AzureDataFactoryPipelineRunStatusSensorAsync, ) -from astronomer.providers.microsoft.azure.triggers.data_factory import ( - ADFPipelineRunStatusSensorTrigger, -) -from tests.utils.airflow_util import create_context - -MODULE = "astronomer.providers.microsoft.azure.sensors.data_factory" class TestAzureDataFactoryPipelineRunStatusSensorAsync: RUN_ID = "7f8c6c72-c093-11ec-a83d-0242ac120007" - SENSOR = AzureDataFactoryPipelineRunStatusSensorAsync( - task_id="pipeline_run_sensor_async", - run_id=RUN_ID, - factory_name="factory_name", - resource_group_name="resource_group_name", - ) - - @mock.patch(f"{MODULE}.AzureDataFactoryPipelineRunStatusSensorAsync.defer") - @mock.patch(f"{MODULE}.AzureDataFactoryPipelineRunStatusSensorAsync.poke", return_value=True) - def test_adf_pipeline_status_sensor_async_finish_before_deferred( - self, - mock_poke, - mock_defer, - ): - """Assert task is not deferred when it receives a finish status before deferring""" - self.SENSOR.execute(create_context(self.SENSOR)) - assert not mock_defer.called - - @mock.patch(f"{MODULE}.AzureDataFactoryPipelineRunStatusSensorAsync.poke", return_value=False) - def test_adf_pipeline_status_sensor_async(self, mock_poke): - """Assert execute method defer for Azure Data factory pipeline run status sensor""" - - with pytest.raises(TaskDeferred) as exc: - self.SENSOR.execute(create_context(self.SENSOR)) - assert isinstance( - exc.value.trigger, ADFPipelineRunStatusSensorTrigger - ), "Trigger is not a ADFPipelineRunStatusSensorTrigger" - - def test_adf_pipeline_status_sensor_execute_complete_success(self): - """Assert execute_complete log success message when trigger fire with target status""" - - msg = f"Pipeline run {self.RUN_ID} has been succeeded." - with mock.patch.object(self.SENSOR.log, "info") as mock_log_info: - self.SENSOR.execute_complete(context={}, event={"status": "success", "message": msg}) - mock_log_info.assert_called_with(msg) - def test_adf_pipeline_status_sensor_execute_complete_failure(self): - """Assert execute_complete method fail""" + def test_init(self): + task = AzureDataFactoryPipelineRunStatusSensorAsync( + task_id="pipeline_run_sensor_async", + run_id=self.RUN_ID, + factory_name="factory_name", + resource_group_name="resource_group_name", + ) - with pytest.raises(AirflowException): - self.SENSOR.execute_complete(context={}, event={"status": "error", "message": ""}) + assert isinstance(task, AzureDataFactoryPipelineRunStatusSensor) + assert task.deferrable is True def test_poll_interval_deprecation_warning(self): """Test DeprecationWarning for AzureDataFactoryPipelineRunStatusSensorAsync by setting param poll_interval"""