Skip to content

Commit

Permalink
Bugfix/sbachmei/mic 5296 exposure column for nonloglinear risk effects (
Browse files Browse the repository at this point in the history
#481)

* rename LBWSRisk 'exposure_column_name' staticmethod to not conflict w/ base Risk attribute
* create the exposure col only if needed by a NonLogLinearRiskEffect
  • Loading branch information
stevebachmeier authored Sep 4, 2024
1 parent 6ca60d1 commit 70d8dd8
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 17 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
**3.0.6 - 09/04/24**

- Fix bug that was occurring when RiskEffect's rr_source was a float or DataFrame
- Better handle exposure column creation in Risk component
- Rename LBWSRisk 'exposure_column_name()' staticmethod to not collide with Risk attr

**3.0.5 - 08/29/24**

Expand Down
35 changes: 24 additions & 11 deletions src/vivarium_public_health/risks/base_risk.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ def configuration_defaults(self) -> Dict[str, Any]:

@property
def columns_created(self) -> List[str]:
return [self.propensity_column_name, self.exposure_column_name]
columns_to_create = [self.propensity_column_name]
if self.create_exposure_column:
columns_to_create.append(self.exposure_column_name)
return columns_to_create

@property
def initialization_requirements(self) -> Dict[str, List[str]]:
Expand Down Expand Up @@ -172,6 +175,16 @@ def setup(self, builder: Builder) -> None:
self.propensity = self.get_propensity_pipeline(builder)
self.exposure = self.get_exposure_pipeline(builder)

# We want to set this to True iff there is a non-loglinear risk effect
# on this risk instance
self.create_exposure_column = bool(
[
component
for component in builder.components.list_components()
if component.startswith(f"non_log_linear_risk_effect.{self.risk.name}_on_")
]
)

def get_distribution_type(self, builder: Builder) -> str:
"""Get the distribution type for the risk from the configuration.
Expand Down Expand Up @@ -270,19 +283,19 @@ def get_exposure_pipeline(self, builder: Builder) -> Pipeline:
########################

def on_initialize_simulants(self, pop_data: SimulantData) -> None:
propensity_values = self.randomness.get_draw(pop_data.index)
df = pd.DataFrame(
{
self.propensity_column_name: self.randomness.get_draw(pop_data.index),
self.exposure_column_name: self.exposure_distribution.ppf(propensity_values),
}
propensity = pd.Series(
self.randomness.get_draw(pop_data.index), name=self.propensity_column_name
)
self.population_view.update(df)
self.population_view.update(propensity)
self.update_exposure_column(pop_data.index)

def on_time_step_prepare(self, event: Event) -> None:
exposure_values = self.exposure(event.index)
exposure_col = pd.Series(exposure_values, name=self.exposure_column_name)
self.population_view.update(exposure_col)
self.update_exposure_column(event.index)

def update_exposure_column(self, index: pd.Index) -> None:
if self.create_exposure_column:
exposure = pd.Series(self.exposure(index), name=self.exposure_column_name)
self.population_view.update(exposure)

##################################
# Pipeline sources and modifiers #
Expand Down
4 changes: 4 additions & 0 deletions src/vivarium_public_health/risks/effect.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,10 @@ def columns_required(self) -> list[str]:
# Setup methods #
#################

@staticmethod
def get_name(risk: EntityString, target: TargetString) -> str:
return f"non_log_linear_risk_effect.{risk.name}_on_{target}"

def build_all_lookup_tables(self, builder: Builder) -> None:
rr_data = self.get_relative_risk_data(builder)
self.validate_rr_data(rr_data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def birth_exposure_pipeline_name(axis: str) -> str:
return f"{axis}.birth_exposure"

@staticmethod
def exposure_column_name(axis: str) -> str:
def get_exposure_column_name(axis: str) -> str:
return f"{axis}_exposure"

##############
Expand All @@ -206,7 +206,7 @@ def configuration_defaults(self) -> Dict[str, Any]:

@property
def columns_created(self) -> List[str]:
return [self.exposure_column_name(axis) for axis in self.AXES]
return [self.get_exposure_column_name(axis) for axis in self.AXES]

#####################
# Lifecycle methods #
Expand Down Expand Up @@ -256,7 +256,7 @@ def get_pipeline(axis_: str):

def on_initialize_simulants(self, pop_data: SimulantData) -> None:
birth_exposures = {
self.exposure_column_name(axis): self.birth_exposures[
self.get_exposure_column_name(axis): self.birth_exposures[
self.birth_exposure_pipeline_name(axis)
](pop_data.index)
for axis in self.AXES
Expand Down Expand Up @@ -318,7 +318,7 @@ def __init__(self, target: str):
super().__init__("risk_factor.low_birth_weight_and_short_gestation", target)

self.lbwsg_exposure_column_names = [
LBWSGRisk.exposure_column_name(axis) for axis in LBWSGRisk.AXES
LBWSGRisk.get_exposure_column_name(axis) for axis in LBWSGRisk.AXES
]
self.relative_risk_pipeline_name = (
f"effect_of_{self.risk.name}_on_{self.target.name}.relative_risk"
Expand Down Expand Up @@ -433,8 +433,8 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None:
pop = self.population_view.subview(["sex"] + self.lbwsg_exposure_column_names).get(
pop_data.index
)
birth_weight = pop[LBWSGRisk.exposure_column_name(BIRTH_WEIGHT)]
gestational_age = pop[LBWSGRisk.exposure_column_name(GESTATIONAL_AGE)]
birth_weight = pop[LBWSGRisk.get_exposure_column_name(BIRTH_WEIGHT)]
gestational_age = pop[LBWSGRisk.get_exposure_column_name(GESTATIONAL_AGE)]

is_male = pop["sex"] == "Male"
is_tmrel = (self.TMREL_GESTATIONAL_AGE_INTERVAL.left <= gestational_age) & (
Expand Down

0 comments on commit 70d8dd8

Please sign in to comment.