diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 0b945f9a..3a5a24fb 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 0b945f9a85edf606fe77c87605069d215cf6b02e +Subproject commit 3a5a24fbbd9882c6e381d9fe9a7446be2279f403 diff --git a/docs/usage/usage.md b/docs/usage/usage.md index 6dac5f75..7b4f4481 100644 --- a/docs/usage/usage.md +++ b/docs/usage/usage.md @@ -74,6 +74,7 @@ Other than tools, preprocessing steps usually don’t return an easily interpret :toctree: preprocessing :nosignatures: + preprocessing.encode preprocessing.pca preprocessing.regress_out preprocessing.subsample @@ -115,16 +116,6 @@ Other than tools, preprocessing steps usually don’t return an easily interpret preprocessing.mice_forest_impute ``` -### Encoding - -```{eval-rst} -.. autosummary:: - :toctree: preprocessing - :nosignatures: - - preprocessing.encode -``` - ### Normalization ```{eval-rst} @@ -406,6 +397,8 @@ Methods that extract and visualize tool-specific annotation in an AnnData object :nosignatures: anndata.infer_feature_types + anndata.feature_type_overview + anndata.replace_feature_types anndata.df_to_anndata anndata.anndata_to_df anndata.move_to_obs @@ -414,7 +407,6 @@ Methods that extract and visualize tool-specific annotation in an AnnData object anndata.get_obs_df anndata.get_var_df anndata.get_rank_features_df - anndata.type_overview ``` ## Settings diff --git a/ehrapy/anndata/__init__.py b/ehrapy/anndata/__init__.py index a762a49b..70f1e09c 100644 --- a/ehrapy/anndata/__init__.py +++ b/ehrapy/anndata/__init__.py @@ -1,4 +1,9 @@ -from ehrapy.anndata._feature_specifications import check_feature_types, infer_feature_types +from ehrapy.anndata._feature_specifications import ( + check_feature_types, + feature_type_overview, + infer_feature_types, + replace_feature_types, +) from ehrapy.anndata.anndata_ext import ( anndata_to_df, delete_from_obs, @@ -10,11 +15,12 @@ move_to_obs, move_to_x, rank_genes_groups_df, - type_overview, ) __all__ = [ "check_feature_types", + "replace_feature_types", + "feature_type_overview", "infer_feature_types", "anndata_to_df", "delete_from_obs", @@ -26,5 +32,4 @@ "move_to_obs", "move_to_x", "rank_genes_groups_df", - "type_overview", ] diff --git a/ehrapy/anndata/_constants.py b/ehrapy/anndata/_constants.py index f5b164b2..68ae6943 100644 --- a/ehrapy/anndata/_constants.py +++ b/ehrapy/anndata/_constants.py @@ -2,13 +2,7 @@ # ----------------------- # The column name and used values in adata.var for column types. -EHRAPY_TYPE_KEY = "ehrapy_column_type" # TODO: Change to ENCODING_TYPE_KEY -NUMERIC_TAG = "numeric" -NON_NUMERIC_TAG = "non_numeric" -NON_NUMERIC_ENCODED_TAG = "non_numeric_encoded" - - FEATURE_TYPE_KEY = "feature_type" -CONTINUOUS_TAG = "numeric" # TODO: Eventually rename to NUMERIC_TAG (as soon as the other NUMERIC_TAG is removed) +NUMERIC_TAG = "numeric" CATEGORICAL_TAG = "categorical" DATE_TAG = "date" diff --git a/ehrapy/anndata/_feature_specifications.py b/ehrapy/anndata/_feature_specifications.py index 4bf3d1fd..215a6427 100644 --- a/ehrapy/anndata/_feature_specifications.py +++ b/ehrapy/anndata/_feature_specifications.py @@ -1,18 +1,71 @@ +from collections.abc import Iterable from functools import wraps from typing import Literal import numpy as np import pandas as pd from anndata import AnnData +from dateutil.parser import isoparse # type: ignore from lamin_utils import logger from rich import print from rich.tree import Tree -from ehrapy.anndata._constants import CATEGORICAL_TAG, CONTINUOUS_TAG, DATE_TAG, FEATURE_TYPE_KEY -from ehrapy.anndata.anndata_ext import anndata_to_df +from ehrapy.anndata._constants import CATEGORICAL_TAG, DATE_TAG, FEATURE_TYPE_KEY, NUMERIC_TAG -def infer_feature_types(adata: AnnData, layer: str | None = None, output: Literal["tree", "dataframe"] | None = "tree"): +def _detect_feature_type(col: pd.Series) -> tuple[Literal["date", "categorical", "numeric"], bool]: + """Detect the feature type of a column in a pandas DataFrame. + + Args: + col: The column to detect the feature type for. + verbose: Whether to print warnings for uncertain feature types. Defaults to True. + + Returns: + The detected feature type (one of 'date', 'categorical', or 'numeric') and a boolean, which is True if the feature type is uncertain. + """ + n_elements = len(col) + col = col.dropna() + if len(col) == 0: + raise ValueError( + f"Feature {col.name} has only NaN values. Please drop the feature if you want to infer the feature type." + ) + majority_type = col.apply(type).value_counts().idxmax() + + if majority_type == pd.Timestamp: + return DATE_TAG, False # type: ignore + + if majority_type == str: + try: + col.apply(isoparse) + return DATE_TAG, False # type: ignore + except ValueError: + try: + col = pd.to_numeric(col, errors="raise") # Could be an encoded categorical or a numeric feature + majority_type = float + except ValueError: + # Features stored as Strings that cannot be converted to float are assumed to be categorical + return CATEGORICAL_TAG, False # type: ignore + + if majority_type not in [int, float]: + return CATEGORICAL_TAG, False # type: ignore + + # Guess categorical if the feature is an integer and the values are 0/1 to n-1/n with no gaps + if ( + (majority_type == int or (np.all(i.is_integer() for i in col))) + and (n_elements != col.nunique()) + and ( + (col.min() == 0 and np.all(np.sort(col.unique()) == np.arange(col.nunique()))) + or (col.min() == 1 and np.all(np.sort(col.unique()) == np.arange(1, col.nunique() + 1))) + ) + ): + return CATEGORICAL_TAG, True # type: ignore + + return NUMERIC_TAG, False # type: ignore + + +def infer_feature_types( + adata: AnnData, layer: str | None = None, output: Literal["tree", "dataframe"] | None = "tree", verbose: bool = True +): """Infer feature types from AnnData object. For each feature in adata.var_names, the method infers one of the following types: 'date', 'categorical', or 'numeric'. @@ -27,32 +80,43 @@ def infer_feature_types(adata: AnnData, layer: str | None = None, output: Litera layer: The layer to use from the AnnData object. If None, the X layer is used. output: The output format. Choose between 'tree', 'dataframe', or None. If 'tree', the feature types will be printed to the console in a tree format. If 'dataframe', a pandas DataFrame with the feature types will be returned. If None, nothing will be returned. Defaults to 'tree'. + verbose: Whether to print warnings for uncertain feature types. Defaults to True. + + Examples: + >>> import ehrapy as ep + >>> adata = ep.dt.mimic_2(encoded=False) + >>> ep.ad.infer_feature_types(adata) """ + from ehrapy.anndata.anndata_ext import anndata_to_df + feature_types = {} + uncertain_features = [] df = anndata_to_df(adata, layer=layer) for feature in adata.var_names: - col = df[feature].dropna() - majority_type = col.apply(type).value_counts().idxmax() - if majority_type == pd.Timestamp: - feature_types[feature] = DATE_TAG - elif majority_type not in [int, float, complex]: - feature_types[feature] = CATEGORICAL_TAG - # Guess categorical if the feature is an integer and the values are 0/1 to n-1 with no gaps - elif np.all(i.is_integer() for i in col) and ( - (col.min() == 0 and np.all(np.sort(col.unique()) == np.arange(col.nunique()))) - or (col.min() == 1 and np.all(np.sort(col.unique()) == np.arange(1, col.nunique() + 1))) + if ( + FEATURE_TYPE_KEY in adata.var.keys() + and adata.var[FEATURE_TYPE_KEY][feature] is not None + and not pd.isna(adata.var[FEATURE_TYPE_KEY][feature]) ): - feature_types[feature] = CATEGORICAL_TAG + feature_types[feature] = adata.var[FEATURE_TYPE_KEY][feature] else: - feature_types[feature] = CONTINUOUS_TAG + feature_types[feature], raise_warning = _detect_feature_type(df[feature]) + if raise_warning: + uncertain_features.append(feature) adata.var[FEATURE_TYPE_KEY] = pd.Series(feature_types)[adata.var_names] - logger.info( - f"Stored feature types in adata.var['{FEATURE_TYPE_KEY}']." - f" Please verify and adjust if necessary using adata.var['{FEATURE_TYPE_KEY}']['feature1']='corrected_type'." - ) + if verbose: + logger.warning( + f"{'Features' if len(uncertain_features) >1 else 'Feature'} {str(uncertain_features)[1:-1]} {'were' if len(uncertain_features) >1 else 'was'} detected as categorical features stored numerically." + f"Please verify and correct using `ep.ad.replace_feature_types` if necessary." + ) + + logger.info( + f"Stored feature types in adata.var['{FEATURE_TYPE_KEY}']." + f" Please verify and adjust if necessary using `ep.ad.replace_feature_types`." + ) if output == "tree": feature_type_overview(adata) @@ -65,9 +129,32 @@ def infer_feature_types(adata: AnnData, layer: str | None = None, output: Litera def check_feature_types(func): @wraps(func) def wrapper(adata, *args, **kwargs): + # Account for class methods that pass self as first argument + _self = None + if not isinstance(adata, AnnData) and len(args) > 0 and isinstance(args[0], AnnData): + _self = adata + adata = args[0] + args = args[1:] + if FEATURE_TYPE_KEY not in adata.var.keys(): - raise ValueError("Feature types are not specified in adata.var. Please run `infer_feature_types` first.") - np.all(adata.var[FEATURE_TYPE_KEY].isin([CATEGORICAL_TAG, CONTINUOUS_TAG, DATE_TAG])) + infer_feature_types(adata, output=None) + logger.warning( + f"Feature types were inferred and stored in adata.var[{FEATURE_TYPE_KEY}]. Please verify using `ep.ad.feature_type_overview` and adjust if necessary using `ep.ad.replace_feature_types`." + ) + + for feature in adata.var_names: + feature_type = adata.var[FEATURE_TYPE_KEY][feature] + if ( + feature_type is not None + and (not pd.isna(feature_type)) + and feature_type not in [CATEGORICAL_TAG, NUMERIC_TAG, DATE_TAG] + ): + logger.warning( + f"Feature '{feature}' has an invalid feature type '{feature_type}'. Please correct using `ep.ad.replace_feature_types`." + ) + + if _self is not None: + return func(_self, adata, *args, **kwargs) return func(adata, *args, **kwargs) return wrapper @@ -75,7 +162,18 @@ def wrapper(adata, *args, **kwargs): @check_feature_types def feature_type_overview(adata: AnnData): - """Print an overview of the feature types in the AnnData object.""" + """Print an overview of the feature types and encoding modes in the AnnData object. + + Args: + adata: The AnnData object storing the EHR data. + + Examples: + >>> import ehrapy as ep + >>> adata = ep.dt.mimic_2(encoded=True) + >>> ep.ad.feature_type_overview(adata) + """ + from ehrapy.anndata.anndata_ext import anndata_to_df + tree = Tree( f"[b] Detected feature types for AnnData object with {len(adata.obs_names)} obs and {len(adata.var_names)} vars", guide_style="underline2", @@ -86,13 +184,55 @@ def feature_type_overview(adata: AnnData): branch.add(date) branch = tree.add("📐[b] Numerical features") - for numeric in sorted(adata.var_names[adata.var[FEATURE_TYPE_KEY] == CONTINUOUS_TAG]): + for numeric in sorted(adata.var_names[adata.var[FEATURE_TYPE_KEY] == NUMERIC_TAG]): branch.add(numeric) branch = tree.add("🗂️[b] Categorical features") cat_features = adata.var_names[adata.var[FEATURE_TYPE_KEY] == CATEGORICAL_TAG] df = anndata_to_df(adata[:, cat_features]) - for categorical in sorted(cat_features): - branch.add(f"{categorical} ({df.loc[:, categorical].nunique()} categories)") + + if "encoding_mode" in adata.var.keys(): + unencoded_vars = adata.var.loc[cat_features, "unencoded_var_names"].unique().tolist() + + for unencoded in sorted(unencoded_vars): + if unencoded in adata.var_names: + branch.add(f"{unencoded} ({df.loc[:, unencoded].nunique()} categories)") + else: + enc_mode = adata.var.loc[adata.var["unencoded_var_names"] == unencoded, "encoding_mode"].values[0] + branch.add(f"{unencoded} ({adata.obs[unencoded].nunique()} categories); {enc_mode} encoded") + + else: + for categorical in sorted(cat_features): + branch.add(f"{categorical} ({df.loc[:, categorical].nunique()} categories)") print(tree) + + +def replace_feature_types(adata, features: Iterable[str], corrected_type: str): + """Correct the feature types for a list of features inplace. + + Args: + adata: :class:`~anndata.AnnData` object storing the EHR data. + features: The features to correct. + corrected_type: The corrected feature type. One of 'date', 'categorical', or 'numeric'. + + Examples: + >>> import ehrapy as ep + >>> adata = ep.dt.diabetes_130_fairlearn() + >>> ep.ad.infer_feature_types(adata) + >>> ep.ad.replace_feature_types(adata, ["time_in_hospital", "number_diagnoses", "num_procedures"], "numeric") + """ + if corrected_type not in [CATEGORICAL_TAG, NUMERIC_TAG, DATE_TAG]: + raise ValueError( + f"Corrected type {corrected_type} not recognized. Choose between '{DATE_TAG}', '{CATEGORICAL_TAG}', or '{NUMERIC_TAG}'." + ) + + if FEATURE_TYPE_KEY not in adata.var.keys(): + raise ValueError( + "Feature types were not inferred. Please infer feature types using 'infer_feature_types' before correcting." + ) + + if isinstance(features, str): + features = [features] + + adata.var[FEATURE_TYPE_KEY].loc[features] = corrected_type diff --git a/ehrapy/anndata/anndata_ext.py b/ehrapy/anndata/anndata_ext.py index db697850..e01ac7dd 100644 --- a/ehrapy/anndata/anndata_ext.py +++ b/ehrapy/anndata/anndata_ext.py @@ -9,14 +9,12 @@ import pandas as pd from anndata import AnnData, concat from lamin_utils import logger -from rich import print -from rich.text import Text -from rich.tree import Tree from scanpy.get import obs_df, rank_genes_groups_df, var_df from scipy import sparse from scipy.sparse import issparse -from ehrapy.anndata._constants import EHRAPY_TYPE_KEY, NON_NUMERIC_ENCODED_TAG, NON_NUMERIC_TAG, NUMERIC_TAG +from ehrapy.anndata import check_feature_types +from ehrapy.anndata._constants import FEATURE_TYPE_KEY, NUMERIC_TAG if TYPE_CHECKING: from collections.abc import Collection, Iterable, Sequence @@ -33,7 +31,7 @@ def df_to_anndata( """Transform a given pandas dataframe into an AnnData object. Note that columns containing boolean values (either 0/1 or T(t)rue/F(f)alse) - will be stored as boolean columns whereas the other non numerical columns will be stored as categorical values. + will be stored as boolean columns whereas the other non-numerical columns will be stored as categorical values. Args: df: The pandas dataframe to be transformed @@ -95,14 +93,7 @@ def df_to_anndata( # initializing an OrderedDict with a non-empty dict might not be intended, # see: https://stackoverflow.com/questions/25480089/right-way-to-initialize-an-ordereddict-using-its-constructor-such-that-it-retain/25480206 uns = OrderedDict() # type: ignore - # store all numerical/non-numerical columns that are not obs only - binary_columns = _detect_binary_columns(df, numerical_columns) - var = pd.DataFrame(index=list(dataframes.df.columns)) - var[EHRAPY_TYPE_KEY] = NON_NUMERIC_TAG - var.loc[var.index.isin(list(set(numerical_columns) | set(binary_columns))), EHRAPY_TYPE_KEY] = NUMERIC_TAG - # in case of encoded columns by ehrapy, want to be able to read it back in - var.loc[var.index.str.contains("ehrapycat"), EHRAPY_TYPE_KEY] = NON_NUMERIC_ENCODED_TAG all_num = True if len(numerical_columns) == len(list(dataframes.df.columns)) else False X = X.astype(np.number) if all_num else X.astype(object) @@ -176,7 +167,7 @@ def move_to_obs(adata: AnnData, to_obs: list[str] | str, copy_obs: bool = False) """Move inplace or copy features from X to obs. Note that columns containing boolean values (either 0/1 or True(true)/False(false)) - will be stored as boolean columns whereas the other non numerical columns will be stored as categorical. + will be stored as boolean columns whereas the other non-numerical columns will be stored as categorical. Args: adata: The AnnData object @@ -226,17 +217,18 @@ def move_to_obs(adata: AnnData, to_obs: list[str] | str, copy_obs: bool = False) return adata +@check_feature_types def _get_var_indices_for_type(adata: AnnData, tag: str) -> list[str]: """Get indices of columns in var for a given tag. Args: adata: The AnnData object - tag: The tag to search for, should be one of `NUMERIC_TAG`, `NON_NUMERIC_TAG` or `NON_NUMERIC_ENCODED_TAG` + tag: The tag to search for, should be one of 'CATEGORIGAL_TAG', 'NUMERIC_TAG', 'DATE_TAG' Returns: List of numeric columns """ - return adata.var_names[adata.var[EHRAPY_TYPE_KEY] == tag].tolist() + return adata.var_names[adata.var[FEATURE_TYPE_KEY] == tag].tolist() def delete_from_obs(adata: AnnData, to_delete: list[str]) -> AnnData: @@ -311,7 +303,7 @@ def move_to_x(adata: AnnData, to_x: list[str] | str) -> AnnData: new_adata.obs = adata.obs[adata.obs.columns[~adata.obs.columns.isin(cols_not_in_x)]] # AnnData's concat discards var if they don't match in their keys, so we need to create a new var - created_var = _create_new_var(adata, cols_not_in_x) + created_var = pd.DataFrame(index=cols_not_in_x) new_adata.var = pd.concat([adata.var, created_var], axis=0) else: new_adata = adata @@ -336,141 +328,6 @@ def _get_column_indices(adata: AnnData, col_names: str | Iterable[str]) -> list[ return indices -def type_overview(data: AnnData, sort_by: str | None = None, sort_reversed: bool = False) -> None: # pragma: no cover - """Prints the current state of an :class:`~anndata.AnnData` object in a tree format. - - Output can be printed in sorted format by using one of `dtype`, `order`, `num_cats` or `None`, which sorts by data type, lexicographical order, - number of unique values (excluding NaN's) and unsorted respectively. Note that sorting by `num_cats` only affects - encoded variables currently and will display unencoded vars unsorted. - - Args: - data: :class:`~anndata.AnnData` object to display - sort_by: How the tree output should be sorted. One of `dtype`, `order`, `num_cats` or None (Defaults to None -> unsorted) - sort_reversed: Whether to sort in reversed order or not - - Examples: - >>> import ehrapy as ep - >>> adata = ep.dt.mimic_2(encoded=True) - >>> ep.ad.type_overview(adata) - """ - if isinstance(data, AnnData): - _adata_type_overview(data, sort_by, sort_reversed) - else: - raise ValueError(f"Unable to present object of type {type(data)}. Can only display AnnData objects!") - - -def _adata_type_overview( - adata: AnnData, sort_by: str | None = None, sort_reversed: bool = False -) -> None: # pragma: no cover - """Display the :class:`~anndata.AnnData object in its current state (encoded and unencoded variables, obs) - - Args: - adata: The :class:`~anndata.AnnData object to display - sort_by: Whether to sort output or not - sort_reversed: Whether to sort output in reversed order or not - """ - - tree = Tree( - f"[b green]Variable names for AnnData object with {len(adata.obs_names)} obs and {len(adata.var_names)} vars", - guide_style="underline2 bright_blue", - ) - - if "var_to_encoding" in adata.uns.keys(): - original_values = adata.uns["original_values_categoricals"] - branch = tree.add("🔐 Encoded variables", style="b green") - dtype_dict = _infer_dtype_per_encoded_var(list(original_values.keys()), original_values) - # sort encoded vars by lexicographical order of original values - if sort_by == "order": - encoded_list = sorted(original_values.keys(), reverse=sort_reversed) - for categorical in encoded_list: - branch.add( - f"[blue]{categorical} -> {dtype_dict[categorical][1]} categories;" - f" [green]{adata.uns['var_to_encoding'][categorical].replace('encoding', '').replace('_', ' ').strip()} [blue]encoded; [green]original data type: [blue]{dtype_dict[categorical][0]}" - ) - # sort encoded vars by data type of the original values or the number of unique values in original data (excluding NaNs) - elif sort_by == "dtype" or sort_by == "num_cats": - sorted_by_type = dict( - sorted( - dtype_dict.items(), key=lambda item: item[1][0 if sort_by == "dtype" else 1], reverse=sort_reversed - ) - ) - for categorical in sorted_by_type: - branch.add( - f"[blue]{categorical} -> {sorted_by_type[categorical][1]} categories;" - f" [green]{adata.uns['var_to_encoding'][categorical].replace('encoding', '').replace('_', ' ').strip()} [blue]encoded; [green]original data type: [blue]{sorted_by_type[categorical][0]}" - ) - # display in unsorted order - else: - encoded_list = original_values.keys() - for categorical in encoded_list: - branch.add( - f"[blue]{categorical} -> {dtype_dict[categorical][1]} categories;" - f" [green]{adata.uns['var_to_encoding'][categorical].replace('encoding', '').replace('_', ' ').strip()} [blue]encoded; [green]original data type: [blue]{dtype_dict[categorical][0]}" - ) - branch_num = tree.add(Text("🔓 Unencoded variables"), style="b green") - - if sort_by == "order": - var_names = sorted(adata.var_names.values, reverse=sort_reversed) - _sort_by_order_or_none(adata, branch_num, var_names) - elif sort_by == "dtype": - var_names = list(adata.var_names.values) - _sort_by_type(adata, branch_num, var_names, sort_reversed) - else: - var_names = list(adata.var_names.values) - _sort_by_order_or_none(adata, branch_num, var_names) - - if sort_by: - logger.info( - "Displaying AnnData object in sorted mode. Note that this might not be the exact same order of the variables in X or var are stored!" - ) - print(tree) - - -def _sort_by_order_or_none(adata: AnnData, branch, var_names: list[str]): - """Add branches to tree for sorting by order or unsorted.""" - var_names_val = list(adata.var_names.values) - for other_vars in var_names: - if not other_vars.startswith("ehrapycat"): - idx = var_names_val.index(other_vars) - unique_categoricals = pd.unique(adata.X[:, idx : idx + 1].flatten()) - data_type = pd.api.types.infer_dtype(unique_categoricals) - branch.add(f"[blue]{other_vars} -> [green]data type: [blue]{data_type}") - - -def _sort_by_type(adata: AnnData, branch, var_names: list[str], sort_reversed: bool): - """Sort tree output by datatype""" - tmp_dict = {} - var_names_val = list(adata.var_names.values) - - for other_vars in var_names: - if not other_vars.startswith("ehrapycat"): - idx = var_names_val.index(other_vars) - unique_categoricals = pd.unique(adata.X[:, idx : idx + 1].flatten()) - data_type = pd.api.types.infer_dtype(unique_categoricals) - tmp_dict[other_vars] = data_type - - sorted_by_type = dict(sorted(tmp_dict.items(), key=lambda item: item[1], reverse=sort_reversed)) - for var in sorted_by_type: - branch.add(f"[blue]{var} -> [green]data type: [blue]{sorted_by_type[var]}") - - -def _infer_dtype_per_encoded_var(encoded_list: list[str], original_values) -> dict[str, tuple[str, int]]: - """Infer dtype of each encoded varibale of an AnnData object.""" - dtype_dict = {} - for categorical in encoded_list: - unique_categoricals = pd.unique(original_values[categorical].flatten()) - categorical_type = pd.api.types.infer_dtype(unique_categoricals) - num_unique_values = pd.DataFrame(unique_categoricals).dropna()[0].nunique() - dtype_dict[categorical] = (categorical_type, num_unique_values) - - return dtype_dict - - -def _single_quote_string(name: str) -> str: # pragma: no cover - """Single quote a string to inject it into f-strings, since backslashes cannot be in double f-strings.""" - return f"'{name}'" - - def _assert_encoded(adata: AnnData): try: assert np.issubdtype(adata.X.dtype, np.number) @@ -478,6 +335,7 @@ def _assert_encoded(adata: AnnData): raise NotEncodedError("The AnnData object has not yet been encoded.") from AssertionError +@check_feature_types def get_numeric_vars(adata: AnnData) -> list[str]: """Fetches the column names for numeric variables in X. @@ -489,11 +347,7 @@ def get_numeric_vars(adata: AnnData) -> list[str]: """ _assert_encoded(adata) - # This behaviour is consistent with the previous behaviour, allowing for a simple fully numeric X - if EHRAPY_TYPE_KEY not in adata.var.columns: - return list(adata.var_names.values) - else: - return _get_var_indices_for_type(adata, NUMERIC_TAG) + return _get_var_indices_for_type(adata, NUMERIC_TAG) def assert_numeric_vars(adata: AnnData, vars: Sequence[str]): @@ -545,59 +399,6 @@ def set_numeric_vars( return adata -def _update_uns( - adata: AnnData, moved_columns: list[str], to_x: bool = False -) -> tuple[list[str], list[str], list[str] | None]: - """Updates .uns of adata to reflect the changes made on the object by moving columns from X to obs or vice versa. - - 1.) Moving `col1` from `X` to `obs`: `col1` is either numerical or non_numerical, so delete it from the corresponding entry in `uns` - 2.) Moving `col1` from `obs` to `X`: `col1` is either numerical or non_numerical, so add it to the corresponding entry in `uns` - - Args: - adata: class:`~anndata.AnnData` object - moved_columns: List of column names to be moved - to_x: Whether to move from `obs` to `X` or vice versa - - Returns: - :class:`~anndata.AnnData` object with updated .uns - """ - moved_columns_set = set(moved_columns) - if not to_x: # moving from `X` to `obs`, delete it from the corresponding entry in `uns`. - num_set = set(adata.uns["numerical_columns"].copy()) - non_num_set = set(adata.uns["non_numerical_columns"].copy()) - var_num = [] - for var in moved_columns_set: - if var in num_set: - var_num.append(var) - num_set -= {var} - elif var in non_num_set: - non_num_set -= {var} - return list(num_set), list(non_num_set), var_num - else: # moving from `obs` to `X`, add it to the corresponding entry in `uns`. - all_moved_non_num_columns = moved_columns_set & set(adata.obs.select_dtypes(exclude="number").columns) - all_moved_num_columns = list(moved_columns_set ^ all_moved_non_num_columns) - return all_moved_num_columns, list(all_moved_non_num_columns), None - - -def _create_new_var(adata: AnnData, cols_not_in_x: list[str]) -> pd.DataFrame: - """Create a new var DataFrame with the EHRAPY_TYPE_KEY column set for entries from .obs. - - Args: - adata: From where to get the .obs - cols_not_in_x: .obs columns to move to X - - Returns: - New var DataFrame with EHRAPY_TYPE_KEY column set for entries from .obs - """ - all_moved_num_columns = set(cols_not_in_x) & set(adata.obs.select_dtypes(include="number").columns) - - new_var = pd.DataFrame(index=cols_not_in_x) - new_var[EHRAPY_TYPE_KEY] = NON_NUMERIC_TAG - new_var.loc[list(all_moved_num_columns), EHRAPY_TYPE_KEY] = NUMERIC_TAG - - return new_var - - def _detect_binary_columns(df: pd.DataFrame, numerical_columns: list[str]) -> list[str]: """Detect all columns that contain only 0 and 1 (besides NaNs). @@ -628,7 +429,7 @@ def _cast_obs_columns(obs: pd.DataFrame) -> pd.DataFrame: """ # only cast non numerical columns object_columns = list(obs.select_dtypes(exclude=["number", "category", "bool"]).columns) - # type cast each non numerical column to either bool (if possible) or category else + # type cast each non-numerical column to either bool (if possible) or category else obs[object_columns] = obs[object_columns].apply( lambda obs_name: obs_name.astype("category") if not set(pd.unique(obs_name)).issubset({False, True, np.NaN}) diff --git a/ehrapy/data/_datasets.py b/ehrapy/data/_datasets.py index 55583eb1..37418404 100644 --- a/ehrapy/data/_datasets.py +++ b/ehrapy/data/_datasets.py @@ -3,6 +3,8 @@ from typing import TYPE_CHECKING from ehrapy import ehrapy_settings +from ehrapy.anndata import anndata_to_df, df_to_anndata, infer_feature_types, replace_feature_types +from ehrapy.anndata._constants import CATEGORICAL_TAG, DATE_TAG, FEATURE_TYPE_KEY, NUMERIC_TAG from ehrapy.io._read import read_csv, read_fhir, read_h5ad from ehrapy.preprocessing._encoding import encode @@ -37,6 +39,8 @@ def mimic_2( columns_obs_only=columns_obs_only, ) if encoded: + infer_feature_types(adata, output=None, verbose=False) + replace_feature_types(adata, "hour_icu_intime", NUMERIC_TAG) return encode(adata, autodetect=True) return adata @@ -66,7 +70,6 @@ def mimic_2_preprocessed() -> AnnData: def mimic_3_demo( - encoded: bool = False, anndata: bool = False, columns_obs_only: dict[str, list[str]] | list[str] | None = None, ) -> dict[str, AnnData] | dict[str, pd.DataFrame]: @@ -78,7 +81,6 @@ def mimic_3_demo( The resulting DataFrame can then be transformed into an AnnData object with :func:`~ehrapy.anndata.df_to_anndata`. Args: - encoded: Whether to return an already encoded object. anndata: Whether to return one AnnData object per CSV file. Defaults to False columns_obs_only: Columns to include in obs only and not X. @@ -97,10 +99,6 @@ def mimic_3_demo( columns_obs_only=columns_obs_only, archive_format="zip", ) - if encoded: - if not anndata: - raise ValueError("Can only encode AnnData objects. Set 'anndata=True' to get AnnData objects.") - encode(data, autodetect=True) return data @@ -133,6 +131,7 @@ def heart_failure(encoded: bool = False, columns_obs_only: dict[str, list[str]] index_column="patient_id", ) if encoded: + infer_feature_types(adata, output=None, verbose=False) return encode(adata, autodetect=True) return adata @@ -170,6 +169,11 @@ def diabetes_130_raw( columns_obs_only=columns_obs_only, ) if encoded: + infer_feature_types(adata, output=None, verbose=False) + replace_feature_types( + adata, ["admission_source_id", "discharge_disposition_id", "encounter_id", "patient_nbr"], CATEGORICAL_TAG + ) + replace_feature_types(adata, ["num_procedures", "number_diagnoses", "time_in_hospital"], NUMERIC_TAG) return encode(adata, autodetect=True) return adata @@ -211,6 +215,8 @@ def diabetes_130_fairlearn( ) if encoded: + infer_feature_types(adata, output=None, verbose=False) + replace_feature_types(adata, ["time_in_hospital", "number_diagnoses", "num_procedures"], NUMERIC_TAG) return encode(adata, autodetect=True) return adata @@ -238,13 +244,14 @@ def chronic_kidney_disease( >>> adata = ep.dt.chronic_kidney_disease(encoded=True) """ adata = read_csv( - dataset_path=f"{ehrapy_settings.datasetdir}/chronic_kidney_disease_precessed.csv", + dataset_path=f"{ehrapy_settings.datasetdir}/chronic_kidney_disease.csv", download_dataset_name="chronic_kidney_disease.csv", backup_url="https://figshare.com/ndownloader/files/33989261", columns_obs_only=columns_obs_only, index_column="Patient_id", ) if encoded: + infer_feature_types(adata, output=None, verbose=False) return encode(adata, autodetect=True) return adata @@ -279,6 +286,7 @@ def breast_tissue( index_column="patient_id", ) if encoded: + infer_feature_types(adata, output=None, verbose=False) return encode(adata, autodetect=True) return adata @@ -312,6 +320,8 @@ def cervical_cancer_risk_factors( index_column="patient_id", ) if encoded: + infer_feature_types(adata, output=None, verbose=False) + replace_feature_types(adata, ["STDs (number)", "STDs: Number of diagnosis"], NUMERIC_TAG) return encode(adata, autodetect=True) return adata @@ -346,6 +356,7 @@ def dermatology( index_column="patient_id", ) if encoded: + infer_feature_types(adata, output=None, verbose=False) return encode(adata, autodetect=True) return adata @@ -380,6 +391,7 @@ def echocardiogram( index_column="patient_id", ) if encoded: + infer_feature_types(adata, output=None, verbose=False) return encode(adata, autodetect=True) return adata @@ -413,6 +425,7 @@ def hepatitis( index_column="patient_id", ) if encoded: + infer_feature_types(adata, output=None, verbose=False) return encode(adata, autodetect=True) return adata @@ -447,6 +460,8 @@ def statlog_heart( index_column="patient_id", ) if encoded: + infer_feature_types(adata, output=None, verbose=False) + replace_feature_types(adata, "number of major vessels", NUMERIC_TAG) return encode(adata, autodetect=True) return adata @@ -480,6 +495,7 @@ def thyroid( index_column="patient_id", ) if encoded: + infer_feature_types(adata, output=None, verbose=False) return encode(adata, autodetect=True) return adata @@ -514,6 +530,7 @@ def breast_cancer_coimbra( index_column="patient_id", ) if encoded: + infer_feature_types(adata, output=None, verbose=False) return encode(adata, autodetect=True) return adata @@ -548,6 +565,7 @@ def parkinsons( index_column="measurement_id", ) if encoded: + infer_feature_types(adata, output=None, verbose=False) return encode(adata, autodetect=True) return adata @@ -581,6 +599,7 @@ def parkinsons_telemonitoring( index_column="measurement_id", ) if encoded: + infer_feature_types(adata, output=None, verbose=False) return encode(adata, autodetect=True) return adata @@ -615,6 +634,7 @@ def parkinsons_disease_classification( index_column="measurement_id", ) if encoded: + infer_feature_types(adata, output=None, verbose=False) return encode(adata, autodetect=True) return adata @@ -649,6 +669,7 @@ def parkinson_dataset_with_replicated_acoustic_features( index_column="measurement_id", ) if encoded: + infer_feature_types(adata, output=None, verbose=False) return encode(adata, autodetect=True) return adata @@ -683,6 +704,9 @@ def heart_disease( index_column="patient_id", ) if encoded: + infer_feature_types(adata, output=None, verbose=False) + replace_feature_types(adata, ["num"], NUMERIC_TAG) + replace_feature_types(adata, ["thal"], CATEGORICAL_TAG) return encode(adata, autodetect=True) return adata @@ -717,7 +741,16 @@ def synthea_1k_sample( archive_format="zip", ) + df = anndata_to_df(adata) + df.drop( + columns=[col for col in df.columns if any(isinstance(x, (list, dict)) for x in df[col].dropna())], inplace=True + ) + df.drop(columns=df.columns[df.isna().all()], inplace=True) + adata = df_to_anndata(df, index_column="id") + if encoded: + infer_feature_types(adata, output=None, verbose=False) + replace_feature_types(adata, ["resource.multipleBirthInteger", "resource.numberOfSeries"], NUMERIC_TAG) return encode(adata, autodetect=True) return adata diff --git a/ehrapy/io/_read.py b/ehrapy/io/_read.py index 157938e9..779034f6 100644 --- a/ehrapy/io/_read.py +++ b/ehrapy/io/_read.py @@ -44,7 +44,7 @@ def read_csv( columns_x_only: These columns will be added to X only and all remaining columns to obs. Note that datetime columns will always be added to .obs though. return_dfs: Whether to return one or several Pandas DataFrames. - cache: Whether to write to cache when reading or not. Defaults to False . + cache: Whether to write to cache when reading or not. Defaults to False. download_dataset_name: Name of the file or directory after download. backup_url: URL to download the data file(s) from, if the dataset is not yet on disk. is_archive: Whether the downloaded file is an archive. @@ -511,8 +511,6 @@ def _read_from_cache(path_cache: Path) -> AnnData: # in case columns_obs_only has not been passed except KeyError: columns_obs_only = [] - # required since reading from cache returns a numpy array instead of a list here - cached_adata.uns["numerical_columns"] = list(cached_adata.uns["numerical_columns"]) # recreate the original AnnData object with the index column for obs and obs only columns cached_adata = _decode_cached_adata(cached_adata, columns_obs_only) @@ -627,8 +625,10 @@ def _decode_cached_adata(adata: AnnData, column_obs_only: list[str]) -> AnnData: if not var_name.startswith("ehrapycat_"): break value_name = var_name[10:] - original_values = adata.uns["original_values_categoricals"][value_name] - adata.X[:, idx : idx + 1] = original_values + if value_name not in adata.obs.keys(): + raise ValueError(f"Unencoded values for feature '{value_name}' not found in obs!") + original_values = adata.obs[value_name] + adata.X[:, idx] = original_values # update var name per categorical var_names[idx] = value_name # drop all columns, that are not obs only in obs @@ -639,11 +639,8 @@ def _decode_cached_adata(adata: AnnData, column_obs_only: list[str]) -> AnnData: # set the new var names (unencoded ones) adata.var.index = var_names adata.layers["original"] = adata.X.copy() - # reset uns but keep numerical columns - numerical_columns = adata.uns["numerical_columns"] + # reset uns adata.uns = OrderedDict() - adata.uns["numerical_columns"] = numerical_columns - adata.uns["non_numerical_columns"] = list(set(adata.var_names) ^ set(numerical_columns)) return adata diff --git a/ehrapy/preprocessing/__init__.py b/ehrapy/preprocessing/__init__.py index 4cd5de71..2d9b3e6c 100644 --- a/ehrapy/preprocessing/__init__.py +++ b/ehrapy/preprocessing/__init__.py @@ -1,6 +1,6 @@ from ehrapy.preprocessing._balanced_sampling import balanced_sample from ehrapy.preprocessing._bias import detect_bias -from ehrapy.preprocessing._encoding import encode, undo_encoding +from ehrapy.preprocessing._encoding import encode from ehrapy.preprocessing._highly_variable_features import highly_variable_features from ehrapy.preprocessing._imputation import ( explicit_impute, @@ -28,7 +28,6 @@ "balanced_sample", "detect_bias", "encode", - "undo_encoding", "highly_variable_features", "explicit_impute", "knn_impute", diff --git a/ehrapy/preprocessing/_bias.py b/ehrapy/preprocessing/_bias.py index 5d325849..8a21d003 100644 --- a/ehrapy/preprocessing/_bias.py +++ b/ehrapy/preprocessing/_bias.py @@ -7,7 +7,7 @@ from anndata import AnnData from ehrapy.anndata import anndata_to_df, check_feature_types -from ehrapy.anndata._constants import CATEGORICAL_TAG, CONTINUOUS_TAG, DATE_TAG, FEATURE_TYPE_KEY +from ehrapy.anndata._constants import CATEGORICAL_TAG, DATE_TAG, FEATURE_TYPE_KEY, NUMERIC_TAG @check_feature_types @@ -166,7 +166,7 @@ def _get_group_name(encoded_feature: str, group_val: int) -> str | int: "Standardized Mean Difference": [], } adata.uns["smd"] = {} - continuous_var_names = adata.var_names[adata.var[FEATURE_TYPE_KEY] == CONTINUOUS_TAG] + continuous_var_names = adata.var_names[adata.var[FEATURE_TYPE_KEY] == NUMERIC_TAG] for sens_feature in cat_sens_features: sens_feature_groups = sorted(adata_df[sens_feature].unique()) if len(sens_feature_groups) == 1: @@ -188,8 +188,10 @@ def _get_group_name(encoded_feature: str, group_val: int) -> str | int: for comp_feature in continuous_var_names: if abs_smd[comp_feature] > smd_threshold: smd_results["Sensitive Feature"].append(sens_feature) - _get_group_name(sens_feature, group) if sens_feature.startswith("ehrapycat_") else group - smd_results["Sensitive Group"].append(group) + group_name = ( + _get_group_name(sens_feature, group) if sens_feature.startswith("ehrapycat_") else group + ) + smd_results["Sensitive Group"].append(group_name) smd_results["Compared Feature"].append(comp_feature) smd_results["Standardized Mean Difference"].append(smd[comp_feature]) adata.uns["smd"][sens_feature] = smd_df diff --git a/ehrapy/preprocessing/_encoding.py b/ehrapy/preprocessing/_encoding.py index 0cce853a..b0564e8b 100644 --- a/ehrapy/preprocessing/_encoding.py +++ b/ehrapy/preprocessing/_encoding.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections import OrderedDict, defaultdict +from collections import OrderedDict from itertools import chain from typing import Any @@ -8,26 +8,21 @@ import pandas as pd from anndata import AnnData from lamin_utils import logger -from rich import print from rich.progress import BarColumn, Progress from sklearn.preprocessing import LabelEncoder, OneHotEncoder +from ehrapy.anndata import anndata_to_df, check_feature_types from ehrapy.anndata._constants import ( CATEGORICAL_TAG, - CONTINUOUS_TAG, - DATE_TAG, - EHRAPY_TYPE_KEY, FEATURE_TYPE_KEY, - NON_NUMERIC_ENCODED_TAG, - NON_NUMERIC_TAG, NUMERIC_TAG, ) from ehrapy.anndata.anndata_ext import _get_var_indices_for_type -multi_encoding_modes = {"hash"} -available_encodings = {"one-hot", "label", "count", *multi_encoding_modes} +available_encodings = {"one-hot", "label"} +@check_feature_types def encode( adata: AnnData, autodetect: bool | dict = False, @@ -37,7 +32,7 @@ def encode( Categorical values could be either passed via parameters or are autodetected on the fly. The categorical values are also stored in obs and uns (for keeping the original, unencoded values). - The current encoding modes for each variable are also stored in uns (`var_to_encoding` key). + The current encoding modes for each variable are also stored in adata.var['encoding_mode']. Variable names in var are updated according to the encoding modes used. A variable name starting with `ehrapycat_` indicates an encoded column (or part of it). @@ -49,8 +44,6 @@ def encode( Available encodings are: 1. one-hot (https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html) 2. label (https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html) - 3. count (https://contrib.scikit-learn.org/category_encoders/count.html) - 4. hash (https://contrib.scikit-learn.org/category_encoders/hashing.html) Args: adata: A :class:`~anndata.AnnData` object. @@ -65,7 +58,7 @@ def encode( Examples: >>> import ehrapy as ep >>> adata = ep.dt.mimic_2() - >>> adata_encoded = ep.pp.encode(adata, autodetect=True, encodings="one_hot_encoding") + >>> adata_encoded = ep.pp.encode(adata, autodetect=True, encodings="one-hot") >>> # Example using custom encodings per columns: >>> import ehrapy as ep @@ -75,259 +68,246 @@ def encode( ... adata, autodetect=False, encodings={"label": ["col1", "col2"], "one-hot": ["col3"]} ... ) """ - if isinstance(adata, AnnData): - if isinstance(encodings, str) and not autodetect: - raise ValueError("Passing a string for parameter encodings is only possible when using autodetect=True!") - elif autodetect and not isinstance(encodings, (str, type(None))): - raise ValueError( - f"Setting encode mode with autodetect=True only works by passing a string (encode mode name) or None not {type(encodings)}!" - ) - if "original" not in adata.layers.keys(): - adata.layers["original"] = adata.X.copy() - - # autodetect categorical values, which could lead to more categoricals - if autodetect: - if "var_to_encoding" in adata.uns.keys(): - logger.warning( - "The current AnnData object has been already encoded. Returning original AnnData object!" - ) - return adata - categoricals_names = _get_var_indices_for_type(adata, NON_NUMERIC_TAG) - - # no columns were detected, that would require an encoding (e.g. non-numerical columns) - if not categoricals_names: - logger.warning("Detected no columns that need to be encoded. Leaving passed AnnData object unchanged.") - return adata - # copy uns so it can be used in encoding process without mutating the original anndata object - orig_uns_copy = adata.uns.copy() - _add_categoricals_to_uns(adata, orig_uns_copy, categoricals_names) + if not isinstance(adata, AnnData): + raise ValueError(f"Cannot encode object of type {type(adata)}. Can only encode AnnData objects!") - encoded_x = None - encoded_var_names = adata.var_names.to_list() - if encodings not in available_encodings - multi_encoding_modes: - raise ValueError( - f"Unknown encoding mode {encodings}. Please provide one of the following encoding modes:\n" - f"{available_encodings - multi_encoding_modes}" - ) - single_encode_mode_switcher = { - "one-hot": _one_hot_encoding, - "label": _label_encoding, - } - with Progress( - "[progress.description]{task.description}", - BarColumn(), - "[progress.percentage]{task.percentage:>3.0f}%", - refresh_per_second=1500, - ) as progress: - task = progress.add_task(f"[red]Running {encodings} on detected columns ...", total=1) - # encode using the desired mode - encoded_x, encoded_var_names = single_encode_mode_switcher[encodings]( # type: ignore - adata, - encoded_x, - orig_uns_copy, - encoded_var_names, - categoricals_names, - progress, - task, - ) - progress.update(task, description="Updating layer originals ...") + if isinstance(encodings, str) and not autodetect: + raise ValueError("Passing a string for parameter encodings is only possible when using autodetect=True!") + elif autodetect and not isinstance(encodings, (str, type(None))): + raise ValueError( + f"Setting encode mode with autodetect=True only works by passing a string (encode mode name) or None not {type(encodings)}!" + ) - # update layer content with the latest categorical encoding and the old other values - updated_layer = _update_layer_after_encoding( - adata.layers["original"], - encoded_x, - encoded_var_names, - adata.var_names.to_list(), - categoricals_names, - ) - progress.update(task, description=f"[bold blue]Finished {encodings} of autodetected columns.") + if "original" not in adata.layers.keys(): + adata.layers["original"] = adata.X.copy() - # copy non-encoded columns, and add new tag for encoded columns. This is needed to track encodings - new_var = pd.DataFrame(index=encoded_var_names) - new_var[EHRAPY_TYPE_KEY] = adata.var[EHRAPY_TYPE_KEY].copy() - new_var.loc[new_var.index.str.contains("ehrapycat")] = NON_NUMERIC_ENCODED_TAG - if FEATURE_TYPE_KEY in adata.var.keys(): - new_var[FEATURE_TYPE_KEY] = adata.var[FEATURE_TYPE_KEY].copy() - new_var.loc[new_var.index.str.contains("ehrapycat"), FEATURE_TYPE_KEY] = CATEGORICAL_TAG + # autodetect categorical values based on feature types stored in adata.var[FEATURE_TYPE_KEY] + if autodetect: + categoricals_names = _get_var_indices_for_type(adata, CATEGORICAL_TAG) - encoded_ann_data = AnnData( - encoded_x, - obs=adata.obs.copy(), - var=new_var, - uns=orig_uns_copy, - layers={"original": updated_layer}, + if "encoding_mode" in adata.var.keys(): + if adata.var["encoding_mode"].isnull().values.any(): + not_encoded_features = adata.var["encoding_mode"].isna().index + categoricals_names = [ + _categorical for _categorical in categoricals_names if _categorical in not_encoded_features + ] + else: + logger.warning( + "The current AnnData object has been already encoded. Returning original AnnData object!" ) - encoded_ann_data.uns["var_to_encoding"] = {categorical: encodings for categorical in categoricals_names} - encoded_ann_data.uns["encoding_to_var"] = {encodings: categoricals_names} - - _add_categoricals_to_obs(adata, encoded_ann_data, categoricals_names) - - # user passed categorical values with encoding mode for each of them - else: - # re-encode data - if "var_to_encoding" in adata.uns.keys(): - encodings = _reorder_encodings(adata, encodings) # type: ignore - adata = _undo_encoding(adata, "all") + return adata - # are all specified encodings valid? - for encoding in encodings.keys(): # type: ignore - if encoding not in available_encodings: - raise ValueError( - f"Unknown encoding mode {encoding}. Please provide one of the following encoding modes:\n" - f"{available_encodings}" - ) - adata.uns["encoding_to_var"] = encodings - - categoricals_not_flat = list(chain(*encodings.values())) # type: ignore - # this is needed since multi-column encoding will get passed a list of list instead of a flat list - categoricals = list( - chain( - *( - _categoricals if isinstance(_categoricals, list) else (_categoricals,) - for _categoricals in categoricals_not_flat - ) - ) + # filter out categorical columns, that are already stored numerically + df_adata = anndata_to_df(adata) + categoricals_names = [ + feat for feat in categoricals_names if not np.all(df_adata[feat].apply(type).isin([int, float])) + ] + + # no columns were detected, that would require an encoding (e.g. non-numerical columns) + if not categoricals_names: + logger.warning("Detected no columns that need to be encoded. Leaving passed AnnData object unchanged.") + return adata + # update obs with the original categorical values + updated_obs = _update_obs(adata, categoricals_names) + + encoded_x = None + encoded_var_names = adata.var_names.to_list() + unencoded_var_names = adata.var_names.to_list() + if encodings not in available_encodings: + raise ValueError( + f"Unknown encoding mode {encodings}. Please provide one of the following encoding modes:\n" + f"{available_encodings}" ) - # ensure no categorical column gets encoded twice - if len(categoricals) != len(set(categoricals)): - raise ValueError( - "The categorical column names given contain at least one duplicate column. " - "Check the column names to ensure that no column is encoded twice!" - ) - elif any(cat in adata.var_names[adata.var[EHRAPY_TYPE_KEY] == NUMERIC_TAG] for cat in categoricals): - logger.warning( - "At least one of passed column names seems to have numerical dtype. In general it is not recommended " - "to encode numerical columns!" - ) - orig_uns_copy = adata.uns.copy() - _add_categoricals_to_uns(adata, orig_uns_copy, categoricals) - var_to_encoding = {} if "var_to_encoding" not in adata.uns.keys() else adata.uns["var_to_encoding"] - encoded_x = None - encoded_var_names = adata.var_names.to_list() - with Progress( - "[progress.description]{task.description}", - BarColumn(), - "[progress.percentage]{task.percentage:>3.0f}%", - refresh_per_second=1500, - ) as progress: - for encoding in encodings.keys(): # type: ignore - task = progress.add_task(f"[red]Setting up {encodings}", total=1) - encode_mode_switcher = { - "one-hot": _one_hot_encoding, - "label": _label_encoding, - } - progress.update(task, description=f"Running {encoding} ...") - # perform the actual encoding - encoded_x, encoded_var_names = encode_mode_switcher[encoding]( - adata, - encoded_x, - orig_uns_copy, - encoded_var_names, - encodings[encoding], # type: ignore - progress, - task, # type: ignore - ) - # update encoding history in uns - for categorical in encodings[encoding]: # type: ignore - # multi column encoding modes -> multiple encoded columns - if isinstance(categorical, list): - for column_name in categorical: - var_to_encoding[column_name] = encoding - else: - var_to_encoding[categorical] = encoding - - # update original layer content with the new categorical encoding and the old other values + single_encode_mode_switcher = { + "one-hot": _one_hot_encoding, + "label": _label_encoding, + } + with Progress( + "[progress.description]{task.description}", + BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + refresh_per_second=1500, + ) as progress: + task = progress.add_task(f"[red]Running {encodings} on detected columns ...", total=1) + # encode using the desired mode + encoded_x, encoded_var_names, unencoded_var_names = single_encode_mode_switcher[encodings]( # type: ignore + adata, + encoded_x, + updated_obs, + encoded_var_names, + unencoded_var_names, + categoricals_names, + progress, + task, + ) + progress.update(task, description="Updating layer originals ...") + + # update layer content with the latest categorical encoding and the old other values updated_layer = _update_layer_after_encoding( adata.layers["original"], encoded_x, encoded_var_names, adata.var_names.to_list(), - categoricals, + categoricals_names, ) + progress.update(task, description=f"[bold blue]Finished {encodings} of autodetected columns.") # copy non-encoded columns, and add new tag for encoded columns. This is needed to track encodings new_var = pd.DataFrame(index=encoded_var_names) - new_var[EHRAPY_TYPE_KEY] = adata.var[EHRAPY_TYPE_KEY].copy() - new_var.loc[new_var.index.str.contains("ehrapycat")] = NON_NUMERIC_ENCODED_TAG - if FEATURE_TYPE_KEY in adata.var.keys(): - new_var[FEATURE_TYPE_KEY] = adata.var[FEATURE_TYPE_KEY].copy() - new_var.loc[new_var.index.str.contains("ehrapycat"), FEATURE_TYPE_KEY] = CATEGORICAL_TAG - - try: - encoded_ann_data = AnnData( - X=encoded_x, - obs=adata.obs.copy(), - var=new_var, - uns=orig_uns_copy, - layers={"original": updated_layer}, - ) - # update current encodings in uns - encoded_ann_data.uns["var_to_encoding"] = var_to_encoding + new_var[FEATURE_TYPE_KEY] = adata.var[FEATURE_TYPE_KEY].copy() + new_var[FEATURE_TYPE_KEY].loc[new_var.index.str.contains("ehrapycat")] = CATEGORICAL_TAG - # if the user did not pass every non-numerical column for encoding, an Anndata object cannot be created - except ValueError: - raise AnnDataCreationError( - "Creation of AnnData object failed. Ensure that you passed all non numerical, " - "categorical values for encoding!" - ) from None + new_var["unencoded_var_names"] = unencoded_var_names - _add_categoricals_to_obs(adata, encoded_ann_data, categoricals) + new_var["encoding_mode"] = [encodings if var in categoricals_names else None for var in unencoded_var_names] - encoded_ann_data.X = encoded_ann_data.X.astype(np.float32) + encoded_ann_data = AnnData( + encoded_x, + obs=updated_obs, + var=new_var, + uns=adata.uns.copy(), + layers={"original": updated_layer}, + ) - return encoded_ann_data + # user passed categorical values with encoding mode for each of them else: - raise ValueError(f"Cannot encode object of type {type(adata)}. Can only encode AnnData objects!") + # re-encode data + if "encoding_mode" in adata.var.keys(): + encodings = _reorder_encodings(adata, encodings) # type: ignore + adata = _undo_encoding(adata) + + # are all specified encodings valid? + for encoding in encodings.keys(): # type: ignore + if encoding not in available_encodings: + raise ValueError( + f"Unknown encoding mode {encoding}. Please provide one of the following encoding modes:\n" + f"{available_encodings}" + ) + categoricals = list(chain(*encodings.values())) # type: ignore -def undo_encoding( - data: AnnData, - columns: str = "all", -) -> AnnData | None: - """Undo the current encodings applied to all columns in X. + # ensure no categorical column gets encoded twice + if len(categoricals) != len(set(categoricals)): + raise ValueError( + "The categorical column names given contain at least one duplicate column. " + "Check the column names to ensure that no column is encoded twice!" + ) + elif any( + _categorical in adata.var_names[adata.var[FEATURE_TYPE_KEY] == NUMERIC_TAG] for _categorical in categoricals + ): + logger.warning( + "At least one of passed column names seems to have numerical dtype. In general it is not recommended " + "to encode numerical columns!" + ) - This currently resets the AnnData object to its initial state. + updated_obs = _update_obs(adata, categoricals) + + encoding_mode = {} + encoded_x = None + encoded_var_names = adata.var_names.to_list() + unencoded_var_names = adata.var_names.to_list() + with Progress( + "[progress.description]{task.description}", + BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + refresh_per_second=1500, + ) as progress: + for encoding in encodings.keys(): # type: ignore + task = progress.add_task(f"[red]Setting up {encodings}", total=1) + encode_mode_switcher = { + "one-hot": _one_hot_encoding, + "label": _label_encoding, + } + progress.update(task, description=f"Running {encoding} ...") + # perform the actual encoding + encoded_x, encoded_var_names, unencoded_var_names = encode_mode_switcher[encoding]( + adata, + encoded_x, + updated_obs, + encoded_var_names, + unencoded_var_names, + encodings[encoding], # type: ignore + progress, + task, # type: ignore + ) - Args: - data: The :class:`~anndata.AnnData` object - columns: The names of the columns to reset encoding for. Defaults to all columns. + for _categorical in encodings[encoding]: # type: ignore + _categorical = [_categorical] if isinstance(_categorical, str) else _categorical # type: ignore + for column_name in _categorical: + # get idx of column in unencoded_var_names + indices = [i for i, var in enumerate(unencoded_var_names) if var == column_name] + encoded_var = [encoded_var_names[idx] for idx in indices] + for var in encoded_var: + encoding_mode[var] = encoding + + # update original layer content with the new categorical encoding and the old other values + updated_layer = _update_layer_after_encoding( + adata.layers["original"], + encoded_x, + encoded_var_names, + adata.var_names.to_list(), + categoricals, + ) - Returns: - A (partially) encoding reset :class:`~anndata.AnnData` + # copy non-encoded columns, and add new tag for encoded columns. This is needed to track encodings + new_var = pd.DataFrame(index=encoded_var_names) - Examples: - >>> import ehrapy as ep - >>> # adata_encoded is an encoded AnnData object - >>> adata_undone = ep.pp.encode.undo_encoding(adata_encoded) - """ - if isinstance(data, AnnData): - return _undo_encoding(data, columns) - else: - raise ValueError(f"Cannot decode object of type {type(data)}. Can only decode AnnData objects!") + new_var[FEATURE_TYPE_KEY] = adata.var[FEATURE_TYPE_KEY].copy() + new_var[FEATURE_TYPE_KEY].loc[new_var.index.str.contains("ehrapycat")] = CATEGORICAL_TAG + + new_var["unencoded_var_names"] = unencoded_var_names + + # update encoding mode in var, keeping already annotated columns + if "encoding_mode" in adata.var.keys(): + encoding_mode.update({key: value for key, value in adata.var["encoding_mode"].items() if value is not None}) + new_var["encoding_mode"] = [None] * len(new_var) + for _categorical in encoding_mode.keys(): + new_var["encoding_mode"].loc[_categorical] = encoding_mode[_categorical] + + try: + encoded_ann_data = AnnData( + X=encoded_x, + obs=updated_obs, + var=new_var, + uns=adata.uns.copy(), + layers={"original": updated_layer}, + ) + + # if the user did not pass every non-numerical column for encoding, an Anndata object cannot be created + except ValueError: + raise AnnDataCreationError( + "Creation of AnnData object failed. Ensure that you passed all non numerical, " + "categorical values for encoding!" + ) from None + + encoded_ann_data.X = encoded_ann_data.X.astype(np.float32) + + return encoded_ann_data def _one_hot_encoding( adata: AnnData, X: np.ndarray | None, - uns: dict[str, Any], + updated_obs: pd.DataFrame, var_names: list[str], + unencoded_var_names: list[str], categories: list[str], progress: Progress, task, -) -> tuple[np.ndarray, list[str]]: +) -> tuple[np.ndarray, list[str], list[str]]: """Encode categorical columns using one hot encoding. Args: adata: The current AnnData object X: Current (encoded) X - uns: A copy of the original uns + updated_obs: A copy of the original obs where the original categorical values are stored that will be encoded var_names: Var names of current AnnData object categories: The name of the categorical columns to be encoded Returns: Encoded new X and the corresponding new var names """ - original_values = _initial_encoding(uns, categories) + original_values = _initial_encoding(updated_obs, categories) progress.update(task, description="[bold blue]Running one-hot encoding on passed columns ...") encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False).fit(original_values) @@ -336,6 +316,7 @@ def _one_hot_encoding( for idx, category in enumerate(categories) for suffix in encoder.categories_[idx] ] + unencoded_prefixes = [category for idx, category in enumerate(categories) for suffix in encoder.categories_[idx]] transformed = encoder.transform(original_values) # X is None if this is the first encoding "round" -> take the former X if X is None: @@ -343,34 +324,37 @@ def _one_hot_encoding( progress.advance(task, 1) progress.update(task, description="[blue]Updating X and var ...") - temp_x, temp_var_names = _update_encoded_data(X, transformed, var_names, categorical_prefixes, categories) + temp_x, temp_var_names, unencoded_var_names = _update_encoded_data( + X, transformed, var_names, categorical_prefixes, categories, unencoded_prefixes, unencoded_var_names + ) progress.update(task, description="[blue]Finished one-hot encoding.") - return temp_x, temp_var_names + return temp_x, temp_var_names, unencoded_var_names def _label_encoding( adata: AnnData, X: np.ndarray | None, - uns: dict[str, Any], + updated_obs: pd.DataFrame, var_names: list[str], + unencoded_var_names: list[str], categoricals: list[str], progress: Progress, task, -) -> tuple[np.ndarray, list[str]]: +) -> tuple[np.ndarray, list[str], list[str]]: """Encode categorical columns using label encoding. Args: adata: The current AnnData object X: Current (encoded) X - uns: A copy of the original uns + updated_obs: A copy of the original obs where the original categorical values are stored that will be encoded var_names: Var names of current AnnData object categoricals: The name of the categorical columns, that need to be encoded Returns: Encoded new X and the corresponding new var names """ - original_values = _initial_encoding(uns, categoricals) + original_values = _initial_encoding(updated_obs, categoricals) # label encoding expects input array to be 1D, so iterate over all columns and encode them one by one for idx in range(original_values.shape[1]): progress.update(task, description=f"[blue]Running label encoding on column {categoricals[idx]} ...") @@ -381,16 +365,18 @@ def _label_encoding( # need a column vector instead of row vector original_values[:, idx : idx + 1] = transformed[..., None] progress.advance(task, 1 / len(categoricals)) - category_prefixes = [f"ehrapycat_{categorical}" for categorical in categoricals] + category_prefixes = [f"ehrapycat_{_categorical}" for _categorical in categoricals] # X is None if this is the first encoding "round" -> take the former X if X is None: X = adata.X progress.update(task, description="[blue]Updating X and var ...") - temp_x, temp_var_names = _update_encoded_data(X, original_values, var_names, category_prefixes, categoricals) + temp_x, temp_var_names, unencoded_var_names = _update_encoded_data( + X, original_values, var_names, category_prefixes, categoricals, categoricals, unencoded_var_names + ) progress.update(task, description="[blue]Finished label encoding.") - return temp_x, temp_var_names + return temp_x, temp_var_names, unencoded_var_names def _update_layer_after_encoding( @@ -440,47 +426,15 @@ def _update_layer_after_encoding( raise ValueError("Ensure that all columns which require encoding are being encoded.") from e -def _update_multi_encoded_data( - X: np.ndarray, - transformed: np.ndarray, - var_names: list[str], - encoded_var_names: list[str], - categoricals: list[str], -) -> tuple[np.ndarray, list[str]]: - """Update X and var_names after applying multi column encoding modes to some columns - - Args: - X: Current (former) X - transformed: The encoded (transformed) categorical columns - var_names: Var names of current AnnData object - encoded_var_names: The name(s) of the encoded column(s) - categoricals: The categorical values that were encoded recently - - Returns: - Encoded new X and the corresponding new var names - """ - idx = [] - for pos, name in enumerate(var_names): - if name in categoricals: - idx.append(pos) - # delete the original categorical column - del_cat_column_x = np.delete(X, list(idx), 1) - # create the new, encoded X - temp_x = np.hstack((transformed, del_cat_column_x)) - # delete old categorical name - var_names = [col_name for col_idx, col_name in enumerate(var_names) if col_idx not in idx] - temp_var_names = encoded_var_names + var_names - - return temp_x, temp_var_names - - def _update_encoded_data( X: np.ndarray, transformed: np.ndarray, var_names: list[str], categorical_prefixes: list[str], categoricals: list[str], -) -> tuple[np.ndarray, list[str]]: + unencoded_prefixes: list[str], + unencoded_var_names: list[str], +) -> tuple[np.ndarray, list[str], list[str]]: """Update X and var_names after each encoding. Args: @@ -489,9 +443,10 @@ def _update_encoded_data( var_names: Var names of current AnnData object categorical_prefixes: The name(s) of the encoded column(s) categoricals: The categorical values that were encoded recently + unencoded_prefixes: The unencoded names of the categorical columns that were encoded Returns: - Encoded new X and the corresponding new var names + Encoded new X, the corresponding new var names, and the unencoded var names """ idx = _get_categoricals_old_indices(var_names, categoricals) # delete the original categorical column @@ -502,94 +457,69 @@ def _update_encoded_data( var_names = [col_name for col_idx, col_name in enumerate(var_names) if col_idx not in idx] temp_var_names = categorical_prefixes + var_names - return temp_x, temp_var_names + unencoded_var_names = [col_name for col_idx, col_name in enumerate(unencoded_var_names) if col_idx not in idx] + unencoded_var_names = unencoded_prefixes + unencoded_var_names + + return temp_x, temp_var_names, unencoded_var_names def _initial_encoding( - uns: dict[str, Any], + obs: pd.DataFrame, categoricals: list[str], ) -> np.ndarray: """Get all original values for all categoricals that need to be encoded (again). Args: - uns: A copy of the original AnnData object's uns + obs: A copy of the original obs where the original categorical values are stored that will be encoded categoricals: All categoricals that need to be encoded Returns: Numpy array of all original categorial values """ - uns_: dict[str, np.ndarray] = uns # create numpy array from all original categorical values, that will be encoded (again) - array = np.array( - [uns_["original_values_categoricals"][categoricals[i]].ravel() for i in range(len(categoricals))] - ).transpose() + array = np.array([obs[categoricals[i]].ravel() for i in range(len(categoricals))]).transpose() return array def _undo_encoding( adata: AnnData, - columns: str = "all", - suppress_warning: bool = False, + verbose: bool = True, ) -> AnnData | None: """Undo the current encodings applied to all columns in X. This currently resets the AnnData object to its initial state. Args: adata: The AnnData object - columns: The names of the columns to reset encoding for. Defaults to all columns. This resets the AnnData object to its initial state. - suppress_warning: Whether warnings should be suppressed or not. + verbose: Set to False to suppress warnings. Defaults to True. Returns: A (partially) encoding reset AnnData object """ - if "var_to_encoding" not in adata.uns.keys(): - if not suppress_warning: - logger.warning("Calling undo_encoding on unencoded AnnData object.") - return None + # get all encoded features + categoricals = _get_encoded_features(adata) - # get all encoded variables - encoded_categoricals = list(adata.uns["original_values_categoricals"].keys()) # get all columns that should be stored in obs only - columns_obs_only = [ - column_name for column_name in list(adata.obs.columns) if column_name not in encoded_categoricals - ] + columns_obs_only = [column_name for column_name in list(adata.obs.columns) if column_name not in categoricals] - if columns == "all": - categoricals = list(adata.uns["original_values_categoricals"].keys()) - else: - logger.error("Currently, one can only reset encodings for all columns! Aborting...") - return None - transformed = _initial_encoding(adata.uns, categoricals) + transformed = _initial_encoding(adata.obs, categoricals) temp_x, temp_var_names = _delete_all_encodings(adata) new_x = np.hstack((transformed, temp_x)) if temp_x is not None else transformed new_var_names = categoricals + temp_var_names if temp_var_names is not None else categoricals + # only keep columns in obs that were stored in obs only -> delete every encoded column from obs new_obs = adata.obs[columns_obs_only] - uns = OrderedDict() - # reset uns and keep numerical/non-numerical columns - num_vars = _get_var_indices_for_type(adata, NUMERIC_TAG) - non_num_vars = _get_var_indices_for_type(adata, NON_NUMERIC_TAG) - for cat in categoricals: - original_values = adata.uns["original_values_categoricals"][cat] - type_first_nan = original_values[np.where(original_values != np.nan)][0] - if isinstance(type_first_nan, (int, float, complex)) and not isinstance(type_first_nan, bool): - num_vars.append(cat) - else: - non_num_vars.append(cat) var = pd.DataFrame(index=new_var_names) - var[EHRAPY_TYPE_KEY] = NON_NUMERIC_TAG - # Notice previously encoded columns are now newly added, and will stay tagged as non numeric - var.loc[num_vars, EHRAPY_TYPE_KEY] = NUMERIC_TAG - - uns["numerical_columns"] = num_vars - uns["non_numerical_columns"] = non_num_vars + var[FEATURE_TYPE_KEY] = [ + adata.var.loc[adata.var["unencoded_var_names"] == unenc_var_name, FEATURE_TYPE_KEY].unique()[0] + for unenc_var_name in new_var_names + ] return AnnData( new_x, obs=new_obs, var=var, - uns=uns, + uns=OrderedDict(), layers={"original": new_x.copy()}, ) @@ -619,7 +549,7 @@ def _delete_all_encodings(adata: AnnData) -> tuple[np.ndarray | None, list | Non def _reorder_encodings(adata: AnnData, new_encodings: dict[str, list[list[str]] | list[str]]): - """Reorder the encodings and update which column will be encoded using which mode (with which columns in case of multi column encoding modes). + """Reorder the encodings and update which column will be encoded using which mode. Args: adata: The AnnData object to be reencoded @@ -628,8 +558,8 @@ def _reorder_encodings(adata: AnnData, new_encodings: dict[str, list[list[str]] Returns: An updated encoding scheme """ - flattened_modes: list[list[str] | str] = sum(new_encodings.values(), []) # type: ignore - latest_encoded_columns = list(chain(*(i if isinstance(i, list) else (i,) for i in flattened_modes))) + latest_encoded_columns = sum(new_encodings.values(), []) + # check for duplicates and raise an error if any if len(set(latest_encoded_columns)) != len(latest_encoded_columns): logger.error( @@ -637,63 +567,25 @@ def _reorder_encodings(adata: AnnData, new_encodings: dict[str, list[list[str]] "cannot be encoded at the same time using different encoding modes!" ) raise DuplicateColumnEncodingError - old_encode_mode = adata.uns["var_to_encoding"] - for categorical in latest_encoded_columns: - encode_mode = old_encode_mode.get(categorical) - # if None, this categorical has not been encoded before but will be encoded now for the first time - # multi column encoding mode - if encode_mode in multi_encoding_modes: - encoded_categoricals_with_mode = adata.uns["encoding_to_var"][encode_mode] - _found = False - for column_list in encoded_categoricals_with_mode: - for column_name in column_list: - if column_name == categorical: - column_list.remove(column_name) - _found = True - break - # a categorical can only be encoded once and therefore found once - if _found: - break - # filter all lists that are now empty since all variables will be reencoded from this list - updated_multi_list = filter(None, encoded_categoricals_with_mode) - # if no columns remain that will be encoded with this encode mode, delete this mode from modes as well - if not list(updated_multi_list): - del adata.uns["encoding_to_var"][encode_mode] - # single column encoding mode - elif encode_mode in available_encodings: - encoded_categoricals_with_mode = adata.uns["encoding_to_var"][encode_mode] - for ind, column_name in enumerate(encoded_categoricals_with_mode): - if column_name == categorical: - del encoded_categoricals_with_mode[ind] - break - # if encoding mode is - if not encoded_categoricals_with_mode: - del adata.uns["encoding_to_var"][encode_mode] - - return _update_new_encode_modes(new_encodings, adata.uns["encoding_to_var"]) - - -def _update_new_encode_modes( - new_encodings: dict[str, list[list[str]] | list[str]], - filtered_old_encodings: dict[str, list[list[str]] | list[str]], -): - """Update the encoding scheme. - - If the encoding mode exists in the filtered old encodings, append all values (columns that should be encoded using this mode) to this key. - If not, defaultdict ensures that no KeyError will be raised and the values are simply appended to the default value ([]). - Args: - new_encodings: The new encoding modes passed by the user; basically what will be passed for encodings when calling the encode API - filtered_old_encodings: The old encoding modes, but with all columns removed that will be reencoded + encodings = {} + for encode_mode in available_encodings: + encoded_categoricals_with_mode = ( + adata.var.loc[adata.var["encoding_mode"] == encode_mode, "unencoded_var_names"].unique().tolist() + ) - Returns: - The updated encoding scheme - """ - updated_encodings = defaultdict(list) # type: ignore - for k, v in chain(new_encodings.items(), filtered_old_encodings.items()): - updated_encodings[k] += v + encodings[encode_mode] = new_encodings[encode_mode] if encode_mode in new_encodings.keys() else [] + # add all columns that were encoded with the current mode before and are not reencoded + encodings[encode_mode] += [ + _categorical + for _categorical in encoded_categoricals_with_mode + if _categorical not in latest_encoded_columns + ] - return dict(updated_encodings) + if len(encodings[encode_mode]) == 0: + del encodings[encode_mode] + + return encodings def _get_categoricals_old_indices(old_var_names: list[str], encoded_categories: list[str]) -> set[int]: @@ -720,53 +612,49 @@ def _get_categoricals_old_indices(old_var_names: list[str], encoded_categories: return idx_list -def _add_categoricals_to_obs(original: AnnData, new: AnnData, categorical_names: list[str]) -> None: +def _update_obs(adata: AnnData, categorical_names: list[str]) -> pd.DataFrame: """Add the original categorical values to obs. Args: - original: The original AnnData object - new: The new AnnData object + adata: The original AnnData object categorical_names: Name of each categorical column + + Returns: + Updated obs with the original categorical values added """ - for idx, var_name in enumerate(original.var_names): - if var_name in new.obs.columns: + updated_obs = adata.obs.copy() + for idx, var_name in enumerate(adata.var_names): + if var_name in updated_obs.columns: continue elif var_name in categorical_names: - new.obs[var_name] = original.X[::, idx : idx + 1].flatten() + updated_obs[var_name] = adata.X[::, idx : idx + 1].flatten() # note: this will count binary columns (0 and 1 only) as well # needed for writing to .h5ad files - if set(pd.unique(new.obs[var_name])).issubset({False, True, np.NaN}): - new.obs[var_name] = new.obs[var_name].astype("bool") - # get all non bool object columns and cast the to category dtype - object_columns = list(new.obs.select_dtypes(include="object").columns) - new.obs[object_columns] = new.obs[object_columns].astype("category") + if set(pd.unique(updated_obs[var_name])).issubset({False, True, np.NaN}): + updated_obs[var_name] = updated_obs[var_name].astype("bool") + # get all non bool object columns and cast them to category dtype + object_columns = list(updated_obs.select_dtypes(include="object").columns) + updated_obs[object_columns] = updated_obs[object_columns].astype("category") logger.info(f"The original categorical values `{categorical_names}` were added to obs.") + return updated_obs -def _add_categoricals_to_uns(original: AnnData, new: AnnData, categorical_names: list[str]) -> None: - """Add the original categorical values to uns. - Args: - original: The original AnnData object - new: The new AnnData object - categorical_names: Name of each categorical column - """ - is_initial = "original_values_categoricals" in original.uns.keys() - new["original_values_categoricals"] = {} if not is_initial else original.uns["original_values_categoricals"].copy() - - for idx, var_name in enumerate(original.var_names): - if is_initial and var_name in new["original_values_categoricals"]: - continue - elif var_name in categorical_names: - # keep numerical dtype when writing original values to uns - if var_name in original.var_names[original.var[EHRAPY_TYPE_KEY] == NUMERIC_TAG]: - new["original_values_categoricals"][var_name] = original.X[::, idx : idx + 1].astype("float") - else: - new["original_values_categoricals"][var_name] = original.X[::, idx : idx + 1].astype("str") +def _get_encoded_features(adata: AnnData) -> list[str]: + """Get all encoded features in an AnnData object. + Args: + adata: The AnnData object -class AlreadyEncodedWarning(UserWarning): - pass + Returns: + List of all unencoded names of features that were encoded + """ + encoded_features = [ + unencoded_feature + for enc_mode, unencoded_feature in adata.var[["encoding_mode", "unencoded_var_names"]].values + if enc_mode is not None and not pd.isna(enc_mode) + ] + return list(set(encoded_features)) class AnnDataCreationError(ValueError): @@ -775,7 +663,3 @@ class AnnDataCreationError(ValueError): class DuplicateColumnEncodingError(ValueError): pass - - -class HashEncodingError(Exception): - pass diff --git a/ehrapy/preprocessing/_quality_control.py b/ehrapy/preprocessing/_quality_control.py index b84f060c..d27ed5b9 100644 --- a/ehrapy/preprocessing/_quality_control.py +++ b/ehrapy/preprocessing/_quality_control.py @@ -7,10 +7,10 @@ import numpy as np import pandas as pd from lamin_utils import logger -from rich import print from thefuzz import process from ehrapy.anndata import anndata_to_df +from ehrapy.preprocessing._encoding import _get_encoded_features if TYPE_CHECKING: from collections.abc import Collection @@ -104,15 +104,19 @@ def _obs_qc_metrics( obs_metrics = pd.DataFrame(index=adata.obs_names) var_metrics = pd.DataFrame(index=adata.var_names) mtx = adata.X if layer is None else adata.layers[layer] - if "original_values_categoricals" in adata.uns: - for original_values_categorical in list(adata.uns["original_values_categoricals"]): + + if "encoding_mode" in adata.var: + for original_values_categorical in _get_encoded_features(adata): mtx = mtx.astype(object) index = np.where(var_metrics.index.str.contains(original_values_categorical))[0] + + if original_values_categorical not in adata.obs.keys(): + raise KeyError(f"Original values for {original_values_categorical} not found in adata.obs.") mtx[:, index[0]] = np.squeeze( np.where( - adata.uns["original_values_categoricals"][original_values_categorical].astype(object) == "nan", + adata.obs[original_values_categorical].astype(object) == "nan", np.nan, - adata.uns["original_values_categoricals"][original_values_categorical].astype(object), + adata.obs[original_values_categorical].astype(object), ) ) @@ -136,16 +140,20 @@ def _var_qc_metrics(adata: AnnData, layer: str | None = None) -> pd.DataFrame: var_metrics = pd.DataFrame(index=adata.var_names) mtx = adata.X if layer is None else adata.layers[layer] categorical_indices = np.ndarray([0], dtype=int) - if "original_values_categoricals" in adata.uns: - for original_values_categorical in list(adata.uns["original_values_categoricals"]): + + if "encoding_mode" in adata.var.keys(): + for original_values_categorical in _get_encoded_features(adata): mtx = copy.deepcopy(mtx.astype(object)) index = np.where(var_metrics.index.str.startswith("ehrapycat_" + original_values_categorical))[0] + + if original_values_categorical not in adata.obs.keys(): + raise KeyError(f"Original values for {original_values_categorical} not found in adata.obs.") mtx[:, index] = np.tile( np.where( - adata.uns["original_values_categoricals"][original_values_categorical].astype(object) == "nan", + adata.obs[original_values_categorical].astype(object) == "nan", np.nan, - adata.uns["original_values_categoricals"][original_values_categorical].astype(object), - ), + adata.obs[original_values_categorical].astype(object), + ).reshape(-1, 1), mtx[:, index].shape[1], ) categorical_indices = np.concatenate([categorical_indices, index]) diff --git a/ehrapy/tools/cohort_tracking/_cohort_tracker.py b/ehrapy/tools/cohort_tracking/_cohort_tracker.py index 41c1e714..2e3fc362 100644 --- a/ehrapy/tools/cohort_tracking/_cohort_tracker.py +++ b/ehrapy/tools/cohort_tracking/_cohort_tracker.py @@ -13,6 +13,9 @@ from scanpy import AnnData from tableone import TableOne +from ehrapy.anndata._constants import CATEGORICAL_TAG, DATE_TAG, FEATURE_TYPE_KEY, NUMERIC_TAG +from ehrapy.anndata._feature_specifications import _detect_feature_type + if TYPE_CHECKING: from collections.abc import Sequence @@ -35,14 +38,6 @@ def _check_no_new_categories(df: pd.DataFrame, categorical: pd.DataFrame, catego raise ValueError(f"New category in {col}: {diff}") -def _detect_categorical_columns(data) -> list: - # TODO grab this from ehrapy once https://github.com/theislab/ehrapy/issues/662 addressed - numeric_cols = set(data.select_dtypes("number").columns) - categorical_cols = set(data.columns) - numeric_cols - - return list(categorical_cols) - - import matplotlib.text as mtext @@ -92,10 +87,15 @@ def __init__(self, adata: AnnData, columns: Sequence = None, categorical: Sequen self._tracked_text: list = [] self._tracked_operations: list = [] - # if categorical columns specified, use them - # else, follow tableone's logic + # if categorical columns specified, use them, else infer the feature types self.categorical = ( - categorical if categorical is not None else _detect_categorical_columns(adata.obs[self.columns]) + categorical + if categorical is not None + else [ + col + for col in adata.obs[self.columns].columns + if _detect_feature_type(adata.obs[col])[0] == CATEGORICAL_TAG + ] ) self._categorical_categories: dict = { @@ -461,7 +461,7 @@ def plot_flowchart( Examples: >>> import ehrapy as ep - >>> adata = ep.dt.diabetes_130_fairlearn(columns_obs_only="gender", "race") + >>> adata = ep.dt.diabetes_130_fairlearn(columns_obs_only=["gender", "race"]) >>> cohort_tracker = ep.tl.CohortTracker(adata) >>> cohort_tracker(adata, label="Initial Cohort") >>> adata = adata[:1000] diff --git a/ehrapy/tools/feature_ranking/_feature_importances.py b/ehrapy/tools/feature_ranking/_feature_importances.py index 280ee138..d636aa4d 100644 --- a/ehrapy/tools/feature_ranking/_feature_importances.py +++ b/ehrapy/tools/feature_ranking/_feature_importances.py @@ -12,7 +12,7 @@ from sklearn.svm import SVC, SVR from ehrapy.anndata import anndata_to_df, check_feature_types -from ehrapy.anndata._constants import CATEGORICAL_TAG, CONTINUOUS_TAG, DATE_TAG, FEATURE_TYPE_KEY +from ehrapy.anndata._constants import CATEGORICAL_TAG, DATE_TAG, FEATURE_TYPE_KEY, NUMERIC_TAG @check_feature_types @@ -92,7 +92,7 @@ def rank_features_supervised( f"Feature {predicted_feature} is of type 'date' and cannot be used for prediction. Please choose a continuous or categorical feature." ) - if prediction_type == CONTINUOUS_TAG: + if prediction_type == NUMERIC_TAG: if model == "regression": predictor = LinearRegression(**kwargs) elif model == "svm": diff --git a/ehrapy/tools/feature_ranking/_rank_features_groups.py b/ehrapy/tools/feature_ranking/_rank_features_groups.py index bd3b3ec3..a8d89d18 100644 --- a/ehrapy/tools/feature_ranking/_rank_features_groups.py +++ b/ehrapy/tools/feature_ranking/_rank_features_groups.py @@ -7,8 +7,13 @@ import pandas as pd import scanpy as sc -from ehrapy.anndata import move_to_x -from ehrapy.anndata._constants import EHRAPY_TYPE_KEY, NON_NUMERIC_ENCODED_TAG, NUMERIC_TAG +from ehrapy.anndata import check_feature_types, infer_feature_types, move_to_x +from ehrapy.anndata._constants import ( + CATEGORICAL_TAG, + DATE_TAG, + FEATURE_TYPE_KEY, + NUMERIC_TAG, +) from ehrapy.preprocessing import encode if TYPE_CHECKING: @@ -148,6 +153,7 @@ def _get_groups_order(groups_subset, group_names, reference): return tuple(groups_order) +@check_feature_types def _evaluate_categorical_features( adata, groupby, @@ -202,11 +208,14 @@ def _evaluate_categorical_features( groups_order = _get_groups_order(groups_subset=groups, group_names=group_names, reference=reference) groups_values = adata.obs[groupby].to_numpy() - for feature in adata.var_names[adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_ENCODED_TAG]: + for feature in adata.var_names[adata.var[FEATURE_TYPE_KEY] == CATEGORICAL_TAG]: if feature == groupby or "ehrapycat_" + feature == groupby or feature == "ehrapycat_" + groupby: continue - feature_values = adata[:, feature].X.flatten().toarray() + try: + feature_values = adata[:, feature].X.flatten().toarray() + except ValueError as e: + raise ValueError(f"Feature {feature} is not encoded. Please encode it using `ehrapy.pp.encode`") from e pvals = [] scores = [] @@ -296,6 +305,7 @@ def _check_columns_to_rank_dict(columns_to_rank): return _var_subset, _obs_subset +@check_feature_types def rank_features_groups( adata: AnnData, groupby: str, @@ -351,7 +361,7 @@ def rank_features_groups( columns_to_rank: Subset of columns to rank. If 'all', all columns are used. If a dictionary, it must have keys 'var_names' and/or 'obs_names' and values must be iterables of strings such as {'var_names': ['glucose'], 'obs_names': ['age', 'height']}. - **kwds: Are passed to test methods. Currently this affects only parameters that + **kwds: Are passed to test methods. Currently, this affects only parameters that are passed to :class:`sklearn.linear_model.LogisticRegression`. For instance, you can pass `penalty='l1'` to try to come up with a minimal set of genes that are good predictors (sparse solution meaning few non-zero fitted coefficients). @@ -459,6 +469,22 @@ def rank_features_groups( # the 0th column is a dummy of zeros and is meaningless in this case, and needs to be removed adata_minimal = adata_minimal[:, 1:] + # if the feature type is set in adata.obs, we store the respective feature type in adata_minimal.var + adata_minimal.var[FEATURE_TYPE_KEY] = [ + adata.var[FEATURE_TYPE_KEY].loc[feature] + if feature not in adata.obs.keys() and FEATURE_TYPE_KEY in adata.var.keys() + else CATEGORICAL_TAG + if adata.obs[feature].dtype == "category" + else DATE_TAG + if pd.api.types.is_datetime64_any_dtype(adata.obs[feature]) + else NUMERIC_TAG + if pd.api.types.is_numeric_dtype(adata.obs[feature]) + else None + for feature in adata_minimal.var_names + ] + # we infer the feature type for all features for which adata.obs did not provide information on the type + infer_feature_types(adata_minimal, output=None) + adata_minimal = encode(adata_minimal, autodetect=True, encodings="label") # this is needed because encode() doesn't add this key if there are no categorical columns to encode if "encoded_non_numerical_columns" not in adata_minimal.uns: @@ -486,12 +512,12 @@ def rank_features_groups( group_names = pd.Categorical(adata.obs[groupby].astype(str)).categories.tolist() - if list(adata.var_names[adata.var[EHRAPY_TYPE_KEY] == NUMERIC_TAG]): + if list(adata.var_names[adata.var[FEATURE_TYPE_KEY] == NUMERIC_TAG]): # Rank numerical features # Without copying `numerical_adata` is a view, and code throws an error # because of "object" type of .X - numerical_adata = adata[:, adata.var_names[adata.var[EHRAPY_TYPE_KEY] == NUMERIC_TAG]].copy() + numerical_adata = adata[:, adata.var_names[adata.var[FEATURE_TYPE_KEY] == NUMERIC_TAG]].copy() numerical_adata.X = numerical_adata.X.astype(float) sc.tl.rank_genes_groups( @@ -524,7 +550,7 @@ def rank_features_groups( groups_order=group_names, ) - if list(adata.var_names[adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_ENCODED_TAG]): + if list(adata.var_names[adata.var[FEATURE_TYPE_KEY] == CATEGORICAL_TAG]): ( categorical_names, categorical_scores, diff --git a/pyproject.toml b/pyproject.toml index cedafea5..a75aa463 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,8 @@ dependencies = [ "tableone", "imbalanced-learn", "fknni", - "filelock" + "python-dateutil", + "filelock", ] [project.optional-dependencies] diff --git a/tests/anndata/test_anndata_ext.py b/tests/anndata/test_anndata_ext.py index a65eab60..e45a5d09 100644 --- a/tests/anndata/test_anndata_ext.py +++ b/tests/anndata/test_anndata_ext.py @@ -11,7 +11,7 @@ from pandas.testing import assert_frame_equal import ehrapy as ep -from ehrapy.anndata._constants import EHRAPY_TYPE_KEY, NON_NUMERIC_TAG, NUMERIC_TAG +from ehrapy.anndata._constants import CATEGORICAL_TAG, FEATURE_TYPE_KEY, NUMERIC_TAG from ehrapy.anndata.anndata_ext import ( NotEncodedError, _assert_encoded, @@ -54,11 +54,10 @@ def setup_binary_df_to_anndata() -> DataFrame: col2_val = ["another_str" + str(idx) for idx in range(100)] col3_val = [0 for _ in range(100)] col4_val = [1.0 for _ in range(100)] - col5_val = [np.NaN for _ in range(100)] - col6_val = [0.0 if idx % 2 == 0 else np.NaN for idx in range(100)] - col7_val = [idx % 2 for idx in range(100)] - col8_val = [float(idx % 2) for idx in range(100)] - col9_val = [idx % 3 if idx % 3 in {0, 1} else np.NaN for idx in range(100)] + col5_val = [0.0 if idx % 2 == 0 else np.NaN for idx in range(100)] + col6_val = [idx % 2 for idx in range(100)] + col7_val = [float(idx % 2) for idx in range(100)] + col8_val = [idx % 3 if idx % 3 in {0, 1} else np.NaN for idx in range(100)] df = DataFrame( { "col1": col1_val, @@ -66,10 +65,9 @@ def setup_binary_df_to_anndata() -> DataFrame: "col3": col3_val, "col4": col4_val, "col5": col5_val, - "col6": col6_val, - "col7_binary_int": col7_val, - "col8_binary_float": col8_val, - "col9_binary_missing_values": col9_val, + "col6_binary_int": col6_val, + "col7_binary_float": col7_val, + "col8_binary_missing_values": col8_val, } ) @@ -142,20 +140,25 @@ def test_move_to_x(adata_move_obs_mix): assert set(new_adata_num.obs.columns) == {"name"} assert {str(col) for col in new_adata_num.obs.dtypes} == {"category"} assert {str(col) for col in new_adata_non_num.obs.dtypes} == {"float32", "category"} + assert_frame_equal( new_adata_non_num.var, DataFrame( - {EHRAPY_TYPE_KEY: [NUMERIC_TAG, NUMERIC_TAG, NON_NUMERIC_TAG]}, + {FEATURE_TYPE_KEY: [NUMERIC_TAG, NUMERIC_TAG, CATEGORICAL_TAG]}, index=["los_days", "b12_values", "name"], ), ) + assert_frame_equal( new_adata_num.var, DataFrame( - {EHRAPY_TYPE_KEY: [NUMERIC_TAG, NUMERIC_TAG, NON_NUMERIC_TAG, NUMERIC_TAG]}, + {FEATURE_TYPE_KEY: [NUMERIC_TAG, NUMERIC_TAG, CATEGORICAL_TAG, np.nan]}, index=["los_days", "b12_values", "name", "clinic_id"], ), ) + ep.ad.infer_feature_types(new_adata_num, output=None) + assert np.all(new_adata_num.var[FEATURE_TYPE_KEY] == [NUMERIC_TAG, NUMERIC_TAG, CATEGORICAL_TAG, NUMERIC_TAG]) + assert_frame_equal( new_adata_num.obs, DataFrame( @@ -163,6 +166,7 @@ def test_move_to_x(adata_move_obs_mix): index=[str(idx) for idx in range(5)], ).astype({"name": "category"}), ) + assert_frame_equal( new_adata_non_num.obs, DataFrame( @@ -215,7 +219,6 @@ def test_delete_from_obs(adata_move_obs_mix): adata = delete_from_obs(adata, ["los_days"]) assert not {"los_days"}.issubset(set(adata.obs.columns)) assert {"los_days"}.issubset(set(adata.var_names)) - assert EHRAPY_TYPE_KEY in adata.var.columns def test_df_to_anndata_simple(setup_df_to_anndata): @@ -356,21 +359,21 @@ def test_anndata_to_df_layers(setup_anndata_to_df): def test_detect_binary_columns(setup_binary_df_to_anndata): adata = df_to_anndata(setup_binary_df_to_anndata) + ep.ad.infer_feature_types(adata, output=None) assert_frame_equal( adata.var, DataFrame( { - EHRAPY_TYPE_KEY: [ - NON_NUMERIC_TAG, - NON_NUMERIC_TAG, - NUMERIC_TAG, - NUMERIC_TAG, - NUMERIC_TAG, - NUMERIC_TAG, - NUMERIC_TAG, - NUMERIC_TAG, - NUMERIC_TAG, + FEATURE_TYPE_KEY: [ + CATEGORICAL_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, ] }, index=[ @@ -379,10 +382,9 @@ def test_detect_binary_columns(setup_binary_df_to_anndata): "col3", "col4", "col5", - "col6", - "col7_binary_int", - "col8_binary_float", - "col9_binary_missing_values", + "col6_binary_int", + "col7_binary_float", + "col8_binary_missing_values", ], ), ) @@ -393,38 +395,17 @@ def test_detect_mixed_binary_columns(): {"Col1": list(range(4)), "Col2": ["str" + str(i) for i in range(4)], "Col3": [1.0, 0.0, np.nan, 1.0]} ) adata = ep.ad.df_to_anndata(df) + ep.ad.infer_feature_types(adata, output=None) + assert_frame_equal( adata.var, DataFrame( - {EHRAPY_TYPE_KEY: [NUMERIC_TAG, NON_NUMERIC_TAG, NUMERIC_TAG]}, + {FEATURE_TYPE_KEY: [NUMERIC_TAG, CATEGORICAL_TAG, CATEGORICAL_TAG]}, index=["Col1", "Col2", "Col3"], ), ) -@pytest.fixture -def adata_numeric(): - obs_data = {"ID": ["Patient1", "Patient2", "Patient3"], "Age": [31, 94, 62]} - - X_numeric = np.array([[1, 3.4, 2.1, 4], [2, 6.9, 7.6, 2], [1, 4.5, 1.3, 7]], dtype=np.dtype(float)) - var_numeric = { - "Feature": ["Numeric1", "Numeric2", "Numeric3", "Numeric4"], - "Type": ["Numeric", "Numeric", "Numeric", "Numeric"], - } - - adata_numeric = AnnData( - X=X_numeric, - obs=pd.DataFrame(data=obs_data), - var=pd.DataFrame(data=var_numeric, index=var_numeric["Feature"]), - uns=OrderedDict(), - ) - adata_numeric.var[EHRAPY_TYPE_KEY] = [NUMERIC_TAG, NUMERIC_TAG, NON_NUMERIC_TAG, NON_NUMERIC_TAG] - adata_numeric.uns["numerical_columns"] = ["Numeric1", "Numeric2"] - adata_numeric.uns["non_numerical_columns"] = ["String1", "String2"] - - return adata_numeric - - @pytest.fixture def adata_strings_encoded(): obs_data = {"ID": ["Patient1", "Patient2", "Patient3"], "Age": [31, 94, 62]} @@ -446,9 +427,8 @@ def adata_strings_encoded(): obs=pd.DataFrame(data=obs_data), var=pd.DataFrame(data=var_strings, index=var_strings["Feature"]), ) - adata_strings.var[EHRAPY_TYPE_KEY] = [NUMERIC_TAG, NUMERIC_TAG, NON_NUMERIC_TAG, NON_NUMERIC_TAG] - adata_strings.uns["numerical_columns"] = ["Numeric1", "Numeric2"] - adata_strings.uns["non_numerical_columns"] = ["String1", "String2"] + adata_strings.var[FEATURE_TYPE_KEY] = [NUMERIC_TAG, NUMERIC_TAG, CATEGORICAL_TAG, CATEGORICAL_TAG] + adata_encoded = ep.pp.encode(adata_strings.copy(), autodetect=True, encodings="label") return adata_strings, adata_encoded diff --git a/tests/anndata/test_feature_specifications.py b/tests/anndata/test_feature_specifications.py index 0be8f0b5..2c298d55 100644 --- a/tests/anndata/test_feature_specifications.py +++ b/tests/anndata/test_feature_specifications.py @@ -6,7 +6,7 @@ import ehrapy as ep from ehrapy.anndata import check_feature_types, df_to_anndata -from ehrapy.anndata._constants import CATEGORICAL_TAG, CONTINUOUS_TAG, DATE_TAG, FEATURE_TYPE_KEY +from ehrapy.anndata._constants import CATEGORICAL_TAG, DATE_TAG, FEATURE_TYPE_KEY, NUMERIC_TAG from ehrapy.io._read import read_csv _TEST_PATH = f"{Path(__file__).parents[1]}/preprocessing/test_data_imputation" @@ -34,9 +34,9 @@ def test_feature_type_inference(adata): assert adata.var[FEATURE_TYPE_KEY]["feature1"] == CATEGORICAL_TAG assert adata.var[FEATURE_TYPE_KEY]["feature2"] == CATEGORICAL_TAG assert adata.var[FEATURE_TYPE_KEY]["feature3"] == CATEGORICAL_TAG - assert adata.var[FEATURE_TYPE_KEY]["feature4"] == CONTINUOUS_TAG + assert adata.var[FEATURE_TYPE_KEY]["feature4"] == NUMERIC_TAG assert adata.var[FEATURE_TYPE_KEY]["feature5"] == CATEGORICAL_TAG - assert adata.var[FEATURE_TYPE_KEY]["feature6"] == CONTINUOUS_TAG + assert adata.var[FEATURE_TYPE_KEY]["feature6"] == NUMERIC_TAG assert adata.var[FEATURE_TYPE_KEY]["feature7"] == DATE_TAG @@ -45,9 +45,9 @@ def test_check_feature_types(adata): def test_func(adata): pass - with pytest.raises(ValueError) as e: - test_func(adata) - assert str(e.value).startswith("Feature types are not specified in adata.var.") + assert FEATURE_TYPE_KEY not in adata.var.keys() + test_func(adata) + assert FEATURE_TYPE_KEY in adata.var.keys() ep.ad.infer_feature_types(adata, output=None) test_func(adata) @@ -64,22 +64,20 @@ def test_func_with_return(adata): def test_feature_types_impute_num_adata(): adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") ep.ad.infer_feature_types(adata, output=None) - assert np.all(adata.var[FEATURE_TYPE_KEY] == [CONTINUOUS_TAG, CONTINUOUS_TAG, CONTINUOUS_TAG]) + assert np.all(adata.var[FEATURE_TYPE_KEY] == [NUMERIC_TAG, NUMERIC_TAG, NUMERIC_TAG]) return adata def test_feature_types_impute_adata(): adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute.csv") ep.ad.infer_feature_types(adata, output=None) - assert np.all(adata.var[FEATURE_TYPE_KEY] == [CATEGORICAL_TAG, CONTINUOUS_TAG, CATEGORICAL_TAG, CATEGORICAL_TAG]) + assert np.all(adata.var[FEATURE_TYPE_KEY] == [NUMERIC_TAG, NUMERIC_TAG, CATEGORICAL_TAG, CATEGORICAL_TAG]) def test_feature_types_impute_iris(): adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_iris.csv") ep.ad.infer_feature_types(adata, output=None) - assert np.all( - adata.var[FEATURE_TYPE_KEY] == [CONTINUOUS_TAG, CONTINUOUS_TAG, CONTINUOUS_TAG, CONTINUOUS_TAG, CATEGORICAL_TAG] - ) + assert np.all(adata.var[FEATURE_TYPE_KEY] == [NUMERIC_TAG, NUMERIC_TAG, NUMERIC_TAG, NUMERIC_TAG, CATEGORICAL_TAG]) def test_feature_types_impute_feature_types_titanic(): @@ -91,11 +89,90 @@ def test_feature_types_impute_feature_types_titanic(): CATEGORICAL_TAG, CATEGORICAL_TAG, CATEGORICAL_TAG, - CONTINUOUS_TAG, - CONTINUOUS_TAG, - CONTINUOUS_TAG, - CONTINUOUS_TAG, - CONTINUOUS_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + NUMERIC_TAG, CATEGORICAL_TAG, CATEGORICAL_TAG, ] + + +def test_date_detection(): + df = pd.DataFrame( + { + "date1": pd.to_datetime(["2021-01-01", "2024-04-16", "2021-01-03"]), + "date2": ["2021-01-01", "2024-04-16", "2021-01-03"], + "date3": ["2024-04-16 07:45:13", "2024-04-16", "2024-04"], + "not_date": ["not_a_date", "2024-04-16", "2021-01-03"], + } + ) + adata = df_to_anndata(df) + ep.ad.infer_feature_types(adata, output=None) + assert np.all(adata.var[FEATURE_TYPE_KEY] == [DATE_TAG, DATE_TAG, DATE_TAG, CATEGORICAL_TAG]) + + +def test_all_possible_types(): + df = pd.DataFrame( + { + "f1": [42, 17, 93, 235], + "f2": ["apple", "banana", "cherry", "date"], + "f3": [1, 2, 3, 1], + "f4": [1.0, 2.0, 1.0, 2.0], + "f5": ["20200101", "20200102", "20200103", "20200104"], + "f6": [True, False, True, False], + "f7": [np.nan, 1, np.nan, 2], + "f8": ["apple", 1, "banana", 2], + "f9": ["001", "002", "003", "002"], + "f10": ["5", "5", "5", "5"], + "f11": ["A1", "A2", "B1", "B2"], + "f12": [90210, 10001, 60614, 80588], + "f13": [0.25, 0.5, 0.75, 1.0], + "f14": ["2125551234", "2125555678", "2125559012", "2125553456"], + "f15": ["$100", "€150", "£200", "¥250"], + "f16": [101, 102, 103, 104], + "f17": [1e3, 5e-2, 3.1e2, 2.7e-1], + "f18": ["23.5", "324", "4.5", "0.5"], + "f19": [1, 2, 3, 4], + "f20": ["001", "002", "003", "004"], + } + ) + + adata = df_to_anndata(df) + ep.ad.infer_feature_types(adata, output=None) + + assert np.all( + adata.var[FEATURE_TYPE_KEY] + == [ + NUMERIC_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + DATE_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + NUMERIC_TAG, + CATEGORICAL_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + CATEGORICAL_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + ] + ) + + +def test_partial_annotation(adata): + adata.var[FEATURE_TYPE_KEY] = ["dummy", np.nan, np.nan, NUMERIC_TAG, None, np.nan, None] + ep.ad.infer_feature_types(adata, output=None) + assert np.all( + adata.var[FEATURE_TYPE_KEY] + == ["dummy", CATEGORICAL_TAG, CATEGORICAL_TAG, NUMERIC_TAG, CATEGORICAL_TAG, NUMERIC_TAG, DATE_TAG] + ) diff --git a/tests/preprocessing/test_bias.py b/tests/preprocessing/test_bias.py index 34438bfe..6c3f588c 100644 --- a/tests/preprocessing/test_bias.py +++ b/tests/preprocessing/test_bias.py @@ -3,7 +3,7 @@ import pytest import ehrapy as ep -from ehrapy.anndata._constants import CATEGORICAL_TAG, CONTINUOUS_TAG, FEATURE_TYPE_KEY +from ehrapy.anndata._constants import CATEGORICAL_TAG, FEATURE_TYPE_KEY, NUMERIC_TAG @pytest.fixture @@ -20,7 +20,7 @@ def adata(): } ) adata = ep.ad.df_to_anndata(df) - adata.var[FEATURE_TYPE_KEY] = [CONTINUOUS_TAG] * 4 + [CATEGORICAL_TAG] * 2 + adata.var[FEATURE_TYPE_KEY] = [NUMERIC_TAG] * 4 + [CATEGORICAL_TAG] * 2 return adata diff --git a/tests/preprocessing/test_encoding.py b/tests/preprocessing/test_encoding.py index 10e38b97..ab6a1f8f 100644 --- a/tests/preprocessing/test_encoding.py +++ b/tests/preprocessing/test_encoding.py @@ -1,11 +1,12 @@ from pathlib import Path +import numpy as np import pandas as pd import pytest from pandas import CategoricalDtype, DataFrame from pandas.testing import assert_frame_equal -from ehrapy.anndata._constants import EHRAPY_TYPE_KEY, NON_NUMERIC_ENCODED_TAG, NON_NUMERIC_TAG, NUMERIC_TAG +from ehrapy.anndata._constants import CATEGORICAL_TAG, FEATURE_TYPE_KEY, NUMERIC_TAG from ehrapy.io._read import read_csv from ehrapy.preprocessing._encoding import DuplicateColumnEncodingError, _reorder_encodings, encode @@ -45,10 +46,24 @@ def test_autodetect_encode(): "b12_values", } - assert encoded_ann_data.uns["var_to_encoding"] == { - "survival": "one-hot", - "clinic_day": "one-hot", - } + assert np.all( + encoded_ann_data.var["unencoded_var_names"] + == [ + "survival", + "survival", + "clinic_day", + "clinic_day", + "clinic_day", + "clinic_day", + "patient_id", + "los_days", + "b12_values", + ] + ) + + assert np.all(encoded_ann_data.var["encoding_mode"][:6] == ["one-hot"] * 6) + assert np.all(enc is None for enc in encoded_ann_data.var["encoding_mode"][6:]) + assert id(encoded_ann_data.X) != id(encoded_ann_data.layers["original"]) assert adata is not None and adata.X is not None and adata.obs is not None and adata.uns is not None assert id(encoded_ann_data) != id(adata) @@ -57,42 +72,30 @@ def test_autodetect_encode(): assert id(encoded_ann_data.var) != id(adata.var) assert all(column in set(encoded_ann_data.obs.columns) for column in ["survival", "clinic_day"]) assert not any(column in set(adata.obs.columns) for column in ["survival", "clinic_day"]) + assert_frame_equal( adata.var, DataFrame( - {EHRAPY_TYPE_KEY: [NUMERIC_TAG, NUMERIC_TAG, NUMERIC_TAG, NON_NUMERIC_TAG, NON_NUMERIC_TAG]}, + {FEATURE_TYPE_KEY: [NUMERIC_TAG, NUMERIC_TAG, NUMERIC_TAG, CATEGORICAL_TAG, CATEGORICAL_TAG]}, index=["patient_id", "los_days", "b12_values", "survival", "clinic_day"], ), ) - assert_frame_equal( - encoded_ann_data.var, - DataFrame( - { - EHRAPY_TYPE_KEY: [ - NON_NUMERIC_ENCODED_TAG, - NON_NUMERIC_ENCODED_TAG, - NON_NUMERIC_ENCODED_TAG, - NON_NUMERIC_ENCODED_TAG, - NON_NUMERIC_ENCODED_TAG, - NON_NUMERIC_ENCODED_TAG, - NUMERIC_TAG, - NUMERIC_TAG, - NUMERIC_TAG, - ] - }, - index=[ - "ehrapycat_survival_False", - "ehrapycat_survival_True", - "ehrapycat_clinic_day_Friday", - "ehrapycat_clinic_day_Monday", - "ehrapycat_clinic_day_Saturday", - "ehrapycat_clinic_day_Sunday", - "patient_id", - "los_days", - "b12_values", - ], - ), + + assert np.all( + encoded_ann_data.var[FEATURE_TYPE_KEY] + == [ + CATEGORICAL_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + ] ) + assert pd.api.types.is_bool_dtype(encoded_ann_data.obs["survival"].dtype) assert isinstance(encoded_ann_data.obs["clinic_day"].dtype, CategoricalDtype) @@ -116,10 +119,13 @@ def test_autodetect_custom_mode(): "b12_values", } - assert encoded_ann_data.uns["var_to_encoding"] == { - "survival": "label", - "clinic_day": "label", - } + assert np.all( + encoded_ann_data.var["unencoded_var_names"] + == ["survival", "clinic_day", "patient_id", "los_days", "b12_values"] + ) + assert np.all(encoded_ann_data.var["encoding_mode"][:2] == ["label"] * 2) + assert np.all(enc is None for enc in encoded_ann_data.var["encoding_mode"][2:]) + assert id(encoded_ann_data.X) != id(encoded_ann_data.layers["original"]) assert adata is not None and adata.X is not None and adata.obs is not None and adata.uns is not None assert id(encoded_ann_data) != id(adata) @@ -128,34 +134,26 @@ def test_autodetect_custom_mode(): assert id(encoded_ann_data.var) != id(adata.var) assert all(column in set(encoded_ann_data.obs.columns) for column in ["survival", "clinic_day"]) assert not any(column in set(adata.obs.columns) for column in ["survival", "clinic_day"]) + assert_frame_equal( adata.var, DataFrame( - {EHRAPY_TYPE_KEY: [NUMERIC_TAG, NUMERIC_TAG, NUMERIC_TAG, NON_NUMERIC_TAG, NON_NUMERIC_TAG]}, + {FEATURE_TYPE_KEY: [NUMERIC_TAG, NUMERIC_TAG, NUMERIC_TAG, CATEGORICAL_TAG, CATEGORICAL_TAG]}, index=["patient_id", "los_days", "b12_values", "survival", "clinic_day"], ), ) - assert_frame_equal( - encoded_ann_data.var, - DataFrame( - { - EHRAPY_TYPE_KEY: [ - NON_NUMERIC_ENCODED_TAG, - NON_NUMERIC_ENCODED_TAG, - NUMERIC_TAG, - NUMERIC_TAG, - NUMERIC_TAG, - ] - }, - index=[ - "ehrapycat_survival", - "ehrapycat_clinic_day", - "patient_id", - "los_days", - "b12_values", - ], - ), + + assert np.all( + encoded_ann_data.var[FEATURE_TYPE_KEY] + == [ + CATEGORICAL_TAG, + CATEGORICAL_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + ] ) + assert pd.api.types.is_bool_dtype(encoded_ann_data.obs["survival"].dtype) assert isinstance(encoded_ann_data.obs["clinic_day"].dtype, CategoricalDtype) @@ -186,10 +184,14 @@ def test_custom_encode(): "ehrapycat_clinic_day_Sunday", ] ) - assert encoded_ann_data.uns["var_to_encoding"] == { - "survival": "label", - "clinic_day": "one-hot", - } + + assert np.all( + encoded_ann_data.var["unencoded_var_names"] + == ["clinic_day", "clinic_day", "clinic_day", "clinic_day", "survival", "patient_id", "los_days", "b12_values"] + ) + assert np.all(encoded_ann_data.var["encoding_mode"][:5] == ["one-hot"] * 4 + ["label"]) + assert np.all(enc is None for enc in encoded_ann_data.var["encoding_mode"][5:]) + assert id(encoded_ann_data.X) != id(encoded_ann_data.layers["original"]) assert adata is not None and adata.X is not None and adata.obs is not None and adata.uns is not None assert id(encoded_ann_data) != id(adata) @@ -198,40 +200,29 @@ def test_custom_encode(): assert id(encoded_ann_data.var) != id(adata.var) assert all(column in set(encoded_ann_data.obs.columns) for column in ["survival", "clinic_day"]) assert not any(column in set(adata.obs.columns) for column in ["survival", "clinic_day"]) + assert_frame_equal( adata.var, DataFrame( - {EHRAPY_TYPE_KEY: [NUMERIC_TAG, NUMERIC_TAG, NUMERIC_TAG, NON_NUMERIC_TAG, NON_NUMERIC_TAG]}, + {FEATURE_TYPE_KEY: [NUMERIC_TAG, NUMERIC_TAG, NUMERIC_TAG, CATEGORICAL_TAG, CATEGORICAL_TAG]}, index=["patient_id", "los_days", "b12_values", "survival", "clinic_day"], ), ) - assert_frame_equal( - encoded_ann_data.var, - DataFrame( - { - EHRAPY_TYPE_KEY: [ - NON_NUMERIC_ENCODED_TAG, - NON_NUMERIC_ENCODED_TAG, - NON_NUMERIC_ENCODED_TAG, - NON_NUMERIC_ENCODED_TAG, - NON_NUMERIC_ENCODED_TAG, - NUMERIC_TAG, - NUMERIC_TAG, - NUMERIC_TAG, - ] - }, - index=[ - "ehrapycat_clinic_day_Friday", - "ehrapycat_clinic_day_Monday", - "ehrapycat_clinic_day_Saturday", - "ehrapycat_clinic_day_Sunday", - "ehrapycat_survival", - "patient_id", - "los_days", - "b12_values", - ], - ), + + assert np.all( + encoded_ann_data.var[FEATURE_TYPE_KEY] + == [ + CATEGORICAL_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + ] ) + assert pd.api.types.is_bool_dtype(encoded_ann_data.obs["survival"].dtype) assert isinstance(encoded_ann_data.obs["clinic_day"].dtype, CategoricalDtype) @@ -245,7 +236,8 @@ def test_custom_encode_again_single_columns_encoding(): ) encoded_ann_data_again = encode(encoded_ann_data, autodetect=False, encodings={"label": ["clinic_day"]}) assert encoded_ann_data_again.X.shape == (5, 5) - assert list(encoded_ann_data_again.obs.columns) == ["survival", "clinic_day"] + assert len(encoded_ann_data_again.obs.columns) == 2 + assert set(encoded_ann_data_again.obs.columns) == {"survival", "clinic_day"} assert "ehrapycat_survival" in list(encoded_ann_data_again.var_names) assert "ehrapycat_clinic_day" in list(encoded_ann_data_again.var_names) assert all( @@ -257,10 +249,12 @@ def test_custom_encode_again_single_columns_encoding(): "ehrapycat_clinic_day_Sunday", ] ) - assert encoded_ann_data_again.uns["var_to_encoding"] == { - "survival": "label", - "clinic_day": "label", - } + + assert np.all( + encoded_ann_data_again.var["encoding_mode"].loc[["ehrapycat_survival", "ehrapycat_clinic_day"]] + == ["label", "label"] + ) + assert id(encoded_ann_data_again.X) != id(encoded_ann_data_again.layers["original"]) assert pd.api.types.is_bool_dtype(encoded_ann_data.obs["survival"].dtype) assert isinstance(encoded_ann_data.obs["clinic_day"].dtype, CategoricalDtype) @@ -275,17 +269,26 @@ def test_custom_encode_again_multiple_columns_encoding(): encodings={"label": ["survival"], "one-hot": ["clinic_day"]}, ) assert encoded_ann_data_again.X.shape == (5, 8) - assert list(encoded_ann_data_again.obs.columns) == ["survival", "clinic_day"] + assert len(encoded_ann_data_again.obs.columns) == 2 + assert set(encoded_ann_data_again.obs.columns) == {"survival", "clinic_day"} assert "ehrapycat_survival" in list(encoded_ann_data_again.var_names) assert "ehrapycat_clinic_day_Friday" in list(encoded_ann_data_again.var_names) assert all( survival_outcome not in list(encoded_ann_data_again.var_names) for survival_outcome in ["ehrapycat_survival_False", "ehrapycat_survival_True"] ) - assert encoded_ann_data_again.uns["var_to_encoding"] == { - "survival": "label", - "clinic_day": "one-hot", - } + + assert np.all( + encoded_ann_data_again.var.loc[encoded_ann_data_again.var["unencoded_var_names"] == "survival", "encoding_mode"] + == "label" + ) + assert np.all( + encoded_ann_data_again.var.loc[ + encoded_ann_data_again.var["unencoded_var_names"] == "clinic_day", "encoding_mode" + ] + == "one-hot" + ) + assert id(encoded_ann_data_again.X) != id(encoded_ann_data_again.layers["original"]) assert pd.api.types.is_bool_dtype(encoded_ann_data.obs["survival"].dtype) assert isinstance(encoded_ann_data.obs["clinic_day"].dtype, CategoricalDtype) @@ -294,21 +297,15 @@ def test_custom_encode_again_multiple_columns_encoding(): def test_update_encoding_scheme_1(): # just a dummy adata object that won't be used actually adata = read_csv(dataset_path=f"{_TEST_PATH}/dataset1.csv") - adata.uns["encoding_to_var"] = { - "label": ["col1", "col2", "col3"], - "one-hot": ["col4"], - } - adata.uns["var_to_encoding"] = { - "col1": "label", - "col2": "label", - "col3": "label", - "col4": "one-hot", - } + + adata.var["unencoded_var_names"] = ["col1", "col2", "col3", "col4", "col5"] + adata.var["encoding_mode"] = ["label", "label", "label", "one-hot", "one-hot"] + new_encodings = {"one-hot": ["col1"], "label": ["col2", "col3", "col4"]} expected_encodings = { "label": ["col2", "col3", "col4"], - "one-hot": ["col1"], + "one-hot": ["col1", "col5"], } updated_encodings = _reorder_encodings(adata, new_encodings) diff --git a/tests/preprocessing/test_highly_variable_features.py b/tests/preprocessing/test_highly_variable_features.py index e8e4b2f3..8b4da744 100644 --- a/tests/preprocessing/test_highly_variable_features.py +++ b/tests/preprocessing/test_highly_variable_features.py @@ -4,7 +4,6 @@ def test_highly_variable_features(): adata = ep.dt.dermatology(encoded=True) - ep.ad.infer_feature_types(adata, output=None) ep.pp.knn_impute(adata) highly_variable_features(adata) @@ -15,7 +14,6 @@ def test_highly_variable_features(): assert "variances_norm" in adata.var.columns adata = ep.dt.dermatology(encoded=True) - ep.ad.infer_feature_types(adata, output=None) ep.pp.knn_impute(adata) highly_variable_features(adata, top_features_percentage=0.5) assert adata.var["highly_variable"].sum() == 17 diff --git a/tests/preprocessing/test_imputation.py b/tests/preprocessing/test_imputation.py index 3164e641..42ea0a8b 100644 --- a/tests/preprocessing/test_imputation.py +++ b/tests/preprocessing/test_imputation.py @@ -7,7 +7,7 @@ from sklearn.exceptions import ConvergenceWarning import ehrapy as ep -from ehrapy.anndata._constants import CATEGORICAL_TAG, CONTINUOUS_TAG, DATE_TAG, FEATURE_TYPE_KEY +from ehrapy.anndata._constants import CATEGORICAL_TAG, DATE_TAG, FEATURE_TYPE_KEY, NUMERIC_TAG from ehrapy.io._read import read_csv from ehrapy.preprocessing._imputation import ( _warn_imputation_threshold, @@ -25,28 +25,24 @@ @pytest.fixture def impute_num_adata(): adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_num.csv") - ep.ad.infer_feature_types(adata, output=None) return adata @pytest.fixture def impute_adata(): adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute.csv") - ep.ad.infer_feature_types(adata, output=None) return adata @pytest.fixture def impute_iris(): adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_iris.csv") - ep.ad.infer_feature_types(adata, output=None) return adata @pytest.fixture def impute_titanic(): adata = read_csv(dataset_path=f"{_TEST_PATH}/test_impute_titanic.csv") - ep.ad.infer_feature_types(adata, output=None) return adata diff --git a/tests/preprocessing/test_normalization.py b/tests/preprocessing/test_normalization.py index d18e3947..66e6b8c6 100644 --- a/tests/preprocessing/test_normalization.py +++ b/tests/preprocessing/test_normalization.py @@ -8,7 +8,7 @@ from anndata import AnnData import ehrapy as ep -from ehrapy.anndata._constants import EHRAPY_TYPE_KEY, NON_NUMERIC_TAG, NUMERIC_TAG +from ehrapy.anndata._constants import CATEGORICAL_TAG, FEATURE_TYPE_KEY, NUMERIC_TAG from ehrapy.io._read import read_csv CURRENT_DIR = Path(__file__).parent @@ -37,7 +37,7 @@ def adata_to_norm(): var_data = { "Feature": ["Integer1", "Numeric1", "Numeric2", "Numeric3", "String1", "String2"], "Type": ["Integer", "Numeric", "Numeric", "Numeric", "String", "String"], - EHRAPY_TYPE_KEY: [NON_NUMERIC_TAG, NUMERIC_TAG, NUMERIC_TAG, "ignore", NON_NUMERIC_TAG, NON_NUMERIC_TAG], + FEATURE_TYPE_KEY: [CATEGORICAL_TAG, NUMERIC_TAG, NUMERIC_TAG, "ignore", CATEGORICAL_TAG, CATEGORICAL_TAG], } adata = AnnData( X=X_data, diff --git a/tests/tools/feature_ranking/test_feature_importances.py b/tests/tools/feature_ranking/test_feature_importances.py index 285ce091..6eb0ec48 100644 --- a/tests/tools/feature_ranking/test_feature_importances.py +++ b/tests/tools/feature_ranking/test_feature_importances.py @@ -3,7 +3,7 @@ import pytest from anndata import AnnData -from ehrapy.anndata._constants import CATEGORICAL_TAG, CONTINUOUS_TAG, DATE_TAG, FEATURE_TYPE_KEY +from ehrapy.anndata._constants import CATEGORICAL_TAG, DATE_TAG, FEATURE_TYPE_KEY, NUMERIC_TAG from ehrapy.tools import rank_features_supervised @@ -13,7 +13,7 @@ def test_continuous_prediction(): adata = AnnData(X) adata.var_names = ["target", "feature1", "feature2"] - adata.var[FEATURE_TYPE_KEY] = [CONTINUOUS_TAG] * 3 + adata.var[FEATURE_TYPE_KEY] = [NUMERIC_TAG] * 3 for model in ["regression", "svm", "rf"]: rank_features_supervised(adata, "target", model=model, input_features="all") diff --git a/tests/tools/feature_ranking/test_rank_features_groups.py b/tests/tools/feature_ranking/test_rank_features_groups.py index 39bd4557..88d13be6 100644 --- a/tests/tools/feature_ranking/test_rank_features_groups.py +++ b/tests/tools/feature_ranking/test_rank_features_groups.py @@ -6,6 +6,7 @@ import ehrapy as ep import ehrapy.tools.feature_ranking._rank_features_groups as _utils +from ehrapy.anndata._constants import FEATURE_TYPE_KEY, NUMERIC_TAG from ehrapy.io._read import read_csv CURRENT_DIR = Path(__file__).parent @@ -205,6 +206,10 @@ def test_get_groups_order(self): def test_evaluate_categorical_features(self): adata = ep.dt.mimic_2(encoded=False) + ep.ad.infer_feature_types(adata, output=None) + adata.var[FEATURE_TYPE_KEY].loc["hour_icu_intime"] = ( + NUMERIC_TAG # This is detected as categorical, so we need to correct that + ) adata = ep.pp.encode(adata, autodetect=True, encodings="label") group_names = pd.Categorical(adata.obs["service_unit"].astype(str)).categories.tolist() @@ -235,7 +240,6 @@ def test_evaluate_categorical_features(self): class TestRankFeaturesGroups: def test_real_dataset(self): adata = ep.dt.mimic_2(encoded=True) - ep.tl.rank_features_groups(adata, groupby="service_unit") assert "rank_features_groups" in adata.uns @@ -255,9 +259,8 @@ def test_real_dataset(self): def test_only_continous_features(self): adata = ep.dt.mimic_2(encoded=True) - adata.uns["non_numerical_columns"] = [] - ep.tl.rank_features_groups(adata, groupby="service_unit") + assert "rank_features_groups" in adata.uns assert "names" in adata.uns["rank_features_groups"] assert "pvals" in adata.uns["rank_features_groups"] @@ -267,8 +270,6 @@ def test_only_continous_features(self): def test_only_cat_features(self): adata = ep.dt.mimic_2(encoded=True) - adata.uns["numerical_columns"] = [] - ep.tl.rank_features_groups(adata, groupby="service_unit") assert "rank_features_groups" in adata.uns assert "names" in adata.uns["rank_features_groups"] @@ -312,7 +313,6 @@ def test_rank_features_groups_generates_outputs(self, field_to_rank): dataset_path=f"{_TEST_PATH}/dataset1.csv", columns_obs_only=["disease", "station", "sys_bp_entry", "dia_bp_entry"], ) - ep.tl.rank_features_groups(adata, groupby="disease", field_to_rank=field_to_rank) # check standard rank_features_groups entries