diff --git a/lume_model/utils.py b/lume_model/utils.py index 0ee67c0..d9d3077 100644 --- a/lume_model/utils.py +++ b/lume_model/utils.py @@ -91,7 +91,9 @@ def load_variables(variable_file: str) -> Tuple[dict]: return variables["input_variables"], variables["output_variables"] -def model_from_yaml(config_file, model_class=None, model_kwargs=None): +def model_from_yaml( + config_file, model_class=None, model_kwargs: dict = None, load_model: bool = True +): """Creates model from yaml configuration. The model class for initialization may either be passed to the function as a kwarg or defined in the config file. This function will attempt to import the path specified in the yaml. @@ -99,6 +101,7 @@ def model_from_yaml(config_file, model_class=None, model_kwargs=None): Args: config_file: Config file model_class: Class for initializing model + load_model (bool): If True, will return model. If False, will return model class and model_kwargs. Returns: model: Initialized model @@ -212,7 +215,7 @@ def model_from_yaml(config_file, model_class=None, model_kwargs=None): else: logger.warning("Module not installed") - klass = locate(config["model"]["model_class"]) + model_class = locate(config["model"]["model_class"]) if "kwargs" in config["model"]: model_kwargs.update(config["model"]["kwargs"]) @@ -222,23 +225,21 @@ def model_from_yaml(config_file, model_class=None, model_kwargs=None): if "output_format" in config["model"]: model_kwargs["output_format"] = config["model"]["output_format"] - try: - model = klass(**model_kwargs) - except: - logger.exception(f"Unable to load model with args: {model_kwargs}") - sys.exit() - - elif model_class is not None: - if model_kwargs: - model_kwargs.update((model_kwargs)) + if model_class is None: + logger.exception("No model class found.") + sys.exit() + if load_model: try: model = model_class(**model_kwargs) except: logger.exception(f"Unable to load model with args: {model_kwargs}") sys.exit() - return model + return model + + else: + return model_class, model_kwargs def variables_from_yaml(config_file):