Skip to content

Commit

Permalink
Add _get_compatible_read_for_providers function
Browse files Browse the repository at this point in the history
- Make it compatible for providers that implemented _read method.
  • Loading branch information
jason810496 committed Dec 25, 2024
1 parent f9e354c commit c3fa4bb
Showing 1 changed file with 50 additions and 1 deletion.
51 changes: 50 additions & 1 deletion airflow/utils/log/file_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from functools import cached_property, partial
from itertools import chain
from pathlib import Path
from types import GeneratorType
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from urllib.parse import urljoin

Expand Down Expand Up @@ -284,6 +285,42 @@ def _get_compatible_parse_log_stream(remote_logs: list[str]) -> _ParsedLogStream
yield timestamp, line_num, line


def _get_compatible_read_for_providers(read_response: tuple) -> tuple[Iterable[str], dict[str, Any]]:
"""
Compatible utility for transforming `_read` method return value for providers.
Providers methods return type might be:
- `tuple[str,dict[str,Any]]`
- alibaba.cloud.log.oss_task_handler.OssTaskHandler
- amazon.aws.log.cloudwatch_task_handler.CloudwatchTaskHandler
- redis.log.redis_task_handler.RedisTaskHandler
- `tuple[list[tuple[str,str]],dict[str,Any]]` ( tuple[list[host,log_documents],metadata] )
- For this case, we need to split host and log_documents and put host into metadata
- elasticsearch.log.es_task_handler.ElasticsearchTaskHandler
- opensearch.log.os_task_handler.OpenSearchTaskHandler
"""
if len(read_response) != 2:
raise ValueError("Unexpected return value from _read")
# for tuple[str,dict[str,Any]]
if isinstance(read_response[0], str):
log_str, metadata = read_response
return (log_str.splitlines(), metadata)

# for tuple[list[tuple[str,str]],dict[str,Any]]
if isinstance(read_response[0], list):
host_by_logs, metadata = read_response
if len(host_by_logs) > 0:
metadata["host"] = host_by_logs[0][0]

def _host_by_logs_to_log_stream(host_by_logs):
for _, log in host_by_logs:
yield log

return (_host_by_logs_to_log_stream(host_by_logs), metadata)

raise ValueError("Unexpected return value from _read")


class FileTaskHandler(logging.Handler):
"""
FileTaskHandler is a python log handler that handles and reads task instance logs.
Expand Down Expand Up @@ -624,11 +661,23 @@ def read(

# subclasses implement _read and may not have log_type, which was added recently
for i, try_number_element in enumerate(try_numbers):
log_stream, out_metadata = self._read(task_instance, try_number_element, metadata)
log_stream: Iterable[str]
out_metadata: dict[str, Any]
read_response = self._read(task_instance, try_number_element, metadata)
if len(read_response) != 2:
raise ValueError("Unexpected return value from _read")
if not (isinstance(read_response[0], GeneratorType) or isinstance(read_response[0], chain)):
# providers haven't adapted to stream-based log reading yet
log_stream, out_metadata = _get_compatible_read_for_providers(read_response)
else:
log_stream, out_metadata = read_response
# es_task_handler return logs grouped by host. wrap other handler returning log string
# with default/ empty host so that UI can render the response in the same way
if not self._read_grouped_logs():
hosts[i] = task_instance.hostname
else:
# the host is stored in metadata
hosts[i] = out_metadata.get("host", "")

logs[i] = log_stream
metadata_array[i] = out_metadata
Expand Down

0 comments on commit c3fa4bb

Please sign in to comment.