Skip to content

Commit

Permalink
Unify feature type detection (#724)
Browse files Browse the repository at this point in the history
* Updated cohort tracker to new feature type detection

* Improved date detection

* Started updating rank_features_groups

* Updated rank_features_groups to use infer_feature_types

* Updated anndata_ext and rank_features_groups to use infer_feature_types

* Updated encoding

* Encode only non-numerical categorical features

* Remove old encoding constants

* Rename CONTINUOUS_TAG to NUMERIC_TAG

* Updated dateutil dependency

* Fixed detection of categorical columns stored numerically

* Resolved Code ToDos

* Removed unused fixture

* Specific warnings

* Use warning instead of info

* Remove multi-column encoding

* Remove storing things in uns during encoding

* Looked through datasets until chronic_kidney_disease

* Fixed _reorder_encodings

* Allow for partial encoding with autodetect

* Dataloader modified until synthea_1k_sample

* Updated synthea_1k_sample dataloader

* Show encoding mode in feature_type_overview

* Updated encoding so that it doesn't save unencoded data in uns

* Updated QC to new encoding functionalities

* Added examples

* PR Reviews

* Updated submodule

* Removed unnecessary infer_feature_types calls

* Print only one feature type inferece warning for uncertain cases

* Type inference in encoding without complex

* Renamed correct_feature_types to replace_feature_types

* Updated usage.md
  • Loading branch information
Lilly-May authored May 19, 2024
1 parent 8a0b4f0 commit 6331db2
Show file tree
Hide file tree
Showing 25 changed files with 843 additions and 913 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
14 changes: 3 additions & 11 deletions docs/usage/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ Other than tools, preprocessing steps usually don’t return an easily interpret
:toctree: preprocessing
:nosignatures:
preprocessing.encode
preprocessing.pca
preprocessing.regress_out
preprocessing.subsample
Expand Down Expand Up @@ -115,16 +116,6 @@ Other than tools, preprocessing steps usually don’t return an easily interpret
preprocessing.mice_forest_impute
```

### Encoding

```{eval-rst}
.. autosummary::
:toctree: preprocessing
:nosignatures:
preprocessing.encode
```

### Normalization

```{eval-rst}
Expand Down Expand Up @@ -406,6 +397,8 @@ Methods that extract and visualize tool-specific annotation in an AnnData object
:nosignatures:
anndata.infer_feature_types
anndata.feature_type_overview
anndata.replace_feature_types
anndata.df_to_anndata
anndata.anndata_to_df
anndata.move_to_obs
Expand All @@ -414,7 +407,6 @@ Methods that extract and visualize tool-specific annotation in an AnnData object
anndata.get_obs_df
anndata.get_var_df
anndata.get_rank_features_df
anndata.type_overview
```

## Settings
Expand Down
11 changes: 8 additions & 3 deletions ehrapy/anndata/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from ehrapy.anndata._feature_specifications import check_feature_types, infer_feature_types
from ehrapy.anndata._feature_specifications import (
check_feature_types,
feature_type_overview,
infer_feature_types,
replace_feature_types,
)
from ehrapy.anndata.anndata_ext import (
anndata_to_df,
delete_from_obs,
Expand All @@ -10,11 +15,12 @@
move_to_obs,
move_to_x,
rank_genes_groups_df,
type_overview,
)

__all__ = [
"check_feature_types",
"replace_feature_types",
"feature_type_overview",
"infer_feature_types",
"anndata_to_df",
"delete_from_obs",
Expand All @@ -26,5 +32,4 @@
"move_to_obs",
"move_to_x",
"rank_genes_groups_df",
"type_overview",
]
8 changes: 1 addition & 7 deletions ehrapy/anndata/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,7 @@
# -----------------------
# The column name and used values in adata.var for column types.

EHRAPY_TYPE_KEY = "ehrapy_column_type" # TODO: Change to ENCODING_TYPE_KEY
NUMERIC_TAG = "numeric"
NON_NUMERIC_TAG = "non_numeric"
NON_NUMERIC_ENCODED_TAG = "non_numeric_encoded"


FEATURE_TYPE_KEY = "feature_type"
CONTINUOUS_TAG = "numeric" # TODO: Eventually rename to NUMERIC_TAG (as soon as the other NUMERIC_TAG is removed)
NUMERIC_TAG = "numeric"
CATEGORICAL_TAG = "categorical"
DATE_TAG = "date"
190 changes: 165 additions & 25 deletions ehrapy/anndata/_feature_specifications.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,71 @@
from collections.abc import Iterable
from functools import wraps
from typing import Literal

import numpy as np
import pandas as pd
from anndata import AnnData
from dateutil.parser import isoparse # type: ignore
from lamin_utils import logger
from rich import print
from rich.tree import Tree

from ehrapy.anndata._constants import CATEGORICAL_TAG, CONTINUOUS_TAG, DATE_TAG, FEATURE_TYPE_KEY
from ehrapy.anndata.anndata_ext import anndata_to_df
from ehrapy.anndata._constants import CATEGORICAL_TAG, DATE_TAG, FEATURE_TYPE_KEY, NUMERIC_TAG


def infer_feature_types(adata: AnnData, layer: str | None = None, output: Literal["tree", "dataframe"] | None = "tree"):
def _detect_feature_type(col: pd.Series) -> tuple[Literal["date", "categorical", "numeric"], bool]:
"""Detect the feature type of a column in a pandas DataFrame.
Args:
col: The column to detect the feature type for.
verbose: Whether to print warnings for uncertain feature types. Defaults to True.
Returns:
The detected feature type (one of 'date', 'categorical', or 'numeric') and a boolean, which is True if the feature type is uncertain.
"""
n_elements = len(col)
col = col.dropna()
if len(col) == 0:
raise ValueError(
f"Feature {col.name} has only NaN values. Please drop the feature if you want to infer the feature type."
)
majority_type = col.apply(type).value_counts().idxmax()

if majority_type == pd.Timestamp:
return DATE_TAG, False # type: ignore

if majority_type == str:
try:
col.apply(isoparse)
return DATE_TAG, False # type: ignore
except ValueError:
try:
col = pd.to_numeric(col, errors="raise") # Could be an encoded categorical or a numeric feature
majority_type = float
except ValueError:
# Features stored as Strings that cannot be converted to float are assumed to be categorical
return CATEGORICAL_TAG, False # type: ignore

if majority_type not in [int, float]:
return CATEGORICAL_TAG, False # type: ignore

# Guess categorical if the feature is an integer and the values are 0/1 to n-1/n with no gaps
if (
(majority_type == int or (np.all(i.is_integer() for i in col)))
and (n_elements != col.nunique())
and (
(col.min() == 0 and np.all(np.sort(col.unique()) == np.arange(col.nunique())))
or (col.min() == 1 and np.all(np.sort(col.unique()) == np.arange(1, col.nunique() + 1)))
)
):
return CATEGORICAL_TAG, True # type: ignore

return NUMERIC_TAG, False # type: ignore


def infer_feature_types(
adata: AnnData, layer: str | None = None, output: Literal["tree", "dataframe"] | None = "tree", verbose: bool = True
):
"""Infer feature types from AnnData object.
For each feature in adata.var_names, the method infers one of the following types: 'date', 'categorical', or 'numeric'.
Expand All @@ -27,32 +80,43 @@ def infer_feature_types(adata: AnnData, layer: str | None = None, output: Litera
layer: The layer to use from the AnnData object. If None, the X layer is used.
output: The output format. Choose between 'tree', 'dataframe', or None. If 'tree', the feature types will be printed to the console in a tree format.
If 'dataframe', a pandas DataFrame with the feature types will be returned. If None, nothing will be returned. Defaults to 'tree'.
verbose: Whether to print warnings for uncertain feature types. Defaults to True.
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=False)
>>> ep.ad.infer_feature_types(adata)
"""
from ehrapy.anndata.anndata_ext import anndata_to_df

feature_types = {}
uncertain_features = []

df = anndata_to_df(adata, layer=layer)
for feature in adata.var_names:
col = df[feature].dropna()
majority_type = col.apply(type).value_counts().idxmax()
if majority_type == pd.Timestamp:
feature_types[feature] = DATE_TAG
elif majority_type not in [int, float, complex]:
feature_types[feature] = CATEGORICAL_TAG
# Guess categorical if the feature is an integer and the values are 0/1 to n-1 with no gaps
elif np.all(i.is_integer() for i in col) and (
(col.min() == 0 and np.all(np.sort(col.unique()) == np.arange(col.nunique())))
or (col.min() == 1 and np.all(np.sort(col.unique()) == np.arange(1, col.nunique() + 1)))
if (
FEATURE_TYPE_KEY in adata.var.keys()
and adata.var[FEATURE_TYPE_KEY][feature] is not None
and not pd.isna(adata.var[FEATURE_TYPE_KEY][feature])
):
feature_types[feature] = CATEGORICAL_TAG
feature_types[feature] = adata.var[FEATURE_TYPE_KEY][feature]
else:
feature_types[feature] = CONTINUOUS_TAG
feature_types[feature], raise_warning = _detect_feature_type(df[feature])
if raise_warning:
uncertain_features.append(feature)

adata.var[FEATURE_TYPE_KEY] = pd.Series(feature_types)[adata.var_names]

logger.info(
f"Stored feature types in adata.var['{FEATURE_TYPE_KEY}']."
f" Please verify and adjust if necessary using adata.var['{FEATURE_TYPE_KEY}']['feature1']='corrected_type'."
)
if verbose:
logger.warning(
f"{'Features' if len(uncertain_features) >1 else 'Feature'} {str(uncertain_features)[1:-1]} {'were' if len(uncertain_features) >1 else 'was'} detected as categorical features stored numerically."
f"Please verify and correct using `ep.ad.replace_feature_types` if necessary."
)

logger.info(
f"Stored feature types in adata.var['{FEATURE_TYPE_KEY}']."
f" Please verify and adjust if necessary using `ep.ad.replace_feature_types`."
)

if output == "tree":
feature_type_overview(adata)
Expand All @@ -65,17 +129,51 @@ def infer_feature_types(adata: AnnData, layer: str | None = None, output: Litera
def check_feature_types(func):
@wraps(func)
def wrapper(adata, *args, **kwargs):
# Account for class methods that pass self as first argument
_self = None
if not isinstance(adata, AnnData) and len(args) > 0 and isinstance(args[0], AnnData):
_self = adata
adata = args[0]
args = args[1:]

if FEATURE_TYPE_KEY not in adata.var.keys():
raise ValueError("Feature types are not specified in adata.var. Please run `infer_feature_types` first.")
np.all(adata.var[FEATURE_TYPE_KEY].isin([CATEGORICAL_TAG, CONTINUOUS_TAG, DATE_TAG]))
infer_feature_types(adata, output=None)
logger.warning(
f"Feature types were inferred and stored in adata.var[{FEATURE_TYPE_KEY}]. Please verify using `ep.ad.feature_type_overview` and adjust if necessary using `ep.ad.replace_feature_types`."
)

for feature in adata.var_names:
feature_type = adata.var[FEATURE_TYPE_KEY][feature]
if (
feature_type is not None
and (not pd.isna(feature_type))
and feature_type not in [CATEGORICAL_TAG, NUMERIC_TAG, DATE_TAG]
):
logger.warning(
f"Feature '{feature}' has an invalid feature type '{feature_type}'. Please correct using `ep.ad.replace_feature_types`."
)

if _self is not None:
return func(_self, adata, *args, **kwargs)
return func(adata, *args, **kwargs)

return wrapper


@check_feature_types
def feature_type_overview(adata: AnnData):
"""Print an overview of the feature types in the AnnData object."""
"""Print an overview of the feature types and encoding modes in the AnnData object.
Args:
adata: The AnnData object storing the EHR data.
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=True)
>>> ep.ad.feature_type_overview(adata)
"""
from ehrapy.anndata.anndata_ext import anndata_to_df

tree = Tree(
f"[b] Detected feature types for AnnData object with {len(adata.obs_names)} obs and {len(adata.var_names)} vars",
guide_style="underline2",
Expand All @@ -86,13 +184,55 @@ def feature_type_overview(adata: AnnData):
branch.add(date)

branch = tree.add("📐[b] Numerical features")
for numeric in sorted(adata.var_names[adata.var[FEATURE_TYPE_KEY] == CONTINUOUS_TAG]):
for numeric in sorted(adata.var_names[adata.var[FEATURE_TYPE_KEY] == NUMERIC_TAG]):
branch.add(numeric)

branch = tree.add("🗂️[b] Categorical features")
cat_features = adata.var_names[adata.var[FEATURE_TYPE_KEY] == CATEGORICAL_TAG]
df = anndata_to_df(adata[:, cat_features])
for categorical in sorted(cat_features):
branch.add(f"{categorical} ({df.loc[:, categorical].nunique()} categories)")

if "encoding_mode" in adata.var.keys():
unencoded_vars = adata.var.loc[cat_features, "unencoded_var_names"].unique().tolist()

for unencoded in sorted(unencoded_vars):
if unencoded in adata.var_names:
branch.add(f"{unencoded} ({df.loc[:, unencoded].nunique()} categories)")
else:
enc_mode = adata.var.loc[adata.var["unencoded_var_names"] == unencoded, "encoding_mode"].values[0]
branch.add(f"{unencoded} ({adata.obs[unencoded].nunique()} categories); {enc_mode} encoded")

else:
for categorical in sorted(cat_features):
branch.add(f"{categorical} ({df.loc[:, categorical].nunique()} categories)")

print(tree)


def replace_feature_types(adata, features: Iterable[str], corrected_type: str):
"""Correct the feature types for a list of features inplace.
Args:
adata: :class:`~anndata.AnnData` object storing the EHR data.
features: The features to correct.
corrected_type: The corrected feature type. One of 'date', 'categorical', or 'numeric'.
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.diabetes_130_fairlearn()
>>> ep.ad.infer_feature_types(adata)
>>> ep.ad.replace_feature_types(adata, ["time_in_hospital", "number_diagnoses", "num_procedures"], "numeric")
"""
if corrected_type not in [CATEGORICAL_TAG, NUMERIC_TAG, DATE_TAG]:
raise ValueError(
f"Corrected type {corrected_type} not recognized. Choose between '{DATE_TAG}', '{CATEGORICAL_TAG}', or '{NUMERIC_TAG}'."
)

if FEATURE_TYPE_KEY not in adata.var.keys():
raise ValueError(
"Feature types were not inferred. Please infer feature types using 'infer_feature_types' before correcting."
)

if isinstance(features, str):
features = [features]

adata.var[FEATURE_TYPE_KEY].loc[features] = corrected_type
Loading

0 comments on commit 6331db2

Please sign in to comment.