Skip to content

Commit

Permalink
Implement DriverInterface.
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterKraus committed May 31, 2024
1 parent 673e5ec commit 6ed7723
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 57 deletions.
22 changes: 11 additions & 11 deletions src/tomato/daemon/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,14 @@ def tomato_driver() -> None:
return

kwargs = dict(settings=daemon.drvs[args.driver].settings)
driver = getattr(tomato.drivers, args.driver).Driver(**kwargs)
interface = getattr(tomato.drivers, args.driver).DriverInterface(**kwargs)

logger.info(f"registering devices in driver {args.driver!r}")
for dev in daemon.devs.values():
if dev.driver == args.driver:
for channel in dev.channels:
driver.dev_register(address=dev.address, channel=channel)
logger.debug(f"{driver.devmap=}")
interface.dev_register(address=dev.address, channel=channel)
logger.debug(f"{interface.devmap=}")

logger.info(f"driver {args.driver!r} bootstrapped successfully")

Expand All @@ -117,7 +117,7 @@ def tomato_driver() -> None:
port=port,
pid=pid,
connected_at=str(datetime.now(timezone.utc)),
settings=driver.settings,
settings=interface.settings,
)
req.send_pyobj(
dict(cmd="driver", params=params, sender=f"{__name__}.tomato_driver")
Expand Down Expand Up @@ -155,34 +155,34 @@ def tomato_driver() -> None:
data=dict(status=status, driver=args.driver),
)
elif msg["cmd"] == "settings":
driver.settings = msg["params"]
params["settings"] = driver.settings
interface.settings = msg["params"]
params["settings"] = interface.settings
ret = Reply(
success=True,
msg="settings received",
data=msg.get("params"),
)
elif msg["cmd"] == "dev_register":
driver.dev_register(**msg["params"])
interface.dev_register(**msg["params"])
ret = Reply(
success=True,
msg="device registered",
data=msg.get("params"),
)
elif msg["cmd"] == "task_status":
ret = driver.task_status(**msg["params"])
ret = interface.task_status(**msg["params"])
elif msg["cmd"] == "task_start":
ret = driver.task_start(**msg["params"])
ret = interface.task_start(**msg["params"])
elif msg["cmd"] == "task_data":
ret = driver.task_data(**msg["params"])
ret = interface.task_data(**msg["params"])
logger.debug(f"{ret=}")
rep.send_pyobj(ret)
if status == "stop":
break

logger.info(f"driver {args.driver!r} is beginning teardown")

driver.teardown()
interface.teardown()

logger.critical(f"driver {args.driver!r} is quitting")

Expand Down
95 changes: 49 additions & 46 deletions src/tomato/drivers/example_counter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import wraps

from tomato.drivers.example_counter.counter import Counter
from tomato.models import Reply
from tomato.models import Reply, DriverInterface
from xarray import Dataset

logger = logging.getLogger(__name__)
Expand All @@ -24,8 +24,8 @@ def wrapper(self, **kwargs):
return wrapper


class Driver:
class Device:
class DriverInterface(DriverInterface):
class DeviceInterface:
dev: Counter
conn: Connection
proc: Process
Expand All @@ -35,49 +35,48 @@ def __init__(self):
self.conn, conn = Pipe()
self.proc = Process(target=self.dev.run_counter, args=(conn,))

devmap: dict[tuple, Device]
devmap: dict[tuple, DeviceInterface]
settings: dict

attrs: dict = dict(
delay=dict(type=float, rw=True),
time=dict(type=float, rw=True),
started=dict(type=bool, rw=True),
val=dict(type=int, rw=False),
)

tasks: dict = dict(
count=dict(
time=dict(type=float),
delay=dict(type=float),
),
random=dict(
time=dict(type=float),
delay=dict(type=float),
min=dict(type=float),
max=dict(type=float),
),
)

