Skip to content

Commit

Permalink
Merge pull request #311 from ihmeuw/develop
Browse files Browse the repository at this point in the history
Release candidate v0.10.24
  • Loading branch information
mattkappel authored May 11, 2023
2 parents efe0f35 + 59d04df commit 06879f5
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 19 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
**0.10.24 - 05/11/23**

- Standardize builder, cause argument order in state get data functions
- Mends a bug where configured key_columns for randomness were not used in register_simulants

**0.10.23 - 05/03/23**

- Throw error when artifact doesn't contain relative risk data for desired target
Expand Down
2 changes: 1 addition & 1 deletion src/vivarium_public_health/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
__summary__ = "Components for modelling diseases, risks, and interventions with ``vivarium``"
__uri__ = "https://github.com/ihmeuw/vivarium_public_health"

__version__ = "0.10.23"
__version__ = "0.10.24"

__author__ = "The vivarium_public_health developers"
__email__ = "[email protected]"
Expand Down
16 changes: 8 additions & 8 deletions src/vivarium_public_health/disease/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def add_transition(
if source_data_type == "rate":
if get_data_functions is None:
get_data_functions = {
"incidence_rate": lambda cause, builder: builder.data.load(
"incidence_rate": lambda builder, cause: builder.data.load(
f"{self.cause_type}.{cause}.incidence_rate"
)
}
Expand All @@ -173,7 +173,7 @@ def add_transition(
if source_data_type == "rate":
if get_data_functions is None:
get_data_functions = {
"incidence_rate": lambda cause, builder: builder.data.load(
"incidence_rate": lambda builder, cause: builder.data.load(
f"{self.cause_type}.{cause}.incidence_rate"
)
}
Expand Down Expand Up @@ -385,7 +385,7 @@ def add_transition(
if source_data_type == "rate":
if get_data_functions is None:
get_data_functions = {
"remission_rate": lambda cause, builder: builder.data.load(
"remission_rate": lambda builder, cause: builder.data.load(
f"{self.cause_type}.{cause}.remission_rate"
)
}
Expand Down Expand Up @@ -446,19 +446,19 @@ def _cleanup_effect(self, index, event_time):

def load_prevalence_data(self, builder):
if "prevalence" in self._get_data_functions:
return self._get_data_functions["prevalence"](self.cause, builder)
return self._get_data_functions["prevalence"](builder, self.cause)
else:
return builder.data.load(f"{self.cause_type}.{self.cause}.prevalence")

def load_birth_prevalence_data(self, builder):
if "birth_prevalence" in self._get_data_functions:
return self._get_data_functions["birth_prevalence"](self.cause, builder)
return self._get_data_functions["birth_prevalence"](builder, self.cause)
else:
return 0

def load_dwell_time_data(self, builder):
if "dwell_time" in self._get_data_functions:
dwell_time = self._get_data_functions["dwell_time"](self.cause, builder)
dwell_time = self._get_data_functions["dwell_time"](builder, self.cause)
else:
dwell_time = 0

Expand All @@ -474,7 +474,7 @@ def load_dwell_time_data(self, builder):
def load_disability_weight_data(self, builder):
if "disability_weight" in self._get_data_functions:
disability_weight = self._get_data_functions["disability_weight"](
self.cause, builder
builder, self.cause
)
else:
disability_weight = builder.data.load(
Expand All @@ -489,7 +489,7 @@ def load_disability_weight_data(self, builder):
def load_excess_mortality_rate_data(self, builder):
only_morbid = builder.data.load(f"cause.{self._model}.restrictions")["yld_only"]
if "excess_mortality_rate" in self._get_data_functions:
return self._get_data_functions["excess_mortality_rate"](self.cause, builder)
return self._get_data_functions["excess_mortality_rate"](builder, self.cause)
elif only_morbid:
return 0
else:
Expand Down
6 changes: 3 additions & 3 deletions src/vivarium_public_health/disease/transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ def compute_transition_rate(self, index):
def load_transition_rate_data(self, builder):
if "incidence_rate" in self._get_data_functions:
rate_data = self._get_data_functions["incidence_rate"](
self.output_state.cause, builder
builder, self.output_state.cause
)
pipeline_name = f"{self.output_state.state_id}.incidence_rate"
elif "remission_rate" in self._get_data_functions:
rate_data = self._get_data_functions["remission_rate"](
self.output_state.cause, builder
builder, self.output_state.cause
)
pipeline_name = f"{self.input_state.state_id}.remission_rate"
elif "transition_rate" in self._get_data_functions:
Expand Down Expand Up @@ -103,7 +103,7 @@ def setup(self, builder):
get_proportion_func = self._get_data_functions.get("proportion", None)
if get_proportion_func is None:
raise ValueError("Must supply a proportion function")
self._proportion_data = get_proportion_func(self.output_state.cause, builder)
self._proportion_data = get_proportion_func(builder, self.output_state.cause)
self.proportion = builder.lookup.build_table(
self._proportion_data, key_columns=["sex"], parameter_columns=["age", "year"]
)
Expand Down
15 changes: 12 additions & 3 deletions src/vivarium_public_health/population/base_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
characteristics to simulants.
"""
from typing import Callable, Dict, List
from typing import Callable, Dict, Iterable, List

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -51,6 +51,7 @@ def sub_components(self) -> List:
# noinspection PyAttributeOutsideInit
def setup(self, builder: Builder) -> None:
self.config = builder.configuration.population
self.key_columns = builder.configuration.randomness.key_columns
if self.config.include_sex not in ["Male", "Female", "Both"]:
raise ValueError(
"Configuration key 'population.include_sex' must be one "
Expand Down Expand Up @@ -102,7 +103,7 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None:
In general, most simulation components (except for those computing summary statistics)
ignore simulants if they are not in the 'alive' category. The 'entrance_time' and
'exit_time' categories simply mark when the simulant enters or leaves the simulation,
respectively. Here we are agnostic to the methods of entrance and exit (e.g birth,
respectively. Here we are agnostic to the methods of entrance and exit (e.g., birth,
migration, death, etc.) as these characteristics can be inferred from this column and
other information about the simulant and the simulation parameters.
Expand All @@ -127,6 +128,7 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None:
population_data=sub_pop_data,
randomness_streams=self.randomness,
register_simulants=self.register_simulants,
key_columns=self.key_columns,
)
)

Expand Down Expand Up @@ -187,6 +189,7 @@ def generate_population(
population_data: pd.DataFrame,
randomness_streams: Dict[str, RandomnessStream],
register_simulants: Callable[[pd.DataFrame], None],
key_columns: Iterable[str] = ("entrance_time", "age"),
) -> pd.DataFrame:
"""Produces a random set of simulants sampled from the provided `population_data`.
Expand All @@ -212,6 +215,8 @@ def generate_population(
The size of the initial time step.
register_simulants
A function to register the new simulants with the CRN framework.
key_columns
A list of key columns for random number generation.
Returns
-------
Expand Down Expand Up @@ -261,6 +266,7 @@ def generate_population(
age_end,
randomness_streams,
register_simulants,
key_columns,
)


Expand Down Expand Up @@ -335,6 +341,7 @@ def _assign_demography_with_age_bounds(
age_end: float,
randomness_streams: Dict[str, RandomnessStream],
register_simulants: Callable[[pd.DataFrame], None],
key_columns: Iterable[str] = ("entrance_time", "age"),
) -> pd.DataFrame:
"""Assigns an age, sex, and location to the provided simulants given a range of ages.
Expand All @@ -352,6 +359,8 @@ def _assign_demography_with_age_bounds(
Source of random number generation within the vivarium common random number framework.
register_simulants
A function to register the new simulants with the CRN framework.
key_columns
A list of key columns for random number generation.
Returns
-------
Expand Down Expand Up @@ -381,5 +390,5 @@ def _assign_demography_with_age_bounds(
simulants = smooth_ages(
simulants, pop_data, randomness_streams["age_smoothing_age_bounds"]
)
register_simulants(simulants[["entrance_time", "age"]])
register_simulants(simulants[list(key_columns)])
return simulants
6 changes: 2 additions & 4 deletions tests/disease/test_disease.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,6 @@ def test_mortality_rate(base_config, base_plugins, disease):


def test_incidence(base_config, base_plugins, disease):
year_start = base_config.time.start.year
year_end = base_config.time.end.year
time_step = pd.Timedelta(days=base_config.time.step_size)

healthy = BaseDiseaseState("healthy")
Expand All @@ -314,7 +312,7 @@ def test_incidence(base_config, base_plugins, disease):
transition = RateTransition(
input_state=healthy,
output_state=sick,
get_data_functions={"incidence_rate": lambda _, builder: builder.data.load(key)},
get_data_functions={"incidence_rate": lambda builder, _: builder.data.load(key)},
)
healthy.transition_set.append(transition)

Expand Down Expand Up @@ -354,7 +352,7 @@ def test_risk_deletion(base_config, base_plugins, disease):
transition = RateTransition(
input_state=healthy,
output_state=sick,
get_data_functions={"incidence_rate": lambda _, builder: builder.data.load(key)},
get_data_functions={"incidence_rate": lambda builder, _: builder.data.load(key)},
)
healthy.transition_set.append(transition)

Expand Down

0 comments on commit 06879f5

Please sign in to comment.