Skip to content

Commit

Permalink
Refactor code in _omop.py files
Browse files Browse the repository at this point in the history
  • Loading branch information
xinyuejohn committed Feb 15, 2024
1 parent cffde3e commit 0bb953e
Show file tree
Hide file tree
Showing 4 changed files with 380 additions and 234 deletions.
42 changes: 20 additions & 22 deletions ehrdata/pl/_omop.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import List, Union, Literal, Optional
from ehrdata.utils.omop_utils import *
from ehrdata.tl import get_concept_name
import seaborn as sns
from typing import Literal

import matplotlib.pyplot as plt
import seaborn as sns

from ehrdata.tl import get_concept_name
from ehrdata.utils.omop_utils import get_column_types, map_concept_id, read_table


# TODO allow users to pass features
def feature_counts(
Expand All @@ -17,29 +20,24 @@ def feature_counts(
"condition_occurrence",
],
number=20,
key = None
):

if source == 'measurement':
columns = ["value_as_number", "time", "visit_occurrence_id", "measurement_concept_id"]
elif source == 'observation':
columns = ["value_as_number", "value_as_string", "measurement_datetime"]
elif source == 'condition_occurrence':
columns = None
else:
raise KeyError(f"Extracting data from {source} is not supported yet")

filepath_dict = adata.uns['filepath_dict']
tables = adata.uns['tables']

key=None,
):
# if source == 'measurement':
# columns = ["value_as_number", "time", "visit_occurrence_id", "measurement_concept_id"]
# elif source == 'observation':
# columns = ["value_as_number", "value_as_string", "measurement_datetime"]
# elif source == 'condition_occurrence':
# columns = None
# else:
# raise KeyError(f"Extracting data from {source} is not supported yet")

column_types = get_column_types(adata.uns, table_name=source)
df_source = read_table(adata.uns, table_name=source, dtype=column_types, usecols=[f"{source}_concept_id"])
feature_counts = df_source[f"{source}_concept_id"].value_counts()
if adata.uns['use_dask']:
if adata.uns["use_dask"]:
feature_counts = feature_counts.compute()
feature_counts = feature_counts.to_frame().reset_index(drop=False)[0:number]


