Skip to content

Commit

Permalink
Merge pull request #20 from jacquelinegarrahan/model_load
Browse files Browse the repository at this point in the history
BUG: Clean up the model class use in utils
  • Loading branch information
jacquelinegarrahan authored Oct 20, 2020
2 parents b0fb230 + 335772d commit eaf5916
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions lume_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,17 @@ def load_variables(variable_file: str) -> Tuple[dict]:
return variables["input_variables"], variables["output_variables"]


def model_from_yaml(config_file, model_class=None, model_kwargs=None):
def model_from_yaml(
config_file, model_class=None, model_kwargs: dict = None, load_model: bool = True
):
"""Creates model from yaml configuration. The model class for initialization may
either be passed to the function as a kwarg or defined in the config file. This function will
attempt to import the path specified in the yaml.
Args:
config_file: Config file
model_class: Class for initializing model
load_model (bool): If True, will return model. If False, will return model class and model_kwargs.
Returns:
model: Initialized model
Expand Down Expand Up @@ -212,7 +215,7 @@ def model_from_yaml(config_file, model_class=None, model_kwargs=None):
else:
logger.warning("Module not installed")

klass = locate(config["model"]["model_class"])
model_class = locate(config["model"]["model_class"])
if "kwargs" in config["model"]:
model_kwargs.update(config["model"]["kwargs"])

Expand All @@ -222,23 +225,21 @@ def model_from_yaml(config_file, model_class=None, model_kwargs=None):
if "output_format" in config["model"]:
model_kwargs["output_format"] = config["model"]["output_format"]

try:
model = klass(**model_kwargs)
except:
logger.exception(f"Unable to load model with args: {model_kwargs}")
sys.exit()

elif model_class is not None:
if model_kwargs:
model_kwargs.update((model_kwargs))
if model_class is None:
logger.exception("No model class found.")
sys.exit()

if load_model:
try:
model = model_class(**model_kwargs)
except:
logger.exception(f"Unable to load model with args: {model_kwargs}")
sys.exit()

return model
return model

else:
return model_class, model_kwargs


def variables_from_yaml(config_file):
Expand Down

0 comments on commit eaf5916

Please sign in to comment.