From dc62eb839d299ee67a1f917fed92b282bd7633b1 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 19 Dec 2024 18:58:20 +0530 Subject: [PATCH] AIP-72: Allow pushing and pulling XCom from Task Context Part of https://github.com/apache/airflow/issues/44481 --- .../api_fastapi/execution_api/routes/xcoms.py | 16 ++- airflow/utils/context.py | 2 +- task_sdk/src/airflow/sdk/api/client.py | 9 +- .../src/airflow/sdk/execution_time/comms.py | 2 +- .../airflow/sdk/execution_time/supervisor.py | 3 +- .../airflow/sdk/execution_time/task_runner.py | 108 +++++++++++++++++- task_sdk/tests/api/test_client.py | 16 ++- task_sdk/tests/execution_time/test_context.py | 2 +- .../tests/execution_time/test_supervisor.py | 2 +- .../execution_api/routes/test_xcoms.py | 83 +++++++++++--- tests/models/test_xcom.py | 22 +++- 11 files changed, 230 insertions(+), 35 deletions(-) diff --git a/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow/api_fastapi/execution_api/routes/xcoms.py index 93285eb3a74f4..faacd543fca2b 100644 --- a/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -22,7 +22,6 @@ from typing import Annotated from fastapi import Body, HTTPException, Query, status -from pydantic import Json from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.router import AirflowRouter @@ -92,7 +91,7 @@ def get_xcom( ) try: - xcom_value = BaseXCom.deserialize_value(result) + xcom_value = BaseXCom.orm_deserialize_value(result) except json.JSONDecodeError: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -118,7 +117,7 @@ def set_xcom( task_id: str, key: str, value: Annotated[ - Json, + str, Body( description="A JSON-formatted string representing the value to set for the XCom.", openapi_examples={ @@ -142,6 +141,17 @@ def set_xcom( map_index: Annotated[int, Query()] = -1, ): """Set an Airflow XCom.""" + try: + json.loads(value) + except json.JSONDecodeError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "reason": "invalid_format", + "message": "XCom value is not a valid JSON-formatted string", + }, + ) + if not has_xcom_access(key, token): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, diff --git a/airflow/utils/context.py b/airflow/utils/context.py index c6cf2db498532..28bcd2fe6701d 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -477,7 +477,7 @@ def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]: """ if not isinstance(source, Context): # Sometimes we are passed a plain dict (usually in tests, or in User's - # custom operators) -- be lienent about what we accept so we don't + # custom operators) -- be lenient about what we accept so we don't # break anything for users. return source diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py index da91c2bd98dd2..8fd052bb9de55 100644 --- a/task_sdk/src/airflow/sdk/api/client.py +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -207,11 +207,16 @@ class XComOperations: def __init__(self, client: Client): self.client = client - def get(self, dag_id: str, run_id: str, task_id: str, key: str, map_index: int = -1) -> XComResponse: + def get( + self, dag_id: str, run_id: str, task_id: str, key: str, map_index: int | None = -1 + ) -> XComResponse: """Get a XCom value from the API server.""" # TODO: check if we need to use map_index as params in the uri # ref: https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81 - resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params={"map_index": map_index}) + params = {} + if map_index: + params.update({"map_index": map_index}) + resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params) return XComResponse.model_validate_json(resp.read()) def set( diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index 97fadcafc409e..b90787ca4cfc9 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -167,7 +167,7 @@ class GetXCom(BaseModel): dag_id: str run_id: str task_id: str - map_index: int = -1 + map_index: int | None = None type: Literal["GetXCom"] = "GetXCom" diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 73bc446a28df8..811d1ce86a60d 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -62,7 +62,6 @@ from airflow.sdk.execution_time.comms import ( ConnectionResult, DeferTask, - ErrorResponse, GetConnection, GetVariable, GetXCom, @@ -719,7 +718,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): if isinstance(conn, ConnectionResponse): conn_result = ConnectionResult.from_conn_response(conn) resp = conn_result.model_dump_json(exclude_unset=True).encode() - elif isinstance(conn, ErrorResponse): + else: resp = conn.model_dump_json().encode() elif isinstance(msg, GetVariable): var = self.client.variables.get(msg.key) diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 48dd3ecbfcd67..031c9297ac5d9 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -21,6 +21,7 @@ import os import sys +from collections.abc import Iterable from datetime import datetime, timezone from io import FileIO from typing import TYPE_CHECKING, Annotated, Any, Generic, TextIO, TypeVar @@ -33,12 +34,15 @@ from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.execution_time.comms import ( DeferTask, + GetXCom, RescheduleTask, SetRenderedFields, + SetXCom, StartupDetails, TaskState, ToSupervisor, ToTask, + XComResult, ) from airflow.sdk.execution_time.context import ConnectionAccessor @@ -111,11 +115,109 @@ def get_template_context(self): "ts_nodash_with_tz": ts_nodash_with_tz, } context.update(context_from_server) + # TODO: We should use/move TypeDict from airflow.utils.context.Context return context - def xcom_pull(self, *args, **kwargs): ... + def xcom_pull( + self, + task_ids: str | None = None, # TODO: Simplify to a single task_id? (breaking change) + dag_id: str | None = None, + key: str = "return_value", # TODO: Make this a constant (``XCOM_RETURN_KEY``) + include_prior_dates: bool = False, # TODO: Add support for this + *, + map_indexes: int | None = None, + default: Any = None, + run_id: str | None = None, + ) -> Any: + """ + Pull XComs that optionally meet certain criteria. + + :param key: A key for the XCom. If provided, only XComs with matching + keys will be returned. The default key is ``'return_value'``, also + available as constant ``XCOM_RETURN_KEY``. This key is automatically + given to XComs returned by tasks (as opposed to being pushed + manually). To remove the filter, pass *None*. + :param task_ids: Only XComs from tasks with matching ids will be + pulled. Pass *None* to remove the filter. + :param dag_id: If provided, only pulls XComs from this DAG. If *None* + (default), the DAG of the calling task is used. + :param map_indexes: If provided, only pull XComs with matching indexes. + If *None* (default), this is inferred from the task(s) being pulled + (see below for details). + :param include_prior_dates: If False, only XComs from the current + logical_date are returned. If *True*, XComs from previous dates + are returned as well. + :param run_id: If provided, only pulls XComs from a DagRun w/a matching run_id. + If *None* (default), the run_id of the calling task is used. + + When pulling one single task (``task_id`` is *None* or a str) without + specifying ``map_indexes``, the return value is inferred from whether + the specified task is mapped. If not, value from the one single task + instance is returned. If the task to pull is mapped, an iterator (not a + list) yielding XComs from mapped task instances is returned. In either + case, ``default`` (*None* if not specified) is returned if no matching + XComs are found. + + When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is + a non-str iterable), a list of matching XComs is returned. Elements in + the list is ordered by item ordering in ``task_id`` and ``map_index``. + """ + if dag_id is None: + dag_id = self.dag_id + if run_id is None: + run_id = self.run_id + + if task_ids is None: + task_ids = self.task_id + elif not isinstance(task_ids, str) and isinstance(task_ids, Iterable): + # TODO: Handle multiple task_ids or remove support + raise NotImplementedError("Multiple task_ids are not supported yet") + + if map_indexes is None: + map_indexes = self.map_index + elif isinstance(map_indexes, Iterable): + # TODO: Handle multiple map_indexes or remove support + raise NotImplementedError("Multiple map_indexes are not supported yet") + + log = structlog.get_logger(logger_name="task") + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetXCom( + key=key, + dag_id=dag_id, + task_id=task_ids, + run_id=run_id, + map_index=map_indexes, + ), + ) + + msg = SUPERVISOR_COMMS.get_message() + if TYPE_CHECKING: + isinstance(msg, XComResult) - def xcom_push(self, *args, **kwargs): ... + value = msg.value + if value is not None: + return value + return default + + def xcom_push(self, key: str, value: Any): + """ + Make an XCom available for tasks to pull. + + :param key: Key to store the value under. + :param value: Value to store. Only be JSON-serializable may be used otherwise. + """ + log = structlog.get_logger(logger_name="task") + SUPERVISOR_COMMS.send_request( + log=log, + msg=SetXCom( + key=key, + value=value, + dag_id=self.dag_id, + task_id=self.task_id, + run_id=self.run_id, + ), + ) def parse(what: StartupDetails) -> RuntimeTaskInstance: @@ -188,7 +290,7 @@ def get_message(self) -> ReceiveMsgType: def send_request(self, log: Logger, msg: SendMsgType): encoded_msg = msg.model_dump_json().encode() + b"\n" - log.debug("Sending request", json=encoded_msg) + # log.debug("Sending request", json=encoded_msg) self.request_socket.write(encoded_msg) diff --git a/task_sdk/tests/api/test_client.py b/task_sdk/tests/api/test_client.py index 16b8d6c9bfe15..279502793ee23 100644 --- a/task_sdk/tests/api/test_client.py +++ b/task_sdk/tests/api/test_client.py @@ -313,13 +313,23 @@ class TestXCOMOperations: response parsing. """ - def test_xcom_get_success(self): + @pytest.mark.parametrize( + "value", + [ + pytest.param("value1", id="string-value"), + pytest.param({"key1": "value1"}, id="dict-value"), + pytest.param('{"key1": "value1"}', id="dict-str-value"), + pytest.param(["value1", "value2"], id="list-value"), + pytest.param({"key": "test_key", "value": {"key2": "value2"}}, id="nested-dict-value"), + ], + ) + def test_xcom_get_success(self, value): # Simulate a successful response from the server when getting an xcom def handle_request(request: httpx.Request) -> httpx.Response: if request.url.path == "/xcoms/dag_id/run_id/task_id/key": return httpx.Response( status_code=201, - json={"key": "test_key", "value": "test_value"}, + json={"key": "test_key", "value": value}, ) return httpx.Response(status_code=400, json={"detail": "Bad Request"}) @@ -332,7 +342,7 @@ def handle_request(request: httpx.Request) -> httpx.Response: ) assert isinstance(result, XComResponse) assert result.key == "test_key" - assert result.value == "test_value" + assert result.value == value def test_xcom_get_success_with_map_index(self): # Simulate a successful response from the server when getting an xcom with map_index passed diff --git a/task_sdk/tests/execution_time/test_context.py b/task_sdk/tests/execution_time/test_context.py index 65d2b50f8a17f..a3220c3bef1e3 100644 --- a/task_sdk/tests/execution_time/test_context.py +++ b/task_sdk/tests/execution_time/test_context.py @@ -67,7 +67,7 @@ def test_getattr_connection(self): ) as mock_supervisor_comms: mock_supervisor_comms.get_message.return_value = conn_result - # Fetch the connection; Triggers __getattr__ + # Fetch the connection; triggers __getattr__ conn = accessor.mysql_conn expected_conn = Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306) diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 098e5914f2551..53da57cf178eb 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -817,7 +817,7 @@ def watched_subprocess(self, mocker): GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), b'{"key":"test_key","value":"test_value","type":"XComResult"}\n', "xcoms.get", - ("test_dag", "test_run", "test_task", "test_key", -1), + ("test_dag", "test_run", "test_task", "test_key", None), XComResult(key="test_key", value="test_value"), id="get_xcom", ), diff --git a/tests/api_fastapi/execution_api/routes/test_xcoms.py b/tests/api_fastapi/execution_api/routes/test_xcoms.py index d9d33f28d4425..6347db9b6db28 100644 --- a/tests/api_fastapi/execution_api/routes/test_xcoms.py +++ b/tests/api_fastapi/execution_api/routes/test_xcoms.py @@ -21,6 +21,7 @@ import pytest +from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse from airflow.models.dagrun import DagRun from airflow.models.xcom import XCom from airflow.utils.session import create_session @@ -40,19 +41,21 @@ class TestXComsGetEndpoint: @pytest.mark.parametrize( ("value", "expected_value"), [ - ("value1", "value1"), - ({"key2": "value2"}, {"key2": "value2"}), - ({"key2": "value2", "key3": ["value3"]}, {"key2": "value2", "key3": ["value3"]}), - (["value1"], ["value1"]), + ('"value1"', '"value1"'), + ('{"key2": "value2"}', '{"key2": "value2"}'), + ('{"key2": "value2", "key3": ["value3"]}', '{"key2": "value2", "key3": ["value3"]}'), + ('["value1"]', '["value1"]'), ], ) def test_xcom_get_from_db(self, client, create_task_instance, session, value, expected_value): """Test that XCom value is returned from the database in JSON-compatible format.""" ti = create_task_instance() ti.xcom_push(key="xcom_1", value=value, session=session) - session.commit() + xcom = session.query(XCom).filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="xcom_1").first() + assert xcom.value == expected_value + response = client.get(f"/execution/xcoms/{ti.dag_id}/{ti.run_id}/{ti.task_id}/xcom_1") assert response.status_code == 200 @@ -86,19 +89,17 @@ class TestXComsSetEndpoint: @pytest.mark.parametrize( ("value", "expected_value"), [ - ('"value1"', "value1"), - ('{"key2": "value2"}', {"key2": "value2"}), - ('{"key2": "value2", "key3": ["value3"]}', {"key2": "value2", "key3": ["value3"]}), - ('["value1"]', ["value1"]), + ('"value1"', '"value1"'), + ('{"key2": "value2"}', '{"key2": "value2"}'), + ('{"key2": "value2", "key3": ["value3"]}', '{"key2": "value2", "key3": ["value3"]}'), + ('["value1"]', '["value1"]'), ], ) def test_xcom_set(self, client, create_task_instance, session, value, expected_value): """ Test that XCom value is set correctly. The value is passed as a JSON string in the request body. - This is then validated via Pydantic.Json type in the request body and converted to - a Python object before being sent to XCom.set. XCom.set then uses json.dumps to - serialize it and store the value in the database. This is done so that Task SDK in multiple - languages can use the same API to set XCom values. + XCom.set then uses json.dumps to serialize it and store the value in the database. + This is done so that Task SDK in multiple languages can use the same API to set XCom values. """ ti = create_task_instance() session.commit() @@ -114,6 +115,24 @@ def test_xcom_set(self, client, create_task_instance, session, value, expected_v xcom = session.query(XCom).filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="xcom_1").first() assert xcom.value == expected_value + @pytest.mark.parametrize( + "value", + ["value1", {"key2": "value2"}, ["value1"]], + ) + def test_xcom_set_invalid_json(self, client, create_task_instance, value): + response = client.post( + "/execution/xcoms/dag/runid/task/xcom_1", + json="invalid_json", + ) + + assert response.status_code == 400 + assert response.json() == { + "detail": { + "reason": "invalid_format", + "message": "XCom value is not a valid JSON-formatted string", + } + } + def test_xcom_access_denied(self, client): with mock.patch("airflow.api_fastapi.execution_api.routes.xcoms.has_xcom_access", return_value=False): response = client.post( @@ -128,3 +147,41 @@ def test_xcom_access_denied(self, client): "message": "Task does not have access to set XCom key 'xcom_perms'", } } + + @pytest.mark.parametrize( + ("value", "expected_value"), + [ + ('"value1"', '"value1"'), + ('{"key2": "value2"}', '{"key2": "value2"}'), + ('{"key2": "value2", "key3": ["value3"]}', '{"key2": "value2", "key3": ["value3"]}'), + ('["value1"]', '["value1"]'), + ], + ) + def test_xcom_roundtrip(self, client, create_task_instance, session, value, expected_value): + """ + Test that XCom value is set and retrieved correctly using API. + + This test sets an XCom value using the API and then retrieves it using the API so we can + ensure client and server are working correctly together. The server expects a JSON string + and it will also return a JSON string. It is the client's responsibility to parse the JSON + string into a native object. This is useful for Task SDKs in other languages. + """ + ti = create_task_instance() + + session.commit() + client.post( + f"/execution/xcoms/{ti.dag_id}/{ti.run_id}/{ti.task_id}/test_xcom_roundtrip", + json=value, + ) + + xcom = ( + session.query(XCom) + .filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="test_xcom_roundtrip") + .first() + ) + assert xcom.value == expected_value + + response = client.get(f"/execution/xcoms/{ti.dag_id}/{ti.run_id}/{ti.task_id}/test_xcom_roundtrip") + + assert response.status_code == 200 + assert XComResponse.model_validate_json(response.read()).value == expected_value diff --git a/tests/models/test_xcom.py b/tests/models/test_xcom.py index 85e3249fe6cbb..1e62b414de2da 100644 --- a/tests/models/test_xcom.py +++ b/tests/models/test_xcom.py @@ -324,18 +324,30 @@ def test_xcom_get_many_from_prior_dates(self, session, tis_for_xcom_get_many_fro class TestXComSet: - def test_xcom_set(self, session, task_instance): + @pytest.mark.parametrize( + ("key", "value", "expected_value"), + [ + pytest.param("xcom_dict", {"key": "value"}, {"key": "value"}, id="dict"), + pytest.param("xcom_int", 123, 123, id="int"), + pytest.param("xcom_float", 45.67, 45.67, id="float"), + pytest.param("xcom_str", "hello", "hello", id="str"), + pytest.param("xcom_bool", True, True, id="bool"), + pytest.param("xcom_list", [1, 2, 3], [1, 2, 3], id="list"), + ], + ) + def test_xcom_set(self, session, task_instance, key, value, expected_value): XCom.set( - key="xcom_1", - value={"key": "value"}, + key=key, + value=value, dag_id=task_instance.dag_id, task_id=task_instance.task_id, run_id=task_instance.run_id, session=session, ) stored_xcoms = session.query(XCom).all() - assert stored_xcoms[0].key == "xcom_1" - assert stored_xcoms[0].value == {"key": "value"} + assert stored_xcoms[0].key == key + assert isinstance(stored_xcoms[0].value, type(expected_value)) + assert stored_xcoms[0].value == expected_value assert stored_xcoms[0].dag_id == "dag" assert stored_xcoms[0].task_id == "task_1" assert stored_xcoms[0].logical_date == task_instance.logical_date