diff --git a/CHANGELOG.rst b/CHANGELOG.rst index c62e2c559..e0e91dfb0 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,8 @@ +**2.3.1 - 3/11/24** + + - Update Mortality Observer to include tracked in population filter + - Fix bug in get_initialization_parameters to only remove existing keys if necessary + **2.3.0 - 3/7/24** - Update population configuration keys to be more descriptive diff --git a/src/vivarium_public_health/disease/state.py b/src/vivarium_public_health/disease/state.py index b54db761a..0f57b9420 100644 --- a/src/vivarium_public_health/disease/state.py +++ b/src/vivarium_public_health/disease/state.py @@ -85,8 +85,9 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None: def get_initialization_parameters(self) -> Dict[str, Any]: """Exclude side effect function and cause type from name and __repr__.""" initialization_parameters = super().get_initialization_parameters() - del initialization_parameters["side_effect_function"] - del initialization_parameters["cause_type"] + for key in ["side_effect_function", "cause_type"]: + if key in initialization_parameters.keys(): + del initialization_parameters[key] return initialization_parameters def get_initial_event_times(self, pop_data: SimulantData) -> pd.DataFrame: diff --git a/src/vivarium_public_health/metrics/mortality.py b/src/vivarium_public_health/metrics/mortality.py index fcb2e9c9e..8f527e33c 100644 --- a/src/vivarium_public_health/metrics/mortality.py +++ b/src/vivarium_public_health/metrics/mortality.py @@ -78,14 +78,18 @@ def setup(self, builder: Builder) -> None: self._cause_components = builder.components.get_components_by_type( (DiseaseState, RiskAttributableDisease) ) - + self.causes_of_death = ["other_causes"] + [ + cause.state_id for cause in self._cause_components if cause.has_excess_mortality + ] + self.required_death_columns = ["alive", "exit_time"] + self.required_yll_columns = [ + "alive", + "cause_of_death", + "exit_time", + "years_of_life_lost", + ] if not self.config.aggregate: - causes_of_death = ["other_causes"] + [ - cause.state_id - for cause in self._cause_components - if cause.has_excess_mortality - ] - for cause_of_death in causes_of_death: + for cause_of_death in self.causes_of_death: self._register_mortality_observations( builder, cause_of_death, f'cause_of_death == "{cause_of_death}"' ) @@ -100,15 +104,15 @@ def _register_mortality_observations( self, builder: Builder, cause: str, additional_pop_filter: str = "" ) -> None: pop_filter = ( - 'alive == "dead"' + 'alive == "dead" and tracked == True' if additional_pop_filter == "" - else f'alive == "dead" and {additional_pop_filter}' + else f'alive == "dead" and tracked == True and {additional_pop_filter}' ) builder.results.register_observation( name=f"death_due_to_{cause}", pop_filter=pop_filter, aggregator=self.count_deaths, - requires_columns=["alive", "exit_time"], + requires_columns=self.required_death_columns, additional_stratifications=self.config.include, excluded_stratifications=self.config.exclude, when="collect_metrics", @@ -117,12 +121,7 @@ def _register_mortality_observations( name=f"ylls_due_to_{cause}", pop_filter=pop_filter, aggregator=self.calculate_ylls, - requires_columns=[ - "alive", - "cause_of_death", - "exit_time", - "years_of_life_lost", - ], + requires_columns=self.required_yll_columns, additional_stratifications=self.config.include, excluded_stratifications=self.config.exclude, when="collect_metrics",