Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-72: Add support to get Variables in task SDK to author tasks #45458

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@
TaskCallbackRequest,
)
from airflow.configuration import conf
from airflow.models import Variable
from airflow.models.dagbag import DagBag
from airflow.sdk.execution_time.comms import GetConnection, GetVariable
from airflow.sdk.api.datamodels._generated import VariableResponse
from airflow.sdk.execution_time import task_runner
from airflow.sdk.execution_time.comms import GetConnection, GetVariable, VariableResult
from airflow.sdk.execution_time.supervisor import WatchedSubprocess
from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG
from airflow.stats import Stats
Expand All @@ -43,26 +46,22 @@
from airflow.typing_compat import Self
from airflow.utils.context import Context

COMMS_DECODER: task_runner.CommsDecoder[ToChild, ToParent]


def _parse_file_entrypoint():
import os

import structlog

from airflow.sdk.execution_time import task_runner
# Parse DAG file, send JSON back up!

comms_decoder = task_runner.CommsDecoder[DagFileParseRequest, DagFileParsingResult](
input=sys.stdin,
decoder=TypeAdapter[DagFileParseRequest](DagFileParseRequest),
)
msg = comms_decoder.get_message()
comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0)
msg = COMMS_DECODER.get_message()
COMMS_DECODER.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0)

log = structlog.get_logger(logger_name="task")

result = _parse_file(msg, log)
comms_decoder.send_request(log, result)
COMMS_DECODER.send_request(log, result)


def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileParsingResult:
Expand Down Expand Up @@ -180,6 +179,11 @@ class DagFileParsingResult(BaseModel):
Field(discriminator="type"),
]

ToChild = Annotated[
Union[DagFileParseRequest, VariableResult],
Field(discriminator="type"),
]


@attrs.define()
class DagFileProcessorProcess(WatchedSubprocess):
Expand All @@ -203,6 +207,11 @@ def start( # type: ignore[override]
target: Callable[[], None] = _parse_file_entrypoint,
**kwargs,
) -> Self:
global COMMS_DECODER
COMMS_DECODER = task_runner.CommsDecoder[ToChild, ToParent](
input=sys.stdin,
decoder=TypeAdapter[ToChild](ToChild),
)
return super().start(path, callbacks, target=target, client=None, **kwargs) # type:ignore[arg-type]

