diff --git a/mlem/core/metadata.py b/mlem/core/metadata.py index ea7dd3ff..4d1a56cb 100644 --- a/mlem/core/metadata.py +++ b/mlem/core/metadata.py @@ -2,24 +2,41 @@ Functions to work with metadata: saving, loading, searching for MLEM object by given path. """ +import logging import posixpath -from typing import Any, Optional, Type, TypeVar, Union, overload +from typing import Any, Dict, List, Optional, Type, TypeVar, Union, overload from fsspec import AbstractFileSystem from typing_extensions import Literal -from mlem.core.errors import HookNotFound, MlemObjectNotFound +from mlem.core.errors import HookNotFound, MlemObjectNotFound, MlemRootNotFound from mlem.core.meta_io import Location, UriResolver, get_meta_path from mlem.core.objects import DatasetMeta, MlemMeta, ModelMeta, find_object from mlem.utils.path import make_posix +logger = logging.getLogger(__name__) -def get_object_metadata(obj: Any, tmp_sample_data=None) -> MlemMeta: + +def get_object_metadata( + obj: Any, + tmp_sample_data=None, + description: str = None, + params: Dict[str, str] = None, + tags: List[str] = None, +) -> MlemMeta: """Convert given object to appropriate MlemMeta subclass""" try: - return DatasetMeta.from_data(obj) + return DatasetMeta.from_data( + obj, description=description, params=params, tags=tags + ) except HookNotFound: - return ModelMeta.from_obj(obj, sample_data=tmp_sample_data) + return ModelMeta.from_obj( + obj, + sample_data=tmp_sample_data, + description=description, + params=params, + tags=tags, + ) def save( @@ -29,8 +46,12 @@ def save( dvc: bool = False, tmp_sample_data=None, fs: Union[str, AbstractFileSystem] = None, - link: bool = True, + link: bool = None, external: Optional[bool] = None, + description: str = None, + params: Dict[str, str] = None, + tags: List[str] = None, + update: bool = False, ) -> MlemMeta: """Saves given object to a given path @@ -38,17 +59,38 @@ def save( obj: Object to dump path: If not located on LocalFileSystem, then should be uri or `fs` argument should be provided + repo: path to mlem repo (optional) dvc: Store the object's artifacts with dvc tmp_sample_data: If the object is a model or function, you can provide input data sample, so MLEM will include it's schema in the model's metadata fs: FileSystem for the `path` argument link: Whether to create a link in .mlem folder found for `path` + external: if obj is saved to repo, whether to put it outside of .mlem dir + description: description for object + params: arbitrary params for object + tags: tags for object + update: whether to keep old description/tags/params if new values were not provided Returns: None """ - meta = get_object_metadata(obj, tmp_sample_data) + if update and (description is None or params is None or tags is None): + try: + old_meta = load_meta(path, repo=repo, fs=fs, load_value=False) + description = description or old_meta.description + params = params or old_meta.params + tags = tags or old_meta.tags + except MlemObjectNotFound: + logger.warning( + "Saving with update=True, but no existing object found at %s %s %s", + repo, + path, + fs, + ) + meta = get_object_metadata( + obj, tmp_sample_data, description=description, params=params, tags=tags + ) meta.dump(path, fs=fs, repo=repo, link=link, external=external) if dvc: # TODO dvc add ./%name% https://github.com/iterative/mlem/issues/47 @@ -68,7 +110,8 @@ def load( path (str): Path to the object. Could be local path or path inside a git repo. repo (Optional[str], optional): URL to repo if object is located there. rev (Optional[str], optional): revision, could be git commit SHA, branch name or tag. - follow_links (bool, optional): If object we read is a MLEM link, whether to load the actual object link points to. Defaults to True. + follow_links (bool, optional): If object we read is a MLEM link, whether to load the + actual object link points to. Defaults to True. Returns: Any: Python object saved by MLEM @@ -100,6 +143,20 @@ def load_meta( ... +@overload +def load_meta( + path: str, + repo: Optional[str] = None, + rev: Optional[str] = None, + follow_links: bool = True, + load_value: bool = False, + fs: Optional[AbstractFileSystem] = None, + *, + force_type: Optional[Type[T]] = None, +) -> T: + ... + + def load_meta( path: str, repo: Optional[str] = None, @@ -116,7 +173,8 @@ def load_meta( path (str): Path to the object. Could be local path or path inside a git repo. repo (Optional[str], optional): URL to repo if object is located there. rev (Optional[str], optional): revision, could be git commit SHA, branch name or tag. - follow_links (bool, optional): If object we read is a MLEM link, whether to load the actual object link points to. Defaults to True. + follow_links (bool, optional): If object we read is a MLEM link, whether to load the + actual object link points to. Defaults to True. load_value (bool, optional): Load actual python object incorporated in MlemMeta object. Defaults to False. fs: filesystem to load from. If not provided, will be inferred from path force_type: type of meta to be loaded. Defaults to MlemMeta (any mlem meta) @@ -163,7 +221,7 @@ def find_meta_location(location: Location) -> Location: _, path = find_object( location.path, fs=location.fs, repo=location.repo ) - except ValueError as e: + except (ValueError, MlemRootNotFound) as e: raise MlemObjectNotFound( f"MLEM object was not found at {location.uri}" ) from e diff --git a/mlem/core/objects.py b/mlem/core/objects.py index f8fe3ef7..b27879bd 100644 --- a/mlem/core/objects.py +++ b/mlem/core/objects.py @@ -7,7 +7,17 @@ from abc import ABC, abstractmethod from functools import partial from inspect import isabstract -from typing import Any, ClassVar, Dict, Optional, Tuple, Type, TypeVar, Union +from typing import ( + Any, + ClassVar, + Dict, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) from fsspec import AbstractFileSystem from fsspec.implementations.local import LocalFileSystem @@ -67,6 +77,9 @@ class MlemMeta(MlemObject): __abstract__: ClassVar[bool] = True object_type: ClassVar[str] location: Optional[Location] = None + description: Optional[str] = None + params: Dict[str, str] = {} + tags: List[str] = [] @property def loc(self) -> Location: @@ -486,11 +499,22 @@ class ModelMeta(_WithArtifacts): ) @classmethod - def from_obj(cls, model: Any, sample_data: Any = None) -> "ModelMeta": + def from_obj( + cls, + model: Any, + sample_data: Any = None, + description: str = None, + tags: List[str] = None, + params: Dict[str, str] = None, + ) -> "ModelMeta": mt = ModelAnalyzer.analyze(model, sample_data=sample_data) mt.model = model return ModelMeta( - model_type=mt, requirements=mt.get_requirements().expanded + model_type=mt, + requirements=mt.get_requirements().expanded, + description=description, + tags=tags or [], + params=params or {}, ) def write_value(self) -> Artifacts: @@ -539,12 +563,21 @@ def data(self): return self.dataset.data @classmethod - def from_data(cls, data: Any) -> "DatasetMeta": + def from_data( + cls, + data: Any, + description: str = None, + params: Dict[str, str] = None, + tags: List[str] = None, + ) -> "DatasetMeta": dataset = Dataset.create( data, ) meta = DatasetMeta( - requirements=dataset.dataset_type.get_requirements().expanded + requirements=dataset.dataset_type.get_requirements().expanded, + description=description, + params=params or {}, + tags=tags or [], ) meta.dataset = dataset return meta @@ -676,5 +709,5 @@ def find_object( ) if len(source_paths) > 1: raise ValueError(f"Ambiguous object {path}: {source_paths}") - type, source_path = source_paths[0] - return type, source_path + type_, source_path = source_paths[0] + return type_, source_path diff --git a/tests/core/test_metadata.py b/tests/core/test_metadata.py index 84bfd02c..4a8f69dd 100644 --- a/tests/core/test_metadata.py +++ b/tests/core/test_metadata.py @@ -6,6 +6,7 @@ import pytest import yaml +from pytest_lazyfixture import lazy_fixture from sklearn.datasets import load_iris from sklearn.ensemble import RandomForestClassifier from sklearn.tree import DecisionTreeClassifier @@ -26,10 +27,43 @@ ) +@pytest.mark.parametrize("obj", [lazy_fixture("model"), lazy_fixture("train")]) +def test_save_with_meta_fields(obj, tmpdir): + path = str(tmpdir / "obj") + save(obj, path, description="desc", params={"a": "b"}, tags=["tag"]) + new = load_meta(path) + assert new.description == "desc" + assert new.params == {"a": "b"} + assert new.tags == ["tag"] + + +def test_save_with_meta_fields_update(model, train, tmpdir): + path = str(tmpdir / "obj") + save( + model, + path, + description="desc", + params={"a": "b"}, + tags=["tag"], + update=True, + ) + save(train, path, update=True) + new = load_meta(path) + assert new.description == "desc" + assert new.params == {"a": "b"} + assert new.tags == ["tag"] + + +def test_saving_with_repo(model, tmpdir): + path = str(tmpdir / "obj") + save(model, path) + load_meta(path) + + def test_model_saving_without_sample_data(model, tmpdir_factory): - dir = str(tmpdir_factory.mktemp("saving-models-without-sample-data")) + path = str(tmpdir_factory.mktemp("saving-models-without-sample-data")) # link=True would require having .mlem folder somewhere - save(model, dir, link=False) + save(model, path, link=False) def test_model_saving_in_mlem_repo_root(model_train_target, tmpdir_factory): @@ -122,8 +156,8 @@ def test_load_link_with_fsspec_path(current_test_branch): "path": f"github://{MLEM_TEST_REPO_ORG}:{MLEM_TEST_REPO_NAME}@{quote_plus(current_test_branch)}/simple/data/model/mlem.yaml", "object_type": "link", } - with tempfile.TemporaryDirectory() as dir: - path = os.path.join(dir, "link.mlem.yaml") + with tempfile.TemporaryDirectory() as dirname: + path = os.path.join(dirname, "link.mlem.yaml") with open(path, "w", encoding="utf-8") as f: f.write(yaml.safe_dump(link_contents)) model = load(path)