From f529541df1f9f0d98171bf8dfa0273d4011a5988 Mon Sep 17 00:00:00 2001 From: lixfz Date: Thu, 15 Feb 2024 22:41:40 +0800 Subject: [PATCH] Update column selector --- hypernets/tabular/cuml_ex/_data_cleaner.py | 4 ++-- hypernets/tabular/dask_ex/_transformers.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/hypernets/tabular/cuml_ex/_data_cleaner.py b/hypernets/tabular/cuml_ex/_data_cleaner.py index c974a899..97a00caf 100644 --- a/hypernets/tabular/cuml_ex/_data_cleaner.py +++ b/hypernets/tabular/cuml_ex/_data_cleaner.py @@ -30,7 +30,7 @@ def as_local(self): reduce_mem_usage=self.reduce_mem_usage, int_convert_to=self.int_convert_to) copy_attrs_as_local(self, target, 'df_meta_', 'columns_', 'dropped_constant_columns_', - 'dropped_idness_columns_', 'dropped_duplicated_columns_') + 'dropped_idness_columns_', 'dropped_duplicated_columns_') return target @@ -52,7 +52,7 @@ def _get_duplicated_columns(df): @staticmethod def replace_nan_chars(X: cudf.DataFrame, nan_chars): - cat_cols = X.select_dtypes(['object', ]) + cat_cols = X.select_dtypes(['object', 'string', ]) if cat_cols.shape[1] > 0: cat_cols = cat_cols.replace(nan_chars, cupy.nan) X[cat_cols.columns] = cat_cols diff --git a/hypernets/tabular/dask_ex/_transformers.py b/hypernets/tabular/dask_ex/_transformers.py index 7107d311..b0bb35dc 100644 --- a/hypernets/tabular/dask_ex/_transformers.py +++ b/hypernets/tabular/dask_ex/_transformers.py @@ -292,7 +292,7 @@ def fit(self, X, y=None): self.dtypes_ = {c: X[c].dtype for c in X.columns} if self.columns is None: - columns = X.select_dtypes(include=["category", 'object', 'bool']).columns.to_list() + columns = X.select_dtypes(include=['category', 'object', 'string', 'bool']).columns.to_list() else: columns = self.columns