From f87beb8dae6873957854a62faf29a39a62e6adc1 Mon Sep 17 00:00:00 2001 From: Amogh Date: Tue, 7 Jan 2025 15:26:56 +0530 Subject: [PATCH 1/5] AIP-72: Add support to get Variables in task SDK to author tasks --- airflow/dag_processing/processor.py | 41 +++++++++---- .../example_get_variable_using_task_sdk.py | 33 +++++++++++ .../src/airflow/sdk/definitions/variable.py | 15 ++++- .../src/airflow/sdk/execution_time/context.py | 14 +++-- .../tests/execution_time/test_task_runner.py | 57 +++++++++++++++++++ tests/dag_processing/test_processor.py | 9 +++ 6 files changed, 151 insertions(+), 18 deletions(-) create mode 100644 airflow/example_dags/example_get_variable_using_task_sdk.py diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index 981e00341a8e6..848b68d2bfb96 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -31,8 +31,11 @@ TaskCallbackRequest, ) from airflow.configuration import conf +from airflow.models import Variable from airflow.models.dagbag import DagBag -from airflow.sdk.execution_time.comms import GetConnection, GetVariable +from airflow.sdk.api.datamodels._generated import VariableResponse +from airflow.sdk.execution_time import task_runner +from airflow.sdk.execution_time.comms import GetConnection, GetVariable, VariableResult from airflow.sdk.execution_time.supervisor import WatchedSubprocess from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG from airflow.stats import Stats @@ -43,26 +46,22 @@ from airflow.typing_compat import Self from airflow.utils.context import Context +COMMS_DECODER: task_runner.CommsDecoder[ToChild, ToParent] + def _parse_file_entrypoint(): import os import structlog - from airflow.sdk.execution_time import task_runner # Parse DAG file, send JSON back up! - - comms_decoder = task_runner.CommsDecoder[DagFileParseRequest, DagFileParsingResult]( - input=sys.stdin, - decoder=TypeAdapter[DagFileParseRequest](DagFileParseRequest), - ) - msg = comms_decoder.get_message() - comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) + msg = COMMS_DECODER.get_message() + COMMS_DECODER.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) log = structlog.get_logger(logger_name="task") result = _parse_file(msg, log) - comms_decoder.send_request(log, result) + COMMS_DECODER.send_request(log, result) def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileParsingResult: @@ -181,6 +180,11 @@ class DagFileParsingResult(BaseModel): Field(discriminator="type"), ] +ToChild = Annotated[ + Union[DagFileParseRequest, VariableResult], + Field(discriminator="type"), +] + @attrs.define() class DagFileProcessorProcess(WatchedSubprocess): @@ -204,6 +208,11 @@ def start( # type: ignore[override] target: Callable[[], None] = _parse_file_entrypoint, **kwargs, ) -> Self: + global COMMS_DECODER + COMMS_DECODER = task_runner.CommsDecoder[ToChild, ToParent]( + input=sys.stdin, + decoder=TypeAdapter[ToChild](ToChild), + ) return super().start(path, callbacks, target=target, client=None, **kwargs) # type:ignore[arg-type] def _on_child_started( # type: ignore[override] @@ -235,8 +244,16 @@ def _handle_request(self, msg: ToParent, log: FilteringBoundLogger) -> None: # if isinstance(msg, DagFileParsingResult): self.parsing_result = msg return - # GetVariable etc -- parsing a dag can run top level code that asks for an Airflow Variable - super()._handle_request(msg, log) + elif isinstance(msg, GetVariable): + key = msg.key + try: + value = Variable.get(key) + except KeyError: + log.exception("Variable: %s does not exist", key) + raise + var_result = VariableResult.from_variable_response(VariableResponse(key=key, value=value)) + resp = var_result.model_dump_json(exclude_unset=True).encode() + self.stdin.write(resp + b"\n") @property def is_ready(self) -> bool: diff --git a/airflow/example_dags/example_get_variable_using_task_sdk.py b/airflow/example_dags/example_get_variable_using_task_sdk.py new file mode 100644 index 0000000000000..9d55b12cab2ec --- /dev/null +++ b/airflow/example_dags/example_get_variable_using_task_sdk.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow import DAG +from airflow.models.baseoperator import BaseOperator +from airflow.sdk import Variable + +value = Variable.get(key="my_var") + + +class CustomOperator(BaseOperator): + def execute(self, context): + print(f"Variable defined at top level of dag has value: {value}") + + +with DAG(dag_id="example_get_variable_using_task_sdk") as dag: + CustomOperator(task_id="print_top_level_variable") diff --git a/task_sdk/src/airflow/sdk/definitions/variable.py b/task_sdk/src/airflow/sdk/definitions/variable.py index 5f458580065c5..b8e75cc78edf1 100644 --- a/task_sdk/src/airflow/sdk/definitions/variable.py +++ b/task_sdk/src/airflow/sdk/definitions/variable.py @@ -21,6 +21,10 @@ import attrs +from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType +from airflow.sdk.execution_time.context import _get_variable +from airflow.sdk.types import NOTSET + @attrs.define class Variable: @@ -34,8 +38,15 @@ class Variable: """ key: str - # keeping as any for supporting deserialize_json + # keeping as any for supporting `deserialize_json` value: Any | None = None description: str | None = None - # TODO: Extend this definition for reading/writing variables without context + @classmethod + def get(cls, key: str, default: Any = NOTSET) -> Any: + try: + return _get_variable(key).value + except AirflowRuntimeError as e: + if e.error.error == ErrorType.VARIABLE_NOT_FOUND: + return default + raise diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py b/task_sdk/src/airflow/sdk/execution_time/context.py index d96f5aeda2c95..6e462c2047dff 100644 --- a/task_sdk/src/airflow/sdk/execution_time/context.py +++ b/task_sdk/src/airflow/sdk/execution_time/context.py @@ -66,18 +66,24 @@ def _get_connection(conn_id: str) -> Connection: return _convert_connection_result_conn(msg) -def _get_variable(key: str, deserialize_json: bool) -> Variable: +def _get_variable(key: str, deserialize_json: bool = False) -> Variable: # TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms` # or `airflow.sdk.execution_time.variable` # A reason to not move it to `airflow.sdk.execution_time.comms` is that it # will make that module depend on Task SDK, which is not ideal because we intend to # keep Task SDK as a separate package than execution time mods. from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + try: + # We check the hypothesis if the request for variable came from task. + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS as COMMS # type: ignore + except ImportError: + # If not, hypothesis is false and this request is from dag level. + from airflow.dag_processing.processor import COMMS_DECODER as COMMS # type: ignore log = structlog.get_logger(logger_name="task") - SUPERVISOR_COMMS.send_request(log=log, msg=GetVariable(key=key)) - msg = SUPERVISOR_COMMS.get_message() + COMMS.send_request(log=log, msg=GetVariable(key=key)) + msg = COMMS.get_message() if isinstance(msg, ErrorResponse): raise AirflowRuntimeError(msg) diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 8749bb2be1085..a8aad7995c088 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -878,3 +878,60 @@ def execute(self, context): assert str(exc_info.value) == ( f"Returned dictionary keys must be strings when using multiple_outputs, found 2 ({int}) instead" ) + + +class TestWritingTasksWithSDKDefinitions: + def test_run_task_which_gets_variable( + self, time_machine, mocked_parse, make_ti_context, mock_supervisor_comms + ): + """Test running a basic task that gets a variable""" + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + + class CustomOperator(BaseOperator): + def execute(self, context): + # Gets a variable and returns it + value = Variable.get(key="my_var") + return value + + task = CustomOperator(task_id="get_var") + what = StartupDetails( + ti=TaskInstance( + id=uuid7(), + task_id="get_var", + dag_id="get_var_dag", + run_id="c", + try_number=1, + ), + file="", + requests_fd=0, + ti_context=make_ti_context(), + ) + + # Parse the task instance + runtime_ti = mocked_parse(what, "get_var_dag", task) + + # Mock the `get_message` to return VariableResult + mock_supervisor_comms.get_message.return_value = VariableResult( + key="my_var", + value="my_value", + ) + + # Run the task + run(runtime_ti, log=mock.MagicMock()) + + # Assert if a GetVariable call was made and later TaskState to mark as success + expected_calls = [ + mock.call.send_request( + msg=GetVariable( + key="my_var", + ), + log=mock.ANY, + ), + mock.call.send_request( + msg=TaskState(end_date=instant, state=TerminalTIState.SUCCESS), + log=mock.ANY, + ), + ] + + mock_supervisor_comms.send_request.assert_has_calls(expected_calls, any_order=True) diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py index f90475a2e96ff..67f301a32a355 100644 --- a/tests/dag_processing/test_processor.py +++ b/tests/dag_processing/test_processor.py @@ -20,6 +20,7 @@ import pathlib import sys from typing import TYPE_CHECKING +from unittest import mock from unittest.mock import patch import pytest @@ -35,6 +36,7 @@ from airflow.models import DagBag, TaskInstance from airflow.models.baseoperator import BaseOperator from airflow.models.serialized_dag import SerializedDagModel +from airflow.sdk.execution_time.comms import GetVariable, VariableResult from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import DagRunState, TaskInstanceState @@ -124,6 +126,13 @@ def fake_collect_dags(dagbag: DagBag, *args, **kwargs): assert resp.import_errors is not None assert "a.py" in resp.import_errors + @patch("airflow.dag_processing.processor.COMMS_DECODER", create=True) + def test_parse_with_top_level_variable(self, mock_comms_decoder): + mock_comms_decoder.get_message.return_value = VariableResult(key="my_var", value="my_value") + dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False) + assert "example_get_variable_using_task_sdk" in dagbag.dag_ids + mock_comms_decoder.send_request.assert_called_once_with(log=mock.ANY, msg=GetVariable(key="my_var")) + # @conf_vars({("logging", "dag_processor_log_target"): "stdout"}) # @mock.patch("airflow.dag_processing.processor.settings.dispose_orm", MagicMock) From 8b7c8ee446a8ffbc0c3c22a2926a19819d332958 Mon Sep 17 00:00:00 2001 From: Amogh Date: Wed, 8 Jan 2025 20:26:20 +0530 Subject: [PATCH 2/5] not found variable as None --- task_sdk/src/airflow/sdk/definitions/variable.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/task_sdk/src/airflow/sdk/definitions/variable.py b/task_sdk/src/airflow/sdk/definitions/variable.py index b8e75cc78edf1..17ea91ba52425 100644 --- a/task_sdk/src/airflow/sdk/definitions/variable.py +++ b/task_sdk/src/airflow/sdk/definitions/variable.py @@ -23,7 +23,6 @@ from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType from airflow.sdk.execution_time.context import _get_variable -from airflow.sdk.types import NOTSET @attrs.define @@ -43,7 +42,7 @@ class Variable: description: str | None = None @classmethod - def get(cls, key: str, default: Any = NOTSET) -> Any: + def get(cls, key: str, default: Any = None) -> Any: try: return _get_variable(key).value except AirflowRuntimeError as e: From 609cc9ea7e3f19fa2b28273827ed689d9828bb16 Mon Sep 17 00:00:00 2001 From: Amogh Date: Thu, 9 Jan 2025 11:01:31 +0530 Subject: [PATCH 3/5] moving variable helpers outside of context --- .../src/airflow/sdk/definitions/variable.py | 2 +- .../src/airflow/sdk/execution_time/context.py | 40 +------------ .../airflow/sdk/execution_time/variable.py | 58 +++++++++++++++++++ 3 files changed, 61 insertions(+), 39 deletions(-) create mode 100644 task_sdk/src/airflow/sdk/execution_time/variable.py diff --git a/task_sdk/src/airflow/sdk/definitions/variable.py b/task_sdk/src/airflow/sdk/definitions/variable.py index 17ea91ba52425..e13ad98863e4c 100644 --- a/task_sdk/src/airflow/sdk/definitions/variable.py +++ b/task_sdk/src/airflow/sdk/definitions/variable.py @@ -22,7 +22,7 @@ import attrs from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType -from airflow.sdk.execution_time.context import _get_variable +from airflow.sdk.execution_time.variable import _get_variable @attrs.define diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py b/task_sdk/src/airflow/sdk/execution_time/context.py index dda3945e57bfe..095b76c19f7c0 100644 --- a/task_sdk/src/airflow/sdk/execution_time/context.py +++ b/task_sdk/src/airflow/sdk/execution_time/context.py @@ -21,12 +21,12 @@ import structlog from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType +from airflow.sdk.execution_time.variable import _get_variable from airflow.sdk.types import NOTSET if TYPE_CHECKING: from airflow.sdk.definitions.connection import Connection - from airflow.sdk.definitions.variable import Variable - from airflow.sdk.execution_time.comms import ConnectionResult, VariableResult + from airflow.sdk.execution_time.comms import ConnectionResult def _convert_connection_result_conn(conn_result: ConnectionResult) -> Connection: @@ -36,16 +36,6 @@ def _convert_connection_result_conn(conn_result: ConnectionResult) -> Connection return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True)) -def _convert_variable_result_to_variable(var_result: VariableResult, deserialize_json: bool) -> Variable: - from airflow.sdk.definitions.variable import Variable - - if deserialize_json: - import json - - var_result.value = json.loads(var_result.value) # type: ignore - return Variable(**var_result.model_dump(exclude={"type"})) - - def _get_connection(conn_id: str) -> Connection: # TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms` # or `airflow.sdk.execution_time.connection` @@ -66,32 +56,6 @@ def _get_connection(conn_id: str) -> Connection: return _convert_connection_result_conn(msg) -def _get_variable(key: str, deserialize_json: bool = False) -> Variable: - # TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms` - # or `airflow.sdk.execution_time.variable` - # A reason to not move it to `airflow.sdk.execution_time.comms` is that it - # will make that module depend on Task SDK, which is not ideal because we intend to - # keep Task SDK as a separate package than execution time mods. - from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable - - try: - # We check the hypothesis if the request for variable came from task. - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS as COMMS # type: ignore - except ImportError: - # If not, hypothesis is false and this request is from dag level. - from airflow.dag_processing.processor import COMMS_DECODER as COMMS # type: ignore - - log = structlog.get_logger(logger_name="task") - COMMS.send_request(log=log, msg=GetVariable(key=key)) - msg = COMMS.get_message() - if isinstance(msg, ErrorResponse): - raise AirflowRuntimeError(msg) - - if TYPE_CHECKING: - assert isinstance(msg, VariableResult) - return _convert_variable_result_to_variable(msg, deserialize_json) - - class ConnectionAccessor: """Wrapper to access Connection entries in template.""" diff --git a/task_sdk/src/airflow/sdk/execution_time/variable.py b/task_sdk/src/airflow/sdk/execution_time/variable.py new file mode 100644 index 0000000000000..49ff411fb1cb7 --- /dev/null +++ b/task_sdk/src/airflow/sdk/execution_time/variable.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +import structlog + +from airflow.sdk.exceptions import AirflowRuntimeError + +if TYPE_CHECKING: + from airflow.sdk.definitions.variable import Variable + from airflow.sdk.execution_time.comms import VariableResult + + +def _convert_variable_result_to_variable(var_result: VariableResult, deserialize_json: bool) -> Variable: + from airflow.sdk.definitions.variable import Variable + + if deserialize_json: + import json + + var_result.value = json.loads(var_result.value) # type: ignore + return Variable(**var_result.model_dump(exclude={"type"})) + + +def _get_variable(key: str, deserialize_json: bool = False) -> Variable: + from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable + + try: + # We check the hypothesis if the request for variable came from task. + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS as COMMS # type: ignore + except ImportError: + # If not, hypothesis is false and this request is from dag level. + from airflow.dag_processing.processor import COMMS_DECODER as COMMS # type: ignore + + log = structlog.get_logger(logger_name="task") + COMMS.send_request(log=log, msg=GetVariable(key=key)) + msg = COMMS.get_message() + if isinstance(msg, ErrorResponse): + raise AirflowRuntimeError(msg) + + if TYPE_CHECKING: + assert isinstance(msg, VariableResult) + return _convert_variable_result_to_variable(msg, deserialize_json) From be9f2aa2935d104543032e7469e016f9fefee1e1 Mon Sep 17 00:00:00 2001 From: Amogh Date: Thu, 9 Jan 2025 12:41:44 +0530 Subject: [PATCH 4/5] fixing tests --- airflow/dag_processing/processor.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index 8b5290d591b2e..e3812d894fbaa 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -55,6 +55,11 @@ def _parse_file_entrypoint(): import structlog # Parse DAG file, send JSON back up! + global COMMS_DECODER + COMMS_DECODER = task_runner.CommsDecoder[ToChild, ToParent]( + input=sys.stdin, + decoder=TypeAdapter[ToChild](ToChild), + ) msg = COMMS_DECODER.get_message() COMMS_DECODER.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) @@ -207,11 +212,6 @@ def start( # type: ignore[override] target: Callable[[], None] = _parse_file_entrypoint, **kwargs, ) -> Self: - global COMMS_DECODER - COMMS_DECODER = task_runner.CommsDecoder[ToChild, ToParent]( - input=sys.stdin, - decoder=TypeAdapter[ToChild](ToChild), - ) return super().start(path, callbacks, target=target, client=None, **kwargs) # type:ignore[arg-type] def _on_child_started( # type: ignore[override] From e05c8fc4f26e4b448f157ecf8f2145b6cf6d3e62 Mon Sep 17 00:00:00 2001 From: Amogh Date: Thu, 9 Jan 2025 13:09:42 +0530 Subject: [PATCH 5/5] fix mypy --- task_sdk/tests/execution_time/test_context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/task_sdk/tests/execution_time/test_context.py b/task_sdk/tests/execution_time/test_context.py index d3bf589c84e1d..44004fef5f133 100644 --- a/task_sdk/tests/execution_time/test_context.py +++ b/task_sdk/tests/execution_time/test_context.py @@ -27,8 +27,8 @@ ConnectionAccessor, VariableAccessor, _convert_connection_result_conn, - _convert_variable_result_to_variable, ) +from airflow.sdk.execution_time.variable import _convert_variable_result_to_variable def test_convert_connection_result_conn():