Skip to content

Commit

Permalink
Fix providers test cases that use read or _read method
Browse files Browse the repository at this point in the history
  • Loading branch information
jason810496 committed Dec 25, 2024
1 parent c3fa4bb commit 0aaf0ab
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 127 deletions.
11 changes: 6 additions & 5 deletions providers/tests/amazon/aws/log/test_cloudwatch_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,18 +147,19 @@ def test_read(self):
],
)

msg_template = "*** Reading remote log from Cloudwatch log_group: {} log_stream: {}.\n{}\n"
msg_template = "*** Reading remote log from Cloudwatch log_group: {} log_stream: {}.\n{}"
events = "\n".join(
[
f"[{get_time_str(current_time-2000)}] First",
f"[{get_time_str(current_time-1000)}] Second",
f"[{get_time_str(current_time)}] Third",
]
)
assert self.cloudwatch_task_handler.read(self.ti) == (
[[("", msg_template.format(self.remote_log_group, self.remote_log_stream, events))]],
[{"end_of_log": True}],
)
hosts, log_streams, metadatas = self.cloudwatch_task_handler.read(self.ti)
assert hosts == [""]
log_str = "\n".join(line for line in log_streams[0])
assert log_str == msg_template.format(self.remote_log_group, self.remote_log_stream, events)
assert metadatas == [{"end_of_log": True}]

