Skip to content

Commit

Permalink
Fix a number of pyright typing errors (#669)
Browse files Browse the repository at this point in the history
  • Loading branch information
DominicOram authored Jul 12, 2024
1 parent e3a3ba4 commit 6f02f54
Show file tree
Hide file tree
Showing 12 changed files with 51 additions and 48 deletions.
8 changes: 4 additions & 4 deletions src/dodal/devices/attenuator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, prefix: str, name: str = ""):
super().__init__(name)

@AsyncStatus.wrap
async def set(self, transmission: float):
async def set(self, value: float):
"""Set the transmission to the fractional (0-1) value given.
The attenuator IOC will then insert filters to reach the desired transmission for
Expand All @@ -58,16 +58,16 @@ async def set(self, transmission: float):

LOGGER.debug("Using current energy ")
await self._use_current_energy.trigger()
LOGGER.info(f"Setting desired transmission to {transmission}")
await self._desired_transmission.set(transmission)
LOGGER.info(f"Setting desired transmission to {value}")
await self._desired_transmission.set(value)
LOGGER.debug("Sending change filter command")
await self._change.trigger()

await asyncio.gather(
*[
wait_for_value(
self._filters_in_position[i],
await self._calculated_filter_states[i].get_value(),
bool(await self._calculated_filter_states[i].get_value()),
None,
)
for i in range(16)
Expand Down
4 changes: 3 additions & 1 deletion src/dodal/devices/detector/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ class Config:
DetectorSizeConstants: lambda d: d.det_type_string,
}

@root_validator(pre=True, skip_on_failure=True) # type: ignore # should be replaced with model_validator once move to pydantic 2 is complete
@root_validator(
pre=True, skip_on_failure=True
) # should be replaced with model_validator once move to pydantic 2 is complete
def create_beamxy_and_runnumber(cls, values: dict[str, Any]) -> dict[str, Any]:
values["beam_xy_converter"] = DetectorDistanceToBeamXYConverter(
values["det_dist_to_beam_converter_path"]
Expand Down
16 changes: 9 additions & 7 deletions src/dodal/devices/eiger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from ophyd import Component, Device, EpicsSignalRO, Signal
from ophyd.areadetector.cam import EigerDetectorCam
from ophyd.status import AndStatus, Status, SubscriptionStatus
from ophyd.status import AndStatus, Status, StatusBase

from dodal.devices.detector import DetectorParams, TriggerMode
from dodal.devices.eiger_odin import EigerOdin
Expand All @@ -27,6 +27,7 @@ class InternalEigerTriggerMode(Enum):
class EigerDetector(Device):
class ArmingSignal(Signal):
def set(self, value, *, timeout=None, settle_time=None, **kwargs):
assert isinstance(self.parent, EigerDetector)
return self.parent.async_stage()

do_arm = Component(ArmingSignal)
Expand All @@ -41,7 +42,7 @@ def set(self, value, *, timeout=None, settle_time=None, **kwargs):
ALL_FRAMES_TIMEOUT = 120
ARMING_TIMEOUT = 60

filewriters_finished: SubscriptionStatus
filewriters_finished: StatusBase

detector_params: DetectorParams | None = None

Expand Down Expand Up @@ -155,7 +156,7 @@ def disable_roi_mode(self):
def enable_roi_mode(self):
return self.change_roi_mode(True)

def change_roi_mode(self, enable: bool) -> Status:
def change_roi_mode(self, enable: bool) -> StatusBase:
assert self.detector_params is not None
detector_dimensions = (
self.detector_params.detector_size_constants.roi_size_pixels
Expand Down Expand Up @@ -206,7 +207,7 @@ def set_odin_number_of_frame_chunks(self) -> Status:
)
return status

def set_odin_pvs(self) -> Status:
def set_odin_pvs(self) -> StatusBase:
assert self.detector_params is not None
file_prefix = self.detector_params.full_filename
status = self.odin.file_writer.file_path.set(
Expand Down Expand Up @@ -264,7 +265,7 @@ def set_detector_threshold(self, energy: float, tolerance: float = 0.1) -> Statu
status.set_finished()
return status

def set_num_triggers_and_captures(self) -> Status:
def set_num_triggers_and_captures(self) -> StatusBase:
"""Sets the number of triggers and the number of images for the Eiger to capture
during the datacollection. The number of images is the number of images per
trigger.
Expand Down Expand Up @@ -295,7 +296,7 @@ def set_num_triggers_and_captures(self) -> Status:

