Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
fix dvc download (#109)
Browse files Browse the repository at this point in the history
* fix dvc download

* fix lazy field for dataset
add gitpython to deps
fix dvc downloading

* conditional pull

* fix isadir error

* fix catboost and upload saving

* fix tests

* fix #83

* fix requirement duplication
requirements tests
requirement expand

* fix test

* spelling
  • Loading branch information
mike0sv authored Nov 11, 2021
1 parent d240f1f commit c7f90a6
Show file tree
Hide file tree
Showing 30 changed files with 496 additions and 80 deletions.
12 changes: 9 additions & 3 deletions mlem/cli/apply.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Tuple
from typing import Any, Optional, Tuple

import click

Expand All @@ -25,13 +25,19 @@
help="Whether to create links for outputs in .mlem directory.",
)
def apply(
model: ModelMeta, output: str, method: str, args: Tuple[Any], link: bool
model: ModelMeta,
output: Optional[str],
method: str,
args: Tuple[Any],
link: bool,
):
"""Apply a model to supplied data."""
from mlem.api import apply

click.echo("applying")
apply(model, *args, method=method, output=output, link=link)
result = apply(model, *args, method=method, output=output, link=link)
if output is None:
click.echo(result)


@mlem_command()
Expand Down
3 changes: 3 additions & 0 deletions mlem/contrib/catboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def _get_model_file_name(self, model):
return self.classifier_file_name
return self.regressor_file_name

class Config:
use_enum_values = True


class CatBoostModel(ModelType, ModelHook):
"""
Expand Down
41 changes: 35 additions & 6 deletions mlem/contrib/dvc.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
import contextlib
import os.path
from typing import IO, ClassVar, Iterator, Tuple
from urllib.parse import unquote_plus

from fsspec import AbstractFileSystem
from fsspec.implementations.github import GithubFileSystem
from fsspec.implementations.local import LocalFileSystem

from mlem.core.artifacts import LocalArtifact, LocalStorage, Storage
from mlem.core.meta_io import get_fs

BATCH_SIZE = 10 ** 5


def find_dvc_repo_root(path: str):
from dvc.exceptions import NotDvcRepoError

_path = path[:]
while True:
if os.path.isdir(os.path.join(_path, ".dvc")):
return _path
if _path == "/":
break
_path = os.path.dirname(_path)
raise NotDvcRepoError(f"Path {path} is not in dvc repo")


class DVCStorage(LocalStorage):
"""For now this storage is user-managed dvc storage, which means user should
Expand Down Expand Up @@ -38,9 +55,13 @@ class DVCArtifact(LocalArtifact):
uri: str

def download(self, target_path: str) -> LocalArtifact:
from dvc.repo import Repo

Repo.get_url(self.uri, out=target_path)
with self.open() as fin, open(
os.path.join(target_path, os.path.basename(self.uri)), "wb"
) as fout:
batch = fin.read(BATCH_SIZE)
while batch:
fout.write(batch)
batch = fin.read(BATCH_SIZE)
return LocalArtifact(uri=target_path)

@contextlib.contextmanager
Expand All @@ -58,9 +79,17 @@ def open(self) -> Iterator[IO]:
mode="rb",
) as f:
yield f
else:
with fs.open(path) as f:
yield f
return
elif isinstance(fs, LocalFileSystem):
if not os.path.exists(path):
root = find_dvc_repo_root(path)
# alternative caching impl
# Repo(root).pull(os.path.relpath(path, root))
with open(os.path.relpath(path, root), mode="rb") as f:
yield f
return
with fs.open(path) as f:
yield f

def relative(self, fs: AbstractFileSystem, path: str) -> "DVCArtifact":
relative = super().relative(fs, path)
Expand Down
13 changes: 13 additions & 0 deletions mlem/contrib/lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from mlem.core.hooks import IsInstanceHookMixin
from mlem.core.model import ModelHook, ModelIO, ModelType, Signature
from mlem.core.requirements import (
AddRequirementHook,
InstallableRequirement,
Requirement,
Requirements,
UnixPackageRequirement,
)
Expand Down Expand Up @@ -131,3 +133,14 @@ def get_requirements(self) -> Requirements:
+ InstallableRequirement.from_module(mod=lgb)
+ LGB_REQUIREMENT
)


class LGBMLibgompHook(AddRequirementHook):
to_add = LGB_REQUIREMENT

