Skip to content

Commit

Permalink
AIP-72: Allow pushing and pulling XCom from Task Context
Browse files Browse the repository at this point in the history
Part of #44481
  • Loading branch information
kaxil committed Dec 24, 2024
1 parent aa30ed9 commit dc62eb8
Show file tree
Hide file tree
Showing 11 changed files with 230 additions and 35 deletions.
16 changes: 13 additions & 3 deletions airflow/api_fastapi/execution_api/routes/xcoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from typing import Annotated

from fastapi import Body, HTTPException, Query, status
from pydantic import Json

from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
Expand Down Expand Up @@ -92,7 +91,7 @@ def get_xcom(
)

try:
xcom_value = BaseXCom.deserialize_value(result)
xcom_value = BaseXCom.orm_deserialize_value(result)
except json.JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Expand All @@ -118,7 +117,7 @@ def set_xcom(
task_id: str,
key: str,
value: Annotated[
Json,
str,
Body(
description="A JSON-formatted string representing the value to set for the XCom.",
openapi_examples={
Expand All @@ -142,6 +141,17 @@ def set_xcom(
map_index: Annotated[int, Query()] = -1,
):
"""Set an Airflow XCom."""
try:
json.loads(value)
except json.JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"reason": "invalid_format",
"message": "XCom value is not a valid JSON-formatted string",
},
)

if not has_xcom_access(key, token):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
Expand Down
2 changes: 1 addition & 1 deletion airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]:
"""
if not isinstance(source, Context):
# Sometimes we are passed a plain dict (usually in tests, or in User's
# custom operators) -- be lienent about what we accept so we don't
# custom operators) -- be lenient about what we accept so we don't
# break anything for users.
return source

Expand Down
9 changes: 7 additions & 2 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,16 @@ class XComOperations:
def __init__(self, client: Client):
self.client = client

def get(self, dag_id: str, run_id: str, task_id: str, key: str, map_index: int = -1) -> XComResponse:
def get(
self, dag_id: str, run_id: str, task_id: str, key: str, map_index: int | None = -1
) -> XComResponse:
"""Get a XCom value from the API server."""
# TODO: check if we need to use map_index as params in the uri
# ref: https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81
resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params={"map_index": map_index})
params = {}
if map_index:
params.update({"map_index": map_index})
resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params)
return XComResponse.model_validate_json(resp.read())

def set(
Expand Down
2 changes: 1 addition & 1 deletion task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class GetXCom(BaseModel):
dag_id: str
run_id: str
task_id: str
map_index: int = -1
map_index: int | None = None
type: Literal["GetXCom"] = "GetXCom"


Expand Down
3 changes: 1 addition & 2 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
from airflow.sdk.execution_time.comms import (
ConnectionResult,
DeferTask,
ErrorResponse,
GetConnection,
GetVariable,
GetXCom,
Expand Down Expand Up @@ -719,7 +718,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
if isinstance(conn, ConnectionResponse):
conn_result = ConnectionResult.from_conn_response(conn)
resp = conn_result.model_dump_json(exclude_unset=True).encode()
elif isinstance(conn, ErrorResponse):
else:
resp = conn.model_dump_json().encode()
elif isinstance(msg, GetVariable):
var = self.client.variables.get(msg.key)
Expand Down
108 changes: 105 additions & 3 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import os
import sys
from collections.abc import Iterable
from datetime import datetime, timezone
from io import FileIO
from typing import TYPE_CHECKING, Annotated, Any, Generic, TextIO, TypeVar
Expand All @@ -33,12 +34,15 @@
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.execution_time.comms import (
DeferTask,
GetXCom,
RescheduleTask,
SetRenderedFields,
SetXCom,
StartupDetails,
TaskState,
ToSupervisor,
ToTask,
XComResult,
)
from airflow.sdk.execution_time.context import ConnectionAccessor

Expand Down Expand Up @@ -111,11 +115,109 @@ def get_template_context(self):
"ts_nodash_with_tz": ts_nodash_with_tz,
}
context.update(context_from_server)
# TODO: We should use/move TypeDict from airflow.utils.context.Context
return context

def xcom_pull(self, *args, **kwargs): ...
def xcom_pull(
self,
task_ids: str | None = None, # TODO: Simplify to a single task_id? (breaking change)
dag_id: str | None = None,
key: str = "return_value", # TODO: Make this a constant (``XCOM_RETURN_KEY``)
include_prior_dates: bool = False, # TODO: Add support for this
*,
map_indexes: int | None = None,
default: Any = None,
run_id: str | None = None,
) -> Any:
"""
Pull XComs that optionally meet certain criteria.
:param key: A key for the XCom. If provided, only XComs with matching
keys will be returned. The default key is ``'return_value'``, also
available as constant ``XCOM_RETURN_KEY``. This key is automatically
given to XComs returned by tasks (as opposed to being pushed
manually). To remove the filter, pass *None*.
:param task_ids: Only XComs from tasks with matching ids will be
pulled. Pass *None* to remove the filter.
:param dag_id: If provided, only pulls XComs from this DAG. If *None*
(default), the DAG of the calling task is used.
:param map_indexes: If provided, only pull XComs with matching indexes.
If *None* (default), this is inferred from the task(s) being pulled
(see below for details).
:param include_prior_dates: If False, only XComs from the current
logical_date are returned. If *True*, XComs from previous dates
are returned as well.
:param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id.
If *None* (default), the run_id of the calling task is used.
When pulling one single task (``task_id`` is *None* or a str) without
specifying ``map_indexes``, the return value is inferred from whether
the specified task is mapped. If not, value from the one single task
instance is returned. If the task to pull is mapped, an iterator (not a
list) yielding XComs from mapped task instances is returned. In either
case, ``default`` (*None* if not specified) is returned if no matching
XComs are found.
When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is
a non-str iterable), a list of matching XComs is returned. Elements in
the list is ordered by item ordering in ``task_id`` and ``map_index``.
"""
if dag_id is None:
dag_id = self.dag_id
if run_id is None:
run_id = self.run_id

if task_ids is None:
task_ids = self.task_id
elif not isinstance(task_ids, str) and isinstance(task_ids, Iterable):
# TODO: Handle multiple task_ids or remove support
raise NotImplementedError("Multiple task_ids are not supported yet")

if map_indexes is None:
map_indexes = self.map_index
elif isinstance(map_indexes, Iterable):
# TODO: Handle multiple map_indexes or remove support
raise NotImplementedError("Multiple map_indexes are not supported yet")

log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(
log=log,
msg=GetXCom(
key=key,
dag_id=dag_id,
task_id=task_ids,
run_id=run_id,
map_index=map_indexes,
),
)

msg = SUPERVISOR_COMMS.get_message()
if TYPE_CHECKING:
isinstance(msg, XComResult)

def xcom_push(self, *args, **kwargs): ...
value = msg.value
if value is not None:
return value
return default

def xcom_push(self, key: str, value: Any):
"""
Make an XCom available for tasks to pull.
:param key: Key to store the value under.
:param value: Value to store. Only be JSON-serializable may be used otherwise.
"""
log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(
log=log,
msg=SetXCom(
key=key,
value=value,
dag_id=self.dag_id,
task_id=self.task_id,
run_id=self.run_id,
),
)


def parse(what: StartupDetails) -> RuntimeTaskInstance:
Expand Down Expand Up @@ -188,7 +290,7 @@ def get_message(self) -> ReceiveMsgType:
def send_request(self, log: Logger, msg: SendMsgType):
encoded_msg = msg.model_dump_json().encode() + b"\n"

log.debug("Sending request", json=encoded_msg)
# log.debug("Sending request", json=encoded_msg)
self.request_socket.write(encoded_msg)


Expand Down
16 changes: 13 additions & 3 deletions task_sdk/tests/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,13 +313,23 @@ class TestXCOMOperations:
response parsing.
"""

