diff --git a/examples/model.py b/examples/model.py index bc9ed85..7573bf3 100644 --- a/examples/model.py +++ b/examples/model.py @@ -12,8 +12,7 @@ def __init__(self, input_variables=None, output_variables=None): self.input_variables = input_variables self.output_variables = output_variables - def evaluate(self, input_variables): - input_variables = {input_var.name: input_var for input_var in input_variables} + def evaluate(self, input_variables: dict) -> dict: self.output_variables["output1"].value = np.random.uniform( input_variables["input1"].value, # lower dist bound input_variables["input2"].value, # upper dist bound @@ -22,4 +21,4 @@ def evaluate(self, input_variables): self.output_variables["output2"].value = input_variables["input1"].value self.output_variables["output3"].value = input_variables["input2"].value - return list(self.output_variables.values()) + return self.output_variables diff --git a/examples/read_only/model.py b/examples/read_only/model.py index 2028370..f138df9 100644 --- a/examples/read_only/model.py +++ b/examples/read_only/model.py @@ -14,7 +14,6 @@ def __init__(self, input_variables=None, output_variables=None): self.output_variables = output_variables def evaluate(self, input_variables): - input_variables = {input_var.name: input_var for input_var in input_variables} self.output_variables["output1"].value = np.random.uniform( input_variables["input1"].value, # lower dist bound input_variables["input2"].value, # upper dist bound @@ -23,4 +22,4 @@ def evaluate(self, input_variables): self.output_variables["output2"].value = input_variables["input1"].value self.output_variables["output3"].value = input_variables["input2"].value - return list(self.output_variables.values()) + return self.output_variables diff --git a/lume_epics/epics_ca_server.py b/lume_epics/epics_ca_server.py index 020ccc8..41e582e 100644 --- a/lume_epics/epics_ca_server.py +++ b/lume_epics/epics_ca_server.py @@ -270,8 +270,8 @@ def setup_server(self) -> None: if self.shutdown_event.is_set(): pass - for output in model_outputs.get("output_variables", []): - self._output_variables[output.name] = output + model_output_vars = model_outputs.get("output_variables", {}) + self._output_variables.update(model_output_vars) # differentiate between values to serve and not to serve to_serve = [] @@ -315,22 +315,23 @@ def setup_server(self) -> None: def update_pvs( self, - input_variables: List[InputVariable], - output_variables: List[OutputVariable], - ): + input_variables: Dict[str, InputVariable], + output_variables: Dict[str, OutputVariable], + ) -> None: """Update process variables over Channel Access. Args: - input_variables (List[InputVariable]): List of lume-epics output variables. + input_variables (Dict[str, InputVariable]): List of lume-epics output variables. - output_variables (List[OutputVariable]): List of lume-model output variables. + output_variables (Dict[str, OutputVariable]): List of lume-model output variables. """ - variables = input_variables + output_variables + variables = input_variables + variables.update(output_variables) # update variables if the driver is running if self._ca_driver is not None: - self._ca_driver.update_pvs(variables) + self._ca_driver.update_pvs(list(variables.values())) def run(self) -> None: """Start server process.""" @@ -339,8 +340,8 @@ def run(self) -> None: while not self.shutdown_event.is_set(): try: data = self._out_queue.get_nowait() - inputs = data.get("input_variables", []) - outputs = data.get("output_variables", []) + inputs = data.get("input_variables", {}) + outputs = data.get("output_variables", {}) self.update_pvs(inputs, outputs) except Empty: diff --git a/lume_epics/epics_pva_server.py b/lume_epics/epics_pva_server.py index f8537a6..27083c6 100644 --- a/lume_epics/epics_pva_server.py +++ b/lume_epics/epics_pva_server.py @@ -209,7 +209,8 @@ def setup_server(self) -> None: # if startup hasn't failed else: - self._output_variables.update(model_outputs.get("output_variables", {})) + model_output_vars = model_outputs.get("output_variables", {}) + self._output_variables.update(model_output_vars) variables = copy.deepcopy(self._input_variables) variables.update(self._output_variables) @@ -420,8 +421,10 @@ def update_pvs( output_variables (Dict[str, OutputVariable]): Dict of lume-model output variables. """ - variables = input_variables.update(output_variables) - for var_name, variable in variables.items(): + variables = input_variables + variables.update(output_variables) + + for variable in variables.values(): parent = self._field_to_parent_map.get(variable.name) if variable.name in self._input_variables and variable.is_constant: diff --git a/lume_epics/epics_server.py b/lume_epics/epics_server.py index b39d3d8..7863195 100644 --- a/lume_epics/epics_server.py +++ b/lume_epics/epics_server.py @@ -273,12 +273,12 @@ def run_comm_thread( # sync pva/ca if duplicated for protocol, queue in out_queues.items(): if protocol != data["protocol"]: - inputs = [ - self.input_variables[var] + inputs = { + var: self.input_variables[var] for var in data["vars"] if self._epics_config[var]["protocol"] in [protocol, "both"] - ] + } if len(inputs): queue.put({"input_variables": inputs}) @@ -289,13 +289,13 @@ def run_comm_thread( predicted_output = model.evaluate(model_input) for protocol, queue in out_queues.items(): - outputs = [ - var - for var_name, var in predicted_output.items() - if var_name in self._pva_fields - or self._epics_config[var_name]["protocol"] + outputs = { + var.name: var + for var in predicted_output.values() + if var.name in self._pva_fields + or self._epics_config[var.name]["protocol"] in [protocol, "both"] - ] + } queue.put({"output_variables": outputs}, timeout=0.1) except Exception as e: diff --git a/lume_epics/tests/launch_server.py b/lume_epics/tests/launch_server.py index afad154..799a30c 100644 --- a/lume_epics/tests/launch_server.py +++ b/lume_epics/tests/launch_server.py @@ -57,8 +57,7 @@ class TestModel(BaseModel): "output4": ArrayOutputVariable(name="output4"), } - def evaluate(self, input_variables): - self.input_variables = {variable.name: variable for variable in input_variables} + def evaluate(self, input_variables: dict): self.output_variables["output1"].value = ( self.input_variables["input1"].value * 2 @@ -111,7 +110,7 @@ def evaluate(self, input_variables): ) # return inputs * 2 - return list(self.output_variables.values()) + return self.output_variables if __name__ == "__main__":