def _on_child_started( # type: ignore[override]
Expand Down Expand Up @@ -234,8 +243,16 @@ def _handle_request(self, msg: ToParent, log: FilteringBoundLogger) -> None: #
if isinstance(msg, DagFileParsingResult):
self.parsing_result = msg
return
# GetVariable etc -- parsing a dag can run top level code that asks for an Airflow Variable
super()._handle_request(msg, log)
Comment on lines -237 to -238
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We wont really need this, because for cases of variables, connecitons, we will have to interact with the DB model directly. If we go to super(). _handle_request, it brings the SDK API client into picture, which shouldn't be needed for DAG level stuff

Comment on lines -237 to -238
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We wont really need this, because for cases of variables, connecitons, we will have to interact with the DB model directly. If we go to super(). _handle_request, it brings the SDK API client into picture, which shouldn't be needed for DAG level stuff

elif isinstance(msg, GetVariable):
key = msg.key
try:
value = Variable.get(key)
except KeyError:
log.exception("Variable: %s does not exist", key)
raise
var_result = VariableResult.from_variable_response(VariableResponse(key=key, value=value))
resp = var_result.model_dump_json(exclude_unset=True).encode()
self.stdin.write(resp + b"\n")

@property
def is_ready(self) -> bool:
Expand Down
33 changes: 33 additions & 0 deletions airflow/example_dags/example_get_variable_using_task_sdk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from airflow import DAG
from airflow.models.baseoperator import BaseOperator
from airflow.sdk import Variable

value = Variable.get(key="my_var")


class CustomOperator(BaseOperator):
def execute(self, context):
print(f"Variable defined at top level of dag has value: {value}")


with DAG(dag_id="example_get_variable_using_task_sdk") as dag:
CustomOperator(task_id="print_top_level_variable")
14 changes: 12 additions & 2 deletions task_sdk/src/airflow/sdk/definitions/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@

import attrs

from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
from airflow.sdk.execution_time.variable import _get_variable


@attrs.define
class Variable:
Expand All @@ -34,8 +37,15 @@ class Variable:
"""

key: str
# keeping as any for supporting deserialize_json
# keeping as any for supporting `deserialize_json`
value: Any | None = None
description: str | None = None

# TODO: Extend this definition for reading/writing variables without context
@classmethod
def get(cls, key: str, default: Any = None) -> Any:
try:
return _get_variable(key).value
except AirflowRuntimeError as e:
if e.error.error == ErrorType.VARIABLE_NOT_FOUND:
return default
raise
34 changes: 2 additions & 32 deletions task_sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
import structlog

from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
from airflow.sdk.execution_time.variable import _get_variable
from airflow.sdk.types import NOTSET

if TYPE_CHECKING:
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.definitions.variable import Variable
from airflow.sdk.execution_time.comms import ConnectionResult, VariableResult
from airflow.sdk.execution_time.comms import ConnectionResult


def _convert_connection_result_conn(conn_result: ConnectionResult) -> Connection:
Expand All @@ -36,16 +36,6 @@ def _convert_connection_result_conn(conn_result: ConnectionResult) -> Connection
return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True))


def _convert_variable_result_to_variable(var_result: VariableResult, deserialize_json: bool) -> Variable:
from airflow.sdk.definitions.variable import Variable

if deserialize_json:
import json

var_result.value = json.loads(var_result.value) # type: ignore
return Variable(**var_result.model_dump(exclude={"type"}))


def _get_connection(conn_id: str) -> Connection:
# TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms`
# or `airflow.sdk.execution_time.connection`
Expand All @@ -66,26 +56,6 @@ def _get_connection(conn_id: str) -> Connection:
return _convert_connection_result_conn(msg)


def _get_variable(key: str, deserialize_json: bool) -> Variable:
# TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms`
# or `airflow.sdk.execution_time.variable`
# A reason to not move it to `airflow.sdk.execution_time.comms` is that it
# will make that module depend on Task SDK, which is not ideal because we intend to
# keep Task SDK as a separate package than execution time mods.
from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(log=log, msg=GetVariable(key=key))
msg = SUPERVISOR_COMMS.get_message()
if isinstance(msg, ErrorResponse):
raise AirflowRuntimeError(msg)

if TYPE_CHECKING:
assert isinstance(msg, VariableResult)
return _convert_variable_result_to_variable(msg, deserialize_json)
Comment on lines -69 to -86
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the right time to move these helpers to airflow.sdk.execution_time.variable. We might be running into a circular import otherwise



class ConnectionAccessor:
"""Wrapper to access Connection entries in template."""

Expand Down
58 changes: 58 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/variable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

import structlog

from airflow.sdk.exceptions import AirflowRuntimeError

if TYPE_CHECKING:
from airflow.sdk.definitions.variable import Variable
from airflow.sdk.execution_time.comms import VariableResult


def _convert_variable_result_to_variable(var_result: VariableResult, deserialize_json: bool) -> Variable:
from airflow.sdk.definitions.variable import Variable

if deserialize_json:
import json

var_result.value = json.loads(var_result.value) # type: ignore
return Variable(**var_result.model_dump(exclude={"type"}))


def _get_variable(key: str, deserialize_json: bool = False) -> Variable:
from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable

try:
# We check the hypothesis if the request for variable came from task.
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS as COMMS # type: ignore
except ImportError:
# If not, hypothesis is false and this request is from dag level.
from airflow.dag_processing.processor import COMMS_DECODER as COMMS # type: ignore

log = structlog.get_logger(logger_name="task")
COMMS.send_request(log=log, msg=GetVariable(key=key))
msg = COMMS.get_message()
if isinstance(msg, ErrorResponse):
raise AirflowRuntimeError(msg)

if TYPE_CHECKING:
assert isinstance(msg, VariableResult)
return _convert_variable_result_to_variable(msg, deserialize_json)
57 changes: 57 additions & 0 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,3 +977,60 @@ def execute(self, context):
assert str(exc_info.value) == (
f"Returned dictionary keys must be strings when using multiple_outputs, found 2 ({int}) instead"
)


class TestWritingTasksWithSDKDefinitions:
def test_run_task_which_gets_variable(
self, time_machine, mocked_parse, make_ti_context, mock_supervisor_comms
):
"""Test running a basic task that gets a variable"""
instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

class CustomOperator(BaseOperator):
def execute(self, context):
# Gets a variable and returns it
value = Variable.get(key="my_var")
return value

task = CustomOperator(task_id="get_var")
what = StartupDetails(
ti=TaskInstance(
id=uuid7(),
task_id="get_var",
dag_id="get_var_dag",
run_id="c",
try_number=1,
),
file="",
requests_fd=0,
ti_context=make_ti_context(),
)

# Parse the task instance
runtime_ti = mocked_parse(what, "get_var_dag", task)

# Mock the `get_message` to return VariableResult
mock_supervisor_comms.get_message.return_value = VariableResult(
key="my_var",
value="my_value",
)

# Run the task
run(runtime_ti, log=mock.MagicMock())

# Assert if a GetVariable call was made and later TaskState to mark as success
expected_calls = [
mock.call.send_request(
msg=GetVariable(
key="my_var",
),
log=mock.ANY,
),
mock.call.send_request(
msg=TaskState(end_date=instant, state=TerminalTIState.SUCCESS),
log=mock.ANY,
),
]

mock_supervisor_comms.send_request.assert_has_calls(expected_calls, any_order=True)
9 changes: 9 additions & 0 deletions tests/dag_processing/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import sys
from socket import socketpair
from typing import TYPE_CHECKING
from unittest import mock
from unittest.mock import patch

import pytest
Expand All @@ -37,6 +38,7 @@
from airflow.models import DagBag, TaskInstance
from airflow.models.baseoperator import BaseOperator
from airflow.models.serialized_dag import SerializedDagModel
from airflow.sdk.execution_time.comms import GetVariable, VariableResult
from airflow.sdk.execution_time.task_runner import CommsDecoder
from airflow.utils import timezone
from airflow.utils.session import create_session
Expand Down Expand Up @@ -127,6 +129,13 @@ def fake_collect_dags(dagbag: DagBag, *args, **kwargs):
assert resp.import_errors is not None
assert "a.py" in resp.import_errors

@patch("airflow.dag_processing.processor.COMMS_DECODER", create=True)
def test_parse_with_top_level_variable(self, mock_comms_decoder):
mock_comms_decoder.get_message.return_value = VariableResult(key="my_var", value="my_value")
dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False)
assert "example_get_variable_using_task_sdk" in dagbag.dag_ids
mock_comms_decoder.send_request.assert_called_once_with(log=mock.ANY, msg=GetVariable(key="my_var"))


# @conf_vars({("logging", "dag_processor_log_target"): "stdout"})
# @mock.patch("airflow.dag_processing.processor.settings.dispose_orm", MagicMock)
Expand Down
Loading