Skip to content
This repository was archived by the owner on Sep 13, 2023. It is now read-only.

Support combining batched prediction results on DatasetType #290

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mlem/api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def apply(
batch_data = get_data_value(part, batch_size)
for batch in batch_data:
preds = w.call_method(resolved_method, batch.data)
res += [*preds] # TODO: merge results
res.append(preds)
dt = w.methods[resolved_method].returns
res = dt.combine(res)
else:
res = [
w.call_method(resolved_method, get_data_value(part))
Expand Down
3 changes: 3 additions & 0 deletions mlem/contrib/lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def process(cls, obj: Any, **kwargs) -> DataType:
def get_model(self, prefix: str = "") -> Type[BaseModel]:
return self.inner.get_serializer().get_model(prefix)

def combine(self, batched_data: List[lgb.Dataset]) -> lgb.Dataset:
raise NotImplementedError


class LightGBMDataWriter(DataWriter):
type: ClassVar[str] = "lightgbm"
Expand Down
20 changes: 19 additions & 1 deletion mlem/contrib/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
DataType,
DataWriter,
)
from mlem.core.errors import DeserializationError, SerializationError
from mlem.core.errors import (
DeserializationError,
InvalidDatatypeForBatchLoading,
SerializationError,
)
from mlem.core.requirements import LibRequirementsMixin


Expand Down Expand Up @@ -79,6 +83,9 @@ def get_writer(self, project: str = None, filename: str = None, **kwargs):
def get_model(self, prefix: str = "") -> Type:
return python_type_from_np_string_repr(self.dtype)

def combine(self, batched_data: List[np.number]) -> np.number:
raise NotImplementedError


class NumpyNdarrayType(
LibRequirementsMixin, DataType, DataHook, DataSerializer
Expand All @@ -101,6 +108,17 @@ class NumpyNdarrayType(
def _abstract_shape(shape):
return (None,) + shape[1:]

def combine(self, batched_data: List[np.ndarray]) -> np.ndarray:
is_valid_type = all(
self.is_object_valid(elem) for elem in batched_data
)
if not is_valid_type:
raise InvalidDatatypeForBatchLoading(
f"Expected all values to be {self.type} for batch-loading."
)

return np.concatenate(batched_data, axis=0)

@classmethod
def process(cls, obj, **kwargs) -> DataType:
return NumpyNdarrayType(
Expand Down
3 changes: 3 additions & 0 deletions mlem/contrib/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ def serialize(self, instance: pd.DataFrame):

return {"values": (instance.to_dict("records"))}

def combine(self, batched_data: List[pd.DataFrame]) -> pd.DataFrame:
raise NotImplementedError


class SeriesType(_PandasDataType):
"""
Expand Down
3 changes: 3 additions & 0 deletions mlem/contrib/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def process(cls, obj: torch.Tensor, **kwargs) -> DataType:
dtype=str(obj.dtype)[len("torch") + 1 :],
)

def combine(self, batched_data: List[torch.Tensor]) -> torch.Tensor:
raise NotImplementedError


class TorchTensorWriter(DataWriter):
type: ClassVar[str] = "torch"
Expand Down
3 changes: 3 additions & 0 deletions mlem/contrib/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ def get_writer(
def get_model(self, prefix: str = "") -> Type[BaseModel]:
raise NotImplementedError

def combine(self, batched_data: List[xgboost.DMatrix]) -> xgboost.DMatrix:
raise NotImplementedError


class XGBoostModelIO(ModelIO):
"""
Expand Down
25 changes: 25 additions & 0 deletions mlem/core/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Sized,
Tuple,
Type,
TypeVar,
Union,
)

Expand All @@ -28,6 +29,8 @@
from mlem.core.requirements import Requirements, WithRequirements
from mlem.utils.module import get_object_requirements

T = TypeVar("T")


class DataType(ABC, MlemABC, WithRequirements):
"""
Expand All @@ -52,6 +55,10 @@ def check_type(obj, exp_type, exc_type):
def get_requirements(self) -> Requirements:
return get_object_requirements(self)

@abstractmethod
def combine(self, batched_data: List[T]) -> T:
raise NotImplementedError

@abstractmethod
def get_writer(
self, project: str = None, filename: str = None, **kwargs
Expand Down Expand Up @@ -112,10 +119,16 @@ def get_writer(
def get_model(self, prefix: str = "") -> Type[BaseModel]:
raise NotImplementedError

def combine(self, batched_data: List[T]) -> T:
raise NotImplementedError


class DataHook(Hook[DataType], ABC):
"""Base class for hooks to analyze data objects"""

def combine(self, batched_data: List[T]) -> T:
raise NotImplementedError


class DataAnalyzer(Analyzer):
"""Analyzer for data objects"""
Expand Down Expand Up @@ -200,6 +213,9 @@ def get_requirements(self) -> Requirements:
def get_model(self, prefix: str = "") -> Type[BaseModel]:
return self.to_type

def combine(self, batched_data: List[Any]) -> Any:
raise NotImplementedError


class PrimitiveWriter(DataWriter):
type: ClassVar[str] = "primitive"
Expand Down Expand Up @@ -267,6 +283,9 @@ def get_model(self, prefix: str = "") -> Type[BaseModel]:
__root__=(List[self.dtype.get_serializer().get_model(subname)], ...), # type: ignore
)

def combine(self, batched_data: List[T]) -> T:
raise NotImplementedError


class ArrayWriter(DataWriter):
type: ClassVar[str] = "array"
Expand Down Expand Up @@ -364,6 +383,9 @@ def get_model(self, prefix: str = "") -> Type[BaseModel]:
),
)

def combine(self, batched_data: List[Any]) -> Any:
raise NotImplementedError


def _check_type_and_size(obj, dtype, size, exc_type):
DataType.check_type(obj, dtype, exc_type)
Expand Down Expand Up @@ -532,6 +554,9 @@ def get_model(self, prefix="") -> Type[BaseModel]:
}
return create_model(prefix + "DictType", **kwargs) # type: ignore

def combine(self, batched_data: List[T]) -> T:
raise NotImplementedError


class DictWriter(DataWriter):
type: ClassVar[str] = "dict"
Expand Down
4 changes: 4 additions & 0 deletions mlem/core/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ class UnsupportedDataBatchLoading(MlemError):
"""Thrown if batch loading of data is called for import workflow"""


class InvalidDatatypeForBatchLoading(MlemError):
"""Thrown if batch loading of dataset has incorrect types"""


class WrongMethodError(ValueError, MlemError):
"""Thrown if wrong method name for model is provided"""

Expand Down
2 changes: 1 addition & 1 deletion mlem/runtime/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class InterfaceDescriptor(BaseModel):

class Interface(ABC, MlemABC):
"""Base class for runtime interfaces.
Describes a set of methods togerher with their signatures (arguments
Describes a set of methods together with their signatures (arguments
and return type) and executors - actual python callables to be run
when the method is invoked. Used to setup `Server`"""

Expand Down
6 changes: 3 additions & 3 deletions tests/cli/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sklearn.tree import DecisionTreeClassifier

from mlem.api import load, save
from mlem.core.data_type import ArrayType
from mlem.contrib.numpy import NumpyNdarrayType
from mlem.core.errors import MlemProjectNotFound
from mlem.core.metadata import load_meta
from mlem.core.objects import MlemData
Expand Down Expand Up @@ -89,9 +89,9 @@ def test_apply_batch(runner, model_path_batch, data_path_batch):
predictions_meta = load_meta(
path, load_value=True, force_type=MlemData
)
assert isinstance(predictions_meta.data_type, ArrayType)
assert isinstance(predictions_meta.dataset, NumpyNdarrayType)
predictions = predictions_meta.get_value()
assert isinstance(predictions, list)
assert isinstance(predictions, ndarray)


def test_apply_with_import(runner, model_meta_saved_single, tmp_path_factory):
Expand Down
5 changes: 4 additions & 1 deletion tests/runtime/test_interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, ClassVar
from typing import Any, ClassVar, List

import pytest

Expand Down Expand Up @@ -30,6 +30,9 @@ def get_writer(
) -> DataWriter:
raise NotImplementedError

def combine(self, batched_data: List[Any]) -> Any:
raise NotImplementedError


@pytest.fixture
def interface() -> Interface:
Expand Down