Skip to content

Commit

Permalink
Change http call to async
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajastro committed Feb 15, 2024
1 parent 0733ab3 commit 2a80a27
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 8 deletions.
40 changes: 40 additions & 0 deletions astronomer/providers/core/hooks/astro.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from urllib.parse import quote

import requests
from aiohttp import ClientSession
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook

Expand Down Expand Up @@ -96,6 +97,24 @@ def get_dag_run(self, external_dag_id: str, dag_run_id: str) -> dict[str, Any] |
dr: dict[str, Any] = response.json()
return dr

async def get_a_dag_run(self, external_dag_id: str, dag_run_id: str) -> dict[str, Any] | None:
"""
Retrieves information about a specific DAG run.
:param external_dag_id: External ID of the DAG.
:param dag_run_id: ID of the DAG run.
"""
base_url, _ = self.get_conn()
dag_run_id = quote(dag_run_id)
path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}"
url = f"{base_url}{path}"

async with ClientSession(headers=self._headers) as session:
async with session.get(url) as response:
response.raise_for_status()
dr: dict[str, Any] = await response.json()
return dr

def get_task_instance(
self, external_dag_id: str, dag_run_id: str, external_task_id: str
) -> dict[str, Any] | None:
Expand All @@ -114,3 +133,24 @@ def get_task_instance(
response.raise_for_status()
ti: dict[str, Any] = response.json()
return ti

async def get_a_task_instance(
self, external_dag_id: str, dag_run_id: str, external_task_id: str
) -> dict[str, Any] | None:
"""
Retrieves information about a specific task instance within a DAG run.
:param external_dag_id: External ID of the DAG.
:param dag_run_id: ID of the DAG run.
:param external_task_id: External ID of the task.
"""
base_url, _ = self.get_conn()
dag_run_id = quote(dag_run_id)
path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}/taskInstances/{external_task_id}"
url = f"{base_url}{path}"

