Skip to content

Commit

Permalink
Merge pull request #108 from kathryn-baker/demo-fixes
Browse files Browse the repository at this point in the history
Demo fixes
  • Loading branch information
ChristopherMayes authored Jun 30, 2023
2 parents 5639f3f + 1464253 commit 8da4089
Show file tree
Hide file tree
Showing 9 changed files with 316 additions and 10 deletions.
63 changes: 63 additions & 0 deletions examples/files/california_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
model:
kwargs:
model_file: examples/files/california_regression.pt
model_class: lume_model.pytorch.PyTorchModel
model_info: examples/files/california_model_info.json
output_format:
type: variable
requirements:
torch: 1.12

input_variables:
MedInc:
default: 3.7857346534729004
range:
- 0.4999000132083893
- 15.000100135803223
type: scalar
HouseAge:
default: 29.282135009765625
range:
- 1.0
- 52.0
type: scalar
AveRooms:
default: 5.4074907302856445
range:
- 0.8461538553237915
- 141.90908813476562
type: scalar
AveBedrms:
default: 1.1071722507476807
range:
- 0.375
- 34.06666564941406
type: scalar
Population:
default: 1437.0687255859375
range:
- 3.0
- 28566.0
type: scalar
AveOccup:
default: 3.035413980484009
range:
- 0.692307710647583
- 599.7142944335938
type: scalar
Latitude:
default: 35.28323745727539
range:
- 32.65999984741211
- 41.95000076293945
type: scalar
Longitude:
default: -119.11573028564453
range:
- -124.3499984741211
- -114.30999755859375
type: scalar

output_variables:
MedHouseVal:
type: scalar
38 changes: 38 additions & 0 deletions examples/files/california_epics_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
input_variables:
MedInc:
pvname: MedInc
protocol: pva
serve: true
HouseAge:
pvname: HouseAge
protocol: pva
serve: true
AveRooms:
pvname: AveRooms
protocol: pva
serve: true
AveBedrms:
pvname: AveBedrms
protocol: pva
serve: true
Population:
pvname: Population
protocol: pva
serve: true
AveOccup:
pvname: AveOccup
protocol: pva
serve: true
Latitude:
pvname: Latitude
protocol: pva
serve: true
Longitude:
pvname: Longitude
protocol: pva
serve: true
output_variables:
MedHouseVal:
pvname: MedHouseVal
protocol: pva
serve: true
48 changes: 48 additions & 0 deletions examples/files/california_model_info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
{
"train_input_mins": [
0.4999000132083893,
1.0,
0.8461538553237915,
0.375,
3.0,
0.692307710647583,
32.65999984741211,
-124.3499984741211
],
"train_input_maxs": [
15.000100135803223,
52.0,
141.90908813476562,
34.06666564941406,
28566.0,
599.7142944335938,
41.95000076293945,
-114.30999755859375
],
"model_in_list": [
"MedInc",
"HouseAge",
"AveRooms",
"AveBedrms",
"Population",
"AveOccup",
"Latitude",
"Longitude"
],
"model_out_list": [
"MedHouseVal"
],
"loc_in": {
"MedInc": 0,
"HouseAge": 1,
"AveRooms": 2,
"AveBedrms": 3,
"Population": 4,
"AveOccup": 5,
"Latitude": 6,
"Longitude": 7
},
"loc_out": {
"MedHouseVal": 0
}
}
28 changes: 28 additions & 0 deletions examples/files/california_normalization.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"x_mean": [
3.7857348454995345,
29.282134699245518,
5.4074908062777505,
1.1071723692329019,
1437.0687339932165,
3.0354138620008504,
35.283234587868925,
-119.11572989538202
],
"x_scale": [
1.8973481323832475,
12.395369585636528,
2.8019995847982195,
0.5464532026823882,
1141.2672447074444,
5.462164977692798,
2.026003694993028,
1.8325257865272555
],
"y_mean": [
1.985587451669133
],
"y_scale": [
1.1218406322612422
]
}
Binary file added examples/files/california_regression.pt
Binary file not shown.
89 changes: 89 additions & 0 deletions examples/pytorch/california_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from bokeh.io import curdoc
from bokeh.layouts import column, row
from bokeh.models import Div, Button

from lume_epics.client.controller import Controller
from lume_model.utils import variables_from_yaml
from lume_epics.utils import config_from_yaml

from lume_epics.client.widgets.tables import ValueTable
from lume_epics.client.widgets.controls import build_sliders
from lume_epics.client.controller import Controller

# load the model and the variables from LUME model
with open("examples/files/california_config.yml", "r") as f:
input_variables, output_variables = variables_from_yaml(f)

# load the EPICS pv definitions
with open("examples/files/california_epics_config.yml", "r") as f:
epics_config = config_from_yaml(f)

# create controller from epics config
controller = Controller(epics_config)

# prepare as list for rendering
# define the variables that have range to make as sliders
sliding_variables = [
input_var
for input_var in input_variables.values()
if input_var.value_range[0] != input_var.value_range[1]
]
input_variables = list(input_variables.values())
output_variables = list(output_variables.values())

