Skip to content

Commit

Permalink
Merge pull request #2 from ihmeuw/notebook_model_refactor
Browse files Browse the repository at this point in the history
lint with isort and then black
  • Loading branch information
aflaxman authored Aug 19, 2024
2 parents 392f108 + e56c233 commit 4004c1d
Show file tree
Hide file tree
Showing 17 changed files with 237 additions and 153 deletions.
24 changes: 11 additions & 13 deletions src/vivarium_sodium_reduction/components/interventions.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from typing import Any, Dict

import pandas as pd
from typing import Dict, Any
from vivarium import Component
from vivarium.framework.engine import Builder
from vivarium.framework.population import SimulantData


class RelativeShiftIntervention(Component):
"""Applies a relative shift to a target value."""

CONFIGURATION_DEFAULTS = {
'intervention': {
'shift_factor': 0.1,
'age_start': 0,
'age_end': 125,
"intervention": {
"shift_factor": 0.1,
"age_start": 0,
"age_end": 125,
}
}

Expand All @@ -25,22 +27,18 @@ def name(self) -> str:

@property
def configuration_defaults(self) -> Dict[str, Dict[str, Any]]:
return {
f'{self.name}': self.CONFIGURATION_DEFAULTS['intervention']
}
return {f"{self.name}": self.CONFIGURATION_DEFAULTS["intervention"]}

def setup(self, builder: Builder) -> None:
self.config = builder.configuration[self.name]
self.shift_factor = self.config.shift_factor
self.age_start = self.config.age_start
self.age_end = self.config.age_end

self.population_view = builder.population.get_view(['age'])
self.population_view = builder.population.get_view(["age"])

builder.value.register_value_modifier(
self.target,
modifier=self.adjust_exposure,
requires_columns=['age']
self.target, modifier=self.adjust_exposure, requires_columns=["age"]
)

def adjust_exposure(self, index: pd.Index, exposure: pd.Series) -> pd.Series:
Expand All @@ -49,4 +47,4 @@ def adjust_exposure(self, index: pd.Index, exposure: pd.Series) -> pd.Series:
(self.age_start <= pop.age) & (pop.age < self.age_end)
].index
exposure.loc[applicable_index] *= self.shift_factor
return exposure
return exposure
65 changes: 43 additions & 22 deletions src/vivarium_sodium_reduction/components/risks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
from vivarium.framework.randomness import get_hash
from vivarium.framework.values import Pipeline
from vivarium_public_health.risks import Risk
from vivarium_public_health.risks.data_transformations import get_exposure_post_processor
from vivarium_public_health.risks.data_transformations import (
get_exposure_post_processor,
)
from vivarium_public_health.utilities import EntityString


class DropValueRisk(Risk):
def __init__(self, risk: str):
super().__init__(risk)
Expand Down Expand Up @@ -57,11 +60,13 @@ def post_processor(exposure, _):

return post_processor


class CorrelatedRisk(DropValueRisk):
"""A risk that can be correlated with another risk.
TODO: document strategy used in this component in more detail,
Abie had an AI adapt it from https://github.com/ihmeuw/vivarium_nih_us_cvd"""

@property
def columns_created(self) -> List[str]:
return []
Expand All @@ -84,8 +89,10 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None:
def on_time_step_prepare(self, event: Event) -> None:
pass


class RiskCorrelation(Component):
"""A component that generates a specified correlation between two risk exposures."""

@property
def columns_created(self) -> List[str]:
return self.propensity_column_names + self.exposure_column_names
Expand All @@ -104,17 +111,27 @@ def __init__(self, risk1: str, risk2: str, correlation: str):
correlation_matrix = np.array([[1, float(correlation)], [float(correlation), 1]])
self.correlated_risks = [EntityString(risk) for risk in correlated_risks]
self.correlation_matrix = correlation_matrix
self.propensity_column_names = [f"{risk.name}_propensity" for risk in self.correlated_risks]
self.exposure_column_names = [f"{risk.name}_exposure" for risk in self.correlated_risks]
self.propensity_column_names = [
f"{risk.name}_propensity" for risk in self.correlated_risks
]
self.exposure_column_names = [
f"{risk.name}_exposure" for risk in self.correlated_risks
]
self.ensemble_propensities = [
f"ensemble_propensity_" + risk
for risk in self.correlated_risks
if risk_factors[risk.name].distribution == "ensemble"
]