@classmethod
def is_object_valid(cls, obj: Requirement) -> bool:
return (
isinstance(obj, InstallableRequirement)
and obj.module == "lightgbm"
)
4 changes: 3 additions & 1 deletion mlem/contrib/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,6 @@ def get_requirements(self) -> Requirements:
) # FIXME: https://github.com/iterative/mlem/issues/34 # optimize methods reqs

# some sklearn compatible model (either from library or user code) - fallback
return super().get_requirements()
return super().get_requirements() + InstallableRequirement.from_module(
sklearn
)
12 changes: 12 additions & 0 deletions mlem/contrib/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from mlem.core.hooks import IsInstanceHookMixin
from mlem.core.model import ModelHook, ModelIO, ModelType, Signature
from mlem.core.requirements import (
AddRequirementHook,
InstallableRequirement,
Requirement,
Requirements,
UnixPackageRequirement,
WithRequirements,
Expand Down Expand Up @@ -172,3 +174,13 @@ def get_requirements(self) -> Requirements:
+ InstallableRequirement.from_module(xgboost)
+ XGB_REQUIREMENT
)


class XGBLibgopmHook(AddRequirementHook):
to_add = XGB_REQUIREMENT

@classmethod
def is_object_valid(cls, obj: Requirement) -> bool:
return (
isinstance(obj, InstallableRequirement) and obj.module == "xgboost"
)
16 changes: 11 additions & 5 deletions mlem/core/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,10 @@ class FSSpecStorage(Storage):
storage_options: Optional[Dict[str, str]] = {}

def upload(self, local_path: str, target_path: str) -> FSSpecArtifact:
self.get_fs().upload(
local_path, os.path.join(self.base_path, target_path)
)
fs = self.get_fs()
path = os.path.join(self.base_path, target_path)
fs.makedirs(os.path.dirname(path), exist_ok=True)
fs.upload(local_path, path)
return FSSpecArtifact(uri=self.create_uri(target_path))

@contextlib.contextmanager
Expand Down Expand Up @@ -146,6 +147,10 @@ class LocalStorage(FSSpecStorage):
type: ClassVar = "local"
fs = LocalFileSystem()

@property
def base_path(self):
return self.uri

def relative(self, fs: AbstractFileSystem, path: str) -> "Storage":
if isinstance(fs, LocalFileSystem):
return LocalStorage(uri=self.create_uri(path))
Expand All @@ -158,11 +163,12 @@ def relative(self, fs: AbstractFileSystem, path: str) -> "Storage":
return storage

def upload(self, local_path: str, target_path: str) -> "LocalArtifact":
return LocalArtifact(uri=super().upload(local_path, target_path).uri)
super().upload(local_path, target_path)
return LocalArtifact(uri=target_path)

@contextlib.contextmanager
def open(self, path) -> Iterator[Tuple[IO, "LocalArtifact"]]:
with super().open(os.path.join(self.uri, path)) as (io, _):
with super().open(path) as (io, _):
yield io, LocalArtifact(uri=path)


Expand Down
8 changes: 8 additions & 0 deletions mlem/core/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,11 @@ class MlemObjectNotSavedError(ValueError, MlemError):

class ObjectExistsError(ValueError, MlemError):
"""Thrown if we attempt to write object, but something already exists in the given path"""


class HookNotFound(MlemError):
"""Thrown if object does not have suitable hook"""


class MultipleHooksFound(MlemError):
"""Thrown if more than one hook found for object"""
10 changes: 6 additions & 4 deletions mlem/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Generic, List, Tuple, Type, TypeVar

from mlem.core.errors import HookNotFound, MultipleHooksFound

logger = logging.getLogger(__name__)

ANALYZER_FIELD = "analyzer"
Expand Down Expand Up @@ -40,7 +42,7 @@ def is_object_valid(cls, obj: Any) -> bool:
:param obj: object to analyze
:return: True or False
"""
raise NotImplementedError()
raise NotImplementedError

@classmethod
@abstractmethod
Expand All @@ -52,7 +54,7 @@ def process(cls, obj: Any, **kwargs) -> T:
:param kwargs: additional information to be used for analysis
:return: analysis result
"""
raise NotImplementedError()
raise NotImplementedError

