Skip to content

Commit

Permalink
Updated to new tpcp version and updated caching accordingly
Browse files Browse the repository at this point in the history
  • Loading branch information
AKuederle committed Dec 19, 2023
1 parent 63c3b77 commit 3b71eb5
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 82 deletions.
57 changes: 16 additions & 41 deletions gaitlink/data/_mobilised_cvs_dmo_dataset.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import binascii
import warnings
from functools import lru_cache
from pathlib import Path
from typing import ClassVar, Generic, Literal, Optional, TypeAlias, TypeVar, Union
from typing import ClassVar, Literal, Optional, TypeAlias, Union

import pandas as pd
from joblib import Memory, Parallel, delayed
from tpcp import Dataset
from tpcp._hash import custom_hash
from tpcp.caching import hybrid_cache
from tqdm.auto import tqdm

from gaitlink.data._mobilsed_weartime_loader import load_weartime_from_daily_mcroberts_report
from gaitlink.data._utils import staggered_cache

SITE_CODES: TypeAlias = Literal[
"CAU",
Expand Down Expand Up @@ -64,31 +62,12 @@
}


# TODO: Replace with tpcp version once released

T = TypeVar("T")


class UniversalHashableWrapper(Generic[T]):
def __init__(self, obj: T) -> None:
self.obj = obj

def __hash__(self):
"""Hash the object using the pickle based approach."""
return int(binascii.hexlify(custom_hash(self.obj).encode("utf-8")), 16)

def __eq__(self, other):
"""Compare the object using their hash."""
return custom_hash(self.obj) == custom_hash(other.obj)


def _create_index(
dmo_path: Path, site_pid_map_path: Path, timezones: UniversalHashableWrapper[dict[SITE_CODES, str]], memory: Memory
dmo_path: Path, site_pid_map_path: Path, timezones: dict[SITE_CODES, str], memory: Memory
) -> pd.DataFrame:
site_data = staggered_cache(_load_site_pid_map, memory, 1)(site_pid_map_path=site_pid_map_path, timezones=timezones)
dmo_data, _ = staggered_cache(_load_dmo_data, memory, 1)(
dmo_path=dmo_path, timezone_per_subject=UniversalHashableWrapper(site_data)
)
cache = hybrid_cache(memory, 1)
site_data = cache(_load_site_pid_map)(site_pid_map_path=site_pid_map_path, timezones=timezones)
dmo_data, _ = cache(_load_dmo_data)(dmo_path=dmo_path, timezone_per_subject=site_data)

visit_type = dmo_path.name.split("-")[1].upper()

Expand All @@ -103,10 +82,8 @@ def _create_index(
)


def _load_site_pid_map(
site_pid_map_path: Path, timezones: UniversalHashableWrapper[dict[SITE_CODES, str]]
) -> pd.DataFrame:
timezones_df = pd.DataFrame.from_dict(timezones.obj, orient="index", columns=["timezone"])
def _load_site_pid_map(site_pid_map_path: Path, timezones: dict[SITE_CODES, str]) -> pd.DataFrame:
timezones_df = pd.DataFrame.from_dict(timezones, orient="index", columns=["timezone"])

site_data = (
pd.read_csv(site_pid_map_path)[["Local.Participant", "Participant.Site"]]
Expand Down Expand Up @@ -141,17 +118,15 @@ def _load_pid_mid_map(compliance_report: Path) -> pd.DataFrame:
)


def _load_dmo_data(
dmo_path: Path, timezone_per_subject: UniversalHashableWrapper[pd.DataFrame]
) -> tuple[pd.DataFrame, pd.DataFrame]:
def _load_dmo_data(dmo_path: Path, timezone_per_subject: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
warnings.warn(
"Initial data loading. This might take a while! But, don't worry, we cache the loaded results.\n\n"
"If you are seeing this message multiple times, you might want to consider using a joblib memory by "
"passing ``memory=Memory('some/cache/path)`` to the dataset constructor to cache the index creation"
"between script executions.",
stacklevel=1,
)
timezone_per_subject = timezone_per_subject.obj
timezone_per_subject = timezone_per_subject

Check failure on line 129 in gaitlink/data/_mobilised_cvs_dmo_dataset.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (PLW0127)

gaitlink/data/_mobilised_cvs_dmo_dataset.py:129:5: PLW0127 Self-assignment of variable `timezone_per_subject`

dmos = [
"duration",
Expand Down Expand Up @@ -280,9 +255,9 @@ def visit_type(self):
return self.dmo_export_path.split("-")[1].upper()

def _get_participant_site_metadata(self) -> pd.DataFrame:
return staggered_cache(_load_site_pid_map, self.memory, 1)(
return hybrid_cache(self.memory, 1)(_load_site_pid_map)(
site_pid_map_path=Path(self.site_pid_map_path),
timezones=UniversalHashableWrapper(self.TIME_ZONES),
timezones=self.TIME_ZONES,
)

def _get_pid_mid_map(self) -> pd.DataFrame:
Expand Down Expand Up @@ -390,9 +365,9 @@ def process_single_file(path):
return results

def _get_dmo_data(self):
return staggered_cache(_load_dmo_data, self.memory, 1)(
return hybrid_cache(self.memory, 1)(_load_dmo_data)(
dmo_path=Path(self.dmo_export_path),
timezone_per_subject=UniversalHashableWrapper(self._get_participant_site_metadata()),
timezone_per_subject=self._get_participant_site_metadata(),
)

def _extract_relevant_data(self, data: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -438,9 +413,9 @@ def timezone(self) -> str:
return self._get_participant_site_metadata().loc[p_id, "timezone"]

def create_index(self) -> pd.DataFrame:
return self.memory.cache(_create_index)(
return hybrid_cache(self.memory, 1)(_create_index)(
Path(self.dmo_export_path),
Path(self.site_pid_map_path),
UniversalHashableWrapper(self.TIME_ZONES),
self.TIME_ZONES,
memory=self.memory,
)
35 changes: 0 additions & 35 deletions gaitlink/data/_utils.py

This file was deleted.

10 changes: 5 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ readme = "README.md"

[tool.poetry.dependencies]
python = ">=3.9,<3.13"
tpcp = "^0.27.0"
tpcp = "^0.29.0"
pandas = ">=2.1.0"
scipy = ">=1.11.2"
numpy = ">=1.25.2"
Expand Down

0 comments on commit 3b71eb5

Please sign in to comment.