Skip to content

Commit

Permalink
Enhancement/core ehrdata class (#38)
Browse files Browse the repository at this point in the history
* add EHRData class

* ehrdata w/ 3D .r field and sliceing

* remove comments

* introduce setup_obs and setup_var, some fixes

* 2 more tests

* nullable .r
  • Loading branch information
eroell authored Oct 8, 2024
1 parent 6c2230a commit dd37f01
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 23 deletions.
31 changes: 30 additions & 1 deletion src/ehrdata/core/ehrdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def __init__(

if r is not None:
self.layers[R_LAYER_KEY] = r
# else:
# self.layers[R_LAYER_KEY] = np.zeros((self._adata.shape[0], self._adata.shape[1], 0))

if t is not None:
if isinstance(t, pd.DataFrame):
Expand All @@ -74,6 +76,8 @@ def __init__(
else:
if R_LAYER_KEY not in self.layers.keys():
self.t = pd.DataFrame(pd.RangeIndex(1))
elif len(self.layers[R_LAYER_KEY].shape) <= 2:
self.t = pd.DataFrame(pd.RangeIndex(1))
else:
self.t = pd.DataFrame(index=pd.RangeIndex(self.layers[R_LAYER_KEY].shape[2]))

Expand Down Expand Up @@ -143,6 +147,24 @@ def varp(self):
def varp(self, input):
self._adata.varp = input

@property
def var_names(self):
"""Field from AnnData."""
return self._adata.var_names

@var_names.setter
def var_names(self, input):
self._adata.var_names = input

@property
def obs_names(self):
"""Field from AnnData."""
return self._adata.obs_names

@obs_names.setter
def obs_names(self, input):
self._adata.obs_names = input

@property
def uns(self):
"""Field from AnnData."""
Expand All @@ -166,7 +188,10 @@ def layers(self, key, input):
@property
def r(self):
"""3-Dimensional tensor, aligned with obs along first axis, var along second axis, and allowing a 3rd axis."""
return self._adata.layers[R_LAYER_KEY]
if R_LAYER_KEY not in self._adata.layers.keys():
return None
else:
return self._adata.layers[R_LAYER_KEY]

@r.setter
def r(self, input):
Expand Down Expand Up @@ -212,3 +237,7 @@ def _unpack_index(self, index):
return index[0], slice(None)
else:
raise IndexError("invalid number of indices")

def copy(self):
"""Returns a copy of the EHRData object."""
return EHRData(adata=self._adata.copy(), t=self._t.copy())
6 changes: 4 additions & 2 deletions src/ehrdata/io/omop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
extract_observation,
extract_observation_period,
extract_person,
extract_person_observation_period,
extract_procedure_occurrence,
extract_specimen,
extract_tables,
get_time_interval_table,
load,
time_interval_table,
setup_obs,
setup_variables,
)
158 changes: 139 additions & 19 deletions src/ehrdata/io/omop/omop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,138 @@ def _check_sanity_of_database(backend_handle: duckdb.DuckDB):
pass


def setup_obs(
backend_handle: Literal[str, duckdb, Path],
observation_table: Literal["person", "observation_period", "person_observation_period", "condition_occurrence"],
):
"""Setup the observation table.
This function sets up the observation table for the EHRData project.
For this, a table from the OMOP CDM which represents to observed unit should be selected.
A unit can be a person, an observation period, the join of these two tables, or a condition occurrence.
Parameters
----------
backend_handle
The backend handle to the database.
observation_table
The observation table to be used.
Returns
-------
An EHRData object with populated .obs field.
"""
from ehrdata import EHRData

if observation_table == "person":
obs = extract_person(backend_handle)
elif observation_table == "observation_period":
obs = extract_observation_period(backend_handle)
elif observation_table == "person_observation_period":
obs = extract_person_observation_period(backend_handle)
elif observation_table == "condition_occurrence":
obs = extract_condition_occurrence(backend_handle)
else:
raise ValueError("observation_table must be either 'person', 'observation_period', or 'condition_occurrence'.")

return EHRData(obs=obs)


def setup_variables(
backend_handle: Literal[str, duckdb, Path],
edata,
tables: Sequence[
Literal[
"measurement", "observation", "procedure_occurrence", "specimen", "device_exposure", "drug_exposure", "note"
]
],
start_time: Literal["observation_period_start"] | pd.Timestamp | str,
interval_length_number: int,
interval_length_unit: str,
num_intervals: int,
concept_ids: Literal["all"] | Sequence = "all",
aggregation_strategy: str = "last",
):
"""Setup the variables.
This function sets up the variables for the EHRData project.
For this, a selection of tables from the OMOP CDM which represents the variables should be selected.
The tables can be measurement, observation, procedure_occurrence, specimen, device_exposure, drug_exposure, or note.
Parameters
----------
backend_handle
The backend handle to the database.
edata
The EHRData object to which the variables should be added.
tables
The tables to be used.
start_time
Starting time for values to be included. Can be 'observation_period' start, which takes the 'observation_period_start' value from obs, or a specific Timestamp.
interval_length_number
Numeric value of the length of one interval.
interval_length_unit
Unit belonging to the interval length. See the units of `pandas.to_timedelta <https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.to_timedelta.html>`_
num_intervals
Numer of intervals
Returns
-------
An EHRData object with populated .var field.
"""
from ehrdata import EHRData

concept_ids_present_list = []
time_interval_tables = []
for table in tables:
if table == "measurement":
concept_ids_present = (
backend_handle.sql("SELECT * FROM measurement").df()["measurement_concept_id"].unique()
)
extracted_awkward = extract_measurement(backend_handle)
time_interval_table = get_time_interval_table(
backend_handle,
extracted_awkward,
edata.obs,
start_time="observation_period_start",
interval_length_number=interval_length_number,
interval_length_unit=interval_length_unit,
num_intervals=num_intervals,
concept_ids=concept_ids,
aggregation_strategy=aggregation_strategy,
)
# TODO: implement the following
# elif table == "observation":
# var = extract_observation(backend_handle)
# elif table == "procedure_occurrence":
# var = extract_procedure_occurrence(backend_handle)
# elif table == "specimen":
# var = extract_specimen(backend_handle)
# elif table == "device_exposure":
# var = extract_device_exposure(backend_handle)
# elif table == "drug_exposure":
# var = extract_drug_exposure(backend_handle)
# elif table == "note":
# var = extract_note(backend_handle)
else:
raise ValueError(
"tables must be a sequence of 'measurement', 'observation', 'procedure_occurrence', 'specimen', 'device_exposure', 'drug_exposure', or 'note'."
)
concept_ids_present_list.append(concept_ids_present)
time_interval_tables.append(time_interval_table)
if len(time_interval_tables) > 1:
time_interval_table = np.concatenate([time_interval_table, time_interval_table], axis=1)
concept_ids_present = pd.concat(concept_ids_present_list)
else:
time_interval_table = time_interval_tables[0]
concept_ids_present = concept_ids_present_list[0]

# TODO: copy other fields too. or other way? is is somewhat scverse-y by taking and returing anndata object...
edata = EHRData(r=time_interval_table, obs=edata.obs, var=concept_ids_present)

return edata


def load(
backend_handle: Literal[str, duckdb, Path],
# folder_path: str,
Expand All @@ -37,29 +169,17 @@ def load(
raise NotImplementedError(f"Backend {backend_handle} not supported. Choose a valid backend.")


def extract_tables():
"""Extract tables of an OMOP CDM Database."""
# TODO for all of this; iterate potential API. should this be functions or object methods?
# extract person, measurements, ....
# define features
# make vars and corresponding .obsm, with specific key
# then can get "static" into .X with
# ep.ts.aggregate(adata, metric="counts")
# or
# ep.ts.aggregate(adata, metrics={counts: [Antibiotic treatment, BP measruement], "average": [Heart Rate]})
pass


def extract_person(duckdb_instance):
"""Extract person table of an OMOP CDM Database."""
return duckdb_instance.sql("SELECT * FROM person").df()

# TODO: check if every person has an observation (can happen)
# TODO: check if every observation has a person (would be data quality issue)
# TODO: figure out how to handle multiple observation periods per person; here the person vs view discussion comes into play


def extract_observation_period(duckdb_instance):
"""Extract person table of an OMOP CDM Database."""
return duckdb_instance.sql("SELECT * FROM observation_period").df()


def extract_person_observation_period(duckdb_instance):
"""Extract observation table of an OMOP CDM Database."""
return duckdb_instance.sql(
"SELECT * \
Expand Down Expand Up @@ -150,7 +270,7 @@ def _get_interval_table_from_awkward_array(
return time_frame


def time_interval_table(
def get_time_interval_table(
# self,
con,
ts: ak.Array,
Expand All @@ -161,7 +281,7 @@ def time_interval_table(
| str, # observation_period_start, birthdate, specific date as popular options?
interval_length_number: int,
interval_length_unit: str,
num_intervals: int,
num_intervals: str | int = "max_observation_duration",
concept_ids: Literal["all"] | Sequence = "all",
aggregation_strategy: str = "first", # what to do if multiple obs. in 1 interval. first, last, mean, median, most_frequent for categories
# strategy="locf",
Expand Down
8 changes: 7 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@


def test_ehrdata_init_empty():
EHRData()
edata = EHRData()
assert edata.r is None


def test_ehrdata_init_standard():
Expand Down Expand Up @@ -75,3 +76,8 @@ def test_ehrdata_slice_3D():
assert adata_sliced.shape[1] == 1
assert adata_sliced.layers[R_LAYER_KEY].shape == (2, 1, 1)
assert adata_sliced.t.shape == (1, 0)


def test_copy():
edata = EHRData()
edata.copy()

0 comments on commit dd37f01

Please sign in to comment.