diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index 8d48b5ab6aeb3..e3812d894fbaa 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -31,8 +31,11 @@ TaskCallbackRequest, ) from airflow.configuration import conf +from airflow.models import Variable from airflow.models.dagbag import DagBag -from airflow.sdk.execution_time.comms import GetConnection, GetVariable +from airflow.sdk.api.datamodels._generated import VariableResponse +from airflow.sdk.execution_time import task_runner +from airflow.sdk.execution_time.comms import GetConnection, GetVariable, VariableResult from airflow.sdk.execution_time.supervisor import WatchedSubprocess from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG from airflow.stats import Stats @@ -43,26 +46,27 @@ from airflow.typing_compat import Self from airflow.utils.context import Context +COMMS_DECODER: task_runner.CommsDecoder[ToChild, ToParent] + def _parse_file_entrypoint(): import os import structlog - from airflow.sdk.execution_time import task_runner # Parse DAG file, send JSON back up! - - comms_decoder = task_runner.CommsDecoder[DagFileParseRequest, DagFileParsingResult]( + global COMMS_DECODER + COMMS_DECODER = task_runner.CommsDecoder[ToChild, ToParent]( input=sys.stdin, - decoder=TypeAdapter[DagFileParseRequest](DagFileParseRequest), + decoder=TypeAdapter[ToChild](ToChild), ) - msg = comms_decoder.get_message() - comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) + msg = COMMS_DECODER.get_message() + COMMS_DECODER.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) log = structlog.get_logger(logger_name="task") result = _parse_file(msg, log) - comms_decoder.send_request(log, result) + COMMS_DECODER.send_request(log, result) def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileParsingResult: @@ -180,6 +184,11 @@ class DagFileParsingResult(BaseModel): Field(discriminator="type"), ] +ToChild = Annotated[ + Union[DagFileParseRequest, VariableResult], + Field(discriminator="type"), +] + @attrs.define() class DagFileProcessorProcess(WatchedSubprocess): @@ -234,8 +243,16 @@ def _handle_request(self, msg: ToParent, log: FilteringBoundLogger) -> None: # if isinstance(msg, DagFileParsingResult): self.parsing_result = msg return - # GetVariable etc -- parsing a dag can run top level code that asks for an Airflow Variable - super()._handle_request(msg, log) + elif isinstance(msg, GetVariable): + key = msg.key + try: + value = Variable.get(key) + except KeyError: + log.exception("Variable: %s does not exist", key) + raise + var_result = VariableResult.from_variable_response(VariableResponse(key=key, value=value)) + resp = var_result.model_dump_json(exclude_unset=True).encode() + self.stdin.write(resp + b"\n") @property def is_ready(self) -> bool: diff --git a/airflow/example_dags/example_get_variable_using_task_sdk.py b/airflow/example_dags/example_get_variable_using_task_sdk.py new file mode 100644 index 0000000000000..9d55b12cab2ec --- /dev/null +++ b/airflow/example_dags/example_get_variable_using_task_sdk.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow import DAG +from airflow.models.baseoperator import BaseOperator +from airflow.sdk import Variable + +value = Variable.get(key="my_var") + + +class CustomOperator(BaseOperator): + def execute(self, context): + print(f"Variable defined at top level of dag has value: {value}") + + +with DAG(dag_id="example_get_variable_using_task_sdk") as dag: + CustomOperator(task_id="print_top_level_variable") diff --git a/task_sdk/src/airflow/sdk/definitions/variable.py b/task_sdk/src/airflow/sdk/definitions/variable.py index 5f458580065c5..e13ad98863e4c 100644 --- a/task_sdk/src/airflow/sdk/definitions/variable.py +++ b/task_sdk/src/airflow/sdk/definitions/variable.py @@ -21,6 +21,9 @@ import attrs +from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType +from airflow.sdk.execution_time.variable import _get_variable + @attrs.define class Variable: @@ -34,8 +37,15 @@ class Variable: """ key: str - # keeping as any for supporting deserialize_json + # keeping as any for supporting `deserialize_json` value: Any | None = None description: str | None = None - # TODO: Extend this definition for reading/writing variables without context + @classmethod + def get(cls, key: str, default: Any = None) -> Any: + try: + return _get_variable(key).value + except AirflowRuntimeError as e: + if e.error.error == ErrorType.VARIABLE_NOT_FOUND: + return default + raise diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py b/task_sdk/src/airflow/sdk/execution_time/context.py index 72ac2af225e85..095b76c19f7c0 100644 --- a/task_sdk/src/airflow/sdk/execution_time/context.py +++ b/task_sdk/src/airflow/sdk/execution_time/context.py @@ -21,12 +21,12 @@ import structlog from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType +from airflow.sdk.execution_time.variable import _get_variable from airflow.sdk.types import NOTSET if TYPE_CHECKING: from airflow.sdk.definitions.connection import Connection - from airflow.sdk.definitions.variable import Variable - from airflow.sdk.execution_time.comms import ConnectionResult, VariableResult + from airflow.sdk.execution_time.comms import ConnectionResult def _convert_connection_result_conn(conn_result: ConnectionResult) -> Connection: @@ -36,16 +36,6 @@ def _convert_connection_result_conn(conn_result: ConnectionResult) -> Connection return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True)) -def _convert_variable_result_to_variable(var_result: VariableResult, deserialize_json: bool) -> Variable: - from airflow.sdk.definitions.variable import Variable - - if deserialize_json: - import json - - var_result.value = json.loads(var_result.value) # type: ignore - return Variable(**var_result.model_dump(exclude={"type"})) - - def _get_connection(conn_id: str) -> Connection: # TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms` # or `airflow.sdk.execution_time.connection` @@ -66,26 +56,6 @@ def _get_connection(conn_id: str) -> Connection: return _convert_connection_result_conn(msg) -def _get_variable(key: str, deserialize_json: bool) -> Variable: - # TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms` - # or `airflow.sdk.execution_time.variable` - # A reason to not move it to `airflow.sdk.execution_time.comms` is that it - # will make that module depend on Task SDK, which is not ideal because we intend to - # keep Task SDK as a separate package than execution time mods. - from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - - log = structlog.get_logger(logger_name="task") - SUPERVISOR_COMMS.send_request(log=log, msg=GetVariable(key=key)) - msg = SUPERVISOR_COMMS.get_message() - if isinstance(msg, ErrorResponse): - raise AirflowRuntimeError(msg) - - if TYPE_CHECKING: - assert isinstance(msg, VariableResult) - return _convert_variable_result_to_variable(msg, deserialize_json) - - class ConnectionAccessor: """Wrapper to access Connection entries in template.""" diff --git a/task_sdk/src/airflow/sdk/execution_time/variable.py b/task_sdk/src/airflow/sdk/execution_time/variable.py new file mode 100644 index 0000000000000..49ff411fb1cb7 --- /dev/null +++ b/task_sdk/src/airflow/sdk/execution_time/variable.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +import structlog + +from airflow.sdk.exceptions import AirflowRuntimeError + +if TYPE_CHECKING: + from airflow.sdk.definitions.variable import Variable + from airflow.sdk.execution_time.comms import VariableResult + + +def _convert_variable_result_to_variable(var_result: VariableResult, deserialize_json: bool) -> Variable: + from airflow.sdk.definitions.variable import Variable + + if deserialize_json: + import json + + var_result.value = json.loads(var_result.value) # type: ignore + return Variable(**var_result.model_dump(exclude={"type"})) + + +def _get_variable(key: str, deserialize_json: bool = False) -> Variable: + from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable + + try: + # We check the hypothesis if the request for variable came from task. + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS as COMMS # type: ignore + except ImportError: + # If not, hypothesis is false and this request is from dag level. + from airflow.dag_processing.processor import COMMS_DECODER as COMMS # type: ignore + + log = structlog.get_logger(logger_name="task") + COMMS.send_request(log=log, msg=GetVariable(key=key)) + msg = COMMS.get_message() + if isinstance(msg, ErrorResponse): + raise AirflowRuntimeError(msg) + + if TYPE_CHECKING: + assert isinstance(msg, VariableResult) + return _convert_variable_result_to_variable(msg, deserialize_json) diff --git a/task_sdk/tests/execution_time/test_context.py b/task_sdk/tests/execution_time/test_context.py index d3bf589c84e1d..44004fef5f133 100644 --- a/task_sdk/tests/execution_time/test_context.py +++ b/task_sdk/tests/execution_time/test_context.py @@ -27,8 +27,8 @@ ConnectionAccessor, VariableAccessor, _convert_connection_result_conn, - _convert_variable_result_to_variable, ) +from airflow.sdk.execution_time.variable import _convert_variable_result_to_variable def test_convert_connection_result_conn(): diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 2dc8c0ef5aad5..961e9ab9c311c 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -977,3 +977,60 @@ def execute(self, context): assert str(exc_info.value) == ( f"Returned dictionary keys must be strings when using multiple_outputs, found 2 ({int}) instead" ) + + +class TestWritingTasksWithSDKDefinitions: + def test_run_task_which_gets_variable( + self, time_machine, mocked_parse, make_ti_context, mock_supervisor_comms + ): + """Test running a basic task that gets a variable""" + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + + class CustomOperator(BaseOperator): + def execute(self, context): + # Gets a variable and returns it + value = Variable.get(key="my_var") + return value + + task = CustomOperator(task_id="get_var") + what = StartupDetails( + ti=TaskInstance( + id=uuid7(), + task_id="get_var", + dag_id="get_var_dag", + run_id="c", + try_number=1, + ), + file="", + requests_fd=0, + ti_context=make_ti_context(), + ) + + # Parse the task instance + runtime_ti = mocked_parse(what, "get_var_dag", task) + + # Mock the `get_message` to return VariableResult + mock_supervisor_comms.get_message.return_value = VariableResult( + key="my_var", + value="my_value", + ) + + # Run the task + run(runtime_ti, log=mock.MagicMock()) + + # Assert if a GetVariable call was made and later TaskState to mark as success + expected_calls = [ + mock.call.send_request( + msg=GetVariable( + key="my_var", + ), + log=mock.ANY, + ), + mock.call.send_request( + msg=TaskState(end_date=instant, state=TerminalTIState.SUCCESS), + log=mock.ANY, + ), + ] + + mock_supervisor_comms.send_request.assert_has_calls(expected_calls, any_order=True) diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py index 7f94a30f08974..b3d3a0d67febf 100644 --- a/tests/dag_processing/test_processor.py +++ b/tests/dag_processing/test_processor.py @@ -21,6 +21,7 @@ import sys from socket import socketpair from typing import TYPE_CHECKING +from unittest import mock from unittest.mock import patch import pytest @@ -37,6 +38,7 @@ from airflow.models import DagBag, TaskInstance from airflow.models.baseoperator import BaseOperator from airflow.models.serialized_dag import SerializedDagModel +from airflow.sdk.execution_time.comms import GetVariable, VariableResult from airflow.sdk.execution_time.task_runner import CommsDecoder from airflow.utils import timezone from airflow.utils.session import create_session @@ -127,6 +129,13 @@ def fake_collect_dags(dagbag: DagBag, *args, **kwargs): assert resp.import_errors is not None assert "a.py" in resp.import_errors + @patch("airflow.dag_processing.processor.COMMS_DECODER", create=True) + def test_parse_with_top_level_variable(self, mock_comms_decoder): + mock_comms_decoder.get_message.return_value = VariableResult(key="my_var", value="my_value") + dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False) + assert "example_get_variable_using_task_sdk" in dagbag.dag_ids + mock_comms_decoder.send_request.assert_called_once_with(log=mock.ANY, msg=GetVariable(key="my_var")) + # @conf_vars({("logging", "dag_processor_log_target"): "stdout"}) # @mock.patch("airflow.dag_processing.processor.settings.dispose_orm", MagicMock)