Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Implement snapshot saving #57

Merged
merged 3 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/source/upcoming_release_notes/57-snapshot_saving.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
57 snapshot saving
#################

API Breaks
----------
- N/A

Features
--------
- ability to save Snapshots for a Collection

Bugfixes
--------
- N/A

Maintenance
-----------
- generalized Client._gather_data to work on any type of Entry
- made Client._gather_data iterative rather than recursive to simplify the conditional logic

Contributors
------------
- shilorigins
168 changes: 126 additions & 42 deletions superscore/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
import logging
import os
from pathlib import Path
from typing import Any, Generator, List, Optional, Union
from typing import Any, Dict, Generator, List, Optional, Union
from uuid import UUID

from superscore.backends import get_backend
from superscore.backends.core import _Backend
from superscore.control_layers import ControlLayer
from superscore.control_layers.status import TaskStatus
from superscore.model import Entry, Setpoint, Snapshot
from superscore.errors import CommunicationError
from superscore.model import (Collection, Entry, Nestable, Parameter, Readback,
Setpoint, Snapshot)
from superscore.type_hints import AnyEpicsType
from superscore.utils import build_abs_path

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -163,6 +166,35 @@ def compare(self, entry_l: Entry, entry_r: Entry) -> Any:
"""Compare two entries. Should be of same type, and return a diff"""
raise NotImplementedError