feature_counts[f"{source}_concept_id_1"], feature_counts[f"{source}_concept_id_2"] = map_concept_id(
adata.uns, concept_id=feature_counts[f"{source}_concept_id"], verbose=False
)
Expand All @@ -56,4 +54,4 @@ def feature_counts(
ax = sns.barplot(feature_counts, x="feature_name", y="count")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
plt.tight_layout()
return feature_counts
return feature_counts
93 changes: 52 additions & 41 deletions ehrdata/pp/_omop.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import List, Union, Literal, Optional
from ehrdata.utils.omop_utils import *
import ehrapy as ep
import warnings
from typing import Literal, Union

import ehrapy as ep
import pandas as pd
from rich import print as rprint

from ehrdata.utils.omop_utils import get_column_types, get_feature_info, read_table


def get_feature_statistics(
adata,
Expand All @@ -14,10 +19,12 @@ def get_feature_statistics(
"drug_exposure",
"condition_occurrence",
],
features: Union[str, int , List[Union[str, int]]] = None,
features: Union[str, int, list[Union[str, int]]] = None,
level="stay_level",
value_col: str = None,
aggregation_methods: Union[Literal["min", "max", "mean", "std", "count"], List[Literal["min", "max", "mean", "std", "count"]]]=None,
aggregation_methods: Union[
Literal["min", "max", "mean", "std", "count"], list[Literal["min", "max", "mean", "std", "count"]]
] = None,
add_aggregation_to_X: bool = True,
verbose: bool = False,
use_dask: bool = None,
Expand All @@ -28,16 +35,22 @@ def get_feature_statistics(
key = f"{source.split('_')[0]}_concept_id"
else:
raise KeyError(f"Extracting data from {source} is not supported yet")

if source == 'measurement':
value_col = 'value_as_number'
warnings.warn(f"Extracting values from {value_col}. Value in measurement table could be saved in these columns: value_as_number, value_source_value.\nSpecify value_col to extract value from desired column.")
source_table_columns = ['visit_occurrence_id', 'measurement_datetime', key, value_col]
elif source == 'observation':
value_col = 'value_as_number'
warnings.warn(f"Extracting values from {value_col}. Value in observation table could be saved in these columns: value_as_number, value_as_string, value_source_value.\nSpecify value_col to extract value from desired column.")
source_table_columns = ['visit_occurrence_id', "observation_datetime", key, value_col]
elif source == 'condition_occurrence':

if source == "measurement":
value_col = "value_as_number"
warnings.warn(
f"Extracting values from {value_col}. Value in measurement table could be saved in these columns: value_as_number, value_source_value.\nSpecify value_col to extract value from desired column.",
stacklevel=2,
)
source_table_columns = ["visit_occurrence_id", "measurement_datetime", key, value_col]
elif source == "observation":
value_col = "value_as_number"
warnings.warn(
f"Extracting values from {value_col}. Value in observation table could be saved in these columns: value_as_number, value_as_string, value_source_value.\nSpecify value_col to extract value from desired column.",
stacklevel=2,
)
source_table_columns = ["visit_occurrence_id", "observation_datetime", key, value_col]
elif source == "condition_occurrence":
source_table_columns = None
else:
raise KeyError(f"Extracting data from {source} is not supported yet")
Expand All @@ -49,62 +62,60 @@ def get_feature_statistics(
use_dask = True

column_types = get_column_types(adata.uns, table_name=source)
df_source = read_table(adata.uns, table_name=source, dtype=column_types, usecols=source_table_columns, use_dask=use_dask)

df_source = read_table(
adata.uns, table_name=source, dtype=column_types, usecols=source_table_columns, use_dask=use_dask
)

info_df = get_feature_info(adata.uns, features=features, verbose=verbose)
info_dict = info_df[['feature_id', 'feature_name']].set_index('feature_id').to_dict()['feature_name']
info_dict = info_df[["feature_id", "feature_name"]].set_index("feature_id").to_dict()["feature_name"]

# Select featrues
df_source = df_source[df_source[key].isin(list(info_df.feature_id))]
#TODO Select time
#da_measurement = da_measurement[(da_measurement.time >= 0) & (da_measurement.time <= 48*60*60)]
#df_source[f'{source}_name'] = df_source[key].map(info_dict)
# TODO Select time
# da_measurement = da_measurement[(da_measurement.time >= 0) & (da_measurement.time <= 48*60*60)]
# df_source[f'{source}_name'] = df_source[key].map(info_dict)
if aggregation_methods is None:
aggregation_methods = ["min", "max", "mean", "std", "count"]
if level == 'stay_level':
result = df_source.groupby(['visit_occurrence_id', key]).agg({
value_col: aggregation_methods})

if level == "stay_level":
result = df_source.groupby(["visit_occurrence_id", key]).agg({value_col: aggregation_methods})

if use_dask:
result = result.compute()
result = result.reset_index(drop=False)
result.columns = ["_".join(a) for a in result.columns.to_flat_index()]
result.columns = result.columns.str.removesuffix('_')
result.columns = result.columns.str.removeprefix(f'{value_col}_')
result[f'{source}_name'] = result[key].map(info_dict)
result.columns = result.columns.str.removesuffix("_")
result.columns = result.columns.str.removeprefix(f"{value_col}_")
result[f"{source}_name"] = result[key].map(info_dict)

df_statistics = result.pivot(index='visit_occurrence_id',
columns=f'{source}_name',
values=aggregation_methods)
df_statistics = result.pivot(index="visit_occurrence_id", columns=f"{source}_name", values=aggregation_methods)
df_statistics.columns = df_statistics.columns.swaplevel()
df_statistics.columns = ["_".join(a) for a in df_statistics.columns.to_flat_index()]


# TODO
sort_columns = True
if sort_columns:
new_column_order = []
for feature in features:
for suffix in (f'_{aggregation_method}' for aggregation_method in aggregation_methods):
col_name = f'{feature}{suffix}'
for suffix in (f"_{aggregation_method}" for aggregation_method in aggregation_methods):
col_name = f"{feature}{suffix}"
if col_name in df_statistics.columns:
new_column_order.append(col_name)

df_statistics.columns = new_column_order

df_statistics.index = df_statistics.index.astype(str)
adata.obs = pd.merge(adata.obs, df_statistics, how='left', left_index=True, right_index=True)

adata.obs = pd.merge(adata.obs, df_statistics, how="left", left_index=True, right_index=True)

if add_aggregation_to_X:
uns = adata.uns
obsm = adata.obsm
varm = adata.varm
layers = adata.layers
# layers = adata.layers
adata = ep.ad.move_to_x(adata, list(df_statistics.columns))
adata.uns = uns
adata.obsm = obsm
adata.varm = varm
# It will change
# adata.layers = layers
return adata
return adata
29 changes: 15 additions & 14 deletions ehrdata/tl/_omop.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
from ehrdata.utils.omop_utils import * #get_column_types, read_table, df_to_dict
from typing import List, Union, Literal, Optional, Dict
import numbers
from rich import print as rprint
from typing import Union

from anndata import AnnData
from rich import print as rprint

from ehrdata.utils.omop_utils import df_to_dict, get_column_types, read_table

def get_concept_name(
adata: Union[AnnData, Dict],
concept_id: Union[str, List],
raise_error=False,
verbose=True):


def get_concept_name(adata: Union[AnnData, dict], concept_id: Union[str, list], raise_error=False, verbose=True):
if isinstance(concept_id, numbers.Integral):
concept_id = [concept_id]

if isinstance(adata, AnnData):
adata_dict = adata.uns
else:
adata_dict = adata

column_types = get_column_types(adata_dict, table_name="concept")
df_concept = read_table(adata_dict, table_name="concept", dtype=column_types)
# TODO dask Support
#df_concept.compute().dropna(subset=["concept_id", "concept_name"], inplace=True, ignore_index=True) # usecols=vocabularies_tables_columns["concept"]
df_concept.dropna(subset=["concept_id", "concept_name"], inplace=True, ignore_index=True) # usecols=vocabularies_tables_columns["concept"]
# df_concept.compute().dropna(subset=["concept_id", "concept_name"], inplace=True, ignore_index=True) # usecols=vocabularies_tables_columns["concept"]
df_concept.dropna(
subset=["concept_id", "concept_name"], inplace=True, ignore_index=True
) # usecols=vocabularies_tables_columns["concept"]
concept_dict = df_to_dict(df=df_concept, key="concept_id", value="concept_name")
concept_name = []
concept_name_not_found = []
Expand All @@ -43,6 +43,7 @@ def get_concept_name(
else:
return concept_name


# TODO
def get_concept_id():
pass
pass
Loading

0 comments on commit 0bb953e

Please sign in to comment.