def __init_subclass__(cls, *args, **kwargs):
if not inspect.isabstract(cls):
Expand Down Expand Up @@ -186,10 +188,10 @@ def _find_hook(cls, obj) -> Type[Hook[T]]:
if len(lp_hooks) == 1:
return lp_hooks[0]
if len(lp_hooks) > 1:
raise ValueError(
raise MultipleHooksFound(
f"Multiple suitable hooks for object {obj} ({lp_hooks})"
)
raise ValueError(
raise HookNotFound(
f"No suitable {cls.base_hook_class.__name__} for object of type "
f'"{type(obj).__name__}". Registered hooks: {cls.hooks}'
)
Expand Down
2 changes: 1 addition & 1 deletion mlem/core/meta_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_path_by_fs_path(fs: AbstractFileSystem, path: str):
Another alternative is to support this on fsspec level, but we need to contribute it ourselves"""
if isinstance(fs, GithubFileSystem):
# here "rev" should be already url encoded
return f"{fs.protocol}://{fs.org}:{fs.repo}@{fs.root}/{path}"
return f"https://github.com/{fs.org}/{fs.repo}/tree/{fs.root}/{path}"
protocol = fs.protocol
if isinstance(protocol, (list, tuple)):
if any(path.startswith(p) for p in protocol):
Expand Down
3 changes: 2 additions & 1 deletion mlem/core/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing_extensions import Literal
from yaml import safe_load

from mlem.core.errors import HookNotFound
from mlem.core.meta_io import get_fs, get_meta_path, get_path_by_repo_path_rev
from mlem.core.objects import DatasetMeta, MlemMeta, ModelMeta, find_object
from mlem.utils.root import find_mlem_root
Expand All @@ -19,7 +20,7 @@ def get_object_metadata(obj: Any, tmp_sample_data=None) -> MlemMeta:
"""Convert given object to appropriate MlemMeta subclass"""
try:
return DatasetMeta.from_data(obj)
except ValueError: # TODO need separate analysis exception
except HookNotFound:
return ModelMeta.from_obj(obj, sample_data=tmp_sample_data)


Expand Down
5 changes: 3 additions & 2 deletions mlem/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from mlem.core.hooks import Analyzer, Hook
from mlem.core.requirements import Requirements, WithRequirements
from mlem.utils.module import get_object_requirements


class ModelIO(MlemObject):
Expand Down Expand Up @@ -255,7 +256,7 @@ def get_requirements(self) -> Requirements:
for m in self.methods.values()
for r in m.get_requirements().__root__
]
)
) + get_object_requirements(self.model)


class ModelHook(Hook[ModelType], ABC):
Expand All @@ -264,7 +265,7 @@ class ModelHook(Hook[ModelType], ABC):
def process( # pylint: disable=arguments-differ # so what
cls, obj: Any, sample_data: Optional[Any] = None, **kwargs
) -> ModelType:
pass
raise NotImplementedError


class ModelAnalyzer(Analyzer[ModelType]):
Expand Down
18 changes: 14 additions & 4 deletions mlem/core/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,9 @@ class ModelMeta(_ExternalMeta):
def from_obj(cls, model: Any, sample_data: Any = None) -> "ModelMeta":
mt = ModelAnalyzer.analyze(model, sample_data=sample_data)
mt.model = model
return ModelMeta(model_type=mt, requirements=mt.get_requirements())
return ModelMeta(
model_type=mt, requirements=mt.get_requirements().expanded
)

def write_value(self, mlem_root: str) -> Artifacts:
if self.model_type.model is not None:
Expand All @@ -417,7 +419,7 @@ def write_value(self, mlem_root: str) -> Artifacts:
self.model_type.model,
)
else:
raise NotImplementedError() # TODO: https://github.com/iterative/mlem/issues/37
raise NotImplementedError # TODO: https://github.com/iterative/mlem/issues/37
# self.get_artifacts().materialize(path)
return artifacts

Expand All @@ -438,7 +440,15 @@ def __getattr__(self, item):
class DatasetMeta(_ExternalMeta):
__transient_fields__ = {"dataset"}
object_type: ClassVar = "dataset"
reader: Optional[DatasetReader] = None
reader_cache: Optional[Dict]
reader: Optional[DatasetReader]
reader, reader_raw, reader_cache = lazy_field(
DatasetReader,
"reader",
"reader_cache",
parse_as_type=Optional[DatasetReader],
default=None,
)
dataset: ClassVar[Optional[Dataset]] = None

@property
Expand All @@ -451,7 +461,7 @@ def from_data(cls, data: Any) -> "DatasetMeta":
data,
)
meta = DatasetMeta(
requirements=dataset.dataset_type.get_requirements()
requirements=dataset.dataset_type.get_requirements().expanded
)
meta.dataset = dataset
return meta
Expand Down
Loading

0 comments on commit c7f90a6

Please sign in to comment.