Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get the latest artifact of a model by creation date instead of version name #3343

Merged
45 changes: 18 additions & 27 deletions src/zenml/models/v2/core/model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator

from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
from zenml.enums import ModelStages
from zenml.enums import ArtifactType, ModelStages
from zenml.metadata.metadata_types import MetadataType
from zenml.models.v2.base.filter import AnyQuery
from zenml.models.v2.base.page import Page
Expand Down Expand Up @@ -426,32 +426,36 @@ def pipeline_runs(self) -> Dict[str, "PipelineRunResponse"]:

def _get_linked_object(
self,
collection: Dict[str, Dict[str, UUID]],
name: str,
version: Optional[str] = None,
type: Optional[ArtifactType] = None,
) -> Optional["ArtifactVersionResponse"]:
"""Get the artifact linked to this model version given type.

Args:
collection: The collection to search in (one of
self.model_artifact_ids, self.data_artifact_ids,
self.deployment_artifact_ids)
name: The name of the artifact to retrieve.
version: The version of the artifact to retrieve (None for
latest/non-versioned)
type: The type of the artifact to filter by.

Returns:
Specific version of an artifact from collection or None
"""
from zenml.client import Client

client = Client()
artifact_versions = Client().list_artifact_versions(
sort_by="desc:created",
size=1,
name=name,
version=version,
model_version_id=self.id,
type=type,
hydrate=True,
)

if name not in collection:
if not artifact_versions.items:
return None
if version is None:
version = max(collection[name].keys())
return client.get_artifact_version(collection[name][version])
return artifact_versions.items[0]

def get_artifact(
self,
Expand All @@ -468,12 +472,7 @@ def get_artifact(
Returns:
Specific version of an artifact or None
"""
all_artifact_ids = {
**self.model_artifact_ids,
**self.data_artifact_ids,
**self.deployment_artifact_ids,
}
return self._get_linked_object(all_artifact_ids, name, version)
return self._get_linked_object(name, version)

def get_model_artifact(
self,
Expand All @@ -490,7 +489,7 @@ def get_model_artifact(
Returns:
Specific version of the model artifact or None
"""
return self._get_linked_object(self.model_artifact_ids, name, version)
return self._get_linked_object(name, version, ArtifactType.MODEL)

def get_data_artifact(
self,
Expand All @@ -507,11 +506,7 @@ def get_data_artifact(
Returns:
Specific version of the data artifact or None
"""
return self._get_linked_object(
self.data_artifact_ids,
name,
version,
)
return self._get_linked_object(name, version, ArtifactType.DATA)

def get_deployment_artifact(
self,
Expand All @@ -528,11 +523,7 @@ def get_deployment_artifact(
Returns:
Specific version of the deployment artifact or None
"""
return self._get_linked_object(
self.deployment_artifact_ids,
name,
version,
)
return self._get_linked_object(name, version, ArtifactType.SERVICE)

def get_pipeline_run(self, name: str) -> "PipelineRunResponse":
"""Get pipeline run linked to this version.
Expand Down
181 changes: 80 additions & 101 deletions tests/unit/models/test_model_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@
# permissions and limitations under the License.

from datetime import datetime
from unittest.mock import patch
from uuid import uuid4

import pytest

from tests.unit.steps.test_external_artifact import MockZenmlClient
from zenml.enums import ArtifactType
from zenml.models import (
ModelResponse,
ModelResponseBody,
Expand All @@ -28,106 +25,88 @@
ModelVersionResponseMetadata,
)

ARTIFACT_VERSION_IDS = [uuid4(), uuid4()]

def test_model_version_response_artifact_fetching(
clean_client, mocker, sample_workspace_model
):
"""Test artifact fetching from a model version response."""
mock_list_artifact_versions = mocker.patch.object(
type(clean_client),
"list_artifact_versions",
)

@pytest.mark.parametrize(
"artifact_object_ids,query_name,query_version,expected",
(
(
{"artifact": {"1": ARTIFACT_VERSION_IDS[0]}},
"artifact",
None,
ARTIFACT_VERSION_IDS[0],
model = ModelResponse(
id=uuid4(),
name="model",
body=ModelResponseBody(
created=datetime.now(),
updated=datetime.now(),
tags=[],
),
(
{
"artifact": {
"1": ARTIFACT_VERSION_IDS[0],
"2": ARTIFACT_VERSION_IDS[1],
}
},
"artifact",
None,
ARTIFACT_VERSION_IDS[1],
metadata=ModelResponseMetadata(
workspace=sample_workspace_model,
),
(
{
"artifact": {
"1": ARTIFACT_VERSION_IDS[0],
"2": ARTIFACT_VERSION_IDS[1],
}
},
"artifact",
"1",
ARTIFACT_VERSION_IDS[0],
)
mv = ModelVersionResponse(
id=uuid4(),
name="foo",
body=ModelVersionResponseBody(
created=datetime.now(),
updated=datetime.now(),
model=model,
number=-1,
),
(
{},
"artifact",
None,
None,
metadata=ModelVersionResponseMetadata(
workspace=sample_workspace_model,
),
),
ids=[
"No collision",
"Latest version",
"Specific version",
"Not found",
],
)
def test_getters(
artifact_object_ids,
query_name,
query_version,
expected,
sample_workspace_model,
):
"""Test that the getters work as expected."""
with patch.dict(
"sys.modules",
{
"zenml.client": MockZenmlClient,
},
):
model = ModelResponse(
id=uuid4(),
name="model",
body=ModelResponseBody(
created=datetime.now(),
updated=datetime.now(),
tags=[],
),
metadata=ModelResponseMetadata(
workspace=sample_workspace_model,
),
)
mv = ModelVersionResponse(
id=uuid4(),
name="foo",
body=ModelVersionResponseBody(
created=datetime.now(),
updated=datetime.now(),
model=model,
number=-1,
data_artifact_ids=artifact_object_ids,
),
metadata=ModelVersionResponseMetadata(
workspace=sample_workspace_model,
),
)
if expected != "RuntimeError":
got = mv.get_data_artifact(
name=query_name,
version=query_version,
)
if got is not None:
assert got.id == expected
else:
assert expected is None
else:
with pytest.raises(RuntimeError):
mv.get_data_artifact(
name=query_name,
version=query_version,
)
)

artifact_name = "artifact_name"
version_name = "version_name"

mv.get_artifact(artifact_name, version_name)
mock_list_artifact_versions.assert_called_once_with(
sort_by="desc:created",
size=1,
name=artifact_name,
version=version_name,
model_version_id=mv.id,
type=None,
hydrate=True,
)
mock_list_artifact_versions.reset_mock()

mv.get_data_artifact(artifact_name, version_name)
mock_list_artifact_versions.assert_called_once_with(
sort_by="desc:created",
size=1,
name=artifact_name,
version=version_name,
model_version_id=mv.id,
type=ArtifactType.DATA,
hydrate=True,
)
mock_list_artifact_versions.reset_mock()

mv.get_model_artifact(artifact_name, version_name)
mock_list_artifact_versions.assert_called_once_with(
sort_by="desc:created",
size=1,
name=artifact_name,
version=version_name,
model_version_id=mv.id,
type=ArtifactType.MODEL,
hydrate=True,
)
mock_list_artifact_versions.reset_mock()

mv.get_deployment_artifact(artifact_name, version_name)
mock_list_artifact_versions.assert_called_once_with(
sort_by="desc:created",
size=1,
name=artifact_name,
version=version_name,
model_version_id=mv.id,
type=ArtifactType.SERVICE,
hydrate=True,
)
Loading