diff --git a/docs/index.md b/docs/index.md index 1f2cb41..f147242 100644 --- a/docs/index.md +++ b/docs/index.md @@ -174,6 +174,8 @@ The `KerasModel` packaged in the toolkit will be compatible with models saved us An example of a model built using the functional API is given below: ```python +from tensorflow import keras +import tensorflow as tf sepal_length_input = keras.Input(shape=(1,), name="SepalLength") sepal_width_input = keras.Input(shape=(1,), name="SepalWidth") diff --git a/examples/files/iris_config.yaml b/examples/files/iris_config.yml similarity index 100% rename from examples/files/iris_config.yaml rename to examples/files/iris_config.yml diff --git a/examples/iris_model.py b/examples/iris_model.py index bba66f9..b2f731c 100644 --- a/examples/iris_model.py +++ b/examples/iris_model.py @@ -3,7 +3,7 @@ """ from lume_model.utils import model_from_yaml -with open("examples/files/iris_config.yaml", "r") as f: +with open("examples/files/iris_config.yml", "r") as f: model = model_from_yaml(f) model.random_evaluate() diff --git a/lume_model/keras/__init__.py b/lume_model/keras/__init__.py index ec5e964..691de3a 100644 --- a/lume_model/keras/__init__.py +++ b/lume_model/keras/__init__.py @@ -13,6 +13,13 @@ logger = logging.getLogger(__name__) +base_layers = { + "ScaleLayer": ScaleLayer, + "UnscaleLayer": UnscaleLayer, + "UnscaleImgLayer": UnscaleImgLayer, +} + + class KerasModel(SurrogateModel): """ The KerasModel class is used for the loading and evaluation of online models. It is designed to @@ -35,6 +42,7 @@ def __init__( output_variables: Dict[str, OutputVariable], input_format: dict = {}, output_format: dict = {}, + custom_layers: dict = {}, ) -> None: """Initializes the model and stores inputs/outputs. @@ -42,27 +50,23 @@ def __init__( model_file (str): Path to model file generated with keras.save() input_variables (List[InputVariable]): list of model input variables output_variables (List[OutputVariable]): list of model output variables + custom_layers """ # Save init self.input_variables = input_variables self.output_variables = output_variables + self._model_file = model_file self._input_format = input_format self._output_format = output_format - self._model_file = model_file + + base_layers.update(custom_layers) # load model in thread safe manner self._thread_graph = tf.Graph() with self._thread_graph.as_default(): - self._model = load_model( - model_file, - custom_objects={ - "ScaleLayer": ScaleLayer, - "UnscaleLayer": UnscaleLayer, - "UnscaleImgLayer": UnscaleImgLayer, - }, - ) + self._model = load_model(model_file, custom_objects=base_layers,) def evaluate(self, input_variables: List[InputVariable]) -> List[OutputVariable]: """Evaluate model using new input variables. @@ -183,13 +187,13 @@ def parse_output(self, model_output): """ output_dict = {} - if self._output_format["type"] == "softmax": + if not self._output_format.get("type") or self._output_format["type"] == "raw": for idx, output_name in enumerate(self._model.output_names): - softmax_output = list(model_output[idx]) - output_dict[output_name] = softmax_output.index(max(softmax_output)) + output_dict[output_name] = model_output[idx] - if self._output_format["type"] == "raw": + elif self._output_format["type"] == "softmax": for idx, output_name in enumerate(self._model.output_names): - output_dict[output_name] = model_output[idx] + softmax_output = list(model_output[idx]) + output_dict[output_name] = softmax_output.index(max(softmax_output)) return output_dict diff --git a/lume_model/tests/conftest.py b/lume_model/tests/conftest.py index 8f22f16..76f0eb3 100644 --- a/lume_model/tests/conftest.py +++ b/lume_model/tests/conftest.py @@ -9,4 +9,4 @@ def rootdir(): @pytest.fixture def config_file(rootdir): - return open(f"{rootdir}/test_files/iris_config.yaml", "r") + return open(f"{rootdir}/test_files/iris_config.yml", "r") diff --git a/lume_model/tests/test_files/iris_config.yaml b/lume_model/tests/test_files/iris_config.yml similarity index 100% rename from lume_model/tests/test_files/iris_config.yaml rename to lume_model/tests/test_files/iris_config.yml diff --git a/lume_model/utils.py b/lume_model/utils.py index d9d3077..14f0d73 100644 --- a/lume_model/utils.py +++ b/lume_model/utils.py @@ -157,7 +157,6 @@ def model_from_yaml( lume_model_var = ScalarOutputVariable(**variable_config) elif variable_config["type"] == "image": - variable_config["default"] = np.load(variable_config["default"]) variable_config["axis_labels"] = [ variable_config["x_label"], variable_config["y_label"], @@ -217,13 +216,25 @@ def model_from_yaml( model_class = locate(config["model"]["model_class"]) if "kwargs" in config["model"]: - model_kwargs.update(config["model"]["kwargs"]) - if "input_format" in config["model"]: - model_kwargs["input_format"] = config["model"]["input_format"] + if "custom_layers" in config["model"]["kwargs"]: + custom_layers = config["model"]["kwargs"]["custom_layers"] + + # delete key to avoid overwrite + del config["model"]["kwargs"]["custom_layers"] + model_kwargs["custom_layers"] = {} + + for layer, import_path in custom_layers.items(): + layer_class = locate(import_path) + + if layer_class is not None: + model_kwargs["custom_layers"][layer] = layer_class + + else: + logger.exception("Layer class %s not found.", layer) + sys.exit() - if "output_format" in config["model"]: - model_kwargs["output_format"] = config["model"]["output_format"] + model_kwargs.update(config["model"]["kwargs"]) if model_class is None: logger.exception("No model class found.") @@ -267,6 +278,13 @@ def variables_from_yaml(config_file): # build variable if variable_config["type"] == "scalar": + + if variable_config.get("is_constant"): + variable_config["range"] = [ + variable_config["default"], + variable_config["default"], + ] + lume_model_var = ScalarInputVariable(**variable_config) elif variable_config["type"] == "image": @@ -275,6 +293,13 @@ def variables_from_yaml(config_file): variable_config["x_label"], variable_config["y_label"], ] + + if variable_config.get("is_constant"): + variable_config["range"] = [ + np.amin(variable_config["default"]), + np.amax(variable_config["default"]), + ] + lume_model_var = ImageInputVariable(**variable_config) else: