Skip to content

Commit

Permalink
test commit cast objects to mock when checking
Browse files Browse the repository at this point in the history
  • Loading branch information
DiamondJoseph committed Nov 13, 2024
1 parent 90963cc commit 7b885bd
Showing 1 changed file with 88 additions and 55 deletions.
143 changes: 88 additions & 55 deletions tests/unit_tests/client/test_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Callable
from unittest.mock import MagicMock, Mock, call
from typing import cast
from unittest.mock import MagicMock, call

import pytest
from bluesky_stomp.messaging import MessageContext
Expand Down Expand Up @@ -81,10 +82,12 @@ def mock_rest() -> BlueapiRestClient:

@pytest.fixture
def mock_events() -> EventBusClient:
mock_events = MagicMock(spec=EventBusClient)
ctx = Mock()
mock_events: EventBusClient = MagicMock(spec=EventBusClient)
ctx = MagicMock()
ctx.correlation_id = "foo"
mock_events.subscribe_to_all_events = lambda on_event: on_event(ctx, COMPLETE_EVENT)
cast(MagicMock, mock_events).subscribe_to_all_events = lambda on_event: on_event(
ctx, COMPLETE_EVENT
)
return mock_events


Expand All @@ -110,7 +113,7 @@ def test_get_nonexistant_plan(
client: BlueapiClient,
mock_rest: BlueapiRestClient,
):
mock_rest.get_plan.side_effect = KeyError("Not found")
cast(MagicMock, mock_rest.get_plan).side_effect = KeyError("Not found")
with pytest.raises(KeyError):
client.get_plan("baz")

Expand All @@ -127,7 +130,7 @@ def test_get_nonexistant_device(
client: BlueapiClient,
mock_rest: BlueapiRestClient,
):
mock_rest.get_device.side_effect = KeyError("Not found")
cast(MagicMock, mock_rest.get_device).side_effect = KeyError("Not found")
with pytest.raises(KeyError):
client.get_device("baz")

Expand All @@ -144,7 +147,7 @@ def test_get_nonexistent_task(
client: BlueapiClient,
mock_rest: BlueapiRestClient,
):
mock_rest.get_task.side_effect = KeyError("Not found")
cast(MagicMock, mock_rest.get_task).side_effect = KeyError("Not found")
with pytest.raises(KeyError):
client.get_task("baz")

Expand All @@ -166,23 +169,23 @@ def test_create_task(
mock_rest: BlueapiRestClient,
):
client.create_task(task=Task(name="foo"))
mock_rest.create_task.assert_called_once_with(Task(name="foo"))
cast(MagicMock, mock_rest.create_task).assert_called_once_with(Task(name="foo"))


def test_create_task_does_not_start_task(
client: BlueapiClient,
mock_rest: BlueapiRestClient,
):
client.create_task(task=Task(name="foo"))
mock_rest.update_worker_task.assert_not_called()
cast(MagicMock, mock_rest.update_worker_task).assert_not_called()


def test_clear_task(
client: BlueapiClient,
mock_rest: BlueapiRestClient,
):
client.clear_task(task_id="foo")
mock_rest.clear_task.assert_called_once_with("foo")
cast(MagicMock, mock_rest.clear_task).assert_called_once_with("foo")


def test_get_active_task(client: BlueapiClient):
Expand All @@ -194,14 +197,16 @@ def test_start_task(
mock_rest: BlueapiRestClient,
):
client.start_task(task=WorkerTask(task_id="bar"))
mock_rest.update_worker_task.assert_called_once_with(WorkerTask(task_id="bar"))
cast(MagicMock, mock_rest.update_worker_task).assert_called_once_with(
WorkerTask(task_id="bar")
)


def test_start_nonexistant_task(
client: BlueapiClient,
mock_rest: BlueapiRestClient,
):
mock_rest.update_worker_task.side_effect = KeyError("Not found")
cast(MagicMock, mock_rest.update_worker_task).side_effect = KeyError("Not found")
with pytest.raises(KeyError):
client.start_task(task=WorkerTask(task_id="bar"))

Expand All @@ -210,18 +215,24 @@ def test_create_and_start_task_calls_both_creating_and_starting_endpoints(
client: BlueapiClient,
mock_rest: BlueapiRestClient,
):
mock_rest.create_task.return_value = TaskResponse(task_id="baz")
mock_rest.update_worker_task.return_value = TaskResponse(task_id="baz")
cast(MagicMock, mock_rest.create_task).return_value = TaskResponse(task_id="baz")
cast(MagicMock, mock_rest.update_worker_task).return_value = TaskResponse(
task_id="baz"
)
client.create_and_start_task(Task(name="baz"))
mock_rest.create_task.assert_called_once_with(Task(name="baz"))
mock_rest.update_worker_task.assert_called_once_with(WorkerTask(task_id="baz"))
cast(MagicMock, mock_rest.create_task).assert_called_once_with(Task(name="baz"))
cast(MagicMock, mock_rest.update_worker_task).assert_called_once_with(
WorkerTask(task_id="baz")
)