def test_xcom_get_success(self):
@pytest.mark.parametrize(
"value",
[
pytest.param("value1", id="string-value"),
pytest.param({"key1": "value1"}, id="dict-value"),
pytest.param('{"key1": "value1"}', id="dict-str-value"),
pytest.param(["value1", "value2"], id="list-value"),
pytest.param({"key": "test_key", "value": {"key2": "value2"}}, id="nested-dict-value"),
],
)
def test_xcom_get_success(self, value):
# Simulate a successful response from the server when getting an xcom
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/xcoms/dag_id/run_id/task_id/key":
return httpx.Response(
status_code=201,
json={"key": "test_key", "value": "test_value"},
json={"key": "test_key", "value": value},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})

Expand All @@ -332,7 +342,7 @@ def handle_request(request: httpx.Request) -> httpx.Response:
)
assert isinstance(result, XComResponse)
assert result.key == "test_key"
assert result.value == "test_value"
assert result.value == value

def test_xcom_get_success_with_map_index(self):
# Simulate a successful response from the server when getting an xcom with map_index passed
Expand Down
2 changes: 1 addition & 1 deletion task_sdk/tests/execution_time/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_getattr_connection(self):
) as mock_supervisor_comms:
mock_supervisor_comms.get_message.return_value = conn_result

# Fetch the connection; Triggers __getattr__
# Fetch the connection; triggers __getattr__
conn = accessor.mysql_conn

expected_conn = Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306)
Expand Down
2 changes: 1 addition & 1 deletion task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ def watched_subprocess(self, mocker):
GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"),
b'{"key":"test_key","value":"test_value","type":"XComResult"}\n',
"xcoms.get",
("test_dag", "test_run", "test_task", "test_key", -1),
("test_dag", "test_run", "test_task", "test_key", None),
XComResult(key="test_key", value="test_value"),
id="get_xcom",
),
Expand Down
Loading

0 comments on commit dc62eb8

Please sign in to comment.