diff --git a/CHANGELOG.rst b/CHANGELOG.rst index e12ecbcee..46cb6b91c 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,8 @@ +**0.10.24 - 05/11/23** + + - Standardize builder, cause argument order in state get data functions + - Mends a bug where configured key_columns for randomness were not used in register_simulants + **0.10.23 - 05/03/23** - Throw error when artifact doesn't contain relative risk data for desired target diff --git a/src/vivarium_public_health/__about__.py b/src/vivarium_public_health/__about__.py index fdb73eedc..5faa3ab6d 100644 --- a/src/vivarium_public_health/__about__.py +++ b/src/vivarium_public_health/__about__.py @@ -13,7 +13,7 @@ __summary__ = "Components for modelling diseases, risks, and interventions with ``vivarium``" __uri__ = "https://github.com/ihmeuw/vivarium_public_health" -__version__ = "0.10.23" +__version__ = "0.10.24" __author__ = "The vivarium_public_health developers" __email__ = "vivarium.dev@gmail.com" diff --git a/src/vivarium_public_health/disease/state.py b/src/vivarium_public_health/disease/state.py index 7dce1d9f9..8cc25d718 100644 --- a/src/vivarium_public_health/disease/state.py +++ b/src/vivarium_public_health/disease/state.py @@ -146,7 +146,7 @@ def add_transition( if source_data_type == "rate": if get_data_functions is None: get_data_functions = { - "incidence_rate": lambda cause, builder: builder.data.load( + "incidence_rate": lambda builder, cause: builder.data.load( f"{self.cause_type}.{cause}.incidence_rate" ) } @@ -173,7 +173,7 @@ def add_transition( if source_data_type == "rate": if get_data_functions is None: get_data_functions = { - "incidence_rate": lambda cause, builder: builder.data.load( + "incidence_rate": lambda builder, cause: builder.data.load( f"{self.cause_type}.{cause}.incidence_rate" ) } @@ -385,7 +385,7 @@ def add_transition( if source_data_type == "rate": if get_data_functions is None: get_data_functions = { - "remission_rate": lambda cause, builder: builder.data.load( + "remission_rate": lambda builder, cause: builder.data.load( f"{self.cause_type}.{cause}.remission_rate" ) } @@ -446,19 +446,19 @@ def _cleanup_effect(self, index, event_time): def load_prevalence_data(self, builder): if "prevalence" in self._get_data_functions: - return self._get_data_functions["prevalence"](self.cause, builder) + return self._get_data_functions["prevalence"](builder, self.cause) else: return builder.data.load(f"{self.cause_type}.{self.cause}.prevalence") def load_birth_prevalence_data(self, builder): if "birth_prevalence" in self._get_data_functions: - return self._get_data_functions["birth_prevalence"](self.cause, builder) + return self._get_data_functions["birth_prevalence"](builder, self.cause) else: return 0 def load_dwell_time_data(self, builder): if "dwell_time" in self._get_data_functions: - dwell_time = self._get_data_functions["dwell_time"](self.cause, builder) + dwell_time = self._get_data_functions["dwell_time"](builder, self.cause) else: dwell_time = 0 @@ -474,7 +474,7 @@ def load_dwell_time_data(self, builder): def load_disability_weight_data(self, builder): if "disability_weight" in self._get_data_functions: disability_weight = self._get_data_functions["disability_weight"]( - self.cause, builder + builder, self.cause ) else: disability_weight = builder.data.load( @@ -489,7 +489,7 @@ def load_disability_weight_data(self, builder): def load_excess_mortality_rate_data(self, builder): only_morbid = builder.data.load(f"cause.{self._model}.restrictions")["yld_only"] if "excess_mortality_rate" in self._get_data_functions: - return self._get_data_functions["excess_mortality_rate"](self.cause, builder) + return self._get_data_functions["excess_mortality_rate"](builder, self.cause) elif only_morbid: return 0 else: diff --git a/src/vivarium_public_health/disease/transition.py b/src/vivarium_public_health/disease/transition.py index 54b86dff2..10554e941 100644 --- a/src/vivarium_public_health/disease/transition.py +++ b/src/vivarium_public_health/disease/transition.py @@ -62,12 +62,12 @@ def compute_transition_rate(self, index): def load_transition_rate_data(self, builder): if "incidence_rate" in self._get_data_functions: rate_data = self._get_data_functions["incidence_rate"]( - self.output_state.cause, builder + builder, self.output_state.cause ) pipeline_name = f"{self.output_state.state_id}.incidence_rate" elif "remission_rate" in self._get_data_functions: rate_data = self._get_data_functions["remission_rate"]( - self.output_state.cause, builder + builder, self.output_state.cause ) pipeline_name = f"{self.input_state.state_id}.remission_rate" elif "transition_rate" in self._get_data_functions: @@ -103,7 +103,7 @@ def setup(self, builder): get_proportion_func = self._get_data_functions.get("proportion", None) if get_proportion_func is None: raise ValueError("Must supply a proportion function") - self._proportion_data = get_proportion_func(self.output_state.cause, builder) + self._proportion_data = get_proportion_func(builder, self.output_state.cause) self.proportion = builder.lookup.build_table( self._proportion_data, key_columns=["sex"], parameter_columns=["age", "year"] ) diff --git a/src/vivarium_public_health/population/base_population.py b/src/vivarium_public_health/population/base_population.py index b8b202a8f..f49c3b25b 100644 --- a/src/vivarium_public_health/population/base_population.py +++ b/src/vivarium_public_health/population/base_population.py @@ -7,7 +7,7 @@ characteristics to simulants. """ -from typing import Callable, Dict, List +from typing import Callable, Dict, Iterable, List import numpy as np import pandas as pd @@ -51,6 +51,7 @@ def sub_components(self) -> List: # noinspection PyAttributeOutsideInit def setup(self, builder: Builder) -> None: self.config = builder.configuration.population + self.key_columns = builder.configuration.randomness.key_columns if self.config.include_sex not in ["Male", "Female", "Both"]: raise ValueError( "Configuration key 'population.include_sex' must be one " @@ -102,7 +103,7 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None: In general, most simulation components (except for those computing summary statistics) ignore simulants if they are not in the 'alive' category. The 'entrance_time' and 'exit_time' categories simply mark when the simulant enters or leaves the simulation, - respectively. Here we are agnostic to the methods of entrance and exit (e.g birth, + respectively. Here we are agnostic to the methods of entrance and exit (e.g., birth, migration, death, etc.) as these characteristics can be inferred from this column and other information about the simulant and the simulation parameters. @@ -127,6 +128,7 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None: population_data=sub_pop_data, randomness_streams=self.randomness, register_simulants=self.register_simulants, + key_columns=self.key_columns, ) ) @@ -187,6 +189,7 @@ def generate_population( population_data: pd.DataFrame, randomness_streams: Dict[str, RandomnessStream], register_simulants: Callable[[pd.DataFrame], None], + key_columns: Iterable[str] = ("entrance_time", "age"), ) -> pd.DataFrame: """Produces a random set of simulants sampled from the provided `population_data`. @@ -212,6 +215,8 @@ def generate_population( The size of the initial time step. register_simulants A function to register the new simulants with the CRN framework. + key_columns + A list of key columns for random number generation. Returns ------- @@ -261,6 +266,7 @@ def generate_population( age_end, randomness_streams, register_simulants, + key_columns, ) @@ -335,6 +341,7 @@ def _assign_demography_with_age_bounds( age_end: float, randomness_streams: Dict[str, RandomnessStream], register_simulants: Callable[[pd.DataFrame], None], + key_columns: Iterable[str] = ("entrance_time", "age"), ) -> pd.DataFrame: """Assigns an age, sex, and location to the provided simulants given a range of ages. @@ -352,6 +359,8 @@ def _assign_demography_with_age_bounds( Source of random number generation within the vivarium common random number framework. register_simulants A function to register the new simulants with the CRN framework. + key_columns + A list of key columns for random number generation. Returns ------- @@ -381,5 +390,5 @@ def _assign_demography_with_age_bounds( simulants = smooth_ages( simulants, pop_data, randomness_streams["age_smoothing_age_bounds"] ) - register_simulants(simulants[["entrance_time", "age"]]) + register_simulants(simulants[list(key_columns)]) return simulants diff --git a/tests/disease/test_disease.py b/tests/disease/test_disease.py index 62bf0af75..6a244bfce 100644 --- a/tests/disease/test_disease.py +++ b/tests/disease/test_disease.py @@ -303,8 +303,6 @@ def test_mortality_rate(base_config, base_plugins, disease): def test_incidence(base_config, base_plugins, disease): - year_start = base_config.time.start.year - year_end = base_config.time.end.year time_step = pd.Timedelta(days=base_config.time.step_size) healthy = BaseDiseaseState("healthy") @@ -314,7 +312,7 @@ def test_incidence(base_config, base_plugins, disease): transition = RateTransition( input_state=healthy, output_state=sick, - get_data_functions={"incidence_rate": lambda _, builder: builder.data.load(key)}, + get_data_functions={"incidence_rate": lambda builder, _: builder.data.load(key)}, ) healthy.transition_set.append(transition) @@ -354,7 +352,7 @@ def test_risk_deletion(base_config, base_plugins, disease): transition = RateTransition( input_state=healthy, output_state=sick, - get_data_functions={"incidence_rate": lambda _, builder: builder.data.load(key)}, + get_data_functions={"incidence_rate": lambda builder, _: builder.data.load(key)}, ) healthy.transition_set.append(transition)