Skip to content
This repository has been archived by the owner on Sep 2, 2024. It is now read-only.

Commit

Permalink
Migrate tests to Bluesky RunEngineSimulator (#1488)
Browse files Browse the repository at this point in the history
* replace SimRunEngine in tests with library version
  • Loading branch information
dperl-dls authored Jul 26, 2024
1 parent 51d02ba commit 4cbcc13
Show file tree
Hide file tree
Showing 17 changed files with 176 additions and 350 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ install_requires =
daq-config-server @ git+https://github.com/DiamondLightSource/daq-config-server.git
ophyd == 1.9.0
ophyd-async >= 0.3a5
bluesky >= 1.13.0a3
bluesky >= 1.13.0a4
blueapi >= 0.4.3-rc1
dls-dodal @ git+https://github.com/DiamondLightSource/dodal.git@fff4a4e8fdcf534de6768c2b45a0dfa7a2e2ccc4

Expand Down
4 changes: 1 addition & 3 deletions src/hyperion/external_interaction/ispyb/ispyb_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,7 @@ def _begin_or_update_deposition(
)
)
else:
assert (
ispyb_ids.data_collection_group_id
), "Attempt to update data collection without a data collection group ID"
assert ispyb_ids.data_collection_group_id, "Attempt to update data collection without a data collection group ID"

grid_ids = list(ispyb_ids.grid_ids)
data_collection_ids_out = list(ispyb_ids.data_collection_ids)
Expand Down
152 changes: 2 additions & 150 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import sys
import threading
from functools import partial
from typing import Any, Callable, Generator, Optional, Sequence
from typing import Any, Generator, Sequence
from unittest.mock import MagicMock, patch

import bluesky.plan_stubs as bps
import pytest
from bluesky.run_engine import RunEngine
from bluesky.simulators import RunEngineSimulator
from bluesky.utils import Msg
from dodal.beamlines import i03
from dodal.common.beamlines import beamline_utils
Expand Down Expand Up @@ -743,155 +744,6 @@ def extract_metafile(input_filename, output_filename):
output_fo.write(metafile_fo.read())


class RunEngineSimulator:
"""This class simulates a Bluesky RunEngine by recording and injecting responses to messages according to the
bluesky Message Protocol (see bluesky docs for details).
Basic usage consists of
1) Registering various handlers to respond to anticipated messages in the experiment plan and fire any
needed callbacks.
2) Calling simulate_plan()
3) Examining the returned message list and making asserts against them"""

def __init__(self):
self.message_handlers = []
self.callbacks = {}
self.next_callback_token = 0

def add_handler_for_callback_subscribes(self):
"""Add a handler that registers all the callbacks from subscribe messages so we can call them later.
You probably want to call this as one of the first things unless you have a good reason not to.
"""
self.message_handlers.append(
MessageHandler(
lambda msg: msg.command == "subscribe",
lambda msg: self._add_callback(msg.args),
)
)

def add_handler(
self,
commands: Sequence[str],
obj_name: Optional[str],
handler: Callable[[Msg], object],
):
"""Add the specified handler for a particular message
Args:
commands: the command name for the message as defined in bluesky Message Protocol, or a sequence if more
than one matches
obj_name: the name property of the obj to match, can be None as not all messages have a name
handler: a lambda that accepts a Msg and returns an object; the object is sent to the current yield statement
in the generator, and is used when reading values from devices, the structure of the object depends on device
hinting.
"""
if isinstance(commands, str):
commands = [commands]

self.message_handlers.append(
MessageHandler(
lambda msg: msg.command in commands
and (obj_name is None or (msg.obj and msg.obj.name == obj_name)),
handler,
)
)

def add_wait_handler(
self, handler: Callable[[Msg], None], group: str = "any"
) -> None:
"""Add a wait handler for a particular message
Args:
handler: a lambda that accepts a Msg, use this to execute any code that simulates something that's
supposed to complete when a group finishes
group: name of the group to wait for, default is any which matches them all
"""
self.message_handlers.append(
MessageHandler(
lambda msg: msg.command == "wait"
and (group == "any" or msg.kwargs["group"] == group),
handler,
)
)

def fire_callback(self, document_name, document) -> None:
"""Fire all the callbacks registered for this document type in order to simulate something happening
Args:
document_name: document name as defined in the Bluesky Message Protocol 'subscribe' call,
all subscribers filtering on this document name will be called
document: the document to send
"""
for callback_func, callback_docname in self.callbacks.values():
if callback_docname == "all" or callback_docname == document_name:
callback_func(document_name, document)

def simulate_plan(self, gen: Generator[Msg, object, object]) -> list[Msg]:
"""Simulate the RunEngine executing the plan
Args:
gen: the generator function that executes the plan
Returns:
a list of the messages generated by the plan
"""
messages = []
send_value = None
try:
while msg := gen.send(send_value):
send_value = None
messages.append(msg)
LOGGER.debug(f"<{msg}")
if handler := next(
(h for h in self.message_handlers if h.predicate(msg)), None
):
send_value = handler.runnable(msg)

