Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HTTP retry handling into task SDK api.client #45121

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions task_sdk/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"msgspec>=0.18.6",
"psutil>=6.1.0",
"structlog>=24.4.0",
"retryhttp>=1.2.0",
]
classifiers = [
"Framework :: Apache Airflow",
Expand Down
28 changes: 28 additions & 0 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from __future__ import annotations

import logging
import os
import sys
import uuid
from http import HTTPStatus
Expand All @@ -26,6 +28,8 @@
import msgspec
import structlog
from pydantic import BaseModel
from retryhttp import retry, wait_retry_after
from tenacity import before_log, wait_random_exponential
from uuid6 import uuid7

from airflow.sdk import __version__
Expand Down Expand Up @@ -268,6 +272,15 @@ def noop_handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(200, json={"text": "Hello, world!"})


# Config options for SDK how retries on HTTP requests should be handled
# Note: Given defaults make attempts after 1, 3, 7, 15, 31seconds, 1:03, 2:07, 3:37 and fails after 5:07min
# So far there is no other config facility in SDK we use ENV for the moment
# TODO: Consider these env variables while handling airflow confs in task sdk
API_RETRIES = int(os.getenv("AIRFLOW__WORKERS__API_RETRIES", 10))
API_RETRY_WAIT_MIN = float(os.getenv("AIRFLOW__WORKERS__API_RETRY_WAIT_MIN", 1.0))
API_RETRY_WAIT_MAX = float(os.getenv("AIRFLOW__WORKERS__API_RETRY_WAIT_MAX", 90.0))

jscheffl marked this conversation as resolved.
Show resolved Hide resolved

class Client(httpx.Client):
def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, **kwargs: Any):
if (not base_url) ^ dry_run:
Expand All @@ -289,6 +302,21 @@ def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, *
**kwargs,
)

_default_wait = wait_random_exponential(min=API_RETRY_WAIT_MIN, max=API_RETRY_WAIT_MAX)

@retry(
reraise=True,
max_attempt_number=API_RETRIES,
wait_server_errors=_default_wait,
wait_network_errors=_default_wait,
wait_timeouts=_default_wait,
wait_rate_limited=wait_retry_after(fallback=_default_wait), # No infinite timeout on HTTP 429
before_sleep=before_log(log, logging.WARNING),
)
def request(self, *args, **kwargs):
"""Implement a convenience for httpx.Client.request with a retry layer."""
return super().request(*args, **kwargs)

# We "group" or "namespace" operations by what they operate on, rather than a flat namespace with all
# methods on one object prefixed with the object type (`.task_instances.update` rather than
# `task_instance_update` etc.)
Expand Down
157 changes: 124 additions & 33 deletions task_sdk/tests/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import json
from unittest import mock

import httpx
import pytest
Expand All @@ -30,18 +31,28 @@
from airflow.utils.state import TerminalTIState


class TestClient:
def test_error_parsing(self):
def handle_request(request: httpx.Request) -> httpx.Response:
"""
A transport handle that always returns errors
"""
def make_client(transport: httpx.MockTransport) -> Client:
"""Get a client with a custom transport"""
return Client(base_url="test://server", token="", transport=transport)

return httpx.Response(422, json={"detail": [{"loc": ["#0"], "msg": "err", "type": "required"}]})

client = Client(
base_url=None, dry_run=True, token="", mounts={"'http://": httpx.MockTransport(handle_request)}
)
def make_client_w_responses(responses: list[httpx.Response]) -> Client:
"""Helper fixture to create a mock client with custom responses."""

def handle_request(request: httpx.Request) -> httpx.Response:
return responses.pop(0)

return Client(
base_url=None, dry_run=True, token="", mounts={"'http://": httpx.MockTransport(handle_request)}
)


class TestClient:
def test_error_parsing(self):
responses = [
httpx.Response(422, json={"detail": [{"loc": ["#0"], "msg": "err", "type": "required"}]})
]
client = make_client_w_responses(responses)

