Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pre-commit.ci] pre-commit autoupdate #5

Merged
merged 14 commits into from
Feb 15, 2024
Prev Previous commit
Next Next commit
Refactor code in _omop.py files
xinyuejohn committed Feb 15, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 012b30aaffcc4dac71cfce3f98b5a7f6d519b650
42 changes: 20 additions & 22 deletions ehrdata/pl/_omop.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import List, Union, Literal, Optional
from ehrdata.utils.omop_utils import *
from ehrdata.tl import get_concept_name
import seaborn as sns
from typing import Literal

import matplotlib.pyplot as plt
import seaborn as sns

from ehrdata.tl import get_concept_name
from ehrdata.utils.omop_utils import get_column_types, map_concept_id, read_table


# TODO allow users to pass features
def feature_counts(
@@ -17,29 +20,24 @@ def feature_counts(
"condition_occurrence",
],
number=20,
key = None
):

if source == 'measurement':
columns = ["value_as_number", "time", "visit_occurrence_id", "measurement_concept_id"]
elif source == 'observation':
columns = ["value_as_number", "value_as_string", "measurement_datetime"]
elif source == 'condition_occurrence':
columns = None
else:
raise KeyError(f"Extracting data from {source} is not supported yet")

filepath_dict = adata.uns['filepath_dict']
tables = adata.uns['tables']

key=None,
):
# if source == 'measurement':
# columns = ["value_as_number", "time", "visit_occurrence_id", "measurement_concept_id"]
# elif source == 'observation':
# columns = ["value_as_number", "value_as_string", "measurement_datetime"]
# elif source == 'condition_occurrence':
# columns = None
# else:
# raise KeyError(f"Extracting data from {source} is not supported yet")

column_types = get_column_types(adata.uns, table_name=source)
df_source = read_table(adata.uns, table_name=source, dtype=column_types, usecols=[f"{source}_concept_id"])
feature_counts = df_source[f"{source}_concept_id"].value_counts()
if adata.uns['use_dask']:
if adata.uns["use_dask"]:
feature_counts = feature_counts.compute()
feature_counts = feature_counts.to_frame().reset_index(drop=False)[0:number]


