Skip to content

Commit

Permalink
added code review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ZohebShaikh committed Sep 26, 2024
1 parent 72c36d3 commit 638dffa
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 51 deletions.
27 changes: 14 additions & 13 deletions src/ophyd_async/core/_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
WritesStreamAssets,
)
from event_model import DataKey
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, NonNegativeInt, computed_field

from ._device import Device
from ._protocol import AsyncConfigurable, AsyncReadable
Expand Down Expand Up @@ -54,7 +54,7 @@ class TriggerInfo(BaseModel):
#: - 3 times for initial flat field images
#: - 100 times for projections
#: - 3 times for final flat field images
number: int | list[int]
number: NonNegativeInt | list[NonNegativeInt]
#: Sort of triggers that will be sent
trigger: DetectorTrigger = Field(default=DetectorTrigger.internal)
#: What is the minimum deadtime between triggers
Expand All @@ -69,6 +69,11 @@ class TriggerInfo(BaseModel):
#: but publish 2 indices, and describe() will show a shape of (5, h, w)
multiplier: int = 1

@computed_field
@property
def total_frames(self) -> int:
return sum(self.number) if isinstance(self.number, list) else self.number


class DetectorControl(ABC):
"""
Expand Down Expand Up @@ -201,7 +206,7 @@ def __init__(
self._current_number_index: int = 0
self._initial_frame: int
self._last_frame: int
self._total_frames: int
self._total_frames_to_capture: int

super().__init__(name)

Expand Down Expand Up @@ -309,18 +314,14 @@ async def prepare(self, value: TriggerInfo) -> None:
)
self._trigger_info = value
self._initial_frame = await self.writer.get_indices_written()
self._scan_index: int = 0
if isinstance(self._trigger_info.number, list):
assert all(
frame >= 0 and type(frame) is int for frame in self._trigger_info.number
), "Number of frames can only be greater than or equal to 0"
self._total_frames = sum(self._trigger_info.number)
self._total_frames_to_capture = sum(self._trigger_info.number)
else:
assert (
self._trigger_info.number >= 0
), "Number of frames can only be greater than or equal to 0"
self._total_frames = self._trigger_info.number
self._last_frame = self._initial_frame + self._total_frames
self._total_frames_to_capture = self._trigger_info.number
self._last_frame = self._initial_frame + self._total_frames_to_capture
self._describe, _ = await asyncio.gather(
self.writer.open(value.multiplier), self.controller.prepare(value)
)
Expand All @@ -331,10 +332,10 @@ async def prepare(self, value: TriggerInfo) -> None:
@AsyncStatus.wrap
async def kickoff(self):
assert self._trigger_info, "Prepare must be called before kickoff!"
if self._frames_collected >= self._total_frames:
if self._frames_collected >= self._total_frames_to_capture:
raise Exception(
f"Kickoff called more than the configured number of "
f"{self._total_frames} iteration(s)!"
f"{self._total_frames_to_capture} iteration(s)!"
)
if isinstance(self._trigger_info.number, list):
self._frames_collected += self._trigger_info.number[
Expand Down Expand Up @@ -375,7 +376,7 @@ async def complete(self):
else:
if index >= self._trigger_info.number:
break
if self._frames_collected == self._total_frames:
if self._frames_collected == self._total_frames_to_capture:
self._frames_collected = 0
await self.controller.wait_for_idle()

Expand Down
9 changes: 2 additions & 7 deletions src/ophyd_async/epics/adaravis/_aravis_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,7 @@ def get_deadtime(self, exposure: float | None) -> float:
return _HIGHEST_POSSIBLE_DEADTIME

async def prepare(self, trigger_info: TriggerInfo):
num: int = (
sum(trigger_info.number)
if isinstance(trigger_info.number, list)
else trigger_info.number
)
if num == 0:
if trigger_info.total_frames == 0:
image_mode = adcore.ImageMode.continuous
else:
image_mode = adcore.ImageMode.multiple
Expand All @@ -48,7 +43,7 @@ async def prepare(self, trigger_info: TriggerInfo):

await asyncio.gather(
self._drv.trigger_source.set(trigger_source),
self._drv.num_images.set(num),
self._drv.num_images.set(trigger_info.total_frames),
self._drv.image_mode.set(image_mode),
)

Expand Down
7 changes: 1 addition & 6 deletions src/ophyd_async/epics/adkinetix/_kinetix_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,9 @@ def get_deadtime(self, exposure: float | None) -> float:
return 0.001

async def prepare(self, trigger_info: TriggerInfo):
frames: int = (
sum(trigger_info.number)
if isinstance(trigger_info.number, list)
else trigger_info.number
)
await asyncio.gather(
self._drv.trigger_mode.set(KINETIX_TRIGGER_MODE_MAP[trigger_info.trigger]),
self._drv.num_images.set(frames),
self._drv.num_images.set(trigger_info.total_frames),
self._drv.image_mode.set(adcore.ImageMode.multiple),
)
if trigger_info.livetime is not None and trigger_info.trigger not in [
Expand Down
9 changes: 3 additions & 6 deletions src/ophyd_async/epics/adpilatus/_pilatus_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,11 @@ async def prepare(self, trigger_info: TriggerInfo):
await adcore.set_exposure_time_and_acquire_period_if_supplied(
self, self._drv, trigger_info.livetime
)
frames: int = (
sum(trigger_info.number)
if isinstance(trigger_info.number, list)
else trigger_info.number
)
await asyncio.gather(
self._drv.trigger_mode.set(self._get_trigger_mode(trigger_info.trigger)),
self._drv.num_images.set(999_999 if frames == 0 else frames),
self._drv.num_images.set(
999_999 if trigger_info.total_frames == 0 else trigger_info.total_frames
),
self._drv.image_mode.set(adcore.ImageMode.multiple),
)

Expand Down
7 changes: 1 addition & 6 deletions src/ophyd_async/epics/adsimdetector/_sim_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,8 @@ async def prepare(self, trigger_info: TriggerInfo):
self.frame_timeout = (
DEFAULT_TIMEOUT + await self.driver.acquire_time.get_value()
)
frames: int = (
sum(trigger_info.number)
if isinstance(trigger_info.number, list)
else trigger_info.number
)
await asyncio.gather(
self.driver.num_images.set(frames),
self.driver.num_images.set(trigger_info.total_frames),
self.driver.image_mode.set(adcore.ImageMode.multiple),
)

Expand Down
7 changes: 1 addition & 6 deletions src/ophyd_async/epics/advimba/_vimba_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,10 @@ def get_deadtime(self, exposure: float | None) -> float:
return 0.001

async def prepare(self, trigger_info: TriggerInfo):
frames: int = (
sum(trigger_info.number)
if isinstance(trigger_info.number, list)
else trigger_info.number
)
await asyncio.gather(
self._drv.trigger_mode.set(TRIGGER_MODE[trigger_info.trigger]),
self._drv.exposure_mode.set(EXPOSE_OUT_MODE[trigger_info.trigger]),
self._drv.num_images.set(frames),
self._drv.num_images.set(trigger_info.total_frames),
self._drv.image_mode.set(adcore.ImageMode.multiple),
)
if trigger_info.livetime is not None and trigger_info.trigger not in [
Expand Down
7 changes: 1 addition & 6 deletions src/ophyd_async/epics/eiger/_eiger_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,11 @@ async def set_energy(self, energy: float, tolerance: float = 0.1):
await self._drv.photon_energy.set(energy)

async def prepare(self, trigger_info: TriggerInfo):
frames: int = (
sum(trigger_info.number)
if isinstance(trigger_info.number, list)
else trigger_info.number
)
coros = [
self._drv.trigger_mode.set(
EIGER_TRIGGER_MODE_MAP[trigger_info.trigger].value
),
self._drv.num_images.set(frames),
self._drv.num_images.set(trigger_info.total_frames),
]
if trigger_info.livetime is not None:
coros.extend(
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_flyer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ async def dummy_arm_2(self=None, trigger=None, num=0, exposure=None):

@pytest.mark.parametrize("number_of_frames", [[1, 1, 1, 1], [2, 3, 100, 3]])
async def test_hardware_triggered_flyable(
RE: RunEngine, detectors: tuple[StandardDetector], number_of_frames
RE: RunEngine, detectors: tuple[StandardDetector], number_of_frames: list[int]
):
docs = {}

Expand Down

0 comments on commit 638dffa

Please sign in to comment.