diff --git a/pyproject.toml b/pyproject.toml index 186ce78bba..de71dc06ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "aiohttp", "redis", "deepdiff", + "scanspec>=0.7.3", ] dynamic = ["version"] @@ -47,6 +48,7 @@ dev = [ # Commented out due to dependency version conflict with pydantic 1.x # "copier", "myst-parser", + "ophyd_async[sim]", "pipdeptree", "pre-commit", "psutil", diff --git a/src/dodal/plan_stubs/__init__.py b/src/dodal/plan_stubs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/dodal/plans/check_topup.py b/src/dodal/plan_stubs/check_topup.py similarity index 100% rename from src/dodal/plans/check_topup.py rename to src/dodal/plan_stubs/check_topup.py diff --git a/src/dodal/plans/data_session_metadata.py b/src/dodal/plan_stubs/data_session.py similarity index 93% rename from src/dodal/plans/data_session_metadata.py rename to src/dodal/plan_stubs/data_session.py index bfdd91d5c2..bed1946ddf 100644 --- a/src/dodal/plans/data_session_metadata.py +++ b/src/dodal/plan_stubs/data_session.py @@ -2,7 +2,7 @@ from bluesky import preprocessors as bpp from bluesky.utils import MsgGenerator, make_decorator -from dodal.common.beamlines import beamline_utils +from dodal.common.beamlines.beamline_utils import get_path_provider from dodal.common.types import UpdatingPathProvider DATA_SESSION = "data_session" @@ -31,7 +31,7 @@ def attach_data_session_metadata_wrapper( Iterator[Msg]: Plan messages """ if provider is None: - provider = beamline_utils.get_path_provider() + provider = get_path_provider() yield from bps.wait_for([provider.update]) ress = yield from bps.wait_for([provider.data_session]) data_session = ress[0].result() diff --git a/src/dodal/plans/motor_util_plans.py b/src/dodal/plan_stubs/motor_utils.py similarity index 98% rename from src/dodal/plans/motor_util_plans.py rename to src/dodal/plan_stubs/motor_utils.py index 7dfcb3037e..dbb4419a13 100644 --- a/src/dodal/plans/motor_util_plans.py +++ b/src/dodal/plan_stubs/motor_utils.py @@ -23,7 +23,7 @@ def __init__( super().__init__(*args) -def _check_and_cache_values( +def check_and_cache_values( devices_and_positions: dict[MovableReadableDevice, float], smallest_move: float, maximum_move: float, @@ -89,7 +89,7 @@ def move_and_reset_wrapper( on. If false it is left up to the caller to wait on them. Defaults to True. """ - initial_positions = yield from _check_and_cache_values( + initial_positions = yield from check_and_cache_values( device_and_positions, smallest_move, maximum_move ) diff --git a/src/dodal/plan_stubs/wrapped.py b/src/dodal/plan_stubs/wrapped.py new file mode 100644 index 0000000000..ba5651a49d --- /dev/null +++ b/src/dodal/plan_stubs/wrapped.py @@ -0,0 +1,150 @@ +import itertools +from collections.abc import Mapping +from typing import Annotated, Any + +import bluesky.plan_stubs as bps +from bluesky.protocols import Movable +from bluesky.utils import MsgGenerator + +""" +Wrappers for Bluesky built-in plan stubs with type hinting +""" + +Group = Annotated[str, "String identifier used by 'wait' or stubs that await"] + + +# After bluesky 1.14, bounds for stubs that move can be narrowed +# https://github.com/bluesky/bluesky/issues/1821 +def set_absolute( + movable: Movable, value: Any, group: Group | None = None, wait: bool = False +) -> MsgGenerator: + """ + Set a device, wrapper for `bp.abs_set`. + + Args: + movable (Movable): The device to set + value (T): The new value + group (Group | None, optional): The message group to associate with the + setting, for sequencing. Defaults to None. + wait (bool, optional): The group should wait until all setting is complete + (e.g. a motor has finished moving). Defaults to False. + + Returns: + MsgGenerator: Plan + + Yields: + Iterator[MsgGenerator]: Bluesky messages + """ + return (yield from bps.abs_set(movable, value, group=group, wait=wait)) + + +def set_relative( + movable: Movable, value: Any, group: Group | None = None, wait: bool = False +) -> MsgGenerator: + """ + Change a device, wrapper for `bp.rel_set`. + + Args: + movable (Movable): The device to set + value (T): The new value + group (Group | None, optional): The message group to associate with the + setting, for sequencing. Defaults to None. + wait (bool, optional): The group should wait until all setting is complete + (e.g. a motor has finished moving). Defaults to False. + + Returns: + MsgGenerator: Plan + + Yields: + Iterator[MsgGenerator]: Bluesky messages + """ + + return (yield from bps.rel_set(movable, value, group=group, wait=wait)) + + +def move(moves: Mapping[Movable, Any], group: Group | None = None) -> MsgGenerator: + """ + Move a device, wrapper for `bp.mv`. + + Args: + moves (Mapping[Movable, T]): Mapping of Movables to target positions + group (Group | None, optional): The message group to associate with the + setting, for sequencing. Defaults to None. + + Returns: + MsgGenerator: Plan + + Yields: + Iterator[MsgGenerator]: Bluesky messages + """ + + return ( + # type ignore until https://github.com/bluesky/bluesky/issues/1809 + yield from bps.mv(*itertools.chain.from_iterable(moves.items()), group=group) # type: ignore + ) + + +def move_relative( + moves: Mapping[Movable, Any], group: Group | None = None +) -> MsgGenerator: + """ + Move a device relative to its current position, wrapper for `bp.mvr`. + + Args: + moves (Mapping[Movable, T]): Mapping of Movables to target deltas + group (Group | None, optional): The message group to associate with the + setting, for sequencing. Defaults to None. + + Returns: + MsgGenerator: Plan + + Yields: + Iterator[MsgGenerator]: Bluesky messages + """ + + return ( + # type ignore until https://github.com/bluesky/bluesky/issues/1809 + yield from bps.mvr(*itertools.chain.from_iterable(moves.items()), group=group) # type: ignore + ) + + +def sleep(time: float) -> MsgGenerator: + """ + Suspend all action for a given time, wrapper for `bp.sleep` + + Args: + time (float): Time to wait in seconds + + Returns: + MsgGenerator: Plan + + Yields: + Iterator[MsgGenerator]: Bluesky messages + """ + + return (yield from bps.sleep(time)) + + +def wait( + group: Group | None = None, + timeout: float | None = None, +) -> MsgGenerator: + """ + Wait for a group status to complete, wrapper for `bp.wait`. + Does not expose move_on, as when used as a stub will not fail on Timeout. + + Args: + group (Group | None, optional): The name of the group to wait for, defaults + to None, in which case waits for all + groups that have not yet been awaited. + timeout (float | None, default=None): a timeout in seconds + + + Returns: + MsgGenerator: Plan + + Yields: + Iterator[MsgGenerator]: Bluesky messages + """ + + return (yield from bps.wait(group, timeout=timeout)) diff --git a/src/dodal/plans/__init__.py b/src/dodal/plans/__init__.py new file mode 100644 index 0000000000..fb40245969 --- /dev/null +++ b/src/dodal/plans/__init__.py @@ -0,0 +1,4 @@ +from .scanspec import spec_scan +from .wrapped import count + +__all__ = ["count", "spec_scan"] diff --git a/src/dodal/plans/scanspec.py b/src/dodal/plans/scanspec.py new file mode 100644 index 0000000000..660e0a0c19 --- /dev/null +++ b/src/dodal/plans/scanspec.py @@ -0,0 +1,66 @@ +import operator +from functools import reduce +from typing import Annotated, Any + +import bluesky.plans as bp +from bluesky.protocols import Movable, Readable +from cycler import Cycler, cycler +from pydantic import Field, validate_call +from scanspec.specs import Spec + +from dodal.common import MsgGenerator +from dodal.plan_stubs.data_session import attach_data_session_metadata_decorator + + +@attach_data_session_metadata_decorator() +@validate_call(config={"arbitrary_types_allowed": True}) +def spec_scan( + detectors: Annotated[ + set[Readable], + Field( + description="Set of readable devices, will take a reading at each point, \ + in addition to any Movables in the Spec", + ), + ], + spec: Annotated[ + Spec[Movable], + Field(description="ScanSpec modelling the path of the scan"), + ], + metadata: dict[str, Any] | None = None, +) -> MsgGenerator: + """Generic plan for reading `detectors` at every point of a ScanSpec `Spec`. + A `Spec` is an N-dimensional path. + """ + # TODO: https://github.com/bluesky/scanspec/issues/154 + # support Static.duration: Spec[Literal["DURATION"]] + + _md = { + "plan_args": { + "detectors": {det.name for det in detectors}, + "spec": repr(spec), + }, + "plan_name": "spec_scan", + "shape": spec.shape(), + **(metadata or {}), + } + + yield from bp.scan_nd(tuple(detectors), _as_cycler(spec), md=_md) + + +def _as_cycler(spec: Spec[Movable]) -> Cycler: + """ + Convert a scanspec to a cycler for compatibility with legacy Bluesky plans such as + `bp.scan_nd`. Use the midpoints of the scanspec since cyclers are normally used + for software triggered scans. + + Args: + spec: A scanspec + + Returns: + Cycler: A new cycler + """ + + midpoints = spec.frames().midpoints + # Need to "add" the cyclers for all the axes together. The code below is + # effectively: cycler(motor1, [...]) + cycler(motor2, [...]) + ... + return reduce(operator.add, (cycler(*args) for args in midpoints.items())) diff --git a/src/dodal/plans/wrapped.py b/src/dodal/plans/wrapped.py new file mode 100644 index 0000000000..9589bb1e77 --- /dev/null +++ b/src/dodal/plans/wrapped.py @@ -0,0 +1,57 @@ +from collections.abc import Sequence +from typing import Annotated, Any + +import bluesky.plans as bp +from bluesky.protocols import Readable +from pydantic import Field, NonNegativeFloat, validate_call + +from dodal.common import MsgGenerator +from dodal.plan_stubs.data_session import attach_data_session_metadata_decorator + +"""This module wraps plan(s) from bluesky.plans until required handling for them is +moved into bluesky or better handled in downstream services. + +Required decorators are installed on plan import +https://github.com/DiamondLightSource/blueapi/issues/474 + +Non-serialisable fields are ignored when they are optional +https://github.com/DiamondLightSource/blueapi/issues/711 + +We may also need other adjustments for UI purposes, e.g. +Forcing uniqueness or orderedness of Readables +Limits and metadata (e.g. units) +""" + + +@attach_data_session_metadata_decorator() +@validate_call(config={"arbitrary_types_allowed": True}) +def count( + detectors: Annotated[ + set[Readable], + Field( + description="Set of readable devices, will take a reading at each point", + min_length=1, + ), + ], + num: Annotated[int, Field(description="Number of frames to collect", ge=1)] = 1, + delay: Annotated[ + NonNegativeFloat | Sequence[NonNegativeFloat], + Field( + description="Delay between readings: if tuple, len(delay) == num - 1 and \ + the delays are between each point, if value or None is the delay for every \ + gap", + json_schema_extra={"units": "s"}, + ), + ] = 0.0, + metadata: dict[str, Any] | None = None, +) -> MsgGenerator: + """Reads from a number of devices. + Wraps bluesky.plans.count(det, num, delay, md=metadata) exposing only serializable + parameters and metadata.""" + if isinstance(delay, Sequence): + assert ( + len(delay) == num - 1 + ), f"Number of delays given must be {num - 1}: was given {len(delay)}" + metadata = metadata or {} + metadata["shape"] = (num,) + yield from bp.count(tuple(detectors), num, delay=delay, md=metadata) diff --git a/tests/plans/test_motor_util_plans.py b/tests/plan_stubs/test_motor_util_plans.py similarity index 95% rename from tests/plans/test_motor_util_plans.py rename to tests/plan_stubs/test_motor_util_plans.py index b9bc13a0d8..bde30adcde 100644 --- a/tests/plans/test_motor_util_plans.py +++ b/tests/plan_stubs/test_motor_util_plans.py @@ -14,9 +14,9 @@ from ophyd_async.epics.motor import Motor from dodal.devices.util.test_utils import patch_motor -from dodal.plans.motor_util_plans import ( +from dodal.plan_stubs.motor_utils import ( MoveTooLarge, - _check_and_cache_values, + check_and_cache_values, home_and_reset_wrapper, ) @@ -59,7 +59,7 @@ def my_device(RE): "device_type", [DeviceWithOnlyMotors, DeviceWithNoMotors, DeviceWithSomeMotors], ) -@patch("dodal.plans.motor_util_plans.move_and_reset_wrapper") +@patch("dodal.plan_stubs.motor_utils.move_and_reset_wrapper") def test_given_types_of_device_when_home_and_reset_wrapper_called_then_motors_and_zeros_passed_to_move_and_reset_wrapper( patch_move_and_reset, device_type, RE ): @@ -80,7 +80,7 @@ def test_given_a_device_when_check_and_cache_values_then_motor_values_returned( set_mock_value(motor.user_readback, i * 100) motors_and_positions: dict[Motor, float] = RE( - _check_and_cache_values( + check_and_cache_values( {motor_obj: 0.0 for motor_obj in my_device.motors}, 0, 1000 ) ).plan_result # type: ignore @@ -109,7 +109,7 @@ def test_given_a_device_with_a_too_large_move_when_check_and_cache_values_then_e motors_and_positions = {motor_obj: new_position for motor_obj in my_device.motors} with pytest.raises(MoveTooLarge) as e: - RE(_check_and_cache_values(motors_and_positions, 0, max)) + RE(check_and_cache_values(motors_and_positions, 0, max)) assert e.value.axis == my_device.y assert e.value.maximum_move == max @@ -136,7 +136,7 @@ def test_given_a_device_where_one_move_too_small_when_check_and_cache_values_the } motors_and_positions: dict[Motor, float] = RE( - _check_and_cache_values(motors_and_new_positions, min, 1000) + check_and_cache_values(motors_and_new_positions, min, 1000) ).plan_result # type: ignore cached_positions = motors_and_positions.values() @@ -156,7 +156,7 @@ def test_given_a_device_where_all_moves_too_small_when_check_and_cache_values_th motors_and_new_positions = {motor_obj: 0.0 for motor_obj in my_device.motors} motors_and_positions: dict[Motor, float] = RE( - _check_and_cache_values(motors_and_new_positions, 40, 1000) + check_and_cache_values(motors_and_new_positions, 40, 1000) ).plan_result # type: ignore cached_positions = motors_and_positions.values() diff --git a/tests/plans/test_topup_plan.py b/tests/plan_stubs/test_topup_plan.py similarity index 91% rename from tests/plans/test_topup_plan.py rename to tests/plan_stubs/test_topup_plan.py index 2b9fcce512..384e79779c 100644 --- a/tests/plans/test_topup_plan.py +++ b/tests/plan_stubs/test_topup_plan.py @@ -7,7 +7,7 @@ from dodal.beamlines import i03 from dodal.devices.synchrotron import Synchrotron, SynchrotronMode -from dodal.plans.check_topup import ( +from dodal.plan_stubs.check_topup import ( check_topup_and_wait_if_necessary, wait_for_topup_complete, ) @@ -18,8 +18,8 @@ def synchrotron(RE) -> Synchrotron: return i03.synchrotron(fake_with_ophyd_sim=True) -@patch("dodal.plans.check_topup.wait_for_topup_complete") -@patch("dodal.plans.check_topup.bps.sleep") +@patch("dodal.plan_stubs.check_topup.wait_for_topup_complete") +@patch("dodal.plan_stubs.check_topup.bps.sleep") def test_when_topup_before_end_of_collection_wait( fake_sleep: MagicMock, fake_wait: MagicMock, synchrotron: Synchrotron, RE: RunEngine ): @@ -37,8 +37,8 @@ def test_when_topup_before_end_of_collection_wait( fake_sleep.assert_called_once_with(61.0) -@patch("dodal.plans.check_topup.bps.rd") -@patch("dodal.plans.check_topup.bps.sleep") +@patch("dodal.plan_stubs.check_topup.bps.rd") +@patch("dodal.plan_stubs.check_topup.bps.sleep") def test_wait_for_topup_complete( fake_sleep: MagicMock, fake_rd: MagicMock, synchrotron: Synchrotron, RE: RunEngine ): @@ -59,8 +59,8 @@ def fake_generator(value): fake_sleep.assert_called_with(0.1) -@patch("dodal.plans.check_topup.bps.sleep") -@patch("dodal.plans.check_topup.bps.null") +@patch("dodal.plan_stubs.check_topup.bps.sleep") +@patch("dodal.plan_stubs.check_topup.bps.null") def test_no_waiting_if_decay_mode( fake_null: MagicMock, fake_sleep: MagicMock, synchrotron: Synchrotron, RE: RunEngine ): @@ -77,7 +77,7 @@ def test_no_waiting_if_decay_mode( assert fake_sleep.call_count == 0 -@patch("dodal.plans.check_topup.bps.null") +@patch("dodal.plan_stubs.check_topup.bps.null") def test_no_waiting_when_mode_does_not_allow_gating( fake_null: MagicMock, synchrotron: Synchrotron, RE: RunEngine ): @@ -120,7 +120,7 @@ def test_no_waiting_when_mode_does_not_allow_gating( (29, 39, 35, 1, 0, "topup_long_delay.txt"), ], ) -@patch("dodal.plans.check_topup.bps.sleep") +@patch("dodal.plan_stubs.check_topup.bps.sleep") def test_topup_not_allowed_when_exceeds_threshold_percentage_of_topup_time( mock_sleep, RE: RunEngine, diff --git a/tests/plan_stubs/test_wrapped_stubs.py b/tests/plan_stubs/test_wrapped_stubs.py new file mode 100644 index 0000000000..a8a58bc9f0 --- /dev/null +++ b/tests/plan_stubs/test_wrapped_stubs.py @@ -0,0 +1,144 @@ +from unittest.mock import ANY + +import pytest +from bluesky.run_engine import RunEngine +from bluesky.utils import Msg +from ophyd_async.core import ( + DeviceCollector, +) +from ophyd_async.sim.demo import SimMotor + +from dodal.plan_stubs.wrapped import ( + move, + move_relative, + set_absolute, + set_relative, + sleep, + wait, +) + + +@pytest.fixture +def x_axis(RE: RunEngine) -> SimMotor: + with DeviceCollector(): + x_axis = SimMotor() + return x_axis + + +@pytest.fixture +def y_axis(RE: RunEngine) -> SimMotor: + with DeviceCollector(): + y_axis = SimMotor() + return y_axis + + +def test_set_absolute(x_axis: SimMotor): + assert list(set_absolute(x_axis, 0.5)) == [Msg("set", x_axis, 0.5, group=None)] + + +def test_set_absolute_with_group(x_axis: SimMotor): + assert list(set_absolute(x_axis, 0.5, group="foo")) == [ + Msg("set", x_axis, 0.5, group="foo") + ] + + +def test_set_absolute_with_wait(x_axis: SimMotor): + msgs = list(set_absolute(x_axis, 0.5, wait=True)) + assert len(msgs) == 2 + assert msgs[0] == Msg("set", x_axis, 0.5, group=ANY) + assert msgs[1] == Msg("wait", group=msgs[0].kwargs["group"]) + + +def test_set_absolute_with_group_and_wait(x_axis: SimMotor): + assert list(set_absolute(x_axis, 0.5, group="foo", wait=True)) == [ + Msg("set", x_axis, 0.5, group="foo"), + Msg("wait", group="foo"), + ] + + +def test_set_relative(x_axis: SimMotor): + assert list(set_relative(x_axis, 0.5)) == [ + Msg("read", x_axis), + Msg("set", x_axis, 0.5, group=None), + ] + + +def test_set_relative_with_group(x_axis: SimMotor): + assert list(set_relative(x_axis, 0.5, group="foo")) == [ + Msg("read", x_axis), + Msg("set", x_axis, 0.5, group="foo"), + ] + + +def test_set_relative_with_wait(x_axis: SimMotor): + msgs = list(set_relative(x_axis, 0.5, wait=True)) + assert len(msgs) == 3 + assert msgs[0] == Msg("read", x_axis) + assert msgs[1] == Msg("set", x_axis, 0.5, group=ANY) + assert msgs[2] == Msg("wait", group=msgs[1].kwargs["group"]) + + +def test_set_relative_with_group_and_wait(x_axis: SimMotor): + assert list(set_relative(x_axis, 0.5, group="foo", wait=True)) == [ + Msg("read", x_axis), + Msg("set", x_axis, 0.5, group="foo"), + Msg("wait", group="foo"), + ] + + +def test_move(x_axis: SimMotor, y_axis: SimMotor): + msgs = list(move({x_axis: 0.5, y_axis: 1.0})) + assert msgs[0] == Msg("set", x_axis, 0.5, group=ANY) + assert msgs[1] == Msg("set", y_axis, 1.0, group=msgs[0].kwargs["group"]) + assert msgs[2] == Msg("wait", group=msgs[0].kwargs["group"]) + + +def test_move_group(x_axis: SimMotor, y_axis: SimMotor): + msgs = list(move({x_axis: 0.5, y_axis: 1.0}, group="foo")) + assert msgs[0] == Msg("set", x_axis, 0.5, group="foo") + assert msgs[1] == Msg("set", y_axis, 1.0, group="foo") + assert msgs[2] == Msg("wait", group="foo") + + +def test_move_relative(x_axis: SimMotor, y_axis: SimMotor): + msgs = list(move_relative({x_axis: 0.5, y_axis: 1.0})) + assert msgs[0] == Msg("read", x_axis) + assert msgs[1] == Msg("set", x_axis, 0.5, group=ANY) + group = msgs[1].kwargs["group"] + assert msgs[2] == Msg("read", y_axis) + assert msgs[3] == Msg("set", y_axis, 1.0, group=group) + assert msgs[4] == Msg("wait", group=group) + + +def test_move_relative_group(x_axis: SimMotor, y_axis: SimMotor): + msgs = list(move_relative({x_axis: 0.5, y_axis: 1.0}, group="foo")) + assert msgs[0] == Msg("read", x_axis) + assert msgs[1] == Msg("set", x_axis, 0.5, group="foo") + assert msgs[2] == Msg("read", y_axis) + assert msgs[3] == Msg("set", y_axis, 1.0, group="foo") + assert msgs[4] == Msg("wait", group="foo") + + +def test_sleep(): + assert list(sleep(1.5)) == [Msg("sleep", None, 1.5)] + + +def test_wait(): + # Waits for all groups + assert list(wait()) == [Msg("wait", group=None, timeout=None, move_on=False)] + + +def test_wait_group(): + assert list(wait("foo")) == [Msg("wait", group="foo", timeout=None, move_on=False)] + + +def test_wait_timeout(): + assert list(wait(timeout=5.0)) == [ + Msg("wait", group=None, timeout=5.0, move_on=False) + ] + + +def test_wait_group_and_timeout(): + assert list(wait("foo", 5.0)) == [ + Msg("wait", group="foo", timeout=5.0, move_on=False) + ] diff --git a/tests/plans/conftest.py b/tests/plans/conftest.py new file mode 100644 index 0000000000..45c7b944ee --- /dev/null +++ b/tests/plans/conftest.py @@ -0,0 +1,40 @@ +from pathlib import Path +from unittest.mock import patch + +import pytest +from bluesky.run_engine import RunEngine +from ophyd_async.core import DeviceCollector, PathProvider, StandardDetector +from ophyd_async.sim.demo import PatternDetector, SimMotor + + +@pytest.fixture +def det( + RE: RunEngine, + tmp_path: Path, + path_provider, +) -> StandardDetector: + with DeviceCollector(mock=True): + det = PatternDetector(tmp_path / "foo.h5") + return det + + +@pytest.fixture +def x_axis(RE: RunEngine) -> SimMotor: + with DeviceCollector(mock=True): + x_axis = SimMotor() + return x_axis + + +@pytest.fixture +def y_axis(RE: RunEngine) -> SimMotor: + with DeviceCollector(mock=True): + y_axis = SimMotor() + return y_axis + + +@pytest.fixture +def path_provider(static_path_provider: PathProvider): + # Prevents issue with leftover state from beamline tests + with patch("dodal.plan_stubs.data_session.get_path_provider") as mock: + mock.return_value = static_path_provider + yield diff --git a/tests/plans/test_compliance.py b/tests/plans/test_compliance.py new file mode 100644 index 0000000000..de4ef2846e --- /dev/null +++ b/tests/plans/test_compliance.py @@ -0,0 +1,78 @@ +import inspect +from collections.abc import Iterable +from types import ModuleType +from typing import Any, get_type_hints + +from bluesky.utils import MsgGenerator + +from dodal import plan_stubs, plans +from dodal.common.types import PlanGenerator + +""" +Bluesky distinguishes between `plans`: complete experimental proceedures, which open and +close data collection runs, and which may be part of a larger plan that collect data +multiple times, but may also be run alone to collect data, and `plan_stubs`: which +do not create & complete data collection runs and are either isolated behaviours or +building blocks for plans. + +In order to make it clearer when a MsgGenerator can be safely used without considering +the enclosing run (as opening a run whilst in a run without explicitly passing a RunID +is likely to cause both runs to fail), when it is required to manage a run and when +running a procedure will create data documents, we have adopted this standard. + +We further impose other requirements on both plans and stubs exported from these modules +to enable them to be exposed in UIs in a consistent way: +- They must have a docstring +- They must not use variadic arguments (*args, **kwargs) +- Plans must and Stubs may have an optional argument for metadata, named "metadata" +- The metadata argument, where present, must be optional with a default of None +""" + + +def is_bluesky_plan_generator(func: Any) -> bool: + try: + return callable(func) and get_type_hints(func).get("return") == MsgGenerator + except TypeError: + # get_type_hints fails on some objects (such as Union or Optional) + return False + + +def get_all_available_generators(mod: ModuleType) -> Iterable[PlanGenerator]: + for value in mod.__dict__.values(): + if is_bluesky_plan_generator(value): + yield value + + +def assert_hard_requirements(plan: PlanGenerator, signature: inspect.Signature): + assert plan.__doc__ is not None, f"'{plan.__name__}' has no docstring" + for parameter in signature.parameters.values(): + assert ( + parameter.kind is not parameter.VAR_POSITIONAL + and parameter.kind is not parameter.VAR_KEYWORD + ), f"'{plan.__name__}' has variadic arguments" + + +def assert_metadata_requirements(plan: PlanGenerator, signature: inspect.Signature): + assert ( + "metadata" in signature.parameters + ), f"'{plan.__name__}' does not allow metadata" + metadata = signature.parameters["metadata"] + assert ( + metadata.annotation == dict[str, Any] | None and metadata.default is None + ), f"'{plan.__name__}' metadata is not optional" + assert metadata.default is None, f"'{plan.__name__}' metadata default is mutable" + + +def test_plans_comply(): + for plan in get_all_available_generators(plans): + signature = inspect.Signature.from_callable(plan) + assert_hard_requirements(plan, signature) + assert_metadata_requirements(plan, signature) + + +def test_stubs_comply(): + for stub in get_all_available_generators(plan_stubs): + signature = inspect.Signature.from_callable(stub) + assert_hard_requirements(stub, signature) + if "metadata" in signature.parameters: + assert_metadata_requirements(stub, signature) diff --git a/tests/plans/test_scanspec.py b/tests/plans/test_scanspec.py new file mode 100644 index 0000000000..bfa0b6acb1 --- /dev/null +++ b/tests/plans/test_scanspec.py @@ -0,0 +1,183 @@ +from collections.abc import Sequence +from functools import reduce +from typing import cast + +import pytest +from bluesky.run_engine import RunEngine +from event_model.documents import ( + DocumentType, + Event, + EventDescriptor, + RunStart, + RunStop, + StreamResource, +) +from ophyd_async.core import StandardDetector +from ophyd_async.sim.demo import SimMotor +from scanspec.specs import Line + +from dodal.plans import spec_scan + + +@pytest.fixture +def documents_from_expected_shape( + request: pytest.FixtureRequest, + det: StandardDetector, + RE: RunEngine, + x_axis: SimMotor, + y_axis: SimMotor, +) -> dict[str, list[DocumentType]]: + shape: Sequence[int] = request.param + motors = [x_axis, y_axis] + # Does not support static, https://github.com/bluesky/scanspec/issues/154 + # spec = Static.duration(1) + spec = Line(motors[0], 0, 5, shape[0]) + for i in range(1, len(shape)): + spec = spec * Line(motors[i], 0, 5, shape[i]) + + docs: dict[str, list[DocumentType]] = {} + RE( + spec_scan({det}, spec), # type: ignore + lambda name, doc: docs.setdefault(name, []).append(doc), + ) + return docs + + +spec_and_shape = ( + # Does not support static, https://github.com/bluesky/scanspec/issues/154 + # [(), (1,)], # static + [(1,), (1,)], + [(3,), (3,)], + [(1, 1), (1, 1)], + [(3, 3), (3, 3)], +) + + +def length_from_shape(shape: tuple[int, ...]) -> int: + return reduce(lambda x, y: x * y, shape) + + +@pytest.mark.parametrize( + "documents_from_expected_shape, shape", + spec_and_shape, + indirect=["documents_from_expected_shape"], +) +def test_plan_produces_expected_start_document( + documents_from_expected_shape: dict[str, list[DocumentType]], + shape: tuple[int, ...], + x_axis: SimMotor, + y_axis: SimMotor, +): + axes = len(shape) + expected_data_keys = ( + [ + x_axis.hints.get("fields", [])[0], + y_axis.hints.get("fields", [])[0], + ] + if axes == 2 + else [x_axis.hints.get("fields", [])[0]] + ) + dimensions = [([data_key], "primary") for data_key in expected_data_keys] + docs = documents_from_expected_shape.get("start") + assert docs and len(docs) == 1 + start = cast(RunStart, docs[0]) + assert start.get("shape") == shape + assert (hints := start.get("hints")) + for dimension in dimensions: + assert dimension in hints.get("dimensions") # type: ignore + + +@pytest.mark.parametrize( + "documents_from_expected_shape, shape", + spec_and_shape, + indirect=["documents_from_expected_shape"], +) +def test_plan_produces_expected_stop_document( + documents_from_expected_shape: dict[str, list[DocumentType]], shape: tuple[int, ...] +): + docs = documents_from_expected_shape.get("stop") + assert docs and len(docs) == 1 + stop = cast(RunStop, docs[0]) + assert stop.get("num_events") == {"primary": length_from_shape(shape)} + assert stop.get("exit_status") == "success" + + +@pytest.mark.parametrize( + "documents_from_expected_shape, shape", + spec_and_shape, + indirect=["documents_from_expected_shape"], +) +def test_plan_produces_expected_descriptor( + documents_from_expected_shape: dict[str, list[DocumentType]], + det: StandardDetector, + shape: tuple[int, ...], +): + docs = documents_from_expected_shape.get("descriptor") + assert docs and len(docs) == 1 + descriptor = cast(EventDescriptor, docs[0]) + object_keys = descriptor.get("object_keys") + assert object_keys is not None and det.name in object_keys + assert descriptor.get("name") == "primary" + + +@pytest.mark.parametrize( + "documents_from_expected_shape, shape", + spec_and_shape, + indirect=["documents_from_expected_shape"], +) +def test_plan_produces_expected_events( + documents_from_expected_shape: dict[str, list[DocumentType]], + shape: tuple[int, ...], + det: StandardDetector, + x_axis: SimMotor, + y_axis: SimMotor, +): + axes = len(shape) + expected_data_keys = ( + { + x_axis.hints.get("fields", [])[0], + y_axis.hints.get("fields", [])[0], + } + if axes == 2 + else {x_axis.hints.get("fields", [])[0]} + ) + docs = documents_from_expected_shape.get("event") + assert docs and len(docs) == length_from_shape(shape) + for i in range(len(docs)): + event = cast(Event, docs[i]) + assert len(event.get("data")) == axes + assert event.get("data").keys() == expected_data_keys + assert event.get("seq_num") == i + 1 + + +@pytest.mark.parametrize( + "documents_from_expected_shape, shape", + spec_and_shape, + indirect=["documents_from_expected_shape"], +) +def test_plan_produces_expected_resources( + documents_from_expected_shape: dict[str, list[DocumentType]], + shape: tuple[int, ...], + det: StandardDetector, +): + docs = documents_from_expected_shape.get("stream_resource") + data_keys = [det.name, f"{det.name}-sum"] + assert docs and len(docs) == len(data_keys) + for i in range(len(docs)): + resource = cast(StreamResource, docs[i]) + assert resource.get("data_key") == data_keys[i] + + +@pytest.mark.parametrize( + "documents_from_expected_shape, shape", + spec_and_shape, + indirect=["documents_from_expected_shape"], +) +def test_plan_produces_expected_datums( + documents_from_expected_shape: dict[str, list[DocumentType]], + shape: tuple[int, ...], + det: StandardDetector, +): + docs = documents_from_expected_shape.get("stream_datum") + data_keys = [det.name, f"{det.name}-sum"] + assert docs and len(docs) == len(data_keys) * length_from_shape(shape) diff --git a/tests/plans/test_wrapped.py b/tests/plans/test_wrapped.py new file mode 100644 index 0000000000..2fbe289b72 --- /dev/null +++ b/tests/plans/test_wrapped.py @@ -0,0 +1,159 @@ +from collections.abc import Sequence +from typing import cast + +import pytest +from bluesky.protocols import Readable +from bluesky.run_engine import RunEngine +from event_model.documents import ( + DocumentType, + Event, + EventDescriptor, + RunStart, + RunStop, + StreamResource, +) +from ophyd_async.core import ( + StandardDetector, +) +from pydantic import ValidationError + +from dodal.plans.wrapped import count + + +@pytest.fixture +def documents_from_num( + request: pytest.FixtureRequest, det: StandardDetector, RE: RunEngine +) -> dict[str, list[DocumentType]]: + docs: dict[str, list[DocumentType]] = {} + RE( + count({det}, num=request.param), + lambda name, doc: docs.setdefault(name, []).append(doc), + ) + return docs + + +def test_count_delay_validation(det: StandardDetector, RE): + args: dict[float | Sequence[float], str] = { # type: ignore + # List wrong length + (1,): "Number of delays given must be 2: was given 1", + (1, 2, 3): "Number of delays given must be 2: was given 3", + # Delay non-physical + # negative time + -1: "Input should be greater than or equal to 0", + (-1, 2): "Input should be greater than or equal to 0", + # # null time + None: "Input should be a valid number", + (None, 2): "Input should be a valid number", + # # NaN time + "foo": "Input should be a valid number", + ("foo", 2): "Input should be a valid number", + } + for delay, reason in args.items(): + with pytest.raises((ValidationError, AssertionError), match=reason): + RE(count({det}, num=3, delay=delay)) + print(delay) + + +def test_count_detectors_validation(RE): + args: dict[str, set[Readable]] = { + # No device to read + "Set should have at least 1 item after validation, not 0": set(), + # Not Readable + "Input should be an instance of Readable": set("foo"), # type: ignore + } + for reason, dets in args.items(): + with pytest.raises(ValidationError, match=reason): + RE(count(dets)) + + +def test_count_num_validation(det: StandardDetector, RE): + args: dict[int, str] = { + -1: "Input should be greater than or equal to 1", + 0: "Input should be greater than or equal to 1", + "str": "Input should be a valid integer", # type: ignore + } + for num, reason in args.items(): + with pytest.raises(ValidationError, match=reason): + RE(count({det}, num=num)) + + +@pytest.mark.parametrize( + "documents_from_num, shape", ([1, (1,)], [3, (3,)]), indirect=["documents_from_num"] +) +def test_plan_produces_expected_start_document( + documents_from_num: dict[str, list[DocumentType]], shape: tuple[int, ...] +): + docs = documents_from_num.get("start") + assert docs and len(docs) == 1 + start = cast(RunStart, docs[0]) + assert start.get("shape") == shape + assert (hints := start.get("hints")) and ( + hints.get("dimensions") == [(("time",), "primary")] + ) + + +@pytest.mark.parametrize( + "documents_from_num, length", ([1, 1], [3, 3]), indirect=["documents_from_num"] +) +def test_plan_produces_expected_stop_document( + documents_from_num: dict[str, list[DocumentType]], length: int +): + docs = documents_from_num.get("stop") + assert docs and len(docs) == 1 + stop = cast(RunStop, docs[0]) + assert stop.get("num_events") == {"primary": length} + assert stop.get("exit_status") == "success" + + +@pytest.mark.parametrize("documents_from_num", [1], indirect=True) +def test_plan_produces_expected_descriptor( + documents_from_num: dict[str, list[DocumentType]], det: StandardDetector +): + docs = documents_from_num.get("descriptor") + assert docs and len(docs) == 1 + descriptor = cast(EventDescriptor, docs[0]) + object_keys = descriptor.get("object_keys") + assert object_keys is not None and det.name in object_keys + assert descriptor.get("name") == "primary" + + +@pytest.mark.parametrize( + "documents_from_num, length", ([1, 1], [3, 3]), indirect=["documents_from_num"] +) +def test_plan_produces_expected_events( + documents_from_num: dict[str, list[DocumentType]], + length: int, + det: StandardDetector, +): + docs = documents_from_num.get("event") + assert docs and len(docs) == length + for i in range(len(docs)): + event = cast(Event, docs[i]) + assert not event.get("data") # empty data + assert event.get("seq_num") == i + 1 + + +@pytest.mark.parametrize("documents_from_num", [1, 3], indirect=True) +def test_plan_produces_expected_resources( + documents_from_num: dict[str, list[DocumentType]], + det: StandardDetector, +): + docs = documents_from_num.get("stream_resource") + data_keys = [det.name, f"{det.name}-sum"] + assert docs and len(docs) == len(data_keys) + for i in range(len(docs)): + resource = cast(StreamResource, docs[i]) + assert resource.get("data_key") == data_keys[i] + + +@pytest.mark.parametrize( + "documents_from_num, length", ([1, 1], [3, 3]), indirect=["documents_from_num"] +) +def test_plan_produces_expected_datums( + documents_from_num: dict[str, list[DocumentType]], + length: int, + det: StandardDetector, +): + docs = documents_from_num.get("stream_datum") + data_keys = [det.name, f"{det.name}-sum"] + assert docs and len(docs) == len(data_keys) * length diff --git a/tests/preprocessors/test_filesystem_metadata.py b/tests/preprocessors/test_filesystem_metadata.py index a4ad0559ec..24b66a0d34 100644 --- a/tests/preprocessors/test_filesystem_metadata.py +++ b/tests/preprocessors/test_filesystem_metadata.py @@ -26,7 +26,7 @@ LocalDirectoryServiceClient, StaticVisitPathProvider, ) -from dodal.plans.data_session_metadata import ( +from dodal.plan_stubs.data_session import ( DATA_SESSION, attach_data_session_metadata_wrapper, )