Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 8, 2024
1 parent 39c78eb commit e2b60df
Showing 1 changed file with 17 additions and 27 deletions.
44 changes: 17 additions & 27 deletions ehrdata.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,17 @@
import awkward as ak
import numpy as np
import pandas as pd
import csv
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import ehrapy as ep
import scanpy as sc
from anndata import AnnData
import mudata as md
from mudata import MuData
from typing import List, Union, Literal
import os
import glob
import dask.dataframe as dd
from thefuzz import process
import sys
from rich import print as rprint
import missingno as msno
import warnings
import numbers
import os
import warnings
from typing import Literal, Union

import awkward as ak
import dask.dataframe as dd
import ehrapy as ep
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from rich import print as rprint

clinical_tables_columns = {
"person": ["person_id", "year_of_birth", "gender_source_value"],
Expand Down Expand Up @@ -102,7 +93,6 @@ def get_close_matches_using_dict(word, possibilities, n=2, cutoff=0.6):
Optional arg cutoff (default 0.6) is a float in [0, 1]. Possibilities
that don't score at least that similar to word are ignored.
"""

if not n > 0:
raise ValueError("n must be > 0: %r" % (n,))
if not 0.0 <= cutoff <= 1.0:
Expand Down Expand Up @@ -133,7 +123,7 @@ def get_column_types(csv_path=None, columns=None):
column_types = {}
parse_dates = []
if csv_path:
with open(csv_path, "r") as f:
with open(csv_path) as f:
dict_reader = csv.DictReader(f)
columns = dict_reader.fieldnames
columns_lowercase = [column.lower() for column in columns]
Expand Down Expand Up @@ -311,14 +301,14 @@ def load(self, level="stay_level", tables=["visit_occurrence", "person", "death"
for visit_occurrence_id in adata.obs.index:
obs_list.append(list(self.measurement[self.measurement['visit_occurrence_id'] == int(visit_occurrence_id)][column]))
adata.obsm[column]= ak.Array(obs_list)
for column in self.drug_exposure.columns:
if column != 'visit_occurrence_id':
obs_list = []
for visit_occurrence_id in adata.obs.index:
obs_list.append(list(self.drug_exposure[self.drug_exposure['visit_occurrence_id'] == int(visit_occurrence_id)][column]))
adata.obsm[column]= ak.Array(obs_list)
for column in self.observation.columns:
if column != 'visit_occurrence_id':
obs_list = []
Expand Down Expand Up @@ -368,7 +358,7 @@ def feature_statistics(
plt.tight_layout()
return feature_counts

def map_concept_id(self, concept_id: Union[str, List], verbose=True):
def map_concept_id(self, concept_id: Union[str, list], verbose=True):
column_types, parse_dates = get_column_types(self.filepath["concept_relationship"])
df_concept_relationship = dd.read_csv(
self.filepath["concept_relationship"], dtype=column_types, parse_dates=parse_dates
Expand Down Expand Up @@ -413,7 +403,7 @@ def map_concept_id(self, concept_id: Union[str, List], verbose=True):
else:
return concept_id_1, concept_id_2

def get_concept_name(self, concept_id: Union[str, List], raise_error=False, verbose=True):
def get_concept_name(self, concept_id: Union[str, list], raise_error=False, verbose=True):
if isinstance(concept_id, numbers.Integral):
concept_id = [concept_id]

Expand Down Expand Up @@ -474,9 +464,9 @@ def extract_features(
"drug_exposure",
"condition_occurrence",
],
features: str or int or List[Union[str, int]] = None,
features: str or int or list[Union[str, int]] = None,
key: str = None,
columns_in_source_table: str or List[str] = None,
columns_in_source_table: str or list[str] = None,
map_concept=True,
add_aggregation_to_X: bool = True,
aggregation_methods=None,
Expand Down

0 comments on commit e2b60df

Please sign in to comment.