diff --git a/src/ophyd_async/core/_detector.py b/src/ophyd_async/core/_detector.py index fbc4a7fbec..7212bc7842 100644 --- a/src/ophyd_async/core/_detector.py +++ b/src/ophyd_async/core/_detector.py @@ -20,7 +20,7 @@ WritesStreamAssets, ) from event_model import DataKey -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, NonNegativeInt, computed_field from ._device import Device from ._protocol import AsyncConfigurable, AsyncReadable @@ -54,7 +54,7 @@ class TriggerInfo(BaseModel): #: - 3 times for initial flat field images #: - 100 times for projections #: - 3 times for final flat field images - number: int | list[int] + number: NonNegativeInt | list[NonNegativeInt] #: Sort of triggers that will be sent trigger: DetectorTrigger = Field(default=DetectorTrigger.internal) #: What is the minimum deadtime between triggers @@ -69,6 +69,11 @@ class TriggerInfo(BaseModel): #: but publish 2 indices, and describe() will show a shape of (5, h, w) multiplier: int = 1 + @computed_field + @property + def total_frames(self) -> int: + return sum(self.number) if isinstance(self.number, list) else self.number + class DetectorControl(ABC): """ @@ -201,7 +206,7 @@ def __init__( self._current_number_index: int = 0 self._initial_frame: int self._last_frame: int - self._total_frames: int + self._total_frames_to_capture: int super().__init__(name) @@ -309,18 +314,14 @@ async def prepare(self, value: TriggerInfo) -> None: ) self._trigger_info = value self._initial_frame = await self.writer.get_indices_written() - self._scan_index: int = 0 if isinstance(self._trigger_info.number, list): - assert all( - frame >= 0 and type(frame) is int for frame in self._trigger_info.number - ), "Number of frames can only be greater than or equal to 0" - self._total_frames = sum(self._trigger_info.number) + self._total_frames_to_capture = sum(self._trigger_info.number) else: assert ( self._trigger_info.number >= 0 ), "Number of frames can only be greater than or equal to 0" - self._total_frames = self._trigger_info.number - self._last_frame = self._initial_frame + self._total_frames + self._total_frames_to_capture = self._trigger_info.number + self._last_frame = self._initial_frame + self._total_frames_to_capture self._describe, _ = await asyncio.gather( self.writer.open(value.multiplier), self.controller.prepare(value) ) @@ -331,10 +332,10 @@ async def prepare(self, value: TriggerInfo) -> None: @AsyncStatus.wrap async def kickoff(self): assert self._trigger_info, "Prepare must be called before kickoff!" - if self._frames_collected >= self._total_frames: + if self._frames_collected >= self._total_frames_to_capture: raise Exception( f"Kickoff called more than the configured number of " - f"{self._total_frames} iteration(s)!" + f"{self._total_frames_to_capture} iteration(s)!" ) if isinstance(self._trigger_info.number, list): self._frames_collected += self._trigger_info.number[ @@ -375,7 +376,7 @@ async def complete(self): else: if index >= self._trigger_info.number: break - if self._frames_collected == self._total_frames: + if self._frames_collected == self._total_frames_to_capture: self._frames_collected = 0 await self.controller.wait_for_idle() diff --git a/src/ophyd_async/epics/adaravis/_aravis_controller.py b/src/ophyd_async/epics/adaravis/_aravis_controller.py index d9761771c6..7bc33b29f9 100644 --- a/src/ophyd_async/epics/adaravis/_aravis_controller.py +++ b/src/ophyd_async/epics/adaravis/_aravis_controller.py @@ -30,12 +30,7 @@ def get_deadtime(self, exposure: float | None) -> float: return _HIGHEST_POSSIBLE_DEADTIME async def prepare(self, trigger_info: TriggerInfo): - num: int = ( - sum(trigger_info.number) - if isinstance(trigger_info.number, list) - else trigger_info.number - ) - if num == 0: + if trigger_info.total_frames == 0: image_mode = adcore.ImageMode.continuous else: image_mode = adcore.ImageMode.multiple @@ -48,7 +43,7 @@ async def prepare(self, trigger_info: TriggerInfo): await asyncio.gather( self._drv.trigger_source.set(trigger_source), - self._drv.num_images.set(num), + self._drv.num_images.set(trigger_info.total_frames), self._drv.image_mode.set(image_mode), ) diff --git a/src/ophyd_async/epics/adkinetix/_kinetix_controller.py b/src/ophyd_async/epics/adkinetix/_kinetix_controller.py index b77054c326..8096c4d788 100644 --- a/src/ophyd_async/epics/adkinetix/_kinetix_controller.py +++ b/src/ophyd_async/epics/adkinetix/_kinetix_controller.py @@ -27,14 +27,9 @@ def get_deadtime(self, exposure: float | None) -> float: return 0.001 async def prepare(self, trigger_info: TriggerInfo): - frames: int = ( - sum(trigger_info.number) - if isinstance(trigger_info.number, list) - else trigger_info.number - ) await asyncio.gather( self._drv.trigger_mode.set(KINETIX_TRIGGER_MODE_MAP[trigger_info.trigger]), - self._drv.num_images.set(frames), + self._drv.num_images.set(trigger_info.total_frames), self._drv.image_mode.set(adcore.ImageMode.multiple), ) if trigger_info.livetime is not None and trigger_info.trigger not in [ diff --git a/src/ophyd_async/epics/adpilatus/_pilatus_controller.py b/src/ophyd_async/epics/adpilatus/_pilatus_controller.py index f1b2ae31d0..51299cbde9 100644 --- a/src/ophyd_async/epics/adpilatus/_pilatus_controller.py +++ b/src/ophyd_async/epics/adpilatus/_pilatus_controller.py @@ -37,14 +37,11 @@ async def prepare(self, trigger_info: TriggerInfo): await adcore.set_exposure_time_and_acquire_period_if_supplied( self, self._drv, trigger_info.livetime ) - frames: int = ( - sum(trigger_info.number) - if isinstance(trigger_info.number, list) - else trigger_info.number - ) await asyncio.gather( self._drv.trigger_mode.set(self._get_trigger_mode(trigger_info.trigger)), - self._drv.num_images.set(999_999 if frames == 0 else frames), + self._drv.num_images.set( + 999_999 if trigger_info.total_frames == 0 else trigger_info.total_frames + ), self._drv.image_mode.set(adcore.ImageMode.multiple), ) diff --git a/src/ophyd_async/epics/adsimdetector/_sim_controller.py b/src/ophyd_async/epics/adsimdetector/_sim_controller.py index 9107bffc9b..402d95e68e 100644 --- a/src/ophyd_async/epics/adsimdetector/_sim_controller.py +++ b/src/ophyd_async/epics/adsimdetector/_sim_controller.py @@ -31,13 +31,8 @@ async def prepare(self, trigger_info: TriggerInfo): self.frame_timeout = ( DEFAULT_TIMEOUT + await self.driver.acquire_time.get_value() ) - frames: int = ( - sum(trigger_info.number) - if isinstance(trigger_info.number, list) - else trigger_info.number - ) await asyncio.gather( - self.driver.num_images.set(frames), + self.driver.num_images.set(trigger_info.total_frames), self.driver.image_mode.set(adcore.ImageMode.multiple), ) diff --git a/src/ophyd_async/epics/advimba/_vimba_controller.py b/src/ophyd_async/epics/advimba/_vimba_controller.py index 445306577e..71d41a98f7 100644 --- a/src/ophyd_async/epics/advimba/_vimba_controller.py +++ b/src/ophyd_async/epics/advimba/_vimba_controller.py @@ -34,15 +34,10 @@ def get_deadtime(self, exposure: float | None) -> float: return 0.001 async def prepare(self, trigger_info: TriggerInfo): - frames: int = ( - sum(trigger_info.number) - if isinstance(trigger_info.number, list) - else trigger_info.number - ) await asyncio.gather( self._drv.trigger_mode.set(TRIGGER_MODE[trigger_info.trigger]), self._drv.exposure_mode.set(EXPOSE_OUT_MODE[trigger_info.trigger]), - self._drv.num_images.set(frames), + self._drv.num_images.set(trigger_info.total_frames), self._drv.image_mode.set(adcore.ImageMode.multiple), ) if trigger_info.livetime is not None and trigger_info.trigger not in [ diff --git a/src/ophyd_async/epics/eiger/_eiger_controller.py b/src/ophyd_async/epics/eiger/_eiger_controller.py index 0ec47017e4..c3c3968adc 100644 --- a/src/ophyd_async/epics/eiger/_eiger_controller.py +++ b/src/ophyd_async/epics/eiger/_eiger_controller.py @@ -37,16 +37,11 @@ async def set_energy(self, energy: float, tolerance: float = 0.1): await self._drv.photon_energy.set(energy) async def prepare(self, trigger_info: TriggerInfo): - frames: int = ( - sum(trigger_info.number) - if isinstance(trigger_info.number, list) - else trigger_info.number - ) coros = [ self._drv.trigger_mode.set( EIGER_TRIGGER_MODE_MAP[trigger_info.trigger].value ), - self._drv.num_images.set(frames), + self._drv.num_images.set(trigger_info.total_frames), ] if trigger_info.livetime is not None: coros.extend( diff --git a/tests/core/test_flyer.py b/tests/core/test_flyer.py index bb48c7dc98..da3e82dce1 100644 --- a/tests/core/test_flyer.py +++ b/tests/core/test_flyer.py @@ -141,7 +141,7 @@ async def dummy_arm_2(self=None, trigger=None, num=0, exposure=None): @pytest.mark.parametrize("number_of_frames", [[1, 1, 1, 1], [2, 3, 100, 3]]) async def test_hardware_triggered_flyable( - RE: RunEngine, detectors: tuple[StandardDetector], number_of_frames + RE: RunEngine, detectors: tuple[StandardDetector], number_of_frames: list[int] ): docs = {}