async with ClientSession(headers=self._headers) as session:
async with session.get(url) as response:
response.raise_for_status()
ti: dict[str, Any] = await response.json()
return ti
4 changes: 2 additions & 2 deletions astronomer/providers/core/triggers/astro.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
hook = AstroHook(self.astro_cloud_conn_id)
while True:
if self.external_task_id is not None:
task_instance = hook.get_task_instance(
task_instance = await hook.get_a_task_instance(
self.external_dag_id, self.dag_run_id, self.external_task_id
)
state = task_instance.get("state") if task_instance else None
Expand All @@ -68,7 +68,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
elif state in ("failed", "upstream_failed"):
yield TriggerEvent({"status": "failed"})
else:
dag_run = hook.get_dag_run(self.external_dag_id, self.dag_run_id)
dag_run = await hook.get_a_dag_run(self.external_dag_id, self.dag_run_id)
state = dag_run.get("state") if dag_run else None
if state == "success":
yield TriggerEvent({"status": "done"})
Expand Down
99 changes: 98 additions & 1 deletion tests/core/hooks/test_astro.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from unittest.mock import MagicMock, patch
from unittest import mock
from unittest.mock import MagicMock, Mock, patch

import pytest
from aioresponses import aioresponses
from airflow.exceptions import AirflowException

from astronomer.providers.core.hooks.astro import AstroHook
Expand Down Expand Up @@ -143,3 +145,98 @@ def test_get_task_instance(self, mock_requests_get, mock_get_connection):
# Assertions
mock_requests_get.assert_called_once()
assert result == {"task_instance_id": "456", "state": "success"}

@pytest.mark.asyncio
@mock.patch("astronomer.providers.core.hooks.astro.AstroHook._headers")
async def test_get_a_dag_run(self, mock_headers):
external_dag_id = "your_external_dag_id"
dag_run_id = "your_dag_run_id"
url = f"https://test.com/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}"

# Mocking necessary objects
your_class_instance = AstroHook()
your_class_instance.get_conn = Mock(return_value=("https://test.com", "Test Token"))
mock_headers.return_value = {"accept": "application/json", "Authorization": "Bearer Token"}
response_data = {
"conf": {},
"dag_id": "my_dag",
"dag_run_id": "manual__2024-02-14T19:06:32.053905+00:00",
"data_interval_end": "2024-02-14T19:06:32.053905+00:00",
"data_interval_start": "2024-02-14T19:06:32.053905+00:00",
"end_date": "2024-02-14T19:16:33.987139+00:00",
"execution_date": "2024-02-14T19:06:32.053905+00:00",
"external_trigger": True,
"last_scheduling_decision": "2024-02-14T19:16:33.985973+00:00",
"logical_date": "2024-02-14T19:06:32.053905+00:00",
"note": None,
"run_type": "manual",
"start_date": "2024-02-14T19:06:33.004299+00:00",
"state": "success",
}

with aioresponses() as mock_session:
mock_session.get(
url,
headers=your_class_instance._headers,
status=200,
payload=response_data,
)

result = await your_class_instance.get_a_dag_run(external_dag_id, dag_run_id)

assert result == response_data

@pytest.mark.asyncio
@mock.patch("astronomer.providers.core.hooks.astro.AstroHook._headers")
async def test_get_a_task_instance(self, mock_headers):
external_dag_id = "your_external_dag_id"
dag_run_id = "your_dag_run_id"
external_task_id = "your_external_task_id"
url = f"https://test.com/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}/taskInstances/{external_task_id}"

# Mocking necessary objects
your_class_instance = AstroHook()
your_class_instance.get_conn = Mock(return_value=("https://test.com", "Test Token"))
mock_headers.return_value = {"accept": "application/json", "Authorization": "Bearer Token"}
response_data = {
"dag_id": "my_dag",
"dag_run_id": "manual__2024-02-14T19:06:32.053905+00:00",
"duration": 600.233105,
"end_date": "2024-02-14T19:16:33.459676+00:00",
"execution_date": "2024-02-14T19:06:32.053905+00:00",
"executor_config": "{}",
"hostname": "d10fc8b0ad27",
"map_index": -1,
"max_tries": 0,
"note": None,
"operator": "_PythonDecoratedOperator",
"pid": 927,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 1,
"queue": "default",
"queued_when": "2024-02-14T19:06:33.036108+00:00",
"rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None},
"sla_miss": None,
"start_date": "2024-02-14T19:06:33.226571+00:00",
"state": "success",
"task_id": "my_python_function",
"trigger": None,
"triggerer_job": None,
"try_number": 1,
"unixname": "astro",
}

with aioresponses() as mock_session:
mock_session.get(
url,
headers=your_class_instance._headers,
status=200,
payload=response_data,
)

result = await your_class_instance.get_a_task_instance(
external_dag_id, dag_run_id, external_task_id
)

assert result == response_data
10 changes: 5 additions & 5 deletions tests/core/triggers/test_astro.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_serialize(self):
assert serialized_data == expected_result

@pytest.mark.asyncio
@patch("astronomer.providers.core.hooks.astro.AstroHook.get_task_instance")
@patch("astronomer.providers.core.hooks.astro.AstroHook.get_a_task_instance")
async def test_run_task_successful(self, mock_get_task_instance):
trigger = AstroDeploymentTrigger(
external_dag_id="external_dag_id",
Expand All @@ -51,7 +51,7 @@ async def test_run_task_successful(self, mock_get_task_instance):
assert actual == TriggerEvent({"status": "done"})

@pytest.mark.asyncio
@patch("astronomer.providers.core.hooks.astro.AstroHook.get_task_instance")
@patch("astronomer.providers.core.hooks.astro.AstroHook.get_a_task_instance")
async def test_run_task_failed(self, mock_get_task_instance):
trigger = AstroDeploymentTrigger(
external_dag_id="external_dag_id",
Expand All @@ -68,7 +68,7 @@ async def test_run_task_failed(self, mock_get_task_instance):
assert actual == TriggerEvent({"status": "failed"})

@pytest.mark.asyncio
@patch("astronomer.providers.core.hooks.astro.AstroHook.get_dag_run")
@patch("astronomer.providers.core.hooks.astro.AstroHook.get_a_dag_run")
async def test_run_dag_successful(self, mock_get_dag_run):
trigger = AstroDeploymentTrigger(
external_dag_id="external_dag_id",
Expand All @@ -85,7 +85,7 @@ async def test_run_dag_successful(self, mock_get_dag_run):
assert actual == TriggerEvent({"status": "done"})

@pytest.mark.asyncio
@patch("astronomer.providers.core.hooks.astro.AstroHook.get_dag_run")
@patch("astronomer.providers.core.hooks.astro.AstroHook.get_a_dag_run")
async def test_run_dag_failed(self, mock_get_dag_run):
trigger = AstroDeploymentTrigger(
external_dag_id="external_dag_id",
Expand All @@ -101,7 +101,7 @@ async def test_run_dag_failed(self, mock_get_dag_run):
assert actual == TriggerEvent({"status": "failed"})

@pytest.mark.asyncio
@patch("astronomer.providers.core.hooks.astro.AstroHook.get_dag_run")
@patch("astronomer.providers.core.hooks.astro.AstroHook.get_a_dag_run")
async def test_run_dag_wait(self, mock_get_dag_run):
trigger = AstroDeploymentTrigger(
external_dag_id="external_dag_id",
Expand Down

0 comments on commit 2a80a27

Please sign in to comment.