@pytest.mark.parametrize(
"end_date, expected_end_time",
Expand Down
28 changes: 16 additions & 12 deletions providers/tests/amazon/aws/log/test_s3_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,23 +130,27 @@ def test_read(self):
self.conn.put_object(Bucket="bucket", Key=self.remote_log_key, Body=b"Log line\n")
ti = copy.copy(self.ti)
ti.state = TaskInstanceState.SUCCESS
log, metadata = self.s3_task_handler.read(ti)
actual = log[0][0][-1]
assert "*** Found logs in s3:\n*** * s3://bucket/remote/log/location/1.log\n" in actual
assert actual.endswith("Log line")
assert metadata == [{"end_of_log": True, "log_pos": 8}]
read_result = self.s3_task_handler.read(ti)
print("read_result", read_result)
_, log_streams, metadata_array = read_result
log_str = "".join(line for line in log_streams[0])
assert "*** Found logs in s3:\n*** * s3://bucket/remote/log/location/1.log\n" in log_str
assert log_str.endswith("Log line\n")
assert metadata_array == [{"end_of_log": True, "log_pos": 9}]

def test_read_when_s3_log_missing(self):
ti = copy.copy(self.ti)
ti.state = TaskInstanceState.SUCCESS
self.s3_task_handler._read_from_logs_server = mock.Mock(return_value=([], []))
log, metadata = self.s3_task_handler.read(ti)
assert len(log) == 1
assert len(log) == len(metadata)
actual = log[0][0][-1]
self.s3_task_handler._read_from_logs_server = mock.Mock(return_value=([], [], 0))
read_result = self.s3_task_handler.read(ti)
print("read_result", read_result)
_, log_streams, metadata_array = read_result
assert len(log_streams) == 1
assert len(log_streams) == len(metadata_array)
log_str = "".join(line for line in log_streams[0])
expected = "*** No logs found on s3 for ti=<TaskInstance: dag_for_testing_s3_task_handler.task_for_testing_s3_log_handler test [success]>\n"
assert expected in actual
assert metadata[0] == {"end_of_log": True, "log_pos": 0}
assert expected in log_str
assert metadata_array[0] == {"end_of_log": True, "log_pos": 0}

def test_s3_read_when_log_missing(self):
handler = self.s3_task_handler
Expand Down
16 changes: 11 additions & 5 deletions providers/tests/celery/log_handlers/test_log_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from airflow.utils.types import DagRunType

from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.file_task_handler import log_str_to_parsed_log_stream

pytestmark = pytest.mark.db_test

Expand Down Expand Up @@ -77,9 +78,14 @@ def test__read_for_celery_executor_fallbacks_to_worker(self, create_task_instanc
fth = FileTaskHandler("")

fth._read_from_logs_server = mock.Mock()
fth._read_from_logs_server.return_value = ["this message"], ["this\nlog\ncontent"]
actual = fth._read(ti=ti, try_number=1)
fth._read_from_logs_server.return_value = (
["this message"],
[log_str_to_parsed_log_stream("this\nlog\ncontent")],
len("this\nlog\ncontent"),
)
log_stream, metadata = fth._read(ti=ti, try_number=1)
log_str = "\n".join(line for line in log_stream)
fth._read_from_logs_server.assert_called_once()
assert "*** this message\n" in actual[0]
assert actual[0].endswith("this\nlog\ncontent")
assert actual[1] == {"end_of_log": False, "log_pos": 16}
assert "*** this message\n" in log_str
assert log_str.endswith("this\nlog\ncontent")
assert metadata == {"end_of_log": False, "log_pos": 16}
135 changes: 73 additions & 62 deletions providers/tests/elasticsearch/log/test_es_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,43 +202,44 @@ def test_client_with_patterns(self):

def test_read(self, ti):
ts = pendulum.now()
logs, metadatas = self.es_task_handler.read(
_, log_streams, metadatas = self.es_task_handler.read(
ti, 1, {"offset": 0, "last_log_timestamp": str(ts), "end_of_log": False}
)

assert len(logs) == 1
assert len(logs) == len(metadatas)
assert len(logs[0]) == 1
assert self.test_message == logs[0][0][-1]
assert len(log_streams) == 1
assert len(log_streams) == len(metadatas)
log_str = "".join(line for line in log_streams[0])
assert self.test_message == log_str
assert not metadatas[0]["end_of_log"]
assert metadatas[0]["offset"] == "1"
assert timezone.parse(metadatas[0]["last_log_timestamp"]) > ts

def test_read_with_patterns(self, ti):
ts = pendulum.now()
with mock.patch.object(self.es_task_handler, "index_patterns", new="test_*,other_*"):
logs, metadatas = self.es_task_handler.read(
_, log_streams, metadatas = self.es_task_handler.read(
ti, 1, {"offset": 0, "last_log_timestamp": str(ts), "end_of_log": False}
)

assert len(logs) == 1
assert len(logs) == len(metadatas)
assert len(logs[0]) == 1
assert self.test_message == logs[0][0][-1]
assert len(log_streams) == 1
assert len(log_streams) == len(metadatas)
log_str = "".join(line for line in log_streams[0])
assert self.test_message == log_str
assert not metadatas[0]["end_of_log"]
assert metadatas[0]["offset"] == "1"
assert timezone.parse(metadatas[0]["last_log_timestamp"]) > ts

def test_read_with_patterns_no_match(self, ti):
ts = pendulum.now()
with mock.patch.object(self.es_task_handler, "index_patterns", new="test_other_*,test_another_*"):
logs, metadatas = self.es_task_handler.read(
_, log_streams, metadatas = self.es_task_handler.read(
ti, 1, {"offset": 0, "last_log_timestamp": str(ts), "end_of_log": False}
)

assert len(logs) == 1
assert len(logs) == len(metadatas)
assert logs == [[]]
assert len(log_streams) == 1
assert len(log_streams) == len(metadatas)
log_str = "".join(line for line in log_streams[0])
assert log_str == ""
assert not metadatas[0]["end_of_log"]
assert metadatas[0]["offset"] == "0"
# last_log_timestamp won't change if no log lines read.
Expand All @@ -265,22 +266,22 @@ def test_read_missing_logs(self, seconds, create_task_instance):
create_task_instance=create_task_instance,
)
ts = pendulum.now().add(seconds=-seconds)
logs, metadatas = self.es_task_handler.read(ti, 1, {"offset": 0, "last_log_timestamp": str(ts)})
_, log_streams, metadatas = self.es_task_handler.read(
ti, 1, {"offset": 0, "last_log_timestamp": str(ts)}
)

assert len(logs) == 1
assert len(log_streams) == 1
log_str = "".join(line for line in log_streams[0])
if seconds > 5:
# we expect a log not found message when checking began more than 5 seconds ago
assert len(logs[0]) == 1
actual_message = logs[0][0][1]
expected_pattern = r"^\*\*\* Log .* not found in Elasticsearch.*"
assert re.match(expected_pattern, actual_message) is not None
assert re.match(expected_pattern, log_str) is not None
assert metadatas[0]["end_of_log"] is True
else:
# we've "waited" less than 5 seconds so it should not be "end of log" and should be no log message
assert len(logs[0]) == 0
assert logs == [[]]
assert log_str == ""
assert metadatas[0]["end_of_log"] is False
assert len(logs) == len(metadatas)
assert len(log_streams) == len(metadatas)
assert metadatas[0]["offset"] == "0"
assert timezone.parse(metadatas[0]["last_log_timestamp"]) == ts

Expand All @@ -295,23 +296,25 @@ def test_read_with_match_phrase_query(self, ti):
self.es.index(index=self.index_name, doc_type=self.doc_type, body=another_body, id=1)

ts = pendulum.now()
logs, metadatas = self.es_task_handler.read(
_, log_streams, metadatas = self.es_task_handler.read(
ti, 1, {"offset": "0", "last_log_timestamp": str(ts), "end_of_log": False, "max_offset": 2}
)
assert len(logs) == 1
assert len(logs) == len(metadatas)
assert self.test_message == logs[0][0][-1]
assert another_test_message != logs[0]
assert len(log_streams) == 1
assert len(log_streams) == len(metadatas)
log_str = "".join(line for line in log_streams[0])
assert self.test_message == log_str
assert another_test_message != log_str

assert not metadatas[0]["end_of_log"]
assert metadatas[0]["offset"] == "1"
assert timezone.parse(metadatas[0]["last_log_timestamp"]) > ts

def test_read_with_none_metadata(self, ti):
logs, metadatas = self.es_task_handler.read(ti, 1)
assert len(logs) == 1
assert len(logs) == len(metadatas)
assert self.test_message == logs[0][0][-1]
_, log_streams, metadatas = self.es_task_handler.read(ti, 1)
assert len(log_streams) == 1
assert len(log_streams) == len(metadatas)
log_str = "".join(line for line in log_streams[0])
assert self.test_message == log_str
assert not metadatas[0]["end_of_log"]
assert metadatas[0]["offset"] == "1"
assert timezone.parse(metadatas[0]["last_log_timestamp"]) < pendulum.now()
Expand All @@ -322,23 +325,25 @@ def test_read_nonexistent_log(self, ti):
# and doc_type regardless of match filters, so we delete the log entry instead
# of making a new TaskInstance to query.
self.es.delete(index=self.index_name, doc_type=self.doc_type, id=1)
logs, metadatas = self.es_task_handler.read(
_, log_streams, metadatas = self.es_task_handler.read(
ti, 1, {"offset": 0, "last_log_timestamp": str(ts), "end_of_log": False}
)
assert len(logs) == 1
assert len(logs) == len(metadatas)
assert logs == [[]]
assert len(log_streams) == 1
assert len(log_streams) == len(metadatas)
log_str = "".join(line for line in log_streams[0])
assert log_str == ""
assert not metadatas[0]["end_of_log"]
assert metadatas[0]["offset"] == "0"
# last_log_timestamp won't change if no log lines read.
assert timezone.parse(metadatas[0]["last_log_timestamp"]) == ts

def test_read_with_empty_metadata(self, ti):
ts = pendulum.now()
logs, metadatas = self.es_task_handler.read(ti, 1, {})
assert len(logs) == 1
assert len(logs) == len(metadatas)
assert self.test_message == logs[0][0][-1]
_, log_streams, metadatas = self.es_task_handler.read(ti, 1, {})
assert len(log_streams) == 1
assert len(log_streams) == len(metadatas)
log_str = "".join(line for line in log_streams[0])
assert self.test_message == log_str
assert not metadatas[0]["end_of_log"]
# offset should be initialized to 0 if not provided.
assert metadatas[0]["offset"] == "1"
Expand All @@ -348,10 +353,11 @@ def test_read_with_empty_metadata(self, ti):

# case where offset is missing but metadata not empty.
self.es.delete(index=self.index_name, doc_type=self.doc_type, id=1)
logs, metadatas = self.es_task_handler.read(ti, 1, {"end_of_log": False})
assert len(logs) == 1
assert len(logs) == len(metadatas)
assert logs == [[]]
_, log_streams, metadatas = self.es_task_handler.read(ti, 1, {"end_of_log": False})
assert len(log_streams) == 1
assert len(log_streams) == len(metadatas)
log_str = "".join(line for line in log_streams[0])
assert log_str == ""
assert not metadatas[0]["end_of_log"]
# offset should be initialized to 0 if not provided.
assert metadatas[0]["offset"] == "0"
Expand All @@ -367,7 +373,7 @@ def test_read_timeout(self, ti):
# if we had never retrieved any logs at all (offset=0), then we would have gotten
# a "logs not found" message after 5 seconds of trying
offset = 1
logs, metadatas = self.es_task_handler.read(
_, log_streams, metadatas = self.es_task_handler.read(
task_instance=ti,
try_number=1,
metadata={
Expand All @@ -376,24 +382,25 @@ def test_read_timeout(self, ti):
"end_of_log": False,
},
)
assert len(logs) == 1
assert len(logs) == len(metadatas)
assert logs == [[]]
assert len(log_streams) == 1
assert len(log_streams) == len(metadatas)
log_str = "".join(line for line in log_streams[0])
assert log_str == ""
assert metadatas[0]["end_of_log"]
assert str(offset) == metadatas[0]["offset"]
assert timezone.parse(metadatas[0]["last_log_timestamp"]) == ts

def test_read_as_download_logs(self, ti):
ts = pendulum.now()
logs, metadatas = self.es_task_handler.read(
_, log_streams, metadatas = self.es_task_handler.read(
ti,
1,
{"offset": 0, "last_log_timestamp": str(ts), "download_logs": True, "end_of_log": False},
)
assert len(logs) == 1
assert len(logs) == len(metadatas)
assert len(logs[0]) == 1
assert self.test_message == logs[0][0][-1]
assert len(log_streams) == 1
assert len(log_streams) == len(metadatas)
log_str = "".join(line for line in log_streams[0])
assert self.test_message == log_str
assert not metadatas[0]["end_of_log"]
assert metadatas[0]["download_logs"]
assert metadatas[0]["offset"] == "1"
Expand All @@ -403,13 +410,14 @@ def test_read_raises(self, ti):
with mock.patch.object(self.es_task_handler.log, "exception") as mock_exception:
with mock.patch.object(self.es_task_handler.client, "search") as mock_execute:
mock_execute.side_effect = SearchFailedException("Failed to read")
logs, metadatas = self.es_task_handler.read(ti, 1)
_, log_streams, metadatas = self.es_task_handler.read(ti, 1)
assert mock_exception.call_count == 1
args, kwargs = mock_exception.call_args
assert "Could not read log with log_id:" in args[0]
assert len(logs) == 1
assert len(logs) == len(metadatas)
assert logs == [[]]
assert len(log_streams) == 1
assert len(log_streams) == len(metadatas)
log_str = "".join(line for line in log_streams[0])
assert log_str == ""
assert not metadatas[0]["end_of_log"]
assert metadatas[0]["offset"] == "0"

Expand Down Expand Up @@ -444,10 +452,11 @@ def test_read_with_json_format(self, ti):
self.es_task_handler.set_context(ti)
self.es.index(index=self.index_name, doc_type=self.doc_type, body=self.body, id=id)

logs, _ = self.es_task_handler.read(
_, log_streams, _ = self.es_task_handler.read(
ti, 1, {"offset": 0, "last_log_timestamp": str(ts), "end_of_log": False}
)
assert logs[0][0][1] == "[2020-12-24 19:25:00,962] {taskinstance.py:851} INFO - some random stuff - "
log_str = "".join(line for line in log_streams[0])
assert log_str == "[2020-12-24 19:25:00,962] {taskinstance.py:851} INFO - some random stuff - "

def test_read_with_json_format_with_custom_offset_and_host_fields(self, ti):
ts = pendulum.now()
Expand All @@ -472,10 +481,11 @@ def test_read_with_json_format_with_custom_offset_and_host_fields(self, ti):
self.es_task_handler.set_context(ti)
self.es.index(index=self.index_name, doc_type=self.doc_type, body=self.body, id=id)

logs, _ = self.es_task_handler.read(
_, log_streams, _ = self.es_task_handler.read(
ti, 1, {"offset": 0, "last_log_timestamp": str(ts), "end_of_log": False}
)
assert logs[0][0][1] == "[2020-12-24 19:25:00,962] {taskinstance.py:851} INFO - some random stuff - "
log_str = "".join(line for line in log_streams[0])
assert log_str == "[2020-12-24 19:25:00,962] {taskinstance.py:851} INFO - some random stuff - "

def test_read_with_custom_offset_and_host_fields(self, ti):
ts = pendulum.now()
Expand All @@ -493,10 +503,11 @@ def test_read_with_custom_offset_and_host_fields(self, ti):
}
self.es.index(index=self.index_name, doc_type=self.doc_type, body=self.body, id=id)

logs, _ = self.es_task_handler.read(
_, log_streams, _ = self.es_task_handler.read(
ti, 1, {"offset": 0, "last_log_timestamp": str(ts), "end_of_log": False}
)
assert self.test_message == logs[0][0][1]
log_str = "".join(line for line in log_streams[0])
assert self.test_message == log_str

def test_close(self, ti):
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
Expand Down
Loading

0 comments on commit 0aaf0ab

Please sign in to comment.