if send_value:
LOGGER.debug(f">send {send_value}")
except StopIteration:
pass
return messages

def _add_callback(self, msg_args):
self.callbacks[self.next_callback_token] = msg_args
self.next_callback_token += 1

def assert_message_and_return_remaining(
self,
messages: list[Msg],
predicate: Callable[[Msg], bool],
group: Optional[str] = None,
):
"""Find the next message matching the predicate, assert that we found it
Return: all the remaining messages starting from the matched message"""
indices = [
i
for i in range(len(messages))
if (
not group
or (messages[i].kwargs and messages[i].kwargs.get("group") == group)
)
and predicate(messages[i])
]
assert indices, f"Nothing matched predicate {predicate}"
return messages[indices[0] :]

def mock_message_generator(
self,
function_name: str,
) -> Callable[..., Generator[Msg, object, object]]:
"""Returns a callable that returns a generator yielding a Msg object recording the call arguments.
This can be used to mock methods returning a bluesky plan or portion thereof, call it from within a unit test
using the RunEngineSimulator, and then perform asserts on the message to verify in-order execution of the plan
"""

def mock_method(*args, **kwargs):
yield Msg(function_name, None, *args, **kwargs)

return mock_method


class MessageHandler:
def __init__(self, p: Callable[[Msg], bool], r: Callable[[Msg], object]):
self.predicate = p
self.runnable = r