return status

def _wait_for_odin_status(self) -> Status:
def _wait_for_odin_status(self) -> StatusBase:
self.forward_bit_depth_to_filewriter()
await_value(self.odin.meta.active, 1).wait(self.GENERAL_STATUS_TIMEOUT)

Expand All @@ -308,7 +309,7 @@ def _wait_for_odin_status(self) -> Status:
)
return status

def _wait_fan_ready(self) -> Status:
def _wait_fan_ready(self) -> StatusBase:
self.filewriters_finished = self.odin.create_finished_status()
LOGGER.info("Eiger staging: awaiting odin fan ready")
return await_value(self.odin.fan.ready, 1, self.GENERAL_STATUS_TIMEOUT)
Expand All @@ -332,6 +333,7 @@ def disarm_detector(self):

def do_arming_chain(self) -> Status:
functions_to_do_arm = []
assert self.detector_params
detector_params: DetectorParams = self.detector_params
if detector_params.use_roi_mode:
functions_to_do_arm.append(self.enable_roi_mode)
Expand Down
6 changes: 3 additions & 3 deletions src/dodal/devices/eiger_odin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ophyd import Component, Device, EpicsSignal, EpicsSignalRO, EpicsSignalWithRBV
from ophyd.areadetector.plugins import HDF5Plugin_V22
from ophyd.sim import NullStatus
from ophyd.status import Status, SubscriptionStatus
from ophyd.status import StatusBase

from dodal.devices.status import await_value

Expand Down Expand Up @@ -120,7 +120,7 @@ class EigerOdin(Device):
meta = Component(OdinMetaListener, "OD:META:")
nodes = Component(OdinNodesStatus, "")

def create_finished_status(self) -> SubscriptionStatus:
def create_finished_status(self) -> StatusBase:
writing_finished = await_value(self.meta.ready, 0)
for node_pv in self.nodes.nodes:
writing_finished &= await_value(node_pv.writing, 0)
Expand Down Expand Up @@ -157,7 +157,7 @@ def check_odin_initialised(self) -> Tuple[bool, str]:

return not errors, "\n".join(errors)

def stop(self) -> Status:
def stop(self) -> StatusBase:
"""Stop odin manually"""
status = self.file_writer.capture.set(0)
status &= self.meta.stop_writing.set(1)
Expand Down
1 change: 1 addition & 0 deletions src/dodal/devices/oav/oav_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class ZoomController(Device):
sxst = Component(EpicsSignal, "MP:SELECT.SXST")

