Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
fix requirement collection for pipelines (#284)
Browse files Browse the repository at this point in the history
* fix requirement collection for pipelines

* add predict_proba for pipelines
  • Loading branch information
mike0sv authored Jun 7, 2022
1 parent 471eec6 commit b70975d
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 5 deletions.
4 changes: 3 additions & 1 deletion mlem/contrib/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
def rename_recursively(model: Type[BaseModel], prefix: str):
model.__name__ = f"{prefix}_{model.__name__}"
for field in model.__fields__.values():
if issubclass(field.type_, BaseModel):
if isinstance(field.type_, type) and issubclass(
field.type_, BaseModel
):
rename_recursively(field.type_, prefix)


Expand Down
27 changes: 24 additions & 3 deletions mlem/contrib/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def process(
return SklearnModel(io=SimplePickleIO(), methods=methods).bind(obj)

def get_requirements(self) -> Requirements:
if get_object_base_module(self.model) is sklearn:
if get_object_base_module(self.model) is sklearn and not isinstance(
self.model, Pipeline
):
return Requirements.resolve(
InstallableRequirement.from_module(sklearn)
) + get_object_requirements(
Expand All @@ -80,7 +82,7 @@ class SklearnPipelineType(SklearnModel):
def process(
cls, obj: Any, sample_data: Optional[Any] = None, **kwargs
) -> ModelType:
mt = SklearnModel(io=SimplePickleIO(), methods={}).bind(obj)
mt = SklearnPipelineType(io=SimplePickleIO(), methods={}).bind(obj)
predict = obj.predict
predict_args = {"X": sample_data}
if hasattr(predict, "__wrapped__"):
Expand All @@ -90,8 +92,27 @@ def process(
predict, auto_infer=sample_data is not None, **predict_args
)
mt.methods["sklearn_predict"] = sk_predict_sig
predict_sig = sk_predict_sig.copy()
predict_sig = sk_predict_sig.copy(deep=True)
predict_sig.args[0].name = "data"
predict_sig.varkw = None
predict_sig.name = PREDICT_METHOD_NAME
mt.methods[PREDICT_METHOD_NAME] = predict_sig

if hasattr(obj, "predict_proba"):
predict_proba = obj.predict_proba
predict_proba_args = {"X": sample_data}
if hasattr(predict_proba, "__wrapped__"):
predict_proba = predict_proba.__wrapped__
predict_proba_args["self"] = obj
sk_predict_proba_sig = Signature.from_method(
predict_proba,
auto_infer=sample_data is not None,
**predict_proba_args
)
mt.methods["sklearn_predict_proba"] = sk_predict_proba_sig
predict_proba_sig = sk_predict_proba_sig.copy(deep=True)
predict_proba_sig.args[0].name = "data"
predict_proba_sig.varkw = None
predict_proba_sig.name = PREDICT_PROBA_METHOD_NAME
mt.methods[PREDICT_PROBA_METHOD_NAME] = predict_proba_sig
return mt
4 changes: 3 additions & 1 deletion mlem/runtime/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ def __getattr__(self, name):
raise WrongMethodError(f"{name} method is not exposed by server")
return _MethodCall(
method=self.methods[name],
name=name,
call_method=self._call_method,
)


class _MethodCall(BaseModel):
method: Signature
name: str
call_method: Callable

def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -83,7 +85,7 @@ def __call__(self, *args, **kwargs):
logger.debug(
'Calling server method "%s", args: %s ...', self.method.name, data
)
out = self.call_method(self.method.name, data)
out = self.call_method(self.name, data)
logger.debug("Server call returned %s", out)
return self.method.returns.get_serializer().deserialize(out)

Expand Down
9 changes: 9 additions & 0 deletions tests/contrib/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mlem.core.artifacts import LOCAL_STORAGE
from mlem.core.data_type import DataAnalyzer
from mlem.core.model import ModelAnalyzer
from mlem.core.objects import MlemModel
from mlem.core.requirements import UnixPackageRequirement
from tests.conftest import check_model_type_common_interface, long

Expand Down Expand Up @@ -163,6 +164,14 @@ def test_model_type_lgb__dump_load(tmpdir, lgbm_model, inp_data):
]


def test_pipeline_requirements(lgbm_model):
model = Pipeline(steps=[("model", lgbm_model)])
meta = MlemModel.from_obj(model)

expected_requirements = {"sklearn", "lightgbm", "pandas", "numpy", "scipy"}
assert set(meta.requirements.modules) == expected_requirements


# Copyright 2019 Zyfra
# Copyright 2021 Iterative
#
Expand Down

0 comments on commit b70975d

Please sign in to comment.