Skip to content

Commit

Permalink
refactor time interval writing sql query
Browse files Browse the repository at this point in the history
  • Loading branch information
eroell committed Dec 10, 2024
1 parent 32589a9 commit a049fa6
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 163 deletions.
129 changes: 2 additions & 127 deletions src/ehrdata/io/omop/_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _generate_value_query(data_table: str, data_field_to_keep: Sequence, aggrega
return is_present_query + value_query


def _time_interval_table(
def _write_long_time_interval_table(
backend_handle: duckdb.duckdb.DuckDBPyConnection,
time_defining_table: str,
data_table: str,
Expand All @@ -125,8 +125,7 @@ def _time_interval_table(
aggregation_strategy: str,
data_field_to_keep: Sequence[str] | str,
keep_date: str = "",
return_as_df: bool = False,
) -> pd.DataFrame | None:
) -> None:
if isinstance(data_field_to_keep, str):
data_field_to_keep = [data_field_to_keep]

Expand Down Expand Up @@ -219,127 +218,3 @@ def _time_interval_table(
WHERE long_person_timestamp_feature_value.person_id = RP.person_id;
"""
backend_handle.execute(add_person_range_index_query)

if return_as_df:
return backend_handle.execute("SELECT * FROM long_person_timestamp_feature_value").df()
else:
return None


# def _get_time_interval_table(
# backend_handle: duckdb.duckdb.DuckDBPyConnection,
# time_defining_table: str,
# data_table: str,
# interval_length_number: int,
# interval_length_unit: str,
# num_intervals: int,
# aggregation_strategy: str,
# data_field_to_keep: Sequence[str] | str,
# keep_date: str = "",
# ):
# return backend_handle.execute("SELECT * FROM long_person_timestamp_feature_value").df()


# def _time_interval_table_for_dataloader(
# backend_handle: duckdb.duckdb.DuckDBPyConnection,
# time_defining_table: str,
# data_table: str,
# interval_length_number: int,
# interval_length_unit: str,
# num_intervals: int,
# aggregation_strategy: str,
# data_field_to_keep: Sequence[str] | str,
# keep_date: str = "",
# ):
# if isinstance(data_field_to_keep, str):
# data_field_to_keep = [data_field_to_keep]

# if keep_date == "":
# keep_date = "timepoint"

# timedeltas_dataframe = _generate_timedeltas(interval_length_number, interval_length_unit, num_intervals)

# _write_timedeltas_to_db(
# backend_handle,
# timedeltas_dataframe,
# )

# # multi-step query
# # 1. Create person_time_defining_table, which matches the one created for obs. Needs to contain the person_id, and the start date in particular.
# # 2. Create person_data_table (data_table is typically measurement), which contains the cross product of person_id and the distinct concept_id s.
# # 3. Create long_format_backbone, which is the left join of person_time_defining_table and person_data_table.
# # 4. Create long_format_intervals, which is the cross product of long_format_backbone and timedeltas. This table contains most notably the person_id, the concept_id, the interval start and end dates.
# # 5. Create the final table, which is the join with the data_table (typically measurement); each measurement is assigned to its person_id, its concept_id, and the interval it fits into.
# prepare_alias_query = f"""
# CREATE TABLE long_person_timestamp_feature_value AS \
# WITH person_time_defining_table AS ( \
# SELECT person.person_id as person_id, {DATA_TABLE_DATE_KEYS["start"][time_defining_table]} as start_date, {DATA_TABLE_DATE_KEYS["end"][time_defining_table]} as end_date \
# FROM person \
# JOIN {time_defining_table} ON person.person_id = {time_defining_table}.{TIME_DEFINING_TABLE_SUBJECT_KEY[time_defining_table]} \
# WHERE visit_concept_id = 262 \
# ), \
# person_data_table AS( \
# WITH distinct_data_table_concept_ids AS ( \
# SELECT DISTINCT {DATA_TABLE_CONCEPT_ID_TRUNK[data_table]}_concept_id
# FROM {data_table} \
# )
# SELECT person.person_id, {DATA_TABLE_CONCEPT_ID_TRUNK[data_table]}_concept_id as data_table_concept_id \
# FROM person \
# CROSS JOIN distinct_data_table_concept_ids \
# ), \
# long_format_backbone as ( \
# SELECT person_time_defining_table.person_id, data_table_concept_id, start_date, end_date \
# FROM person_time_defining_table \
# LEFT JOIN person_data_table USING(person_id)\
# ), \
# long_format_intervals as ( \
# SELECT person_id, data_table_concept_id, interval_step, start_date, start_date + interval_start_offset as interval_start, start_date + interval_end_offset as interval_end \
# FROM long_format_backbone \
# CROSS JOIN timedeltas \
# ), \
# data_table_with_presence_indicator as( \
# SELECT *, 1 as is_present \
# FROM {data_table} \
# ) \
# """

# if keep_date in ["timepoint", "start", "end"]:
# select_query = f"""
# SELECT lfi.person_id, lfi.data_table_concept_id, interval_step, interval_start, interval_end, {_generate_value_query("data_table_with_presence_indicator", data_field_to_keep, AGGREGATION_STRATEGY_KEY[aggregation_strategy])} \
# FROM long_format_intervals as lfi \
# LEFT JOIN data_table_with_presence_indicator ON lfi.person_id = data_table_with_presence_indicator.person_id AND lfi.data_table_concept_id = data_table_with_presence_indicator.{DATA_TABLE_CONCEPT_ID_TRUNK[data_table]}_concept_id AND data_table_with_presence_indicator.{DATA_TABLE_DATE_KEYS[keep_date][data_table]} BETWEEN lfi.interval_start AND lfi.interval_end \
# GROUP BY lfi.person_id, lfi.data_table_concept_id, interval_step, interval_start, interval_end
# """

# elif keep_date == "interval":
# select_query = f"""
# SELECT lfi.person_id, lfi.data_table_concept_id, interval_step, interval_start, interval_end, {_generate_value_query("data_table_with_presence_indicator", data_field_to_keep, AGGREGATION_STRATEGY_KEY[aggregation_strategy])} \
# FROM long_format_intervals as lfi \
# LEFT JOIN data_table_with_presence_indicator ON lfi.person_id = data_table_with_presence_indicator.person_id \
# AND lfi.data_table_concept_id = data_table_with_presence_indicator.{DATA_TABLE_CONCEPT_ID_TRUNK[data_table]}_concept_id \
# AND (data_table_with_presence_indicator.{DATA_TABLE_DATE_KEYS["start"][data_table]} BETWEEN lfi.interval_start AND lfi.interval_end \
# OR data_table_with_presence_indicator.{DATA_TABLE_DATE_KEYS["end"][data_table]} BETWEEN lfi.interval_start AND lfi.interval_end \
# OR (data_table_with_presence_indicator.{DATA_TABLE_DATE_KEYS["start"][data_table]} < lfi.interval_start AND data_table_with_presence_indicator.{DATA_TABLE_DATE_KEYS["end"][data_table]} > lfi.interval_end)) \
# GROUP BY lfi.person_id, lfi.data_table_concept_id, interval_step, interval_start, interval_end
# """

# query = prepare_alias_query + select_query
# backend_handle.execute("DROP TABLE IF EXISTS long_person_timestamp_feature_value")
# backend_handle.execute(query)
# add_person_range_index_query = """
# ALTER TABLE long_person_timestamp_feature_value
# ADD COLUMN person_index INTEGER;

# WITH RankedPersons AS (
# SELECT person_id,
# ROW_NUMBER() OVER (ORDER BY person_id) - 1 AS idx
# FROM (SELECT DISTINCT person_id FROM long_person_timestamp_feature_value) AS unique_persons
# )
# UPDATE long_person_timestamp_feature_value
# SET person_index = RP.idx
# FROM RankedPersons RP
# WHERE long_person_timestamp_feature_value.person_id = RP.person_id;
# """
# backend_handle.execute(add_person_range_index_query)

# return None
78 changes: 42 additions & 36 deletions src/ehrdata/io/omop/omop.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
_check_valid_observation_table,
_check_valid_variable_data_tables,
)
from ehrdata.io.omop._queries import _time_interval_table
from ehrdata.io.omop._queries import _write_long_time_interval_table

DOWNLOAD_VERIFICATION_TAG = "download_verification_tag"

Expand Down Expand Up @@ -345,30 +345,21 @@ def setup_variables(
logging.warning(f"No data found in {data_tables[0]}. Returning edata without additional variables.")
return edata

# TODO: if instantiate_tensor
ds = (
_time_interval_table(
backend_handle=backend_handle,
time_defining_table=time_defining_table,
data_table=data_tables[0],
data_field_to_keep=data_field_to_keep,
interval_length_number=interval_length_number,
interval_length_unit=interval_length_unit,
num_intervals=num_intervals,
aggregation_strategy=aggregation_strategy,
return_as_df=True,
)
.set_index(["person_id", "data_table_concept_id", "interval_step"])
.to_xarray()
_write_long_time_interval_table(
backend_handle=backend_handle,
time_defining_table=time_defining_table,
data_table=data_tables[0],
data_field_to_keep=data_field_to_keep,
interval_length_number=interval_length_number,
interval_length_unit=interval_length_unit,
num_intervals=num_intervals,
aggregation_strategy=aggregation_strategy,
)

# TODO: if instantiate_tensor! rdbms backed, make ds independent but build on long table
_check_one_unit_per_feature(backend_handle)
# TODO ignore? go with more vanilla omop style. _check_one_unit_per_feature(ds, unit_key="unit_source_value")

unit_report = _create_feature_unit_concept_id_report(backend_handle)

var = ds["data_table_concept_id"].to_dataframe()
var = backend_handle.execute("SELECT DISTINCT data_table_concept_id FROM long_person_timestamp_feature_value").df()

if enrich_var_with_feature_info or enrich_var_with_unit_info:
concepts = backend_handle.sql("SELECT * FROM concept").df()
Expand Down Expand Up @@ -398,9 +389,19 @@ def setup_variables(
suffixes=("", "_unit"),
)

t = ds["interval_step"].to_dataframe()
t = pd.DataFrame({"interval_step": np.arange(num_intervals)})

edata = EHRData(r=ds[data_field_to_keep[0]].values, obs=edata.obs, var=var, uns=edata.uns, t=t)
if instantiate_tensor:
ds = (
(backend_handle.execute("SELECT * FROM long_person_timestamp_feature_value").df())
.set_index(["person_id", "data_table_concept_id", "interval_step"])
.to_xarray()
)

else:
ds = None

edata = EHRData(r=ds[data_field_to_keep[0]].values if ds else None, obs=edata.obs, var=var, uns=edata.uns, t=t)
edata.uns[f"unit_report_{data_tables[0]}"] = unit_report

return edata
Expand All @@ -420,6 +421,7 @@ def setup_interval_variables(
enrich_var_with_feature_info: bool = False,
enrich_var_with_unit_info: bool = False,
keep_date: Literal["start", "end", "interval"] = "start",
instantiate_tensor: bool = True,
):
"""Setup the interval variables
Expand Down Expand Up @@ -453,6 +455,8 @@ def setup_interval_variables(
Whether to enrich the var table with feature information. If a concept_id is not found in the concept table, the feature information will be NaN.
date_type
Whether to keep the start or end date, or the interval span.
instantiate_tensor
Whether to instantiate the tensor into the .r field of the EHRData object.
Returns
-------
Expand Down Expand Up @@ -483,24 +487,26 @@ def setup_interval_variables(
logging.warning(f"No data in {data_tables}.")
return edata

_write_long_time_interval_table(
backend_handle=backend_handle,
time_defining_table=time_defining_table,
data_table=data_tables[0],
data_field_to_keep=data_field_to_keep,
interval_length_number=interval_length_number,
interval_length_unit=interval_length_unit,
num_intervals=num_intervals,
aggregation_strategy=aggregation_strategy,
keep_date=keep_date,
)

ds = (
_time_interval_table(
backend_handle=backend_handle,
time_defining_table=time_defining_table,
data_table=data_tables[0],
data_field_to_keep=data_field_to_keep,
interval_length_number=interval_length_number,
interval_length_unit=interval_length_unit,
num_intervals=num_intervals,
aggregation_strategy=aggregation_strategy,
keep_date=keep_date,
return_as_df=True,
)
backend_handle.execute("SELECT * FROM long_person_timestamp_feature_value")
.df()
.set_index(["person_id", "data_table_concept_id", "interval_step"])
.to_xarray()
)

var = ds["data_table_concept_id"].to_dataframe()
var = backend_handle.execute("SELECT DISTINCT data_table_concept_id FROM long_person_timestamp_feature_value").df()

if enrich_var_with_feature_info or enrich_var_with_unit_info:
concepts = backend_handle.sql("SELECT * FROM concept").df()
Expand All @@ -509,7 +515,7 @@ def setup_interval_variables(
if enrich_var_with_feature_info:
var = pd.merge(var, concepts, how="left", left_index=True, right_on="concept_id")

t = ds["interval_step"].to_dataframe()
t = pd.DataFrame({"interval_step": np.arange(num_intervals)})

edata = EHRData(r=ds[data_field_to_keep[0]].values, obs=edata.obs, var=var, uns=edata.uns, t=t)

Expand Down

0 comments on commit a049fa6

Please sign in to comment.