feature_counts[f"{source}_concept_id_1"], feature_counts[f"{source}_concept_id_2"] = map_concept_id(
adata.uns, concept_id=feature_counts[f"{source}_concept_id"], verbose=False
)
@@ -56,4 +54,4 @@ def feature_counts(
ax = sns.barplot(feature_counts, x="feature_name", y="count")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
plt.tight_layout()
return feature_counts
return feature_counts
93 changes: 52 additions & 41 deletions ehrdata/pp/_omop.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import List, Union, Literal, Optional
from ehrdata.utils.omop_utils import *
import ehrapy as ep
import warnings
from typing import Literal, Union

import ehrapy as ep
import pandas as pd
from rich import print as rprint

from ehrdata.utils.omop_utils import get_column_types, get_feature_info, read_table


def get_feature_statistics(
adata,
@@ -14,10 +19,12 @@ def get_feature_statistics(
"drug_exposure",
"condition_occurrence",
],
features: Union[str, int , List[Union[str, int]]] = None,
features: Union[str, int, list[Union[str, int]]] = None,
level="stay_level",
value_col: str = None,
aggregation_methods: Union[Literal["min", "max", "mean", "std", "count"], List[Literal["min", "max", "mean", "std", "count"]]]=None,
aggregation_methods: Union[
Literal["min", "max", "mean", "std", "count"], list[Literal["min", "max", "mean", "std", "count"]]
] = None,
add_aggregation_to_X: bool = True,
verbose: bool = False,
use_dask: bool = None,
@@ -28,16 +35,22 @@ def get_feature_statistics(
key = f"{source.split('_')[0]}_concept_id"
else:
raise KeyError(f"Extracting data from {source} is not supported yet")

if source == 'measurement':
value_col = 'value_as_number'
warnings.warn(f"Extracting values from {value_col}. Value in measurement table could be saved in these columns: value_as_number, value_source_value.\nSpecify value_col to extract value from desired column.")
source_table_columns = ['visit_occurrence_id', 'measurement_datetime', key, value_col]
elif source == 'observation':
value_col = 'value_as_number'
warnings.warn(f"Extracting values from {value_col}. Value in observation table could be saved in these columns: value_as_number, value_as_string, value_source_value.\nSpecify value_col to extract value from desired column.")
source_table_columns = ['visit_occurrence_id', "observation_datetime", key, value_col]
elif source == 'condition_occurrence':

if source == "measurement":
value_col = "value_as_number"
warnings.warn(
f"Extracting values from {value_col}. Value in measurement table could be saved in these columns: value_as_number, value_source_value.\nSpecify value_col to extract value from desired column.",
stacklevel=2,
)
source_table_columns = ["visit_occurrence_id", "measurement_datetime", key, value_col]
elif source == "observation":
value_col = "value_as_number"
warnings.warn(
f"Extracting values from {value_col}. Value in observation table could be saved in these columns: value_as_number, value_as_string, value_source_value.\nSpecify value_col to extract value from desired column.",
stacklevel=2,
)
source_table_columns = ["visit_occurrence_id", "observation_datetime", key, value_col]
elif source == "condition_occurrence":
source_table_columns = None
else:
raise KeyError(f"Extracting data from {source} is not supported yet")
@@ -49,62 +62,60 @@ def get_feature_statistics(
use_dask = True

column_types = get_column_types(adata.uns, table_name=source)
df_source = read_table(adata.uns, table_name=source, dtype=column_types, usecols=source_table_columns, use_dask=use_dask)

df_source = read_table(
adata.uns, table_name=source, dtype=column_types, usecols=source_table_columns, use_dask=use_dask
)

info_df = get_feature_info(adata.uns, features=features, verbose=verbose)
info_dict = info_df[['feature_id', 'feature_name']].set_index('feature_id').to_dict()['feature_name']
info_dict = info_df[["feature_id", "feature_name"]].set_index("feature_id").to_dict()["feature_name"]

# Select featrues
df_source = df_source[df_source[key].isin(list(info_df.feature_id))]
#TODO Select time
#da_measurement = da_measurement[(da_measurement.time >= 0) & (da_measurement.time <= 48*60*60)]
#df_source[f'{source}_name'] = df_source[key].map(info_dict)
# TODO Select time
# da_measurement = da_measurement[(da_measurement.time >= 0) & (da_measurement.time <= 48*60*60)]
# df_source[f'{source}_name'] = df_source[key].map(info_dict)
if aggregation_methods is None:
aggregation_methods = ["min", "max", "mean", "std", "count"]
if level == 'stay_level':
result = df_source.groupby(['visit_occurrence_id', key]).agg({
value_col: aggregation_methods})

if level == "stay_level":
result = df_source.groupby(["visit_occurrence_id", key]).agg({value_col: aggregation_methods})

if use_dask:
result = result.compute()
result = result.reset_index(drop=False)
result.columns = ["_".join(a) for a in result.columns.to_flat_index()]
result.columns = result.columns.str.removesuffix('_')
result.columns = result.columns.str.removeprefix(f'{value_col}_')
result[f'{source}_name'] = result[key].map(info_dict)
result.columns = result.columns.str.removesuffix("_")
result.columns = result.columns.str.removeprefix(f"{value_col}_")
result[f"{source}_name"] = result[key].map(info_dict)

df_statistics = result.pivot(index='visit_occurrence_id',
columns=f'{source}_name',
values=aggregation_methods)
df_statistics = result.pivot(index="visit_occurrence_id", columns=f"{source}_name", values=aggregation_methods)
df_statistics.columns = df_statistics.columns.swaplevel()
df_statistics.columns = ["_".join(a) for a in df_statistics.columns.to_flat_index()]


# TODO
sort_columns = True
if sort_columns:
new_column_order = []
for feature in features:
for suffix in (f'_{aggregation_method}' for aggregation_method in aggregation_methods):
col_name = f'{feature}{suffix}'
for suffix in (f"_{aggregation_method}" for aggregation_method in aggregation_methods):
col_name = f"{feature}{suffix}"
if col_name in df_statistics.columns:
new_column_order.append(col_name)

df_statistics.columns = new_column_order

df_statistics.index = df_statistics.index.astype(str)
adata.obs = pd.merge(adata.obs, df_statistics, how='left', left_index=True, right_index=True)

adata.obs = pd.merge(adata.obs, df_statistics, how="left", left_index=True, right_index=True)

if add_aggregation_to_X:
uns = adata.uns
obsm = adata.obsm
varm = adata.varm
layers = adata.layers
# layers = adata.layers
adata = ep.ad.move_to_x(adata, list(df_statistics.columns))
adata.uns = uns
adata.obsm = obsm
adata.varm = varm
# It will change
# adata.layers = layers
return adata
return adata
29 changes: 15 additions & 14 deletions ehrdata/tl/_omop.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
from ehrdata.utils.omop_utils import * #get_column_types, read_table, df_to_dict
from typing import List, Union, Literal, Optional, Dict
import numbers
from rich import print as rprint
from typing import Union

from anndata import AnnData
from rich import print as rprint

from ehrdata.utils.omop_utils import df_to_dict, get_column_types, read_table

def get_concept_name(
adata: Union[AnnData, Dict],
concept_id: Union[str, List],
raise_error=False,
verbose=True):


def get_concept_name(adata: Union[AnnData, dict], concept_id: Union[str, list], raise_error=False, verbose=True):
if isinstance(concept_id, numbers.Integral):
concept_id = [concept_id]

if isinstance(adata, AnnData):
adata_dict = adata.uns
else:
adata_dict = adata

column_types = get_column_types(adata_dict, table_name="concept")
df_concept = read_table(adata_dict, table_name="concept", dtype=column_types)
# TODO dask Support
#df_concept.compute().dropna(subset=["concept_id", "concept_name"], inplace=True, ignore_index=True) # usecols=vocabularies_tables_columns["concept"]
df_concept.dropna(subset=["concept_id", "concept_name"], inplace=True, ignore_index=True) # usecols=vocabularies_tables_columns["concept"]
# df_concept.compute().dropna(subset=["concept_id", "concept_name"], inplace=True, ignore_index=True) # usecols=vocabularies_tables_columns["concept"]
df_concept.dropna(
subset=["concept_id", "concept_name"], inplace=True, ignore_index=True
) # usecols=vocabularies_tables_columns["concept"]
concept_dict = df_to_dict(df=df_concept, key="concept_id", value="concept_name")
concept_name = []
concept_name_not_found = []
@@ -43,6 +43,7 @@ def get_concept_name(
else:
return concept_name


# TODO
def get_concept_id():
pass
pass
Loading