Skip to content

Commit

Permalink
Refactor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chongshenng committed Sep 23, 2024
1 parent 25d5c1e commit 48adaa5
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 19 deletions.
7 changes: 1 addition & 6 deletions src/py/flwr/cli/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def stream_logs(run_id: int, channel: grpc.Channel, duration: int) -> None:
break


def print_logs(
run_id: int, channel: grpc.Channel, timeout: int, is_test: bool = False
) -> None:
def print_logs(run_id: int, channel: grpc.Channel, timeout: int) -> None:
"""Print logs from the beginning of a run.
The `is_test` parameter is only used for `pytest` and must be `False` otherwise.
Expand All @@ -60,9 +58,6 @@ def print_logs(
# Enforce timeout for graceful exit
for res in stub.StreamLogs(req, timeout=timeout):
print(res.log_output)
# Break out of while-loop when using pytest
if is_test:
break
except grpc.RpcError as e:
# pylint: disable=E1101
if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
Expand Down
44 changes: 31 additions & 13 deletions src/py/flwr/cli/log_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,43 @@


import unittest
from unittest.mock import Mock, patch
from typing import NoReturn
from unittest.mock import Mock, call, patch

from flwr.proto.exec_pb2 import StreamLogsResponse # pylint: disable=E0611

from .log import print_logs, stream_logs


class InterruptedStreamLogsResponse:
"""Create a StreamLogsResponse object with KeyboardInterrupt."""

@property
def log_output(self) -> NoReturn:
"""Raise KeyboardInterrupt to exit logstream test gracefully."""
raise KeyboardInterrupt


class TestFlwrLog(unittest.TestCase):
"""Unit tests for `flwr log` CLI functions."""

def setUp(self) -> None:
"""Initialize mock ExecStub before each test."""
mock_response_iterator = iter(
[StreamLogsResponse(log_output=f"result_{i}") for i in range(1, 4)]
)
self.expected_calls = [
call("log_output_1"),
call("log_output_2"),
call("log_output_3"),
]
mock_response_iterator = [
iter(
[StreamLogsResponse(log_output=f"log_output_{i}") for i in range(1, 4)]
+ [InterruptedStreamLogsResponse()]
)
]
self.mock_stub = Mock()
self.mock_stub.StreamLogs.return_value = mock_response_iterator
self.mock_stub.StreamLogs.side_effect = mock_response_iterator
self.patcher = patch("flwr.cli.log.ExecStub", return_value=self.mock_stub)

self.patcher.start()

# Create mock channel
Expand All @@ -46,15 +65,14 @@ def tearDown(self) -> None:
def test_flwr_log_stream_method(self) -> None:
"""Test stream_logs."""
with patch("builtins.print") as mock_print:
stream_logs(run_id=123, channel=self.mock_channel, duration=1)
mock_print.assert_any_call("result_1")
mock_print.assert_any_call("result_2")
mock_print.assert_any_call("result_3")
with self.assertRaises(KeyboardInterrupt):
stream_logs(run_id=123, channel=self.mock_channel, duration=1)
# Assert that mock print was called with the expected arguments
mock_print.assert_has_calls(self.expected_calls)

def test_flwr_log_print_method(self) -> None:
"""Test print_logs."""
with patch("builtins.print") as mock_print:
print_logs(run_id=123, channel=self.mock_channel, timeout=0, is_test=True)
mock_print.assert_any_call("result_1")
mock_print.assert_any_call("result_2")
mock_print.assert_any_call("result_3")
print_logs(run_id=123, channel=self.mock_channel, timeout=0)
# Assert that mock print was called with the expected arguments
mock_print.assert_has_calls(self.expected_calls)

0 comments on commit 48adaa5

Please sign in to comment.