diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6b2b842a..5e921952 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,7 +50,7 @@ repos: hooks: - id: pylint name: pylint - entry: pylint + entry: pylint -v language: system types: [ python ] # - repo: https://github.com/PyCQA/bandit diff --git a/mlem/core/errors.py b/mlem/core/errors.py index 9f070cfe..709b58a6 100644 --- a/mlem/core/errors.py +++ b/mlem/core/errors.py @@ -33,7 +33,11 @@ def __init__(self, path, fs=None, rev=None) -> None: super().__init__(self.message) -class RevisionNotFound(MlemError): +class LocationNotFound(MlemError): + """Thrown if MLEM could not resolve location""" + + +class RevisionNotFound(LocationNotFound): _message = "Revision '{rev}' wasn't found in path={path}, fs={fs}" def __init__( diff --git a/mlem/core/meta_io.py b/mlem/core/meta_io.py index d5353ef8..11b4cfff 100644 --- a/mlem/core/meta_io.py +++ b/mlem/core/meta_io.py @@ -14,6 +14,7 @@ from mlem.core.errors import ( HookNotFound, InvalidArgumentError, + LocationNotFound, MlemObjectNotFound, RevisionNotFound, ) @@ -32,6 +33,7 @@ class Location(BaseModel): repo: Optional[str] rev: Optional[str] uri: str + repo_uri: Optional[str] fs: AbstractFileSystem class Config: @@ -45,14 +47,6 @@ def fullpath(self): def path_in_repo(self): return posixpath.relpath(self.fullpath, self.repo) - @property - def repo_uri(self): - if self.repo is None: - return None - # not sure if this is ok - # maybe we need to merge Location with UriResolver and implement this separately for each case - return self.uri[: -len(self.path)] - @contextlib.contextmanager def open(self, mode="r", **kwargs): with self.fs.open(self.fullpath, mode, **kwargs) as f: @@ -60,7 +54,7 @@ def open(self, mode="r", **kwargs): @classmethod def abs(cls, path: str, fs: AbstractFileSystem): - return Location(path=path, repo=None, fs=fs, uri=path) + return Location(path=path, repo=None, fs=fs, uri=path, repo_uri=None) def update_path(self, path): if not self.uri.endswith(self.path): @@ -71,6 +65,9 @@ def update_path(self, path): def exists(self): return self.fs.exists(self.fullpath) + def is_same_repo(self, other: "Location"): + return other.fs == self.fs and other.repo == self.repo + class UriResolver(ABC): impls: List[Type["UriResolver"]] = [] @@ -146,12 +143,14 @@ def process( fs, path = cls.get_fs(path, rev) if repo is None and find_repo: path, repo = cls.get_repo(path, fs) + uri = cls.get_uri(path, repo, rev, fs) return Location( path=path, repo=repo, rev=rev, - uri=cls.get_uri(path, repo, rev, fs), + uri=uri, fs=fs, + repo_uri=cls.get_repo_uri(path, repo, rev, fs, uri), ) @classmethod @@ -186,6 +185,19 @@ def get_repo( path = posixpath.relpath(path, repo) return path, repo + @classmethod + def get_repo_uri( # pylint: disable=unused-argument + cls, + path: str, + repo: Optional[str], + rev: Optional[str], + fs: AbstractFileSystem, + uri: str, + ): + if repo is None: + return None + return uri[: -len(path)] + class GithubResolver(UriResolver): PROTOCOL = "github://" @@ -214,7 +226,11 @@ def get_fs( ) -> Tuple[GithubFileSystem, str]: options = get_github_envs() if not uri.startswith(cls.PROTOCOL): - options.update(get_github_kwargs(uri)) + try: + github_kwargs = get_github_kwargs(uri) + except ValueError as e: + raise LocationNotFound(*e.args) from e + options.update(github_kwargs) path = options.pop("path") options["sha"] = rev or options.get("sha", None) else: @@ -228,7 +244,9 @@ def get_fs( options["org"], options["repo"], options["sha"] ): raise RevisionNotFound(options["sha"], uri) from e - raise + raise LocationNotFound( + f"Could not resolve github location {uri}" + ) from e return fs, path @classmethod @@ -266,6 +284,17 @@ def pre_process( return path, repo, rev, fs + @classmethod + def get_repo_uri( + cls, + path: str, + repo: Optional[str], + rev: Optional[str], + fs: GithubFileSystem, + uri: str, + ): + return f"https://github.com/{fs.org}/{fs.repo}/" + class FSSpecResolver(UriResolver): @classmethod diff --git a/mlem/core/objects.py b/mlem/core/objects.py index 58bdfa9b..22107020 100644 --- a/mlem/core/objects.py +++ b/mlem/core/objects.py @@ -281,20 +281,30 @@ def make_link( raise MlemObjectNotSavedError( "Cannot create link for not saved meta object" ) - if absolute: - link = MlemLink( - path=self.loc.path, - repo=self.loc.repo_uri, - rev=self.loc.rev, - link_type=self.resolved_type, + link = MlemLink( + path=self.loc.path, + repo=self.loc.repo_uri, + rev=self.loc.rev, + link_type=self.resolved_type, + ) + if path is not None: + ( + location, + _, + ) = link._parse_dump_args( # pylint: disable=protected-access + path, repo, fs, False, external=external ) - else: - link = MlemLink( - path=self.get_metafile_path(self.name), - link_type=self.resolved_type, + if ( + not absolute + and self.loc.is_same_repo(location) + and self.loc.rev is None + ): + link.path = self.get_metafile_path(self.name) + link.link_type = self.resolved_type + link.repo = None + link._write_meta( # pylint: disable=protected-access + location, False ) - if path is not None: - link.dump(path, fs, repo, external=external, link=False) return link @classmethod diff --git a/tests/api/test_commands.py b/tests/api/test_commands.py index 95b6e7ab..688befd5 100644 --- a/tests/api/test_commands.py +++ b/tests/api/test_commands.py @@ -72,10 +72,36 @@ def test_link_in_mlem_dir(model_path_mlem_repo): assert os.path.exists(link_dumped_to) loaded_link_object = load_meta(link_dumped_to, follow_links=False) assert isinstance(loaded_link_object, MlemLink) + assert loaded_link_object.repo is None + assert loaded_link_object.rev is None + assert ( + loaded_link_object.path + == os.path.relpath(model_path, mlem_repo) + MLEM_EXT + ) model = load_meta(link_dumped_to) assert isinstance(model, ModelMeta) +@long +def test_link_from_remote_to_local(current_test_branch, mlem_repo): + link( + "simple/data/model", + source_repo=MLEM_TEST_REPO, + rev="main", + target="remote", + target_repo=mlem_repo, + ) + loaded_link_object = load_meta( + "remote", repo=mlem_repo, follow_links=False + ) + assert isinstance(loaded_link_object, MlemLink) + assert loaded_link_object.repo == MLEM_TEST_REPO + assert loaded_link_object.rev == "main" + assert loaded_link_object.path == "simple/data/model" + MLEM_EXT + model = loaded_link_object.load_link() + assert isinstance(model, ModelMeta) + + def test_ls_local(filled_mlem_repo): objects = ls(filled_mlem_repo) assert len(objects) == 1