Skip to content

Commit

Permalink
Merge pull request #99 from jacquelinegarrahan/v1.6-bug-fixes
Browse files Browse the repository at this point in the history
V1.6 bug fixes
  • Loading branch information
jacquelinegarrahan committed May 31, 2022
2 parents 67fd64f + 69809ef commit 0102c98
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 31 deletions.
5 changes: 2 additions & 3 deletions examples/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ def __init__(self, input_variables=None, output_variables=None):
self.input_variables = input_variables
self.output_variables = output_variables

def evaluate(self, input_variables):
input_variables = {input_var.name: input_var for input_var in input_variables}
def evaluate(self, input_variables: dict) -> dict:
self.output_variables["output1"].value = np.random.uniform(
input_variables["input1"].value, # lower dist bound
input_variables["input2"].value, # upper dist bound
Expand All @@ -22,4 +21,4 @@ def evaluate(self, input_variables):
self.output_variables["output2"].value = input_variables["input1"].value
self.output_variables["output3"].value = input_variables["input2"].value

return list(self.output_variables.values())
return self.output_variables
3 changes: 1 addition & 2 deletions examples/read_only/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def __init__(self, input_variables=None, output_variables=None):
self.output_variables = output_variables

def evaluate(self, input_variables):
input_variables = {input_var.name: input_var for input_var in input_variables}
self.output_variables["output1"].value = np.random.uniform(
input_variables["input1"].value, # lower dist bound
input_variables["input2"].value, # upper dist bound
Expand All @@ -23,4 +22,4 @@ def evaluate(self, input_variables):
self.output_variables["output2"].value = input_variables["input1"].value
self.output_variables["output3"].value = input_variables["input2"].value

return list(self.output_variables.values())
return self.output_variables
23 changes: 12 additions & 11 deletions lume_epics/epics_ca_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ def setup_server(self) -> None:
if self.shutdown_event.is_set():
pass

for output in model_outputs.get("output_variables", []):
self._output_variables[output.name] = output
model_output_vars = model_outputs.get("output_variables", {})
self._output_variables.update(model_output_vars)

# differentiate between values to serve and not to serve
to_serve = []
Expand Down Expand Up @@ -315,22 +315,23 @@ def setup_server(self) -> None:

def update_pvs(
self,
input_variables: List[InputVariable],
output_variables: List[OutputVariable],
):
input_variables: Dict[str, InputVariable],
output_variables: Dict[str, OutputVariable],
) -> None:
"""Update process variables over Channel Access.
Args:
input_variables (List[InputVariable]): List of lume-epics output variables.
input_variables (Dict[str, InputVariable]): List of lume-epics output variables.
output_variables (List[OutputVariable]): List of lume-model output variables.
output_variables (Dict[str, OutputVariable]): List of lume-model output variables.
"""
variables = input_variables + output_variables
variables = input_variables
variables.update(output_variables)

# update variables if the driver is running
if self._ca_driver is not None:
self._ca_driver.update_pvs(variables)
self._ca_driver.update_pvs(list(variables.values()))

def run(self) -> None:
"""Start server process."""
Expand All @@ -339,8 +340,8 @@ def run(self) -> None:
while not self.shutdown_event.is_set():
try:
data = self._out_queue.get_nowait()
inputs = data.get("input_variables", [])
outputs = data.get("output_variables", [])
inputs = data.get("input_variables", {})
outputs = data.get("output_variables", {})
self.update_pvs(inputs, outputs)

except Empty:
Expand Down
9 changes: 6 additions & 3 deletions lume_epics/epics_pva_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ def setup_server(self) -> None:

# if startup hasn't failed
else:
self._output_variables.update(model_outputs.get("output_variables", {}))
model_output_vars = model_outputs.get("output_variables", {})
self._output_variables.update(model_output_vars)

variables = copy.deepcopy(self._input_variables)
variables.update(self._output_variables)
Expand Down Expand Up @@ -420,8 +421,10 @@ def update_pvs(
output_variables (Dict[str, OutputVariable]): Dict of lume-model output variables.
"""
variables = input_variables.update(output_variables)
for var_name, variable in variables.items():
variables = input_variables
variables.update(output_variables)

for variable in variables.values():
parent = self._field_to_parent_map.get(variable.name)

if variable.name in self._input_variables and variable.is_constant:
Expand Down
18 changes: 9 additions & 9 deletions lume_epics/epics_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,12 @@ def run_comm_thread(
# sync pva/ca if duplicated
for protocol, queue in out_queues.items():
if protocol != data["protocol"]:
inputs = [
self.input_variables[var]
inputs = {
var: self.input_variables[var]
for var in data["vars"]
if self._epics_config[var]["protocol"]
in [protocol, "both"]
]
}

if len(inputs):
queue.put({"input_variables": inputs})
Expand All @@ -289,13 +289,13 @@ def run_comm_thread(
predicted_output = model.evaluate(model_input)

for protocol, queue in out_queues.items():
outputs = [
var
for var_name, var in predicted_output.items()
if var_name in self._pva_fields
or self._epics_config[var_name]["protocol"]
outputs = {
var.name: var
for var in predicted_output.values()
if var.name in self._pva_fields
or self._epics_config[var.name]["protocol"]
in [protocol, "both"]
]
}
queue.put({"output_variables": outputs}, timeout=0.1)

except Exception as e:
Expand Down
5 changes: 2 additions & 3 deletions lume_epics/tests/launch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ class TestModel(BaseModel):
"output4": ArrayOutputVariable(name="output4"),
}

def evaluate(self, input_variables):
self.input_variables = {variable.name: variable for variable in input_variables}
def evaluate(self, input_variables: dict):

self.output_variables["output1"].value = (
self.input_variables["input1"].value * 2
Expand Down Expand Up @@ -111,7 +110,7 @@ def evaluate(self, input_variables):
)

# return inputs * 2
return list(self.output_variables.values())
return self.output_variables


if __name__ == "__main__":
Expand Down

0 comments on commit 0102c98

Please sign in to comment.