Skip to content

Commit 0906986

Browse files
authored
Restore call convention compatibility in get_model (#304)
A bug surfaced where first time evaluation of a model fails due to the Model constructor throwing if the model does not exist. Looking deeper, we see that most calls to get_model expect a possible None response and check at the call site. Unfortunately we get the same WebserviceException class for a model not being found as we do a REST error or similar. This change is a stopgap mitigation to restore compatibility with the existing callers, and compromises by allowing the model version dependent behavior to continue passing on exceptions. In a future follow up we should settle on a convention and allow version checks to propagate failure while still giving the possibility for handling a service exception in the caller.
1 parent b21a46e commit 0906986

File tree

1 file changed

+30
-15
lines changed

1 file changed

+30
-15
lines changed

diabetes_regression/util/model_helper.py

+30-15
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
def get_current_workspace() -> Workspace:
1010
"""
11-
Retrieves and returns the latest model from the workspace
12-
by its name and tag. Will not work when ran locally.
11+
Retrieves and returns the current workspace.
12+
Will not work when ran locally.
1313
1414
Parameters:
1515
None
@@ -30,8 +30,8 @@ def get_model(
3030
aml_workspace: Workspace = None
3131
) -> AMLModel:
3232
"""
33-
Retrieves and returns the latest model from the workspace
34-
by its name and (optional) tag.
33+
Retrieves and returns a model from the workspace by its name
34+
and (optional) tag.
3535
3636
Parameters:
3737
aml_workspace (Workspace): aml.core Workspace that the model lives.
@@ -40,25 +40,40 @@ def get_model(
4040
(optional) tag (str): the tag value & name the model was registered under.
4141
4242
Return:
43-
A single aml model from the workspace that matches the name and tag.
43+
A single aml model from the workspace that matches the name and tag, or
44+
None.
4445
"""
4546
if aml_workspace is None:
4647
print("No workspace defined - using current experiment workspace.")
4748
aml_workspace = get_current_workspace()
4849

49-
if tag_name is not None and tag_value is not None:
50+
tags = None
51+
if tag_name is not None or tag_value is not None:
52+
# Both a name and value must be specified to use tags.
53+
if tag_name is None or tag_value is None:
54+
raise ValueError(
55+
"model_tag_name and model_tag_value should both be supplied"
56+
+ "or excluded" # NOQA: E501
57+
)
58+
tags = [[tag_name, tag_value]]
59+
60+
model = None
61+
if model_version is not None:
62+
# TODO(tcare): Finding a specific version currently expects exceptions
63+
# to propagate in the case we can't find the model. This call may
64+
# result in a WebserviceException that may or may not be due to the
65+
# model not existing.
5066
model = AMLModel(
5167
aml_workspace,
5268
name=model_name,
5369
version=model_version,
54-
tags=[[tag_name, tag_value]])
55-
elif (tag_name is None and tag_value is not None) or (
56-
tag_value is None and tag_name is not None
57-
):
58-
raise ValueError(
59-
"model_tag_name and model_tag_value should both be supplied"
60-
+ "or excluded" # NOQA: E501
61-
)
70+
tags=tags)
6271
else:
63-
model = AMLModel(aml_workspace, name=model_name, version=model_version) # NOQA: E501
72+
models = AMLModel.list(
73+
aml_workspace, name=model_name, tags=tags, latest=True)
74+
if len(models) == 1:
75+
model = models[0]
76+
elif len(models) > 1:
77+
raise Exception("Expected only one model")
78+
6479
return model

0 commit comments

Comments
 (0)