diff --git a/xgboost_ray/data_sources/ray_dataset.py b/xgboost_ray/data_sources/ray_dataset.py index 90cf8b9e..013677b9 100644 --- a/xgboost_ray/data_sources/ray_dataset.py +++ b/xgboost_ray/data_sources/ray_dataset.py @@ -57,8 +57,13 @@ def load_data(data: "ray.data.dataset.Dataset", else: data = [data[i] for i in indices] - local_df = [ds.to_pandas(limit=DATASET_TO_PANDAS_LIMIT) for ds in data] - return Pandas.load_data(pd.concat(local_df, copy=False), ignore=ignore) + if isinstance(data, ray.data.dataset.Dataset): + local_df = data.to_pandas(limit=DATASET_TO_PANDAS_LIMIT) + else: + local_df = pd.concat( + [ds.to_pandas(limit=DATASET_TO_PANDAS_LIMIT) for ds in data], + copy=False) + return Pandas.load_data(local_df, ignore=ignore) @staticmethod def convert_to_series(data: Union["ray.data.dataset.Dataset", Sequence[