diff --git a/astronomer/providers/core/example_dags/example_astro.py b/astronomer/providers/core/example_dags/example_astro.py new file mode 100644 index 000000000..6548988d3 --- /dev/null +++ b/astronomer/providers/core/example_dags/example_astro.py @@ -0,0 +1,19 @@ +from datetime import datetime + +from airflow import DAG + +from astronomer.providers.core.sensors.external_task import ExternalDeploymentSensor + +with DAG( + dag_id="example_astro_task", + start_date=datetime(2022, 1, 1), + schedule=None, + catchup=False, + tags=["example", "async", "core"], +) as dag: + ExternalDeploymentSensor( + task_id="test", + deployment_id="clpccxlbs45772d7yz84be4ykx", + workspace_id="cll0nk0c3003u01kd092pghag", + organization_id="cll0nj92h00iu01j51htnafwh" + ) diff --git a/astronomer/providers/core/hooks/__init__.py b/astronomer/providers/core/hooks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/astronomer/providers/core/hooks/astro.py b/astronomer/providers/core/hooks/astro.py new file mode 100644 index 000000000..d73c28c2c --- /dev/null +++ b/astronomer/providers/core/hooks/astro.py @@ -0,0 +1,29 @@ +from typing import Any + +from airflow.hooks.base import BaseHook + + +class AstroHook(BaseHook): + conn_name_attr = "astro_cloud_conn_id" + default_conn_name = "astro_cloud_default" + conn_type = "Astro Cloud" + hook_name = "Astro Cloud" + + @classmethod + def get_connection_form_widgets(cls) -> dict[str, Any]: + return {} + + @classmethod + def get_ui_field_behaviour(cls) -> dict[str, Any]: + return { + "hidden_fields": ["host", "login", "port", "schema", "extra"], + "relabeling": { + "password": "Astro Cloud API Token", + }, + "placeholders": { + "password": "ey...xz.ey...fq.tw...ap", + }, + } + + def get_conn(self) -> Any: + pass diff --git a/astronomer/providers/core/sensors/external_task.py b/astronomer/providers/core/sensors/external_task.py index 984e1123f..04d2c220b 100644 --- a/astronomer/providers/core/sensors/external_task.py +++ b/astronomer/providers/core/sensors/external_task.py @@ -1,14 +1,18 @@ from __future__ import annotations import datetime +import os import warnings from typing import TYPE_CHECKING, Any +from airflow.sensors.base import BaseSensorOperator from airflow.providers.http.hooks.http import HttpHook from airflow.providers.http.sensors.http import HttpSensor from airflow.sensors.external_task import ExternalTaskSensor from airflow.utils.session import provide_session +import requests +from airflow.hooks.base import BaseHook from astronomer.providers.core.triggers.external_task import ( DagStateTrigger, ExternalDeploymentTaskTrigger, @@ -23,9 +27,9 @@ class ExternalTaskSensorAsync(ExternalTaskSensor): # noqa: D101 def __init__( - self, - poke_interval: float = 5.0, - **kwargs: Any, + self, + poke_interval: float = 5.0, + **kwargs: Any, ) -> None: warnings.warn( ( @@ -47,7 +51,7 @@ def execute(self, context: Context) -> None: # Defer to our trigger if not poke(self, context): if ( - not self.external_task_id + not self.external_task_id ): # Tempting to explicitly check for None, but this captures falsely values self.defer( timeout=datetime.timedelta(seconds=self.timeout), @@ -77,7 +81,7 @@ def execute(self, context: Context) -> None: @provide_session def execute_complete( # type: ignore[override] - self, context: Context, session: Session, event: dict[str, Any] | None = None + self, context: Context, session: Session, event: dict[str, Any] | None = None ) -> None: """Verifies that there is a success status for each task via execution date.""" execution_dates = self.get_execution_dates(context) @@ -102,6 +106,81 @@ def get_execution_dates(self, context: Context) -> list[datetime.datetime]: return execution_dates +class ExternalDeploymentSensor(BaseSensorOperator): + + def __init__( + self, + astro_cloud_conn_id: str = "astro_cloud_default", + deployment_id: str | None = None, + workspace_id: str | None = None, + organization_id: str | None = None, + **kwargs + ): + super().__init__(**kwargs) + self._astro_cloud_conn_id = astro_cloud_conn_id + self._deployment_id = deployment_id + self._workspace_id = workspace_id + self._organization_id = organization_id + self._astro_api_token: str | None = None + self._deployment_details: dict | None = None + + @property + def target_deployment_url(self) -> str: + """ + Get the URLs of the specified deployment + https://docs.astronomer.io/astro/api/platform-api-reference#tag/Deployment/operation/GetDeployment + """ + if self._deployment_details is None: + get_deployment_url = f"https://api.astronomer.io/platform/v1beta1/organizations/{self._organization_id}/deployments/{self._deployment_id}" + response = requests.get( + url=get_deployment_url, headers={"Authorization": f"Bearer {self._api_token}"} + ) + self.log.info("Fetched deployment details.") + self.log.info(response.text) + response.raise_for_status() + self._deployment_details = response.json() + target_deployment_url, _, _ = self._deployment_details["webServerUrl"].partition("?orgId=") + return f"https://{target_deployment_url}" + + @property + def target_deployment_rest_api_url(self) -> str: + """ + Return the deployment's REST API URL. Example: + https://clkvh3b46003m01kbalgwwdcy.astronomer.run/dd8od9mt/api/v1 + + Example URL structure returned by Astro API: + clkvh3b46003m01kbalgwwdcy.astronomer.run/dd8od9mt/api/v1 + """ + if self._deployment_id is None: + self.log.info("No deployment id configured. Using current deployment id.") + self._deployment_id = os.environ["ASTRO_DEPLOYMENT_ID"] + + if self._workspace_id is None: + self.log.info("No workspace id configured. Using current workspace id.") + self._workspace_id = os.environ["ASTRO_WORKSPACE_ID"] + + if self._organization_id is None: + self.log.info("No organization id configured. Using current organization id.") + self._organization_id = os.environ["ASTRO_ORGANIZATION_ID"] + + return self.target_deployment_url + + @property + def _api_token(self) -> str: + """ + Cache the Astro API token in memory to avoid re-fetching multiple times from Airflow connection. + """ + if self._astro_api_token is None: + conn = BaseHook.get_connection(conn_id=self._astro_cloud_conn_id) + self._astro_api_token = conn.password + self.log.info("Cached Astro API token from Airflow connection.") + self.log.info("_astro_api_token %s", self._astro_api_token) + return self._astro_api_token + + def execute(self, context: Context) -> Any: + print("target_deployment_url", self.target_deployment_rest_api_url) + + class ExternalDeploymentTaskSensorAsync(HttpSensor): """ External deployment task sensor Make HTTP call and poll for the response state of externally diff --git a/astronomer/providers/package.py b/astronomer/providers/package.py index de4ead1d6..1f8cfc029 100644 --- a/astronomer/providers/package.py +++ b/astronomer/providers/package.py @@ -10,6 +10,11 @@ def get_provider_info() -> Dict[str, Any]: "description": "Apache Airflow Providers containing Deferrable Operators & Sensors from Astronomer", "versions": "1.18.4", # Optional. - "hook-class-names": [], + "connection-types": [ + { + "hook-class-name": "astronomer.providers.core.hooks.astro.AstroHook", + "connection-type": "Astro Cloud" + } + ], "extra-links": [], }