diff --git a/astronomer/providers/google/cloud/hooks/dataproc.py b/astronomer/providers/google/cloud/hooks/dataproc.py index 923e298f5..7babf529a 100644 --- a/astronomer/providers/google/cloud/hooks/dataproc.py +++ b/astronomer/providers/google/cloud/hooks/dataproc.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import warnings -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Any, Sequence, Union from airflow.providers.google.common.consts import CLIENT_INFO from airflow.providers.google.common.hooks.base_google import GoogleBaseHook @@ -21,10 +23,25 @@ class DataprocHookAsync(GoogleBaseHook): - """Async Hook for Google Cloud Dataproc APIs""" + """Async Hook for Google Cloud Dataproc APIs + + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook` instead + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn( + ( + "This module is deprecated and will be removed in 2.0.0." + "Please use `airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook`" + ), + DeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) def get_cluster_client( - self, region: Optional[str] = None, location: Optional[str] = None + self, region: str | None = None, location: str | None = None ) -> ClusterControllerAsyncClient: """ Get async cluster controller client for GCP Dataproc. @@ -46,7 +63,7 @@ def get_cluster_client( ) def get_job_client( - self, region: Optional[str] = None, location: Optional[str] = None + self, region: str | None = None, location: str | None = None ) -> JobControllerAsyncClient: """ Get async job controller for GCP Dataproc. @@ -73,7 +90,7 @@ async def get_cluster( cluster_name: str, project_id: str, retry: OptionalRetry = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[tuple[str, str]] = (), ) -> clusters.Cluster: """ Get a cluster details from GCP using `ClusterControllerAsyncClient` @@ -100,10 +117,10 @@ async def get_job( job_id: str, project_id: str, timeout: float = 5, - region: Optional[str] = None, - location: Optional[str] = None, + region: str | None = None, + location: str | None = None, retry: OptionalRetry = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[tuple[str, str]] = (), ) -> JobType: """ Gets the resource representation for a job using `JobControllerAsyncClient`. @@ -128,8 +145,8 @@ async def get_job( return job def _get_client_options_and_region( - self, region: Optional[str] = None, location: Optional[str] = None - ) -> Tuple[ClientOptions, Optional[str]]: + self, region: str | None = None, location: str | None = None + ) -> tuple[ClientOptions, str | None]: """ Checks for location if present or not and creates a client options using the provided region/location diff --git a/astronomer/providers/google/cloud/operators/dataproc.py b/astronomer/providers/google/cloud/operators/dataproc.py index 6adcbcec8..fd455aa4b 100644 --- a/astronomer/providers/google/cloud/operators/dataproc.py +++ b/astronomer/providers/google/cloud/operators/dataproc.py @@ -1,69 +1,22 @@ """This module contains Google Dataproc operators.""" from __future__ import annotations -import time import warnings from typing import Any -from airflow.exceptions import AirflowException -from airflow.providers.google.cloud.hooks.dataproc import DataprocHook -from airflow.providers.google.cloud.links.dataproc import ( - DATAPROC_CLUSTER_LINK, - DataprocLink, -) from airflow.providers.google.cloud.operators.dataproc import ( DataprocCreateClusterOperator, DataprocDeleteClusterOperator, DataprocSubmitJobOperator, DataprocUpdateClusterOperator, ) -from google.api_core.exceptions import AlreadyExists, NotFound -from google.cloud.dataproc_v1 import Cluster - -from astronomer.providers.google.cloud.triggers.dataproc import ( - DataprocCreateClusterTrigger, - DataprocDeleteClusterTrigger, -) -from astronomer.providers.utils.typing_compat import Context class DataprocCreateClusterOperatorAsync(DataprocCreateClusterOperator): """ - Create a new cluster on Google Cloud Dataproc Asynchronously. - - :param project_id: The ID of the google cloud project in which - to create the cluster. (templated) - :param cluster_name: Name of the cluster to create - :param labels: Labels that will be assigned to created cluster - :param cluster_config: Required. The cluster config to create. - If a dict is provided, it must be of the same form as the protobuf message - :class:`~google.cloud.dataproc_v1.types.ClusterConfig` - :param virtual_cluster_config: Optional. The virtual cluster config, used when creating a Dataproc - cluster that does not directly control the underlying compute resources, for example, when creating a - `Dataproc-on-GKE cluster - ` - :param region: The specified region where the dataproc cluster is created. - :param delete_on_error: If true the cluster will be deleted if created with ERROR state. Default - value is true. - :param use_if_exists: If true use existing cluster - :param request_id: Optional. A unique id used to identify the request. If the server receives two - ``DeleteClusterRequest`` requests with the same id, then the second request will be ignored and the - first ``google.longrunning.Operation`` created and stored in the backend is returned. - :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - ``retry`` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - :param polling_interval: Time in seconds to sleep between checks of cluster status + This class is deprecated. + Please use :class: `~airflow.providers.google.cloud.operators.dataproc.DataprocCreateClusterOperator` + and set `deferrable` param to `True` instead. """ def __init__( @@ -72,98 +25,24 @@ def __init__( polling_interval: float = 5.0, **kwargs: Any, ): - super().__init__(**kwargs) - self.polling_interval = polling_interval - - def execute(self, context: Context) -> Any: - """Call create cluster API and defer to DataprocCreateClusterTrigger to check the status""" - hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) - DataprocLink.persist( - context=context, task_instance=self, url=DATAPROC_CLUSTER_LINK, resource=self.cluster_name - ) - try: - hook.create_cluster( - region=self.region, - project_id=self.project_id, - cluster_name=self.cluster_name, - cluster_config=self.cluster_config, - labels=self.labels, - request_id=self.request_id, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - except AlreadyExists: - if not self.use_if_exists: - raise - self.log.info("Cluster already exists.") - - cluster = hook.get_cluster( - project_id=self.project_id, region=self.region, cluster_name=self.cluster_name + warnings.warn( + ( + "This module is deprecated and will be removed in 2.0.0." + "Please use `airflow.providers.google.cloud.operators.dataproc.DataprocCreateClusterOperator`" + "and set `deferrable` param to `True` instead." + ), + DeprecationWarning, + stacklevel=2, ) - if cluster.status.state == cluster.status.State.RUNNING: - self.log.info("Cluster created.") - return Cluster.to_dict(cluster) - else: - end_time: float = time.time() + self.timeout - self.defer( - trigger=DataprocCreateClusterTrigger( - project_id=self.project_id, - region=self.region, - cluster_name=self.cluster_name, - end_time=end_time, - metadata=self.metadata, - delete_on_error=self.delete_on_error, - cluster_config=self.cluster_config, - labels=self.labels, - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - polling_interval=self.polling_interval, - ), - method_name="execute_complete", - ) - - def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> Any: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event and event["status"] == "success": - self.log.info("Cluster created successfully \n %s", event["data"]) - return event["data"] - elif event and event["status"] == "error": - raise AirflowException(event["message"]) - raise AirflowException("No event received in trigger callback") + kwargs["polling_interval_seconds"] = polling_interval + super().__init__(deferrable=True, **kwargs) class DataprocDeleteClusterOperatorAsync(DataprocDeleteClusterOperator): """ - Delete a cluster on Google Cloud Dataproc Asynchronously. - - :param region: Required. The Cloud Dataproc region in which to handle the request (templated). - :param cluster_name: Required. The cluster name (templated). - :param project_id: Optional. The ID of the Google Cloud project that the cluster belongs to (templated). - :param cluster_uuid: Optional. Specifying the ``cluster_uuid`` means the RPC should fail - if cluster with specified UUID does not exist. - :param request_id: Optional. A unique id used to identify the request. If the server receives two - ``DeleteClusterRequest`` requests with the same id, then the second request will be ignored and the - first ``google.longrunning.Operation`` created and stored in the backend is returned. - :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - ``retry`` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - :param polling_interval: Time in seconds to sleep between checks of cluster status + This class is deprecated. + Please use :class: `~airflow.providers.google.cloud.operators.dataproc.DataprocDeleteClusterOperator` + and set `deferrable` param to `True` instead. """ def __init__( @@ -172,61 +51,19 @@ def __init__( polling_interval: float = 5.0, **kwargs: Any, ): - super().__init__(**kwargs) - self.polling_interval = polling_interval - if self.timeout is None: - self.timeout: float = 24 * 60 * 60 - - def execute(self, context: Context) -> None: - """Call delete cluster API and defer to wait for cluster to completely deleted""" - hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) - self.log.info("Deleting cluster: %s", self.cluster_name) - hook.delete_cluster( - project_id=self.project_id, - region=self.region, - cluster_name=self.cluster_name, - cluster_uuid=self.cluster_uuid, - request_id=self.request_id, - retry=self.retry, - metadata=self.metadata, - ) - - try: - hook.get_cluster(project_id=self.project_id, region=self.region, cluster_name=self.cluster_name) - except NotFound: - self.log.info("Cluster deleted.") - return - except Exception as e: - raise AirflowException(str(e)) - - end_time: float = time.time() + self.timeout - - self.defer( - trigger=DataprocDeleteClusterTrigger( - gcp_conn_id=self.gcp_conn_id, - project_id=self.project_id, - region=self.region, - cluster_name=self.cluster_name, - request_id=self.request_id, - retry=self.retry, - end_time=end_time, - metadata=self.metadata, - impersonation_chain=self.impersonation_chain, + warnings.warn( + ( + "This module is deprecated and will be removed in 2.0.0." + "Please use `airflow.providers.google.cloud.operators.dataproc.DataprocDeleteClusterOperator`" + "and set `deferrable` param to `True` instead." ), - method_name="execute_complete", + DeprecationWarning, + stacklevel=2, ) - - def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> Any: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event and event["status"] == "error": - raise AirflowException(event["message"]) - elif event is None: - raise AirflowException("No event received in trigger callback") - self.log.info("Cluster deleted.") + kwargs["polling_interval_seconds"] = polling_interval + super().__init__(deferrable=True, **kwargs) + if self.timeout is None: + self.timeout: float = 24 * 60 * 60 class DataprocSubmitJobOperatorAsync(DataprocSubmitJobOperator): @@ -251,42 +88,9 @@ def __init__(self, *args: Any, **kwargs: Any): class DataprocUpdateClusterOperatorAsync(DataprocUpdateClusterOperator): """ - Updates an existing cluster in a Google cloud platform project. - - :param region: Required. The Cloud Dataproc region in which to handle the request. - :param project_id: Optional. The ID of the Google Cloud project the cluster belongs to. - :param cluster_name: Required. The cluster name. - :param cluster: Required. The changes to the cluster. - If a dict is provided, it must be of the same form as the protobuf message - :class:`~google.cloud.dataproc_v1.types.Cluster` - :param update_mask: Required. Specifies the path, relative to ``Cluster``, of the field to update. For - example, to change the number of workers in a cluster to 5, the ``update_mask`` parameter would be - specified as ``config.worker_config.num_instances``, and the ``PATCH`` request body would specify the - new value. If a dict is provided, it must be of the same form as the protobuf message - :class:`~google.protobuf.field_mask_pb2.FieldMask` - :param graceful_decommission_timeout: Optional. Timeout for graceful YARN decommissioning. Graceful - decommissioning allows removing nodes from the cluster without interrupting jobs in progress. Timeout - specifies how long to wait for jobs in progress to finish before forcefully removing nodes (and - potentially interrupting jobs). Default timeout is 0 (for forceful decommission), and the maximum - allowed timeout is 1 day. - :param request_id: Optional. A unique id used to identify the request. If the server receives two - ``UpdateClusterRequest`` requests with the same id, then the second request will be ignored and the - first ``google.longrunning.Operation`` created and stored in the backend is returned. - :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - ``retry`` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - :param polling_interval: Time in seconds to sleep between checks of cluster status + This class is deprecated. + Please use :class: `~airflow.providers.google.cloud.operators.dataproc.DataprocUpdateClusterOperator` + and set `deferrable` param to `True` instead. """ def __init__( @@ -295,61 +99,16 @@ def __init__( polling_interval: float = 5.0, **kwargs: Any, ): - super().__init__(**kwargs) - self.polling_interval = polling_interval + warnings.warn( + ( + "This module is deprecated and will be removed in 2.0.0." + "Please use `airflow.providers.google.cloud.operators.dataproc.DataprocUpdateClusterOperator`" + "and set `deferrable` param to `True` instead." + ), + DeprecationWarning, + stacklevel=2, + ) + kwargs["polling_interval_seconds"] = polling_interval + super().__init__(deferrable=True, **kwargs) if self.timeout is None: self.timeout: float = 24 * 60 * 60 - - def execute(self, context: Context) -> None: - """Call update cluster API and defer to wait for cluster update to complete""" - hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) - # Save data required by extra links no matter what the cluster status will be - DataprocLink.persist( - context=context, task_instance=self, url=DATAPROC_CLUSTER_LINK, resource=self.cluster_name - ) - self.log.info("Updating %s cluster.", self.cluster_name) - hook.update_cluster( - project_id=self.project_id, - region=self.region, - cluster_name=self.cluster_name, - cluster=self.cluster, - update_mask=self.update_mask, - graceful_decommission_timeout=self.graceful_decommission_timeout, - request_id=self.request_id, - retry=self.retry, - metadata=self.metadata, - ) - cluster = hook.get_cluster( - project_id=self.project_id, region=self.region, cluster_name=self.cluster_name - ) - if cluster.status.state == cluster.status.State.RUNNING: - self.log.info("Updated %s cluster.", self.cluster_name) - else: - end_time: float = time.time() + self.timeout - - self.defer( - trigger=DataprocCreateClusterTrigger( - project_id=self.project_id, - region=self.region, - cluster_name=self.cluster_name, - end_time=end_time, - metadata=self.metadata, - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - polling_interval=self.polling_interval, - ), - method_name="execute_complete", - ) - - def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event and event["status"] == "success": - self.log.info("Updated %s cluster.", event["cluster_name"]) - return - if event and event["status"] == "error": - raise AirflowException(event["message"]) - raise AirflowException("No event received in trigger callback") diff --git a/astronomer/providers/google/cloud/triggers/dataproc.py b/astronomer/providers/google/cloud/triggers/dataproc.py index 560140f29..d04475cd9 100644 --- a/astronomer/providers/google/cloud/triggers/dataproc.py +++ b/astronomer/providers/google/cloud/triggers/dataproc.py @@ -19,6 +19,9 @@ class DataprocCreateClusterTrigger(BaseTrigger): """ Asynchronously check the status of a cluster + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger` instead + :param project_id: The ID of the Google Cloud project the cluster belongs to :param region: The Cloud Dataproc region in which to handle the request :param cluster_name: The name of the cluster @@ -52,6 +55,14 @@ def __init__( polling_interval: float = 5.0, **kwargs: Any, ): + warnings.warn( + ( + "This module is deprecated and will be removed in 2.0.0." + "Please use `airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger`." + ), + DeprecationWarning, + stacklevel=2, + ) super().__init__(**kwargs) self.project_id = project_id self.region = region @@ -194,6 +205,9 @@ class DataprocDeleteClusterTrigger(BaseTrigger): """ Asynchronously check the status of a cluster + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.google.cloud.triggers.dataproc.DataprocDeleteClusterTrigger` instead + :param cluster_name: The name of the cluster :param end_time: Time in second left to check the cluster status :param project_id: The ID of the Google Cloud project the cluster belongs to @@ -223,6 +237,14 @@ def __init__( polling_interval: float = 5.0, **kwargs: Any, ): + warnings.warn( + ( + "This module is deprecated and will be removed in 2.0.0." + "Please use `airflow.providers.google.cloud.triggers.dataproc.DataprocDeleteClusterTrigger`" + ), + DeprecationWarning, + stacklevel=2, + ) super().__init__(**kwargs) self.cluster_name = cluster_name self.end_time = end_time diff --git a/setup.cfg b/setup.cfg index 1e68e119c..e291bf37d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -61,7 +61,7 @@ databricks = dbt.cloud = apache-airflow-providers-dbt-cloud>=3.5.1 google = - apache-airflow-providers-google>=8.1.0 + apache-airflow-providers-google>=10.14.0 gcloud-aio-storage gcloud-aio-bigquery http = @@ -123,7 +123,7 @@ all = apache-airflow-providers-apache-livy>=3.7.1 apache-airflow-providers-cncf-kubernetes>=4 apache-airflow-providers-databricks>=2.2.0 - apache-airflow-providers-google>=8.1.0 + apache-airflow-providers-google>=10.14.0 apache-airflow-providers-http apache-airflow-providers-snowflake apache-airflow-providers-sftp diff --git a/tests/google/cloud/operators/test_dataproc.py b/tests/google/cloud/operators/test_dataproc.py index 329cd138e..a642f9738 100644 --- a/tests/google/cloud/operators/test_dataproc.py +++ b/tests/google/cloud/operators/test_dataproc.py @@ -1,12 +1,9 @@ -from unittest import mock - -import pytest -from airflow.exceptions import AirflowException, TaskDeferred -from airflow.providers.google.cloud.hooks.bigquery import NotFound -from airflow.providers.google.cloud.operators.dataproc import DataprocSubmitJobOperator -from google.api_core.exceptions import AlreadyExists -from google.cloud import dataproc -from google.cloud.dataproc_v1 import Cluster +from airflow.providers.google.cloud.operators.dataproc import ( + DataprocCreateClusterOperator, + DataprocDeleteClusterOperator, + DataprocSubmitJobOperator, + DataprocUpdateClusterOperator, +) from astronomer.providers.google.cloud.operators.dataproc import ( DataprocCreateClusterOperatorAsync, @@ -14,11 +11,6 @@ DataprocSubmitJobOperatorAsync, DataprocUpdateClusterOperatorAsync, ) -from astronomer.providers.google.cloud.triggers.dataproc import ( - DataprocCreateClusterTrigger, - DataprocDeleteClusterTrigger, -) -from tests.utils.airflow_util import create_context TEST_PROJECT_ID = "test_project_id" TEST_CLUSTER_NAME = "test_cluster" @@ -39,191 +31,25 @@ class TestDataprocCreateClusterOperatorAsync: - @mock.patch( - "astronomer.providers.google.cloud.operators.dataproc.DataprocCreateClusterOperatorAsync.defer" - ) - @mock.patch(f"{MODULE}.get_cluster") - @mock.patch(f"{MODULE}.create_cluster") - def test_dataproc_operator_create_cluster_execute_async_finish_before_defer( - self, mock_create_cluster, mock_get_cluster, mock_defer - ): - cluster = Cluster( - cluster_name="test_cluster", - status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.RUNNING), - ) - mock_create_cluster.return_value = cluster - mock_get_cluster.return_value = cluster - task = DataprocCreateClusterOperatorAsync( - task_id="task-id", cluster_name="test_cluster", region=TEST_REGION, project_id=TEST_PROJECT_ID - ) - task.execute(create_context(task)) - assert not mock_defer.called - - @mock.patch(f"{MODULE}.get_cluster") - @mock.patch(f"{MODULE}.create_cluster") - def test_dataproc_operator_create_cluster_execute_async(self, mock_create_cluster, mock_get_cluster): - """ - Asserts that a task is deferred and a DataprocCreateClusterTrigger will be fired - when the DataprocCreateClusterOperatorAsync is executed. - """ - cluster = Cluster( - cluster_name="test_cluster", - status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.CREATING), - ) - mock_create_cluster.return_value = cluster - mock_get_cluster.return_value = cluster - + def test_init(self): task = DataprocCreateClusterOperatorAsync( task_id="task-id", cluster_name="test_cluster", region=TEST_REGION, project_id=TEST_PROJECT_ID ) - with pytest.raises(TaskDeferred) as exc: - task.execute(create_context(task)) - assert isinstance( - exc.value.trigger, DataprocCreateClusterTrigger - ), "Trigger is not a DataprocCreateClusterTrigger" + assert isinstance(task, DataprocCreateClusterOperator) + assert task.deferrable is True - @mock.patch(f"{MODULE}.get_cluster") - @mock.patch(f"{MODULE}.create_cluster") - def test_dataproc_operator_create_cluster_execute_async_cluster_exist_exception( - self, mock_create_cluster, mock_get_cluster - ): - """ - Asserts that a task will raise exception when dataproc cluster already exist - and use_if_exists param is False - """ - mock_create_cluster.side_effect = AlreadyExists("Cluster already exist") - mock_get_cluster.return_value = Cluster( - cluster_name="test_cluster", - status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.CREATING), - ) - task = DataprocCreateClusterOperatorAsync( +class TestDataprocDeleteClusterOperatorAsync: + def test_init(self): + task = DataprocDeleteClusterOperatorAsync( task_id="task-id", - cluster_name="test_cluster", - region=TEST_REGION, project_id=TEST_PROJECT_ID, - use_if_exists=False, - ) - with pytest.raises(AlreadyExists): - task.execute(create_context(task)) - - @mock.patch(f"{MODULE}.get_cluster") - @mock.patch(f"{MODULE}.create_cluster") - def test_dataproc_operator_create_cluster_execute_async_cluster_exist( - self, mock_create_cluster, mock_get_cluster - ): - """ - Asserts that a task is deferred and a DataprocCreateClusterTrigger will be fired - when the DataprocCreateClusterOperatorAsync is executed when dataproc cluster already exist. - """ - mock_create_cluster.side_effect = AlreadyExists("Cluster already exist") - mock_get_cluster.return_value = Cluster( - cluster_name="test_cluster", - status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.CREATING), - ) - - task = DataprocCreateClusterOperatorAsync( - task_id="task-id", cluster_name="test_cluster", region=TEST_REGION, project_id=TEST_PROJECT_ID - ) - with pytest.raises(TaskDeferred) as exc: - task.execute(create_context(task)) - assert isinstance( - exc.value.trigger, DataprocCreateClusterTrigger - ), "Trigger is not a DataprocCreateClusterTrigger" - - def test_dataproc_operator_create_cluster_execute_complete_success(self, context): - """assert that execute_complete return cluster detail when task succeed""" - cluster = Cluster( - cluster_name="test_cluster", - status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.CREATING), - ) - task = DataprocCreateClusterOperatorAsync( - task_id="task-id", cluster_name="test_cluster", region=TEST_REGION, project_id=TEST_PROJECT_ID - ) - cluster_details = task.execute_complete( - context=context, event={"status": "success", "data": cluster, "message": ""} - ) - assert cluster_details is not None - - @pytest.mark.parametrize( - "status", - [ - "error", - None, - ], - ) - def test_dataproc_operator_create_cluster_execute_complete_fail(self, status, context): - """assert that execute_complete raise exception when task fail""" - task = DataprocCreateClusterOperatorAsync( - task_id="task-id", cluster_name="test_cluster", region=TEST_REGION, project_id=TEST_PROJECT_ID + cluster_name=TEST_CLUSTER_NAME, + region=TEST_REGION, + timeout=None, ) - with pytest.raises(AirflowException): - task.execute_complete( - context=context, event={"status": status, "message": "fail to create cluster"} - ) - - -class TestDataprocDeleteClusterOperatorAsync: - OPERATOR = DataprocDeleteClusterOperatorAsync( - task_id="task-id", project_id=TEST_PROJECT_ID, cluster_name=TEST_CLUSTER_NAME, region=TEST_REGION - ) - - @mock.patch( - "astronomer.providers.google.cloud.operators.dataproc.DataprocDeleteClusterOperatorAsync.defer" - ) - @mock.patch(f"{MODULE}.get_cluster") - @mock.patch(f"{MODULE}.delete_cluster") - def test_dataproc_operator_create_cluster_execute_async_finish_before_defer( - self, mock_delete_cluster, mock_get_cluster, mock_defer, context - ): - mock_delete_cluster.return_value = {} - mock_get_cluster.side_effect = NotFound("test") - self.OPERATOR.execute(context) - assert not mock_defer.called - - @mock.patch( - "astronomer.providers.google.cloud.operators.dataproc.DataprocDeleteClusterOperatorAsync.defer" - ) - @mock.patch(f"{MODULE}.get_cluster") - @mock.patch(f"{MODULE}.delete_cluster") - def test_dataproc_operator_create_cluster_execute_async_unexpected_error_before_defer( - self, mock_delete_cluster, mock_get_cluster, mock_defer, context - ): - mock_delete_cluster.return_value = {} - mock_get_cluster.side_effect = Exception("Unexpected") - with pytest.raises(AirflowException): - self.OPERATOR.execute(context) - assert not mock_defer.called - - @mock.patch(f"{MODULE}.get_cluster") - @mock.patch(f"{MODULE}.delete_cluster") - def test_dataproc_delete_operator_execute_async(self, mock_delete_cluster, get_cluster, context): - """ - Asserts that a task is deferred and a DataprocDeleteClusterTrigger will be fired - when the DataprocDeleteClusterOperatorAsync is executed. - """ - mock_delete_cluster.return_value = {} - with pytest.raises(TaskDeferred) as exc: - self.OPERATOR.execute(context) - assert isinstance( - exc.value.trigger, DataprocDeleteClusterTrigger - ), "Trigger is not a DataprocDeleteClusterTrigger" - - def test_dataproc_delete_operator_execute_complete_success(self, context): - """assert that execute_complete execute without error when receive success signal from trigger""" - assert self.OPERATOR.execute_complete(context=context, event={"status": "success"}) is None - - @pytest.mark.parametrize( - "event", - [ - ({"status": "error", "message": "test failure message"}), - None, - ], - ) - def test_dataproc_delete_operator_execute_complete_exception(self, event, context): - """assert that execute_complete raise exception when receive error from trigger""" - with pytest.raises(AirflowException): - self.OPERATOR.execute_complete(context=context, event=event) + assert isinstance(task, DataprocDeleteClusterOperator) + assert task.deferrable is True class TestDataprocSubmitJobOperatorAsync: @@ -237,85 +63,15 @@ def test_init(self): class TestDataprocUpdateClusterOperatorAsync: - OPERATOR = DataprocUpdateClusterOperatorAsync( - task_id="task-id", - cluster_name="test_cluster", - region=TEST_REGION, - project_id=TEST_PROJECT_ID, - cluster={}, - graceful_decommission_timeout=30, - update_mask={}, - ) - - @mock.patch( - "astronomer.providers.google.cloud.operators.dataproc.DataprocUpdateClusterOperatorAsync.defer" - ) - @mock.patch("airflow.providers.google.cloud.links.dataproc.DataprocLink.persist") - @mock.patch(f"{MODULE}.get_cluster") - @mock.patch(f"{MODULE}.update_cluster") - def test_dataproc_operator_update_cluster_execute_async_finish_before_defer( - self, mock_update_cluster, mock_get_cluster, mock_persist, mock_defer, context - ): - mock_persist.return_value = {} - cluster = Cluster( - cluster_name="test_cluster", - status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.RUNNING), - ) - mock_update_cluster.return_value = cluster - mock_get_cluster.return_value = cluster - DataprocCreateClusterOperatorAsync( - task_id="task-id", cluster_name="test_cluster", region=TEST_REGION, project_id=TEST_PROJECT_ID - ) - self.OPERATOR.execute(context) - assert not mock_defer.called - - @mock.patch("airflow.providers.google.cloud.links.dataproc.DataprocLink.persist") - @mock.patch(f"{MODULE}.get_cluster") - @mock.patch(f"{MODULE}.update_cluster") - def test_dataproc_operator_update_cluster_execute_async( - self, mock_update_cluster, mock_get_cluster, mock_persist, context - ): - """ - Asserts that a task is deferred and a DataprocCreateClusterTrigger will be fired - when the DataprocCreateClusterOperatorAsync is executed. - """ - mock_persist.return_value = {} - cluster = Cluster( - cluster_name="test_cluster", - status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.CREATING), - ) - mock_update_cluster.return_value = cluster - mock_get_cluster.return_value = cluster - - with pytest.raises(TaskDeferred) as exc: - self.OPERATOR.execute(context) - assert isinstance( - exc.value.trigger, DataprocCreateClusterTrigger - ), "Trigger is not a DataprocCreateClusterTrigger" - - def test_dataproc_operator_update_cluster_execute_complete_success(self, context): - """assert that execute_complete return cluster detail when task succeed""" - cluster = Cluster( + def test_init(self): + task = DataprocUpdateClusterOperatorAsync( + task_id="task-id", cluster_name="test_cluster", - status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.CREATING), - ) - - assert ( - self.OPERATOR.execute_complete( - context=context, event={"status": "success", "data": cluster, "cluster_name": "test_cluster"} - ) - is None + region=TEST_REGION, + project_id=TEST_PROJECT_ID, + cluster={}, + graceful_decommission_timeout=30, + update_mask={}, ) - - @pytest.mark.parametrize( - "event", - [ - {"status": "error", "message": ""}, - None, - ], - ) - def test_dataproc_operator_update_cluster_execute_complete_fail(self, event, context): - """assert that execute_complete raise exception when task fail""" - - with pytest.raises(AirflowException): - self.OPERATOR.execute_complete(context=context, event=event) + assert isinstance(task, DataprocUpdateClusterOperator) + assert task.deferrable is True