def snap(self, entry: Collection) -> Snapshot:
"""
Asyncronously read data for all PVs under ``entry``, and store in a
Snapshot. PVs that can't be read will have an exception as their value.

Parameters
----------
entry : Collection
the Collection to save

Returns
-------
Snapshot
a Snapshot corresponding to the input Collection
"""
logger.debug(f"Saving Snapshot for Collection {entry.uuid}")
pvs, _ = self._gather_data(entry)
pvs.extend(Collection.meta_pvs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm realizing now that we're defining meta PVs as a class variable, which prevents us from properly representing capturing changes to the meta_pv list. More concretely, if we:

  • define one collection expecting a set of meta_pvs
  • change the (global) list of meta_pvs
  • take a snapshot with that same old collection

The new snapshot will pull the new meta_pvs, and might inadvertently be missing / duplicate PVs. This could be the behavior we want, but I think we need to find a way to make it more obvious if we do. Alternately we record meta_pvs as an instance variable (not class variable), and fill it with the global configuration on creation.

We also need to determine how we capture this in the Filestore backend, as currently this won't be recorded properly. This probably needs to be a separate dictionary / store external to the Root tree that gets stored

values = self.cl.get(pvs)
data = {}
for pv, value in zip(pvs, values):
if isinstance(value, CommunicationError):
logger.debug(f"Couldn't read value for {pv}, storing \"None\"")
data[pv] = None
else:
logger.debug(f"Storing {pv} = {value}")
data[pv] = value
shilorigins marked this conversation as resolved.
Show resolved Hide resolved
return self._build_snapshot(entry, data)

def apply(
self,
entry: Union[Setpoint, Snapshot],
Expand Down Expand Up @@ -194,7 +226,7 @@ def apply(

# Gather pv-value list and apply at once
status_list = []
pv_list, data_list = self._gather_data(entry)
pv_list, data_list = self._gather_data(entry, writable_only=True)
if sequential:
for pv, data in zip(pv_list, data_list):
logger.debug(f'Putting {pv} = {data}')
Expand All @@ -210,57 +242,109 @@ def apply(

def _gather_data(
self,
entry: Union[Setpoint, Snapshot, UUID],
pv_list: Optional[List[str]] = None,
data_list: Optional[List[Any]] = None
) -> Optional[tuple[List[str], List[Any]]]:
entry: Union[Entry, UUID],
writable_only: bool = False,
) -> tuple[List[str], Optional[List[Any]]]:
"""
Gather writable pv name - data pairs recursively.
If pv_list and data_list are provided, gathered data will be added to
these lists in-place. If both lists are omitted, this function will return
the two lists after gathering.

Queries the backend to fill any UUID values found.
Gather PV name - data pairs that are accessible from ``entry``. Queries
the backend to fill any UUIDs found.

Parameters
----------
entry : Union[Setpoint, Snapshot, UUID]
Entry to gather writable data from
pv_list : Optional[List[str]], optional
List of addresses to write data to, by default None
data_list : Optional[List[Any]], optional
List of data to write to addresses in ``pv_list``, by default None
entry : Union[Entry, UUID]
Entry to gather data from
writable_only : bool
If True, only include writable data e.g. omit Readbacks; by default False

Returns
-------
Optional[tuple[List[str], List[Any]]]
tuple[List[str], Optional[List[Any]]]
the filled pv_list and data_list
"""
top_level = False
if (pv_list is None) and (data_list is None):
pv_list = []
data_list = []
top_level = True
elif (pv_list is None) or (data_list is None):
raise ValueError(
"Arguments pv_list and data_list must either both be provided "
"or both omitted."
)
pv_list = []
data_list = []

seen = set()
q = [entry]
while len(q) > 0:
entry = q.pop()
uuid = entry if isinstance(entry, UUID) else entry.uuid
if uuid in seen:
continue
elif isinstance(entry, UUID):
entry = self.backend.get_entry(entry)
seen.add(entry.uuid)

if isinstance(entry, Nestable):
q.extend(reversed(entry.children)) # preserve execution order
elif isinstance(entry, Readback) and writable_only:
pass
else: # entry is Parameter, Setpoint, or Readback
pv_list.append(entry.pv_name)
if hasattr(entry, "data"):
data_list.append(entry.data)
if getattr(entry, "readback", None) is not None:
q.append(entry.readback)
return pv_list, data_list

def _build_snapshot(
self,
coll: Collection,
values: Dict[str, AnyEpicsType],
) -> Snapshot:
"""
Traverse a Collection, assembling a Snapshot using pre-fetched data
along the way

if isinstance(entry, Snapshot):
for child in entry.children:
self._gather_data(child, pv_list, data_list)
elif isinstance(entry, UUID):
child_entry = self.backend.get_entry(entry)
self._gather_data(child_entry, pv_list, data_list)
elif isinstance(entry, Setpoint):
pv_list.append(entry.pv_name)
data_list.append(entry.data)
Parameters
----------
coll : Collection
The collection being saved
values : Dict[str, AnyEpicsType]
A dictionary mapping PV names to pre-fetched values

# Readbacks are not writable, and are not gathered
Returns
-------
Snapshot
A Snapshot corresponding to the input Collection
"""
snapshot = Snapshot(
title=coll.title,
tags=coll.tags.copy(),
origin_collection=coll
)

for child in coll.children:
if isinstance(child, UUID):
child = self.backend.get(child)
if isinstance(child, Parameter):
if child.readback is not None:
readback = Readback(
shilorigins marked this conversation as resolved.
Show resolved Hide resolved
pv_name=child.readback.pv_name,
description=child.readback.description,
data=values[child.readback.pv_name]
)
else:
readback = None
setpoint = Setpoint(
pv_name=child.pv_name,
description=child.description,
data=values[child.pv_name],
readback=readback
)
snapshot.children.append(setpoint)
elif isinstance(child, Collection):
snapshot.append(self._build_snapshot(child, values))

snapshot.meta_pvs = []
for pv in Collection.meta_pvs:
readback = Readback(
pv_name=readback.pv_name,
data=values[readback.pv_name]
)
snapshot.meta_pvs.append(readback)

if top_level:
return pv_list, data_list
return snapshot

def validate(self, entry: Entry):
"""
Expand Down
5 changes: 3 additions & 2 deletions superscore/control_layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
import asyncio
import logging
from collections.abc import Iterable
from functools import singledispatchmethod
from typing import Any, Callable, Dict, List, Optional, Union

Expand Down Expand Up @@ -99,13 +100,13 @@ def _get_single(self, address: str) -> Any:
return asyncio.run(self._get_one(address))

@get.register
def _get_list(self, address: list) -> Any:
def _get_list(self, address: Iterable) -> Any:
"""Synchronously get a list of ``address``"""
async def gathered_coros():
coros = []
for p in address:
coros.append(self._get_one(p))
return await asyncio.gather(*coros)
return await asyncio.gather(*coros, return_exceptions=True)

return asyncio.run(gathered_coros())

Expand Down
40 changes: 40 additions & 0 deletions superscore/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,46 @@ def sample_database() -> Root:
return root


@pytest.fixture(scope='function')
def parameter_with_readback() -> Parameter:
"""
A simple setpoint-readback parameter pair
"""
readback = Parameter(
uuid="64772c61-c117-445b-b0c8-4c17fd1625d9",
pv_name="RBV",
description="A readback PV",
)
setpoint = Parameter(
uuid="b508344d-1fe9-473b-8d43-9499d0e8e23f",
pv_name="SET",
description="A setpoint PV",
readback=readback,
)
return setpoint


@pytest.fixture(scope='function')
def setpoint_with_readback() -> Setpoint:
"""
A simple setpoint-readback value pair
"""
readback = Readback(
uuid="7b30ddba-9fae-4691-988c-07384c29fe22",
pv_name="RBV",
description="A readback PV",
data=False,
)
setpoint = Setpoint(
uuid="418ed1ab-f1cf-4188-8f4c-ae7cbaf00e6c",
pv_name="SET",
description="A setpoint PV",
data=True,
readback=readback,
)
return setpoint


@pytest.fixture(scope='function')
def filestore_backend(tmp_path: Path) -> FilestoreBackend:
fp = Path(__file__).parent / 'db' / 'filestore.json'
Expand Down
52 changes: 50 additions & 2 deletions superscore/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from superscore.backends.filestore import FilestoreBackend
from superscore.client import Client
from superscore.model import Root
from superscore.errors import CommunicationError
from superscore.model import Parameter, Readback, Root, Setpoint

from .conftest import MockTaskStatus

Expand Down Expand Up @@ -39,8 +40,23 @@ def sscore_cfg(xdg_config_patch: Path):
os.environ["XDG_CONFIG_HOME"] = xdg_cfg


def test_gather_data(mock_client, sample_database):
snapshot = sample_database.entries[3]
orig_pv = snapshot.children[0]
dup_pv = Setpoint(
uuid=orig_pv.uuid,
description=orig_pv.description,
pv_name=orig_pv.pv_name,
data="You shouldn't see this",
)
snapshot.children.append(dup_pv)
pvs, data_list = mock_client._gather_data(snapshot)
assert len(pvs) == len(data_list) == 3
assert data_list[pvs.index("MY:PREFIX:mtr1.ACCL")] == 2


@patch('superscore.control_layers.core.ControlLayer.put')
def test_apply(put_mock, mock_client: Client, sample_database: Root):
def test_apply(put_mock, mock_client: Client, sample_database: Root, setpoint_with_readback):
put_mock.return_value = MockTaskStatus()
snap = sample_database.entries[3]
mock_client.apply(snap)
Expand All @@ -53,6 +69,38 @@ def test_apply(put_mock, mock_client: Client, sample_database: Root):
mock_client.apply(snap, sequential=True)
assert put_mock.call_count == 3

put_mock.reset_mock()
mock_client.apply(setpoint_with_readback, sequential=True)
assert put_mock.call_count == 1


@patch('superscore.control_layers.core.ControlLayer._get_one')
shilorigins marked this conversation as resolved.
Show resolved Hide resolved
def test_snap(
get_mock,
mock_client: Client,
sample_database: Root,
parameter_with_readback: Parameter
):
coll = sample_database.entries[2]
coll.children.append(parameter_with_readback)

get_mock.side_effect = range(5)
snapshot = mock_client.snap(coll)
assert get_mock.call_count == 5
assert all([snapshot.children[i].data == i for i in range(4)]) # children saved in order
setpoint = snapshot.children[-1]
assert isinstance(setpoint, Setpoint)
assert isinstance(setpoint.readback, Readback)
assert setpoint.readback.data == 4 # readback saved after setpoint


@patch('superscore.control_layers.core.ControlLayer._get_one')
def test_snap_exception(get_mock, mock_client: Client, sample_database: Root):
coll = sample_database.entries[2]
get_mock.side_effect = [0, 1, CommunicationError, 3, 4]
snapshot = mock_client.snap(coll)
assert snapshot.children[2].data is None


def test_from_cfg(sscore_cfg: str):
client = Client.from_config()
Expand Down
Loading