Skip to content

Commit

Permalink
Merge pull request #57 from shilorigins/devagr/save-snapshot
Browse files Browse the repository at this point in the history
ENH: Implement snapshot saving
  • Loading branch information
shilorigins committed Aug 2, 2024
2 parents dc7ded7 + 0fcf43b commit 6d44dc9
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 47 deletions.
23 changes: 23 additions & 0 deletions docs/source/upcoming_release_notes/57-snapshot_saving.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
57 snapshot saving
#################

API Breaks
----------
- N/A

Features
--------
- ability to save Snapshots for a Collection

Bugfixes
--------
- N/A

Maintenance
-----------
- generalized Client._gather_data to work on any type of Entry
- made Client._gather_data iterative rather than recursive to simplify the conditional logic

Contributors
------------
- shilorigins
168 changes: 126 additions & 42 deletions superscore/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
import logging
import os
from pathlib import Path
from typing import Any, Generator, List, Optional, Union
from typing import Any, Dict, Generator, List, Optional, Union
from uuid import UUID

from superscore.backends import get_backend
from superscore.backends.core import _Backend
from superscore.control_layers import ControlLayer
from superscore.control_layers.status import TaskStatus
from superscore.model import Entry, Setpoint, Snapshot
from superscore.errors import CommunicationError
from superscore.model import (Collection, Entry, Nestable, Parameter, Readback,
Setpoint, Snapshot)
from superscore.type_hints import AnyEpicsType
from superscore.utils import build_abs_path

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -163,6 +166,35 @@ def compare(self, entry_l: Entry, entry_r: Entry) -> Any:
"""Compare two entries. Should be of same type, and return a diff"""
raise NotImplementedError

def snap(self, entry: Collection) -> Snapshot:
"""
Asyncronously read data for all PVs under ``entry``, and store in a
Snapshot. PVs that can't be read will have an exception as their value.
Parameters
----------
entry : Collection
the Collection to save
Returns
-------
Snapshot
a Snapshot corresponding to the input Collection
"""
logger.debug(f"Saving Snapshot for Collection {entry.uuid}")
pvs, _ = self._gather_data(entry)
pvs.extend(Collection.meta_pvs)
values = self.cl.get(pvs)
data = {}
for pv, value in zip(pvs, values):
if isinstance(value, CommunicationError):
logger.debug(f"Couldn't read value for {pv}, storing \"None\"")
data[pv] = None
else:
logger.debug(f"Storing {pv} = {value}")
data[pv] = value
return self._build_snapshot(entry, data)

