Skip to content

Commit

Permalink
Add ExternalDeploymentSensor
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajastro committed Feb 6, 2024
1 parent 57be9e7 commit b181c21
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 6 deletions.
19 changes: 19 additions & 0 deletions astronomer/providers/core/example_dags/example_astro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from datetime import datetime

from airflow import DAG

from astronomer.providers.core.sensors.external_task import ExternalDeploymentSensor

with DAG(
dag_id="example_astro_task",
start_date=datetime(2022, 1, 1),
schedule=None,
catchup=False,
tags=["example", "async", "core"],
) as dag:
ExternalDeploymentSensor(
task_id="test",
deployment_id="clpccxlbs45772d7yz84be4ykx",
workspace_id="cll0nk0c3003u01kd092pghag",
organization_id="cll0nj92h00iu01j51htnafwh"
)
Empty file.
29 changes: 29 additions & 0 deletions astronomer/providers/core/hooks/astro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Any

from airflow.hooks.base import BaseHook


class AstroHook(BaseHook):
conn_name_attr = "astro_cloud_conn_id"
default_conn_name = "astro_cloud_default"
conn_type = "Astro Cloud"
hook_name = "Astro Cloud"

@classmethod
def get_connection_form_widgets(cls) -> dict[str, Any]:
return {}

@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
return {
"hidden_fields": ["host", "login", "port", "schema", "extra"],
"relabeling": {
"password": "Astro Cloud API Token",
},
"placeholders": {
"password": "ey...xz.ey...fq.tw...ap",
},
}

def get_conn(self) -> Any:
pass
89 changes: 84 additions & 5 deletions astronomer/providers/core/sensors/external_task.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from __future__ import annotations

import datetime
import os
import warnings
from typing import TYPE_CHECKING, Any

from airflow.sensors.base import BaseSensorOperator
from airflow.providers.http.hooks.http import HttpHook
from airflow.providers.http.sensors.http import HttpSensor
from airflow.sensors.external_task import ExternalTaskSensor
from airflow.utils.session import provide_session

import requests
from airflow.hooks.base import BaseHook
from astronomer.providers.core.triggers.external_task import (
DagStateTrigger,
ExternalDeploymentTaskTrigger,
Expand All @@ -23,9 +27,9 @@

class ExternalTaskSensorAsync(ExternalTaskSensor): # noqa: D101
def __init__(
self,
poke_interval: float = 5.0,
**kwargs: Any,
self,
poke_interval: float = 5.0,
**kwargs: Any,
) -> None:
warnings.warn(
(
Expand All @@ -47,7 +51,7 @@ def execute(self, context: Context) -> None:
# Defer to our trigger
if not poke(self, context):
if (
not self.external_task_id
not self.external_task_id
): # Tempting to explicitly check for None, but this captures falsely values
self.defer(
timeout=datetime.timedelta(seconds=self.timeout),
Expand Down Expand Up @@ -77,7 +81,7 @@ def execute(self, context: Context) -> None:

@provide_session
def execute_complete( # type: ignore[override]
self, context: Context, session: Session, event: dict[str, Any] | None = None
self, context: Context, session: Session, event: dict[str, Any] | None = None
) -> None:
"""Verifies that there is a success status for each task via execution date."""
execution_dates = self.get_execution_dates(context)
Expand All @@ -102,6 +106,81 @@ def get_execution_dates(self, context: Context) -> list[datetime.datetime]:
return execution_dates


class ExternalDeploymentSensor(BaseSensorOperator):

def __init__(
self,
astro_cloud_conn_id: str = "astro_cloud_default",
deployment_id: str | None = None,
workspace_id: str | None = None,
organization_id: str | None = None,
**kwargs
):
super().__init__(**kwargs)
self._astro_cloud_conn_id = astro_cloud_conn_id
self._deployment_id = deployment_id
self._workspace_id = workspace_id
self._organization_id = organization_id
self._astro_api_token: str | None = None
self._deployment_details: dict | None = None

@property
def target_deployment_url(self) -> str:
"""
Get the URLs of the specified deployment
https://docs.astronomer.io/astro/api/platform-api-reference#tag/Deployment/operation/GetDeployment
"""
if self._deployment_details is None:
get_deployment_url = f"https://api.astronomer.io/platform/v1beta1/organizations/{self._organization_id}/deployments/{self._deployment_id}"
response = requests.get(
url=get_deployment_url, headers={"Authorization": f"Bearer {self._api_token}"}
)
self.log.info("Fetched deployment details.")
self.log.info(response.text)
response.raise_for_status()
self._deployment_details = response.json()
target_deployment_url, _, _ = self._deployment_details["webServerUrl"].partition("?orgId=")
return f"https://{target_deployment_url}"

@property
def target_deployment_rest_api_url(self) -> str:
"""
Return the deployment's REST API URL. Example:
https://clkvh3b46003m01kbalgwwdcy.astronomer.run/dd8od9mt/api/v1
Example URL structure returned by Astro API:
clkvh3b46003m01kbalgwwdcy.astronomer.run/dd8od9mt/api/v1
"""
if self._deployment_id is None:
self.log.info("No deployment id configured. Using current deployment id.")
self._deployment_id = os.environ["ASTRO_DEPLOYMENT_ID"]

if self._workspace_id is None:
self.log.info("No workspace id configured. Using current workspace id.")
self._workspace_id = os.environ["ASTRO_WORKSPACE_ID"]

if self._organization_id is None:
self.log.info("No organization id configured. Using current organization id.")
self._organization_id = os.environ["ASTRO_ORGANIZATION_ID"]

return self.target_deployment_url

@property
def _api_token(self) -> str:
"""
Cache the Astro API token in memory to avoid re-fetching multiple times from Airflow connection.
"""
if self._astro_api_token is None:
conn = BaseHook.get_connection(conn_id=self._astro_cloud_conn_id)
self._astro_api_token = conn.password
self.log.info("Cached Astro API token from Airflow connection.")
self.log.info("_astro_api_token %s", self._astro_api_token)
return self._astro_api_token

def execute(self, context: Context) -> Any:
print("target_deployment_url", self.target_deployment_rest_api_url)


class ExternalDeploymentTaskSensorAsync(HttpSensor):
"""
External deployment task sensor Make HTTP call and poll for the response state of externally
Expand Down
7 changes: 6 additions & 1 deletion astronomer/providers/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ def get_provider_info() -> Dict[str, Any]:
"description": "Apache Airflow Providers containing Deferrable Operators & Sensors from Astronomer",
"versions": "1.18.4",
# Optional.
"hook-class-names": [],
"connection-types": [
{
"hook-class-name": "astronomer.providers.core.hooks.astro.AstroHook",
"connection-type": "Astro Cloud"
}
],
"extra-links": [],
}

0 comments on commit b181c21

Please sign in to comment.