# define the plots we want to see - sliders for all input values
# and tables summarising the current state of the inputs and
# output values
sliders = build_sliders(sliding_variables, controller)
input_value_table = ValueTable(input_variables, controller)
output_value_table = ValueTable(output_variables, controller)


title_div = Div(
text=f"<b>California Housing Prediction: Last update {controller.last_update}</b>",
style={
"font-size": "150%",
"color": "#3881e8",
"text-align": "center",
"width": "100%",
},
)


def update_div_text():
global controller
title_div.text = (
f"<b>California Housing Prediction: Last update {controller.last_update}</b>"
)


def reset_slider_values():
for slider in sliders:
slider.reset()


slider_reset_button = Button(label="Reset")
slider_reset_button.on_click(reset_slider_values)

# render
curdoc().title = "California Housing Prediction"
curdoc().add_root(
column(
row(column(title_div, width=600)),
row(
column(
[slider_reset_button] + [slider.bokeh_slider for slider in sliders],
width=350,
),
column(input_value_table.table, output_value_table.table, width=350),
),
),
)

# add refresh callbacks to ensure that the values are updated
# curdoc().add_periodic_callback(image_plot.update, 1000)
for slider in sliders:
curdoc().add_periodic_callback(slider.update, 1000)
curdoc().add_periodic_callback(update_div_text, 1000)
curdoc().add_periodic_callback(input_value_table.update, 1000)
curdoc().add_periodic_callback(output_value_table.update, 1000)
42 changes: 42 additions & 0 deletions examples/pytorch/california_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from lume_epics.epics_server import Server
from lume_model.utils import model_from_yaml
from lume_epics.utils import config_from_yaml
from pathlib import Path
import json
from botorch.models.transforms.input import AffineInputTransform
import torch
from pprint import pprint

if __name__ == "__main__":
# load the model and the variables from LUME model
with open("examples/files/california_config.yml", "r") as f:
model_class, model_kwargs = model_from_yaml(f, load_model=False)

# load the EPICS pv definitions
with open("examples/files/california_epics_config.yml", "r") as f:
epics_config = config_from_yaml(f)

# load the transformers required for the model
with open("examples/files/california_normalization.json", "r") as f:
normalizations = json.load(f)

input_transformer = AffineInputTransform(
len(normalizations["x_mean"]),
coefficient=torch.tensor(normalizations["x_scale"]),
offset=torch.tensor(normalizations["x_mean"]),
)
output_transformer = AffineInputTransform(
len(normalizations["y_mean"]),
coefficient=torch.tensor(normalizations["y_scale"]),
offset=torch.tensor(normalizations["y_mean"]),
)

# update the model kwargs with the transformers
model_kwargs["input_transformers"] = [input_transformer]
model_kwargs["output_transformers"] = [output_transformer]

# start the EPICS server
server = Server(model_class, epics_config, model_kwargs=model_kwargs)

# monitor = False does not loop in main thread
server.start(monitor=True)
3 changes: 3 additions & 0 deletions lume_epics/client/widgets/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def update(self):
"""
self.bokeh_slider.value = self.controller.get_value(self.pvname)

def reset(self):
self.bokeh_slider.value = self.variable.default


def build_sliders(
Expand Down
15 changes: 5 additions & 10 deletions lume_epics/epics_pva_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,16 @@ def update_pv(self, pvname: str, value: Union[np.ndarray, float]) -> None:
varname = self._pvname_to_varname_map[pvname]
model_variable = self._input_variables[varname]

# check for already cached variable
model_variable = self._cached_values.get(varname, model_variable)

if model_variable.variable_type == "image":
model_variable.x_min = value.attrib["x_min"]
model_variable.x_max = value.attrib["x_max"]
model_variable.y_min = value.attrib["y_min"]
model_variable.y_max = value.attrib["y_max"]

# check for already cached variable
model_variable = self._cached_values.get(varname, model_variable)
else:
model_variable.value = value

self._cached_values[varname] = model_variable

Expand Down Expand Up @@ -178,7 +180,6 @@ def setup_server(self) -> None:
].default

else:

if self._context is None:
self._context = Context("pva")

Expand All @@ -198,7 +199,6 @@ def setup_server(self) -> None:
self._initialize_model()
model_outputs = None
while not self.shutdown_event.is_set() and model_outputs is None:

try:
model_outputs = self._out_queue.get(timeout=0.1)
except Empty:
Expand All @@ -224,14 +224,11 @@ def setup_server(self) -> None:
self._structures = {}
self._structure_specs = {}
for variable_name, config in self._epics_config.items():

if config["serve"]:

fields = config.get("fields")
pvname = config.get("pvname")

if fields is not None:

spec = []
structure = {}

Expand All @@ -255,7 +252,6 @@ def setup_server(self) -> None:
spec.append((field, "v"))
table_rep = ()
for col in variable.columns:

# here we assume double type in tables...
table_rep += (col, "ad")

Expand Down Expand Up @@ -320,7 +316,6 @@ def setup_server(self) -> None:
elif variable.variable_type == "table":
table_rep = ()
for col in variable.columns:

# here we assume double type in tables...
table_rep += (col, "ad")

Expand Down

0 comments on commit 8da4089

Please sign in to comment.