diff --git a/mlem/contrib/lightgbm.py b/mlem/contrib/lightgbm.py index b45fad44..a2bf6f51 100644 --- a/mlem/contrib/lightgbm.py +++ b/mlem/contrib/lightgbm.py @@ -1,8 +1,9 @@ import os import posixpath import tempfile -from typing import Any, ClassVar, Iterator, List, Optional, Tuple, Type +from typing import Any, ClassVar, Iterator, Optional, Tuple, Type +import flatdict import lightgbm as lgb from pydantic import BaseModel @@ -28,6 +29,8 @@ ) LGB_REQUIREMENT = UnixPackageRequirement(package_name="libgomp1") +LIGHTGBM_DATA = "inner" +LIGHTGBM_LABEL = "label" class LightGBMDataType( @@ -37,20 +40,38 @@ class LightGBMDataType( :class:`.DataType` implementation for `lightgbm.Dataset` type :param inner: :class:`.DataType` instance for underlying data + :param labels: :class:`.DataType` instance for underlying labels """ type: ClassVar[str] = "lightgbm" valid_types: ClassVar = (lgb.Dataset,) inner: DataType + labels: Optional[DataType] def serialize(self, instance: Any) -> dict: self.check_type(instance, lgb.Dataset, SerializationError) + if self.labels is not None: + return { + LIGHTGBM_DATA: self.inner.get_serializer().serialize( + instance.data + ), + LIGHTGBM_LABEL: self.labels.get_serializer().serialize( + instance.label + ), + } return self.inner.get_serializer().serialize(instance.data) def deserialize(self, obj: dict) -> Any: - v = self.inner.get_serializer().deserialize(obj) + if self.labels is not None: + data = self.inner.get_serializer().deserialize(obj[LIGHTGBM_DATA]) + label = self.labels.get_serializer().deserialize( + obj[LIGHTGBM_LABEL] + ) + else: + data = self.inner.get_serializer().deserialize(obj) + label = None try: - return lgb.Dataset(v, free_raw_data=False) + return lgb.Dataset(data, label=label, free_raw_data=False) except ValueError as e: raise DeserializationError( f"object: {obj} could not be converted to lightgbm dataset" @@ -70,7 +91,12 @@ def get_writer( @classmethod def process(cls, obj: Any, **kwargs) -> DataType: - return LightGBMDataType(inner=DataAnalyzer.analyze(obj.data)) + return LightGBMDataType( + inner=DataAnalyzer.analyze(obj.data), + labels=DataAnalyzer.analyze(obj.label) + if obj.label is not None + else None, + ) def get_model(self, prefix: str = "") -> Type[BaseModel]: return self.inner.get_serializer().get_model(prefix) @@ -86,19 +112,42 @@ def write( raise ValueError( f"expected data to be of LightGBMDataType, got {type(data)} instead" ) - lightgbm_construct = data.data.construct() - raw_data = lightgbm_construct.get_data() - underlying_labels = lightgbm_construct.get_label().tolist() - inner_reader, art = data.inner.get_writer().write( - data.inner.copy().bind(raw_data), storage, path - ) + + lightgbm_raw = data.data + + if data.labels is not None: + inner_reader, inner_art = data.inner.get_writer().write( + data.inner.copy().bind(lightgbm_raw.data), + storage, + posixpath.join(path, LIGHTGBM_DATA), + ) + labels_reader, labels_art = data.labels.get_writer().write( + data.labels.copy().bind(lightgbm_raw.label), + storage, + posixpath.join(path, LIGHTGBM_LABEL), + ) + res = dict( + flatdict.FlatterDict( + {LIGHTGBM_DATA: inner_art, LIGHTGBM_LABEL: labels_art}, + delimiter="/", + ) + ) + else: + inner_reader, inner_art = data.inner.get_writer().write( + data.inner.copy().bind(lightgbm_raw.data), + storage, + path, + ) + res = inner_art + labels_reader = None + return ( LightGBMDataReader( data_type=data, inner=inner_reader, - label=underlying_labels, + labels=labels_reader, ), - art, + res, ) @@ -106,13 +155,25 @@ class LightGBMDataReader(DataReader): type: ClassVar[str] = "lightgbm" data_type: LightGBMDataType inner: DataReader - label: List + labels: Optional[DataReader] def read(self, artifacts: Artifacts) -> DataType: - inner_data_type = self.inner.read(artifacts) - return LightGBMDataType(inner=inner_data_type).bind( + if self.labels is not None: + artifacts = flatdict.FlatterDict(artifacts, delimiter="/") + inner_data_type = self.inner.read(artifacts[LIGHTGBM_DATA]) # type: ignore[arg-type] + labels_data_type = self.labels.read(artifacts[LIGHTGBM_LABEL]) # type: ignore[arg-type] + else: + inner_data_type = self.inner.read(artifacts) + labels_data_type = None + return LightGBMDataType( + inner=inner_data_type, labels=labels_data_type + ).bind( lgb.Dataset( - inner_data_type.data, label=self.label, free_raw_data=False + inner_data_type.data, + label=labels_data_type.data + if labels_data_type is not None + else None, + free_raw_data=False, ) ) diff --git a/tests/contrib/test_lightgbm.py b/tests/contrib/test_lightgbm.py index bd2e193a..6addd714 100644 --- a/tests/contrib/test_lightgbm.py +++ b/tests/contrib/test_lightgbm.py @@ -4,6 +4,8 @@ import pytest from mlem.contrib.lightgbm import ( + LIGHTGBM_DATA, + LIGHTGBM_LABEL, LightGBMDataReader, LightGBMDataType, LightGBMDataWriter, @@ -12,7 +14,12 @@ from mlem.contrib.numpy import NumpyNdarrayType from mlem.contrib.pandas import DataFrameType from mlem.core.artifacts import LOCAL_STORAGE -from mlem.core.data_type import DataAnalyzer, DataType +from mlem.core.data_type import ( + ArrayType, + DataAnalyzer, + DataType, + PrimitiveType, +) from mlem.core.errors import DeserializationError, SerializationError from mlem.core.model import ModelAnalyzer, ModelType from mlem.core.requirements import UnixPackageRequirement @@ -46,7 +53,7 @@ def df_payload(): def data_df(df_payload): return lgb.Dataset( df_payload, - label=np.array([0, 1]).tolist(), + label=np.array([0, 1]), free_raw_data=False, ) @@ -75,6 +82,8 @@ def test_hook_np(dtype_np: DataType): assert set(dtype_np.get_requirements().modules) == {"lightgbm", "numpy"} assert isinstance(dtype_np, LightGBMDataType) assert isinstance(dtype_np.inner, NumpyNdarrayType) + assert isinstance(dtype_np.labels, ArrayType) + assert dtype_np.labels.dtype == PrimitiveType(data=None, ptype="float") assert dtype_np.get_model().__name__ == dtype_np.inner.get_model().__name__ assert dtype_np.get_model().schema() == { "title": "NumpyNdarray", @@ -92,6 +101,7 @@ def test_hook_df(dtype_df: DataType): assert set(dtype_df.get_requirements().modules) == {"lightgbm", "pandas"} assert isinstance(dtype_df, LightGBMDataType) assert isinstance(dtype_df.inner, DataFrameType) + assert isinstance(dtype_df.labels, NumpyNdarrayType) assert dtype_df.get_model().__name__ == dtype_df.inner.get_model().__name__ assert dtype_df.get_model().schema() == { "title": "DataFrame", @@ -116,54 +126,131 @@ def test_hook_df(dtype_df: DataType): @pytest.mark.parametrize( - "lgb_dtype, data_type", - [("dtype_np", NumpyNdarrayType), ("dtype_df", DataFrameType)], + "lgb_dtype, data_type, label_type", + [ + ("dtype_np", NumpyNdarrayType, ArrayType), + ("dtype_df", DataFrameType, NumpyNdarrayType), + ], ) -def test_lightgbm_source(lgb_dtype, data_type, request): +def test_lightgbm_source(lgb_dtype, data_type, label_type, request): lgb_dtype = request.getfixturevalue(lgb_dtype) assert isinstance(lgb_dtype, LightGBMDataType) assert isinstance(lgb_dtype.inner, data_type) + assert isinstance(lgb_dtype.labels, label_type) def custom_assert(x, y): assert hasattr(x, "data") assert hasattr(y, "data") assert all(x.data == y.data) - assert all(x.label == y.label) + label_check = x.label == y.label + if isinstance(label_check, (list, np.ndarray)): + assert all(label_check) + else: + assert label_check - data_write_read_check( + artifacts = data_write_read_check( lgb_dtype, writer=LightGBMDataWriter(), reader_type=LightGBMDataReader, custom_assert=custom_assert, ) + if isinstance(lgb_dtype.inner, NumpyNdarrayType): + assert list(artifacts.keys()) == [ + f"{LIGHTGBM_DATA}/data", + f"{LIGHTGBM_LABEL}/0/data", + f"{LIGHTGBM_LABEL}/1/data", + f"{LIGHTGBM_LABEL}/2/data", + f"{LIGHTGBM_LABEL}/3/data", + f"{LIGHTGBM_LABEL}/4/data", + ] + assert artifacts[f"{LIGHTGBM_DATA}/data"].uri.endswith( + f"data/{LIGHTGBM_DATA}" + ) + assert artifacts[f"{LIGHTGBM_LABEL}/0/data"].uri.endswith( + f"data/{LIGHTGBM_LABEL}/0" + ) + assert artifacts[f"{LIGHTGBM_LABEL}/1/data"].uri.endswith( + f"data/{LIGHTGBM_LABEL}/1" + ) + assert artifacts[f"{LIGHTGBM_LABEL}/2/data"].uri.endswith( + f"data/{LIGHTGBM_LABEL}/2" + ) + assert artifacts[f"{LIGHTGBM_LABEL}/3/data"].uri.endswith( + f"data/{LIGHTGBM_LABEL}/3" + ) + assert artifacts[f"{LIGHTGBM_LABEL}/4/data"].uri.endswith( + f"data/{LIGHTGBM_LABEL}/4" + ) + else: + assert list(artifacts.keys()) == [ + f"{LIGHTGBM_DATA}/data", + f"{LIGHTGBM_LABEL}/data", + ] + assert artifacts[f"{LIGHTGBM_DATA}/data"].uri.endswith( + f"data/{LIGHTGBM_DATA}" + ) + assert artifacts[f"{LIGHTGBM_LABEL}/data"].uri.endswith( + f"data/{LIGHTGBM_LABEL}" + ) + def test_serialize__np(dtype_np, np_payload): - ds = lgb.Dataset(np_payload) + ds = lgb.Dataset(np_payload, label=np_payload.reshape((-1,)).tolist()) payload = dtype_np.serialize(ds) - assert payload == np_payload.tolist() + assert payload[LIGHTGBM_DATA] == np_payload.tolist() + assert payload[LIGHTGBM_LABEL] == np_payload.reshape((-1,)).tolist() with pytest.raises(SerializationError): dtype_np.serialize({"abc": 123}) # wrong type def test_deserialize__np(dtype_np, np_payload): - ds = dtype_np.deserialize(np_payload) + ds = dtype_np.deserialize( + { + LIGHTGBM_DATA: np_payload, + LIGHTGBM_LABEL: np_payload.reshape((-1,)).tolist(), + } + ) assert isinstance(ds, lgb.Dataset) assert np.all(ds.data == np_payload) + assert np.all(ds.label == np_payload.reshape((-1,)).tolist()) with pytest.raises(DeserializationError): - dtype_np.deserialize([[1], ["abc"]]) # illegal matrix + dtype_np.deserialize({LIGHTGBM_DATA: [[1], ["abc"]]}) # illegal matrix -def test_serialize__df(dtype_df, df_payload): - ds = lgb.Dataset(df_payload) - payload = dtype_df.serialize(ds) - assert payload["values"] == df_payload.to_dict("records") +def test_serialize__df(df_payload): + ds = lgb.Dataset(df_payload, label=None, free_raw_data=False) + payload = DataType.create(obj=ds) + assert payload.serialize(ds)["values"] == df_payload.to_dict("records") + assert LIGHTGBM_LABEL not in payload + + def custom_assert(x, y): + assert hasattr(x, "data") + assert hasattr(y, "data") + assert all(x.data == y.data) + assert x.label == y.label + + artifacts = data_write_read_check( + payload, + writer=LightGBMDataWriter(), + reader_type=LightGBMDataReader, + custom_assert=custom_assert, + ) + + assert len(artifacts.keys()) == 1 + assert list(artifacts.keys()) == ["data"] + assert artifacts["data"].uri.endswith("/data") def test_deserialize__df(dtype_df, df_payload): - ds = dtype_df.deserialize({"values": df_payload}) + ds = dtype_df.deserialize( + { + LIGHTGBM_DATA: {"values": df_payload}, + LIGHTGBM_LABEL: np.array([0, 1]).tolist(), + } + ) assert isinstance(ds, lgb.Dataset) assert ds.data.equals(df_payload)