def __init__(self, settings=None):
self.devmap = {}
self.settings = settings if settings is not None else {}
def attrs(self, **kwargs) -> dict:
return dict(
delay=dict(type=float, rw=True),
time=dict(type=float, rw=True),
started=dict(type=bool, rw=True),
val=dict(type=int, rw=False),
)

def tasks(self, **kwargs) -> dict:
return dict(
count=dict(
time=dict(type=float),
delay=dict(type=float),
),
random=dict(
time=dict(type=float),
delay=dict(type=float),
min=dict(type=float),
max=dict(type=float),
),
)

def dev_register(self, address: str, channel: int, **kwargs):
key = (address, channel)
self.devmap[key] = self.Device()
self.devmap[key] = self.DeviceInterface()
self.devmap[key].proc.start()

@in_devmap
def dev_attr_set(self, attr: str, val: Any, address: str, channel: int, **kwargs):
def dev_set_attr(self, attr: str, val: Any, address: str, channel: int, **kwargs):
key = (address, channel)
if attr in self.attrs:
if self.attrs[attr]["rw"] and isinstance(val, self.attrs[attr]["type"]):
if attr in self.attrs():
params = self.attrs()[attr]
if params["rw"] and isinstance(val, params["type"]):
self.devmap[key].conn.send(("set", attr, val))

@in_devmap
def dev_attr_get(self, attr: str, address: str, channel: int, **kwargs):
def dev_get_attr(self, attr: str, address: str, channel: int, **kwargs):
key = (address, channel)
if attr in self.attrs:
if attr in self.attrs():
self.devmap[key].conn.send(("get", attr, None))
return self.devmap[key].conn.recv()

Expand All @@ -87,24 +86,16 @@ def dev_status(self, address: str, channel: int, **kwargs):
self.devmap[key].conn.send(("status", None, None))
return self.devmap[key].conn.recv()

@in_devmap
def task_status(self, address: str, channel: int):
started = self.dev_attr_get(attr="started", address=address, channel=channel)
if not started:
return Reply(success=True, msg="ready")
else:
return Reply(success=True, msg="running")

@in_devmap
def task_start(self, address: str, channel: int, task: str, **kwargs):
if task not in self.tasks:
if task not in self.tasks():
return Reply(
success=False,
msg=f"unknown task {task!r} requested",
data=self.tasks,
data=self.tasks(),
)

reqs = self.tasks[task]
reqs = self.tasks()[task]
for k, v in reqs.items():
if k not in kwargs and "default" not in v:
logger.critical("Somehow we're here")
Expand All @@ -123,14 +114,22 @@ def task_start(self, address: str, channel: int, task: str, **kwargs):
msg=f"parameter {k!r} is wrong type",
data=reqs,
)
self.dev_attr_set(attr=k, val=val, address=address, channel=channel)
self.dev_attr_set(attr="started", val=True, address=address, channel=channel)
self.dev_set_attr(attr=k, val=val, address=address, channel=channel)
self.dev_set_attr(attr="started", val=True, address=address, channel=channel)
return Reply(
success=True,
msg=f"task {task!r} started successfully",
data=kwargs,
)

@in_devmap
def task_status(self, address: str, channel: int):
started = self.dev_get_attr(attr="started", address=address, channel=channel)
if not started:
return Reply(success=True, msg="ready")
else:
return Reply(success=True, msg="running")

@in_devmap
def task_data(self, address: str, channel: int, **kwargs):
key = (address, channel)
Expand Down Expand Up @@ -171,3 +170,7 @@ def teardown(self):
logger.error(f"device {key!r} is still alive")
else:
logger.debug(f"device {key!r} successfully closed")


if __name__ == "__main__":
test = DriverInterface()
128 changes: 128 additions & 0 deletions src/tomato/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pydantic import BaseModel, Field
from typing import Optional, Any, Mapping, Sequence, Literal
from pathlib import Path
from abc import ABCMeta, abstractmethod
import xarray as xr