def apply(
self,
entry: Union[Setpoint, Snapshot],
Expand Down Expand Up @@ -194,7 +226,7 @@ def apply(

# Gather pv-value list and apply at once
status_list = []
pv_list, data_list = self._gather_data(entry)
pv_list, data_list = self._gather_data(entry, writable_only=True)
if sequential:
for pv, data in zip(pv_list, data_list):
logger.debug(f'Putting {pv} = {data}')
Expand All @@ -210,57 +242,109 @@ def apply(

def _gather_data(
self,
entry: Union[Setpoint, Snapshot, UUID],
pv_list: Optional[List[str]] = None,
data_list: Optional[List[Any]] = None
) -> Optional[tuple[List[str], List[Any]]]:
entry: Union[Entry, UUID],
writable_only: bool = False,
) -> tuple[List[str], Optional[List[Any]]]:
"""
Gather writable pv name - data pairs recursively.
If pv_list and data_list are provided, gathered data will be added to
these lists in-place. If both lists are omitted, this function will return
the two lists after gathering.
Queries the backend to fill any UUID values found.
Gather PV name - data pairs that are accessible from ``entry``. Queries
the backend to fill any UUIDs found.
Parameters
----------
entry : Union[Setpoint, Snapshot, UUID]
Entry to gather writable data from
pv_list : Optional[List[str]], optional
List of addresses to write data to, by default None
data_list : Optional[List[Any]], optional
List of data to write to addresses in ``pv_list``, by default None
entry : Union[Entry, UUID]
Entry to gather data from
writable_only : bool
If True, only include writable data e.g. omit Readbacks; by default False
Returns
-------
Optional[tuple[List[str], List[Any]]]
tuple[List[str], Optional[List[Any]]]
the filled pv_list and data_list
"""
top_level = False
if (pv_list is None) and (data_list is None):
pv_list = []
data_list = []
top_level = True
elif (pv_list is None) or (data_list is None):
raise ValueError(
"Arguments pv_list and data_list must either both be provided "
"or both omitted."
)
pv_list = []
data_list = []

seen = set()
q = [entry]
while len(q) > 0:
entry = q.pop()
uuid = entry if isinstance(entry, UUID) else entry.uuid
if uuid in seen:
continue
elif isinstance(entry, UUID):
entry = self.backend.get_entry(entry)
seen.add(entry.uuid)

if isinstance(entry, Nestable):
q.extend(reversed(entry.children)) # preserve execution order
elif isinstance(entry, Readback) and writable_only:
pass
else: # entry is Parameter, Setpoint, or Readback
pv_list.append(entry.pv_name)
if hasattr(entry, "data"):
data_list.append(entry.data)
if getattr(entry, "readback", None) is not None:
q.append(entry.readback)
return pv_list, data_list

def _build_snapshot(
self,
coll: Collection,
values: Dict[str, AnyEpicsType],
) -> Snapshot:
"""
Traverse a Collection, assembling a Snapshot using pre-fetched data
along the way
if isinstance(entry, Snapshot):
for child in entry.children:
self._gather_data(child, pv_list, data_list)
elif isinstance(entry, UUID):
child_entry = self.backend.get_entry(entry)
self._gather_data(child_entry, pv_list, data_list)
elif isinstance(entry, Setpoint):
pv_list.append(entry.pv_name)
data_list.append(entry.data)
Parameters
----------
coll : Collection
The collection being saved
values : Dict[str, AnyEpicsType]
A dictionary mapping PV names to pre-fetched values
# Readbacks are not writable, and are not gathered
Returns
-------
Snapshot
A Snapshot corresponding to the input Collection
"""
snapshot = Snapshot(
title=coll.title,
tags=coll.tags.copy(),
origin_collection=coll
)

for child in coll.children:
if isinstance(child, UUID):
child = self.backend.get(child)
if isinstance(child, Parameter):
if child.readback is not None:
readback = Readback(
pv_name=child.readback.pv_name,
description=child.readback.description,
data=values[child.readback.pv_name]
)
else:
readback = None
setpoint = Setpoint(
pv_name=child.pv_name,
description=child.description,
data=values[child.pv_name],
readback=readback
)
snapshot.children.append(setpoint)
elif isinstance(child, Collection):
snapshot.append(self._build_snapshot(child, values))

snapshot.meta_pvs = []
for pv in Collection.meta_pvs:
readback = Readback(
pv_name=readback.pv_name,
data=values[readback.pv_name]
)
snapshot.meta_pvs.append(readback)

if top_level:
return pv_list, data_list
return snapshot

def validate(self, entry: Entry):
"""
Expand Down
5 changes: 3 additions & 2 deletions superscore/control_layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
import asyncio
import logging
from collections.abc import Iterable
from functools import singledispatchmethod
from typing import Any, Callable, Dict, List, Optional, Union

Expand Down Expand Up @@ -99,13 +100,13 @@ def _get_single(self, address: str) -> Any:
return asyncio.run(self._get_one(address))

@get.register
def _get_list(self, address: list) -> Any:
def _get_list(self, address: Iterable) -> Any:
"""Synchronously get a list of ``address``"""
async def gathered_coros():
coros = []
for p in address:
coros.append(self._get_one(p))
return await asyncio.gather(*coros)
return await asyncio.gather(*coros, return_exceptions=True)

return asyncio.run(gathered_coros())

Expand Down
40 changes: 40 additions & 0 deletions superscore/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,46 @@ def sample_database() -> Root:
return root


@pytest.fixture(scope='function')
def parameter_with_readback() -> Parameter:
"""
A simple setpoint-readback parameter pair
"""
readback = Parameter(
uuid="64772c61-c117-445b-b0c8-4c17fd1625d9",
pv_name="RBV",
description="A readback PV",
)
setpoint = Parameter(
uuid="b508344d-1fe9-473b-8d43-9499d0e8e23f",
pv_name="SET",
description="A setpoint PV",
readback=readback,
)
return setpoint


@pytest.fixture(scope='function')
def setpoint_with_readback() -> Setpoint:
"""
A simple setpoint-readback value pair
"""
readback = Readback(
uuid="7b30ddba-9fae-4691-988c-07384c29fe22",
pv_name="RBV",
description="A readback PV",
data=False,
)
setpoint = Setpoint(
uuid="418ed1ab-f1cf-4188-8f4c-ae7cbaf00e6c",
pv_name="SET",
description="A setpoint PV",
data=True,
readback=readback,
)
return setpoint


@pytest.fixture(scope='function')
def filestore_backend(tmp_path: Path) -> FilestoreBackend:
fp = Path(__file__).parent / 'db' / 'filestore.json'
Expand Down
52 changes: 50 additions & 2 deletions superscore/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from superscore.backends.filestore import FilestoreBackend
from superscore.client import Client
from superscore.model import Root
from superscore.errors import CommunicationError
from superscore.model import Parameter, Readback, Root, Setpoint

from .conftest import MockTaskStatus

Expand Down Expand Up @@ -39,8 +40,23 @@ def sscore_cfg(xdg_config_patch: Path):
os.environ["XDG_CONFIG_HOME"] = xdg_cfg


def test_gather_data(mock_client, sample_database):
snapshot = sample_database.entries[3]
orig_pv = snapshot.children[0]
dup_pv = Setpoint(
uuid=orig_pv.uuid,
description=orig_pv.description,
pv_name=orig_pv.pv_name,
data="You shouldn't see this",
)
snapshot.children.append(dup_pv)
pvs, data_list = mock_client._gather_data(snapshot)
assert len(pvs) == len(data_list) == 3
assert data_list[pvs.index("MY:PREFIX:mtr1.ACCL")] == 2


@patch('superscore.control_layers.core.ControlLayer.put')
def test_apply(put_mock, mock_client: Client, sample_database: Root):
def test_apply(put_mock, mock_client: Client, sample_database: Root, setpoint_with_readback):
put_mock.return_value = MockTaskStatus()
snap = sample_database.entries[3]
mock_client.apply(snap)
Expand All @@ -53,6 +69,38 @@ def test_apply(put_mock, mock_client: Client, sample_database: Root):
mock_client.apply(snap, sequential=True)
assert put_mock.call_count == 3

put_mock.reset_mock()
mock_client.apply(setpoint_with_readback, sequential=True)
assert put_mock.call_count == 1


@patch('superscore.control_layers.core.ControlLayer._get_one')
def test_snap(
get_mock,
mock_client: Client,
sample_database: Root,
parameter_with_readback: Parameter
):
coll = sample_database.entries[2]
coll.children.append(parameter_with_readback)

get_mock.side_effect = range(5)
snapshot = mock_client.snap(coll)
assert get_mock.call_count == 5
assert all([snapshot.children[i].data == i for i in range(4)]) # children saved in order
setpoint = snapshot.children[-1]
assert isinstance(setpoint, Setpoint)
assert isinstance(setpoint.readback, Readback)
assert setpoint.readback.data == 4 # readback saved after setpoint


@patch('superscore.control_layers.core.ControlLayer._get_one')
def test_snap_exception(get_mock, mock_client: Client, sample_database: Root):
coll = sample_database.entries[2]
get_mock.side_effect = [0, 1, CommunicationError, 3, 4]
snapshot = mock_client.snap(coll)
assert snapshot.children[2].data is None


def test_from_cfg(sscore_cfg: str):
client = Client.from_config()
Expand Down
Loading

0 comments on commit 6d44dc9

Please sign in to comment.