def test_create_and_start_task_fails_if_task_creation_fails(
client: BlueapiClient,
mock_rest: BlueapiRestClient,
):
mock_rest.create_task.side_effect = BlueskyRemoteControlError("No can do")
cast(MagicMock, mock_rest.create_task).side_effect = BlueskyRemoteControlError(
"No can do"
)
with pytest.raises(BlueskyRemoteControlError):
client.create_and_start_task(Task(name="baz"))

Expand All @@ -230,8 +241,10 @@ def test_create_and_start_task_fails_if_task_id_is_wrong(
client: BlueapiClient,
mock_rest: BlueapiRestClient,
):
mock_rest.create_task.return_value = TaskResponse(task_id="baz")
mock_rest.update_worker_task.return_value = TaskResponse(task_id="bar")
cast(MagicMock, mock_rest.create_task).return_value = TaskResponse(task_id="baz")
cast(MagicMock, mock_rest.update_worker_task).return_value = TaskResponse(
task_id="bar"
)
with pytest.raises(BlueskyRemoteControlError):
client.create_and_start_task(Task(name="baz"))

Expand All @@ -240,8 +253,10 @@ def test_create_and_start_task_fails_if_task_start_fails(
client: BlueapiClient,
mock_rest: BlueapiRestClient,
):
mock_rest.create_task.return_value = TaskResponse(task_id="baz")
mock_rest.update_worker_task.side_effect = BlueskyRemoteControlError("No can do")
cast(MagicMock, mock_rest.create_task).return_value = TaskResponse(task_id="baz")
cast(
MagicMock, mock_rest.update_worker_task
).side_effect = BlueskyRemoteControlError("No can do")
with pytest.raises(BlueskyRemoteControlError):
client.create_and_start_task(Task(name="baz"))

Expand All @@ -255,15 +270,15 @@ def test_reload_environment(
mock_rest: BlueapiRestClient,
):
client.reload_environment()
mock_rest.get_environment.assert_called_once()
mock_rest.delete_environment.assert_called_once()
cast(MagicMock, mock_rest.get_environment).assert_called_once()
cast(MagicMock, mock_rest.delete_environment).assert_called_once()


