diff --git a/ehrdata.py b/ehrdata.py index 534839c..98c6725 100644 --- a/ehrdata.py +++ b/ehrdata.py @@ -1,26 +1,17 @@ -import awkward as ak -import numpy as np -import pandas as pd import csv -import pandas as pd -import matplotlib.pyplot as plt -import seaborn as sns -import ehrapy as ep -import scanpy as sc -from anndata import AnnData -import mudata as md -from mudata import MuData -from typing import List, Union, Literal -import os import glob -import dask.dataframe as dd -from thefuzz import process -import sys -from rich import print as rprint -import missingno as msno -import warnings import numbers +import os +import warnings +from typing import Literal, Union +import awkward as ak +import dask.dataframe as dd +import ehrapy as ep +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +from rich import print as rprint clinical_tables_columns = { "person": ["person_id", "year_of_birth", "gender_source_value"], @@ -102,7 +93,6 @@ def get_close_matches_using_dict(word, possibilities, n=2, cutoff=0.6): Optional arg cutoff (default 0.6) is a float in [0, 1]. Possibilities that don't score at least that similar to word are ignored. """ - if not n > 0: raise ValueError("n must be > 0: %r" % (n,)) if not 0.0 <= cutoff <= 1.0: @@ -133,7 +123,7 @@ def get_column_types(csv_path=None, columns=None): column_types = {} parse_dates = [] if csv_path: - with open(csv_path, "r") as f: + with open(csv_path) as f: dict_reader = csv.DictReader(f) columns = dict_reader.fieldnames columns_lowercase = [column.lower() for column in columns] @@ -311,14 +301,14 @@ def load(self, level="stay_level", tables=["visit_occurrence", "person", "death" for visit_occurrence_id in adata.obs.index: obs_list.append(list(self.measurement[self.measurement['visit_occurrence_id'] == int(visit_occurrence_id)][column])) adata.obsm[column]= ak.Array(obs_list) - + for column in self.drug_exposure.columns: if column != 'visit_occurrence_id': obs_list = [] for visit_occurrence_id in adata.obs.index: obs_list.append(list(self.drug_exposure[self.drug_exposure['visit_occurrence_id'] == int(visit_occurrence_id)][column])) adata.obsm[column]= ak.Array(obs_list) - + for column in self.observation.columns: if column != 'visit_occurrence_id': obs_list = [] @@ -368,7 +358,7 @@ def feature_statistics( plt.tight_layout() return feature_counts - def map_concept_id(self, concept_id: Union[str, List], verbose=True): + def map_concept_id(self, concept_id: Union[str, list], verbose=True): column_types, parse_dates = get_column_types(self.filepath["concept_relationship"]) df_concept_relationship = dd.read_csv( self.filepath["concept_relationship"], dtype=column_types, parse_dates=parse_dates @@ -413,7 +403,7 @@ def map_concept_id(self, concept_id: Union[str, List], verbose=True): else: return concept_id_1, concept_id_2 - def get_concept_name(self, concept_id: Union[str, List], raise_error=False, verbose=True): + def get_concept_name(self, concept_id: Union[str, list], raise_error=False, verbose=True): if isinstance(concept_id, numbers.Integral): concept_id = [concept_id] @@ -474,9 +464,9 @@ def extract_features( "drug_exposure", "condition_occurrence", ], - features: str or int or List[Union[str, int]] = None, + features: str or int or list[Union[str, int]] = None, key: str = None, - columns_in_source_table: str or List[str] = None, + columns_in_source_table: str or list[str] = None, map_concept=True, add_aggregation_to_X: bool = True, aggregation_methods=None,