Skip to content

Commit

Permalink
MNT: Replace args with kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
jacquelinegarrahan committed Oct 15, 2020
1 parent 265ae20 commit 26ebec0
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 13 deletions.
3 changes: 1 addition & 2 deletions examples/files/iris_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ model:
model_class: lume_model.keras.KerasModel
requirements:
tensorflow: 2.3.1
args:
kwargs:
model_file: examples/files/iris_model.h5
output_format:
type: softmax


input_variables:
SepalLength:
name: SepalLength
Expand Down
2 changes: 1 addition & 1 deletion lume_model/tests/test_files/iris_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ model:
model_class: lume_model.keras.KerasModel
requirements:
tensorflow: 2.3.1
args:
kwargs:
model_file: examples/files/iris_model.h5
output_format:
type: softmax
Expand Down
20 changes: 10 additions & 10 deletions lume_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def model_from_yaml(config_file, model_class=None, model_kwargs=None):
sys.exit()

model = None
model_args = {
model_kwargs = {
"input_variables": input_variables,
"output_variables": output_variables,
}
Expand Down Expand Up @@ -210,29 +210,29 @@ def model_from_yaml(config_file, model_class=None, model_kwargs=None):
logger.warning("Module not installed")

klass = locate(config["model"]["model_class"])
if "args" in config["model"]:
model_args.update(config["model"]["args"])
if "kwargs" in config["model"]:
model_kwargs.update(config["model"]["kwargs"])

if "input_format" in config["model"]:
model_args["input_format"] = config["model"]["input_format"]
model_kwargs["input_format"] = config["model"]["input_format"]

if "output_format" in config["model"]:
model_args["output_format"] = config["model"]["output_format"]
model_kwargs["output_format"] = config["model"]["output_format"]

try:
model = klass(**model_args)
model = klass(**model_kwargs)
except:
logger.exception(f"Unable to load model with args: {model_args}")
logger.exception(f"Unable to load model with args: {model_kwargs}")
sys.exit()

elif model_class is not None:
if model_kwargs:
model_args.update((model_kwargs))
model_kwargs.update((model_kwargs))

try:
model = model_class(**model_args)
model = model_class(**model_kwargs)
except:
logger.exception(f"Unable to load model with args: {model_args}")
logger.exception(f"Unable to load model with args: {model_kwargs}")
sys.exit()

return model
Expand Down

0 comments on commit 26ebec0

Please sign in to comment.