def test_reload_environment_failure(
client: BlueapiClient,
mock_rest: BlueapiRestClient,
):
mock_rest.get_environment.return_value = EnvironmentResponse(
cast(MagicMock, mock_rest.get_environment).return_value = EnvironmentResponse(
initialized=False, error_message="foo"
)
with pytest.raises(BlueskyRemoteControlError, match="foo"):
Expand All @@ -275,7 +290,7 @@ def test_abort(
mock_rest: BlueapiRestClient,
):
client.abort(reason="foo")
mock_rest.cancel_current_task.assert_called_once_with(
cast(MagicMock, mock_rest.cancel_current_task).assert_called_once_with(
WorkerState.ABORTING,
reason="foo",
)
Expand All @@ -286,15 +301,17 @@ def test_stop(
mock_rest: BlueapiRestClient,
):
client.stop()
mock_rest.cancel_current_task.assert_called_once_with(WorkerState.STOPPING)
cast(MagicMock, mock_rest.cancel_current_task).assert_called_once_with(
WorkerState.STOPPING
)


def test_pause(
client: BlueapiClient,
mock_rest: BlueapiRestClient,
):
client.pause(defer=True)
mock_rest.set_state.assert_called_once_with(
cast(MagicMock, mock_rest.set_state).assert_called_once_with(
WorkerState.PAUSED,
defer=True,
)
Expand All @@ -305,7 +322,7 @@ def test_resume(
mock_rest: BlueapiRestClient,
):
client.resume()
mock_rest.set_state.assert_called_once_with(
cast(MagicMock, mock_rest.set_state).assert_called_once_with(
WorkerState.RUNNING,
defer=False,
)
Expand All @@ -322,32 +339,42 @@ def test_cannot_run_task_without_message_bus(client: BlueapiClient):
def test_run_task_sets_up_control(
client_with_events: BlueapiClient,
mock_rest: BlueapiRestClient,
mock_events: MagicMock,
mock_events: EventBusClient,
):
mock_rest.create_task.return_value = TaskResponse(task_id="foo")
mock_rest.update_worker_task.return_value = TaskResponse(task_id="foo")
ctx = Mock()
cast(MagicMock, mock_rest.create_task).return_value = TaskResponse(task_id="foo")
cast(MagicMock, mock_rest.update_worker_task).return_value = TaskResponse(
task_id="foo"
)
ctx = MagicMock()
ctx.correlation_id = "foo"
mock_events.subscribe_to_all_events = lambda on_event: on_event(COMPLETE_EVENT, ctx)
cast(MagicMock, mock_events).subscribe_to_all_events = lambda on_event: on_event(
COMPLETE_EVENT, ctx
)

client_with_events.run_task(Task(name="foo"))
mock_rest.create_task.assert_called_once_with(Task(name="foo"))
mock_rest.update_worker_task.assert_called_once_with(WorkerTask(task_id="foo"))
cast(MagicMock, mock_rest.create_task).assert_called_once_with(Task(name="foo"))
cast(MagicMock, mock_rest.update_worker_task).assert_called_once_with(
WorkerTask(task_id="foo")
)


def test_run_task_fails_on_failing_event(
client_with_events: BlueapiClient,
mock_rest: BlueapiRestClient,
mock_events: MagicMock,
mock_events: EventBusClient,
):
mock_rest.create_task.return_value = TaskResponse(task_id="foo")
mock_rest.update_worker_task.return_value = TaskResponse(task_id="foo")
cast(MagicMock, mock_rest.create_task).return_value = TaskResponse(task_id="foo")
cast(MagicMock, mock_rest.update_worker_task).return_value = TaskResponse(
task_id="foo"
)

ctx = Mock()
ctx = MagicMock()
ctx.correlation_id = "foo"
mock_events.subscribe_to_all_events = lambda on_event: on_event(FAILED_EVENT, ctx)
cast(MagicMock, mock_events).subscribe_to_all_events = lambda on_event: on_event(
FAILED_EVENT, ctx
)

on_event = Mock()
on_event = MagicMock()
with pytest.raises(BlueskyStreamingError):
client_with_events.run_task(Task(name="foo"), on_event=on_event)

Expand All @@ -372,22 +399,24 @@ def test_run_task_fails_on_failing_event(
def test_run_task_calls_event_callback(
client_with_events: BlueapiClient,
mock_rest: BlueapiRestClient,
mock_events: MagicMock,
mock_events: EventBusClient,
test_event: AnyEvent,
):
mock_rest.create_task.return_value = TaskResponse(task_id="foo")
mock_rest.update_worker_task.return_value = TaskResponse(task_id="foo")
cast(MagicMock, mock_rest.create_task).return_value = TaskResponse(task_id="foo")
cast(MagicMock, mock_rest.update_worker_task).return_value = TaskResponse(
task_id="foo"
)

ctx = Mock()
ctx = MagicMock()
ctx.correlation_id = "foo"

def callback(on_event: Callable[[AnyEvent, MessageContext], None]):
on_event(test_event, ctx)
on_event(COMPLETE_EVENT, ctx)

mock_events.subscribe_to_all_events = callback
cast(MagicMock, mock_events).subscribe_to_all_events = callback

mock_on_event = Mock()
mock_on_event = MagicMock()
client_with_events.run_task(Task(name="foo"), on_event=mock_on_event)

assert mock_on_event.mock_calls == [call(test_event), call(COMPLETE_EVENT)]
Expand All @@ -411,22 +440,24 @@ def callback(on_event: Callable[[AnyEvent, MessageContext], None]):
def test_run_task_ignores_non_matching_events(
client_with_events: BlueapiClient,
mock_rest: BlueapiRestClient,
mock_events: MagicMock,
mock_events: EventBusClient,
test_event: AnyEvent,
):
mock_rest.create_task.return_value = TaskResponse(task_id="foo")
mock_rest.update_worker_task.return_value = TaskResponse(task_id="foo")
cast(MagicMock, mock_rest.create_task).return_value = TaskResponse(task_id="foo")
cast(MagicMock, mock_rest.update_worker_task).return_value = TaskResponse(
task_id="foo"
)

ctx = Mock()
ctx = MagicMock()
ctx.correlation_id = "foo"

def callback(on_event: Callable[[AnyEvent, MessageContext], None]):
on_event(test_event, ctx)
on_event(COMPLETE_EVENT, ctx)

mock_events.subscribe_to_all_events = callback
cast(MagicMock, mock_events).subscribe_to_all_events = callback

mock_on_event = Mock()
mock_on_event = MagicMock()
client_with_events.run_task(Task(name="foo"), on_event=mock_on_event)

mock_on_event.assert_called_once_with(COMPLETE_EVENT)
Expand Down Expand Up @@ -506,8 +537,10 @@ def test_create_and_start_task_span_ok(
client: BlueapiClient,
mock_rest: BlueapiRestClient,
):
mock_rest.create_task.return_value = TaskResponse(task_id="baz")
mock_rest.update_worker_task.return_value = TaskResponse(task_id="baz")
cast(MagicMock, mock_rest.create_task).return_value = TaskResponse(task_id="baz")
cast(MagicMock, mock_rest.update_worker_task).return_value = TaskResponse(
task_id="baz"
)
with asserting_span_exporter(exporter, "create_and_start_task", "task"):
client.create_and_start_task(Task(name="baz"))

Expand Down

0 comments on commit 7b885bd

Please sign in to comment.