with pytest.raises(ServerResponseError) as err:
client.get("http://error")
Expand All @@ -53,39 +64,92 @@ def handle_request(request: httpx.Request) -> httpx.Response:
]

def test_error_parsing_plain_text(self):
def handle_request(request: httpx.Request) -> httpx.Response:
"""
A transport handle that always returns errors
"""

return httpx.Response(422, content=b"Internal Server Error")

client = Client(
base_url=None, dry_run=True, token="", mounts={"'http://": httpx.MockTransport(handle_request)}
)
responses = [httpx.Response(422, content=b"Internal Server Error")]
client = make_client_w_responses(responses)

with pytest.raises(httpx.HTTPStatusError) as err:
client.get("http://error")
assert not isinstance(err.value, ServerResponseError)

def test_error_parsing_other_json(self):
def handle_request(request: httpx.Request) -> httpx.Response:
# Some other json than an error body.
return httpx.Response(404, json={"detail": "Not found"})

client = Client(
base_url=None, dry_run=True, token="", mounts={"'http://": httpx.MockTransport(handle_request)}
)
responses = [httpx.Response(404, json={"detail": "Not found"})]
client = make_client_w_responses(responses)

with pytest.raises(ServerResponseError) as err:
client.get("http://error")
assert err.value.args == ("Not found",)
assert err.value.detail is None

