Skip to content

Commit

Permalink
feat(providers/google): deprecate GCSObjectUpdateSensorAsync
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W committed Jan 22, 2024
1 parent 54fe67c commit add7c91
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 125 deletions.
77 changes: 14 additions & 63 deletions astronomer/providers/google/cloud/sensors/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from __future__ import annotations

import warnings
from datetime import timedelta
from typing import Any

from airflow.providers.google.cloud.sensors.gcs import (
Expand All @@ -12,12 +11,6 @@
GCSUploadSessionCompleteSensor,
)

from astronomer.providers.google.cloud.triggers.gcs import (
GCSCheckBlobUpdateTimeTrigger,
)
from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception
from astronomer.providers.utils.typing_compat import Context


class GCSObjectExistenceSensorAsync(GCSObjectExistenceSensor):
"""
Expand Down Expand Up @@ -124,31 +117,14 @@ def __init__(

class GCSObjectUpdateSensorAsync(GCSObjectUpdateSensor):
"""
Async version to check if an object is updated in Google Cloud Storage
:param bucket: The Google Cloud Storage bucket where the object is.
:param object: The name of the object to download in the Google cloud
storage bucket.
:param ts_func: Callback for defining the update condition. The default callback
returns execution_date + schedule. The callback takes the context
as parameter.
:param google_cloud_conn_id: The connection ID to use when
connecting to Google Cloud Storage.
:param delegate_to: (Removed in apache-airflow-providers-google release 10.0.0, use impersonation_chain instead)
The account to impersonate using domain-wide delegation of authority, if any. For this to work, the service
account making the request must have domain-wide delegation enabled.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
This class is deprecated.
Please use :class: `~airflow.providers.google.cloud.sensors.gcs.GCSObjectUpdateSensor`
and set `deferrable` param to `True` instead.
"""

def __init__(
self,
*args: Any,
polling_interval: float = 5,
**kwargs: Any,
) -> None:
Expand All @@ -161,38 +137,13 @@ def __init__(
DeprecationWarning,
stacklevel=2,
)
super().__init__(**kwargs)

def execute(self, context: Context) -> None:
"""Airflow runs this method on the worker and defers using the trigger."""
hook_params = {"impersonation_chain": self.impersonation_chain}
if hasattr(self, "delegate_to"):
hook_params["delegate_to"] = self.delegate_to

if not poke(self, context):
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=GCSCheckBlobUpdateTimeTrigger(
bucket=self.bucket,
object_name=self.object,
ts=self.ts_func(context),
poke_interval=self.poke_interval,
google_cloud_conn_id=self.google_cloud_conn_id,
hook_params=hook_params,
),
method_name="execute_complete",
)

def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None = None) -> str: # type: ignore[return]
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if event["status"] == "success":
self.log.info(
"Sensor checks update time for object %s in bucket : %s", self.object, self.bucket
)
return event["message"]
raise_error_or_skip_exception(self.soft_fail, event["message"])
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use :class: `~airflow.providers.google.cloud.sensors.gcs.GCSObjectUpdateSensor`"
"and set `deferrable` param to `True` instead."
),
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, deferrable=True, **kwargs)
8 changes: 8 additions & 0 deletions astronomer/providers/google/cloud/triggers/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,14 @@ def __init__(
google_cloud_conn_id: str,
hook_params: dict[str, Any],
):
warnings.warn(
(
"This class is deprecated and will be removed in 2.0.0."
"Use :class: `~airflow.providers.google.cloud.triggers.gcs.GCSUploadSessionTrigger` instead"
),
DeprecationWarning,
stacklevel=2,
)
super().__init__()
self.bucket = bucket
self.object_name = object_name
Expand Down
71 changes: 9 additions & 62 deletions tests/google/cloud/sensors/test_gcs.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from unittest import mock

import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.google.cloud.sensors.gcs import (
GCSObjectExistenceSensor,
GCSObjectsWithPrefixExistenceSensor,
GCSObjectUpdateSensor,
GCSUploadSessionCompleteSensor,
)

Expand All @@ -14,10 +12,6 @@
GCSObjectUpdateSensorAsync,
GCSUploadSessionCompleteSensorAsync,
)
from astronomer.providers.google.cloud.triggers.gcs import (
GCSCheckBlobUpdateTimeTrigger,
)
from tests.utils.airflow_util import create_context

TEST_BUCKET = "TEST_BUCKET"
TEST_OBJECT = "TEST_OBJECT"
Expand Down Expand Up @@ -69,60 +63,13 @@ def test_init(self):


class TestGCSObjectUpdateSensorAsync:
OPERATOR = GCSObjectUpdateSensorAsync(
task_id="gcs-obj-update",
bucket=TEST_BUCKET,
object=TEST_OBJECT,
google_cloud_conn_id=TEST_GCP_CONN_ID,
)

@mock.patch(f"{MODULE}.GCSObjectUpdateSensorAsync.defer")
@mock.patch(f"{MODULE}.GCSObjectUpdateSensorAsync.poke", return_value=True)
def test_gcs_object_update_sensor_async_finish_before_deferred(self, mock_poke, mock_defer, context):
"""Assert task is not deferred when it receives a finish status before deferring"""
self.OPERATOR.execute(create_context(self.OPERATOR))
assert not mock_defer.called

@mock.patch(f"{MODULE}.GCSObjectUpdateSensorAsync.poke", return_value=False)
def test_gcs_object_update_sensor_async(self, context):
"""
Asserts that a task is deferred and a GCSBlobTrigger will be fired
when the GCSObjectUpdateSensorAsync is executed.
"""

with pytest.raises(TaskDeferred) as exc:
self.OPERATOR.execute(create_context(self.OPERATOR))
assert isinstance(
exc.value.trigger, GCSCheckBlobUpdateTimeTrigger
), "Trigger is not a GCSCheckBlobUpdateTimeTrigger"

def test_gcs_object_update_sensor_async_execute_failure(self, context):
"""Tests that an AirflowException is raised in case of error event"""

with pytest.raises(AirflowException):
self.OPERATOR.execute_complete(
context=context, event={"status": "error", "message": "test failure message"}
)

def test_gcs_object_update_sensor_async_execute_complete(self, context):
"""Asserts that logging occurs as expected"""

with mock.patch.object(self.OPERATOR.log, "info") as mock_log_info:
self.OPERATOR.execute_complete(
context=context, event={"status": "success", "message": "Job completed"}
)
mock_log_info.assert_called_with(
"Sensor checks update time for object %s in bucket : %s", TEST_OBJECT, TEST_BUCKET
def test_init(self):
task = GCSObjectUpdateSensorAsync(
task_id="gcs-obj-update",
bucket=TEST_BUCKET,
object=TEST_OBJECT,
google_cloud_conn_id=TEST_GCP_CONN_ID,
)

def test_poll_interval_deprecation_warning(self):
"""Test DeprecationWarning for GCSObjectUpdateSensorAsync by setting param poll_interval"""
# TODO: Remove once deprecated
with pytest.warns(expected_warning=DeprecationWarning):
GCSObjectUpdateSensorAsync(
task_id="task-id",
bucket=TEST_BUCKET,
object=TEST_OBJECT,
google_cloud_conn_id=TEST_GCP_CONN_ID,
polling_interval=5.0,
)
assert isinstance(task, GCSObjectUpdateSensor)
assert task.deferrable is True

0 comments on commit add7c91

Please sign in to comment.