Skip to content

Commit

Permalink
added draft changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ZohebShaikh committed May 8, 2024
1 parent 991196c commit 39054ef
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 20 deletions.
15 changes: 8 additions & 7 deletions src/ophyd_async/core/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class DetectorWriter(ABC):
(e.g. an HDF5 file)"""

@abstractmethod
async def open(self, multiplier: int = 1) -> Dict[str, Descriptor]:
async def open(self, multiplier: int = 1) -> Dict[str, Descriptor]: # type: ignore
"""Open writer and wait for it to be ready for data.
Args:
Expand All @@ -122,7 +122,7 @@ async def open(self, multiplier: int = 1) -> Dict[str, Descriptor]:

@abstractmethod
def observe_indices_written(
self, timeout=DEFAULT_TIMEOUT
self, timeout: float | str = DEFAULT_TIMEOUT
) -> AsyncGenerator[int, None]:
"""Yield the index of each frame (or equivalent data point) as it is written"""

Expand All @@ -131,7 +131,7 @@ async def get_indices_written(self) -> int:
"""Get the number of indices written"""

@abstractmethod
def collect_stream_docs(self, indices_written: int) -> AsyncIterator[StreamAsset]:
def collect_stream_docs(self, indices_written: int) -> AsyncIterator[StreamAsset]: # type: ignore
"""Create Stream docs up to given number written"""

@abstractmethod
Expand Down Expand Up @@ -203,7 +203,7 @@ def controller(self) -> DetectorControl:
def writer(self) -> DetectorWriter:
return self._writer

@AsyncStatus.wrap
@AsyncStatus.wrap # type: ignore
async def stage(self) -> None:
# Disarm the detector, stop filewriting, and open file for writing.
await self._check_config_sigs()
Expand All @@ -212,20 +212,21 @@ async def stage(self) -> None:

async def _check_config_sigs(self):
"""Checks configuration signals are named and connected."""
signal: AsyncReadable
for signal in self._config_sigs:
if signal.name == "":
raise Exception(
"config signal must be named before it is passed to the detector"
)
try:
await signal.get_value()
await signal.get_value() # type: ignore
except NotImplementedError:
raise Exception(
f"config signal {signal._name} must be connected before it is "
f"config signal {signal._name} must be connected before it is " # type: ignore
+ "passed to the detector"
)

@AsyncStatus.wrap
@AsyncStatus.wrap # type: ignore
async def unstage(self) -> None:
# Stop data writing.
await self.writer.close()
Expand Down
28 changes: 15 additions & 13 deletions src/ophyd_async/core/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import sys
from types import FrameType
from typing import (
Any,
Coroutine,
Expand All @@ -15,13 +16,13 @@
TypeVar,
)

from bluesky.protocols import HasName
from bluesky.run_engine import call_in_bluesky_event_loop
from bluesky.protocols import HasName # type: ignore
from bluesky.run_engine import call_in_bluesky_event_loop # type: ignore

from .utils import DEFAULT_TIMEOUT, NotConnected, wait_for_connection


class Device(HasName):
class Device(HasName): # type: ignore
"""Common base class for all Ophyd Async Devices.
By default, names and connects all Device children.
Expand Down Expand Up @@ -127,9 +128,9 @@ class DeviceCollector:

def __init__(
self,
set_name=True,
connect=True,
sim=False,
set_name: bool = True,
connect: bool = True,
sim: bool = False,
timeout: float = 10.0,
):
self._set_name = set_name
Expand All @@ -146,9 +147,10 @@ def _caller_locals(self):
except ValueError:
_, _, tb = sys.exc_info()
assert tb, "Can't get traceback, this shouldn't happen"
caller_frame = tb.tb_frame
caller_frame: FrameType = tb.tb_frame
while caller_frame.f_locals.get("self", None) is self:
caller_frame = caller_frame.f_back
if caller_frame.f_back is not None:
caller_frame = caller_frame.f_back
return caller_frame.f_locals

def __enter__(self) -> "DeviceCollector":
Expand All @@ -161,7 +163,7 @@ async def __aenter__(self) -> "DeviceCollector":

async def _on_exit(self) -> None:
# Name and kick off connect for devices
connect_coroutines: Dict[str, Coroutine] = {}
connect_coroutines: Dict[str, Coroutine[Any, Any, Any]] = {}
for name, obj in self._objects_on_exit.items():
if name not in self._names_on_enter and isinstance(obj, Device):
if self._set_name and not obj.name:
Expand All @@ -175,18 +177,18 @@ async def _on_exit(self) -> None:
if connect_coroutines:
await wait_for_connection(**connect_coroutines)

async def __aexit__(self, type, value, traceback):
async def __aexit__(self):
self._objects_on_exit = self._caller_locals()
await self._on_exit()

def __exit__(self, type_, value, traceback):
def __exit__(self) -> Any:
self._objects_on_exit = self._caller_locals()
try:
fut = call_in_bluesky_event_loop(self._on_exit())
fut = call_in_bluesky_event_loop(self._on_exit()) # type: ignore
except RuntimeError:
raise NotConnected(
"Could not connect devices. Is the bluesky event loop running? See "
"https://blueskyproject.io/ophyd-async/main/"
"user/explanations/event-loop-choice.html for more info."
)
return fut
return fut # type: ignore

0 comments on commit 39054ef

Please sign in to comment.