Skip to content

Commit

Permalink
Add multi label support v2 (#306)
Browse files Browse the repository at this point in the history
* Add multi label support to xgboost ray

* fix lint

* add a missing change

* add another missing change

* fix lint
  • Loading branch information
louis-huang authored Mar 2, 2024
1 parent 5a840af commit e904925
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 9 deletions.
6 changes: 3 additions & 3 deletions xgboost_ray/data_sources/data_source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union

import pandas as pd
from ray.actor import ActorHandle
Expand Down Expand Up @@ -118,12 +118,12 @@ def convert_to_series(data: Any) -> pd.Series:
@classmethod
def get_column(
cls, data: pd.DataFrame, column: Any
) -> Tuple[pd.Series, Optional[str]]:
) -> Tuple[pd.Series, Optional[Union[str, List]]]:
"""Helper method wrapping around convert to series.
This method should usually not be overwritten.
"""
if isinstance(column, str):
if isinstance(column, str) or isinstance(column, List):
return data[column], column
elif column is not None:
return cls.convert_to_series(column), None
Expand Down
17 changes: 14 additions & 3 deletions xgboost_ray/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,10 @@ def _split_dataframe(

label, exclude = data_source.get_column(local_data, self.label)
if exclude:
exclude_cols.add(exclude)
if isinstance(exclude, List):
exclude_cols.update(exclude)
else:
exclude_cols.add(exclude)

weight, exclude = data_source.get_column(local_data, self.weight)
if exclude:
Expand Down Expand Up @@ -406,7 +409,11 @@ def get_data_source(self) -> Type[DataSource]:
): # noqa: E721:
# Label is an object of a different type than the main data.
# We have to make sure they are compatible
if not data_source.is_data_type(self.label):
# if it's a parquet data source and label is a list,
# then we consider it a multi-label data
if not data_source.is_data_type(self.label) and not (
isinstance(self.label, List) and data_source.__name__ == "Parquet"
):
raise ValueError(
"The passed `data` and `label` types are not compatible."
"\nFIX THIS by passing the same types to the "
Expand Down Expand Up @@ -521,7 +528,11 @@ def get_data_source(self) -> Type[DataSource]:
f"RayDMatrix."
)

if self.label is not None and not isinstance(self.label, str):
if (
self.label is not None
and not isinstance(self.label, str)
and not isinstance(self.label, List)
):
raise ValueError(
f"Invalid `label` value for distributed datasets: "
f"{self.label}. Only strings are supported. "
Expand Down
59 changes: 56 additions & 3 deletions xgboost_ray/tests/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ def setUp(self):
* repeat
)
self.y = np.array([0, 1, 2, 3] * repeat)
self.multi_y = np.array(
[
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 1],
[0, 0, 1, 0],
]
* repeat
)

@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -62,7 +71,7 @@ def testColumnOrdering(self):

assert data.columns.tolist() == cols[:-1]

def _testMatrixCreation(self, in_x, in_y, **kwargs):
def _testMatrixCreation(self, in_x, in_y, multi_label=False, **kwargs):
if "sharding" not in kwargs:
kwargs["sharding"] = RayShardingMode.BATCH
mat = RayDMatrix(in_x, in_y, **kwargs)
Expand All @@ -81,7 +90,10 @@ def _load_data(params):
x, y = _load_data(params)

self.assertTrue(np.allclose(self.x, x))
self.assertTrue(np.allclose(self.y, y))
if multi_label:
self.assertTrue(np.allclose(self.multi_y, y))
else:
self.assertTrue(np.allclose(self.y, y))

# Multi actor check
mat = RayDMatrix(in_x, in_y, **kwargs)
Expand All @@ -95,7 +107,10 @@ def _load_data(params):
x2, y2 = _load_data(params)

self.assertTrue(np.allclose(self.x, concat_dataframes([x1, x2])))
self.assertTrue(np.allclose(self.y, concat_dataframes([y1, y2])))
if multi_label:
self.assertTrue(np.allclose(self.multi_y, concat_dataframes([y1, y2])))
else:
self.assertTrue(np.allclose(self.y, concat_dataframes([y1, y2])))

def testFromNumpy(self):
in_x = self.x
Expand Down Expand Up @@ -276,6 +291,22 @@ def testFromMultiCSVString(self):
[data_file_1, data_file_2], "label", distributed=True
)

def testFromParquetStringMultiLabel(self):
with tempfile.TemporaryDirectory() as dir:
data_file = os.path.join(dir, "data.parquet")

data_df = pd.DataFrame(self.x, columns=["a", "b", "c", "d"])
labels = [f"label_{label}" for label in range(4)]
data_df[labels] = self.multi_y
data_df.to_parquet(data_file)

self._testMatrixCreation(
data_file, labels, multi_label=True, distributed=False
)
self._testMatrixCreation(
data_file, labels, multi_label=True, distributed=True
)

def testFromParquetString(self):
with tempfile.TemporaryDirectory() as dir:
data_file = os.path.join(dir, "data.parquet")
Expand All @@ -287,6 +318,28 @@ def testFromParquetString(self):
self._testMatrixCreation(data_file, "label", distributed=False)
self._testMatrixCreation(data_file, "label", distributed=True)

def testFromMultiParquetStringMultiLabel(self):
with tempfile.TemporaryDirectory() as dir:
data_file_1 = os.path.join(dir, "data_1.parquet")
data_file_2 = os.path.join(dir, "data_2.parquet")

data_df = pd.DataFrame(self.x, columns=["a", "b", "c", "d"])
labels = [f"label_{label}" for label in range(4)]
data_df[labels] = self.multi_y

df_1 = data_df[0 : len(data_df) // 2]
df_2 = data_df[len(data_df) // 2 :]

df_1.to_parquet(data_file_1)
df_2.to_parquet(data_file_2)

self._testMatrixCreation(
[data_file_1, data_file_2], labels, multi_label=True, distributed=False
)
self._testMatrixCreation(
[data_file_1, data_file_2], labels, multi_label=True, distributed=True
)

def testFromMultiParquetString(self):
with tempfile.TemporaryDirectory() as dir:
data_file_1 = os.path.join(dir, "data_1.parquet")
Expand Down

0 comments on commit e904925

Please sign in to comment.