Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setup obs #51

Merged
merged 17 commits into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"duckdb",
# for debug logging (referenced from the issue template)
"session-info",
"xarray",
]
optional-dependencies.dev = [
"pre-commit",
Expand Down
112 changes: 90 additions & 22 deletions src/ehrdata/dt/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,36 @@ def _get_table_list() -> list:
return flat_table_list


def _set_up_duckdb(path: Path, backend_handle: DuckDBPyConnection) -> None:
def _set_up_duckdb(path: Path, backend_handle: DuckDBPyConnection, prefix: str = "") -> None:
tables = _get_table_list()

used_tables = []
missing_tables = []
for table in tables:
# if path exists lowercse, uppercase, capitalized:
table_path = f"{path}/{table}.csv"
if os.path.exists(table_path):
if table == "measurement":
backend_handle.register(
table, backend_handle.read_csv(f"{path}/{table}.csv", dtype={"measurement_source_value": str})
)
unused_files = []
for file_name in os.listdir(path):
file_name_trunk = file_name.split(".")[0].lower()

if file_name_trunk in tables or file_name_trunk.replace(prefix, "") in tables:
used_tables.append(file_name_trunk.replace(prefix, ""))

if file_name_trunk == "measurement":
dtype = {"measurement_source_value": str}
else:
backend_handle.register(table, backend_handle.read_csv(f"{path}/{table}.csv"))
dtype = None

backend_handle.register(
file_name_trunk.replace(prefix, ""),
backend_handle.read_csv(f"{path}/{file_name_trunk}.csv", dtype=dtype),
)
else:
missing_tables.append([table])
unused_files.append(file_name)

for table in tables:
if table not in used_tables:
missing_tables.append(table)

print("missing tables: ", missing_tables)
print("unused files: ", unused_files)


def mimic_iv_omop(backend_handle: DuckDBPyConnection, data_path: Path | None = None) -> None:
Expand Down Expand Up @@ -80,14 +93,14 @@ def mimic_iv_omop(backend_handle: DuckDBPyConnection, data_path: Path | None = N
else:
print(f"Failed to download the file. Status code: {response.status_code}")
return

return _set_up_duckdb(data_path + "/1_omop_data_csv", backend_handle)
# TODO: capitalization, and lowercase, and containing the name
return _set_up_duckdb(data_path + "/1_omop_data_csv", backend_handle, prefix="2b_")


def gibleed_omop(backend_handle: DuckDBPyConnection, data_path: Path | None = None) -> None:
"""Loads the GIBleed dataset.
"""Loads the GIBleed dataset in the OMOP Common Data model.

More details: https://github.com/OHDSI/EunomiaDatasets.
More details: https://github.com/OHDSI/EunomiaDatasets/tree/main/datasets/GiBleed.

Parameters
----------
Expand All @@ -109,13 +122,38 @@ def gibleed_omop(backend_handle: DuckDBPyConnection, data_path: Path | None = No
>>> ed.dt.gibleed_omop(backend_handle=con)
>>> con.execute("SHOW TABLES;").fetchall()
"""
# TODO:
# https://github.com/darwin-eu/EunomiaDatasets/tree/main/datasets/GiBleed
raise NotImplementedError()
if data_path is None:
data_path = Path("ehrapy_data/GIBleed_dataset")

if data_path.exists():
print(f"Path to data exists, load tables from there: {data_path}")
else:
print("Downloading data...")
URL = "https://github.com/OHDSI/EunomiaDatasets/raw/main/datasets/GiBleed/GiBleed_5.3.zip"
response = requests.get(URL)

if response.status_code == 200:
# extract_path = data_path / "gibleed_data_csv"
# extract_path.mkdir(parents=True, exist_ok=True)

# Use zipfile and io to open the ZIP file in memory
with zipfile.ZipFile(io.BytesIO(response.content)) as z:
# Extract all contents of the ZIP file into the correct subdirectory
z.extractall(data_path) # Extracting to 'extract_path'
print(f"Download successful. ZIP file downloaded and extracted successfully to {data_path}.")

else:
print(f"Failed to download the file. Status code: {response.status_code}")

# extracted_folder = next(data_path.iterdir(), data_path)
# extracted_folder = next((folder for folder in data_path.iterdir() if folder.is_dir() and "_csv" in folder.name and "__MACOSX" not in folder.name), data_path)
return _set_up_duckdb(data_path / "GiBleed_5.3", backend_handle)


def synthea27nj_omop(backend_handle: DuckDBPyConnection, data_path: Path | None = None) -> None:
"""Loads the Synthe27Nj dataset.
"""Loads the Synthea27NJ dataset in the OMOP Common Data model.

More details: https://github.com/darwin-eu/EunomiaDatasets/tree/main/datasets/Synthea27Nj.

Parameters
----------
Expand All @@ -137,9 +175,39 @@ def synthea27nj_omop(backend_handle: DuckDBPyConnection, data_path: Path | None
>>> ed.dt.synthea27nj_omop(backend_handle=con)
>>> con.execute("SHOW TABLES;").fetchall()
"""
# TODO
# https://github.com/darwin-eu/EunomiaDatasets/tree/main/datasets/Synthea27Nj
raise NotImplementedError()
if data_path is None:
data_path = Path("ehrapy_data/Synthea27Nj")

if data_path.exists():
print(f"Path to data exists, load tables from there: {data_path}")
else:
print("Downloading data...")
URL = "https://github.com/OHDSI/EunomiaDatasets/raw/main/datasets/Synthea27Nj/Synthea27Nj_5.4.zip"
response = requests.get(URL)

if response.status_code == 200:
extract_path = data_path / "synthea27nj_omop_csv"
extract_path.mkdir(parents=True, exist_ok=True)

# Use zipfile and io to open the ZIP file in memory
with zipfile.ZipFile(io.BytesIO(response.content)) as z:
# Extract all contents of the ZIP file into the correct subdirectory
z.extractall(extract_path) # Extracting to 'extract_path'
print(f"Download successful. ZIP file downloaded and extracted successfully to {extract_path}.")

else:
print(f"Failed to download the file. Status code: {response.status_code}")
return

extracted_folder = next(
(
folder
for folder in data_path.iterdir()
if folder.is_dir() and "_csv" in folder.name and "__MACOSX" not in folder.name
),
data_path,
)
return _set_up_duckdb(extracted_folder, backend_handle)


def mimic_ii(backend_handle: DuckDBPyConnection, data_path: Path | None = None) -> None:
Expand Down
24 changes: 13 additions & 11 deletions src/ehrdata/io/omop/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from .omop import (
extract_condition_occurrence,
extract_device_exposure,
extract_drug_exposure,
extract_measurement,
extract_note,
extract_observation,
extract_observation_period,
extract_person,
extract_person_observation_period,
extract_procedure_occurrence,
extract_specimen,
get_table,
get_time_interval_table,
load,
# extract_condition_occurrence,
# extract_device_exposure,
# extract_drug_exposure,
# extract_measurement,
# extract_note,
# extract_observation,
# extract_observation_period,
# extract_person,
# extract_person_observation_period,
# extract_procedure_occurrence,
# extract_specimen,
register_omop_to_db_connection,
setup_obs,
setup_variables,
)
139 changes: 139 additions & 0 deletions src/ehrdata/io/omop/_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from collections.abc import Sequence

import duckdb
import pandas as pd

START_DATE_KEY = {
"visit_occurrence": "visit_start_date",
"observation_period": "observation_period_start_date",
"cohort": "cohort_start_date",
}
END_DATE_KEY = {
"visit_occurrence": "visit_end_date",
"observation_period": "observation_period_end_date",
"cohort": "cohort_end_date",
}
TIME_DEFINING_TABLE_SUBJECT_KEY = {
"visit_occurrence": "person_id",
"observation_period": "person_id",
"cohort": "subject_id",
}

AGGREGATION_STRATEGY_KEY = {
"last": "LAST",
"first": "FIRST",
"mean": "MEAN",
"median": "MEDIAN",
"mode": "MODE",
"sum": "SUM",
"count": "COUNT",
"min": "MIN",
"max": "MAX",
"std": "STD",
}


def _generate_timedeltas(interval_length_number: int, interval_length_unit: str, num_intervals: int) -> pd.DataFrame:
timedeltas_dataframe = pd.DataFrame(
{
"interval_start_offset": [
pd.to_timedelta(i * interval_length_number, interval_length_unit) for i in range(num_intervals)
],
"interval_end_offset": [
pd.to_timedelta(i * interval_length_number, interval_length_unit) for i in range(1, num_intervals + 1)
],
"interval_step": list(range(num_intervals)),
}
)
return timedeltas_dataframe


def _write_timedeltas_to_db(
backend_handle: duckdb.duckdb.DuckDBPyConnection,
timedeltas_dataframe,
) -> None:
backend_handle.execute("DROP TABLE IF EXISTS timedeltas")
backend_handle.execute(
"""
CREATE TABLE timedeltas (
interval_start_offset INTERVAL,
interval_end_offset INTERVAL,
interval_step INTEGER
)
"""
)
backend_handle.execute("INSERT INTO timedeltas SELECT * FROM timedeltas_dataframe")


def _drop_timedeltas(backend_handle: duckdb.duckdb.DuckDBPyConnection):
backend_handle.execute("DROP TABLE IF EXISTS timedeltas")


def _generate_value_query(data_table: str, data_field_to_keep: Sequence, aggregation_strategy: str) -> str:
query = f"{', ' .join([f'CASE WHEN COUNT(*) = 0 THEN NULL ELSE {aggregation_strategy}({column}) END AS {column}' for column in data_field_to_keep])}"
return query


def time_interval_table_query_long_format(
backend_handle: duckdb.duckdb.DuckDBPyConnection,
time_defining_table: str,
data_table: str,
interval_length_number: int,
interval_length_unit: str,
num_intervals: int,
aggregation_strategy: str,
data_field_to_keep: Sequence[str] | str,
) -> pd.DataFrame:
"""Returns a long format DataFrame from the data_table. The following columns should be considered the indices of this long format: person_id, data_table_concept_id, interval_step. The other columns, except for start_date and end_date, should be considered the values."""
if isinstance(data_field_to_keep, str):
data_field_to_keep = [data_field_to_keep]

timedeltas_dataframe = _generate_timedeltas(interval_length_number, interval_length_unit, num_intervals)

_write_timedeltas_to_db(
backend_handle,
timedeltas_dataframe,
)

# multi-step query
# 1. Create person_time_defining_table, which matches the one created for obs. Needs to contain the person_id, and the start date in particular.
# 2. Create person_data_table (data_table is typically measurement), which contains the cross product of person_id and the distinct concept_id s.
# 3. Create long_format_backbone, which is the left join of person_time_defining_table and person_data_table.
# 4. Create long_format_intervals, which is the cross product of long_format_backbone and timedeltas. This table contains most notably the person_id, the concept_id, the interval start and end dates.
# 5. Create the final table, which is the join with the data_table (typically measurement); each measurement is assigned to its person_id, its concept_id, and the interval it fits into.
df = backend_handle.execute(
f"""
WITH person_time_defining_table AS ( \
SELECT person.person_id as person_id, {START_DATE_KEY[time_defining_table]} as start_date, {END_DATE_KEY[time_defining_table]} as end_date \
FROM person \
JOIN {time_defining_table} ON person.person_id = {time_defining_table}.{TIME_DEFINING_TABLE_SUBJECT_KEY[time_defining_table]} \
), \
person_data_table AS( \
WITH distinct_data_table_concept_ids AS ( \
SELECT DISTINCT {data_table}_concept_id
FROM {data_table} \
)
SELECT person.person_id, {data_table}_concept_id as data_table_concept_id \
FROM person \
CROSS JOIN distinct_data_table_concept_ids \
), \
long_format_backbone as ( \
SELECT person_time_defining_table.person_id, data_table_concept_id, start_date, end_date \
FROM person_time_defining_table \
LEFT JOIN person_data_table USING(person_id)\
), \
long_format_intervals as ( \
SELECT person_id, data_table_concept_id, interval_step, start_date, start_date + interval_start_offset as interval_start, start_date + interval_end_offset as interval_end \
FROM long_format_backbone \
CROSS JOIN timedeltas \
) \
SELECT lfi.person_id, lfi.data_table_concept_id, interval_step, interval_start, interval_end, {_generate_value_query(data_table, data_field_to_keep, AGGREGATION_STRATEGY_KEY[aggregation_strategy])} \
FROM long_format_intervals as lfi \
LEFT JOIN {data_table} ON lfi.person_id = {data_table}.person_id AND lfi.data_table_concept_id = {data_table}.{data_table}_concept_id AND {data_table}.{data_table}_date BETWEEN lfi.interval_start AND lfi.interval_end \
GROUP BY lfi.person_id, lfi.data_table_concept_id, interval_step, interval_start, interval_end
"""
).df()

_drop_timedeltas(backend_handle)

return df
Loading
Loading