diff --git a/lume_epics/client/controller.py b/lume_epics/client/controller.py index 2f72dd1..bea259f 100644 --- a/lume_epics/client/controller.py +++ b/lume_epics/client/controller.py @@ -196,9 +196,12 @@ def get(self, pvname: str) -> np.ndarray: """ self._set_up_pv_monitor(pvname) + pv = self._pv_registry.get(pvname, None) + if pv: return pv["value"] + return None def get_value(self, pvname): @@ -289,7 +292,7 @@ def get_array(self, pvname) -> dict: # context returns numpy array with WRITEABLE=False # copy to manipulate array below - array = self.get(array) + array = self.get(pvname) if array is not None: return array diff --git a/lume_epics/epics_ca_server.py b/lume_epics/epics_ca_server.py index 9d3a18d..690d824 100644 --- a/lume_epics/epics_ca_server.py +++ b/lume_epics/epics_ca_server.py @@ -4,7 +4,6 @@ import time import signal from typing import Dict - from lume_model.variables import Variable, InputVariable, OutputVariable import numpy as np from pcaspy import Driver, SimpleServer @@ -72,7 +71,10 @@ def __init__( self._in_queue = in_queue self._out_queue = out_queue self._providers = {} - self._running = running_indicator + self._running_indicator = running_indicator + + # cached pv values + self._cached_values = {} def update_pv(self, pvname, value) -> None: """Adds update to input process variable to the input queue. @@ -85,7 +87,13 @@ def update_pv(self, pvname, value) -> None: """ val = value pvname = pvname.replace(f"{self._prefix}:", "") - self._in_queue.put({"protocol": self.protocol, "pvname": pvname, "value": val}) + + self._cached_values.update({pvname: val}) + + # only update if not running + if not self._running_indicator.value: + self._in_queue.put({"protocol": self.protocol, "pvs": self._cached_values}) + self._cached_values = {} def setup_server(self) -> None: """Configure and start server. @@ -137,7 +145,6 @@ def run(self) -> None: """ self.setup_server() - self._running.value = True while not self.exit_event.is_set(): try: data = self._out_queue.get_nowait() @@ -149,7 +156,7 @@ def run(self) -> None: logger.debug("out queue empty") self.server_thread.stop() - self._running.value = False + # self.server_thread.join() logger.info("Channel access server stopped.") def shutdown(self): @@ -387,6 +394,10 @@ def write(self, pvname: str, value: Union[float, np.ndarray]) -> bool: ) return False + if value is None: + logger.debug(f"None value provided for {pvname}") + return False + if model_var_name in self.server._input_variables: if self.server._input_variables[model_var_name].is_constant: diff --git a/lume_epics/epics_pva_server.py b/lume_epics/epics_pva_server.py index 8401198..f401b55 100644 --- a/lume_epics/epics_pva_server.py +++ b/lume_epics/epics_pva_server.py @@ -41,8 +41,8 @@ def __init__( output_variables: List[OutputVariable], in_queue: multiprocessing.Queue, out_queue: multiprocessing.Queue, - running_indicator: multiprocessing.Value, conf_proxy: DictProxy, + running_indicator=multiprocessing.Value, *args, **kwargs, ) -> None: @@ -71,7 +71,9 @@ def __init__( self._out_queue = out_queue self._providers = {} self._conf = conf_proxy - self._running = running_indicator + self._running_indicator = running_indicator + + self._cached_values = {} def update_pv(self, pvname: str, value: Union[np.ndarray, float]) -> None: """Adds update to input process variable to the input queue. @@ -85,7 +87,13 @@ def update_pv(self, pvname: str, value: Union[np.ndarray, float]) -> None: # Hack for now to get the pickable value val = value.raw.value pvname = pvname.replace(f"{self._prefix}:", "") - self._in_queue.put({"protocol": self.protocol, "pvname": pvname, "value": val}) + + self._cached_values.update({"pvname": val}) + + # only update if not running + if not self._running_indicator: + self._in_queue.put({"protocol": self.protocol, "pvs": self._cached_values}) + self._cached_values = {} def setup_server(self) -> None: """Configure and start server. @@ -229,8 +237,8 @@ def update_pvs( logger.debug( "pvAccess array process variable %s updated.", variable.name ) - if variable.value_type == "str": - value = variable.value + if variable.value_type == "string": + value = list(variable.value) else: value = variable.value.view(NTNDArrayData) @@ -252,7 +260,6 @@ def run(self) -> None: """ self.setup_server() - self._running.value = True # mark running while not self.exit_event.is_set(): @@ -261,12 +268,18 @@ def run(self) -> None: inputs = data.get("input_variables", []) outputs = data.get("output_variables", []) self.update_pvs(inputs, outputs) + + # check cached values + if len(self._cached_values) > 0 and not self._running_indicator: + self._in_queue.put( + {"protocol": self.protocol, "pvs": self._cached_values} + ) + except Empty: time.sleep(0.01) logger.debug("out queue empty") self.pva_server.stop() - self._running.value = False logger.info("pvAccess server stopped.") def shutdown(self): @@ -308,7 +321,7 @@ def put(self, pv: SharedPV, op: ServOpWrap) -> None: """ # update input values and global input process variable state - if not self.is_constant: + if not self.is_constant and op.value() is not None: pv.post(op.value()) self.server.update_pv(pvname=self.pvname, value=op.value()) # mark server operation as complete diff --git a/lume_epics/epics_server.py b/lume_epics/epics_server.py index 5bdfabc..b5a7a49 100644 --- a/lume_epics/epics_server.py +++ b/lume_epics/epics_server.py @@ -13,28 +13,29 @@ from .epics_ca_server import CAServer logger = logging.getLogger(__name__) +multiprocessing.set_start_method("fork") class Server: """ Server for EPICS process variables. Can be optionally initialized with only - pvAccess or Channel Access protocols; but, defaults to serving over both. + pvAccess or Channel Access protocols; but, defaults to serving over both. Attributes: model (SurrogateModel): SurrogateModel class to be served input_variables (List[Variable]): List of lume-model variables passed to model. - ouput_variables (List[Variable]): List of lume-model variables to use as + ouput_variables (List[Variable]): List of lume-model variables to use as outputs. - ca_server (SimpleServer): Server class that interfaces between the Channel + ca_server (SimpleServer): Server class that interfaces between the Channel Access client and the driver. - ca_driver (CADriver): Class used by server to handle to process variable + ca_driver (CADriver): Class used by server to handle to process variable read/write requests. - pva_server (P4PServer): Threaded p4p server used for serving pvAccess + pva_server (P4PServer): Threaded p4p server used for serving pvAccess variables. exit_event (Event): Threading exit event marking server shutdown. @@ -79,9 +80,8 @@ def __init__( self.prefix = prefix self.protocols = protocols - model = model_class(**model_kwargs) - self.input_variables = model.input_variables - + self.model = model_class(**model_kwargs) + self.input_variables = self.model.input_variables # update inputs for starting value to be the default for variable in self.input_variables.values(): @@ -90,8 +90,8 @@ def __init__( model_input = list(self.input_variables.values()) - self.input_variables = model.input_variables - self.output_variables = model.evaluate(model_input) + self.input_variables = self.model.input_variables + self.output_variables = self.model.evaluate(model_input) self.output_variables = { variable.name: variable for variable in self.output_variables } @@ -103,20 +103,19 @@ def __init__( self.exit_event = Event() + self._running_indicator = multiprocessing.Value("b", False) + + # we use the running marker to make sure pvs + ca don't just keep adding queue elements self.comm_thread = threading.Thread( target=self.run_comm_thread, - args=(model_class,), kwargs={ "model_kwargs": model_kwargs, "in_queue": self.in_queue, - "out_queues": self.out_queues - } + "out_queues": self.out_queues, + "running_indicator": self._running_indicator, + }, ) - # track running servers - self._ca_running = multiprocessing.Value('b', False) - self._pva_running = multiprocessing.Value('b', False) - # initialize channel access server if "ca" in protocols: self.ca_process = CAServer( @@ -125,7 +124,7 @@ def __init__( output_variables=self.output_variables, in_queue=self.in_queue, out_queue=self.out_queues["ca"], - running_indicator=self._ca_running, + running_indicator=self._running_indicator, ) # initialize pvAccess server @@ -139,8 +138,8 @@ def __init__( output_variables=self.output_variables, in_queue=self.in_queue, out_queue=self.out_queues["pva"], - running_indicator = self._pva_running, - conf_proxy = self._pva_conf, + conf_proxy=self._pva_conf, + running_indicator=self._running_indicator, ) def __enter__(self): @@ -154,33 +153,51 @@ def __exit__(self, exc_type, exc_val, exc_tb): """ self.stop() - def run_comm_thread(self, model_class, model_kwargs={}, in_queue: multiprocessing.Queue=None, - out_queues: Dict[str, multiprocessing.Queue]=None): + def run_comm_thread( + self, + *, + running_indicator: multiprocessing.Value, + model_kwargs={}, + in_queue: multiprocessing.Queue = None, + out_queues: Dict[str, multiprocessing.Queue] = None, + ): """Handles communications between pvAccess server, Channel Access server, and model. - + Arguments: model_class: Model class to be executed. model_kwargs (dict): Dictionary of model keyword arguments. - in_queue (multiprocessing.Queue): + in_queue (multiprocessing.Queue): out_queues (Dict[str: multiprocessing.Queue]): Maps protocol to output assignment queue. + running_marker (multiprocessing.Value): multiprocessing marker for whether comm thread computing or not """ - model = model_class(**model_kwargs) + model = self.model while not self.exit_event.is_set(): try: + data = in_queue.get(timeout=0.1) - self.input_variables[data["pvname"]].value = data["value"] + + # mark running + running_indicator.value = True + + for pv in data["pvs"]: + self.input_variables[pv].value = data["pvs"][pv] + + # sync pva/ca for protocol, queue in out_queues.items(): if protocol == data["protocol"]: continue + queue.put( - {"input_variables": - [self.input_variables[data["pvname"]]] + { + "input_variables": [ + self.input_variables[pv] for pv in data["pvs"] + ] } ) @@ -188,10 +205,13 @@ def run_comm_thread(self, model_class, model_kwargs={}, in_queue: multiprocessin model_input = list(self.input_variables.values()) predicted_output = model.evaluate(model_input) for protocol, queue in out_queues.items(): - queue.put({"output_variables": predicted_output}, - timeout=0.1) + queue.put({"output_variables": predicted_output}, timeout=0.1) + + running_indicator.value = False + except Empty: continue + except Full: logger.error(f"{protocol} queue is full.") @@ -200,7 +220,7 @@ def run_comm_thread(self, model_class, model_kwargs={}, in_queue: multiprocessin def start(self, monitor: bool = True) -> None: """Starts server using set server protocol(s). - Args: + Args: monitor (bool): Indicates whether to run the server in the background or to continually monitor. If monitor = False, the server must be explicitly stopped using server.stop() @@ -232,8 +252,8 @@ def stop(self) -> None: if "ca" in self.protocols: self.ca_process.shutdown() - + if "pva" in self.protocols: self.pva_process.shutdown() - logger.info("Server is stopped.") \ No newline at end of file + logger.info("Server is stopped.")