@mock.patch("time.sleep", return_value=None)
def test_retry_handling_unrecoverable_error(self, mock_sleep):
responses: list[httpx.Response] = [
*[httpx.Response(500, text="Internal Server Error")] * 11,
httpx.Response(200, json={"detail": "Recovered from error - but will fail before"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]
client = make_client_w_responses(responses)

jscheffl marked this conversation as resolved.
Show resolved Hide resolved
def make_client(transport: httpx.MockTransport) -> Client:
"""Get a client with a custom transport"""
return Client(base_url="test://server", token="", transport=transport)
with pytest.raises(httpx.HTTPStatusError) as err:
client.get("http://error")
assert not isinstance(err.value, ServerResponseError)
assert len(responses) == 3
assert mock_sleep.call_count == 9

@mock.patch("time.sleep", return_value=None)
def test_retry_handling_recovered(self, mock_sleep):
responses: list[httpx.Response] = [
*[httpx.Response(500, text="Internal Server Error")] * 3,
httpx.Response(200, json={"detail": "Recovered from error"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]
client = make_client_w_responses(responses)

response = client.get("http://error")
assert response.status_code == 200
assert len(responses) == 1
assert mock_sleep.call_count == 3

@mock.patch("time.sleep", return_value=None)
def test_retry_handling_overload(self, mock_sleep):
responses: list[httpx.Response] = [
httpx.Response(429, text="I am really busy atm, please back-off", headers={"Retry-After": "37"}),
httpx.Response(200, json={"detail": "Recovered from error"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]
client = make_client_w_responses(responses)

response = client.get("http://error")
assert response.status_code == 200
assert len(responses) == 1
assert mock_sleep.call_count == 1
assert mock_sleep.call_args[0][0] == 37

@mock.patch("time.sleep", return_value=None)
def test_retry_handling_non_retry_error(self, mock_sleep):
responses: list[httpx.Response] = [
httpx.Response(422, json={"detail": "Somehow this is a bad request"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]
client = make_client_w_responses(responses)

with pytest.raises(ServerResponseError) as err:
client.get("http://error")
assert len(responses) == 1
assert mock_sleep.call_count == 0
assert err.value.args == ("Somehow this is a bad request",)

@mock.patch("time.sleep", return_value=None)
def test_retry_handling_ok(self, mock_sleep):
responses: list[httpx.Response] = [
httpx.Response(200, json={"detail": "Recovered from error"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]
client = make_client_w_responses(responses)

response = client.get("http://error")
assert response.status_code == 200
assert len(responses) == 1
assert mock_sleep.call_count == 0


class TestTaskInstanceOperations:
jscheffl marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -95,7 +159,8 @@ class TestTaskInstanceOperations:
response parsing.
"""

def test_task_instance_start(self, make_ti_context):
@mock.patch("time.sleep", return_value=None) # To have retries not slowing down tests
def test_task_instance_start(self, mock_sleep, make_ti_context):
# Simulate a successful response from the server that starts a task
ti_id = uuid6.uuid7()
start_date = "2024-10-31T12:00:00Z"
Expand All @@ -105,7 +170,14 @@ def test_task_instance_start(self, make_ti_context):
run_type="manual",
)

# ...including a validation that retry really works
call_count = 0

def handle_request(request: httpx.Request) -> httpx.Response:
nonlocal call_count
call_count += 1
if call_count < 4:
return httpx.Response(status_code=500, json={"detail": "Internal Server Error"})
if request.url.path == f"/task-instances/{ti_id}/run":
actual_body = json.loads(request.read())
assert actual_body["pid"] == 100
Expand All @@ -120,6 +192,7 @@ def handle_request(request: httpx.Request) -> httpx.Response:
client = make_client(transport=httpx.MockTransport(handle_request))
resp = client.task_instances.start(ti_id, 100, start_date)
assert resp == ti_context
assert call_count == 4

@pytest.mark.parametrize("state", [state for state in TerminalTIState])
def test_task_instance_finish(self, state):
Expand Down Expand Up @@ -245,9 +318,17 @@ class TestVariableOperations:
response parsing.
"""

def test_variable_get_success(self):
@mock.patch("time.sleep", return_value=None) # To have retries not slowing down tests
def test_variable_get_success(self, mock_sleep):
# Simulate a successful response from the server with a variable
# ...including a validation that retry really works
call_count = 0

def handle_request(request: httpx.Request) -> httpx.Response:
nonlocal call_count
call_count += 1
if call_count < 2:
return httpx.Response(status_code=500, json={"detail": "Internal Server Error"})
if request.url.path == "/variables/test_key":
return httpx.Response(
status_code=200,
Expand All @@ -261,6 +342,7 @@ def handle_request(request: httpx.Request) -> httpx.Response:
assert isinstance(result, VariableResponse)
assert result.key == "test_key"
assert result.value == "test_value"
assert call_count == 2

def test_variable_not_found(self):
# Simulate a 404 response from the server
Expand Down Expand Up @@ -323,9 +405,17 @@ class TestXCOMOperations:
pytest.param({"key": "test_key", "value": {"key2": "value2"}}, id="nested-dict-value"),
],
)
def test_xcom_get_success(self, value):
@mock.patch("time.sleep", return_value=None) # To have retries not slowing down tests
def test_xcom_get_success(self, mock_sleep, value):
# Simulate a successful response from the server when getting an xcom
# ...including a validation that retry really works
call_count = 0

def handle_request(request: httpx.Request) -> httpx.Response:
nonlocal call_count
call_count += 1
if call_count < 3:
return httpx.Response(status_code=500, json={"detail": "Internal Server Error"})
if request.url.path == "/xcoms/dag_id/run_id/task_id/key":
return httpx.Response(
status_code=201,
Expand All @@ -343,6 +433,7 @@ def handle_request(request: httpx.Request) -> httpx.Response:
assert isinstance(result, XComResponse)
assert result.key == "test_key"
assert result.value == value
assert call_count == 3

def test_xcom_get_success_with_map_index(self):
# Simulate a successful response from the server when getting an xcom with map_index passed
Expand Down
2 changes: 2 additions & 0 deletions tests/cli/commands/remote_commands/test_task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,8 @@ def test_cli_run_no_local_no_raw_runs_executor(self, dag_maker):
mock.patch(
"airflow.executors.executor_loader.ExecutorLoader.get_default_executor"
) as get_default_mock,
mock.patch("airflow.executors.local_executor.SimpleQueue"), # Prevent a task being queued
mock.patch("airflow.executors.local_executor.LocalExecutor.end"),
):
EmptyOperator(task_id="task1")
EmptyOperator(task_id="task2", executor="foo_executor_alias")
Expand Down