def set_flatfield_on_zoom_level_one(self, value):
self.parent: "OAV"
flat_applied = self.parent.proc.port_name.get()
no_flat_applied = self.parent.cam.port_name.get()
return self.parent.grid_snapshot.input_plugin.set(
Expand Down
24 changes: 13 additions & 11 deletions src/dodal/devices/oav/pin_image_recognition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@ def __init__(self, prefix: str, name: str = ""):
self._prefix: str = prefix
self._name = name

self.triggered_tip, _ = soft_signal_r_and_setter(Tip, name="triggered_tip")
self.triggered_top_edge, _ = soft_signal_r_and_setter(
self.triggered_tip, self._tip_setter = soft_signal_r_and_setter(
Tip, name="triggered_tip"
)
self.triggered_top_edge, self._top_edge_setter = soft_signal_r_and_setter(
NDArray[np.uint32], name="triggered_top_edge"
)
self.triggered_bottom_edge, _ = soft_signal_r_and_setter(
self.triggered_bottom_edge, self._bottom_edge_setter = soft_signal_r_and_setter(
NDArray[np.uint32], name="triggered_bottom_edge"
)
self.array_data = epics_signal_r(NDArray[np.uint8], f"pva://{prefix}PVA:ARRAY")
Expand Down Expand Up @@ -85,14 +87,14 @@ def __init__(self, prefix: str, name: str = ""):

super().__init__(name=name)

async def _set_triggered_values(self, results: SampleLocation):
def _set_triggered_values(self, results: SampleLocation):
tip = (results.tip_x, results.tip_y)
if tip == self.INVALID_POSITION:
raise InvalidPinException
else:
await self.triggered_tip._backend.put(tip)
await self.triggered_top_edge._backend.put(results.edge_top)
await self.triggered_bottom_edge._backend.put(results.edge_bottom)
self._tip_setter(tip)
self._top_edge_setter(results.edge_top)
self._bottom_edge_setter(results.edge_bottom)

async def _get_tip_and_edge_data(
self, array_data: NDArray[np.uint8]
Expand Down Expand Up @@ -150,7 +152,7 @@ async def _set_triggered_tip():
async for value in observe_value(self.array_data):
try:
location = await self._get_tip_and_edge_data(value)
await self._set_triggered_values(location)
self._set_triggered_values(location)
except Exception as e:
LOGGER.warn(
f"Failed to detect pin-tip location, will retry with next image: {e}"
Expand All @@ -166,6 +168,6 @@ async def _set_triggered_tip():
LOGGER.error(
f"No tip found in {await self.validity_timeout.get_value()} seconds."
)
await self.triggered_tip._backend.put(self.INVALID_POSITION)
await self.triggered_bottom_edge._backend.put(np.array([]))
await self.triggered_top_edge._backend.put(np.array([]))
self._tip_setter(self.INVALID_POSITION)
self._bottom_edge_setter(np.array([]))
self._top_edge_setter(np.array([]))
4 changes: 2 additions & 2 deletions src/dodal/devices/util/epics_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Callable
from typing import Callable, Sequence

from bluesky.protocols import Movable
from ophyd import Component, EpicsSignal
Expand All @@ -26,7 +26,7 @@ def epics_signal_put_wait(pv_name: str, wait: float = 3.0) -> Component[EpicsSig


def run_functions_without_blocking(
functions_to_chain: list[Callable[[], StatusBase]],
functions_to_chain: Sequence[Callable[[], StatusBase]],
timeout: float = 60.0,
associated_obj: OphydDevice | None = None,
) -> Status:
Expand Down
6 changes: 2 additions & 4 deletions tests/devices/unit_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,8 @@ def test_if_one_status_errors_then_later_functions_not_called():
NullStatus,
tester,
]
expected_obj = "TEST OBJECT"
returned_status = run_functions_without_blocking(
status_calls, associated_obj=expected_obj
status_calls, associated_obj=MagicMock()
)
with pytest.raises(StatusException):
returned_status.wait(0.1)
Expand All @@ -154,9 +153,8 @@ def test_if_one_status_pending_then_later_functions_not_called():
NullStatus,
tester,
]
expected_obj = "TEST OBJECT"
returned_status = run_functions_without_blocking(
status_calls, associated_obj=expected_obj
status_calls, associated_obj=MagicMock()
)
with pytest.raises(WaitTimeoutError):
returned_status.wait(0.1)
Expand Down
2 changes: 1 addition & 1 deletion tests/devices/unit_tests/test_xbpm_feedback.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from bluesky import RunEngine
from bluesky import plan_stubs as bps
from bluesky.run_engine import RunEngine
from ophyd_async.core import DeviceCollector, set_mock_value

from dodal.devices.xbpm_feedback import XBPMFeedback
Expand Down
9 changes: 4 additions & 5 deletions tests/devices/unit_tests/test_zebra.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, MagicMock

import pytest
from bluesky.run_engine import RunEngine
from mockito import mock, verify
from ophyd_async.core import set_mock_value

from dodal.devices.zebra import (
Expand Down Expand Up @@ -100,8 +99,8 @@ async def run_configurer_test(
configurer = LogicGateConfigurer(prefix="", name="test fake logicconfigurer")
await configurer.connect(mock=True)

mock_gate_control = mock()
mock_pvs = [mock() for i in range(6)]
mock_gate_control = MagicMock()
mock_pvs = [MagicMock() for i in range(6)]
mock_gate_control.enable = mock_pvs[0]
mock_gate_control.sources = {i: mock_pvs[i] for i in range(1, 5)}
mock_gate_control.invert = mock_pvs[5]
Expand All @@ -113,7 +112,7 @@ async def run_configurer_test(
configurer.apply_or_gate_config(gate_num, config)

for pv, value in zip(mock_pvs, expected_pv_values):
verify(pv).set(value)
pv.set.assert_called_once_with(value)


async def test_apply_and_logic_gate_configuration_32_and_51_inv_and_1():
Expand Down
6 changes: 2 additions & 4 deletions tests/preprocessors/test_filesystem_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import bluesky.plans as bp
import pytest
from aiohttp import ClientResponseError
from bluesky import RunEngine
from bluesky.preprocessors import (
run_decorator,
run_wrapper,
Expand All @@ -18,12 +17,11 @@
HasName,
Readable,
Reading,
Status,
Triggerable,
)
from bluesky.run_engine import RunEngine
from event_model.documents.event_descriptor import DataKey
from ophyd.status import StatusBase
from ophyd.status import Status
from ophyd_async.core import DeviceCollector, DirectoryProvider
from pydantic import BaseModel

Expand Down Expand Up @@ -76,7 +74,7 @@ async def describe(self) -> dict[str, DataKey]:
}

def trigger(self) -> Status:
status = StatusBase()
status = Status()
status.set_finished()
return status

Expand Down
13 changes: 7 additions & 6 deletions tests/unit_tests/test_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ def test_handlers_set_at_correct_default_level(
for handler in handlers.values():
mock_logger.addHandler.assert_any_call(handler)

handlers["debug_memory_handler"].setLevel.assert_called_once_with(logging.DEBUG)
handlers["graylog_handler"].setLevel.assert_called_once_with(logging.INFO)
handlers["info_file_handler"].setLevel.assert_any_call(logging.INFO)
handlers["info_file_handler"].setLevel.assert_any_call(logging.DEBUG)
handlers["stream_handler"].setLevel.assert_called_once_with(logging.INFO)
handlers["debug_memory_handler"].setLevel.assert_called_once_with(logging.DEBUG) # type: ignore
handlers["graylog_handler"].setLevel.assert_called_once_with(logging.INFO) # type: ignore
handlers["info_file_handler"].setLevel.assert_any_call(logging.INFO) # type: ignore
handlers["info_file_handler"].setLevel.assert_any_call(logging.DEBUG) # type: ignore
handlers["stream_handler"].setLevel.assert_called_once_with(logging.INFO) # type: ignore


@patch("dodal.log.GELFTCPHandler", autospec=True)
Expand Down Expand Up @@ -134,7 +134,7 @@ def test_messages_logged_from_dodal_get_sent_to_graylog_and_file(
mock_graylog_handler_class.assert_called_once_with(
"graylog-log-target.diamond.ac.uk", 12231
)
mock_GELFTCPHandler.handle.assert_called()
mock_GELFTCPHandler.handle.assert_called() # type: ignore
mock_filehandler_emit.assert_called()


Expand Down Expand Up @@ -172,6 +172,7 @@ def mock_set_up_graylog_handler(logger, host, port):
assert mock_GELFTCPHandler.port == 5555

LOGGER.info("test")
assert isinstance(mock_GELFTCPHandler.emit, MagicMock)
mock_GELFTCPHandler.emit.assert_called()
assert mock_GELFTCPHandler.emit.call_args.args[0].beamline == "dev"

Expand Down

0 comments on commit 6f02f54

Please sign in to comment.