From 15cb47f8c588137e1a9e7222abc8652b751e21cb Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Tue, 23 Jan 2024 18:20:09 +0530 Subject: [PATCH 1/2] Deprecate DbtCloudRunJobOperatorAsync and DbtCloudJobRunSensorAsync This PR deprecates the operator DbtCloudRunJobOperatorAsync and the sensor DbtCloudJobRunSensorAsync from the dbt provider by proxying them to their Airflow OSS provider's counterpart. closes: #1414 --- astronomer/providers/dbt/cloud/hooks/dbt.py | 13 +- .../providers/dbt/cloud/operators/dbt.py | 87 +++------- astronomer/providers/dbt/cloud/sensors/dbt.py | 78 +++------ .../providers/dbt/cloud/triggers/dbt.py | 19 +- setup.cfg | 4 +- tests/dbt/cloud/operators/test_dbt.py | 162 +----------------- tests/dbt/cloud/sensors/test_dbt.py | 67 +------- 7 files changed, 79 insertions(+), 351 deletions(-) diff --git a/astronomer/providers/dbt/cloud/hooks/dbt.py b/astronomer/providers/dbt/cloud/hooks/dbt.py index 57b4e8fd8..42b41661e 100644 --- a/astronomer/providers/dbt/cloud/hooks/dbt.py +++ b/astronomer/providers/dbt/cloud/hooks/dbt.py @@ -42,9 +42,8 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: class DbtCloudHookAsync(BaseHook): """ - Interact with dbt Cloud using the V2 API. - - :param dbt_cloud_conn_id: The ID of the :ref:`dbt Cloud connection `. + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook` instead. """ conn_name_attr = "dbt_cloud_conn_id" @@ -53,6 +52,14 @@ class DbtCloudHookAsync(BaseHook): hook_name = "dbt Cloud" def __init__(self, dbt_cloud_conn_id: str): + warnings.warn( + ( + "This class is deprecated. " + "Use `airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook` instead." + ), + DeprecationWarning, + stacklevel=2, + ) self.dbt_cloud_conn_id = dbt_cloud_conn_id async def get_headers_tenants_from_connection(self) -> Tuple[Dict[str, Any], str]: diff --git a/astronomer/providers/dbt/cloud/operators/dbt.py b/astronomer/providers/dbt/cloud/operators/dbt.py index b394e4c2e..6a8f48c4b 100644 --- a/astronomer/providers/dbt/cloud/operators/dbt.py +++ b/astronomer/providers/dbt/cloud/operators/dbt.py @@ -1,86 +1,33 @@ from __future__ import annotations -import time +import warnings from typing import Any from airflow import AirflowException from airflow.exceptions import AirflowFailException -from airflow.providers.dbt.cloud.hooks.dbt import ( - DbtCloudHook, - DbtCloudJobRunException, - DbtCloudJobRunStatus, - JobRunInfo, -) from airflow.providers.dbt.cloud.operators.dbt import DbtCloudRunJobOperator -from astronomer.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger from astronomer.providers.utils.typing_compat import Context class DbtCloudRunJobOperatorAsync(DbtCloudRunJobOperator): """ - Executes a dbt Cloud job asynchronously. Trigger the dbt cloud job via worker to dbt and with run id in response - poll for the status in trigger. - - .. seealso:: - For more information on sync Operator DbtCloudRunJobOperator, take a look at the guide: - :ref:`howto/operator:DbtCloudRunJobOperator` - - :param dbt_cloud_conn_id: The connection ID for connecting to dbt Cloud. - :param job_id: The ID of a dbt Cloud job. - :param account_id: Optional. The ID of a dbt Cloud account. - :param trigger_reason: Optional Description of the reason to trigger the job. Dbt requires the trigger reason while - making an API. if it is not provided uses the default reasons. - :param steps_override: Optional. List of dbt commands to execute when triggering the job instead of those - configured in dbt Cloud. - :param schema_override: Optional. Override the destination schema in the configured target for this job. - :param timeout: Time in seconds to wait for a job run to reach a terminal status. Defaults to 7 days. - :param check_interval: Time in seconds to check on a job run's status. Defaults to 60 seconds. - :param additional_run_config: Optional. Any additional parameters that should be included in the API - request when triggering the job. - :return: The ID of the triggered dbt Cloud job run. + This class is deprecated. + Use :class: `~airflow.providers.dbt.cloud.operators.dbt.DbtCloudRunJobOperator` instead + and set `deferrable` param to `True` instead. """ - def execute(self, context: Context) -> Any: - """Submits a job which generates a run_id and gets deferred""" - if self.trigger_reason is None: - self.trigger_reason = ( - f"Triggered via Apache Airflow by task {self.task_id!r} in the {self.dag.dag_id} DAG." - ) - hook = DbtCloudHook(dbt_cloud_conn_id=self.dbt_cloud_conn_id) - trigger_job_response = hook.trigger_job_run( - account_id=self.account_id, - job_id=self.job_id, - cause=self.trigger_reason, - steps_override=self.steps_override, - schema_override=self.schema_override, - additional_run_config=self.additional_run_config, + def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn( + ( + "This class is deprecated. " + "Use `airflow.providers.dbt.cloud.operators.dbt.DbtCloudRunJobOperator` " + "and set `deferrable` param to `True` instead." + ), + DeprecationWarning, + stacklevel=2, ) - run_id = trigger_job_response.json()["data"]["id"] - job_run_url = trigger_job_response.json()["data"]["href"] - - context["ti"].xcom_push(key="job_run_url", value=job_run_url) - end_time = time.time() + self.timeout - - job_run_info = JobRunInfo(account_id=self.account_id, run_id=run_id) - job_run_status = hook.get_job_run_status(**job_run_info) - if not DbtCloudJobRunStatus.is_terminal(job_run_status): - self.defer( - timeout=self.execution_timeout, - trigger=DbtCloudRunJobTrigger( - conn_id=self.dbt_cloud_conn_id, - run_id=run_id, - end_time=end_time, - account_id=self.account_id, - poll_interval=self.check_interval, - ), - method_name="execute_complete", - ) - elif job_run_status == DbtCloudJobRunStatus.SUCCESS.value: - self.log.info("Job run %s has completed successfully.", str(run_id)) - return run_id - elif job_run_status in (DbtCloudJobRunStatus.CANCELLED.value, DbtCloudJobRunStatus.ERROR.value): - raise DbtCloudJobRunException(f"Job run {run_id} has failed or has been cancelled.") + super().__init__(*args, deferrable=True, **kwargs) def execute_complete(self, context: Context, event: dict[str, Any]) -> int: """ @@ -88,6 +35,12 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> int: Relies on trigger to throw an exception, otherwise it assumes execution was successful. """ + # We handle the case where the job run is cancelled a bit differently than the OSS operator. + # Essentially, we do not want to retry the task if the job run is cancelled, whereas the OSS operator will + # retry the task if the job run is cancelled. This has been specifically handled here differently based upon + # the feedback from a user. And hence, while we are deprecating this operator, we are not changing the behavior + # of the `execute_complete` method. We can check if the wider OSS community wants this behavior to be changed + # in the future as it is here, and then we can remove this override. if event["status"] == "cancelled": self.log.info("Job run %s has been cancelled.", str(event["run_id"])) self.log.info("Task will not be retried.") diff --git a/astronomer/providers/dbt/cloud/sensors/dbt.py b/astronomer/providers/dbt/cloud/sensors/dbt.py index c5e73930a..e6b1683e8 100644 --- a/astronomer/providers/dbt/cloud/sensors/dbt.py +++ b/astronomer/providers/dbt/cloud/sensors/dbt.py @@ -1,61 +1,35 @@ -import time -from typing import Any, Dict +from __future__ import annotations -from airflow.providers.dbt.cloud.sensors.dbt import DbtCloudJobRunSensor +import warnings +from typing import Any -from astronomer.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger -from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception -from astronomer.providers.utils.typing_compat import Context +from airflow.providers.dbt.cloud.sensors.dbt import DbtCloudJobRunSensor class DbtCloudJobRunSensorAsync(DbtCloudJobRunSensor): """ - Checks the status of a dbt Cloud job run. - - .. seealso:: - For more information on sync Sensor DbtCloudJobRunSensor, take a look at the guide:: - :ref:`howto/operator:DbtCloudJobRunSensor` - - :param dbt_cloud_conn_id: The connection identifier for connecting to dbt Cloud. - :param run_id: The job run identifier. - :param account_id: The dbt Cloud account identifier. - :param timeout: Time in seconds to wait for a job run to reach a terminal status. Defaults to 7 days. + This class is deprecated. + Use :class: `~airflow.providers.dbt.cloud.sensors.dbt.DbtCloudJobRunSensor` instead + and set `deferrable` param to `True` instead. """ - def __init__( - self, - *, - poll_interval: float = 5, - timeout: float = 60 * 60 * 24 * 7, - **kwargs: Any, - ): - self.poll_interval = poll_interval - self.timeout = timeout - super().__init__(**kwargs) - - def execute(self, context: "Context") -> None: - """Defers trigger class to poll for state of the job run until it reaches a failure state or success state""" - if not poke(self, context): - end_time = time.time() + self.timeout - self.defer( - timeout=self.execution_timeout, - trigger=DbtCloudRunJobTrigger( - run_id=self.run_id, - conn_id=self.dbt_cloud_conn_id, - account_id=self.account_id, - poll_interval=self.poll_interval, - end_time=end_time, - ), - method_name="execute_complete", + def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn( + ( + "This class is deprecated. " + "Use `airflow.providers.dbt.cloud.sensors.dbt.DbtCloudJobRunSensor` " + "and set `deferrable` param to `True` instead." + ), + DeprecationWarning, + stacklevel=2, + ) + # TODO: Remove once deprecated + if kwargs.get("poll_interval"): + warnings.warn( + "Argument `poll_interval` is deprecated and will be removed " + "in a future release. Please use `poke_interval` instead.", + DeprecationWarning, + stacklevel=2, ) - - def execute_complete(self, context: "Context", event: Dict[str, Any]) -> int: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event["status"] in ["error", "cancelled"]: - raise_error_or_skip_exception(self.soft_fail, event["message"]) - self.log.info(event["message"]) - return int(event["run_id"]) + kwargs["poke_interval"] = kwargs.pop("poll_interval") + super().__init__(*args, deferrable=True, **kwargs) diff --git a/astronomer/providers/dbt/cloud/triggers/dbt.py b/astronomer/providers/dbt/cloud/triggers/dbt.py index 7b63beccf..93fa1791d 100644 --- a/astronomer/providers/dbt/cloud/triggers/dbt.py +++ b/astronomer/providers/dbt/cloud/triggers/dbt.py @@ -1,5 +1,6 @@ import asyncio import time +import warnings from typing import Any, AsyncIterator, Dict, Optional, Tuple from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudJobRunStatus @@ -10,14 +11,8 @@ class DbtCloudRunJobTrigger(BaseTrigger): """ - DbtCloudRunJobTrigger is triggered with run id and account id, makes async Http call to dbt and get the status - for the submitted job with run id in polling interval of time. - - :param conn_id: The connection identifier for connecting to Dbt. - :param run_id: The ID of a dbt Cloud job. - :param end_time: Time in seconds to wait for a job run to reach a terminal status. Defaults to 7 days. - :param account_id: The ID of a dbt Cloud account. - :param poll_interval: polling period in seconds to check for the status. + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.dbt.triggers.dbt.DbtCloudRunJobTrigger` instead. """ def __init__( @@ -28,6 +23,14 @@ def __init__( poll_interval: float, account_id: Optional[int], ): + warnings.warn( + ( + "This class is deprecated. " + "Use `airflow.providers.dbt.triggers.dbt.DbtCloudRunJobTrigger` instead." + ), + DeprecationWarning, + stacklevel=2, + ) super().__init__() self.run_id = run_id self.account_id = account_id diff --git a/setup.cfg b/setup.cfg index 868385db7..ab161f7d5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -59,7 +59,7 @@ databricks = apache-airflow-providers-databricks>=2.2.0 databricks-sql-connector>=2.0.4;python_version>='3.10' dbt.cloud = - apache-airflow-providers-dbt-cloud>=2.1.0 + apache-airflow-providers-dbt-cloud>=3.5.1 google = apache-airflow-providers-google>=8.1.0 gcloud-aio-storage @@ -130,7 +130,7 @@ all = apache-airflow-providers-microsoft-azure>=8.5.1 asyncssh>=2.12.0 databricks-sql-connector>=2.0.4;python_version>='3.10' - apache-airflow-providers-dbt-cloud>=2.1.0 + apache-airflow-providers-dbt-cloud>=3.5.1 gcloud-aio-bigquery gcloud-aio-storage kubernetes_asyncio diff --git a/tests/dbt/cloud/operators/test_dbt.py b/tests/dbt/cloud/operators/test_dbt.py index ec78e655e..b124ea06a 100644 --- a/tests/dbt/cloud/operators/test_dbt.py +++ b/tests/dbt/cloud/operators/test_dbt.py @@ -1,18 +1,8 @@ -from datetime import datetime -from unittest import mock - -import pytest -from airflow.exceptions import AirflowException, AirflowFailException, TaskDeferred -from airflow.models import DAG, DagRun, TaskInstance -from airflow.providers.dbt.cloud.hooks.dbt import ( - DbtCloudJobRunException, - DbtCloudJobRunStatus, -) +from airflow.models import DAG +from airflow.providers.dbt.cloud.operators.dbt import DbtCloudRunJobOperator from airflow.utils import timezone -from airflow.utils.types import DagRunType from astronomer.providers.dbt.cloud.operators.dbt import DbtCloudRunJobOperatorAsync -from astronomer.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger class TestDbtCloudRunJobOperatorAsync: @@ -24,36 +14,8 @@ class TestDbtCloudRunJobOperatorAsync: DEFAULT_DATE = timezone.datetime(2021, 1, 1) dag = DAG("test_dbt_cloud_job_run_op", start_date=DEFAULT_DATE) - def create_context(self, task): - execution_date = datetime(2022, 1, 1, 0, 0, 0) - dag_run = DagRun( - dag_id=self.dag.dag_id, - execution_date=execution_date, - run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date), - ) - task_instance = TaskInstance(task=task) - task_instance.dag_run = dag_run - task_instance.dag_id = self.dag.dag_id - task_instance.xcom_push = mock.Mock() - return { - "dag": self.dag, - "run_id": dag_run.run_id, - "task": task, - "ti": task_instance, - "task_instance": task_instance, - } - - @mock.patch( - "airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_run_status", - return_value=DbtCloudJobRunStatus.SUCCESS.value, - ) - @mock.patch("astronomer.providers.dbt.cloud.operators.dbt.DbtCloudRunJobOperatorAsync.defer") - @mock.patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_connection") - @mock.patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.trigger_job_run") - def test_dbt_run_job_op_async_succeeded_before_deferred( - self, mock_trigger_job_run, mock_dbt_hook, mock_defer, mock_job_run_status - ): - dbt_op = DbtCloudRunJobOperatorAsync( + def test_init(self): + task = DbtCloudRunJobOperatorAsync( dbt_cloud_conn_id=self.CONN_ID, task_id=f"{self.TASK_ID}", job_id=self.DBT_RUN_ID, @@ -61,118 +23,6 @@ def test_dbt_run_job_op_async_succeeded_before_deferred( timeout=self.TIMEOUT, dag=self.dag, ) - dbt_op.execute(self.create_context(dbt_op)) - assert not mock_defer.called - - @pytest.mark.parametrize( - "status", (DbtCloudJobRunStatus.CANCELLED.value, DbtCloudJobRunStatus.ERROR.value) - ) - @mock.patch( - "airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_run_status", - ) - @mock.patch("astronomer.providers.dbt.cloud.operators.dbt.DbtCloudRunJobOperatorAsync.defer") - @mock.patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_connection") - @mock.patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.trigger_job_run") - def test_dbt_run_job_op_async_failed_before_deferred( - self, mock_trigger_job_run, mock_dbt_hook, mock_defer, mock_job_run_status, status - ): - mock_job_run_status.return_value = status - dbt_op = DbtCloudRunJobOperatorAsync( - dbt_cloud_conn_id=self.CONN_ID, - task_id=f"{self.TASK_ID}{status}", - job_id=self.DBT_RUN_ID, - check_interval=self.CHECK_INTERVAL, - timeout=self.TIMEOUT, - dag=self.dag, - ) - with pytest.raises(DbtCloudJobRunException): - dbt_op.execute(self.create_context(dbt_op)) - assert not mock_defer.called - - @pytest.mark.parametrize( - "status", - ( - DbtCloudJobRunStatus.QUEUED.value, - DbtCloudJobRunStatus.STARTING.value, - DbtCloudJobRunStatus.RUNNING.value, - ), - ) - @mock.patch( - "airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_run_status", - ) - @mock.patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_connection") - @mock.patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.trigger_job_run") - def test_dbt_run_job_op_async(self, mock_trigger_job_run, mock_dbt_hook, mock_job_run_status, status): - """ - Asserts that a task is deferred and an DbtCloudRunJobTrigger will be fired - when the DbtCloudRunJobOperatorAsync is provided with all required arguments - """ - mock_job_run_status.return_value = status - dbt_op = DbtCloudRunJobOperatorAsync( - dbt_cloud_conn_id=self.CONN_ID, - task_id=f"{self.TASK_ID}{status}", - job_id=self.DBT_RUN_ID, - check_interval=self.CHECK_INTERVAL, - timeout=self.TIMEOUT, - dag=self.dag, - ) - with pytest.raises(TaskDeferred) as exc: - dbt_op.execute(self.create_context(dbt_op)) - - assert isinstance(exc.value.trigger, DbtCloudRunJobTrigger), "Trigger is not a DbtCloudRunJobTrigger" - - def test_dbt_run_job_op_with_exception(self): - """Test DbtCloudRunJobOperatorAsync to raise exception""" - dbt_op = DbtCloudRunJobOperatorAsync( - dbt_cloud_conn_id=self.CONN_ID, - task_id=self.TASK_ID, - job_id=self.DBT_RUN_ID, - check_interval=self.CHECK_INTERVAL, - timeout=self.TIMEOUT, - ) - with pytest.raises(AirflowException): - dbt_op.execute_complete( - context=None, event={"status": "error", "message": "test failure message"} - ) - - def test_dbt_run_job_cancelled_exception(self, caplog): - """Test DbtCloudRunJobOperatorAsync to raise exception when job is cancelled""" - dbt_op = DbtCloudRunJobOperatorAsync( - dbt_cloud_conn_id=self.CONN_ID, - task_id=self.TASK_ID, - job_id=self.DBT_RUN_ID, - check_interval=self.CHECK_INTERVAL, - timeout=self.TIMEOUT, - ) - with pytest.raises(AirflowFailException) as exc: - dbt_op.execute_complete( - context=None, - event={ - "status": "cancelled", - "message": f"Job run {self.DBT_RUN_ID} has been cancelled.", - "run_id": self.DBT_RUN_ID, - }, - ) - assert f"Job run {self.DBT_RUN_ID} has been cancelled." in str(exc.value) - assert "Task will not be retried." in caplog.text - - @pytest.mark.parametrize( - "mock_event", - [ - ({"status": "success", "message": "Job run 48617 has completed successfully.", "run_id": 1234}), - ], - ) - def test_dbt_job_execute_complete(self, mock_event): - """Test DbtCloudRunJobOperatorAsync by mocking the success response and assert the log and return value""" - dbt_op = DbtCloudRunJobOperatorAsync( - dbt_cloud_conn_id=self.CONN_ID, - task_id=self.TASK_ID, - job_id=self.DBT_RUN_ID, - check_interval=self.CHECK_INTERVAL, - timeout=self.TIMEOUT, - ) - - with mock.patch.object(dbt_op.log, "info") as mock_log_info: - assert dbt_op.execute_complete(context=None, event=mock_event) == self.DBT_RUN_ID - mock_log_info.assert_called_with("Job run 48617 has completed successfully.") + assert isinstance(task, DbtCloudRunJobOperator) + assert task.deferrable is True diff --git a/tests/dbt/cloud/sensors/test_dbt.py b/tests/dbt/cloud/sensors/test_dbt.py index 305def612..e51d9e097 100644 --- a/tests/dbt/cloud/sensors/test_dbt.py +++ b/tests/dbt/cloud/sensors/test_dbt.py @@ -1,13 +1,6 @@ -from unittest import mock - -import pytest -from airflow import AirflowException -from airflow.exceptions import TaskDeferred +from airflow.providers.dbt.cloud.sensors.dbt import DbtCloudJobRunSensor from astronomer.providers.dbt.cloud.sensors.dbt import DbtCloudJobRunSensorAsync -from astronomer.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger - -MODULE = "astronomer.providers.dbt.cloud.sensors.dbt" class TestDbtCloudJobRunSensorAsync: @@ -16,35 +9,7 @@ class TestDbtCloudJobRunSensorAsync: DBT_RUN_ID = 1234 TIMEOUT = 300 - @mock.patch(f"{MODULE}.DbtCloudJobRunSensorAsync.defer") - @mock.patch(f"{MODULE}.DbtCloudJobRunSensorAsync.poke", return_value=True) - def test_DbtCloudJobRunSensorAsync_async_finish_before_deferred(self, mock_poke, mock_defer, context): - """Assert task is not deferred when it receives a finish status before deferring""" - task = DbtCloudJobRunSensorAsync( - dbt_cloud_conn_id=self.CONN_ID, - task_id=self.TASK_ID, - run_id=self.DBT_RUN_ID, - timeout=self.TIMEOUT, - ) - task.execute(context) - - assert not mock_defer.called - - @mock.patch(f"{MODULE}.DbtCloudJobRunSensorAsync.poke", return_value=False) - def test_dbt_job_run_sensor_async(self, context): - """Assert execute method defer for Dbt cloud job run status sensors""" - task = DbtCloudJobRunSensorAsync( - dbt_cloud_conn_id=self.CONN_ID, - task_id=self.TASK_ID, - run_id=self.DBT_RUN_ID, - timeout=self.TIMEOUT, - ) - with pytest.raises(TaskDeferred) as exc: - task.execute(context) - assert isinstance(exc.value.trigger, DbtCloudRunJobTrigger), "Trigger is not a DbtCloudRunJobTrigger" - - def test_dbt_job_run_sensor_async_execute_complete_success(self): - """Assert execute_complete log success message when trigger fire with target status""" + def test_init(self): task = DbtCloudJobRunSensorAsync( dbt_cloud_conn_id=self.CONN_ID, task_id=self.TASK_ID, @@ -52,29 +17,5 @@ def test_dbt_job_run_sensor_async_execute_complete_success(self): timeout=self.TIMEOUT, ) - msg = f"Job run {self.DBT_RUN_ID} has completed successfully." - with mock.patch.object(task.log, "info") as mock_log_info: - task.execute_complete( - context={}, event={"status": "success", "message": msg, "run_id": self.DBT_RUN_ID} - ) - mock_log_info.assert_called_with(msg) - - @pytest.mark.parametrize( - "mock_status, mock_message", - [ - ("cancelled", "Job run 1234 has been cancelled."), - ("error", "Job run 1234 has failed."), - ], - ) - def test_dbt_job_run_sensor_async_execute_complete_failure(self, mock_status, mock_message): - """Assert execute_complete method to raise exception on the cancelled and error status""" - task = DbtCloudJobRunSensorAsync( - dbt_cloud_conn_id=self.CONN_ID, - task_id=self.TASK_ID, - run_id=self.DBT_RUN_ID, - timeout=self.TIMEOUT, - ) - with pytest.raises(AirflowException): - task.execute_complete( - context={}, event={"status": mock_status, "message": mock_message, "run_id": self.DBT_RUN_ID} - ) + assert isinstance(task, DbtCloudJobRunSensor) + assert task.deferrable is True From b0edbd9295cecbe3dc26cc08589b798ecf9d7033 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Tue, 23 Jan 2024 18:35:24 +0530 Subject: [PATCH 2/2] Add test for poll_interval deprecation warning --- tests/dbt/cloud/sensors/test_dbt.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/dbt/cloud/sensors/test_dbt.py b/tests/dbt/cloud/sensors/test_dbt.py index e51d9e097..eb929048f 100644 --- a/tests/dbt/cloud/sensors/test_dbt.py +++ b/tests/dbt/cloud/sensors/test_dbt.py @@ -1,3 +1,4 @@ +import pytest from airflow.providers.dbt.cloud.sensors.dbt import DbtCloudJobRunSensor from astronomer.providers.dbt.cloud.sensors.dbt import DbtCloudJobRunSensorAsync @@ -19,3 +20,15 @@ def test_init(self): assert isinstance(task, DbtCloudJobRunSensor) assert task.deferrable is True + + def test_poll_interval_deprecation_warning(self): + """Test DeprecationWarning for DbtCloudJobRunSensorAsync by setting param poll_interval""" + # TODO: Remove once deprecated + with pytest.warns(expected_warning=DeprecationWarning): + DbtCloudJobRunSensorAsync( + dbt_cloud_conn_id=self.CONN_ID, + task_id=self.TASK_ID, + run_id=self.DBT_RUN_ID, + timeout=self.TIMEOUT, + poll_interval=5.0, + )