@pytest.fixture
def sim_run_engine():
return RunEngineSimulator()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
from bluesky.run_engine import RunEngine
from bluesky.simulators import RunEngineSimulator, assert_message_and_return_remaining
from dodal.common.beamlines.beamline_parameters import GDABeamlineParameters
from dodal.devices.focusing_mirror import (
FocusingMirrorWithStripes,
Expand Down Expand Up @@ -165,51 +166,51 @@ def test_adjust_dcm_pitch_roll_vfm_from_lut(
vfm: FocusingMirrorWithStripes,
vfm_mirror_voltages: VFMMirrorVoltages,
beamline_parameters: GDABeamlineParameters,
sim_run_engine,
sim_run_engine: RunEngineSimulator,
):
sim_run_engine.add_handler_for_callback_subscribes()
sim_run_engine.add_handler(
"read",
"dcm-bragg_in_degrees",
lambda msg: {"dcm-bragg_in_degrees": {"value": 5.0}},
"dcm-bragg_in_degrees",
)

messages = sim_run_engine.simulate_plan(
adjust_dcm_pitch_roll_vfm_from_lut(undulator_dcm, vfm, vfm_mirror_voltages, 7.5)
)

messages = sim_run_engine.assert_message_and_return_remaining(
messages = assert_message_and_return_remaining(
messages,
lambda msg: msg.command == "set"
and msg.obj.name == "dcm-pitch_in_mrad"
and abs(msg.args[0] - -0.75859) < 1e-5
and msg.kwargs["group"] == "DCM_GROUP",
)
messages = sim_run_engine.assert_message_and_return_remaining(
messages = assert_message_and_return_remaining(
messages[1:],
lambda msg: msg.command == "set"
and msg.obj.name == "dcm-roll_in_mrad"
and abs(msg.args[0] - 4.0) < 1e-5
and msg.kwargs["group"] == "DCM_GROUP",
)
messages = sim_run_engine.assert_message_and_return_remaining(
messages = assert_message_and_return_remaining(
messages[1:],
lambda msg: msg.command == "set"
and msg.obj.name == "dcm-offset_in_mm"
and msg.args == (25.6,)
and msg.kwargs["group"] == "DCM_GROUP",
)
messages = sim_run_engine.assert_message_and_return_remaining(
messages = assert_message_and_return_remaining(
messages[1:],
lambda msg: msg.command == "set"
and msg.obj.name == "vfm-stripe"
and msg.args == (MirrorStripe.RHODIUM,),
)
messages = sim_run_engine.assert_message_and_return_remaining(
messages = assert_message_and_return_remaining(
messages[1:],
lambda msg: msg.command == "wait",
)
messages = sim_run_engine.assert_message_and_return_remaining(
messages = assert_message_and_return_remaining(
messages[1:],
lambda msg: msg.command == "trigger" and msg.obj.name == "vfm-apply_stripe",
)
Expand All @@ -223,17 +224,17 @@ def test_adjust_dcm_pitch_roll_vfm_from_lut(
(6, 4),
(7, -46),
):
messages = sim_run_engine.assert_message_and_return_remaining(
messages = assert_message_and_return_remaining(
messages[1:],
lambda msg: msg.command == "set"
and msg.obj.name == f"vfm_mirror_voltages-voltage_channels-{channel}"
and msg.args == (expected_voltage,),
)
messages = sim_run_engine.assert_message_and_return_remaining(
messages = assert_message_and_return_remaining(
messages[1:],
lambda msg: msg.command == "wait" and msg.kwargs["group"] == "DCM_GROUP",
)
messages = sim_run_engine.assert_message_and_return_remaining(
messages = assert_message_and_return_remaining(
messages[1:],
lambda msg: msg.command == "set"
and msg.obj.name == "vfm-x_mm"
Expand Down
27 changes: 12 additions & 15 deletions tests/unit_tests/device_setup_plans/test_setup_panda.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from bluesky.plan_stubs import null
from bluesky.run_engine import RunEngine
from bluesky.simulators import RunEngineSimulator
from dodal.devices.fast_grid_scan import PandAGridScanParams
from ophyd_async.panda import SeqTrigger

Expand All @@ -16,14 +17,14 @@
setup_panda_for_flyscan,
)

from ...conftest import RunEngineSimulator


def get_smargon_speed(x_step_size_mm: float, time_between_x_steps_ms: float) -> float:
return x_step_size_mm / time_between_x_steps_ms


def run_simulating_setup_panda_functions(plan: str, mock_load_device=MagicMock):
def run_simulating_setup_panda_functions(
plan: str, sim_run_engine: RunEngineSimulator, mock_load_device=MagicMock
):
num_of_sets = 0
num_of_waits = 0
mock_panda = MagicMock()
Expand All @@ -37,11 +38,7 @@ def count_commands(msg):
num_of_waits += 1

sim = RunEngineSimulator()
sim.add_handler(
["set", "wait"],
None,
count_commands,
)
sim.add_handler(["set", "wait"], count_commands)

if plan == "setup":
smargon_speed = get_smargon_speed(0.1, 1)
Expand All @@ -63,9 +60,9 @@ def count_commands(msg):


@patch("hyperion.device_setup_plans.setup_panda.load_device")
def test_setup_panda_performs_correct_plans(mock_load_device):
def test_setup_panda_performs_correct_plans(mock_load_device, sim_run_engine):
num_of_sets, num_of_waits = run_simulating_setup_panda_functions(
"setup", mock_load_device
"setup", sim_run_engine, mock_load_device
)
mock_load_device.assert_called_once()
assert num_of_sets == 9
Expand Down Expand Up @@ -181,9 +178,7 @@ def assert_set_table_has_been_waited_on(*args, **kwargs):
), patch(
"hyperion.device_setup_plans.setup_panda.bps.wait",
MagicMock(side_effect=handle_wait),
), patch(
"hyperion.device_setup_plans.setup_panda.load_device"
), patch(
), patch("hyperion.device_setup_plans.setup_panda.load_device"), patch(
"hyperion.device_setup_plans.setup_panda.bps.abs_set"
):
RE(
Expand All @@ -201,8 +196,10 @@ def assert_set_table_has_been_waited_on(*args, **kwargs):

# It also would be useful to have some system tests which check that (at least)
# all the blocks which were enabled on setup are also disabled on tidyup
def test_disarm_panda_disables_correct_blocks():
num_of_sets, num_of_waits = run_simulating_setup_panda_functions("disarm")
def test_disarm_panda_disables_correct_blocks(sim_run_engine):
num_of_sets, num_of_waits = run_simulating_setup_panda_functions(
"disarm", sim_run_engine
)
assert num_of_sets == 6
assert num_of_waits == 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import pytest
from bluesky.run_engine import RunEngine, RunEngineResult
from bluesky.simulators import assert_message_and_return_remaining
from bluesky.utils import FailedStatus, Msg
from dodal.beamlines import i03
from dodal.common.beamlines.beamline_utils import clear_device
Expand Down Expand Up @@ -998,30 +999,30 @@ def test_read_hardware_during_collection_occurs_after_eiger_arm(
)
sim_run_engine.add_handler(
"read",
"synchrotron-synchrotron_mode",
lambda msg: {"values": {"value": SynchrotronMode.USER}},
"synchrotron-synchrotron_mode",
)
msgs = sim_run_engine.simulate_plan(
run_gridscan(
fake_fgs_composite, test_fgs_params_panda_zebra, feature_controlled
)
)
msgs = sim_run_engine.assert_message_and_return_remaining(
msgs = assert_message_and_return_remaining(
msgs, lambda msg: msg.command == "stage" and msg.obj.name == "eiger"
)
msgs = sim_run_engine.assert_message_and_return_remaining(
msgs = assert_message_and_return_remaining(
msgs,
lambda msg: msg.command == "kickoff"
and msg.obj == feature_controlled.fgs_motors,
)
msgs = sim_run_engine.assert_message_and_return_remaining(
msgs = assert_message_and_return_remaining(
msgs, lambda msg: msg.command == "create"
)
msgs = sim_run_engine.assert_message_and_return_remaining(
msgs = assert_message_and_return_remaining(
msgs,
lambda msg: msg.command == "read" and msg.obj.name == "eiger_bit_depth",
)
msgs = sim_run_engine.assert_message_and_return_remaining(
msgs = assert_message_and_return_remaining(
msgs, lambda msg: msg.command == "save"
)

Expand Down
Loading

0 comments on commit 4cbcc13

Please sign in to comment.