def setup(self, builder: Builder) -> None:
self.distributions = {risk: builder.components.get_component(risk).exposure_distribution for risk in self.correlated_risks}
self.exposures = {risk: builder.value.get_value(f"{risk.name}.exposure") for risk in self.correlated_risks}
self.distributions = {
risk: builder.components.get_component(risk).exposure_distribution
for risk in self.correlated_risks
}
self.exposures = {
risk: builder.value.get_value(f"{risk.name}.exposure")
for risk in self.correlated_risks
}
self.input_draw = builder.configuration.input_data.input_draw_number
self.random_seed = builder.configuration.randomness.random_seed

Expand All @@ -124,21 +141,21 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None:

np.random.seed(get_hash(f"{self.input_draw}_{self.random_seed}"))
probit_propensity = np.random.multivariate_normal(
mean=[0] * len(self.correlated_risks),
cov=self.correlation_matrix,
size=len(pop)
mean=[0] * len(self.correlated_risks), cov=self.correlation_matrix, size=len(pop)
)
correlated_propensities = scipy.stats.norm().cdf(probit_propensity)
propensities[self.propensity_column_names] = correlated_propensities

def get_exposure_from_propensity(propensity_col: pd.Series) -> pd.Series:
risk = propensity_col.name.replace('_propensity','')
exposure_values = self.distributions['risk_factor.' + risk].ppf(propensity_col)
risk = propensity_col.name.replace("_propensity", "")
exposure_values = self.distributions["risk_factor." + risk].ppf(propensity_col)
return pd.Series(exposure_values)

exposures = propensities.apply(get_exposure_from_propensity)
exposures.columns = [col.replace('_propensity','_exposure') for col in propensities.columns]

exposures.columns = [
col.replace("_propensity", "_exposure") for col in propensities.columns
]

self.population_view.update(pd.concat([propensities, exposures], axis=1))

def on_time_step_prepare(self, event: Event) -> None:
Expand All @@ -147,33 +164,37 @@ def on_time_step_prepare(self, event: Event) -> None:
exposure_col = pd.Series(exposure_values, name=f"{risk.name}_exposure")
self.population_view.update(exposure_col)


class SodiumSBPEffect(Component):
@property
def name(self):
return "sodium_sbp_effect"

def setup(self, builder: Builder):
self.sodium_exposure = builder.value.get_value('diet_high_in_sodium.exposure')
self.sodium_exposure_raw = builder.value.get_value('diet_high_in_sodium.raw_exposure')
self.sodium_exposure = builder.value.get_value("diet_high_in_sodium.exposure")
self.sodium_exposure_raw = builder.value.get_value("diet_high_in_sodium.raw_exposure")

builder.value.register_value_modifier(
'high_systolic_blood_pressure.drop_value',
"high_systolic_blood_pressure.drop_value",
modifier=self.sodium_effect_on_sbp,
requires_columns=['age', 'sex'],
requires_values=['diet_high_in_sodium.exposure', 'diet_high_in_sodium.raw_exposure']
requires_columns=["age", "sex"],
requires_values=[
"diet_high_in_sodium.exposure",
"diet_high_in_sodium.raw_exposure",
],
)

def sodium_effect_on_sbp(self, index, sbp_drop_value):
sodium_exposure = self.sodium_exposure(index)
sodium_exposure_raw = self.sodium_exposure_raw(index)

sodium_threshold = 2.0 # g/day
mmHg_per_g_sodium = 10 # mmHg increase per 1g sodium above threshold

sbp_increase = pd.Series(0, index=index)
sodium_drop = sodium_exposure_raw - sodium_exposure
# TODO: use threshold

