Skip to content

Commit

Permalink
Merge pull request #21 from jacquelinegarrahan/model_load
Browse files Browse the repository at this point in the history
Model load enhancements
  • Loading branch information
jacquelinegarrahan authored Oct 21, 2020
2 parents eaf5916 + 3bbdc6f commit 075f818
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 22 deletions.
2 changes: 2 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ The `KerasModel` packaged in the toolkit will be compatible with models saved us
An example of a model built using the functional API is given below:

```python
from tensorflow import keras
import tensorflow as tf

sepal_length_input = keras.Input(shape=(1,), name="SepalLength")
sepal_width_input = keras.Input(shape=(1,), name="SepalWidth")
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion examples/iris_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
from lume_model.utils import model_from_yaml

with open("examples/files/iris_config.yaml", "r") as f:
with open("examples/files/iris_config.yml", "r") as f:
model = model_from_yaml(f)

model.random_evaluate()
32 changes: 18 additions & 14 deletions lume_model/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@
logger = logging.getLogger(__name__)


base_layers = {
"ScaleLayer": ScaleLayer,
"UnscaleLayer": UnscaleLayer,
"UnscaleImgLayer": UnscaleImgLayer,
}


class KerasModel(SurrogateModel):
"""
The KerasModel class is used for the loading and evaluation of online models. It is designed to
Expand All @@ -35,34 +42,31 @@ def __init__(
output_variables: Dict[str, OutputVariable],
input_format: dict = {},
output_format: dict = {},
custom_layers: dict = {},
) -> None:
"""Initializes the model and stores inputs/outputs.
Args:
model_file (str): Path to model file generated with keras.save()
input_variables (List[InputVariable]): list of model input variables
output_variables (List[OutputVariable]): list of model output variables
custom_layers
"""

# Save init
self.input_variables = input_variables
self.output_variables = output_variables
self._model_file = model_file
self._input_format = input_format
self._output_format = output_format
self._model_file = model_file

base_layers.update(custom_layers)

# load model in thread safe manner
self._thread_graph = tf.Graph()
with self._thread_graph.as_default():
self._model = load_model(
model_file,
custom_objects={
"ScaleLayer": ScaleLayer,
"UnscaleLayer": UnscaleLayer,
"UnscaleImgLayer": UnscaleImgLayer,
},
)
self._model = load_model(model_file, custom_objects=base_layers,)

def evaluate(self, input_variables: List[InputVariable]) -> List[OutputVariable]:
"""Evaluate model using new input variables.
Expand Down Expand Up @@ -183,13 +187,13 @@ def parse_output(self, model_output):
"""
output_dict = {}

if self._output_format["type"] == "softmax":
if not self._output_format.get("type") or self._output_format["type"] == "raw":
for idx, output_name in enumerate(self._model.output_names):
softmax_output = list(model_output[idx])
output_dict[output_name] = softmax_output.index(max(softmax_output))
output_dict[output_name] = model_output[idx]

if self._output_format["type"] == "raw":
elif self._output_format["type"] == "softmax":
for idx, output_name in enumerate(self._model.output_names):
output_dict[output_name] = model_output[idx]
softmax_output = list(model_output[idx])
output_dict[output_name] = softmax_output.index(max(softmax_output))

return output_dict
2 changes: 1 addition & 1 deletion lume_model/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ def rootdir():

@pytest.fixture
def config_file(rootdir):
return open(f"{rootdir}/test_files/iris_config.yaml", "r")
return open(f"{rootdir}/test_files/iris_config.yml", "r")
File renamed without changes.
37 changes: 31 additions & 6 deletions lume_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def model_from_yaml(
lume_model_var = ScalarOutputVariable(**variable_config)

elif variable_config["type"] == "image":
variable_config["default"] = np.load(variable_config["default"])
variable_config["axis_labels"] = [
variable_config["x_label"],
variable_config["y_label"],
Expand Down Expand Up @@ -217,13 +216,25 @@ def model_from_yaml(

model_class = locate(config["model"]["model_class"])
if "kwargs" in config["model"]:
model_kwargs.update(config["model"]["kwargs"])

if "input_format" in config["model"]:
model_kwargs["input_format"] = config["model"]["input_format"]
if "custom_layers" in config["model"]["kwargs"]:
custom_layers = config["model"]["kwargs"]["custom_layers"]

# delete key to avoid overwrite
del config["model"]["kwargs"]["custom_layers"]
model_kwargs["custom_layers"] = {}

for layer, import_path in custom_layers.items():
layer_class = locate(import_path)

if layer_class is not None:
model_kwargs["custom_layers"][layer] = layer_class

else:
logger.exception("Layer class %s not found.", layer)
sys.exit()

if "output_format" in config["model"]:
model_kwargs["output_format"] = config["model"]["output_format"]
model_kwargs.update(config["model"]["kwargs"])

if model_class is None:
logger.exception("No model class found.")
Expand Down Expand Up @@ -267,6 +278,13 @@ def variables_from_yaml(config_file):

# build variable
if variable_config["type"] == "scalar":

if variable_config.get("is_constant"):
variable_config["range"] = [
variable_config["default"],
variable_config["default"],
]

lume_model_var = ScalarInputVariable(**variable_config)

elif variable_config["type"] == "image":
Expand All @@ -275,6 +293,13 @@ def variables_from_yaml(config_file):
variable_config["x_label"],
variable_config["y_label"],
]

if variable_config.get("is_constant"):
variable_config["range"] = [
np.amin(variable_config["default"]),
np.amax(variable_config["default"]),
]

lume_model_var = ImageInputVariable(**variable_config)

else:
Expand Down

0 comments on commit 075f818

Please sign in to comment.