diff --git a/mlem/api/commands.py b/mlem/api/commands.py index f574164b..f3d5c757 100644 --- a/mlem/api/commands.py +++ b/mlem/api/commands.py @@ -91,7 +91,9 @@ def apply( batch_data = get_data_value(part, batch_size) for batch in batch_data: preds = w.call_method(resolved_method, batch.data) - res += [*preds] # TODO: merge results + res.append(preds) + dt = w.methods[resolved_method].returns + res = dt.combine(res) else: res = [ w.call_method(resolved_method, get_data_value(part)) diff --git a/mlem/contrib/lightgbm.py b/mlem/contrib/lightgbm.py index b45fad44..ee2537f9 100644 --- a/mlem/contrib/lightgbm.py +++ b/mlem/contrib/lightgbm.py @@ -75,6 +75,9 @@ def process(cls, obj: Any, **kwargs) -> DataType: def get_model(self, prefix: str = "") -> Type[BaseModel]: return self.inner.get_serializer().get_model(prefix) + def combine(self, batched_data: List[lgb.Dataset]) -> lgb.Dataset: + raise NotImplementedError + class LightGBMDataWriter(DataWriter): type: ClassVar[str] = "lightgbm" diff --git a/mlem/contrib/numpy.py b/mlem/contrib/numpy.py index 63169805..88000124 100644 --- a/mlem/contrib/numpy.py +++ b/mlem/contrib/numpy.py @@ -12,7 +12,11 @@ DataType, DataWriter, ) -from mlem.core.errors import DeserializationError, SerializationError +from mlem.core.errors import ( + DeserializationError, + InvalidDatatypeForBatchLoading, + SerializationError, +) from mlem.core.requirements import LibRequirementsMixin @@ -79,6 +83,9 @@ def get_writer(self, project: str = None, filename: str = None, **kwargs): def get_model(self, prefix: str = "") -> Type: return python_type_from_np_string_repr(self.dtype) + def combine(self, batched_data: List[np.number]) -> np.number: + raise NotImplementedError + class NumpyNdarrayType( LibRequirementsMixin, DataType, DataHook, DataSerializer @@ -101,6 +108,17 @@ class NumpyNdarrayType( def _abstract_shape(shape): return (None,) + shape[1:] + def combine(self, batched_data: List[np.ndarray]) -> np.ndarray: + is_valid_type = all( + self.is_object_valid(elem) for elem in batched_data + ) + if not is_valid_type: + raise InvalidDatatypeForBatchLoading( + f"Expected all values to be {self.type} for batch-loading." + ) + + return np.concatenate(batched_data, axis=0) + @classmethod def process(cls, obj, **kwargs) -> DataType: return NumpyNdarrayType( diff --git a/mlem/contrib/pandas.py b/mlem/contrib/pandas.py index 3a62a4a8..8cb6936f 100644 --- a/mlem/contrib/pandas.py +++ b/mlem/contrib/pandas.py @@ -217,6 +217,9 @@ def serialize(self, instance: pd.DataFrame): return {"values": (instance.to_dict("records"))} + def combine(self, batched_data: List[pd.DataFrame]) -> pd.DataFrame: + raise NotImplementedError + class SeriesType(_PandasDataType): """ diff --git a/mlem/contrib/torch.py b/mlem/contrib/torch.py index c2b373b1..ce85a210 100644 --- a/mlem/contrib/torch.py +++ b/mlem/contrib/torch.py @@ -97,6 +97,9 @@ def process(cls, obj: torch.Tensor, **kwargs) -> DataType: dtype=str(obj.dtype)[len("torch") + 1 :], ) + def combine(self, batched_data: List[torch.Tensor]) -> torch.Tensor: + raise NotImplementedError + class TorchTensorWriter(DataWriter): type: ClassVar[str] = "torch" diff --git a/mlem/contrib/xgboost.py b/mlem/contrib/xgboost.py index c7db3fe8..a0503984 100644 --- a/mlem/contrib/xgboost.py +++ b/mlem/contrib/xgboost.py @@ -111,6 +111,9 @@ def get_writer( def get_model(self, prefix: str = "") -> Type[BaseModel]: raise NotImplementedError + def combine(self, batched_data: List[xgboost.DMatrix]) -> xgboost.DMatrix: + raise NotImplementedError + class XGBoostModelIO(ModelIO): """ diff --git a/mlem/core/data_type.py b/mlem/core/data_type.py index 4356fbf9..7b552b3d 100644 --- a/mlem/core/data_type.py +++ b/mlem/core/data_type.py @@ -14,6 +14,7 @@ Sized, Tuple, Type, + TypeVar, Union, ) @@ -28,6 +29,8 @@ from mlem.core.requirements import Requirements, WithRequirements from mlem.utils.module import get_object_requirements +T = TypeVar("T") + class DataType(ABC, MlemABC, WithRequirements): """ @@ -52,6 +55,10 @@ def check_type(obj, exp_type, exc_type): def get_requirements(self) -> Requirements: return get_object_requirements(self) + @abstractmethod + def combine(self, batched_data: List[T]) -> T: + raise NotImplementedError + @abstractmethod def get_writer( self, project: str = None, filename: str = None, **kwargs @@ -112,10 +119,16 @@ def get_writer( def get_model(self, prefix: str = "") -> Type[BaseModel]: raise NotImplementedError + def combine(self, batched_data: List[T]) -> T: + raise NotImplementedError + class DataHook(Hook[DataType], ABC): """Base class for hooks to analyze data objects""" + def combine(self, batched_data: List[T]) -> T: + raise NotImplementedError + class DataAnalyzer(Analyzer): """Analyzer for data objects""" @@ -200,6 +213,9 @@ def get_requirements(self) -> Requirements: def get_model(self, prefix: str = "") -> Type[BaseModel]: return self.to_type + def combine(self, batched_data: List[Any]) -> Any: + raise NotImplementedError + class PrimitiveWriter(DataWriter): type: ClassVar[str] = "primitive" @@ -267,6 +283,9 @@ def get_model(self, prefix: str = "") -> Type[BaseModel]: __root__=(List[self.dtype.get_serializer().get_model(subname)], ...), # type: ignore ) + def combine(self, batched_data: List[T]) -> T: + raise NotImplementedError + class ArrayWriter(DataWriter): type: ClassVar[str] = "array" @@ -364,6 +383,9 @@ def get_model(self, prefix: str = "") -> Type[BaseModel]: ), ) + def combine(self, batched_data: List[Any]) -> Any: + raise NotImplementedError + def _check_type_and_size(obj, dtype, size, exc_type): DataType.check_type(obj, dtype, exc_type) @@ -532,6 +554,9 @@ def get_model(self, prefix="") -> Type[BaseModel]: } return create_model(prefix + "DictType", **kwargs) # type: ignore + def combine(self, batched_data: List[T]) -> T: + raise NotImplementedError + class DictWriter(DataWriter): type: ClassVar[str] = "dict" diff --git a/mlem/core/errors.py b/mlem/core/errors.py index 7b2a5aaf..b0fee8bf 100644 --- a/mlem/core/errors.py +++ b/mlem/core/errors.py @@ -91,6 +91,10 @@ class UnsupportedDataBatchLoading(MlemError): """Thrown if batch loading of data is called for import workflow""" +class InvalidDatatypeForBatchLoading(MlemError): + """Thrown if batch loading of dataset has incorrect types""" + + class WrongMethodError(ValueError, MlemError): """Thrown if wrong method name for model is provided""" diff --git a/mlem/runtime/interface.py b/mlem/runtime/interface.py index e526b517..6eec047d 100644 --- a/mlem/runtime/interface.py +++ b/mlem/runtime/interface.py @@ -25,7 +25,7 @@ class InterfaceDescriptor(BaseModel): class Interface(ABC, MlemABC): """Base class for runtime interfaces. - Describes a set of methods togerher with their signatures (arguments + Describes a set of methods together with their signatures (arguments and return type) and executors - actual python callables to be run when the method is invoked. Used to setup `Server`""" diff --git a/tests/cli/test_apply.py b/tests/cli/test_apply.py index 76af85f2..fd8e40ef 100644 --- a/tests/cli/test_apply.py +++ b/tests/cli/test_apply.py @@ -9,7 +9,7 @@ from sklearn.tree import DecisionTreeClassifier from mlem.api import load, save -from mlem.core.data_type import ArrayType +from mlem.contrib.numpy import NumpyNdarrayType from mlem.core.errors import MlemProjectNotFound from mlem.core.metadata import load_meta from mlem.core.objects import MlemData @@ -89,9 +89,9 @@ def test_apply_batch(runner, model_path_batch, data_path_batch): predictions_meta = load_meta( path, load_value=True, force_type=MlemData ) - assert isinstance(predictions_meta.data_type, ArrayType) + assert isinstance(predictions_meta.dataset, NumpyNdarrayType) predictions = predictions_meta.get_value() - assert isinstance(predictions, list) + assert isinstance(predictions, ndarray) def test_apply_with_import(runner, model_meta_saved_single, tmp_path_factory): diff --git a/tests/runtime/test_interface.py b/tests/runtime/test_interface.py index 856dddd0..991aeb75 100644 --- a/tests/runtime/test_interface.py +++ b/tests/runtime/test_interface.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar +from typing import Any, ClassVar, List import pytest @@ -30,6 +30,9 @@ def get_writer( ) -> DataWriter: raise NotImplementedError + def combine(self, batched_data: List[Any]) -> Any: + raise NotImplementedError + @pytest.fixture def interface() -> Interface: