Skip to content

Commit

Permalink
Merge pull request #1 from ihmeuw/notebook_model_refactor
Browse files Browse the repository at this point in the history
Refactor minimal .ipynb model into more traditional vivarium model form
  • Loading branch information
aflaxman authored Aug 19, 2024
2 parents a87891c + 6853832 commit 392f108
Show file tree
Hide file tree
Showing 5 changed files with 275 additions and 18 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@

install_requirements = [
"gbd_mapping==4.0.0",
"vivarium==2.3.8",
"vivarium_public_health==2.3.3",
"vivarium==3.0.0",
"vivarium_public_health>=3.0.0,<4.0.0",
"click",
"jinja2",
"loguru",
Expand Down
1 change: 1 addition & 0 deletions src/vivarium_sodium_reduction/components/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

52 changes: 52 additions & 0 deletions src/vivarium_sodium_reduction/components/interventions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pandas as pd
from typing import Dict, Any
from vivarium import Component
from vivarium.framework.engine import Builder
from vivarium.framework.population import SimulantData

class RelativeShiftIntervention(Component):
"""Applies a relative shift to a target value."""

CONFIGURATION_DEFAULTS = {
'intervention': {
'shift_factor': 0.1,
'age_start': 0,
'age_end': 125,
}
}

def __init__(self, target: str):
super().__init__()
self.target = target

@property
def name(self) -> str:
return f"relative_shift_intervention.{self.target}"

@property
def configuration_defaults(self) -> Dict[str, Dict[str, Any]]:
return {
f'{self.name}': self.CONFIGURATION_DEFAULTS['intervention']
}

def setup(self, builder: Builder) -> None:
self.config = builder.configuration[self.name]
self.shift_factor = self.config.shift_factor
self.age_start = self.config.age_start
self.age_end = self.config.age_end

self.population_view = builder.population.get_view(['age'])

builder.value.register_value_modifier(
self.target,
modifier=self.adjust_exposure,
requires_columns=['age']
)

def adjust_exposure(self, index: pd.Index, exposure: pd.Series) -> pd.Series:
pop = self.population_view.get(index)
applicable_index = pop.loc[
(self.age_start <= pop.age) & (pop.age < self.age_end)
].index
exposure.loc[applicable_index] *= self.shift_factor
return exposure
179 changes: 179 additions & 0 deletions src/vivarium_sodium_reduction/components/risks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from typing import Dict, List, Optional

import numpy as np
import pandas as pd
import scipy
from gbd_mapping import risk_factors
from vivarium import Component
from vivarium.framework.engine import Builder
from vivarium.framework.event import Event
from vivarium.framework.population import SimulantData
from vivarium.framework.randomness import get_hash
from vivarium.framework.values import Pipeline
from vivarium_public_health.risks import Risk
from vivarium_public_health.risks.data_transformations import get_exposure_post_processor
from vivarium_public_health.utilities import EntityString

class DropValueRisk(Risk):
def __init__(self, risk: str):
super().__init__(risk)
self.raw_exposure_pipeline_name = f"{self.risk.name}.raw_exposure"
self.drop_value_pipeline_name = f"{self.risk.name}.drop_value"

def setup(self, builder: Builder) -> None:
super().setup(builder)
self.raw_exposure = self.get_raw_exposure_pipeline(builder)
self.drop_value = self.get_drop_value_pipeline(builder)

def get_drop_value_pipeline(self, builder: Builder) -> Pipeline:
return builder.value.register_value_producer(
self.drop_value_pipeline_name,
source=lambda index: pd.Series(0.0, index=index),
)

def get_raw_exposure_pipeline(self, builder: Builder) -> Pipeline:
return builder.value.register_value_producer(
self.raw_exposure_pipeline_name,
source=self.get_current_exposure,
requires_columns=["age", "sex"],
requires_values=[self.propensity_pipeline_name],
)

def get_exposure_pipeline(self, builder: Builder) -> Pipeline:
return builder.value.register_value_producer(
self.exposure_pipeline_name,
source=self.get_current_exposure,
requires_columns=["age", "sex"],
requires_values=[self.propensity_pipeline_name],
preferred_post_processor=self.get_drop_value_post_processor(builder, self.risk),
)

def get_drop_value_post_processor(self, builder: Builder, risk: EntityString):
drop_value_pipeline = builder.value.get_value(self.drop_value_pipeline_name)

def post_processor(exposure, _):
drop_values = drop_value_pipeline(exposure.index)
return exposure - drop_values

return post_processor

class CorrelatedRisk(DropValueRisk):
"""A risk that can be correlated with another risk.
TODO: document strategy used in this component in more detail,
Abie had an AI adapt it from https://github.com/ihmeuw/vivarium_nih_us_cvd"""
@property
def columns_created(self) -> List[str]:
return []

@property
def columns_required(self) -> Optional[List[str]]:
return [self.propensity_column_name]

@property
def initialization_requirements(self) -> Dict[str, List[str]]:
return {
"requires_columns": [],
"requires_values": [],
"requires_streams": [],
}

def on_initialize_simulants(self, pop_data: SimulantData) -> None:
pass

def on_time_step_prepare(self, event: Event) -> None:
pass

class RiskCorrelation(Component):
"""A component that generates a specified correlation between two risk exposures."""
@property
def columns_created(self) -> List[str]:
return self.propensity_column_names + self.exposure_column_names

@property
def columns_required(self) -> Optional[List[str]]:
return ["age"]

@property
def initialization_requirements(self) -> Dict[str, List[str]]:
return {"requires_columns": ["age"] + self.ensemble_propensities}

def __init__(self, risk1: str, risk2: str, correlation: str):
super().__init__()
correlated_risks = [risk1, risk2]
correlation_matrix = np.array([[1, float(correlation)], [float(correlation), 1]])
self.correlated_risks = [EntityString(risk) for risk in correlated_risks]
self.correlation_matrix = correlation_matrix
self.propensity_column_names = [f"{risk.name}_propensity" for risk in self.correlated_risks]
self.exposure_column_names = [f"{risk.name}_exposure" for risk in self.correlated_risks]
self.ensemble_propensities = [
f"ensemble_propensity_" + risk
for risk in self.correlated_risks
if risk_factors[risk.name].distribution == "ensemble"
]

def setup(self, builder: Builder) -> None:
self.distributions = {risk: builder.components.get_component(risk).exposure_distribution for risk in self.correlated_risks}
self.exposures = {risk: builder.value.get_value(f"{risk.name}.exposure") for risk in self.correlated_risks}
self.input_draw = builder.configuration.input_data.input_draw_number
self.random_seed = builder.configuration.randomness.random_seed

def on_initialize_simulants(self, pop_data: SimulantData) -> None:
pop = self.population_view.subview(["age"]).get(pop_data.index)
propensities = pd.DataFrame(index=pop.index)

np.random.seed(get_hash(f"{self.input_draw}_{self.random_seed}"))
probit_propensity = np.random.multivariate_normal(
mean=[0] * len(self.correlated_risks),
cov=self.correlation_matrix,
size=len(pop)
)
correlated_propensities = scipy.stats.norm().cdf(probit_propensity)
propensities[self.propensity_column_names] = correlated_propensities

def get_exposure_from_propensity(propensity_col: pd.Series) -> pd.Series:
risk = propensity_col.name.replace('_propensity','')
exposure_values = self.distributions['risk_factor.' + risk].ppf(propensity_col)
return pd.Series(exposure_values)

exposures = propensities.apply(get_exposure_from_propensity)
exposures.columns = [col.replace('_propensity','_exposure') for col in propensities.columns]

self.population_view.update(pd.concat([propensities, exposures], axis=1))

def on_time_step_prepare(self, event: Event) -> None:
for risk in self.exposures:
exposure_values = self.exposures[risk](event.index)
exposure_col = pd.Series(exposure_values, name=f"{risk.name}_exposure")
self.population_view.update(exposure_col)

class SodiumSBPEffect(Component):
@property
def name(self):
return "sodium_sbp_effect"

def setup(self, builder: Builder):
self.sodium_exposure = builder.value.get_value('diet_high_in_sodium.exposure')
self.sodium_exposure_raw = builder.value.get_value('diet_high_in_sodium.raw_exposure')

builder.value.register_value_modifier(
'high_systolic_blood_pressure.drop_value',
modifier=self.sodium_effect_on_sbp,
requires_columns=['age', 'sex'],
requires_values=['diet_high_in_sodium.exposure', 'diet_high_in_sodium.raw_exposure']
)

def sodium_effect_on_sbp(self, index, sbp_drop_value):
sodium_exposure = self.sodium_exposure(index)
sodium_exposure_raw = self.sodium_exposure_raw(index)

sodium_threshold = 2.0 # g/day
mmHg_per_g_sodium = 10 # mmHg increase per 1g sodium above threshold

sbp_increase = pd.Series(0, index=index)
sodium_drop = sodium_exposure_raw - sodium_exposure
# TODO: use threshold

sbp_drop_due_to_sodium_drop = sodium_drop * mmHg_per_g_sodium

return sbp_drop_value + sbp_drop_due_to_sodium_drop
57 changes: 41 additions & 16 deletions src/vivarium_sodium_reduction/model_specifications/model_spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,62 @@ components:
population:
- BasePopulation()
- Mortality()
metrics:
- DisabilityObserver()
- MortalityObserver()
disease:
- SI('ischemic_heart_disease')
- SI('ischemic_stroke')
# metrics:
# - DisabilityObserver()
# - MortalityObserver()
vivarium_sodium_reduction.components.risks:
- CorrelatedRisk('risk_factor.high_systolic_blood_pressure')
- CorrelatedRisk('risk_factor.diet_high_in_sodium')
- RiskCorrelation('risk_factor.high_systolic_blood_pressure', 'risk_factor.diet_high_in_sodium', '0.75')
- SodiumSBPEffect()
vivarium_public_health.risks:
- NonLogLinearRiskEffect('risk_factor.high_systolic_blood_pressure',
'cause.ischemic_heart_disease.incidence_rate')
- NonLogLinearRiskEffect('risk_factor.high_systolic_blood_pressure',
'cause.ischemic_stroke.incidence_rate')
vivarium_sodium_reduction.components.interventions:
- RelativeShiftIntervention('diet_high_in_sodium.exposure')

# Causes an error if left empty. Uncomment when you have components.
# vivarium_sodium_reduction.components:

configuration:
input_data:
artifact_path: /home/j/Project/simulation_science/sodium_usa.hdf
input_draw_number: 0
interpolation:
order: 0
extrapolate: False
extrapolate: True
randomness:
map_size: 1_000_000
key_columns: ['entrance_time', 'age']
random_seed: 0
time:
start:
year: 2005
month: 7
day: 2
year: 2025
month: 1
day: 1
end:
year: 2010
month: 7
day: 2
step_size: 1 # Days
year: 2044
month: 12
day: 31
step_size: 168 # 28*6 (approximately 6 months, in days)
population:
population_size: 100
age_start: 0
age_end: 100
population_size: 10_000
initialization_age_min: 25
initialization_age_max: 125

relative_shift_intervention:
diet_high_in_sodium:
exposure:
- shift_factor: 1.0
- age_start: 0
- age_end: 125

risk_factor:
high_systolic_blood_pressure:
- tmred: 1.0

stratification:
default:
Expand Down

0 comments on commit 392f108

Please sign in to comment.