sbp_drop_due_to_sodium_drop = sodium_drop * mmHg_per_g_sodium

return sbp_drop_value + sbp_drop_due_to_sodium_drop
8 changes: 6 additions & 2 deletions src/vivarium_sodium_reduction/constants/data_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@ class __SomeDisease(NamedTuple):

# Keys that will be loaded into the artifact. must have a colon type declaration
SOME_DISEASE_PREVALENCE: TargetString = TargetString("cause.some_disease.prevalence")
SOME_DISEASE_INCIDENCE_RATE: TargetString = TargetString("cause.some_disease.incidence_rate")
SOME_DISEASE_REMISSION_RATE: TargetString = TargetString("cause.some_disease.remission_rate")
SOME_DISEASE_INCIDENCE_RATE: TargetString = TargetString(
"cause.some_disease.incidence_rate"
)
SOME_DISEASE_REMISSION_RATE: TargetString = TargetString(
"cause.some_disease.remission_rate"
)
DISABILITY_WEIGHT: TargetString = TargetString("cause.some_disease.disability_weight")
EMR: TargetString = TargetString("cause.some_disease.excess_mortality_rate")
CSMR: TargetString = TargetString("cause.some_disease.cause_specific_mortality_rate")
Expand Down
5 changes: 3 additions & 2 deletions src/vivarium_sodium_reduction/constants/data_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@
SCALE_UP_START_DT = datetime(2021, 1, 1)
SCALE_UP_END_DT = datetime(2030, 1, 1)
SCREENING_SCALE_UP_GOAL_COVERAGE = 0.50
SCREENING_SCALE_UP_DIFFERENCE = SCREENING_SCALE_UP_GOAL_COVERAGE - PROBABILITY_ATTENDING_SCREENING_START_MEAN

SCREENING_SCALE_UP_DIFFERENCE = (
SCREENING_SCALE_UP_GOAL_COVERAGE - PROBABILITY_ATTENDING_SCREENING_START_MEAN
)
4 changes: 3 additions & 1 deletion src/vivarium_sodium_reduction/constants/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,6 @@ def __new__(cls, value):


STATES = tuple(state for model in STATE_MACHINE_MAP.values() for state in model["states"])
TRANSITIONS = tuple(state for model in STATE_MACHINE_MAP.values() for state in model["transitions"])
TRANSITIONS = tuple(
state for model in STATE_MACHINE_MAP.values() for state in model["transitions"]
)
5 changes: 3 additions & 2 deletions src/vivarium_sodium_reduction/constants/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@

BASE_DIR = Path(vivarium_sodium_reduction.__file__).resolve().parent

ARTIFACT_ROOT = Path(f"/mnt/team/simulation_science/pub/models/{metadata.PROJECT_NAME}/artifacts/")

ARTIFACT_ROOT = Path(
f"/mnt/team/simulation_science/pub/models/{metadata.PROJECT_NAME}/artifacts/"
)
44 changes: 29 additions & 15 deletions src/vivarium_sodium_reduction/constants/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
THROWAWAY_COLUMNS = [f"{state}_event_count" for state in models.STATES]

DEATH_COLUMN_TEMPLATE = "MEASURE_death_due_to_{CAUSE_OF_DEATH}_AGE_GROUP_{AGE_GROUP}_CURRENT_YEAR_{YEAR}_SEX_{SEX}"
YLLS_COLUMN_TEMPLATE = "MEASURE_ylls_due_to_{CAUSE_OF_DEATH}_AGE_GROUP_{AGE_GROUP}_CURRENT_YEAR_{YEAR}_SEX_{SEX}"
YLDS_COLUMN_TEMPLATE = (
"MEASURE_ylds_due_to_{CAUSE_OF_DISABILITY}_AGE_GROUP_{AGE_GROUP}_CURRENT_YEAR_{YEAR}_SEX_{SEX}"
YLLS_COLUMN_TEMPLATE = (
"MEASURE_ylls_due_to_{CAUSE_OF_DEATH}_AGE_GROUP_{AGE_GROUP}_CURRENT_YEAR_{YEAR}_SEX_{SEX}"
)
YLDS_COLUMN_TEMPLATE = "MEASURE_ylds_due_to_{CAUSE_OF_DISABILITY}_AGE_GROUP_{AGE_GROUP}_CURRENT_YEAR_{YEAR}_SEX_{SEX}"
STATE_PERSON_TIME_COLUMN_TEMPLATE = (
"MEASURE_{STATE}_person_time_AGE_GROUP_{AGE_GROUP}_CURRENT_YEAR_{YEAR}_SEX_{SEX}"
)
Expand All @@ -46,8 +46,7 @@
"transition_count": TRANSITION_COUNT_COLUMN_TEMPLATE,
}

