Skip to content

Commit

Permalink
Add signal connection cache (#368)
Browse files Browse the repository at this point in the history
* from scratch precise update

* add tests for the force reconnect and caching feature

* further fixes

* add the check for no changing connect type

* Update signal.py

* removed `_previous_connect_was_mock` from device

* deleted tests that were redundant or not testing the right behavior

* adapt device tests

* wip: made `Device` also fail if `mock` value changes

* respond to time approx comment

* more tests modified

* fix timeout test

* stuck atmaking the connect task not None

* cleaned up tests for lazy connection

* delete self._initial_backend

---------

Co-authored-by: Eva Lott <[email protected]>
  • Loading branch information
stan-dot and evalott100 authored Jun 26, 2024
1 parent ce3b85d commit f8009fd
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 34 deletions.
1 change: 0 additions & 1 deletion src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@
"WatchableAsyncStatus",
"assert_configuration",
"assert_emitted",
"assert_mock_put_called_with",
"assert_reading",
"assert_value",
"callback_on_mock_put",
Expand Down
24 changes: 18 additions & 6 deletions src/ophyd_async/core/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ class Device(HasName):
parent: Optional["Device"] = None
# None if connect hasn't started, a Task if it has
_connect_task: Optional[asyncio.Task] = None
_connect_mock_arg: bool = False

# Used to check if the previous connect was mocked,
# if the next mock value differs then we fail
_previous_connect_was_mock = None

def __init__(self, name: str = "") -> None:
self.set_name(name)
Expand Down Expand Up @@ -90,11 +93,21 @@ async def connect(
timeout:
Time to wait before failing with a TimeoutError.
"""

if (
self._previous_connect_was_mock is not None
and self._previous_connect_was_mock != mock
):
raise RuntimeError(
f"`connect(mock={mock})` called on a `Device` where the previous "
f"connect was `mock={self._previous_connect_was_mock}`. Changing mock "
"value between connects is not permitted."
)
self._previous_connect_was_mock = mock

# If previous connect with same args has started and not errored, can use it
can_use_previous_connect = (
self._connect_task
and not (self._connect_task.done() and self._connect_task.exception())
and self._connect_mock_arg == mock
can_use_previous_connect = self._connect_task and not (
self._connect_task.done() and self._connect_task.exception()
)
if force_reconnect or not can_use_previous_connect:
# Kick off a connection
Expand All @@ -105,7 +118,6 @@ async def connect(
for name, child_device in self.children()
}
self._connect_task = asyncio.create_task(wait_for_connection(**coros))
self._connect_mock_arg = mock

assert self._connect_task, "Connect task not created, this shouldn't happen"
# Wait for it to complete
Expand Down
40 changes: 32 additions & 8 deletions src/ophyd_async/core/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
name: str = "",
) -> None:
self._timeout = timeout
self._initial_backend = self._backend = backend
self._backend = backend
super().__init__(name)

async def connect(
Expand All @@ -73,19 +73,43 @@ async def connect(
backend: Optional[SignalBackend[T]] = None,
):
if backend:
if self._initial_backend and backend is not self._initial_backend:
raise ValueError(
"Backend at connection different from initialised one."
)
if self._backend and backend is not self._backend:
raise ValueError("Backend at connection different from previous one.")

self._backend = backend
if mock and not isinstance(self._backend, MockSignalBackend):
if (
self._previous_connect_was_mock is not None
and self._previous_connect_was_mock != mock
):
raise RuntimeError(
f"`connect(mock={mock})` called on a `Signal` where the previous "
f"connect was `mock={self._previous_connect_was_mock}`. Changing mock "
"value between connects is not permitted."
)
self._previous_connect_was_mock = mock

if mock and not issubclass(type(self._backend), MockSignalBackend):
# Using a soft backend, look to the initial value
self._backend = MockSignalBackend(initial_backend=self._backend)

if self._backend is None:
raise RuntimeError("`connect` called on signal without backend")
self.log.debug(f"Connecting to {self.source}")
await self._backend.connect(timeout=timeout)

can_use_previous_connection: bool = self._connect_task is not None and not (
self._connect_task.done() and self._connect_task.exception()
)

if force_reconnect or not can_use_previous_connection:
self.log.debug(f"Connecting to {self.source}")
self._connect_task = asyncio.create_task(
self._backend.connect(timeout=timeout)
)
else:
self.log.debug(f"Reusing previous connection to {self.source}")
assert (
self._connect_task
), "this assert is for type analysis and will never fail"
await self._connect_task

@property
def source(self) -> str:
Expand Down
64 changes: 47 additions & 17 deletions tests/core/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,29 +128,59 @@ async def test_device_log_has_correct_name():


async def test_device_lazily_connects(RE):
async with DeviceCollector(mock=True, connect=False):
mock_motor = motor.Motor("BLxxI-MO-TABLE-01:X")
class MockSignalBackendFailingFirst(MockSignalBackend):
succeed_on_connect = False

assert not mock_motor._connect_task
async def connect(self, timeout=DEFAULT_TIMEOUT):
if self.succeed_on_connect:
self.succeed_on_connect = False
await super().connect(timeout=timeout)
else:
self.succeed_on_connect = True
raise RuntimeError("connect fail")

# When ready to connect
RE(ensure_connected(mock_motor, mock=True))
test_motor = motor.Motor("BLxxI-MO-TABLE-01:X")
test_motor.user_setpoint._backend = MockSignalBackendFailingFirst(int)

with pytest.raises(NotConnected, match="RuntimeError: connect fail"):
await test_motor.connect(mock=True)

assert (
mock_motor._connect_task
and mock_motor._connect_task.done()
and not mock_motor._connect_task.exception()
test_motor._connect_task
and test_motor._connect_task.done()
and test_motor._connect_task.exception()
)

RE(ensure_connected(test_motor, mock=True))

assert (
test_motor._connect_task
and test_motor._connect_task.done()
and not test_motor._connect_task.exception()
)

# TODO https://github.com/bluesky/ophyd-async/issues/413
RE(ensure_connected(test_motor, mock=True, force_reconnect=True))

assert (
test_motor._connect_task
and test_motor._connect_task.done()
and test_motor._connect_task.exception()
)


async def test_device_mock_and_back_again(RE):
async def test_device_refuses_two_connects_differing_on_mock_attribute(RE):
motor = SimMotor("motor")
assert not motor._connect_task
await motor.connect(mock=False)
assert isinstance(motor.units._backend, SoftSignalBackend)
assert motor._connect_task
await motor.connect(mock=True)
assert isinstance(motor.units._backend, MockSignalBackend)
with pytest.raises(RuntimeError) as exc:
await motor.connect(mock=True)
assert str(exc.value) == (
"`connect(mock=True)` called on a `Device` where the previous connect was "
"`mock=False`. Changing mock value between connects is not permitted."
)


class MotorBundle(Device):
Expand Down Expand Up @@ -185,7 +215,7 @@ async def test_device_with_children_lazily_connects(RE):
)


async def test_device_with_device_collector_lazily_connects():
async def test_device_with_device_collector_refuses_to_connect_if_mock_switch():
mock_motor = motor.Motor("NONE_EXISTENT")
with pytest.raises(NotConnected):
await mock_motor.connect(mock=False, timeout=0.01)
Expand All @@ -194,11 +224,11 @@ async def test_device_with_device_collector_lazily_connects():
and mock_motor._connect_task.done()
and mock_motor._connect_task.exception()
)
await mock_motor.connect(mock=True, timeout=0.01)
assert (
mock_motor._connect_task is not None
and mock_motor._connect_task.done()
and not mock_motor._connect_task.exception()
with pytest.raises(RuntimeError) as exc:
await mock_motor.connect(mock=True, timeout=0.01)
assert str(exc.value) == (
"`connect(mock=True)` called on a `Device` where the previous connect was "
"`mock=False`. Changing mock value between connects is not permitted."
)


Expand Down
97 changes: 95 additions & 2 deletions tests/core/test_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
wait_for_value,
)
from ophyd_async.core.signal import _SignalCache
from ophyd_async.core.utils import DEFAULT_TIMEOUT
from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw
from ophyd_async.plan_stubs import ensure_connected


async def test_signals_equality_raises():
Expand Down Expand Up @@ -63,8 +65,9 @@ async def test_signal_can_be_given_backend_on_connect():
async def test_signal_connect_fails_with_different_backend_on_connection():
sim_signal = Signal(MockSignalBackend(str))

with pytest.raises(ValueError):
with pytest.raises(ValueError) as exc:
await sim_signal.connect(mock=True, backend=MockSignalBackend(int))
assert str(exc.value) == "Backend at connection different from previous one."

with pytest.raises(ValueError):
await sim_signal.connect(mock=True, backend=SoftSignalBackend(str))
Expand All @@ -76,11 +79,101 @@ async def test_signal_connect_fails_if_different_backend_but_same_by_value():

with pytest.raises(ValueError) as exc:
await sim_signal.connect(mock=False, backend=MockSignalBackend(str))
assert str(exc.value) == "Backend at connection different from initialised one."
assert str(exc.value) == "Backend at connection different from previous one."

await sim_signal.connect(mock=False, backend=initial_backend)


async def test_signal_connects_to_previous_backend(caplog):
caplog.set_level(logging.DEBUG)
int_mock_backend = MockSignalBackend(int)
original_connect = int_mock_backend.connect

async def new_connect(timeout=1):
await asyncio.sleep(0.1)
await original_connect(timeout=timeout)

int_mock_backend.connect = new_connect
signal = Signal(int_mock_backend)
assert await time_taken_by(
asyncio.gather(signal.connect(), signal.connect())
) == pytest.approx(0.1, rel=1e-2)
response = f"Reusing previous connection to {signal.source}"
assert response in caplog.text


async def test_signal_connects_with_force_reconnect(caplog):
caplog.set_level(logging.DEBUG)
signal = Signal(MockSignalBackend(int))
await signal.connect()
assert signal._backend.datatype == int
await signal.connect(force_reconnect=True)
response = f"Connecting to {signal.source}"
assert response in caplog.text
assert "Reusing previous connection to" not in caplog.text


@pytest.mark.parametrize(
"first, second",
[(True, False), (True, False)],
)
async def test_rejects_reconnect_when_connects_have_diff_mock_status(
caplog, first, second
):
caplog.set_level(logging.DEBUG)
signal = Signal(MockSignalBackend(int))
await signal.connect(mock=first)
assert signal._backend.datatype == int
with pytest.raises(RuntimeError) as exc:
await signal.connect(mock=second)

assert f"`connect(mock={second})` called on a `Signal` where the previous " in str(
exc.value
)

response = f"Connecting to {signal.source}"
assert response in caplog.text


async def test_signal_lazily_connects(RE):
class MockSignalBackendFailingFirst(MockSignalBackend):
succeed_on_connect = False

async def connect(self, timeout=DEFAULT_TIMEOUT):
if self.succeed_on_connect:
self.succeed_on_connect = False
await super().connect(timeout=timeout)
else:
self.succeed_on_connect = True
raise RuntimeError("connect fail")

signal = SignalRW(MockSignalBackendFailingFirst(int))

with pytest.raises(RuntimeError, match="connect fail"):
await signal.connect(mock=False)

assert (
signal._connect_task
and signal._connect_task.done()
and signal._connect_task.exception()
)

RE(ensure_connected(signal, mock=False))
assert (
signal._connect_task
and signal._connect_task.done()
and not signal._connect_task.exception()
)

# TODO https://github.com/bluesky/ophyd-async/issues/413
RE(ensure_connected(signal, mock=False, force_reconnect=True))
assert (
signal._connect_task
and signal._connect_task.done()
and signal._connect_task.exception()
)


async def time_taken_by(coro) -> float:
start = time.monotonic()
await coro
Expand Down

0 comments on commit f8009fd

Please sign in to comment.