From 14adcd3b6c6ca951573cc47e528c751bc98a6a40 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Sat, 5 Oct 2024 00:00:36 +0200 Subject: [PATCH] categorical_dtype update --- docs/notebooks | 2 +- src/moscot/plotting/_utils.py | 3 +-- src/moscot/problems/time/_mixins.py | 4 ++-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/notebooks b/docs/notebooks index c48edf3d0..5b9d4e07d 160000 --- a/docs/notebooks +++ b/docs/notebooks @@ -1 +1 @@ -Subproject commit c48edf3d0acb6dc191bb571320357b9119a6c559 +Subproject commit 5b9d4e07d7188b1c391ec47a1c5d957da1ab2bca diff --git a/src/moscot/plotting/_utils.py b/src/moscot/plotting/_utils.py index 358ae5fa2..49fdf3c8a 100644 --- a/src/moscot/plotting/_utils.py +++ b/src/moscot/plotting/_utils.py @@ -18,7 +18,6 @@ import numpy as np import pandas as pd -from pandas.api.types import is_categorical_dtype from sklearn.preprocessing import MinMaxScaler import matplotlib as mpl @@ -523,7 +522,7 @@ def _color_transition(c1: str, c2: str, num: int, alpha: float) -> List[str]: def _create_col_colors(adata: AnnData, obs_col: str, subset: Union[str, List[str]]) -> Optional[mpl.colors.Colormap]: if isinstance(subset, list): subset = subset[0] - if not is_categorical_dtype(adata.obs[obs_col]): + if not isinstance(adata.obs[obs_col].dtype, pd.CategoricalDtype): raise TypeError(f"`adata.obs[{obs_col!r}] must be of categorical type.") for i, cat in enumerate(adata.obs[obs_col].cat.categories): diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 29a44aed0..4673934c1 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -18,7 +18,7 @@ import numpy as np import pandas as pd -from pandas.api.types import infer_dtype, is_categorical_dtype, is_numeric_dtype +from pandas.api.types import infer_dtype, is_numeric_dtype from anndata import AnnData @@ -1034,7 +1034,7 @@ def temporal_key(self: TemporalMixinProtocol[K, B], key: Optional[str]) -> None: raise KeyError(f"Unable to find temporal key in `adata.obs[{key!r}]`.") self.adata.obs[key] = self.adata.obs[key].astype("category") col = self.adata.obs[key] - if not (is_categorical_dtype(col) and is_numeric_dtype(col.cat.categories)): + if not (isinstance(col.dtype, pd.CategoricalDtype) and is_numeric_dtype(col.cat.categories)): raise TypeError( f"Expected `adata.obs[{key!r}]` to be categorical with numeric categories, " f"found `{infer_dtype(col)}`."