Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
add separate reader/writer for lgb labels (#421)
Browse files Browse the repository at this point in the history
* add separate reader/writer for labels

* fix linter

* use free_raw_data as False in serialize test

* use construct in process so that labels is never None

fix some tests

fix lint

* remove the usage of construct

* fix pylint

* do not save labels when they are None

* saved key shouldn't have /data/data when labels is None

* fix test
  • Loading branch information
madhur-tandon authored Oct 17, 2022
1 parent 466fcac commit 86105ee
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 32 deletions.
93 changes: 77 additions & 16 deletions mlem/contrib/lightgbm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import posixpath
import tempfile
from typing import Any, ClassVar, Iterator, List, Optional, Tuple, Type
from typing import Any, ClassVar, Iterator, Optional, Tuple, Type

import flatdict
import lightgbm as lgb
from pydantic import BaseModel

Expand All @@ -28,6 +29,8 @@
)

LGB_REQUIREMENT = UnixPackageRequirement(package_name="libgomp1")
LIGHTGBM_DATA = "inner"
LIGHTGBM_LABEL = "label"


class LightGBMDataType(
Expand All @@ -37,20 +40,38 @@ class LightGBMDataType(
:class:`.DataType` implementation for `lightgbm.Dataset` type
:param inner: :class:`.DataType` instance for underlying data
:param labels: :class:`.DataType` instance for underlying labels
"""

type: ClassVar[str] = "lightgbm"
valid_types: ClassVar = (lgb.Dataset,)
inner: DataType
labels: Optional[DataType]

def serialize(self, instance: Any) -> dict:
self.check_type(instance, lgb.Dataset, SerializationError)
if self.labels is not None:
return {
LIGHTGBM_DATA: self.inner.get_serializer().serialize(
instance.data
),
LIGHTGBM_LABEL: self.labels.get_serializer().serialize(
instance.label
),
}
return self.inner.get_serializer().serialize(instance.data)

def deserialize(self, obj: dict) -> Any:
v = self.inner.get_serializer().deserialize(obj)
if self.labels is not None:
data = self.inner.get_serializer().deserialize(obj[LIGHTGBM_DATA])
label = self.labels.get_serializer().deserialize(
obj[LIGHTGBM_LABEL]
)
else:
data = self.inner.get_serializer().deserialize(obj)
label = None
try:
return lgb.Dataset(v, free_raw_data=False)
return lgb.Dataset(data, label=label, free_raw_data=False)
except ValueError as e:
raise DeserializationError(
f"object: {obj} could not be converted to lightgbm dataset"
Expand All @@ -70,7 +91,12 @@ def get_writer(

@classmethod
def process(cls, obj: Any, **kwargs) -> DataType:
return LightGBMDataType(inner=DataAnalyzer.analyze(obj.data))
return LightGBMDataType(
inner=DataAnalyzer.analyze(obj.data),
labels=DataAnalyzer.analyze(obj.label)
if obj.label is not None
else None,
)

def get_model(self, prefix: str = "") -> Type[BaseModel]:
return self.inner.get_serializer().get_model(prefix)
Expand All @@ -86,33 +112,68 @@ def write(
raise ValueError(
f"expected data to be of LightGBMDataType, got {type(data)} instead"
)
lightgbm_construct = data.data.construct()
raw_data = lightgbm_construct.get_data()
underlying_labels = lightgbm_construct.get_label().tolist()
inner_reader, art = data.inner.get_writer().write(
data.inner.copy().bind(raw_data), storage, path
)

lightgbm_raw = data.data

if data.labels is not None:
inner_reader, inner_art = data.inner.get_writer().write(
data.inner.copy().bind(lightgbm_raw.data),
storage,
posixpath.join(path, LIGHTGBM_DATA),
)
labels_reader, labels_art = data.labels.get_writer().write(
data.labels.copy().bind(lightgbm_raw.label),
storage,
posixpath.join(path, LIGHTGBM_LABEL),
)
res = dict(
flatdict.FlatterDict(
{LIGHTGBM_DATA: inner_art, LIGHTGBM_LABEL: labels_art},
delimiter="/",
)
)
else:
inner_reader, inner_art = data.inner.get_writer().write(
data.inner.copy().bind(lightgbm_raw.data),
storage,
path,
)
res = inner_art
labels_reader = None

return (
LightGBMDataReader(
data_type=data,
inner=inner_reader,
label=underlying_labels,
labels=labels_reader,
),
art,
res,
)


class LightGBMDataReader(DataReader):
type: ClassVar[str] = "lightgbm"
data_type: LightGBMDataType
inner: DataReader
label: List
labels: Optional[DataReader]

def read(self, artifacts: Artifacts) -> DataType:
inner_data_type = self.inner.read(artifacts)
return LightGBMDataType(inner=inner_data_type).bind(
if self.labels is not None:
artifacts = flatdict.FlatterDict(artifacts, delimiter="/")
inner_data_type = self.inner.read(artifacts[LIGHTGBM_DATA]) # type: ignore[arg-type]
labels_data_type = self.labels.read(artifacts[LIGHTGBM_LABEL]) # type: ignore[arg-type]
else:
inner_data_type = self.inner.read(artifacts)
labels_data_type = None
return LightGBMDataType(
inner=inner_data_type, labels=labels_data_type
).bind(
lgb.Dataset(
inner_data_type.data, label=self.label, free_raw_data=False
inner_data_type.data,
label=labels_data_type.data
if labels_data_type is not None
else None,
free_raw_data=False,
)
)

Expand Down
119 changes: 103 additions & 16 deletions tests/contrib/test_lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import pytest

from mlem.contrib.lightgbm import (
LIGHTGBM_DATA,
LIGHTGBM_LABEL,
LightGBMDataReader,
LightGBMDataType,
LightGBMDataWriter,
Expand All @@ -12,7 +14,12 @@
from mlem.contrib.numpy import NumpyNdarrayType
from mlem.contrib.pandas import DataFrameType
from mlem.core.artifacts import LOCAL_STORAGE
from mlem.core.data_type import DataAnalyzer, DataType
from mlem.core.data_type import (
ArrayType,
DataAnalyzer,
DataType,
PrimitiveType,
)
from mlem.core.errors import DeserializationError, SerializationError
from mlem.core.model import ModelAnalyzer, ModelType
from mlem.core.requirements import UnixPackageRequirement
Expand Down Expand Up @@ -46,7 +53,7 @@ def df_payload():
def data_df(df_payload):
return lgb.Dataset(
df_payload,
label=np.array([0, 1]).tolist(),
label=np.array([0, 1]),
free_raw_data=False,
)

Expand Down Expand Up @@ -75,6 +82,8 @@ def test_hook_np(dtype_np: DataType):
assert set(dtype_np.get_requirements().modules) == {"lightgbm", "numpy"}
assert isinstance(dtype_np, LightGBMDataType)
assert isinstance(dtype_np.inner, NumpyNdarrayType)
assert isinstance(dtype_np.labels, ArrayType)
assert dtype_np.labels.dtype == PrimitiveType(data=None, ptype="float")
assert dtype_np.get_model().__name__ == dtype_np.inner.get_model().__name__
assert dtype_np.get_model().schema() == {
"title": "NumpyNdarray",
Expand All @@ -92,6 +101,7 @@ def test_hook_df(dtype_df: DataType):
assert set(dtype_df.get_requirements().modules) == {"lightgbm", "pandas"}
assert isinstance(dtype_df, LightGBMDataType)
assert isinstance(dtype_df.inner, DataFrameType)
assert isinstance(dtype_df.labels, NumpyNdarrayType)
assert dtype_df.get_model().__name__ == dtype_df.inner.get_model().__name__
assert dtype_df.get_model().schema() == {
"title": "DataFrame",
Expand All @@ -116,54 +126,131 @@ def test_hook_df(dtype_df: DataType):


@pytest.mark.parametrize(
"lgb_dtype, data_type",
[("dtype_np", NumpyNdarrayType), ("dtype_df", DataFrameType)],
"lgb_dtype, data_type, label_type",
[
("dtype_np", NumpyNdarrayType, ArrayType),
("dtype_df", DataFrameType, NumpyNdarrayType),
],
)
def test_lightgbm_source(lgb_dtype, data_type, request):
def test_lightgbm_source(lgb_dtype, data_type, label_type, request):
lgb_dtype = request.getfixturevalue(lgb_dtype)
assert isinstance(lgb_dtype, LightGBMDataType)
assert isinstance(lgb_dtype.inner, data_type)
assert isinstance(lgb_dtype.labels, label_type)

def custom_assert(x, y):
assert hasattr(x, "data")
assert hasattr(y, "data")
assert all(x.data == y.data)
assert all(x.label == y.label)
label_check = x.label == y.label
if isinstance(label_check, (list, np.ndarray)):
assert all(label_check)
else:
assert label_check

data_write_read_check(
artifacts = data_write_read_check(
lgb_dtype,
writer=LightGBMDataWriter(),
reader_type=LightGBMDataReader,
custom_assert=custom_assert,
)

if isinstance(lgb_dtype.inner, NumpyNdarrayType):
assert list(artifacts.keys()) == [
f"{LIGHTGBM_DATA}/data",
f"{LIGHTGBM_LABEL}/0/data",
f"{LIGHTGBM_LABEL}/1/data",
f"{LIGHTGBM_LABEL}/2/data",
f"{LIGHTGBM_LABEL}/3/data",
f"{LIGHTGBM_LABEL}/4/data",
]
assert artifacts[f"{LIGHTGBM_DATA}/data"].uri.endswith(
f"data/{LIGHTGBM_DATA}"
)
assert artifacts[f"{LIGHTGBM_LABEL}/0/data"].uri.endswith(
f"data/{LIGHTGBM_LABEL}/0"
)
assert artifacts[f"{LIGHTGBM_LABEL}/1/data"].uri.endswith(
f"data/{LIGHTGBM_LABEL}/1"
)
assert artifacts[f"{LIGHTGBM_LABEL}/2/data"].uri.endswith(
f"data/{LIGHTGBM_LABEL}/2"
)
assert artifacts[f"{LIGHTGBM_LABEL}/3/data"].uri.endswith(
f"data/{LIGHTGBM_LABEL}/3"
)
assert artifacts[f"{LIGHTGBM_LABEL}/4/data"].uri.endswith(
f"data/{LIGHTGBM_LABEL}/4"
)
else:
assert list(artifacts.keys()) == [
f"{LIGHTGBM_DATA}/data",
f"{LIGHTGBM_LABEL}/data",
]
assert artifacts[f"{LIGHTGBM_DATA}/data"].uri.endswith(
f"data/{LIGHTGBM_DATA}"
)
assert artifacts[f"{LIGHTGBM_LABEL}/data"].uri.endswith(
f"data/{LIGHTGBM_LABEL}"
)


def test_serialize__np(dtype_np, np_payload):
ds = lgb.Dataset(np_payload)
ds = lgb.Dataset(np_payload, label=np_payload.reshape((-1,)).tolist())
payload = dtype_np.serialize(ds)
assert payload == np_payload.tolist()
assert payload[LIGHTGBM_DATA] == np_payload.tolist()
assert payload[LIGHTGBM_LABEL] == np_payload.reshape((-1,)).tolist()

with pytest.raises(SerializationError):
dtype_np.serialize({"abc": 123}) # wrong type


def test_deserialize__np(dtype_np, np_payload):
ds = dtype_np.deserialize(np_payload)
ds = dtype_np.deserialize(
{
LIGHTGBM_DATA: np_payload,
LIGHTGBM_LABEL: np_payload.reshape((-1,)).tolist(),
}
)
assert isinstance(ds, lgb.Dataset)
assert np.all(ds.data == np_payload)
assert np.all(ds.label == np_payload.reshape((-1,)).tolist())

with pytest.raises(DeserializationError):
dtype_np.deserialize([[1], ["abc"]]) # illegal matrix
dtype_np.deserialize({LIGHTGBM_DATA: [[1], ["abc"]]}) # illegal matrix


def test_serialize__df(dtype_df, df_payload):
ds = lgb.Dataset(df_payload)
payload = dtype_df.serialize(ds)
assert payload["values"] == df_payload.to_dict("records")
def test_serialize__df(df_payload):
ds = lgb.Dataset(df_payload, label=None, free_raw_data=False)
payload = DataType.create(obj=ds)
assert payload.serialize(ds)["values"] == df_payload.to_dict("records")
assert LIGHTGBM_LABEL not in payload

def custom_assert(x, y):
assert hasattr(x, "data")
assert hasattr(y, "data")
assert all(x.data == y.data)
assert x.label == y.label

artifacts = data_write_read_check(
payload,
writer=LightGBMDataWriter(),
reader_type=LightGBMDataReader,
custom_assert=custom_assert,
)

assert len(artifacts.keys()) == 1
assert list(artifacts.keys()) == ["data"]
assert artifacts["data"].uri.endswith("/data")


def test_deserialize__df(dtype_df, df_payload):
ds = dtype_df.deserialize({"values": df_payload})
ds = dtype_df.deserialize(
{
LIGHTGBM_DATA: {"values": df_payload},
LIGHTGBM_LABEL: np.array([0, 1]).tolist(),
}
)
assert isinstance(ds, lgb.Dataset)
assert ds.data.equals(df_payload)

Expand Down

0 comments on commit 86105ee

Please sign in to comment.