class Driver(BaseModel):
Expand Down Expand Up @@ -68,3 +70,129 @@ class Reply(BaseModel):
success: bool
msg: str
data: Optional[Any] = None


class DriverInterface(metaclass=ABCMeta):
class DeviceInterface(metaclass=ABCMeta):
"""Class used to implement management of each individual device."""

pass

devmap: dict[tuple, DeviceInterface]
"""Map of registered devices, the tuple keys are components = (address, channel)"""

settings: dict[str, str]
"""A settings map to contain driver-specific settings such as `dllpath` for BioLogic"""

def __init__(self, settings=None):
self.devmap = {}
self.settings = settings if settings is not None else {}

def dev_register(self, address: str, channel: int, **kwargs: dict) -> None:
"""
Register a Device and its Component in this DriverInterface, creating a
:obj:`self.DeviceInterface` object in the :obj:`self.devmap` if necessary, or
updating existing channels in :obj:`self.devmap`.
"""
self.devmap[(address, channel)] = self.DeviceInterface(**kwargs)

def dev_teardown(self, address: str, channel: int, **kwargs: dict) -> None:
"""
Emergency stop function. Set the device into a documented, safe state.
The function is to be only called in case of critical errors, not as part of
normal operation.
"""
pass

@abstractmethod
def attrs(self, address: str, channel: int, **kwargs) -> dict:
"""
Function that returns all gettable and settable attributes, their rw status,
and whether they are to be printed in `dev_status`.
This is the "low level" control interface, intended for the device dashboard.
Example:
--------
return dict(
delay = dict(type=float, rw=True, status=False),
time = dict(type=float, rw=True, status=False),
started = dict(type=bool, rw=True, status=True),
val = dict(type=int, rw=False, status=True),
)
"""
pass

@abstractmethod
def dev_set_attr(self, attr: str, val: Any, address: str, channel: int, **kwargs):
"""Set the value of a read-write attr on a Component"""
pass

@abstractmethod
def dev_get_attr(self, attr: str, address: str, channel: int, **kwargs):
"""Get the value of any attr from a Component"""
pass

def dev_status(self, address: str, channel: int, **kwargs) -> dict[str, Any]:
"""Get a status report from a Component"""
ret = {}
for k, v in self.attrs(address=address, channel=channel, **kwargs).items():
if v.status:
ret[k] = self.dev_get_attr(
attr=k, address=address, channel=channel, **kwargs
)
return ret

# @abstractmethod
# def dev_get_data(self, address: str, channel: int, **kwargs):
# """Get a data report from a Component"""
# pass

@abstractmethod
def tasks(self, address: str, channel: int, **kwargs) -> dict:
"""
Function that returns all tasks that can be submitted to the Device. This
implements the driver specific language. Each task in tasks can only contain
elements present in :func:`self.attrs`.
Example:
return dict(
count = dict(time = dict(type=float), delay = dict(type=float),
)
"""
pass

@abstractmethod
def task_start(self, address: str, channel: int, task: str, **kwargs) -> None:
"""start a task on a (ready) component"""
pass

@abstractmethod
def task_status(self, address: str, channel: int) -> Literal["running", "ready"]:
"""check task status of the component"""
pass

@abstractmethod
def task_data(self, address: str, channel: int, **kwargs) -> xr.Dataset:
"""get any cached data for the current task on the component"""
pass

# @abstractmethod
# def task_stop(self, address: str, channel: int) -> xr.Dataset:
# """stops the current task, making the component ready and returning any data"""
# pass

@abstractmethod
def status(self) -> dict:
"""return status info of the driver"""
pass

@abstractmethod
def teardown(self) -> None:
"""
Stop all tasks, tear down all devices, close all processes.
Users can assume the devices are put in a safe state (valves closed, power off).
"""
pass

0 comments on commit 6ed7723

Please sign in to comment.