Skip to content

Commit

Permalink
Merge branch 'main' into bundle-parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
jedcunningham authored Jan 11, 2025
2 parents 4f3df7c + 4c01f7d commit c5e18d1
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 180 deletions.
39 changes: 15 additions & 24 deletions airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import os
import sys
import traceback
from collections.abc import Generator
from typing import TYPE_CHECKING, Annotated, Callable, Literal, Union

import attrs
Expand Down Expand Up @@ -186,7 +185,7 @@ class DagFileParsingResult(BaseModel):
]


@attrs.define()
@attrs.define(kw_only=True)
class DagFileProcessorProcess(WatchedSubprocess):
"""
Parses dags with Task SDK API.
Expand All @@ -199,6 +198,7 @@ class DagFileProcessorProcess(WatchedSubprocess):
"""

parsing_result: DagFileParsingResult | None = None
decoder: TypeAdapter[ToParent] = TypeAdapter[ToParent](ToParent)

@classmethod
def start( # type: ignore[override]
Expand All @@ -208,39 +208,30 @@ def start( # type: ignore[override]
target: Callable[[], None] = _parse_file_entrypoint,
**kwargs,
) -> Self:
return super().start(path, callbacks, target=target, client=None, **kwargs) # type:ignore[arg-type]
proc = super().start(target=target, **kwargs)
proc._on_child_started(callbacks, path)
return proc

def _on_child_started( # type: ignore[override]
self, callbacks: list[CallbackRequest], path: str | os.PathLike[str], child_comms_fd: int
) -> None:
def _on_child_started(self, callbacks: list[CallbackRequest], path: str | os.PathLike[str]) -> None:
msg = DagFileParseRequest(
file=os.fspath(path),
requests_fd=child_comms_fd,
requests_fd=self._requests_fd,
callback_requests=callbacks,
)
self.stdin.write(msg.model_dump_json().encode() + b"\n")

def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, None]:
# TODO: Make decoder an instance variable, then this can live in the base class
decoder = TypeAdapter[ToParent](ToParent)

while True:
line = yield

try:
msg = decoder.validate_json(line)
except Exception:
log.exception("Unable to decode message", line=line)
continue

self._handle_request(msg, log) # type: ignore[arg-type]

def _handle_request(self, msg: ToParent, log: FilteringBoundLogger) -> None: # type: ignore[override]
# TODO: GetVariable etc -- parsing a dag can run top level code that asks for an Airflow Variable
resp = None
if isinstance(msg, DagFileParsingResult):
self.parsing_result = msg
return
# GetVariable etc -- parsing a dag can run top level code that asks for an Airflow Variable
super()._handle_request(msg, log)
else:
log.error("Unhandled request", msg=msg)
return

if resp:
self.stdin.write(resp + b"\n")

@property
def is_ready(self) -> bool:
Expand Down
Loading

0 comments on commit c5e18d1

Please sign in to comment.