NON_COUNT_TEMPLATES = [
]
NON_COUNT_TEMPLATES = []

SEXES = ("male", "female")
# TODO - add literals for years in the model
Expand Down Expand Up @@ -87,11 +86,18 @@ def RESULT_COLUMNS(kind="all"):
columns = list(STANDARD_COLUMNS.values()) + columns
else:
template = COLUMN_TEMPLATES[kind]
filtered_field_map = {field: values
for field, values in TEMPLATE_FIELD_MAP.items() if f"{{{field}}}" in template}
fields, value_groups = filtered_field_map.keys(), itertools.product(*filtered_field_map.values())
filtered_field_map = {
field: values
for field, values in TEMPLATE_FIELD_MAP.items()
if f"{{{field}}}" in template
}
fields, value_groups = filtered_field_map.keys(), itertools.product(
*filtered_field_map.values()
)
for value_group in value_groups:
columns.append(template.format(**{field: value for field, value in zip(fields, value_group)}))
columns.append(
template.format(**{field: value for field, value in zip(fields, value_group)})
)
return columns


Expand All @@ -101,13 +107,21 @@ def RESULTS_MAP(kind):
raise ValueError(f"Unknown result column type {kind}")
columns = []
template = COLUMN_TEMPLATES[kind]
filtered_field_map = {field: values
for field, values in TEMPLATE_FIELD_MAP.items() if f"{{{field}}}" in template}
fields, value_groups = list(filtered_field_map.keys()), list(itertools.product(*filtered_field_map.values()))
filtered_field_map = {
field: values
for field, values in TEMPLATE_FIELD_MAP.items()
if f"{{{field}}}" in template
}
fields, value_groups = list(filtered_field_map.keys()), list(
itertools.product(*filtered_field_map.values())
)
for value_group in value_groups:
columns.append(template.format(**{field: value for field, value in zip(fields, value_group)}))
columns.append(
template.format(**{field: value for field, value in zip(fields, value_group)})
)
df = pd.DataFrame(value_groups, columns=map(lambda x: x.lower(), fields))
df["key"] = columns
df["measure"] = kind # per researcher feedback, this column is useful, even when it"s identical for all rows
df["measure"] = (
kind # per researcher feedback, this column is useful, even when it"s identical for all rows
)
return df.set_index("key").sort_index()

10 changes: 5 additions & 5 deletions src/vivarium_sodium_reduction/constants/scenarios.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
class InterventionScenario:

def __init__(
self,
name: str,
# todo add additional interventions
# has_treatment_one: bool = False,
# has_treatment_two: bool = False,
self,
name: str,
# todo add additional interventions
# has_treatment_one: bool = False,
# has_treatment_two: bool = False,
):
self.name = name
# self.has_treatment_one = has_treatment_one
Expand Down
1 change: 1 addition & 0 deletions src/vivarium_sodium_reduction/data/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Logging in this module should be done at the ``debug`` level.
"""

from pathlib import Path
from typing import Optional

Expand Down
1 change: 1 addition & 0 deletions src/vivarium_sodium_reduction/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
No logging is done here. Logging is done in vivarium inputs itself and forwarded.
"""

from typing import List, Optional, Union

import pandas as pd
Expand Down
Loading

0 comments on commit 4004c1d

Please sign in to comment.