diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6211e284f..242e9f1c6 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -53,3 +53,8 @@ jobs: - name: Doctest run: | make doctest -C docs/ + - name: Lint + run: | + pip install black==22.1.0 isort + black . --check --diff + isort . --check --verbose --only-modified --diff diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 26b18b98d..1560fea36 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,8 @@ +**0.10.17 - 02/15/22** + + - Autoformat code with black and isort. + - Add black and isort checks to CI. + **0.10.16 - 02/13/22** - Update CI diff --git a/docs/source/conf.py b/docs/source/conf.py index c5fa355d2..86ff0612d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -10,28 +10,29 @@ # All configuration values have a default; values that are commented out # serve to show the default. +import sys + # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. from pathlib import Path -import sys from docutils.nodes import Text from sphinx.ext.intersphinx import missing_reference - import vivarium_public_health + base_dir = Path(vivarium_public_health.__file__).parent about = {} with (base_dir / "__about__.py").open() as f: exec(f.read(), about) -sys.path.insert(0, str(Path('..').resolve())) +sys.path.insert(0, str(Path("..").resolve())) # -- Project information ----------------------------------------------------- -project = about['__title__'] +project = about["__title__"] copyright = f'2021, {about["__author__"]}' author = about["__author__"] @@ -45,35 +46,35 @@ # If your documentation needs a minimal Sphinx version, state it here. -needs_sphinx = '4.0' +needs_sphinx = "4.0" # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'sphinx.ext.doctest', - 'sphinx.ext.todo', - 'sphinx.ext.coverage', - 'sphinx.ext.mathjax', - 'sphinx.ext.napoleon', - 'sphinx.ext.viewcode', - 'sphinx_click.ext', - 'matplotlib.sphinxext.plot_directive', + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.doctest", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinx_click.ext", + "matplotlib.sphinxext.plot_directive", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. @@ -91,7 +92,7 @@ exclude_patterns = [] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = True @@ -102,8 +103,8 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme_path = ['_theme'] -html_theme = 'sphinx_rtd_theme' +html_theme_path = ["_theme"] +html_theme = "sphinx_rtd_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -114,7 +115,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. @@ -122,9 +123,9 @@ # This is required for the alabaster theme # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars html_sidebars = { - '**': [ - 'globaltoc.html', # needs 'show_related': True theme option to display - 'searchbox.html', + "**": [ + "globaltoc.html", # needs 'show_related': True theme option to display + "searchbox.html", ] } @@ -141,15 +142,12 @@ # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -159,8 +157,13 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, f'{about["__title__"]}.tex', f'{about["__title__"]} Documentation', - about["__author__"], 'manual'), + ( + master_doc, + f'{about["__title__"]}.tex', + f'{about["__title__"]} Documentation', + about["__author__"], + "manual", + ), ] @@ -169,8 +172,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - (master_doc, f'{about["__title__"]}', f'{about["__title__"]} Documentation', - [author], 1) + (master_doc, f'{about["__title__"]}', f'{about["__title__"]} Documentation', [author], 1) ] @@ -180,17 +182,25 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, f'{about["__title__"]}', f'{about["__title__"]} Documentation', - author, f'{about["__title__"]}', about["__summary__"], - 'Miscellaneous'), + ( + master_doc, + f'{about["__title__"]}', + f'{about["__title__"]} Documentation', + author, + f'{about["__title__"]}', + about["__summary__"], + "Miscellaneous", + ), ] # Other docs we can link to -intersphinx_mapping = {'python': ('https://docs.python.org/3.8', None), - 'pandas': ('https://pandas.pydata.org/pandas-docs/stable/', None), - 'tables': ('https://www.pytables.org/', None), - 'numpy': ('https://numpy.org/doc/stable/', None), - 'vivarium': ('https://vivarium.readthedocs.io/en/latest/', None)} +intersphinx_mapping = { + "python": ("https://docs.python.org/3.8", None), + "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), + "tables": ("https://www.pytables.org/", None), + "numpy": ("https://numpy.org/doc/stable/", None), + "vivarium": ("https://vivarium.readthedocs.io/en/latest/", None), +} # -- Autodoc configuration ------------------------------------------------ @@ -198,18 +208,18 @@ autodoc_default_options = { # Automatically document members (e.g. classes in a module, # methods in a class, etc.) - 'members': True, + "members": True, # Order of items documented is determined by the order # of appearance in the source code - 'member-order': 'bysource', + "member-order": "bysource", # Generate docs even if an item has no docstring. - 'undoc-members': True, + "undoc-members": True, # Don't document things with a leading underscore. - 'private-members': False, + "private-members": False, } # Display type hints in the description instead of the signature. -autodoc_typehints = 'description' +autodoc_typehints = "description" # -- nitpicky mode -------------------------------------------------------- @@ -218,7 +228,7 @@ nitpicky = True nitpick_ignore = [] -for line in open('../nitpick-exceptions'): +for line in open("../nitpick-exceptions"): if line.strip() == "" or line.startswith("#"): continue dtype, target = line.split(None, 1) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..8574c8edd --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,13 @@ +[tool.black] +line_length = 94 +exclude = ''' +/( + vivarium +)/ + +''' + +[tool.isort] +profile = "black" +skip = ["vivarium"] +known_third_party = ["vivarium"] diff --git a/setup.py b/setup.py index 15fd54c7b..b21c49ba3 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,7 @@ #!/usr/bin/env python import os -from setuptools import setup, find_packages - +from setuptools import find_packages, setup if __name__ == "__main__": @@ -17,40 +16,37 @@ long_description = f.read() install_requirements = [ - 'vivarium>=0.10.1', - 'numpy', - 'pandas', - 'scipy', - 'tables', - 'risk_distributions>=2.0.6', + "vivarium>=0.10.1", + "numpy", + "pandas", + "scipy", + "tables", + "risk_distributions>=2.0.6", ] test_requirements = [ - 'pytest', - 'pytest-mock', - 'hypothesis', + "pytest", + "pytest-mock", + "hypothesis", ] doc_requirements = [ - 'sphinx>=4.0', - 'sphinx-rtd-theme', - 'sphinx-click', - 'IPython', - 'matplotlib' + "sphinx>=4.0", + "sphinx-rtd-theme", + "sphinx-click", + "IPython", + "matplotlib", ] setup( - name=about['__title__'], - version=about['__version__'], - - description=about['__summary__'], + name=about["__title__"], + version=about["__version__"], + description=about["__summary__"], long_description=long_description, - license=about['__license__'], + license=about["__license__"], url=about["__uri__"], - author=about["__author__"], author_email=about["__email__"], - classifiers=[ "Intended Audience :: Developers", "Intended Audience :: Education", @@ -73,18 +69,15 @@ "Topic :: Scientific/Engineering :: Physics", "Topic :: Software Development :: Libraries", ], - - package_dir={'': 'src'}, - packages=find_packages(where='src'), + package_dir={"": "src"}, + packages=find_packages(where="src"), include_package_data=True, - install_requires=install_requirements, tests_require=test_requirements, extras_require={ - 'docs': doc_requirements, - 'test': test_requirements, - 'dev': doc_requirements + test_requirements, + "docs": doc_requirements, + "test": test_requirements, + "dev": doc_requirements + test_requirements, }, - zip_safe=False, ) diff --git a/src/vivarium_public_health/__about__.py b/src/vivarium_public_health/__about__.py index 3236dece8..fedf87082 100644 --- a/src/vivarium_public_health/__about__.py +++ b/src/vivarium_public_health/__about__.py @@ -1,13 +1,19 @@ __all__ = [ - "__title__", "__summary__", "__uri__", "__version__", "__author__", - "__email__", "__license__", "__copyright__", + "__title__", + "__summary__", + "__uri__", + "__version__", + "__author__", + "__email__", + "__license__", + "__copyright__", ] __title__ = "vivarium_public_health" __summary__ = "Components for modelling diseases, risks, and interventions with ``vivarium``" __uri__ = "https://github.com/ihmeuw/vivarium_public_health" -__version__ = "0.10.16" +__version__ = "0.10.17" __author__ = "The vivarium_public_health developers" __email__ = "vivarium.dev@gmail.com" diff --git a/src/vivarium_public_health/__init__.py b/src/vivarium_public_health/__init__.py index 0abf09faa..c75a8c910 100644 --- a/src/vivarium_public_health/__init__.py +++ b/src/vivarium_public_health/__init__.py @@ -1,5 +1,21 @@ -from vivarium_public_health.__about__ import (__author__, __copyright__, __email__, __license__, - __summary__, __title__, __uri__, __version__, ) +from vivarium_public_health.__about__ import ( + __author__, + __copyright__, + __email__, + __license__, + __summary__, + __title__, + __uri__, + __version__, +) -__all__ = [__author__, __copyright__, __email__, - __license__, __summary__, __title__, __uri__, __version__, ] +__all__ = [ + __author__, + __copyright__, + __email__, + __license__, + __summary__, + __title__, + __uri__, + __version__, +] diff --git a/src/vivarium_public_health/disease/__init__.py b/src/vivarium_public_health/disease/__init__.py index 7b58d688b..d55a090f0 100644 --- a/src/vivarium_public_health/disease/__init__.py +++ b/src/vivarium_public_health/disease/__init__.py @@ -1,7 +1,18 @@ -from .transition import RateTransition, ProportionTransition -from .state import (DiseaseState, TransientDiseaseState, - SusceptibleState, RecoveredState, BaseDiseaseState) from .model import DiseaseModel -from .models import (SI, SIR, SIS, SIS_fixed_duration, - NeonatalSWC_with_incidence, NeonatalSWC_without_incidence) +from .models import ( + SI, + SIR, + SIS, + NeonatalSWC_with_incidence, + NeonatalSWC_without_incidence, + SIS_fixed_duration, +) from .special_disease import RiskAttributableDisease +from .state import ( + BaseDiseaseState, + DiseaseState, + RecoveredState, + SusceptibleState, + TransientDiseaseState, +) +from .transition import ProportionTransition, RateTransition diff --git a/src/vivarium_public_health/disease/model.py b/src/vivarium_public_health/disease/model.py index 1c3f21008..79412441f 100644 --- a/src/vivarium_public_health/disease/model.py +++ b/src/vivarium_public_health/disease/model.py @@ -8,13 +8,12 @@ transitions at simulation initialization and during transitions. """ -import pandas as pd import numpy as np - +import pandas as pd from vivarium.exceptions import VivariumError from vivarium.framework.state_machine import Machine -from vivarium_public_health.disease import SusceptibleState +from vivarium_public_health.disease.state import SusceptibleState class DiseaseModelError(VivariumError): @@ -22,7 +21,9 @@ class DiseaseModelError(VivariumError): class DiseaseModel(Machine): - def __init__(self, cause, initial_state=None, get_data_functions=None, cause_type="cause", **kwargs): + def __init__( + self, cause, initial_state=None, get_data_functions=None, cause_type="cause", **kwargs + ): super().__init__(cause, **kwargs) self.cause = cause self.cause_type = cause_type @@ -32,7 +33,9 @@ def __init__(self, cause, initial_state=None, get_data_functions=None, cause_typ else: self.initial_state = self._get_default_initial_state() - self._get_data_functions = get_data_functions if get_data_functions is not None else {} + self._get_data_functions = ( + get_data_functions if get_data_functions is not None else {} + ) @property def name(self): @@ -40,16 +43,16 @@ def name(self): @property def state_names(self): - return [s.name.split('.')[1] for s in self.states] + return [s.name.split(".")[1] for s in self.states] @property def transition_names(self): - states = {s.name.split('.')[1]: s for s in self.states} + states = {s.name.split(".")[1]: s for s in self.states} transitions = [] for state in states.values(): for trans in state.transition_set.transitions: - _, _, init_state, _, end_state = trans.name.split('.') - transitions.append(f'{init_state}_TO_{end_state}') + _, _, init_state, _, end_state = trans.name.split(".") + transitions.append(f"{init_state}_TO_{end_state}") return transitions def setup(self, builder): @@ -60,54 +63,76 @@ def setup(self, builder): self.configuration_age_end = builder.configuration.population.age_end cause_specific_mortality_rate = self.load_cause_specific_mortality_rate_data(builder) - self.cause_specific_mortality_rate = builder.lookup.build_table(cause_specific_mortality_rate, - key_columns=['sex'], - parameter_columns=['age', 'year']) - builder.value.register_value_modifier('cause_specific_mortality_rate', - self.adjust_cause_specific_mortality_rate, - requires_columns=['age', 'sex']) - - self.population_view = builder.population.get_view(['age', 'sex', self.state_column]) - builder.population.initializes_simulants(self.on_initialize_simulants, - creates_columns=[self.state_column], - requires_columns=['age', 'sex'], - requires_streams=[f'{self.state_column}_initial_states']) - self.randomness = builder.randomness.get_stream(f'{self.state_column}_initial_states') - - builder.event.register_listener('time_step', self.on_time_step) - builder.event.register_listener('time_step__cleanup', self.on_time_step_cleanup) + self.cause_specific_mortality_rate = builder.lookup.build_table( + cause_specific_mortality_rate, + key_columns=["sex"], + parameter_columns=["age", "year"], + ) + builder.value.register_value_modifier( + "cause_specific_mortality_rate", + self.adjust_cause_specific_mortality_rate, + requires_columns=["age", "sex"], + ) + + self.population_view = builder.population.get_view(["age", "sex", self.state_column]) + builder.population.initializes_simulants( + self.on_initialize_simulants, + creates_columns=[self.state_column], + requires_columns=["age", "sex"], + requires_streams=[f"{self.state_column}_initial_states"], + ) + self.randomness = builder.randomness.get_stream(f"{self.state_column}_initial_states") + + builder.event.register_listener("time_step", self.on_time_step) + builder.event.register_listener("time_step__cleanup", self.on_time_step_cleanup) def on_initialize_simulants(self, pop_data): - population = self.population_view.subview(['age', 'sex']).get(pop_data.index) + population = self.population_view.subview(["age", "sex"]).get(pop_data.index) assert self.initial_state in {s.state_id for s in self.states} # FIXME: this is a hack to figure out whether or not we're at the simulation start based on the fact that the # fertility components create this user data - if pop_data.user_data['sim_state'] == 'setup': # simulation start + if pop_data.user_data["sim_state"] == "setup": # simulation start if self.configuration_age_start != self.configuration_age_end != 0: - state_names, weights_bins = self.get_state_weights(pop_data.index, "prevalence") + state_names, weights_bins = self.get_state_weights( + pop_data.index, "prevalence" + ) else: - raise NotImplementedError('We do not currently support an age 0 cohort. ' - 'configuration.population.age_start and configuration.population.age_end ' - 'cannot both be 0.') + raise NotImplementedError( + "We do not currently support an age 0 cohort. " + "configuration.population.age_start and configuration.population.age_end " + "cannot both be 0." + ) else: # on time step - if pop_data.user_data['age_start'] == pop_data.user_data['age_end'] == 0: - state_names, weights_bins = self.get_state_weights(pop_data.index, "birth_prevalence") + if pop_data.user_data["age_start"] == pop_data.user_data["age_end"] == 0: + state_names, weights_bins = self.get_state_weights( + pop_data.index, "birth_prevalence" + ) else: - state_names, weights_bins = self.get_state_weights(pop_data.index, "prevalence") + state_names, weights_bins = self.get_state_weights( + pop_data.index, "prevalence" + ) if state_names and not population.empty: # only do this if there are states in the model that supply prevalence data - population['sex_id'] = population.sex.apply({'Male': 1, 'Female': 2}.get) - - condition_column = self.assign_initial_status_to_simulants(population, state_names, weights_bins, - self.randomness.get_draw(population.index)) - - condition_column = condition_column.rename(columns={'condition_state': self.state_column}) + population["sex_id"] = population.sex.apply({"Male": 1, "Female": 2}.get) + + condition_column = self.assign_initial_status_to_simulants( + population, + state_names, + weights_bins, + self.randomness.get_draw(population.index), + ) + + condition_column = condition_column.rename( + columns={"condition_state": self.state_column} + ) else: - condition_column = pd.Series(self.initial_state, index=population.index, name=self.state_column) + condition_column = pd.Series( + self.initial_state, index=population.index, name=self.state_column + ) self.population_view.update(condition_column) def on_time_step(self, event): @@ -117,14 +142,18 @@ def on_time_step_cleanup(self, event): self.cleanup(event.index, event.time) def load_cause_specific_mortality_rate_data(self, builder): - if 'cause_specific_mortality_rate' not in self._get_data_functions: - only_morbid = builder.data.load(f'cause.{self.cause}.restrictions')['yld_only'] + if "cause_specific_mortality_rate" not in self._get_data_functions: + only_morbid = builder.data.load(f"cause.{self.cause}.restrictions")["yld_only"] if only_morbid: csmr_data = 0 else: - csmr_data = builder.data.load(f"{self.cause_type}.{self.cause}.cause_specific_mortality_rate") + csmr_data = builder.data.load( + f"{self.cause_type}.{self.cause}.cause_specific_mortality_rate" + ) else: - csmr_data = self._get_data_functions['cause_specific_mortality_rate'](self.cause, builder) + csmr_data = self._get_data_functions["cause_specific_mortality_rate"]( + self.cause, builder + ) return csmr_data def adjust_cause_specific_mortality_rate(self, index, rate): @@ -137,16 +166,20 @@ def _get_default_initial_state(self): return susceptible_states[0].state_id def get_state_weights(self, pop_index, prevalence_type): - states = [s for s in self.states - if hasattr(s, f'{prevalence_type}') and getattr(s, f'{prevalence_type}') is not None] + states = [ + s + for s in self.states + if hasattr(s, f"{prevalence_type}") + and getattr(s, f"{prevalence_type}") is not None + ] if not states: return states, None - weights = [getattr(s, f'{prevalence_type}')(pop_index) for s in states] + weights = [getattr(s, f"{prevalence_type}")(pop_index) for s in states] for w in weights: w.reset_index(inplace=True, drop=True) - weights += ((1 - np.sum(weights, axis=0)), ) + weights += ((1 - np.sum(weights, axis=0)),) weights = np.array(weights).T weights_bins = np.cumsum(weights, axis=1) @@ -156,11 +189,13 @@ def get_state_weights(self, pop_index, prevalence_type): return state_names, weights_bins @staticmethod - def assign_initial_status_to_simulants(simulants_df, state_names, weights_bins, propensities): - simulants = simulants_df[['age', 'sex']].copy() + def assign_initial_status_to_simulants( + simulants_df, state_names, weights_bins, propensities + ): + simulants = simulants_df[["age", "sex"]].copy() choice_index = (propensities.values[np.newaxis].T > weights_bins).sum(axis=1) initial_states = pd.Series(np.array(state_names)[choice_index], index=simulants.index) - simulants.loc[:, 'condition_state'] = initial_states + simulants.loc[:, "condition_state"] = initial_states return simulants diff --git a/src/vivarium_public_health/disease/models.py b/src/vivarium_public_health/disease/models.py index d21d8e9f3..f4fa32016 100644 --- a/src/vivarium_public_health/disease/models.py +++ b/src/vivarium_public_health/disease/models.py @@ -9,7 +9,12 @@ """ import pandas as pd -from vivarium_public_health.disease import SusceptibleState, RecoveredState, DiseaseState, DiseaseModel +from vivarium_public_health.disease.model import DiseaseModel +from vivarium_public_health.disease.state import ( + DiseaseState, + RecoveredState, + SusceptibleState, +) def SI(cause: str) -> DiseaseModel: @@ -17,7 +22,7 @@ def SI(cause: str) -> DiseaseModel: infected = DiseaseState(cause) healthy.allow_self_transitions() - healthy.add_transition(infected, source_data_type='rate') + healthy.add_transition(infected, source_data_type="rate") infected.allow_self_transitions() return DiseaseModel(cause, states=[healthy, infected]) @@ -29,9 +34,9 @@ def SIR(cause: str) -> DiseaseModel: recovered = RecoveredState(cause) healthy.allow_self_transitions() - healthy.add_transition(infected, source_data_type='rate') + healthy.add_transition(infected, source_data_type="rate") infected.allow_self_transitions() - infected.add_transition(recovered, source_data_type='rate') + infected.add_transition(recovered, source_data_type="rate") recovered.allow_self_transitions() return DiseaseModel(cause, states=[healthy, infected, recovered]) @@ -42,9 +47,9 @@ def SIS(cause: str) -> DiseaseModel: infected = DiseaseState(cause) healthy.allow_self_transitions() - healthy.add_transition(infected, source_data_type='rate') + healthy.add_transition(infected, source_data_type="rate") infected.allow_self_transitions() - infected.add_transition(healthy, source_data_type='rate') + infected.add_transition(healthy, source_data_type="rate") return DiseaseModel(cause, states=[healthy, infected]) @@ -53,10 +58,10 @@ def SIS_fixed_duration(cause: str, duration: str) -> DiseaseModel: duration = pd.Timedelta(days=float(duration) // 1, hours=(float(duration) % 1) * 24.0) healthy = SusceptibleState(cause) - infected = DiseaseState(cause, get_data_functions={'dwell_time': lambda _, __: duration}) + infected = DiseaseState(cause, get_data_functions={"dwell_time": lambda _, __: duration}) healthy.allow_self_transitions() - healthy.add_transition(infected, source_data_type='rate') + healthy.add_transition(infected, source_data_type="rate") infected.add_transition(healthy) infected.allow_self_transitions() @@ -67,21 +72,24 @@ def SIR_fixed_duration(cause: str, duration: str) -> DiseaseModel: duration = pd.Timedelta(days=float(duration) // 1, hours=(float(duration) % 1) * 24.0) healthy = SusceptibleState(cause) - infected = DiseaseState(cause, get_data_functions={'dwell_time': lambda _, __: duration}) + infected = DiseaseState(cause, get_data_functions={"dwell_time": lambda _, __: duration}) recovered = RecoveredState(cause) healthy.allow_self_transitions() - healthy.add_transition(infected, source_data_type='rate') + healthy.add_transition(infected, source_data_type="rate") infected.add_transition(recovered) infected.allow_self_transitions() recovered.allow_self_transitions() return DiseaseModel(cause, states=[healthy, infected, recovered]) - + def NeonatalSWC_without_incidence(cause): - with_condition_data_functions = {'birth_prevalence': - lambda cause, builder: builder.data.load(f"cause.{cause}.birth_prevalence")} + with_condition_data_functions = { + "birth_prevalence": lambda cause, builder: builder.data.load( + f"cause.{cause}.birth_prevalence" + ) + } healthy = SusceptibleState(cause) with_condition = DiseaseState(cause, get_data_functions=with_condition_data_functions) @@ -93,14 +101,17 @@ def NeonatalSWC_without_incidence(cause): def NeonatalSWC_with_incidence(cause): - with_condition_data_functions = {'birth_prevalence': - lambda cause, builder: builder.data.load(f"cause.{cause}.birth_prevalence")} + with_condition_data_functions = { + "birth_prevalence": lambda cause, builder: builder.data.load( + f"cause.{cause}.birth_prevalence" + ) + } healthy = SusceptibleState(cause) with_condition = DiseaseState(cause, get_data_functions=with_condition_data_functions) healthy.allow_self_transitions() - healthy.add_transition(with_condition, source_data_type='rate') + healthy.add_transition(with_condition, source_data_type="rate") with_condition.allow_self_transitions() return DiseaseModel(cause, states=[healthy, with_condition]) diff --git a/src/vivarium_public_health/disease/special_disease.py b/src/vivarium_public_health/disease/special_disease.py index 1d7be318a..038b56acf 100644 --- a/src/vivarium_public_health/disease/special_disease.py +++ b/src/vivarium_public_health/disease/special_disease.py @@ -6,9 +6,9 @@ This module contains frequently used, but non-standard disease models. """ -from collections import namedtuple -from operator import lt, gt import re +from collections import namedtuple +from operator import gt, lt import pandas as pd from vivarium.framework.values import list_combiner, union_post_processor @@ -41,7 +41,7 @@ class RiskAttributableDisease: For categorical risks, the threshold should be provided as a list of categories. This list contains the categories that indicate the simulant is experiencing the condition. For a dichotomous risk - there will be 2 categories. By convention ``cat1`` is used to + there will be 2 categories. By convention ``cat1`` is used to indicate the with condition state and would be the single item in the ``threshold`` setting list. @@ -81,10 +81,10 @@ class RiskAttributableDisease: """ configuration_defaults = { - 'risk_attributable_disease': { - 'threshold': None, - 'mortality': True, - 'recoverable': True + "risk_attributable_disease": { + "threshold": None, + "mortality": True, + "recoverable": True, } } @@ -93,20 +93,24 @@ def __init__(self, cause, risk): self.risk = EntityString(risk) self.state_column = self.cause.name self.state_id = self.cause.name - self.diseased_event_time_column = f'{self.cause.name}_event_time' - self.susceptible_event_time_column = f'susceptible_to_{self.cause.name}_event_time' + self.diseased_event_time_column = f"{self.cause.name}_event_time" + self.susceptible_event_time_column = f"susceptible_to_{self.cause.name}_event_time" self.configuration_defaults = { - self.cause.name: RiskAttributableDisease.configuration_defaults['risk_attributable_disease'] + self.cause.name: RiskAttributableDisease.configuration_defaults[ + "risk_attributable_disease" + ] } - self._state_names = [f'{self.cause.name}', f'susceptible_to_{self.cause.name}'] - self._transition_names = [f'susceptible_to_{self.cause.name}_TO_{self.cause.name}'] + self._state_names = [f"{self.cause.name}", f"susceptible_to_{self.cause.name}"] + self._transition_names = [f"susceptible_to_{self.cause.name}_TO_{self.cause.name}"] - self.excess_mortality_rate_pipeline_name = f'{self.cause.name}.excess_mortality_rate' - self.excess_mortality_rate_paf_pipeline_name = f'{self.excess_mortality_rate_pipeline_name}.paf' + self.excess_mortality_rate_pipeline_name = f"{self.cause.name}.excess_mortality_rate" + self.excess_mortality_rate_paf_pipeline_name = ( + f"{self.excess_mortality_rate_pipeline_name}.paf" + ) @property def name(self): - return f'disease_model.{self.cause.name}' + return f"disease_model.{self.cause.name}" @property def state_names(self): @@ -122,68 +126,96 @@ def setup(self, builder): self.adjust_state_and_transitions() self.clock = builder.time.clock() - disability_weight_data = builder.data.load(f'{self.cause}.disability_weight') - self.base_disability_weight = builder.lookup.build_table(disability_weight_data, key_columns=['sex'], - parameter_columns=['age', 'year']) + disability_weight_data = builder.data.load(f"{self.cause}.disability_weight") + self.base_disability_weight = builder.lookup.build_table( + disability_weight_data, key_columns=["sex"], parameter_columns=["age", "year"] + ) self.disability_weight = builder.value.register_value_producer( - f'{self.cause.name}.disability_weight', + f"{self.cause.name}.disability_weight", source=self.compute_disability_weight, - requires_columns=['age', 'sex', 'alive', self.cause.name] + requires_columns=["age", "sex", "alive", self.cause.name], + ) + builder.value.register_value_modifier( + "disability_weight", modifier=self.disability_weight ) - builder.value.register_value_modifier('disability_weight', modifier=self.disability_weight) cause_specific_mortality_rate = self.load_cause_specific_mortality_rate_data(builder) - self.cause_specific_mortality_rate = builder.lookup.build_table(cause_specific_mortality_rate, - key_columns=['sex'], - parameter_columns=['age', 'year']) - builder.value.register_value_modifier('cause_specific_mortality_rate', - self.adjust_cause_specific_mortality_rate, - requires_columns=['age', 'sex']) + self.cause_specific_mortality_rate = builder.lookup.build_table( + cause_specific_mortality_rate, + key_columns=["sex"], + parameter_columns=["age", "year"], + ) + builder.value.register_value_modifier( + "cause_specific_mortality_rate", + self.adjust_cause_specific_mortality_rate, + requires_columns=["age", "sex"], + ) excess_mortality_data = self.load_excess_mortality_rate_data(builder) - self.base_excess_mortality_rate = builder.lookup.build_table(excess_mortality_data, key_columns=['sex'], - parameter_columns=['age', 'year']) + self.base_excess_mortality_rate = builder.lookup.build_table( + excess_mortality_data, key_columns=["sex"], parameter_columns=["age", "year"] + ) self.excess_mortality_rate = builder.value.register_value_producer( self.excess_mortality_rate_pipeline_name, source=self.compute_excess_mortality_rate, - requires_columns=['age', 'sex', 'alive', self.cause.name], - requires_values=[self.excess_mortality_rate_paf_pipeline_name] + requires_columns=["age", "sex", "alive", self.cause.name], + requires_values=[self.excess_mortality_rate_paf_pipeline_name], ) paf = builder.lookup.build_table(0) self.joint_paf = builder.value.register_value_producer( self.excess_mortality_rate_paf_pipeline_name, source=lambda idx: [paf(idx)], preferred_combiner=list_combiner, - preferred_post_processor=union_post_processor + preferred_post_processor=union_post_processor, + ) + builder.value.register_value_modifier( + "mortality_rate", + modifier=self.adjust_mortality_rate, + requires_values=[self.excess_mortality_rate_pipeline_name], ) - builder.value.register_value_modifier('mortality_rate', - modifier=self.adjust_mortality_rate, - requires_values=[self.excess_mortality_rate_pipeline_name]) - distribution = builder.data.load(f'{self.risk}.distribution') - exposure_pipeline = builder.value.get_value(f'{self.risk.name}.exposure') + distribution = builder.data.load(f"{self.risk}.distribution") + exposure_pipeline = builder.value.get_value(f"{self.risk.name}.exposure") threshold = builder.configuration[self.cause.name].threshold - self.filter_by_exposure = self.get_exposure_filter(distribution, exposure_pipeline, threshold) - self.population_view = builder.population.get_view([self.cause.name, self.diseased_event_time_column, - self.susceptible_event_time_column, 'alive']) + self.filter_by_exposure = self.get_exposure_filter( + distribution, exposure_pipeline, threshold + ) + self.population_view = builder.population.get_view( + [ + self.cause.name, + self.diseased_event_time_column, + self.susceptible_event_time_column, + "alive", + ] + ) - builder.population.initializes_simulants(self.on_initialize_simulants, - creates_columns=[self.cause.name, - self.diseased_event_time_column, - self.susceptible_event_time_column], - requires_values=[f'{self.risk.name}.exposure']) + builder.population.initializes_simulants( + self.on_initialize_simulants, + creates_columns=[ + self.cause.name, + self.diseased_event_time_column, + self.susceptible_event_time_column, + ], + requires_values=[f"{self.risk.name}.exposure"], + ) - builder.event.register_listener('time_step', self.on_time_step) + builder.event.register_listener("time_step", self.on_time_step) def on_initialize_simulants(self, pop_data): - new_pop = pd.DataFrame({self.cause.name: f'susceptible_to_{self.cause.name}', - self.diseased_event_time_column: pd.Series(pd.NaT, index=pop_data.index), - self.susceptible_event_time_column: pd.Series(pd.NaT, index=pop_data.index)}, - index=pop_data.index) + new_pop = pd.DataFrame( + { + self.cause.name: f"susceptible_to_{self.cause.name}", + self.diseased_event_time_column: pd.Series(pd.NaT, index=pop_data.index), + self.susceptible_event_time_column: pd.Series(pd.NaT, index=pop_data.index), + }, + index=pop_data.index, + ) sick = self.filter_by_exposure(pop_data.index) new_pop.loc[sick, self.cause.name] = self.cause.name - new_pop.loc[sick, self.diseased_event_time_column] = self.clock() # match VPH disease, only set w/ condition + new_pop.loc[ + sick, self.diseased_event_time_column + ] = self.clock() # match VPH disease, only set w/ condition self.population_view.update(new_pop) @@ -192,9 +224,13 @@ def on_time_step(self, event): sick = self.filter_by_exposure(pop.index) # if this is recoverable, anyone who gets lower exposure in the event goes back in to susceptible status. if self.recoverable: - change_to_susceptible = (~sick) & (pop[self.cause.name] != f'susceptible_to_{self.cause.name}') + change_to_susceptible = (~sick) & ( + pop[self.cause.name] != f"susceptible_to_{self.cause.name}" + ) pop.loc[change_to_susceptible, self.susceptible_event_time_column] = event.time - pop.loc[change_to_susceptible, self.cause.name] = f'susceptible_to_{self.cause.name}' + pop.loc[ + change_to_susceptible, self.cause.name + ] = f"susceptible_to_{self.cause.name}" change_to_diseased = sick & (pop[self.cause.name] != self.cause.name) pop.loc[change_to_diseased, self.diseased_event_time_column] = event.time pop.loc[change_to_diseased, self.cause.name] = self.cause.name @@ -212,7 +248,9 @@ def compute_excess_mortality_rate(self, index): with_condition = self.with_condition(index) base_excess_mort = self.base_excess_mortality_rate(with_condition) joint_mediated_paf = self.joint_paf(with_condition) - excess_mortality_rate.loc[with_condition] = base_excess_mort * (1 - joint_mediated_paf.values) + excess_mortality_rate.loc[with_condition] = base_excess_mort * ( + 1 - joint_mediated_paf.values + ) return excess_mortality_rate def adjust_cause_specific_mortality_rate(self, index, rate): @@ -233,34 +271,41 @@ def adjust_mortality_rate(self, index, rates_df): return rates_df def with_condition(self, index): - pop = self.population_view.subview(['alive', self.cause.name]).get(index) - with_condition = pop.loc[(pop[self.cause.name] == self.cause.name) & (pop['alive'] == 'alive')].index + pop = self.population_view.subview(["alive", self.cause.name]).get(index) + with_condition = pop.loc[ + (pop[self.cause.name] == self.cause.name) & (pop["alive"] == "alive") + ].index return with_condition def get_exposure_filter(self, distribution, exposure_pipeline, threshold): - if distribution in ['dichotomous', 'ordered_polytomous', 'unordered_polytomous']: + if distribution in ["dichotomous", "ordered_polytomous", "unordered_polytomous"]: def categorical_filter(index): exposure = exposure_pipeline(index) return exposure.isin(threshold) + filter_function = categorical_filter else: # continuous - Threshold = namedtuple('Threshold', ['operator', 'value']) + Threshold = namedtuple("Threshold", ["operator", "value"]) threshold_val = re.findall(r"[-+]?\d*\.?\d+", threshold) if len(threshold_val) != 1: - raise ValueError(f'Your {threshold} is an incorrect threshold format. It should include ' - f'"<" or ">" along with an integer or float number. Your threshold does not ' - f'include a number or more than one number.') + raise ValueError( + f"Your {threshold} is an incorrect threshold format. It should include " + f'"<" or ">" along with an integer or float number. Your threshold does not ' + f"include a number or more than one number." + ) - allowed_operator = {'<', '>'} + allowed_operator = {"<", ">"} threshold_op = [s for s in threshold.split(threshold_val[0]) if s] # if threshold_op has more than 1 operators or 0 operator if len(threshold_op) != 1 or not allowed_operator.intersection(threshold_op): - raise ValueError(f'Your {threshold} is an incorrect threshold format. It should include ' - f'"<" or ">" along with an integer or float number.') + raise ValueError( + f"Your {threshold} is an incorrect threshold format. It should include " + f'"<" or ">" along with an integer or float number.' + ) op = gt if threshold_op[0] == ">" else lt threshold = Threshold(op, float(threshold_val[0])) @@ -268,24 +313,29 @@ def categorical_filter(index): def continuous_filter(index): exposure = exposure_pipeline(index) return threshold.operator(exposure, threshold.value) + filter_function = continuous_filter return filter_function def adjust_state_and_transitions(self): if self.recoverable: - self._transition_names.append(f'{self.cause.name}_TO_susceptible_to_{self.cause.name}') + self._transition_names.append( + f"{self.cause.name}_TO_susceptible_to_{self.cause.name}" + ) def load_cause_specific_mortality_rate_data(self, builder): if builder.configuration[self.cause.name].mortality: - csmr_data = builder.data.load(f'cause.{self.cause.name}.cause_specific_mortality_rate') + csmr_data = builder.data.load( + f"cause.{self.cause.name}.cause_specific_mortality_rate" + ) else: csmr_data = 0 return csmr_data def load_excess_mortality_rate_data(self, builder): if builder.configuration[self.cause.name].mortality: - emr_data = builder.data.load(f'cause.{self.cause.name}.excess_mortality_rate') + emr_data = builder.data.load(f"cause.{self.cause.name}.excess_mortality_rate") else: emr_data = 0 return emr_data diff --git a/src/vivarium_public_health/disease/state.py b/src/vivarium_public_health/disease/state.py index d7188cc6e..eebfc93b3 100644 --- a/src/vivarium_public_health/disease/state.py +++ b/src/vivarium_public_health/disease/state.py @@ -8,17 +8,21 @@ """ from typing import Callable, Dict -import pandas as pd import numpy as np +import pandas as pd from vivarium.framework.state_machine import State, Transient, Transition from vivarium.framework.values import list_combiner, union_post_processor -from vivarium_public_health.disease import RateTransition, ProportionTransition +from vivarium_public_health.disease.transition import ( + ProportionTransition, + RateTransition, +) class BaseDiseaseState(State): - - def __init__(self, cause, name_prefix='', side_effect_function=None, cause_type="cause", **kwargs): + def __init__( + self, cause, name_prefix="", side_effect_function=None, cause_type="cause", **kwargs + ): super().__init__(name_prefix + cause) # becomes state_id self.cause_type = cause_type self.cause = cause @@ -27,8 +31,8 @@ def __init__(self, cause, name_prefix='', side_effect_function=None, cause_type= if self.side_effect_function is not None: self._sub_components.append(side_effect_function) - self.event_time_column = self.state_id + '_event_time' - self.event_count_column = self.state_id + '_event_count' + self.event_time_column = self.state_id + "_event_time" + self.event_count_column = self.state_id + "_event_count" @property def columns_created(self): @@ -47,11 +51,13 @@ def setup(self, builder): self.clock = builder.time.clock() - view_columns = self.columns_created + [self._model, 'alive'] + view_columns = self.columns_created + [self._model, "alive"] self.population_view = builder.population.get_view(view_columns) - builder.population.initializes_simulants(self.on_initialize_simulants, - creates_columns=self.columns_created, - requires_columns=[self._model]) + builder.population.initializes_simulants( + self.on_initialize_simulants, + creates_columns=self.columns_created, + requires_columns=[self._model], + ) def on_initialize_simulants(self, pop_data): """Adds this state's columns to the simulation state table.""" @@ -59,9 +65,9 @@ def on_initialize_simulants(self, pop_data): if transition.start_active: transition.set_active(pop_data.index) - pop_update = pd.DataFrame({self.event_time_column: pd.NaT, - self.event_count_column: 0}, - index=pop_data.index) + pop_update = pd.DataFrame( + {self.event_time_column: pd.NaT, self.event_count_column: 0}, index=pop_data.index + ) self.population_view.update(pop_update) def _transition_side_effect(self, index, event_time): @@ -83,8 +89,13 @@ def _transition_side_effect(self, index, event_time): if self.side_effect_function is not None: self.side_effect_function(index, event_time) - def add_transition(self, output: State, source_data_type: str = None, - get_data_functions: Dict[str, Callable] = None, **kwargs) -> Transition: + def add_transition( + self, + output: State, + source_data_type: str = None, + get_data_functions: Dict[str, Callable] = None, + **kwargs, + ) -> Transition: """Builds a transition from this state to the given state. Parameters @@ -104,7 +115,7 @@ def add_transition(self, output: State, source_data_type: str = None, The created transition object. """ - transition_map = {'rate': RateTransition, 'proportion': ProportionTransition} + transition_map = {"rate": RateTransition, "proportion": ProportionTransition} if not source_data_type: return super().add_transition(output, **kwargs) @@ -117,49 +128,62 @@ def add_transition(self, output: State, source_data_type: str = None, class SusceptibleState(BaseDiseaseState): - def __init__(self, cause, *args, **kwargs): - super().__init__(cause, *args, name_prefix='susceptible_to_', **kwargs) - - def add_transition(self, output: State, source_data_type: str = None, - get_data_functions: Dict[str, Callable] = None, **kwargs) -> Transition: - if source_data_type == 'rate': + super().__init__(cause, *args, name_prefix="susceptible_to_", **kwargs) + + def add_transition( + self, + output: State, + source_data_type: str = None, + get_data_functions: Dict[str, Callable] = None, + **kwargs, + ) -> Transition: + if source_data_type == "rate": if get_data_functions is None: get_data_functions = { - 'incidence_rate': lambda cause, builder: builder.data.load(f"{self.cause_type}.{cause}.incidence_rate") + "incidence_rate": lambda cause, builder: builder.data.load( + f"{self.cause_type}.{cause}.incidence_rate" + ) } - elif 'incidence_rate' not in get_data_functions: - raise ValueError('You must supply an incidence rate function.') - elif source_data_type == 'proportion': - if 'proportion' not in get_data_functions: - raise ValueError('You must supply a proportion function.') + elif "incidence_rate" not in get_data_functions: + raise ValueError("You must supply an incidence rate function.") + elif source_data_type == "proportion": + if "proportion" not in get_data_functions: + raise ValueError("You must supply a proportion function.") return super().add_transition(output, source_data_type, get_data_functions, **kwargs) class RecoveredState(BaseDiseaseState): - def __init__(self, cause, *args, **kwargs): super().__init__(cause, *args, name_prefix="recovered_from_", **kwargs) - def add_transition(self, output: State, source_data_type: str = None, - get_data_functions: Dict[str, Callable] = None, **kwargs) -> Transition: - if source_data_type == 'rate': + def add_transition( + self, + output: State, + source_data_type: str = None, + get_data_functions: Dict[str, Callable] = None, + **kwargs, + ) -> Transition: + if source_data_type == "rate": if get_data_functions is None: get_data_functions = { - 'incidence_rate': lambda cause, builder: builder.data.load(f"{self.cause_type}.{cause}.incidence_rate") + "incidence_rate": lambda cause, builder: builder.data.load( + f"{self.cause_type}.{cause}.incidence_rate" + ) } - elif 'incidence_rate' not in get_data_functions: - raise ValueError('You must supply an incidence rate function.') - elif source_data_type == 'proportion': - if 'proportion' not in get_data_functions: - raise ValueError('You must supply a proportion function.') + elif "incidence_rate" not in get_data_functions: + raise ValueError("You must supply an incidence rate function.") + elif source_data_type == "proportion": + if "proportion" not in get_data_functions: + raise ValueError("You must supply a proportion function.") return super().add_transition(output, source_data_type, get_data_functions, **kwargs) class DiseaseState(BaseDiseaseState): """State representing a disease in a state machine model.""" + def __init__(self, cause, get_data_functions=None, cleanup_function=None, **kwargs): """ Parameters @@ -181,16 +205,23 @@ def __init__(self, cause, get_data_functions=None, cleanup_function=None, **kwar """ super().__init__(cause, **kwargs) - self.excess_mortality_rate_pipeline_name = f'{self.state_id}.excess_mortality_rate' - self.excess_mortality_rate_paf_pipeline_name = f'{self.excess_mortality_rate_pipeline_name}.paf' + self.excess_mortality_rate_pipeline_name = f"{self.state_id}.excess_mortality_rate" + self.excess_mortality_rate_paf_pipeline_name = ( + f"{self.excess_mortality_rate_pipeline_name}.paf" + ) - self._get_data_functions = get_data_functions if get_data_functions is not None else {} + self._get_data_functions = ( + get_data_functions if get_data_functions is not None else {} + ) self.cleanup_function = cleanup_function - if (self.cause is None and - not set(self._get_data_functions.keys()).issuperset(['disability_weight', 'dwell_time', 'prevalence'])): - raise ValueError('If you do not provide a cause, you must supply' - 'custom data gathering functions for disability_weight, prevalence, and dwell_time.') + if self.cause is None and not set(self._get_data_functions.keys()).issuperset( + ["disability_weight", "dwell_time", "prevalence"] + ): + raise ValueError( + "If you do not provide a cause, you must supply" + "custom data gathering functions for disability_weight, prevalence, and dwell_time." + ) # noinspection PyAttributeOutsideInit def setup(self, builder): @@ -204,60 +235,76 @@ def setup(self, builder): super().setup(builder) prevalence_data = self.load_prevalence_data(builder) - self.prevalence = builder.lookup.build_table(prevalence_data, key_columns=['sex'], - parameter_columns=['age', 'year']) + self.prevalence = builder.lookup.build_table( + prevalence_data, key_columns=["sex"], parameter_columns=["age", "year"] + ) birth_prevalence_data = self.load_birth_prevalence_data(builder) - self.birth_prevalence = builder.lookup.build_table(birth_prevalence_data, - key_columns=['sex'], - parameter_columns=['year']) + self.birth_prevalence = builder.lookup.build_table( + birth_prevalence_data, key_columns=["sex"], parameter_columns=["year"] + ) dwell_time_data = self.load_dwell_time_data(builder) self.dwell_time = builder.value.register_value_producer( - f'{self.state_id}.dwell_time', - source=builder.lookup.build_table(dwell_time_data, key_columns=['sex'], parameter_columns=['age', 'year']), - requires_columns=['age', 'sex'] + f"{self.state_id}.dwell_time", + source=builder.lookup.build_table( + dwell_time_data, key_columns=["sex"], parameter_columns=["age", "year"] + ), + requires_columns=["age", "sex"], ) disability_weight_data = self.load_disability_weight_data(builder) - self.base_disability_weight = builder.lookup.build_table(disability_weight_data, key_columns=['sex'], - parameter_columns=['age', 'year']) + self.base_disability_weight = builder.lookup.build_table( + disability_weight_data, key_columns=["sex"], parameter_columns=["age", "year"] + ) self.disability_weight = builder.value.register_value_producer( - f'{self.state_id}.disability_weight', + f"{self.state_id}.disability_weight", source=self.compute_disability_weight, - requires_columns=['age', 'sex', 'alive', self._model] + requires_columns=["age", "sex", "alive", self._model], + ) + builder.value.register_value_modifier( + "disability_weight", modifier=self.disability_weight ) - builder.value.register_value_modifier('disability_weight', modifier=self.disability_weight) excess_mortality_data = self.load_excess_mortality_rate_data(builder) - self.base_excess_mortality_rate = builder.lookup.build_table(excess_mortality_data, key_columns=['sex'], - parameter_columns=['age', 'year']) + self.base_excess_mortality_rate = builder.lookup.build_table( + excess_mortality_data, key_columns=["sex"], parameter_columns=["age", "year"] + ) self.excess_mortality_rate = builder.value.register_rate_producer( self.excess_mortality_rate_pipeline_name, source=self.compute_excess_mortality_rate, - requires_columns=['age', 'sex', 'alive', self._model], - requires_values=[self.excess_mortality_rate_paf_pipeline_name] + requires_columns=["age", "sex", "alive", self._model], + requires_values=[self.excess_mortality_rate_paf_pipeline_name], ) paf = builder.lookup.build_table(0) self.joint_paf = builder.value.register_value_producer( self.excess_mortality_rate_paf_pipeline_name, source=lambda idx: [paf(idx)], preferred_combiner=list_combiner, - preferred_post_processor=union_post_processor + preferred_post_processor=union_post_processor, + ) + builder.value.register_value_modifier( + "mortality_rate", + modifier=self.adjust_mortality_rate, + requires_values=[self.excess_mortality_rate_pipeline_name], ) - builder.value.register_value_modifier('mortality_rate', - modifier=self.adjust_mortality_rate, - requires_values=[self.excess_mortality_rate_pipeline_name]) - self.randomness_prevalence = builder.randomness.get_stream(f'{self.state_id}_prevalent_cases') + self.randomness_prevalence = builder.randomness.get_stream( + f"{self.state_id}_prevalent_cases" + ) def on_initialize_simulants(self, pop_data): super().on_initialize_simulants(pop_data) - simulants_with_condition = self.population_view.subview([self._model]).get(pop_data.index, query=f'{self._model}=="{self.state_id}"') + simulants_with_condition = self.population_view.subview([self._model]).get( + pop_data.index, query=f'{self._model}=="{self.state_id}"' + ) if not simulants_with_condition.empty: - infected_at = self._assign_event_time_for_prevalent_cases(simulants_with_condition, self.clock(), - self.randomness_prevalence.get_draw, - self.dwell_time) + infected_at = self._assign_event_time_for_prevalent_cases( + simulants_with_condition, + self.clock(), + self.randomness_prevalence.get_draw, + self.dwell_time, + ) infected_at.name = self.event_time_column self.population_view.update(infected_at) @@ -284,7 +331,9 @@ def compute_excess_mortality_rate(self, index): with_condition = self.with_condition(index) base_excess_mort = self.base_excess_mortality_rate(with_condition) joint_mediated_paf = self.joint_paf(with_condition) - excess_mortality_rate.loc[with_condition] = base_excess_mort * (1 - joint_mediated_paf.values) + excess_mortality_rate.loc[with_condition] = base_excess_mort * ( + 1 - joint_mediated_paf.values + ) return excess_mortality_rate def adjust_mortality_rate(self, index, rates_df): @@ -302,29 +351,45 @@ def adjust_mortality_rate(self, index, rates_df): return rates_df def with_condition(self, index): - pop = self.population_view.subview(['alive', self._model]).get(index) - with_condition = pop.loc[(pop[self._model] == self.state_id) & (pop['alive'] == 'alive')].index + pop = self.population_view.subview(["alive", self._model]).get(index) + with_condition = pop.loc[ + (pop[self._model] == self.state_id) & (pop["alive"] == "alive") + ].index return with_condition @staticmethod - def _assign_event_time_for_prevalent_cases(infected, current_time, randomness_func, dwell_time_func): + def _assign_event_time_for_prevalent_cases( + infected, current_time, randomness_func, dwell_time_func + ): dwell_time = dwell_time_func(infected.index) infected_at = dwell_time * randomness_func(infected.index) - infected_at = current_time - pd.to_timedelta(infected_at, unit='D') + infected_at = current_time - pd.to_timedelta(infected_at, unit="D") return infected_at - def add_transition(self, output: State, source_data_type: str = None, - get_data_functions: Dict[str, Callable] = None, **kwargs) -> Transition: - if source_data_type == 'rate': + def add_transition( + self, + output: State, + source_data_type: str = None, + get_data_functions: Dict[str, Callable] = None, + **kwargs, + ) -> Transition: + if source_data_type == "rate": if get_data_functions is None: get_data_functions = { - 'remission_rate': lambda cause, builder: builder.data.load(f"{self.cause_type}.{cause}.remission_rate") + "remission_rate": lambda cause, builder: builder.data.load( + f"{self.cause_type}.{cause}.remission_rate" + ) } - elif 'remission_rate' not in get_data_functions and 'transition_rate' not in get_data_functions: - raise ValueError('You must supply a transition rate or remission rate function.') - elif source_data_type == 'proportion': - if 'proportion' not in get_data_functions: - raise ValueError('You must supply a proportion function.') + elif ( + "remission_rate" not in get_data_functions + and "transition_rate" not in get_data_functions + ): + raise ValueError( + "You must supply a transition rate or remission rate function." + ) + elif source_data_type == "proportion": + if "proportion" not in get_data_functions: + raise ValueError("You must supply a proportion function.") return super().add_transition(output, source_data_type, get_data_functions, **kwargs) def next_state(self, index, event_time, population_view): @@ -357,7 +422,9 @@ def _filter_for_transition_eligibility(self, index, event_time): """ population = self.population_view.get(index, query='alive == "alive"') if np.any(self.dwell_time(index)) > 0: - state_exit_time = population[self.event_time_column] + pd.to_timedelta(self.dwell_time(index), unit='D') + state_exit_time = population[self.event_time_column] + pd.to_timedelta( + self.dwell_time(index), unit="D" + ) return population.loc[state_exit_time <= event_time].index else: return index @@ -367,35 +434,41 @@ def _cleanup_effect(self, index, event_time): self.cleanup_function(index, event_time) def load_prevalence_data(self, builder): - if 'prevalence' in self._get_data_functions: - return self._get_data_functions['prevalence'](self.cause, builder) + if "prevalence" in self._get_data_functions: + return self._get_data_functions["prevalence"](self.cause, builder) else: return builder.data.load(f"{self.cause_type}.{self.cause}.prevalence") def load_birth_prevalence_data(self, builder): - if 'birth_prevalence' in self._get_data_functions: - return self._get_data_functions['birth_prevalence'](self.cause, builder) + if "birth_prevalence" in self._get_data_functions: + return self._get_data_functions["birth_prevalence"](self.cause, builder) else: return 0 def load_dwell_time_data(self, builder): - if 'dwell_time' in self._get_data_functions: - dwell_time = self._get_data_functions['dwell_time'](self.cause, builder) + if "dwell_time" in self._get_data_functions: + dwell_time = self._get_data_functions["dwell_time"](self.cause, builder) else: dwell_time = 0 if isinstance(dwell_time, pd.Timedelta): dwell_time = dwell_time.total_seconds() / (60 * 60 * 24) - if (isinstance(dwell_time, pd.DataFrame) and np.any(dwell_time.value != 0)) or dwell_time > 0: + if ( + isinstance(dwell_time, pd.DataFrame) and np.any(dwell_time.value != 0) + ) or dwell_time > 0: self.transition_set.allow_null_transition = True return dwell_time def load_disability_weight_data(self, builder): - if 'disability_weight' in self._get_data_functions: - disability_weight = self._get_data_functions['disability_weight'](self.cause, builder) + if "disability_weight" in self._get_data_functions: + disability_weight = self._get_data_functions["disability_weight"]( + self.cause, builder + ) else: - disability_weight = builder.data.load(f"{self.cause_type}.{self.cause}.disability_weight") + disability_weight = builder.data.load( + f"{self.cause_type}.{self.cause}.disability_weight" + ) if isinstance(disability_weight, pd.DataFrame) and len(disability_weight) == 1: disability_weight = disability_weight.value[0] # sequela only have single value @@ -403,19 +476,18 @@ def load_disability_weight_data(self, builder): return disability_weight def load_excess_mortality_rate_data(self, builder): - only_morbid = builder.data.load(f'cause.{self._model}.restrictions')['yld_only'] - if 'excess_mortality_rate' in self._get_data_functions: - return self._get_data_functions['excess_mortality_rate'](self.cause, builder) + only_morbid = builder.data.load(f"cause.{self._model}.restrictions")["yld_only"] + if "excess_mortality_rate" in self._get_data_functions: + return self._get_data_functions["excess_mortality_rate"](self.cause, builder) elif only_morbid: return 0 else: - return builder.data.load(f'{self.cause_type}.{self.cause}.excess_mortality_rate') + return builder.data.load(f"{self.cause_type}.{self.cause}.excess_mortality_rate") def __repr__(self): - return 'DiseaseState({})'.format(self.state_id) + return "DiseaseState({})".format(self.state_id) class TransientDiseaseState(BaseDiseaseState, Transient): - def __repr__(self): - return 'TransientDiseaseState(name={})'.format(self.state_id) + return "TransientDiseaseState(name={})".format(self.state_id) diff --git a/src/vivarium_public_health/disease/transition.py b/src/vivarium_public_health/disease/transition.py index c7751c08b..ebdba24d2 100644 --- a/src/vivarium_public_health/disease/transition.py +++ b/src/vivarium_public_health/disease/transition.py @@ -7,7 +7,6 @@ """ import pandas as pd - from vivarium.framework.state_machine import Transition from vivarium.framework.utilities import rate_to_probability from vivarium.framework.values import list_combiner, union_post_processor @@ -15,24 +14,34 @@ class RateTransition(Transition): def __init__(self, input_state, output_state, get_data_functions=None, **kwargs): - super().__init__(input_state, output_state, probability_func=self._probability, **kwargs) - self._get_data_functions = get_data_functions if get_data_functions is not None else {} + super().__init__( + input_state, output_state, probability_func=self._probability, **kwargs + ) + self._get_data_functions = ( + get_data_functions if get_data_functions is not None else {} + ) # noinspection PyAttributeOutsideInit def setup(self, builder): rate_data, pipeline_name = self.load_transition_rate_data(builder) - self.base_rate = builder.lookup.build_table(rate_data, key_columns=['sex'], parameter_columns=['age', 'year']) - self.transition_rate = builder.value.register_rate_producer(pipeline_name, - source=self.compute_transition_rate, - requires_columns=['age', 'sex', 'alive'], - requires_values=[f'{pipeline_name}.paf']) + self.base_rate = builder.lookup.build_table( + rate_data, key_columns=["sex"], parameter_columns=["age", "year"] + ) + self.transition_rate = builder.value.register_rate_producer( + pipeline_name, + source=self.compute_transition_rate, + requires_columns=["age", "sex", "alive"], + requires_values=[f"{pipeline_name}.paf"], + ) paf = builder.lookup.build_table(0) - self.joint_paf = builder.value.register_value_producer(f'{pipeline_name}.paf', - source=lambda index: [paf(index)], - preferred_combiner=list_combiner, - preferred_post_processor=union_post_processor) + self.joint_paf = builder.value.register_value_producer( + f"{pipeline_name}.paf", + source=lambda index: [paf(index)], + preferred_combiner=list_combiner, + preferred_post_processor=union_post_processor, + ) - self.population_view = builder.population.get_view(['alive']) + self.population_view = builder.population.get_view(["alive"]) def compute_transition_rate(self, index): transition_rate = pd.Series(0, index=index) @@ -43,16 +52,23 @@ def compute_transition_rate(self, index): return transition_rate def load_transition_rate_data(self, builder): - if 'incidence_rate' in self._get_data_functions: - rate_data = self._get_data_functions['incidence_rate'](self.output_state.cause, builder) - pipeline_name = f'{self.output_state.state_id}.incidence_rate' - elif 'remission_rate' in self._get_data_functions: - rate_data = self._get_data_functions['remission_rate'](self.output_state.cause, builder) - pipeline_name = f'{self.input_state.state_id}.remission_rate' - elif 'transition_rate' in self._get_data_functions: - rate_data = self._get_data_functions['transition_rate'](builder, self.input_state.cause, - self.output_state.cause) - pipeline_name = f'{self.input_state.cause}_to_{self.output_state.cause}.transition_rate' + if "incidence_rate" in self._get_data_functions: + rate_data = self._get_data_functions["incidence_rate"]( + self.output_state.cause, builder + ) + pipeline_name = f"{self.output_state.state_id}.incidence_rate" + elif "remission_rate" in self._get_data_functions: + rate_data = self._get_data_functions["remission_rate"]( + self.output_state.cause, builder + ) + pipeline_name = f"{self.input_state.state_id}.remission_rate" + elif "transition_rate" in self._get_data_functions: + rate_data = self._get_data_functions["transition_rate"]( + builder, self.input_state.cause, self.output_state.cause + ) + pipeline_name = ( + f"{self.input_state.cause}_to_{self.output_state.cause}.transition_rate" + ) else: raise ValueError("No valid data functions supplied.") return rate_data, pipeline_name @@ -61,26 +77,31 @@ def _probability(self, index): return rate_to_probability(self.transition_rate(index)) def __str__(self): - return f'RateTransition(from={self.input_state.state_id}, to={self.output_state.state_id})' + return f"RateTransition(from={self.input_state.state_id}, to={self.output_state.state_id})" class ProportionTransition(Transition): def __init__(self, input_state, output_state, get_data_functions=None, **kwargs): - super().__init__(input_state, output_state, probability_func=self._probability, **kwargs) - self._get_data_functions = get_data_functions if get_data_functions is not None else {} + super().__init__( + input_state, output_state, probability_func=self._probability, **kwargs + ) + self._get_data_functions = ( + get_data_functions if get_data_functions is not None else {} + ) # noinspection PyAttributeOutsideInit def setup(self, builder): super().setup(builder) - get_proportion_func = self._get_data_functions.get('proportion', None) + get_proportion_func = self._get_data_functions.get("proportion", None) if get_proportion_func is None: - raise ValueError('Must supply a proportion function') + raise ValueError("Must supply a proportion function") self._proportion_data = get_proportion_func(self.output_state.cause, builder) - self.proportion = builder.lookup.build_table(self._proportion_data, key_columns=['sex'], - parameter_columns=['age', 'year']) + self.proportion = builder.lookup.build_table( + self._proportion_data, key_columns=["sex"], parameter_columns=["age", "year"] + ) def _probability(self, index): return self.proportion(index) def __str__(self): - return f'ProportionTransition(from={self.input_state.state_id}, {self.output_state.state_id})' + return f"ProportionTransition(from={self.input_state.state_id}, {self.output_state.state_id})" diff --git a/src/vivarium_public_health/metrics/__init__.py b/src/vivarium_public_health/metrics/__init__.py index e596841b7..3aff3368f 100644 --- a/src/vivarium_public_health/metrics/__init__.py +++ b/src/vivarium_public_health/metrics/__init__.py @@ -1,4 +1,4 @@ from .disability import DisabilityObserver +from .disease import DiseaseObserver from .mortality import MortalityObserver from .risk import CategoricalRiskObserver -from .disease import DiseaseObserver diff --git a/src/vivarium_public_health/metrics/disability.py b/src/vivarium_public_health/metrics/disability.py index 8271c0224..8f52fdbb2 100644 --- a/src/vivarium_public_health/metrics/disability.py +++ b/src/vivarium_public_health/metrics/disability.py @@ -10,10 +10,17 @@ from collections import Counter import pandas as pd -from vivarium.framework.values import list_combiner, union_post_processor, rescale_post_processor +from vivarium.framework.values import ( + list_combiner, + rescale_post_processor, + union_post_processor, +) from vivarium_public_health.disease import DiseaseState, RiskAttributableDisease -from .utilities import get_age_bins, get_years_lived_with_disability +from vivarium_public_health.metrics.utilities import ( + get_age_bins, + get_years_lived_with_disability, +) class DisabilityObserver: @@ -37,67 +44,86 @@ class DisabilityObserver: by_sex: True """ + configuration_defaults = { - 'metrics': { - 'disability': { - 'by_age': False, - 'by_year': False, - 'by_sex': False, + "metrics": { + "disability": { + "by_age": False, + "by_year": False, + "by_sex": False, } } } @property def name(self): - return 'disability_observer' + return "disability_observer" def setup(self, builder): self.config = builder.configuration.metrics.disability self.age_bins = get_age_bins(builder) self.clock = builder.time.clock() self.step_size = builder.time.step_size() - self.causes = [c.state_id - for c in builder.components.get_components_by_type((DiseaseState, RiskAttributableDisease))] + self.causes = [ + c.state_id + for c in builder.components.get_components_by_type( + (DiseaseState, RiskAttributableDisease) + ) + ] self.years_lived_with_disability = Counter() - self.disability_weight_pipelines = {cause: builder.value.get_value(f'{cause}.disability_weight') - for cause in self.causes} + self.disability_weight_pipelines = { + cause: builder.value.get_value(f"{cause}.disability_weight") + for cause in self.causes + } self.disability_weight = builder.value.register_value_producer( - 'disability_weight', + "disability_weight", source=lambda index: [pd.Series(0.0, index=index)], preferred_combiner=list_combiner, - preferred_post_processor=_disability_post_processor) + preferred_post_processor=_disability_post_processor, + ) - columns_required = ['tracked', 'alive', 'years_lived_with_disability'] + columns_required = ["tracked", "alive", "years_lived_with_disability"] if self.config.by_age: - columns_required += ['age'] + columns_required += ["age"] if self.config.by_sex: - columns_required += ['sex'] + columns_required += ["sex"] self.population_view = builder.population.get_view(columns_required) - builder.population.initializes_simulants(self.initialize_disability, - creates_columns=['years_lived_with_disability']) + builder.population.initializes_simulants( + self.initialize_disability, creates_columns=["years_lived_with_disability"] + ) # FIXME: The state table is modified before the clock advances. # In order to get an accurate representation of person time we need to look at # the state table before anything happens. - builder.event.register_listener('time_step__prepare', self.on_time_step_prepare) - builder.value.register_value_modifier('metrics', modifier=self.metrics) + builder.event.register_listener("time_step__prepare", self.on_time_step_prepare) + builder.value.register_value_modifier("metrics", modifier=self.metrics) def initialize_disability(self, pop_data): - self.population_view.update(pd.Series(0., index=pop_data.index, name='years_lived_with_disability')) + self.population_view.update( + pd.Series(0.0, index=pop_data.index, name="years_lived_with_disability") + ) def on_time_step_prepare(self, event): - pop = self.population_view.get(event.index, query='tracked == True and alive == "alive"') - ylds_this_step = get_years_lived_with_disability(pop, self.config.to_dict(), - self.clock().year, self.step_size(), - self.age_bins, self.disability_weight_pipelines, self.causes) + pop = self.population_view.get( + event.index, query='tracked == True and alive == "alive"' + ) + ylds_this_step = get_years_lived_with_disability( + pop, + self.config.to_dict(), + self.clock().year, + self.step_size(), + self.age_bins, + self.disability_weight_pipelines, + self.causes, + ) self.years_lived_with_disability.update(ylds_this_step) - pop.loc[:, 'years_lived_with_disability'] += self.disability_weight(pop.index) + pop.loc[:, "years_lived_with_disability"] += self.disability_weight(pop.index) self.population_view.update(pop) def metrics(self, index, metrics): - total_ylds = self.population_view.get(index)['years_lived_with_disability'].sum() - metrics['years_lived_with_disability'] = total_ylds + total_ylds = self.population_view.get(index)["years_lived_with_disability"].sum() + metrics["years_lived_with_disability"] = total_ylds metrics.update(self.years_lived_with_disability) return metrics diff --git a/src/vivarium_public_health/metrics/disease.py b/src/vivarium_public_health/metrics/disease.py index 756532e91..b8d13c403 100644 --- a/src/vivarium_public_health/metrics/disease.py +++ b/src/vivarium_public_health/metrics/disease.py @@ -11,8 +11,13 @@ import pandas as pd -from .utilities import (get_age_bins, get_prevalent_cases, get_state_person_time, - get_transition_count, TransitionString) +from vivarium_public_health.metrics.utilities import ( + TransitionString, + get_age_bins, + get_prevalent_cases, + get_state_person_time, + get_transition_count, +) class DiseaseObserver: @@ -46,18 +51,19 @@ class DiseaseObserver: day: 10 """ + configuration_defaults = { - 'metrics': { - 'disease_observer': { - 'by_age': False, - 'by_year': False, - 'by_sex': False, - 'sample_prevalence': { - 'sample': False, - 'date': { - 'month': 7, - 'day': 1, - } + "metrics": { + "disease_observer": { + "by_age": False, + "by_year": False, + "by_sex": False, + "sample_prevalence": { + "sample": False, + "date": { + "month": 7, + "day": 1, + }, }, } } @@ -66,54 +72,67 @@ class DiseaseObserver: def __init__(self, disease: str): self.disease = disease self.configuration_defaults = { - 'metrics': {f'{disease}_observer': DiseaseObserver.configuration_defaults['metrics']['disease_observer']} + "metrics": { + f"{disease}_observer": DiseaseObserver.configuration_defaults["metrics"][ + "disease_observer" + ] + } } @property def name(self): - return f'disease_observer.{self.disease}' + return f"disease_observer.{self.disease}" def setup(self, builder): - self.config = builder.configuration['metrics'][f'{self.disease}_observer'] + self.config = builder.configuration["metrics"][f"{self.disease}_observer"] self.clock = builder.time.clock() self.age_bins = get_age_bins(builder) self.counts = Counter() self.person_time = Counter() self.prevalence = Counter() - comp = builder.components.get_component(f'disease_model.{self.disease}') + comp = builder.components.get_component(f"disease_model.{self.disease}") self.states = comp.state_names self.transitions = comp.transition_names - self.previous_state_column = f'previous_{self.disease}' - builder.population.initializes_simulants(self.on_initialize_simulants, - creates_columns=[self.previous_state_column]) + self.previous_state_column = f"previous_{self.disease}" + builder.population.initializes_simulants( + self.on_initialize_simulants, creates_columns=[self.previous_state_column] + ) - columns_required = ['alive', f'{self.disease}', self.previous_state_column] + columns_required = ["alive", f"{self.disease}", self.previous_state_column] if self.config.by_age: - columns_required += ['age'] + columns_required += ["age"] if self.config.by_sex: - columns_required += ['sex'] + columns_required += ["sex"] self.population_view = builder.population.get_view(columns_required) - builder.value.register_value_modifier('metrics', self.metrics) + builder.value.register_value_modifier("metrics", self.metrics) # FIXME: The state table is modified before the clock advances. # In order to get an accurate representation of person time we need to look at # the state table before anything happens. - builder.event.register_listener('time_step__prepare', self.on_time_step_prepare) - builder.event.register_listener('collect_metrics', self.on_collect_metrics) + builder.event.register_listener("time_step__prepare", self.on_time_step_prepare) + builder.event.register_listener("collect_metrics", self.on_collect_metrics) def on_initialize_simulants(self, pop_data): - self.population_view.update(pd.Series('', index=pop_data.index, name=self.previous_state_column)) + self.population_view.update( + pd.Series("", index=pop_data.index, name=self.previous_state_column) + ) def on_time_step_prepare(self, event): pop = self.population_view.get(event.index) for state in self.states: # noinspection PyTypeChecker - state_person_time_this_step = get_state_person_time(pop, self.config, self.disease, - state, self.clock().year, event.step_size, - self.age_bins) + state_person_time_this_step = get_state_person_time( + pop, + self.config, + self.disease, + state, + self.clock().year, + event.step_size, + self.age_bins, + ) self.person_time.update(state_person_time_this_step) # This enables tracking of transitions between states @@ -122,22 +141,32 @@ def on_time_step_prepare(self, event): self.population_view.update(prior_state_pop) if self._should_sample(event.time): - point_prevalence = get_prevalent_cases(pop, self.config.to_dict(), self.disease, event.time, self.age_bins) + point_prevalence = get_prevalent_cases( + pop, self.config.to_dict(), self.disease, event.time, self.age_bins + ) self.prevalence.update(point_prevalence) def on_collect_metrics(self, event): pop = self.population_view.get(event.index) for transition in self.transitions: # noinspection PyTypeChecker - transition_counts_this_step = get_transition_count(pop, self.config, self.disease, - TransitionString(transition), event.time, self.age_bins) + transition_counts_this_step = get_transition_count( + pop, + self.config, + self.disease, + TransitionString(transition), + event.time, + self.age_bins, + ) self.counts.update(transition_counts_this_step) def _should_sample(self, event_time: pd.Timestamp) -> bool: """Returns true if we should sample on this time step.""" should_sample = self.config.sample_prevalence.sample if should_sample: - sample_date = pd.Timestamp(year=event_time.year, **self.config.sample_prevalence.date.to_dict()) + sample_date = pd.Timestamp( + year=event_time.year, **self.config.sample_prevalence.date.to_dict() + ) should_sample &= self.clock() <= sample_date < event_time return should_sample diff --git a/src/vivarium_public_health/metrics/mortality.py b/src/vivarium_public_health/metrics/mortality.py index a05d1303e..0909db913 100644 --- a/src/vivarium_public_health/metrics/mortality.py +++ b/src/vivarium_public_health/metrics/mortality.py @@ -8,11 +8,16 @@ """ from vivarium_public_health.disease import DiseaseState, RiskAttributableDisease -from .utilities import get_age_bins, get_person_time, get_deaths, get_years_of_life_lost +from vivarium_public_health.metrics.utilities import ( + get_age_bins, + get_deaths, + get_person_time, + get_years_of_life_lost, +) class MortalityObserver: - """ An observer for cause-specific deaths, ylls, and total person time. + """An observer for cause-specific deaths, ylls, and total person time. By default, this counts cause-specific deaths, years of life lost, and total person time over the full course of the simulation. It can be @@ -32,19 +37,20 @@ class MortalityObserver: by_sex: True """ + configuration_defaults = { - 'metrics': { - 'mortality': { - 'by_age': False, - 'by_year': False, - 'by_sex': False, + "metrics": { + "mortality": { + "by_age": False, + "by_year": False, + "by_sex": False, } } } @property def name(self): - return 'mortality_observer' + return "mortality_observer" def setup(self, builder): self.config = builder.configuration.metrics.mortality @@ -53,39 +59,67 @@ def setup(self, builder): self.start_time = self.clock() self.initial_pop_entrance_time = self.start_time - self.step_size() self.age_bins = get_age_bins(builder) - diseases = builder.components.get_components_by_type((DiseaseState, RiskAttributableDisease)) - self.causes = [c.state_id for c in diseases] + ['other_causes'] - - life_expectancy_data = builder.data.load("population.theoretical_minimum_risk_life_expectancy") - self.life_expectancy = builder.lookup.build_table(life_expectancy_data, key_columns=[], - parameter_columns=['age']) - - columns_required = ['tracked', 'alive', 'entrance_time', 'exit_time', 'cause_of_death', - 'years_of_life_lost', 'age'] + diseases = builder.components.get_components_by_type( + (DiseaseState, RiskAttributableDisease) + ) + self.causes = [c.state_id for c in diseases] + ["other_causes"] + + life_expectancy_data = builder.data.load( + "population.theoretical_minimum_risk_life_expectancy" + ) + self.life_expectancy = builder.lookup.build_table( + life_expectancy_data, key_columns=[], parameter_columns=["age"] + ) + + columns_required = [ + "tracked", + "alive", + "entrance_time", + "exit_time", + "cause_of_death", + "years_of_life_lost", + "age", + ] if self.config.by_sex: - columns_required += ['sex'] + columns_required += ["sex"] self.population_view = builder.population.get_view(columns_required) - builder.value.register_value_modifier('metrics', self.metrics) + builder.value.register_value_modifier("metrics", self.metrics) def metrics(self, index, metrics): pop = self.population_view.get(index) - pop.loc[pop.exit_time.isnull(), 'exit_time'] = self.clock() - - person_time = get_person_time(pop, self.config.to_dict(), self.start_time, self.clock(), self.age_bins) - deaths = get_deaths(pop, self.config.to_dict(), self.start_time, self.clock(), self.age_bins, self.causes) - ylls = get_years_of_life_lost(pop, self.config.to_dict(), self.start_time, self.clock(), - self.age_bins, self.life_expectancy, self.causes) + pop.loc[pop.exit_time.isnull(), "exit_time"] = self.clock() + + person_time = get_person_time( + pop, self.config.to_dict(), self.start_time, self.clock(), self.age_bins + ) + deaths = get_deaths( + pop, + self.config.to_dict(), + self.start_time, + self.clock(), + self.age_bins, + self.causes, + ) + ylls = get_years_of_life_lost( + pop, + self.config.to_dict(), + self.start_time, + self.clock(), + self.age_bins, + self.life_expectancy, + self.causes, + ) metrics.update(person_time) metrics.update(deaths) metrics.update(ylls) - the_living = pop[(pop.alive == 'alive') & pop.tracked] - the_dead = pop[pop.alive == 'dead'] - metrics['years_of_life_lost'] = self.life_expectancy(the_dead.index).sum() - metrics['total_population_living'] = len(the_living) - metrics['total_population_dead'] = len(the_dead) + the_living = pop[(pop.alive == "alive") & pop.tracked] + the_dead = pop[pop.alive == "dead"] + metrics["years_of_life_lost"] = self.life_expectancy(the_dead.index).sum() + metrics["total_population_living"] = len(the_living) + metrics["total_population_dead"] = len(the_dead) return metrics diff --git a/src/vivarium_public_health/metrics/population.py b/src/vivarium_public_health/metrics/population.py index 86c1e0bf2..4b73a9f28 100644 --- a/src/vivarium_public_health/metrics/population.py +++ b/src/vivarium_public_health/metrics/population.py @@ -11,11 +11,11 @@ import pandas as pd -from .utilities import get_age_bins, get_population_counts +from vivarium_public_health.metrics.utilities import get_age_bins, get_population_counts class PopulationObserver: - """ An observer for population counts. + """An observer for population counts. By default, this counts the population at a particular sample date annually. It can be configured to bin the population into age groups and @@ -40,23 +40,24 @@ class PopulationObserver: day: 10 """ + configuration_defaults = { - 'metrics': { - 'population': { - 'by_age': False, - 'by_year': False, - 'by_sex': False, - 'sample_date': { - 'month': 7, - 'day': 1, - } + "metrics": { + "population": { + "by_age": False, + "by_year": False, + "by_sex": False, + "sample_date": { + "month": 7, + "day": 1, + }, } } } @property def name(self): - return 'population_observer' + return "population_observer" def setup(self, builder): self.config = builder.configuration.metrics.population @@ -64,27 +65,31 @@ def setup(self, builder): self.age_bins = get_age_bins(builder) self.population = Counter() - columns_required = ['tracked', 'alive'] + columns_required = ["tracked", "alive"] if self.config.by_age: - columns_required += ['age'] + columns_required += ["age"] if self.config.by_sex: - columns_required += ['sex'] + columns_required += ["sex"] self.population_view = builder.population.get_view(columns_required) - builder.event.register_listener('time_step__prepare', self.on_time_step_prepare) + builder.event.register_listener("time_step__prepare", self.on_time_step_prepare) - builder.value.register_value_modifier('metrics', self.metrics) + builder.value.register_value_modifier("metrics", self.metrics) def on_time_step_prepare(self, event): pop = self.population_view.get(event.index) if self.should_sample(event.time): - population_counts = get_population_counts(pop, self.config.to_dict(), event.time, self.age_bins) + population_counts = get_population_counts( + pop, self.config.to_dict(), event.time, self.age_bins + ) self.population.update(population_counts) def should_sample(self, event_time: pd.Timestamp) -> bool: """Returns true if we should sample on this time step.""" - sample_date = pd.Timestamp(event_time.year, self.config.sample_date.month, self.config.sample_date.day) + sample_date = pd.Timestamp( + event_time.year, self.config.sample_date.month, self.config.sample_date.day + ) return self.clock() <= sample_date < event_time def metrics(self, index, metrics): diff --git a/src/vivarium_public_health/metrics/risk.py b/src/vivarium_public_health/metrics/risk.py index 7b5a27a3f..3885f54ca 100644 --- a/src/vivarium_public_health/metrics/risk.py +++ b/src/vivarium_public_health/metrics/risk.py @@ -10,15 +10,18 @@ from typing import Dict import pandas as pd - from vivarium.framework.engine import Builder from vivarium.framework.event import Event -from vivarium_public_health.metrics.utilities import get_age_bins, get_prevalent_cases, get_state_person_time +from vivarium_public_health.metrics.utilities import ( + get_age_bins, + get_prevalent_cases, + get_state_person_time, +) class CategoricalRiskObserver: - """ An observer for a categorical risk factor. + """An observer for a categorical risk factor. Observes category person time for a risk factor. @@ -41,18 +44,19 @@ class CategoricalRiskObserver: month: 12 day: 31 """ + configuration_defaults = { - 'metrics': { - 'risk': { - 'by_age': False, - 'by_year': False, - 'by_sex': False, - 'sample_exposure': { - 'sample': False, - 'date': { - 'month': 7, - 'day': 1, - } + "metrics": { + "risk": { + "by_age": False, + "by_year": False, + "by_sex": False, + "sample_exposure": { + "sample": False, + "date": { + "month": 7, + "day": 1, + }, }, } } @@ -68,54 +72,72 @@ def __init__(self, risk: str): """ self.risk = risk self.configuration_defaults = { - 'metrics': { - f'{self.risk}': CategoricalRiskObserver.configuration_defaults['metrics']['risk'] + "metrics": { + f"{self.risk}": CategoricalRiskObserver.configuration_defaults["metrics"][ + "risk" + ] } } @property def name(self): - return f'categorical_risk_observer.{self.risk}' + return f"categorical_risk_observer.{self.risk}" # noinspection PyAttributeOutsideInit def setup(self, builder: Builder): self.data = {} - self.config = builder.configuration[f'metrics'][f'{self.risk}'] + self.config = builder.configuration[f"metrics"][f"{self.risk}"] self.clock = builder.time.clock() - self.categories = builder.data.load(f'risk_factor.{self.risk}.categories') + self.categories = builder.data.load(f"risk_factor.{self.risk}.categories") self.age_bins = get_age_bins(builder) self.person_time = Counter() self.sampled_exposure = Counter() - columns_required = ['alive'] + columns_required = ["alive"] if self.config.by_age: - columns_required += ['age'] + columns_required += ["age"] if self.config.by_sex: - columns_required += ['sex'] + columns_required += ["sex"] self.population_view = builder.population.get_view(columns_required) - self.exposure = builder.value.get_value(f'{self.risk}.exposure') - builder.value.register_value_modifier('metrics', self.metrics) - builder.event.register_listener('time_step__prepare', self.on_time_step_prepare) + self.exposure = builder.value.get_value(f"{self.risk}.exposure") + builder.value.register_value_modifier("metrics", self.metrics) + builder.event.register_listener("time_step__prepare", self.on_time_step_prepare) def on_time_step_prepare(self, event: Event): - pop = pd.concat([self.population_view.get(event.index), pd.Series(self.exposure(event.index), name=self.risk)], - axis=1) + pop = pd.concat( + [ + self.population_view.get(event.index), + pd.Series(self.exposure(event.index), name=self.risk), + ], + axis=1, + ) for category in self.categories: - state_person_time_this_step = get_state_person_time(pop, self.config, self.risk, category, - self.clock().year, event.step_size, self.age_bins) + state_person_time_this_step = get_state_person_time( + pop, + self.config, + self.risk, + category, + self.clock().year, + event.step_size, + self.age_bins, + ) self.person_time.update(state_person_time_this_step) if self._should_sample(event.time): - sampled_exposure = get_prevalent_cases(pop, self.config.to_dict(), self.risk, event.time, self.age_bins) + sampled_exposure = get_prevalent_cases( + pop, self.config.to_dict(), self.risk, event.time, self.age_bins + ) self.sampled_exposure.update(sampled_exposure) def _should_sample(self, event_time: pd.Timestamp) -> bool: """Returns true if we should sample on this time step.""" should_sample = self.config.sample_exposure.sample if should_sample: - sample_date = pd.Timestamp(year=event_time.year, **self.config.sample_prevalence.date.to_dict()) + sample_date = pd.Timestamp( + year=event_time.year, **self.config.sample_prevalence.date.to_dict() + ) should_sample &= self.clock() <= sample_date < event_time return should_sample diff --git a/src/vivarium_public_health/metrics/utilities.py b/src/vivarium_public_health/metrics/utilities.py index e4786ea52..71e9c385f 100644 --- a/src/vivarium_public_health/metrics/utilities.py +++ b/src/vivarium_public_health/metrics/utilities.py @@ -9,7 +9,7 @@ """ from collections import ChainMap from string import Template -from typing import Union, List, Tuple, Dict, Callable +from typing import Callable, Dict, List, Tuple, Union import numpy as np import pandas as pd @@ -18,8 +18,8 @@ from vivarium_public_health.utilities import to_years -_MIN_AGE = 0. -_MAX_AGE = 150. +_MIN_AGE = 0.0 +_MAX_AGE = 150.0 _MIN_YEAR = 1900 _MAX_YEAR = 2100 @@ -46,16 +46,17 @@ class QueryString(str): 'abc and def' """ - def __add__(self, other: Union[str, 'QueryString']) -> 'QueryString': + + def __add__(self, other: Union[str, "QueryString"]) -> "QueryString": if self: if other: - return QueryString(str(self) + ' and ' + str(other)) + return QueryString(str(self) + " and " + str(other)) else: return self else: return QueryString(other) - def __radd__(self, other: Union[str, 'QueryString']) -> 'QueryString': + def __radd__(self, other: Union[str, "QueryString"]) -> "QueryString": return QueryString(other) + self @@ -65,28 +66,31 @@ class SubstituteString(str): Meant to be used with the OutputTemplate. """ - def substitute(self, *_, **__) -> 'SubstituteString': + + def substitute(self, *_, **__) -> "SubstituteString": """No-op method for consistency with OutputTemplate.""" return self class OutputTemplate(Template): """Output string template that enforces standardized formatting.""" + @staticmethod def format_template_value(value): """Formatting helper method for substituting values into a template.""" - return str(value).replace(' ', '_').lower() + return str(value).replace(" ", "_").lower() @staticmethod def get_mapping(*args, **kws): """Gets a consistent mapping from args passed to substitute.""" # This is copied directly from the first part of Template.substitute if not args: - raise TypeError("descriptor 'substitute' of 'Template' object " - "needs an argument") + raise TypeError( + "descriptor 'substitute' of 'Template' object " "needs an argument" + ) self, *args = args # allow the "self" keyword be passed if len(args) > 1: - raise TypeError('Too many positional arguments') + raise TypeError("Too many positional arguments") if not args: mapping = dict(kws) elif kws: @@ -95,7 +99,7 @@ def get_mapping(*args, **kws): mapping = args[0] return self, mapping - def substitute(*args, **kws) -> Union[SubstituteString, 'OutputTemplate']: + def substitute(*args, **kws) -> Union[SubstituteString, "OutputTemplate"]: """Substitutes provided values into the template. Users are allowed to pass any dictionary like object whose keys match @@ -140,18 +144,18 @@ def get_age_bins(builder) -> pd.DataFrame: and ``age_end``. """ - age_bins = builder.data.load('population.age_bins') + age_bins = builder.data.load("population.age_bins") # Works based on the fact that currently only models with age_start = 0 can include fertility age_start = builder.configuration.population.age_start min_bin_start = age_bins.age_start[np.asscalar(np.digitize(age_start, age_bins.age_end))] age_bins = age_bins[age_bins.age_start >= min_bin_start] - age_bins.loc[age_bins.age_start < age_start, 'age_start'] = age_start + age_bins.loc[age_bins.age_start < age_start, "age_start"] = age_start exit_age = builder.configuration.population.exit_age if exit_age: age_bins = age_bins[age_bins.age_start < exit_age] - age_bins.loc[age_bins.age_end > exit_age, 'age_end'] = exit_age + age_bins.loc[age_bins.age_end > exit_age, "age_end"] = exit_age return age_bins @@ -176,20 +180,19 @@ def get_output_template(by_age: bool, by_sex: bool, by_year: bool, **_) -> Outpu A template string with measure and possibly additional criteria. """ - template = '${measure}' + template = "${measure}" if by_year: - template += '_in_${year}' + template += "_in_${year}" if by_sex: - template += '_among_${sex}' + template += "_among_${sex}" if by_age: - template += '_in_age_group_${age_group}' + template += "_in_age_group_${age_group}" return OutputTemplate(template) -def get_age_sex_filter_and_iterables(config: Dict[str, bool], - age_bins: pd.DataFrame, - in_span: bool = False) -> Tuple[QueryString, - Tuple[List[Tuple[str, pd.Series]], List[str]]]: +def get_age_sex_filter_and_iterables( + config: Dict[str, bool], age_bins: pd.DataFrame, in_span: bool = False +) -> Tuple[QueryString, Tuple[List[Tuple[str, pd.Series]], List[str]]]: """Constructs a filter and a set of iterables for age and sex. The constructed filter and iterables are based on configuration for the @@ -218,27 +221,29 @@ def get_age_sex_filter_and_iterables(config: Dict[str, bool], """ age_sex_filter = QueryString("") - if config['by_age']: - ages = list(age_bins.set_index('age_group_name').iterrows()) + if config["by_age"]: + ages = list(age_bins.set_index("age_group_name").iterrows()) if in_span: - age_sex_filter += '{age_start} < age_at_span_end and age_at_span_start < {age_end}' + age_sex_filter += ( + "{age_start} < age_at_span_end and age_at_span_start < {age_end}" + ) else: - age_sex_filter += '{age_start} <= age and age < {age_end}' + age_sex_filter += "{age_start} <= age and age < {age_end}" else: - ages = [('all_ages', pd.Series({'age_start': _MIN_AGE, 'age_end': _MAX_AGE}))] + ages = [("all_ages", pd.Series({"age_start": _MIN_AGE, "age_end": _MAX_AGE}))] - if config['by_sex']: - sexes = ['Male', 'Female'] + if config["by_sex"]: + sexes = ["Male", "Female"] age_sex_filter += 'sex == "{sex}"' else: - sexes = ['Both'] + sexes = ["Both"] return age_sex_filter, (ages, sexes) -def get_time_iterable(config: Dict[str, bool], - sim_start: pd.Timestamp, - sim_end: pd.Timestamp) -> List[Tuple[str, Tuple[pd.Timestamp, pd.Timestamp]]]: +def get_time_iterable( + config: Dict[str, bool], sim_start: pd.Timestamp, sim_end: pd.Timestamp +) -> List[Tuple[str, Tuple[pd.Timestamp, pd.Timestamp]]]: """Constructs an iterable for time bins. The constructed iterable are based on configuration for the observer @@ -261,21 +266,29 @@ def get_time_iterable(config: Dict[str, bool], for the observers. """ - if config['by_year']: - time_spans = [(year, (pd.Timestamp(f'1-1-{year}'), pd.Timestamp(f'1-1-{year + 1}'))) - for year in range(sim_start.year, sim_end.year + 1)] + if config["by_year"]: + time_spans = [ + (year, (pd.Timestamp(f"1-1-{year}"), pd.Timestamp(f"1-1-{year + 1}"))) + for year in range(sim_start.year, sim_end.year + 1) + ] else: - time_spans = [('all_years', (pd.Timestamp(f'1-1-{_MIN_YEAR}'), pd.Timestamp(f'1-1-{_MAX_YEAR}')))] + time_spans = [ + ( + "all_years", + (pd.Timestamp(f"1-1-{_MIN_YEAR}"), pd.Timestamp(f"1-1-{_MAX_YEAR}")), + ) + ] return time_spans -def get_group_counts(pop: pd.DataFrame, - base_filter: str, - base_key: OutputTemplate, - config: Dict[str, bool], - age_bins: pd.DataFrame, - aggregate: Callable[[pd.DataFrame], Union[int, float]] = len - ) -> Dict[Union[SubstituteString, OutputTemplate], Union[int, float]]: +def get_group_counts( + pop: pd.DataFrame, + base_filter: str, + base_key: OutputTemplate, + config: Dict[str, bool], + age_bins: pd.DataFrame, + aggregate: Callable[[pd.DataFrame], Union[int, float]] = len, +) -> Dict[Union[SubstituteString, OutputTemplate], Union[int, float]]: """Gets a count of people in a custom subgroup. The user is responsible for providing a default filter (e.g. only alive @@ -316,7 +329,12 @@ def get_group_counts(pop: pd.DataFrame, for group, age_group in ages: start, end = age_group.age_start, age_group.age_end for sex in sexes: - filter_kwargs = {'age_start': start, 'age_end': end, 'sex': sex, 'age_group': group} + filter_kwargs = { + "age_start": start, + "age_end": end, + "sex": sex, + "age_group": group, + } group_key = base_key.substitute(**filter_kwargs) group_filter = base_filter.format(**filter_kwargs) in_group = pop.query(group_filter) if group_filter and not pop.empty else pop @@ -327,40 +345,59 @@ def get_group_counts(pop: pd.DataFrame, def get_susceptible_person_time(pop, config, disease, current_year, step_size, age_bins): - base_key = get_output_template(**config).substitute(measure=f'{disease}_susceptible_person_time', year=current_year) + base_key = get_output_template(**config).substitute( + measure=f"{disease}_susceptible_person_time", year=current_year + ) base_filter = QueryString(f'alive == "alive" and {disease} == "susceptible_to_{disease}"') - person_time = get_group_counts(pop, base_filter, base_key, config, age_bins, - aggregate=lambda x: len(x) * to_years(step_size)) + person_time = get_group_counts( + pop, + base_filter, + base_key, + config, + age_bins, + aggregate=lambda x: len(x) * to_years(step_size), + ) return person_time def get_disease_event_counts(pop, config, disease, event_time, age_bins): - base_key = get_output_template(**config).substitute(measure=f'{disease}_counts', year=event_time.year) + base_key = get_output_template(**config).substitute( + measure=f"{disease}_counts", year=event_time.year + ) # Can't use query with time stamps, so filter - pop = pop.loc[pop[f'{disease}_event_time'] == event_time] - base_filter = QueryString('') + pop = pop.loc[pop[f"{disease}_event_time"] == event_time] + base_filter = QueryString("") return get_group_counts(pop, base_filter, base_key, config, age_bins) def get_prevalent_cases(pop, config, disease, event_time, age_bins): config = config.copy() - config['by_year'] = True # This is always an annual point estimate - base_key = get_output_template(**config).substitute(measure=f'{disease}_prevalent_cases', year=event_time.year) + config["by_year"] = True # This is always an annual point estimate + base_key = get_output_template(**config).substitute( + measure=f"{disease}_prevalent_cases", year=event_time.year + ) base_filter = QueryString(f'alive == "alive" and {disease} != "susceptible_to_{disease}"') return get_group_counts(pop, base_filter, base_key, config, age_bins) def get_population_counts(pop, config, event_time, age_bins): config = config.copy() - config['by_year'] = True # This is always an annual point estimate - base_key = get_output_template(**config).substitute(measure=f'population_count', year=event_time.year) + config["by_year"] = True # This is always an annual point estimate + base_key = get_output_template(**config).substitute( + measure=f"population_count", year=event_time.year + ) base_filter = QueryString(f'alive == "alive"') return get_group_counts(pop, base_filter, base_key, config, age_bins) -def get_person_time(pop: pd.DataFrame, config: Dict[str, bool], sim_start: pd.Timestamp, - sim_end: pd.Timestamp, age_bins: pd.DataFrame) -> Dict[str, float]: - base_key = get_output_template(**config).substitute(measure='person_time') +def get_person_time( + pop: pd.DataFrame, + config: Dict[str, bool], + sim_start: pd.Timestamp, + sim_end: pd.Timestamp, + age_bins: pd.DataFrame, +) -> Dict[str, float]: + base_key = get_output_template(**config).substitute(measure="person_time") base_filter = QueryString("") time_spans = get_time_iterable(config, sim_start, sim_end) @@ -368,12 +405,16 @@ def get_person_time(pop: pd.DataFrame, config: Dict[str, bool], sim_start: pd.Ti for year, (t_start, t_end) in time_spans: year_key = base_key.substitute(year=year) lived_in_span = get_lived_in_span(pop, t_start, t_end) - person_time_in_span = get_person_time_in_span(lived_in_span, base_filter, year_key, config, age_bins) + person_time_in_span = get_person_time_in_span( + lived_in_span, base_filter, year_key, config, age_bins + ) person_time.update(person_time_in_span) return person_time -def get_lived_in_span(pop: pd.DataFrame, t_start: pd.Timestamp, t_end: pd.Timestamp) -> pd.DataFrame: +def get_lived_in_span( + pop: pd.DataFrame, t_start: pd.Timestamp, t_end: pd.Timestamp +) -> pd.DataFrame: """Gets a subset of the population that lived in the time span. Parameters @@ -398,23 +439,29 @@ def get_lived_in_span(pop: pd.DataFrame, t_start: pd.Timestamp, t_end: pd.Timest be greater than the age at the simulant's exit time. """ - lived_in_span = pop.loc[(t_start < pop['exit_time']) & (pop['entrance_time'] < t_end)] + lived_in_span = pop.loc[(t_start < pop["exit_time"]) & (pop["entrance_time"] < t_end)] span_entrance_time = lived_in_span.entrance_time.copy() span_entrance_time.loc[t_start > span_entrance_time] = t_start span_exit_time = lived_in_span.exit_time.copy() span_exit_time.loc[t_end < span_exit_time] = t_end - lived_in_span.loc[:, 'age_at_span_end'] = lived_in_span.age - to_years(lived_in_span.exit_time - - span_exit_time) - lived_in_span.loc[:, 'age_at_span_start'] = lived_in_span.age - to_years(lived_in_span.exit_time - - span_entrance_time) + lived_in_span.loc[:, "age_at_span_end"] = lived_in_span.age - to_years( + lived_in_span.exit_time - span_exit_time + ) + lived_in_span.loc[:, "age_at_span_start"] = lived_in_span.age - to_years( + lived_in_span.exit_time - span_entrance_time + ) return lived_in_span -def get_person_time_in_span(lived_in_span: pd.DataFrame, base_filter: QueryString, - span_key: OutputTemplate, config: Dict[str, bool], - age_bins: pd.DataFrame) -> Dict[Union[SubstituteString, OutputTemplate], float]: +def get_person_time_in_span( + lived_in_span: pd.DataFrame, + base_filter: QueryString, + span_key: OutputTemplate, + config: Dict[str, bool], + age_bins: pd.DataFrame, +) -> Dict[Union[SubstituteString, OutputTemplate], float]: """Counts the amount of person time lived in a particular time span. Parameters @@ -442,19 +489,27 @@ def get_person_time_in_span(lived_in_span: pd.DataFrame, base_filter: QueryStrin corresponds to a particular demographic group. """ person_time = {} - age_sex_filter, (ages, sexes) = get_age_sex_filter_and_iterables(config, age_bins, in_span=True) + age_sex_filter, (ages, sexes) = get_age_sex_filter_and_iterables( + config, age_bins, in_span=True + ) base_filter += age_sex_filter for group, age_bin in ages: a_start, a_end = age_bin.age_start, age_bin.age_end for sex in sexes: - filter_kwargs = {'sex': sex, 'age_start': a_start, - 'age_end': a_end, 'age_group': group} + filter_kwargs = { + "sex": sex, + "age_start": a_start, + "age_end": a_end, + "age_group": group, + } key = span_key.substitute(**filter_kwargs) group_filter = base_filter.format(**filter_kwargs) - in_group = lived_in_span.query(group_filter) if group_filter else lived_in_span.copy() + in_group = ( + lived_in_span.query(group_filter) if group_filter else lived_in_span.copy() + ) age_start = np.maximum(in_group.age_at_span_start, a_start) age_end = np.minimum(in_group.age_at_span_end, a_end) @@ -463,8 +518,14 @@ def get_person_time_in_span(lived_in_span: pd.DataFrame, base_filter: QueryStrin return person_time -def get_deaths(pop: pd.DataFrame, config: Dict[str, bool], sim_start: pd.Timestamp, - sim_end: pd.Timestamp, age_bins: pd.DataFrame, causes: List[str]) -> Dict[str, int]: +def get_deaths( + pop: pd.DataFrame, + config: Dict[str, bool], + sim_start: pd.Timestamp, + sim_end: pd.Timestamp, + age_bins: pd.DataFrame, + causes: List[str], +) -> Dict[str, int]: """Counts the number of deaths by cause. Parameters @@ -502,20 +563,24 @@ def get_deaths(pop: pd.DataFrame, config: Dict[str, bool], sim_start: pd.Timesta for year, (t_start, t_end) in time_spans: died_in_span = pop[(t_start <= pop.exit_time) & (pop.exit_time < t_end)] for cause in causes: - cause_year_key = base_key.substitute(measure=f'death_due_to_{cause}', year=year) + cause_year_key = base_key.substitute(measure=f"death_due_to_{cause}", year=year) cause_filter = base_filter.format(cause=cause) - group_deaths = get_group_counts(died_in_span, cause_filter, cause_year_key, config, age_bins) + group_deaths = get_group_counts( + died_in_span, cause_filter, cause_year_key, config, age_bins + ) deaths.update(group_deaths) return deaths -def get_years_of_life_lost(pop: pd.DataFrame, - config: Dict[str, bool], - sim_start: pd.Timestamp, - sim_end: pd.Timestamp, - age_bins: pd.DataFrame, - life_expectancy: LookupTable, - causes: List[str]) -> Dict[str, float]: +def get_years_of_life_lost( + pop: pd.DataFrame, + config: Dict[str, bool], + sim_start: pd.Timestamp, + sim_end: pd.Timestamp, + age_bins: pd.DataFrame, + life_expectancy: LookupTable, + causes: List[str], +) -> Dict[str, float]: """Counts the years of life lost by cause. Parameters @@ -556,21 +621,29 @@ def get_years_of_life_lost(pop: pd.DataFrame, for year, (t_start, t_end) in time_spans: died_in_span = pop[(t_start <= pop.exit_time) & (pop.exit_time < t_end)] for cause in causes: - cause_year_key = base_key.substitute(measure=f'ylls_due_to_{cause}', year=year) + cause_year_key = base_key.substitute(measure=f"ylls_due_to_{cause}", year=year) cause_filter = base_filter.format(cause=cause) - group_ylls = get_group_counts(died_in_span, cause_filter, cause_year_key, config, age_bins, - aggregate=lambda subgroup: sum(life_expectancy(subgroup.index))) + group_ylls = get_group_counts( + died_in_span, + cause_filter, + cause_year_key, + config, + age_bins, + aggregate=lambda subgroup: sum(life_expectancy(subgroup.index)), + ) years_of_life_lost.update(group_ylls) return years_of_life_lost -def get_years_lived_with_disability(pop: pd.DataFrame, - config: Dict[str, bool], - current_year: int, - step_size: pd.Timedelta, - age_bins: pd.DataFrame, - disability_weights: Dict[str, Pipeline], - causes: List[str]) -> Dict[str, float]: +def get_years_lived_with_disability( + pop: pd.DataFrame, + config: Dict[str, bool], + current_year: int, + step_size: pd.Timedelta, + age_bins: pd.DataFrame, + disability_weights: Dict[str, Pipeline], + causes: List[str], +) -> Dict[str, float]: """Counts the years lived with disability by cause in the time step. Parameters @@ -605,13 +678,15 @@ def get_years_lived_with_disability(pop: pd.DataFrame, years_lived_with_disability = {} for cause in causes: - cause_key = base_key.substitute(measure=f'ylds_due_to_{cause}') + cause_key = base_key.substitute(measure=f"ylds_due_to_{cause}") def count_ylds(sub_group): """Counts ylds attributable to a cause in the time step.""" return sum(disability_weights[cause](sub_group.index) * to_years(step_size)) - group_ylds = get_group_counts(pop, base_filter, cause_key, config, age_bins, aggregate=count_ylds) + group_ylds = get_group_counts( + pop, base_filter, cause_key, config, age_bins, aggregate=count_ylds + ) years_lived_with_disability.update(group_ylds) return years_lived_with_disability @@ -621,46 +696,67 @@ def clean_cause_of_death(pop: pd.DataFrame) -> pd.DataFrame: """Standardizes cause of death names to all read ``death_due_to_cause``.""" def _clean(cod: str) -> str: - if 'death' in cod or 'dead' in cod: + if "death" in cod or "dead" in cod: pass else: - cod = f'death_due_to_{cod}' + cod = f"death_due_to_{cod}" return cod pop.cause_of_death = pop.cause_of_death.apply(_clean) return pop -def get_state_person_time(pop: pd.DataFrame, config: Dict[str, bool], - state_machine: str, state: str, current_year: Union[str, int], - step_size: pd.Timedelta, age_bins: pd.DataFrame) -> Dict[str, float]: +def get_state_person_time( + pop: pd.DataFrame, + config: Dict[str, bool], + state_machine: str, + state: str, + current_year: Union[str, int], + step_size: pd.Timedelta, + age_bins: pd.DataFrame, +) -> Dict[str, float]: """Custom person time getter that handles state column name assumptions""" - base_key = get_output_template(**config).substitute(measure=f'{state}_person_time', - year=current_year) + base_key = get_output_template(**config).substitute( + measure=f"{state}_person_time", year=current_year + ) base_filter = QueryString(f'alive == "alive" and {state_machine} == "{state}"') - person_time = get_group_counts(pop, base_filter, base_key, config, age_bins, - aggregate=lambda x: len(x) * to_years(step_size)) + person_time = get_group_counts( + pop, + base_filter, + base_key, + config, + age_bins, + aggregate=lambda x: len(x) * to_years(step_size), + ) return person_time class TransitionString(str): - def __new__(cls, value): # noinspection PyArgumentList obj = str.__new__(cls, value.lower()) - obj.from_state, obj.to_state = value.split('_TO_') + obj.from_state, obj.to_state = value.split("_TO_") return obj -def get_transition_count(pop: pd.DataFrame, config: Dict[str, bool], - state_machine: str, transition: TransitionString, - event_time: pd.Timestamp, age_bins: pd.DataFrame) -> Dict[str, float]: +def get_transition_count( + pop: pd.DataFrame, + config: Dict[str, bool], + state_machine: str, + transition: TransitionString, + event_time: pd.Timestamp, + age_bins: pd.DataFrame, +) -> Dict[str, float]: """Counts transitions that occurred this step.""" - event_this_step = ((pop[f'previous_{state_machine}'] == transition.from_state) - & (pop[state_machine] == transition.to_state)) + event_this_step = (pop[f"previous_{state_machine}"] == transition.from_state) & ( + pop[state_machine] == transition.to_state + ) transitioned_pop = pop.loc[event_this_step] - base_key = get_output_template(**config).substitute(measure=f'{transition}_event_count', - year=event_time.year) - base_filter = QueryString('') - transition_count = get_group_counts(transitioned_pop, base_filter, base_key, config, age_bins) + base_key = get_output_template(**config).substitute( + measure=f"{transition}_event_count", year=event_time.year + ) + base_filter = QueryString("") + transition_count = get_group_counts( + transitioned_pop, base_filter, base_key, config, age_bins + ) return transition_count diff --git a/src/vivarium_public_health/mslt/delay.py b/src/vivarium_public_health/mslt/delay.py index 2da273f63..2bcc9cad7 100644 --- a/src/vivarium_public_health/mslt/delay.py +++ b/src/vivarium_public_health/mslt/delay.py @@ -7,8 +7,8 @@ lifetable simulation. """ -import pandas as pd import numpy as np +import pandas as pd class DelayedRisk: @@ -92,12 +92,12 @@ def __init__(self, name: str): self._name = name self.configuration_defaults = { name: { - 'constant_prevalence': False, - 'tobacco_tax': False, - 'delay': 20, + "constant_prevalence": False, + "tobacco_tax": False, + "delay": 20, }, } - + @property def name(self): return self._name @@ -116,50 +116,54 @@ def setup(self, builder): # Determine whether smoking prevalence should change over time. # The alternative scenario is that there is no remission; all people # who begin smoking will continue to smoke. - self.constant_prevalence = self.config[self.name]['constant_prevalence'] + self.constant_prevalence = self.config[self.name]["constant_prevalence"] - self.tobacco_tax = self.config[self.name]['tobacco_tax'] + self.tobacco_tax = self.config[self.name]["tobacco_tax"] - self.bin_years = int(self.config[self.name]['delay']) + self.bin_years = int(self.config[self.name]["delay"]) # Load the initial prevalence. - prev_data = pivot_load(builder,f'risk_factor.{self.name}.prevalence') - self.initial_prevalence = builder.lookup.build_table(prev_data, - key_columns=['sex'], - parameter_columns=['age','year']) + prev_data = pivot_load(builder, f"risk_factor.{self.name}.prevalence") + self.initial_prevalence = builder.lookup.build_table( + prev_data, key_columns=["sex"], parameter_columns=["age", "year"] + ) # Load the incidence rates for the BAU and intervention scenarios. inc_data = builder.lookup.build_table( - pivot_load(builder,f'risk_factor.{self.name}.incidence'), - key_columns=['sex'], - parameter_columns=['age','year'] + pivot_load(builder, f"risk_factor.{self.name}.incidence"), + key_columns=["sex"], + parameter_columns=["age", "year"], ) - inc_name = '{}.incidence'.format(self.name) - inc_int_name = '{}_intervention.incidence'.format(self.name) + inc_name = "{}.incidence".format(self.name) + inc_int_name = "{}_intervention.incidence".format(self.name) self.incidence = builder.value.register_rate_producer(inc_name, source=inc_data) - self.int_incidence = builder.value.register_rate_producer(inc_int_name, source=inc_data) + self.int_incidence = builder.value.register_rate_producer( + inc_int_name, source=inc_data + ) # Load the remission rates for the BAU and intervention scenarios. - rem_df = pivot_load(builder,f'risk_factor.{self.name}.remission') + rem_df = pivot_load(builder, f"risk_factor.{self.name}.remission") # In the constant-prevalence case, assume there is no remission. if self.constant_prevalence: - rem_df['value'] = 0.0 - rem_data = builder.lookup.build_table(rem_df, - key_columns=['sex'], - parameter_columns=['age','year']) - rem_name = '{}.remission'.format(self.name) - rem_int_name = '{}_intervention.remission'.format(self.name) + rem_df["value"] = 0.0 + rem_data = builder.lookup.build_table( + rem_df, key_columns=["sex"], parameter_columns=["age", "year"] + ) + rem_name = "{}.remission".format(self.name) + rem_int_name = "{}_intervention.remission".format(self.name) self.remission = builder.value.register_rate_producer(rem_name, source=rem_data) - self.int_remission = builder.value.register_rate_producer(rem_int_name, source=rem_data) + self.int_remission = builder.value.register_rate_producer( + rem_int_name, source=rem_data + ) # We apply separate mortality rates to the different exposure bins. # This requires having access to the life table mortality rate, and # also the relative risks associated with each bin. - self.acm_rate = builder.value.get_value('mortality_rate') - mort_rr_data = pivot_load(builder,f'risk_factor.{self.name}.mortality_relative_risk') - self.mortality_rr = builder.lookup.build_table(mort_rr_data, - key_columns=['sex'], - parameter_columns=['age','year']) + self.acm_rate = builder.value.get_value("mortality_rate") + mort_rr_data = pivot_load(builder, f"risk_factor.{self.name}.mortality_relative_risk") + self.mortality_rr = builder.lookup.build_table( + mort_rr_data, key_columns=["sex"], parameter_columns=["age", "year"] + ) # Register a modifier for each disease affected by this delayed risk. diseases = self.config[self.name].affects.keys() @@ -167,74 +171,75 @@ def setup(self, builder): self.register_modifier(builder, disease) # Load the disease-specific relative risks for each exposure bin. - dis_rr_data = pivot_load(builder,f'risk_factor.{self.name}.disease_relative_risk') + dis_rr_data = pivot_load(builder, f"risk_factor.{self.name}.disease_relative_risk") # Check that the relative risk table includes required columns. - key_columns = ['age_start', 'age_end', 'sex', - 'year_start', 'year_end'] + key_columns = ["age_start", "age_end", "sex", "year_start", "year_end"] if set(key_columns) & set(dis_rr_data.columns) != set(key_columns): # Fallback option, handle tables that do not define bin edges. - key_columns = ['age', 'sex', 'year'] + key_columns = ["age", "sex", "year"] if set(key_columns) & set(dis_rr_data.columns) != set(key_columns): - msg = 'Missing index columns for disease-specific relative risks' + msg = "Missing index columns for disease-specific relative risks" raise ValueError(msg) self.dis_rr = {} for disease in diseases: - dis_columns = [c for c in dis_rr_data.columns - if c.startswith(disease)] - dis_keys = [c for c in dis_rr_data.columns - if c in key_columns] + dis_columns = [c for c in dis_rr_data.columns if c.startswith(disease)] + dis_keys = [c for c in dis_rr_data.columns if c in key_columns] if not dis_columns or not dis_keys: - msg = 'No {} relative risks for disease {}' + msg = "No {} relative risks for disease {}" raise ValueError(msg.format(self.name, disease)) rr_data = dis_rr_data.loc[:, dis_keys + dis_columns] - dis_prefix = '{}_'.format(disease) - bau_prefix = '{}.'.format(self.name) - int_prefix = '{}_intervention.'.format(self.name) - bau_col = {c: c.replace(dis_prefix, bau_prefix).replace('post_', '') - for c in dis_columns} - int_col = {c: c.replace(dis_prefix, int_prefix).replace('post_', '') - for c in dis_columns} + dis_prefix = "{}_".format(disease) + bau_prefix = "{}.".format(self.name) + int_prefix = "{}_intervention.".format(self.name) + bau_col = { + c: c.replace(dis_prefix, bau_prefix).replace("post_", "") for c in dis_columns + } + int_col = { + c: c.replace(dis_prefix, int_prefix).replace("post_", "") for c in dis_columns + } for column in dis_columns: # NOTE: avoid SettingWithCopyWarning rr_data.loc[:, int_col[column]] = rr_data[column] rr_data = rr_data.rename(columns=bau_col) - self.dis_rr[disease] = builder.lookup.build_table(rr_data, - key_columns=['sex'], - parameter_columns=['age','year']) + self.dis_rr[disease] = builder.lookup.build_table( + rr_data, key_columns=["sex"], parameter_columns=["age", "year"] + ) # Add a handler to create the exposure bin columns. - req_columns = ['age', 'sex', 'population'] + req_columns = ["age", "sex", "population"] new_columns = self.get_bin_names() builder.population.initializes_simulants( self.on_initialize_simulants, creates_columns=new_columns, - requires_columns=req_columns) + requires_columns=req_columns, + ) # Load the effects of a tobacco tax. - tax_inc = pivot_load(builder,f'risk_factor.{self.name}.tax_effect_incidence') - tax_rem = pivot_load(builder,f'risk_factor.{self.name}.tax_effect_remission') - self.tax_effect_inc = builder.lookup.build_table(tax_inc, - key_columns=['sex'], - parameter_columns=['age','year']) - self.tax_effect_rem = builder.lookup.build_table(tax_rem, - key_columns=['sex'], - parameter_columns=['age','year']) + tax_inc = pivot_load(builder, f"risk_factor.{self.name}.tax_effect_incidence") + tax_rem = pivot_load(builder, f"risk_factor.{self.name}.tax_effect_remission") + self.tax_effect_inc = builder.lookup.build_table( + tax_inc, key_columns=["sex"], parameter_columns=["age", "year"] + ) + self.tax_effect_rem = builder.lookup.build_table( + tax_rem, key_columns=["sex"], parameter_columns=["age", "year"] + ) # Add a handler to move people from one bin to the next. - builder.event.register_listener('time_step__prepare', - self.on_time_step_prepare) + builder.event.register_listener("time_step__prepare", self.on_time_step_prepare) # Define the columns that we need to access during the simulation. view_columns = req_columns + new_columns self.population_view = builder.population.get_view(view_columns) - mortality_data = pivot_load(builder,'cause.all_causes.mortality') + mortality_data = pivot_load(builder, "cause.all_causes.mortality") self.tobacco_acmr = builder.value.register_rate_producer( - 'tobacco_acmr', source=builder.lookup.build_table(mortality_data, - key_columns=['sex'], - parameter_columns=['age','year'])) + "tobacco_acmr", + source=builder.lookup.build_table( + mortality_data, key_columns=["sex"], parameter_columns=["age", "year"] + ), + ) def get_bin_names(self): """Return the bin names for both the BAU and the intervention scenario. @@ -257,9 +262,9 @@ def get_bin_names(self): delay_bins = [str(0)] else: delay_bins = [str(s) for s in range(self.bin_years + 2)] - bins = ['no', 'yes'] + delay_bins - bau_bins = ['{}.{}'.format(self.name, bin) for bin in bins] - int_bins = ['{}_intervention.{}'.format(self.name, bin) for bin in bins] + bins = ["no", "yes"] + delay_bins + bau_bins = ["{}.{}".format(self.name, bin) for bin in bins] + int_bins = ["{}_intervention.{}".format(self.name, bin) for bin in bins] all_bins = bau_bins + int_bins return all_bins @@ -283,18 +288,21 @@ def on_initialize_simulants(self, pop_data): # NOTE: the number of current smokers is defined at the middle of each # year; i.e., it corresponds to the person_years. bau_acmr = self.tobacco_acmr.source(pop_data.index) - bau_probability_of_death = 1 - np.exp(- bau_acmr) + bau_probability_of_death = 1 - np.exp(-bau_acmr) pop.population *= 1 - 0.5 * bau_probability_of_death - prev = self.initial_prevalence(pop_data.index).mul(pop['population'], axis=0) + prev = self.initial_prevalence(pop_data.index).mul(pop["population"], axis=0) self.population_view.update(prev) # Rename the columns and apply the same initial prevalence for the # intervention. - bau_prefix = '{}.'.format(self.name) - int_prefix = '{}_intervention.'.format(self.name) - rename_to = {c: c.replace(bau_prefix, int_prefix) - for c in prev.columns if c.startswith(bau_prefix)} + bau_prefix = "{}.".format(self.name) + int_prefix = "{}_intervention.".format(self.name) + rename_to = { + c: c.replace(bau_prefix, int_prefix) + for c in prev.columns + if c.startswith(bau_prefix) + } int_prev = prev.rename(columns=rename_to) self.population_view.update(int_prev) @@ -322,8 +330,8 @@ def on_time_step_prepare(self, event): # Identify the relevant columns for the BAU and intervention. bin_cols = self.get_bin_names() - bau_prefix = '{}.'.format(self.name) - int_prefix = '{}_intervention.'.format(self.name) + bau_prefix = "{}.".format(self.name) + int_prefix = "{}_intervention.".format(self.name) bau_cols = [c for c in bin_cols if c.startswith(bau_prefix)] int_cols = [c for c in bin_cols if c.startswith(int_prefix)] @@ -359,12 +367,11 @@ def on_time_step_prepare(self, event): # NOTE: adjust the RR *after* calculating the ACMR adjustments, but # *before* calculating the survival probability for each exposure # level. - penultimate_cols = [s + str(self.bin_years) - for s in [bau_prefix, int_prefix]] + penultimate_cols = [s + str(self.bin_years) for s in [bau_prefix, int_prefix]] mort_rr.loc[:, penultimate_cols] = 1.0 # Calculate the mortality risk for non-smokers. - bau_surv_no = 1 - np.exp(- bau_acmr_no) + bau_surv_no = 1 - np.exp(-bau_acmr_no) # Calculate the survival probability for each exposure level: # (1 - mort_risk_non_smokers)^RR bau_surv_rate = mort_rr.loc[:, bau_cols].rpow(1 - bau_surv_no, axis=0) @@ -375,34 +382,33 @@ def on_time_step_prepare(self, event): # (intervention). # NOTE: we apply the same survival rate to each exposure level for # the intervention scenario as we used for the BAU scenario. - rename_to = {c: c.replace('.', '_intervention.') - for c in bau_surv_rate.columns} + rename_to = {c: c.replace(".", "_intervention.") for c in bau_surv_rate.columns} int_surv_rate = bau_surv_rate.rename(columns=rename_to) pop.loc[:, int_cols] = pop.loc[:, int_cols].mul(int_surv_rate) # Account for transitions between bins. # Note that the order of evaluation matters. - suffixes = ['', '_intervention'] + suffixes = ["", "_intervention"] # First, accumulate the final post-exposure bin. if self.bin_years > 0: for suffix in suffixes: - accum_col = '{}{}.{}'.format(self.name, suffix, self.bin_years + 1) - from_col = '{}{}.{}'.format(self.name, suffix, self.bin_years) + accum_col = "{}{}.{}".format(self.name, suffix, self.bin_years + 1) + from_col = "{}{}.{}".format(self.name, suffix, self.bin_years) pop[accum_col] += pop[from_col] # Then increase time since exposure for all other post-exposure bins. for n_years in reversed(range(self.bin_years)): for suffix in suffixes: - source_col = '{}{}.{}'.format(self.name, suffix, n_years) - dest_col = '{}{}.{}'.format(self.name, suffix, n_years + 1) + source_col = "{}{}.{}".format(self.name, suffix, n_years) + dest_col = "{}{}.{}".format(self.name, suffix, n_years + 1) pop[dest_col] = pop[source_col] # Account for incidence and remission. - col_no = '{}.no'.format(self.name) - col_int_no = '{}_intervention.no'.format(self.name) - col_yes = '{}.yes'.format(self.name) - col_int_yes = '{}_intervention.yes'.format(self.name) - col_zero = '{}.0'.format(self.name) - col_int_zero = '{}_intervention.0'.format(self.name) + col_no = "{}.no".format(self.name) + col_int_no = "{}_intervention.no".format(self.name) + col_yes = "{}.yes".format(self.name) + col_int_yes = "{}_intervention.yes".format(self.name) + col_zero = "{}.0".format(self.name) + col_int_zero = "{}_intervention.0".format(self.name) inc = inc_rate * pop[col_no] int_inc = int_inc_rate * pop[col_int_no] @@ -446,9 +452,11 @@ def register_modifier(self, builder, disease): """ # NOTE: we need to modify different rates for chronic and acute # diseases. For now, register modifiers for all possible rates. - rate_templates = ['{}_intervention.incidence', - '{}_intervention.excess_mortality', - '{}_intervention.yld_rate'] + rate_templates = [ + "{}_intervention.incidence", + "{}_intervention.excess_mortality", + "{}_intervention.yld_rate", + ] for template in rate_templates: rate_name = template.format(disease) modifier = lambda ix, rate: self.incidence_adjustment(disease, ix, rate) @@ -474,7 +482,7 @@ def incidence_adjustment(self, disease, index, incidence_rate): rr_values = pop[bin_cols] * incidence_rr # Calculate the mean relative-risk for the BAU scenario. - bau_prefix = '{}.'.format(self.name) + bau_prefix = "{}.".format(self.name) bau_cols = [c for c in bin_cols if c.startswith(bau_prefix)] # Sum over all of the bins in each row. mean_bau_rr = rr_values[bau_cols].sum(axis=1) / pop[bau_cols].sum(axis=1) @@ -482,7 +490,7 @@ def incidence_adjustment(self, disease, index, incidence_rate): mean_bau_rr = mean_bau_rr.fillna(1.0) # Calculate the mean relative-risk for the intervention scenario. - int_prefix = '{}_intervention.'.format(self.name) + int_prefix = "{}_intervention.".format(self.name) int_cols = [c for c in bin_cols if c.startswith(int_prefix)] # Sum over all of the bins in each row. mean_int_rr = rr_values[int_cols].sum(axis=1) / pop[int_cols].sum(axis=1) @@ -494,6 +502,7 @@ def incidence_adjustment(self, disease, index, incidence_rate): pif = pif.fillna(0.0) return incidence_rate * (1 - pif) + def pivot_load(builder, entity_key): """Helper method for loading dataframe from artifact. @@ -503,8 +512,15 @@ def pivot_load(builder, entity_key): """ data = builder.data.load(entity_key) - if 'measure' in data.columns : - data = data.pivot_table(index = [i for i in data.columns if i not in ['measure','value']], columns = 'measure', \ - values = 'value').rename_axis(None,axis = 1).reset_index() - - return data + if "measure" in data.columns: + data = ( + data.pivot_table( + index=[i for i in data.columns if i not in ["measure", "value"]], + columns="measure", + values="value", + ) + .rename_axis(None, axis=1) + .reset_index() + ) + + return data diff --git a/src/vivarium_public_health/mslt/disease.py b/src/vivarium_public_health/mslt/disease.py index 46710dccf..2b1ee708f 100644 --- a/src/vivarium_public_health/mslt/disease.py +++ b/src/vivarium_public_health/mslt/disease.py @@ -34,35 +34,35 @@ class AcuteDisease: def __init__(self, name): self._name = name - + @property def name(self): return self._name def setup(self, builder): """Load the morbidity and mortality data.""" - mty_data = builder.data.load(f'acute_disease.{self.name}.mortality') - mty_rate = builder.lookup.build_table(mty_data, - key_columns=['sex'], - parameter_columns=['age','year']) - yld_data = builder.data.load(f'acute_disease.{self.name}.morbidity') - yld_rate = builder.lookup.build_table(yld_data, - key_columns=['sex'], - parameter_columns=['age','year']) + mty_data = builder.data.load(f"acute_disease.{self.name}.mortality") + mty_rate = builder.lookup.build_table( + mty_data, key_columns=["sex"], parameter_columns=["age", "year"] + ) + yld_data = builder.data.load(f"acute_disease.{self.name}.morbidity") + yld_rate = builder.lookup.build_table( + yld_data, key_columns=["sex"], parameter_columns=["age", "year"] + ) self.excess_mortality = builder.value.register_rate_producer( - f'{self.name}.excess_mortality', - source=mty_rate) + f"{self.name}.excess_mortality", source=mty_rate + ) self.int_excess_mortality = builder.value.register_rate_producer( - f'{self.name}_intervention.excess_mortality', - source=mty_rate) + f"{self.name}_intervention.excess_mortality", source=mty_rate + ) self.disability_rate = builder.value.register_rate_producer( - f'{self.name}.yld_rate', - source=yld_rate) + f"{self.name}.yld_rate", source=yld_rate + ) self.int_disability_rate = builder.value.register_rate_producer( - f'{self.name}_intervention.yld_rate', - source=yld_rate) - builder.value.register_value_modifier('mortality_rate', self.mortality_adjustment) - builder.value.register_value_modifier('yld_rate', self.disability_adjustment) + f"{self.name}_intervention.yld_rate", source=yld_rate + ) + builder.value.register_value_modifier("mortality_rate", self.mortality_adjustment) + builder.value.register_value_modifier("yld_rate", self.disability_adjustment) def mortality_adjustment(self, index, mortality_rate): """ @@ -105,94 +105,102 @@ def __init__(self, name): self._name = name self.configuration_defaults = { self.name: { - 'simplified_no_remission_equations': False, + "simplified_no_remission_equations": False, }, } - + @property def name(self): return self._name def setup(self, builder): """Load the disease prevalence and rates data.""" - data_prefix = 'chronic_disease.{}.'.format(self.name) - bau_prefix = self.name + '.' - int_prefix = self.name + '_intervention.' + data_prefix = "chronic_disease.{}.".format(self.name) + bau_prefix = self.name + "." + int_prefix = self.name + "_intervention." self.clock = builder.time.clock() self.start_year = builder.configuration.time.start.year - self.simplified_equations = builder.configuration[self.name].simplified_no_remission_equations - - inc_data = builder.data.load(data_prefix + 'incidence') - i = builder.lookup.build_table(inc_data, - key_columns=['sex'], - parameter_columns=['age','year']) + self.simplified_equations = builder.configuration[ + self.name + ].simplified_no_remission_equations + + inc_data = builder.data.load(data_prefix + "incidence") + i = builder.lookup.build_table( + inc_data, key_columns=["sex"], parameter_columns=["age", "year"] + ) self.incidence = builder.value.register_rate_producer( - bau_prefix + 'incidence', source=i) + bau_prefix + "incidence", source=i + ) self.incidence_intervention = builder.value.register_rate_producer( - int_prefix + 'incidence', source=i) + int_prefix + "incidence", source=i + ) - rem_data = builder.data.load(data_prefix + 'remission') - r = builder.lookup.build_table(rem_data, - key_columns=['sex'], - parameter_columns=['age','year']) + rem_data = builder.data.load(data_prefix + "remission") + r = builder.lookup.build_table( + rem_data, key_columns=["sex"], parameter_columns=["age", "year"] + ) self.remission = builder.value.register_rate_producer( - bau_prefix + 'remission', source=r) + bau_prefix + "remission", source=r + ) - mty_data = builder.data.load(data_prefix + 'mortality') - f = builder.lookup.build_table(mty_data, - key_columns=['sex'], - parameter_columns=['age','year']) + mty_data = builder.data.load(data_prefix + "mortality") + f = builder.lookup.build_table( + mty_data, key_columns=["sex"], parameter_columns=["age", "year"] + ) self.excess_mortality = builder.value.register_rate_producer( - bau_prefix + 'excess_mortality', source=f) + bau_prefix + "excess_mortality", source=f + ) - yld_data = builder.data.load(data_prefix + 'morbidity') - yld_rate = builder.lookup.build_table(yld_data, - key_columns=['sex'], - parameter_columns=['age','year']) + yld_data = builder.data.load(data_prefix + "morbidity") + yld_rate = builder.lookup.build_table( + yld_data, key_columns=["sex"], parameter_columns=["age", "year"] + ) self.disability_rate = builder.value.register_rate_producer( - bau_prefix + 'yld_rate', source=yld_rate) + bau_prefix + "yld_rate", source=yld_rate + ) - prev_data = builder.data.load(data_prefix + 'prevalence') - self.initial_prevalence = builder.lookup.build_table(prev_data, - key_columns=['sex'], - parameter_columns=['age','year']) + prev_data = builder.data.load(data_prefix + "prevalence") + self.initial_prevalence = builder.lookup.build_table( + prev_data, key_columns=["sex"], parameter_columns=["age", "year"] + ) - builder.value.register_value_modifier( - 'mortality_rate', self.mortality_adjustment) - builder.value.register_value_modifier( - 'yld_rate', self.disability_adjustment) + builder.value.register_value_modifier("mortality_rate", self.mortality_adjustment) + builder.value.register_value_modifier("yld_rate", self.disability_adjustment) columns = [] - for scenario in ['', '_intervention']: - for rate in ['_S', '_C']: - for when in ['', '_previous']: + for scenario in ["", "_intervention"]: + for rate in ["_S", "_C"]: + for when in ["", "_previous"]: columns.append(self.name + rate + scenario + when) builder.population.initializes_simulants( self.on_initialize_simulants, creates_columns=columns, - requires_columns=['age', 'sex']) + requires_columns=["age", "sex"], + ) self.population_view = builder.population.get_view(columns) - builder.event.register_listener( - 'time_step__prepare', - self.on_time_step_prepare) + builder.event.register_listener("time_step__prepare", self.on_time_step_prepare) def on_initialize_simulants(self, pop_data): """Initialize the test population for which this disease is modeled.""" C = 1000 * self.initial_prevalence(pop_data.index) S = 1000 - C - pop = pd.DataFrame({f'{self.name}_S': S, - f'{self.name}_C': C, - f'{self.name}_S_previous': S, - f'{self.name}_C_previous': C, - f'{self.name}_S_intervention': S, - f'{self.name}_C_intervention': C, - f'{self.name}_S_intervention_previous': S, - f'{self.name}_C_intervention_previous': C}, - index=pop_data.index) + pop = pd.DataFrame( + { + f"{self.name}_S": S, + f"{self.name}_C": C, + f"{self.name}_S_previous": S, + f"{self.name}_C_previous": C, + f"{self.name}_S_intervention": S, + f"{self.name}_C_intervention": C, + f"{self.name}_S_intervention_previous": S, + f"{self.name}_C_intervention_previous": C, + }, + index=pop_data.index, + ) self.population_view.update(pop) @@ -208,9 +216,9 @@ def on_time_step_prepare(self, event): if pop.empty: return idx = pop.index - S_bau, C_bau = pop[f'{self.name}_S'], pop[f'{self.name}_C'] - S_int = pop[f'{self.name}_S_intervention'] - C_int = pop[f'{self.name}_C_intervention'] + S_bau, C_bau = pop[f"{self.name}_S"], pop[f"{self.name}_C"] + S_int = pop[f"{self.name}_S_intervention"] + C_int = pop[f"{self.name}_C_intervention"] # Extract all of the required rates *once only*. i_bau = self.incidence(idx) @@ -226,20 +234,23 @@ def on_time_step_prepare(self, event): # NOTE: for the 'mslt_reduce_chd' experiment, this results in a # slightly lower HALY gain than that obtained when using the # full equations (below). - new_S_bau = S_bau * np.exp(- i_bau) - new_S_int = S_int * np.exp(- i_int) - new_C_bau = C_bau * np.exp(- f) + S_bau - new_S_bau - new_C_int = C_int * np.exp(- f) + S_int - new_S_int - pop_update = pd.DataFrame({ - f'{self.name}_S': new_S_bau, - f'{self.name}_C': new_C_bau, - f'{self.name}_S_previous': S_bau, - f'{self.name}_C_previous': C_bau, - f'{self.name}_S_intervention': new_S_int, - f'{self.name}_C_intervention': new_C_int, - f'{self.name}_S_intervention_previous': S_int, - f'{self.name}_C_intervention_previous': C_int, - }, index=pop.index) + new_S_bau = S_bau * np.exp(-i_bau) + new_S_int = S_int * np.exp(-i_int) + new_C_bau = C_bau * np.exp(-f) + S_bau - new_S_bau + new_C_int = C_int * np.exp(-f) + S_int - new_S_int + pop_update = pd.DataFrame( + { + f"{self.name}_S": new_S_bau, + f"{self.name}_C": new_C_bau, + f"{self.name}_S_previous": S_bau, + f"{self.name}_C_previous": C_bau, + f"{self.name}_S_intervention": new_S_int, + f"{self.name}_C_intervention": new_C_int, + f"{self.name}_S_intervention_previous": S_int, + f"{self.name}_C_intervention_previous": C_int, + }, + index=pop.index, + ) self.population_view.update(pop_update) return @@ -277,36 +288,41 @@ def on_time_step_prepare(self, event): new_C_int = C_int.copy() # Calculate new_S_bau, new_C_bau, new_S_int, new_C_int. - num_S_bau = (2 * (v_bau - w_bau) * (S_bau * f_plus_r + C_bau * r) - + S_bau * (v_bau * (q_bau - l_bau) - + w_bau * (q_bau + l_bau))) - num_S_int = (2 * (v_int - w_int) * (S_int * f_plus_r + C_int * r) - + S_int * (v_int * (q_int - l_int) - + w_int * (q_int + l_int))) + num_S_bau = 2 * (v_bau - w_bau) * (S_bau * f_plus_r + C_bau * r) + S_bau * ( + v_bau * (q_bau - l_bau) + w_bau * (q_bau + l_bau) + ) + num_S_int = 2 * (v_int - w_int) * (S_int * f_plus_r + C_int * r) + S_int * ( + v_int * (q_int - l_int) + w_int * (q_int + l_int) + ) new_S_bau[nz_bau] = num_S_bau[nz_bau] / denom_bau[nz_bau] new_S_int[nz_int] = num_S_int[nz_int] / denom_int[nz_int] - num_C_bau = - ((v_bau - w_bau) * (2 * (f_plus_r * (S_bau + C_bau) - - l_bau * S_bau) - - l_bau * C_bau) - - (v_bau + w_bau) * q_bau * C_bau) - num_C_int = - ((v_int - w_int) * (2 * (f_plus_r * (S_int + C_int) - - l_int * S_int) - - l_int * C_int) - - (v_int + w_int) * q_int * C_int) + num_C_bau = -( + (v_bau - w_bau) + * (2 * (f_plus_r * (S_bau + C_bau) - l_bau * S_bau) - l_bau * C_bau) + - (v_bau + w_bau) * q_bau * C_bau + ) + num_C_int = -( + (v_int - w_int) + * (2 * (f_plus_r * (S_int + C_int) - l_int * S_int) - l_int * C_int) + - (v_int + w_int) * q_int * C_int + ) new_C_bau[nz_bau] = num_C_bau[nz_bau] / denom_bau[nz_bau] new_C_int[nz_int] = num_C_int[nz_int] / denom_int[nz_int] - pop_update = pd.DataFrame({ - f'{self.name}_S': new_S_bau, - f'{self.name}_C': new_C_bau, - f'{self.name}_S_previous': S_bau, - f'{self.name}_C_previous': C_bau, - f'{self.name}_S_intervention': new_S_int, - f'{self.name}_C_intervention': new_C_int, - f'{self.name}_S_intervention_previous': S_int, - f'{self.name}_C_intervention_previous': C_int, - }, index=pop.index) + pop_update = pd.DataFrame( + { + f"{self.name}_S": new_S_bau, + f"{self.name}_C": new_C_bau, + f"{self.name}_S_previous": S_bau, + f"{self.name}_C_previous": C_bau, + f"{self.name}_S_intervention": new_S_int, + f"{self.name}_C_intervention": new_C_int, + f"{self.name}_S_intervention_previous": S_int, + f"{self.name}_C_intervention_previous": C_int, + }, + index=pop.index, + ) self.population_view.update(pop_update) def mortality_adjustment(self, index, mortality_rate): @@ -317,12 +333,15 @@ def mortality_adjustment(self, index, mortality_rate): """ pop = self.population_view.get(index) - S, C = pop[f'{self.name}_S'], pop[f'{self.name}_C'] - S_prev, C_prev = pop[f'{self.name}_S_previous'], pop[f'{self.name}_C_previous'] + S, C = pop[f"{self.name}_S"], pop[f"{self.name}_C"] + S_prev, C_prev = pop[f"{self.name}_S_previous"], pop[f"{self.name}_C_previous"] D, D_prev = 1000 - S - C, 1000 - S_prev - C_prev - S_int, C_int = pop[f'{self.name}_S_intervention'], pop[f'{self.name}_C_intervention'] - S_int_prev, C_int_prev = pop[f'{self.name}_S_intervention_previous'], pop[f'{self.name}_C_intervention_previous'] + S_int, C_int = pop[f"{self.name}_S_intervention"], pop[f"{self.name}_C_intervention"] + S_int_prev, C_int_prev = ( + pop[f"{self.name}_S_intervention_previous"], + pop[f"{self.name}_C_intervention_previous"], + ) D_int, D_int_prev = 1000 - S_int - C_int, 1000 - S_int_prev - C_int_prev # NOTE: as per the spreadsheet, the denominator is from the same point @@ -342,10 +361,16 @@ def disability_adjustment(self, index, yld_rate): """ pop = self.population_view.get(index) - S, S_prev = pop[f'{self.name}_S'], pop[f'{self.name}_S_previous'] - C, C_prev = pop[f'{self.name}_C'], pop[f'{self.name}_C_previous'] - S_int, S_int_prev = pop[f'{self.name}_S_intervention'], pop[f'{self.name}_S_intervention_previous'] - C_int, C_int_prev = pop[f'{self.name}_C_intervention'], pop[f'{self.name}_C_intervention_previous'] + S, S_prev = pop[f"{self.name}_S"], pop[f"{self.name}_S_previous"] + C, C_prev = pop[f"{self.name}_C"], pop[f"{self.name}_C_previous"] + S_int, S_int_prev = ( + pop[f"{self.name}_S_intervention"], + pop[f"{self.name}_S_intervention_previous"], + ) + C_int, C_int_prev = ( + pop[f"{self.name}_C_intervention"], + pop[f"{self.name}_C_intervention_previous"], + ) # The prevalence rate is the mean number of diseased people over the # year, divided by the mean number of alive people over the year. diff --git a/src/vivarium_public_health/mslt/intervention.py b/src/vivarium_public_health/mslt/intervention.py index 6c56c3a03..f72d039f5 100644 --- a/src/vivarium_public_health/mslt/intervention.py +++ b/src/vivarium_public_health/mslt/intervention.py @@ -11,9 +11,10 @@ class ModifyAllCauseMortality: """Interventions that modify the all-cause mortality rate.""" + def __init__(self, name): self._name = name - + @property def name(self): return self._name @@ -22,9 +23,8 @@ def setup(self, builder): self.config = builder.configuration self.scale = self.config.intervention[self.name]["scale"] if self.scale < 0: - raise ValueError('Invalid scale: {}'.format(self.scale)) - builder.value.register_value_modifier('mortality_rate', - self.mortality_adjustment) + raise ValueError("Invalid scale: {}".format(self.scale)) + builder.value.register_value_modifier("mortality_rate", self.mortality_adjustment) def mortality_adjustment(self, index, rates): return rates * self.scale @@ -32,6 +32,7 @@ def mortality_adjustment(self, index, rates): class ModifyDiseaseRate: """Interventions that modify a rate associated with a chronic disease.""" + def __init__(self, name, disease, rate): self._name = name self.disease = disease @@ -47,8 +48,8 @@ def setup(self, builder): scale_name = "{}_{}_scale".format(self.disease, self.rate) self.scale = self.config.intervention[self.name][scale_name] if self.scale < 0: - raise ValueError('Invalid scale: {}'.format(self.scale)) - rate_name = '{}_intervention.{}'.format(self.disease, self.rate) + raise ValueError("Invalid scale: {}".format(self.scale)) + rate_name = "{}_intervention.{}".format(self.disease, self.rate) builder.value.register_value_modifier(rate_name, self.adjust_rate) def adjust_rate(self, index, rates): @@ -62,7 +63,7 @@ class ModifyDiseaseIncidence(ModifyDiseaseRate): """ def __init__(self, name, disease): - super().__init__(name=name, disease=disease, rate='incidence') + super().__init__(name=name, disease=disease, rate="incidence") class ModifyDiseaseMortality(ModifyDiseaseRate): @@ -72,7 +73,7 @@ class ModifyDiseaseMortality(ModifyDiseaseRate): """ def __init__(self, name, disease): - super().__init__(name=name, disease=disease, rate='excess_mortality') + super().__init__(name=name, disease=disease, rate="excess_mortality") class ModifyDiseaseMorbidity(ModifyDiseaseRate): @@ -82,7 +83,7 @@ class ModifyDiseaseMorbidity(ModifyDiseaseRate): """ def __init__(self, name, disease): - super().__init__(name=name, disease=disease, rate='yld_rate') + super().__init__(name=name, disease=disease, rate="yld_rate") class ModifyAcuteDiseaseIncidence: @@ -94,7 +95,7 @@ class ModifyAcuteDiseaseIncidence: def __init__(self, name): self._name = name - + @property def name(self): return self._name @@ -103,10 +104,10 @@ def setup(self, builder): self.config = builder.configuration self.scale = self.config.intervention[self.name].incidence_scale if self.scale < 0: - raise ValueError('Invalid incidence scale: {}'.format(self.scale)) - yld_rate = '{}_intervention.yld_rate'.format(self.name) + raise ValueError("Invalid incidence scale: {}".format(self.scale)) + yld_rate = "{}_intervention.yld_rate".format(self.name) builder.value.register_value_modifier(yld_rate, self.rate_adjustment) - mort_rate = '{}_intervention.excess_mortality'.format(self.name) + mort_rate = "{}_intervention.excess_mortality".format(self.name) builder.value.register_value_modifier(mort_rate, self.rate_adjustment) def rate_adjustment(self, index, rates): @@ -118,7 +119,7 @@ class ModifyAcuteDiseaseMorbidity: def __init__(self, name): self._name = name - + @property def name(self): return self._name @@ -127,8 +128,8 @@ def setup(self, builder): self.config = builder.configuration self.scale = self.config.intervention[self.name].yld_scale if self.scale < 0: - raise ValueError('Invalid YLD scale: {}'.format(self.scale)) - rate = '{}_intervention.yld_rate'.format(self.name) + raise ValueError("Invalid YLD scale: {}".format(self.scale)) + rate = "{}_intervention.yld_rate".format(self.name) builder.value.register_value_modifier(rate, self.disability_adjustment) def disability_adjustment(self, index, rates): @@ -140,7 +141,7 @@ class ModifyAcuteDiseaseMortality: def __init__(self, name): self._name = name - + @property def name(self): return self._name @@ -149,8 +150,8 @@ def setup(self, builder): self.config = builder.configuration self.scale = self.config.intervention[self.name].mortality_scale if self.scale < 0: - raise ValueError('Invalid mortality scale: {}'.format(self.scale)) - rate = '{}_intervention.excess_mortality'.format(self.name) + raise ValueError("Invalid mortality scale: {}".format(self.scale)) + rate = "{}_intervention.excess_mortality".format(self.name) builder.value.register_value_modifier(rate, self.mortality_adjustment) def mortality_adjustment(self, index, rates): @@ -159,17 +160,18 @@ def mortality_adjustment(self, index, rates): class TobaccoFreeGeneration: """Eradicate tobacco uptake at some point in time.""" + def __init__(self): - self.exposure = 'tobacco' - + self.exposure = "tobacco" + @property def name(self): - return 'tobacco_free_generation' + return "tobacco_free_generation" def setup(self, builder): - self.year = builder.configuration['tobacco_free_generation'].year + self.year = builder.configuration["tobacco_free_generation"].year self.clock = builder.time.clock() - rate_name = '{}_intervention.incidence'.format(self.exposure) + rate_name = "{}_intervention.incidence".format(self.exposure) builder.value.register_value_modifier(rate_name, self.adjust_rate) def adjust_rate(self, index, rates): @@ -182,22 +184,21 @@ def adjust_rate(self, index, rates): class TobaccoEradication: """Eradicate all tobacco use at some point in time.""" + def __init__(self): - self.exposure = 'tobacco' - + self.exposure = "tobacco" + @property def name(self): - return 'tobacco_eradication' + return "tobacco_eradication" def setup(self, builder): - self.year = builder.configuration['tobacco_eradication'].year + self.year = builder.configuration["tobacco_eradication"].year self.clock = builder.time.clock() - inc_rate_name = '{}_intervention.incidence'.format(self.exposure) - builder.value.register_value_modifier(inc_rate_name, - self.adjust_inc_rate) - rem_rate_name = '{}_intervention.remission'.format(self.exposure) - builder.value.register_value_modifier(rem_rate_name, - self.adjust_rem_rate) + inc_rate_name = "{}_intervention.incidence".format(self.exposure) + builder.value.register_value_modifier(inc_rate_name, self.adjust_inc_rate) + rem_rate_name = "{}_intervention.remission".format(self.exposure) + builder.value.register_value_modifier(rem_rate_name, self.adjust_rem_rate) def adjust_inc_rate(self, index, rates): this_year = self.clock().year diff --git a/src/vivarium_public_health/mslt/magic_wand_components.py b/src/vivarium_public_health/mslt/magic_wand_components.py index df52815cb..fbbb886cb 100644 --- a/src/vivarium_public_health/mslt/magic_wand_components.py +++ b/src/vivarium_public_health/mslt/magic_wand_components.py @@ -10,33 +10,30 @@ class MortalityShift: - @property def name(self): - return 'mortality_shift' + return "mortality_shift" def setup(self, builder): - builder.value.register_value_modifier('mortality_rate', self.mortality_adjustment) + builder.value.register_value_modifier("mortality_rate", self.mortality_adjustment) def mortality_adjustment(self, index, rates): - return rates * .5 + return rates * 0.5 class YLDShift: - @property def name(self): - return 'yld_shift' + return "yld_shift" def setup(self, builder): - builder.value.register_value_modifier('yld_rate', self.disability_adjustment) + builder.value.register_value_modifier("yld_rate", self.disability_adjustment) def disability_adjustment(self, index, rates): - return rates * .5 + return rates * 0.5 class IncidenceShift: - def __init__(self, name): self._name = name @@ -45,17 +42,18 @@ def name(self): return self._name def setup(self, builder): - builder.value.register_value_modifier(f'{self.name}_intervention.incidence', self.incidence_adjustment) + builder.value.register_value_modifier( + f"{self.name}_intervention.incidence", self.incidence_adjustment + ) def incidence_adjustment(self, index, rates): - return rates * .5 + return rates * 0.5 class ModifyAcuteDiseaseYLD: - def __init__(self, name): self._name = name - + @property def name(self): return self._name @@ -64,20 +62,19 @@ def setup(self, builder): self.config = builder.configuration self.scale = self.config.intervention[self.name].yld_scale if self.scale < 0: - raise ValueError(f'Invalid YLD scale: {self.scale}') + raise ValueError(f"Invalid YLD scale: {self.scale}") builder.value.register_value_modifier( - f'{self.name}_intervention.yld_rate', - self.disability_adjustment) + f"{self.name}_intervention.yld_rate", self.disability_adjustment + ) def disability_adjustment(self, index, rates): return rates * self.scale class ModifyAcuteDiseaseMortality: - def __init__(self, name): self._name = name - + @property def name(self): return self._name @@ -86,10 +83,10 @@ def setup(self, builder): self.config = builder.configuration self.scale = self.config.intervention[self.name].mortality_scale if self.scale < 0: - raise ValueError(f'Invalid mortality scale: {self.scale}') + raise ValueError(f"Invalid mortality scale: {self.scale}") builder.value.register_value_modifier( - f'{self.name}_intervention.excess_mortality', - self.mortality_adjustment) + f"{self.name}_intervention.excess_mortality", self.mortality_adjustment + ) def mortality_adjustment(self, index, rates): return rates * self.scale diff --git a/src/vivarium_public_health/mslt/observer.py b/src/vivarium_public_health/mslt/observer.py index 2a27a3123..6ff04c7db 100644 --- a/src/vivarium_public_health/mslt/observer.py +++ b/src/vivarium_public_health/mslt/observer.py @@ -10,7 +10,7 @@ import pandas as pd -def output_file(config, suffix, sep='_', ext='csv'): +def output_file(config, suffix, sep="_", ext="csv"): """ Determine the output file name for an observer, based on the prefix defined in ``config.observer.output_prefix`` and the (optional) @@ -28,19 +28,19 @@ def output_file(config, suffix, sep='_', ext='csv'): The output file extension. """ - if 'observer' not in config: - raise ValueError('observer.output_prefix not defined') - if 'output_prefix' not in config.observer: - raise ValueError('observer.output_prefix not defined') + if "observer" not in config: + raise ValueError("observer.output_prefix not defined") + if "output_prefix" not in config.observer: + raise ValueError("observer.output_prefix not defined") prefix = config.observer.output_prefix - if 'input_draw_number' in config.input_data: + if "input_draw_number" in config.input_data: draw = config.input_data.input_draw_number else: draw = 0 out_file = prefix + sep + suffix if draw > 0: - out_file += '{}{}'.format(sep, draw) - out_file += '.{}'.format(ext) + out_file += "{}{}".format(sep, draw) + out_file += ".{}".format(ext) return out_file @@ -57,40 +57,61 @@ class MorbidityMortality: """ - def __init__(self, output_suffix='mm'): + def __init__(self, output_suffix="mm"): self.output_suffix = output_suffix @property def name(self): - return 'morbidity_mortality_observer' + return "morbidity_mortality_observer" def setup(self, builder): # Record the key columns from the core multi-state life table. - columns = ['age', 'sex', - 'population', 'bau_population', - 'acmr', 'bau_acmr', - 'pr_death', 'bau_pr_death', - 'deaths', 'bau_deaths', - 'yld_rate', 'bau_yld_rate', - 'person_years', 'bau_person_years', - 'HALY', 'bau_HALY'] + columns = [ + "age", + "sex", + "population", + "bau_population", + "acmr", + "bau_acmr", + "pr_death", + "bau_pr_death", + "deaths", + "bau_deaths", + "yld_rate", + "bau_yld_rate", + "person_years", + "bau_person_years", + "HALY", + "bau_HALY", + ] self.population_view = builder.population.get_view(columns) self.clock = builder.time.clock() - builder.event.register_listener('collect_metrics', self.on_collect_metrics) - builder.event.register_listener('simulation_end', self.write_output) + builder.event.register_listener("collect_metrics", self.on_collect_metrics) + builder.event.register_listener("simulation_end", self.write_output) self.tables = [] - self.table_cols = ['sex', 'age', 'year', - 'population', 'bau_population', - 'prev_population', 'bau_prev_population', - 'acmr', 'bau_acmr', - 'pr_death', 'bau_pr_death', - 'deaths', 'bau_deaths', - 'yld_rate', 'bau_yld_rate', - 'person_years', 'bau_person_years', - 'HALY', 'bau_HALY'] - - self.output_file = output_file(builder.configuration, - self.output_suffix) + self.table_cols = [ + "sex", + "age", + "year", + "population", + "bau_population", + "prev_population", + "bau_prev_population", + "acmr", + "bau_acmr", + "pr_death", + "bau_pr_death", + "deaths", + "bau_deaths", + "yld_rate", + "bau_yld_rate", + "person_years", + "bau_person_years", + "HALY", + "bau_HALY", + ] + + self.output_file = output_file(builder.configuration, self.output_suffix) def on_collect_metrics(self, event): pop = self.population_view.get(event.index) @@ -98,10 +119,10 @@ def on_collect_metrics(self, event): # No tracked population remains. return - pop['year'] = self.clock().year + pop["year"] = self.clock().year # Record the population size prior to the deaths. - pop['prev_population'] = pop['population'] + pop['deaths'] - pop['bau_prev_population'] = pop['bau_population'] + pop['bau_deaths'] + pop["prev_population"] = pop["population"] + pop["deaths"] + pop["bau_prev_population"] = pop["bau_population"] + pop["bau_deaths"] self.tables.append(pop[self.table_cols]) def calculate_LE(self, table, py_col, denom_col): @@ -123,7 +144,7 @@ def calculate_LE(self, table, py_col, denom_col): """ # Group the person-years by cohort. - group_cols = ['year_of_birth', 'sex'] + group_cols = ["year_of_birth", "sex"] subset_cols = group_cols + [py_col] grouped = table.loc[:, subset_cols].groupby(by=group_cols)[py_col] # Calculate the reverse-cumulative sums of the adjusted person-years @@ -136,23 +157,21 @@ def calculate_LE(self, table, py_col, denom_col): def write_output(self, event): data = pd.concat(self.tables, ignore_index=True) - data['year_of_birth'] = data['year'] - data['age'] + data["year_of_birth"] = data["year"] - data["age"] # Sort the table by cohort (i.e., generation and sex), and then by # calendar year, so that results are output in the same order as in # the spreadsheet models. - data = data.sort_values(by=['year_of_birth', 'sex', 'age'], axis=0) + data = data.sort_values(by=["year_of_birth", "sex", "age"], axis=0) data = data.reset_index(drop=True) # Re-order the table columns. - cols = ['year_of_birth'] + self.table_cols + cols = ["year_of_birth"] + self.table_cols data = data[cols] # Calculate life expectancy and HALE for the BAU and intervention, # with respect to the initial population, not the survivors. - data['LE'] = self.calculate_LE(data, 'person_years', 'prev_population') - data['bau_LE'] = self.calculate_LE(data, 'bau_person_years', - 'bau_prev_population') - data['HALE'] = self.calculate_LE(data, 'HALY', 'prev_population') - data['bau_HALE'] = self.calculate_LE(data, 'bau_HALY', - 'bau_prev_population') + data["LE"] = self.calculate_LE(data, "person_years", "prev_population") + data["bau_LE"] = self.calculate_LE(data, "bau_person_years", "bau_prev_population") + data["HALE"] = self.calculate_LE(data, "HALY", "prev_population") + data["bau_HALE"] = self.calculate_LE(data, "bau_HALY", "bau_prev_population") data.to_csv(self.output_file, index=False) @@ -176,38 +195,49 @@ def __init__(self, name, output_suffix=None): if output_suffix is None: output_suffix = name.lower() self.output_suffix = output_suffix - + @property def name(self): - return f'{self._name}_observer' + return f"{self._name}_observer" def setup(self, builder): - bau_incidence_value = '{}.incidence'.format(self._name) - int_incidence_value = '{}_intervention.incidence'.format(self._name) + bau_incidence_value = "{}.incidence".format(self._name) + int_incidence_value = "{}_intervention.incidence".format(self._name) self.bau_incidence = builder.value.get_value(bau_incidence_value) self.int_incidence = builder.value.get_value(int_incidence_value) - self.bau_S_col = '{}_S'.format(self._name) - self.bau_C_col = '{}_C'.format(self._name) - self.int_S_col = '{}_S_intervention'.format(self._name) - self.int_C_col = '{}_C_intervention'.format(self._name) - - columns = ['age', 'sex', - self.bau_S_col, self.bau_C_col, - self.int_S_col, self.int_C_col] + self.bau_S_col = "{}_S".format(self._name) + self.bau_C_col = "{}_C".format(self._name) + self.int_S_col = "{}_S_intervention".format(self._name) + self.int_C_col = "{}_C_intervention".format(self._name) + + columns = [ + "age", + "sex", + self.bau_S_col, + self.bau_C_col, + self.int_S_col, + self.int_C_col, + ] self.population_view = builder.population.get_view(columns) - builder.event.register_listener('collect_metrics', self.on_collect_metrics) - builder.event.register_listener('simulation_end', self.write_output) + builder.event.register_listener("collect_metrics", self.on_collect_metrics) + builder.event.register_listener("simulation_end", self.write_output) self.tables = [] - self.table_cols = ['sex', 'age', 'year', - 'bau_incidence', 'int_incidence', - 'bau_prevalence', 'int_prevalence', - 'bau_deaths', 'int_deaths'] + self.table_cols = [ + "sex", + "age", + "year", + "bau_incidence", + "int_incidence", + "bau_prevalence", + "int_prevalence", + "bau_deaths", + "int_deaths", + ] self.clock = builder.time.clock() - self.output_file = output_file(builder.configuration, - self.output_suffix) + self.output_file = output_file(builder.configuration, self.output_suffix) def on_collect_metrics(self, event): pop = self.population_view.get(event.index) @@ -215,29 +245,33 @@ def on_collect_metrics(self, event): # No tracked population remains. return - pop['year'] = self.clock().year - pop['bau_incidence'] = self.bau_incidence(event.index) - pop['int_incidence'] = self.int_incidence(event.index) - pop['bau_prevalence'] = pop[self.bau_C_col] / (pop[self.bau_C_col] + pop[self.bau_S_col]) - pop['int_prevalence'] = pop[self.int_C_col] / (pop[self.bau_C_col] + pop[self.bau_S_col]) - pop['bau_deaths'] = 1000 - pop[self.bau_S_col] - pop[self.bau_C_col] - pop['int_deaths'] = 1000 - pop[self.int_S_col] - pop[self.int_C_col] + pop["year"] = self.clock().year + pop["bau_incidence"] = self.bau_incidence(event.index) + pop["int_incidence"] = self.int_incidence(event.index) + pop["bau_prevalence"] = pop[self.bau_C_col] / ( + pop[self.bau_C_col] + pop[self.bau_S_col] + ) + pop["int_prevalence"] = pop[self.int_C_col] / ( + pop[self.bau_C_col] + pop[self.bau_S_col] + ) + pop["bau_deaths"] = 1000 - pop[self.bau_S_col] - pop[self.bau_C_col] + pop["int_deaths"] = 1000 - pop[self.int_S_col] - pop[self.int_C_col] self.tables.append(pop.loc[:, self.table_cols]) def write_output(self, event): data = pd.concat(self.tables, ignore_index=True) - data['diff_incidence'] = data['int_incidence'] - data['bau_incidence'] - data['diff_prevalence'] = data['int_prevalence'] - data['bau_prevalence'] - data['year_of_birth'] = data['year'] - data['age'] - data['disease'] = self._name + data["diff_incidence"] = data["int_incidence"] - data["bau_incidence"] + data["diff_prevalence"] = data["int_prevalence"] - data["bau_prevalence"] + data["year_of_birth"] = data["year"] - data["age"] + data["disease"] = self._name # Sort the table by cohort (i.e., generation and sex), and then by # calendar year, so that results are output in the same order as in # the spreadsheet models. - data = data.sort_values(by=['year_of_birth', 'sex', 'age'], axis=0) + data = data.sort_values(by=["year_of_birth", "sex", "age"], axis=0) data = data.reset_index(drop=True) # Re-order the table columns. - diff_cols = ['diff_incidence', 'diff_prevalence'] - cols = ['disease', 'year_of_birth'] + self.table_cols + diff_cols + diff_cols = ["diff_incidence", "diff_prevalence"] + cols = ["disease", "year_of_birth"] + self.table_cols + diff_cols data = data[cols] data.to_csv(self.output_file, index=False) @@ -253,32 +287,39 @@ class TobaccoPrevalence: """ - def __init__(self, output_suffix='tobacco'): + def __init__(self, output_suffix="tobacco"): self.output_suffix = output_suffix - + @property def name(self): - return 'tobacco_prevalence_observer' + return "tobacco_prevalence_observer" def setup(self, builder): self.config = builder.configuration self.clock = builder.time.clock() - self.bin_years = int(self.config['tobacco']['delay']) + self.bin_years = int(self.config["tobacco"]["delay"]) - view_columns = ['age', 'sex', 'bau_population', 'population'] + self.get_bin_names() + view_columns = ["age", "sex", "bau_population", "population"] + self.get_bin_names() self.population_view = builder.population.get_view(view_columns) self.tables = [] - self.table_cols = ['age', 'sex', 'year', - 'bau_no', 'bau_yes', 'bau_previously', 'bau_population', - 'int_no', 'int_yes', 'int_previously', 'int_population'] - - builder.event.register_listener('collect_metrics', - self.on_collect_metrics) - builder.event.register_listener('simulation_end', - self.write_output) - self.output_file = output_file(builder.configuration, - self.output_suffix) + self.table_cols = [ + "age", + "sex", + "year", + "bau_no", + "bau_yes", + "bau_previously", + "bau_population", + "int_no", + "int_yes", + "int_previously", + "int_population", + ] + + builder.event.register_listener("collect_metrics", self.on_collect_metrics) + builder.event.register_listener("simulation_end", self.write_output) + self.output_file = output_file(builder.configuration, self.output_suffix) def get_bin_names(self): """Return the bin names for both the BAU and the intervention scenario. @@ -302,9 +343,9 @@ def get_bin_names(self): delay_bins = [str(0)] else: delay_bins = [str(s) for s in range(self.bin_years + 2)] - bins = ['no', 'yes'] + delay_bins - bau_bins = ['{}.{}'.format('tobacco', bin) for bin in bins] - int_bins = ['{}_intervention.{}'.format('tobacco', bin) for bin in bins] + bins = ["no", "yes"] + delay_bins + bau_bins = ["{}.{}".format("tobacco", bin) for bin in bins] + int_bins = ["{}_intervention.{}".format("tobacco", bin) for bin in bins] all_bins = bau_bins + int_bins return all_bins @@ -314,36 +355,38 @@ def on_collect_metrics(self, event): # No tracked population remains. return - bau_cols = [c for c in pop.columns.values - if c.startswith('{}.'.format('tobacco'))] - int_cols = [c for c in pop.columns.values - if c.startswith('{}_intervention.'.format('tobacco'))] + bau_cols = [c for c in pop.columns.values if c.startswith("{}.".format("tobacco"))] + int_cols = [ + c + for c in pop.columns.values + if c.startswith("{}_intervention.".format("tobacco")) + ] bau_denom = pop.reindex(columns=bau_cols).sum(axis=1) int_denom = pop.reindex(columns=int_cols).sum(axis=1) # Normalise prevalence with respect to the total population. - pop['bau_no'] = pop['{}.no'.format('tobacco')] / bau_denom - pop['bau_yes'] = pop['{}.yes'.format('tobacco')] / bau_denom - pop['bau_previously'] = 1 - pop['bau_no'] - pop['bau_yes'] - pop['int_no'] = pop['{}_intervention.no'.format('tobacco')] / int_denom - pop['int_yes'] = pop['{}_intervention.yes'.format('tobacco')] / int_denom - pop['int_previously'] = 1 - pop['int_no'] - pop['int_yes'] + pop["bau_no"] = pop["{}.no".format("tobacco")] / bau_denom + pop["bau_yes"] = pop["{}.yes".format("tobacco")] / bau_denom + pop["bau_previously"] = 1 - pop["bau_no"] - pop["bau_yes"] + pop["int_no"] = pop["{}_intervention.no".format("tobacco")] / int_denom + pop["int_yes"] = pop["{}_intervention.yes".format("tobacco")] / int_denom + pop["int_previously"] = 1 - pop["int_no"] - pop["int_yes"] - pop = pop.rename(columns={'population': 'int_population'}) + pop = pop.rename(columns={"population": "int_population"}) - pop['year'] = self.clock().year + pop["year"] = self.clock().year self.tables.append(pop.reindex(columns=self.table_cols).reset_index(drop=True)) def write_output(self, event): data = pd.concat(self.tables, ignore_index=True) - data['year_of_birth'] = data['year'] - data['age'] + data["year_of_birth"] = data["year"] - data["age"] # Sort the table by cohort (i.e., generation and sex), and then by # calendar year, so that results are output in the same order as in # the spreadsheet models. - data = data.sort_values(by=['year_of_birth', 'sex', 'age'], axis=0) + data = data.sort_values(by=["year_of_birth", "sex", "age"], axis=0) data = data.reset_index(drop=True) # Re-order the table columns. - cols = ['year_of_birth'] + self.table_cols + cols = ["year_of_birth"] + self.table_cols data = data.reindex(columns=cols) data.to_csv(self.output_file, index=False) diff --git a/src/vivarium_public_health/mslt/population.py b/src/vivarium_public_health/mslt/population.py index ec666261c..f07dd3ee4 100644 --- a/src/vivarium_public_health/mslt/population.py +++ b/src/vivarium_public_health/mslt/population.py @@ -33,26 +33,38 @@ class BasePopulation: """ configuration_defaults = { - 'population': { - 'max_age': 110, + "population": { + "max_age": 110, } } @property def name(self): - return 'base_population' - + return "base_population" + def setup(self, builder): """Load the population data.""" - columns = ['age', 'sex', 'population', 'bau_population', - 'acmr', 'bau_acmr', - 'pr_death', 'bau_pr_death', 'deaths', 'bau_deaths', - 'yld_rate', 'bau_yld_rate', - 'person_years', 'bau_person_years', - 'HALY', 'bau_HALY'] + columns = [ + "age", + "sex", + "population", + "bau_population", + "acmr", + "bau_acmr", + "pr_death", + "bau_pr_death", + "deaths", + "bau_deaths", + "yld_rate", + "bau_yld_rate", + "person_years", + "bau_person_years", + "HALY", + "bau_HALY", + ] self.pop_data = load_population_data(builder) - + # Create additional columns with placeholder (zero) values. for column in columns: if column in self.pop_data.columns: @@ -65,11 +77,13 @@ def setup(self, builder): self.clock = builder.time.clock() # Track all of the quantities that exist in the core spreadsheet table. - builder.population.initializes_simulants(self.on_initialize_simulants, creates_columns=columns) - self.population_view = builder.population.get_view(columns + ['tracked']) + builder.population.initializes_simulants( + self.on_initialize_simulants, creates_columns=columns + ) + self.population_view = builder.population.get_view(columns + ["tracked"]) # Age cohorts before each time-step (except the first time-step). - builder.event.register_listener('time_step__prepare', self.on_time_step_prepare) + builder.event.register_listener("time_step__prepare", self.on_time_step_prepare) def on_initialize_simulants(self, _): """Initialize each cohort.""" @@ -77,11 +91,11 @@ def on_initialize_simulants(self, _): def on_time_step_prepare(self, event): """Remove cohorts that have reached the maximum age.""" - pop = self.population_view.get(event.index, query='tracked == True') + pop = self.population_view.get(event.index, query="tracked == True") # Only increase cohort ages after the first time-step. if self.clock().year > self.start_year: - pop['age'] += 1 - pop.loc[pop.age > self.max_age, 'tracked'] = False + pop["age"] += 1 + pop.loc[pop.age > self.max_age, "tracked"] = False self.population_view.update(pop) @@ -90,26 +104,37 @@ class Mortality: This component reduces the population size of each cohort over time, according to the all-cause mortality rate. """ - + @property def name(self): - return 'mortality' + return "mortality" def setup(self, builder): """Load the all-cause mortality rate.""" - mortality_data = builder.data.load('cause.all_causes.mortality') + mortality_data = builder.data.load("cause.all_causes.mortality") self.mortality_rate = builder.value.register_rate_producer( - 'mortality_rate', source=builder.lookup.build_table(mortality_data, - key_columns=['sex'], - parameter_columns=['age','year'])) - - builder.event.register_listener('time_step', self.on_time_step) - - self.population_view = builder.population.get_view(['population', 'bau_population', - 'acmr', 'bau_acmr', - 'pr_death', 'bau_pr_death', - 'deaths', 'bau_deaths', - 'person_years', 'bau_person_years']) + "mortality_rate", + source=builder.lookup.build_table( + mortality_data, key_columns=["sex"], parameter_columns=["age", "year"] + ), + ) + + builder.event.register_listener("time_step", self.on_time_step) + + self.population_view = builder.population.get_view( + [ + "population", + "bau_population", + "acmr", + "bau_acmr", + "pr_death", + "bau_pr_death", + "deaths", + "bau_deaths", + "person_years", + "bau_person_years", + ] + ) def on_time_step(self, event): """ @@ -142,25 +167,31 @@ class Disability: cohort over time, according to the years lost due to disability (YLD) rate. """ - + @property def name(self): - return 'disability' + return "disability" def setup(self, builder): """Load the years lost due to disability (YLD) rate.""" - yld_data = builder.data.load('cause.all_causes.disability_rate') - yld_rate = builder.lookup.build_table(yld_data, - key_columns=['sex'], - parameter_columns=['age','year']) - self.yld_rate = builder.value.register_rate_producer('yld_rate', source=yld_rate) - - builder.event.register_listener('time_step', self.on_time_step) - - self.population_view = builder.population.get_view([ - 'bau_yld_rate', 'yld_rate', - 'bau_person_years', 'person_years', - 'bau_HALY', 'HALY']) + yld_data = builder.data.load("cause.all_causes.disability_rate") + yld_rate = builder.lookup.build_table( + yld_data, key_columns=["sex"], parameter_columns=["age", "year"] + ) + self.yld_rate = builder.value.register_rate_producer("yld_rate", source=yld_rate) + + builder.event.register_listener("time_step", self.on_time_step) + + self.population_view = builder.population.get_view( + [ + "bau_yld_rate", + "yld_rate", + "bau_person_years", + "person_years", + "bau_HALY", + "HALY", + ] + ) def on_time_step(self, event): """ @@ -178,7 +209,7 @@ def on_time_step(self, event): def load_population_data(builder): - pop_data = builder.data.load('population.structure') - pop_data = pop_data[['age', 'sex', 'value']].rename(columns={'value': 'population'}) - pop_data['bau_population'] = pop_data['population'] + pop_data = builder.data.load("population.structure") + pop_data = pop_data[["age", "sex", "value"]].rename(columns={"value": "population"}) + pop_data["bau_population"] = pop_data["population"] return pop_data diff --git a/src/vivarium_public_health/population/__init__.py b/src/vivarium_public_health/population/__init__.py index 0dc5e1591..fc9216afd 100644 --- a/src/vivarium_public_health/population/__init__.py +++ b/src/vivarium_public_health/population/__init__.py @@ -1,3 +1,7 @@ +from .add_new_birth_cohorts import ( + FertilityAgeSpecificRates, + FertilityCrudeBirthRate, + FertilityDeterministic, +) from .base_population import BasePopulation, generate_population -from .add_new_birth_cohorts import FertilityAgeSpecificRates, FertilityCrudeBirthRate, FertilityDeterministic from .mortality import Mortality diff --git a/src/vivarium_public_health/population/add_new_birth_cohorts.py b/src/vivarium_public_health/population/add_new_birth_cohorts.py index f671e7074..93f84dab2 100644 --- a/src/vivarium_public_health/population/add_new_birth_cohorts.py +++ b/src/vivarium_public_health/population/add_new_birth_cohorts.py @@ -6,22 +6,24 @@ This module contains several different models of fertility. """ -import pandas as pd import numpy as np +import pandas as pd from vivarium_public_health import utilities -from vivarium_public_health.population.data_transformations import get_live_births_per_year +from vivarium_public_health.population.data_transformations import ( + get_live_births_per_year, +) # TODO: Incorporate better data into gestational model (probably as a separate component) -PREGNANCY_DURATION = pd.Timedelta(days=9*utilities.DAYS_PER_MONTH) +PREGNANCY_DURATION = pd.Timedelta(days=9 * utilities.DAYS_PER_MONTH) class FertilityDeterministic: """Deterministic model of births.""" configuration_defaults = { - 'fertility': { - 'number_of_new_simulants_each_year': 1000, + "fertility": { + "number_of_new_simulants_each_year": 1000, }, } @@ -31,11 +33,13 @@ def name(self): def setup(self, builder): self.fractional_new_births = 0 - self.simulants_per_year = builder.configuration.fertility.number_of_new_simulants_each_year + self.simulants_per_year = ( + builder.configuration.fertility.number_of_new_simulants_each_year + ) self.simulant_creator = builder.population.get_simulant_creator() - builder.event.register_listener('time_step', self.on_time_step) + builder.event.register_listener("time_step", self.on_time_step) def on_time_step(self, event): """Adds a set number of simulants to the population each time step. @@ -47,18 +51,20 @@ def on_time_step(self, event): """ # Assume births are uniformly distributed throughout the year. step_size = utilities.to_years(event.step_size) - simulants_to_add = self.simulants_per_year*step_size + self.fractional_new_births + simulants_to_add = self.simulants_per_year * step_size + self.fractional_new_births self.fractional_new_births = simulants_to_add % 1 simulants_to_add = int(simulants_to_add) if simulants_to_add > 0: - self.simulant_creator(simulants_to_add, - population_configuration={ - 'age_start': 0, - 'age_end': 0, - 'sim_state': 'time_step', - }) + self.simulant_creator( + simulants_to_add, + population_configuration={ + "age_start": 0, + "age_end": 0, + "sim_state": "time_step", + }, + ) def __repr__(self): return "FertilityDeterministic()" @@ -92,9 +98,9 @@ class FertilityCrudeBirthRate: """ configuration_defaults = { - 'fertility': { - 'time_dependent_live_births': True, - 'time_dependent_population_fraction': False, + "fertility": { + "time_dependent_live_births": True, + "time_dependent_population_fraction": False, } } @@ -109,7 +115,7 @@ def setup(self, builder): self.randomness = builder.randomness self.simulant_creator = builder.population.get_simulant_creator() - builder.event.register_listener('time_step', self.on_time_step) + builder.event.register_listener("time_step", self.on_time_step) def on_time_step(self, event): """Adds new simulants every time step based on the Crude Birth Rate @@ -124,16 +130,18 @@ def on_time_step(self, event): mean_births = birth_rate * step_size # Assume births occur as a Poisson process - r = np.random.RandomState(seed=self.randomness.get_seed('crude_birth_rate')) + r = np.random.RandomState(seed=self.randomness.get_seed("crude_birth_rate")) simulants_to_add = r.poisson(mean_births) if simulants_to_add > 0: - self.simulant_creator(simulants_to_add, - population_configuration={ - 'age_start': 0, - 'age_end': 0, - 'sim_state': 'time_step', - }) + self.simulant_creator( + simulants_to_add, + population_configuration={ + "age_start": 0, + "age_end": 0, + "sim_state": "time_step", + }, + ) def __repr__(self): return "FertilityCrudeBirthRate()" @@ -146,10 +154,10 @@ class FertilityAgeSpecificRates: @property def name(self): - return 'age_specific_fertility' + return "age_specific_fertility" def setup(self, builder): - """ Setup the common randomness stream and + """Setup the common randomness stream and age-specific fertility lookup tables. Parameters ---------- @@ -157,37 +165,47 @@ def setup(self, builder): Framework coordination object. """ age_specific_fertility_rate = self.load_age_specific_fertility_rate_data(builder) - fertility_rate = builder.lookup.build_table(age_specific_fertility_rate, parameter_columns=['age', 'year']) - self.fertility_rate = builder.value.register_rate_producer('fertility rate', - source=fertility_rate, - requires_columns=['age']) - - self.randomness = builder.randomness.get_stream('fertility') - - self.population_view = builder.population.get_view(['last_birth_time', 'sex', 'parent_id']) + fertility_rate = builder.lookup.build_table( + age_specific_fertility_rate, parameter_columns=["age", "year"] + ) + self.fertility_rate = builder.value.register_rate_producer( + "fertility rate", source=fertility_rate, requires_columns=["age"] + ) + + self.randomness = builder.randomness.get_stream("fertility") + + self.population_view = builder.population.get_view( + ["last_birth_time", "sex", "parent_id"] + ) self.simulant_creator = builder.population.get_simulant_creator() - builder.population.initializes_simulants(self.on_initialize_simulants, - creates_columns=['last_birth_time', 'parent_id'], - requires_columns=['sex']) + builder.population.initializes_simulants( + self.on_initialize_simulants, + creates_columns=["last_birth_time", "parent_id"], + requires_columns=["sex"], + ) - builder.event.register_listener('time_step', self.on_time_step) + builder.event.register_listener("time_step", self.on_time_step) def on_initialize_simulants(self, pop_data): - """ Adds 'last_birth_time' and 'parent' columns to the state table.""" - pop = self.population_view.subview(['sex']).get(pop_data.index) - women = pop.loc[pop.sex == 'Female'].index + """Adds 'last_birth_time' and 'parent' columns to the state table.""" + pop = self.population_view.subview(["sex"]).get(pop_data.index) + women = pop.loc[pop.sex == "Female"].index - if pop_data.user_data['sim_state'] == 'setup': + if pop_data.user_data["sim_state"] == "setup": parent_id = -1 else: # 'sim_state' == 'time_step' - parent_id = pop_data.user_data['parent_ids'] - pop_update = pd.DataFrame({'last_birth_time': pd.NaT, 'parent_id': parent_id}, index=pop_data.index) + parent_id = pop_data.user_data["parent_ids"] + pop_update = pd.DataFrame( + {"last_birth_time": pd.NaT, "parent_id": parent_id}, index=pop_data.index + ) # FIXME: This is a misuse of the column and makes it invalid for # tracking metrics. # Do the naive thing, set so all women can have children # and none of them have had a child in the last year. - pop_update.loc[women, 'last_birth_time'] = pop_data.creation_time - pd.Timedelta(days=utilities.DAYS_PER_YEAR) + pop_update.loc[women, "last_birth_time"] = pop_data.creation_time - pd.Timedelta( + days=utilities.DAYS_PER_YEAR + ) self.population_view.update(pop_update) @@ -200,32 +218,36 @@ def on_time_step(self, event): """ # Get a view on all living women who haven't had a child in at least nine months. nine_months_ago = pd.Timestamp(event.time - PREGNANCY_DURATION) - population = self.population_view.get(event.index, query='alive == "alive" and sex =="Female"') + population = self.population_view.get( + event.index, query='alive == "alive" and sex =="Female"' + ) can_have_children = population.last_birth_time < nine_months_ago eligible_women = population[can_have_children] rate_series = self.fertility_rate(eligible_women.index) had_children = self.randomness.filter_for_rate(eligible_women, rate_series).copy() - had_children.loc[:, 'last_birth_time'] = event.time - self.population_view.update(had_children['last_birth_time']) + had_children.loc[:, "last_birth_time"] = event.time + self.population_view.update(had_children["last_birth_time"]) # If children were born, add them to the state table and record # who their mother was. num_babies = len(had_children) if num_babies: - self.simulant_creator(num_babies, - population_configuration={ - 'age_start': 0, - 'age_end': 0, - 'sim_state': 'time_step', - 'parent_ids': had_children.index - }) + self.simulant_creator( + num_babies, + population_configuration={ + "age_start": 0, + "age_end": 0, + "sim_state": "time_step", + "parent_ids": had_children.index, + }, + ) def load_age_specific_fertility_rate_data(self, builder): asfr_data = builder.data.load("covariate.age_specific_fertility_rate.estimate") - columns = ['year_start', 'year_end', 'age_start', 'age_end', 'mean_value'] - asfr_data = asfr_data.loc[asfr_data.sex == 'Female'][columns] + columns = ["year_start", "year_end", "age_start", "age_end", "mean_value"] + asfr_data = asfr_data.loc[asfr_data.sex == "Female"][columns] return asfr_data def __repr__(self): diff --git a/src/vivarium_public_health/population/base_population.py b/src/vivarium_public_health/population/base_population.py index d7276492b..41c591518 100644 --- a/src/vivarium_public_health/population/base_population.py +++ b/src/vivarium_public_health/population/base_population.py @@ -7,23 +7,26 @@ characteristics to simulants. """ -import pandas as pd import numpy as np +import pandas as pd from vivarium_public_health import utilities -from vivarium_public_health.population.data_transformations import (assign_demographic_proportions, - rescale_binned_proportions, - smooth_ages, load_population_structure) +from vivarium_public_health.population.data_transformations import ( + assign_demographic_proportions, + load_population_structure, + rescale_binned_proportions, + smooth_ages, +) class BasePopulation: """Component for producing and aging simulants based on demographic data.""" configuration_defaults = { - 'population': { - 'age_start': 0, - 'age_end': 125, - 'exit_age': None, + "population": { + "age_start": 0, + "age_end": 125, + "exit_age": None, } } @@ -42,30 +45,40 @@ def sub_components(self): def setup(self, builder): self.config = builder.configuration.population - self.randomness = {'general_purpose': builder.randomness.get_stream('population_generation'), - 'bin_selection': builder.randomness.get_stream('bin_selection', for_initialization=True), - 'age_smoothing': builder.randomness.get_stream('age_smoothing', for_initialization=True), - 'age_smoothing_age_bounds': builder.randomness.get_stream('age_smoothing_age_bounds', - for_initialization=True)} + self.randomness = { + "general_purpose": builder.randomness.get_stream("population_generation"), + "bin_selection": builder.randomness.get_stream( + "bin_selection", for_initialization=True + ), + "age_smoothing": builder.randomness.get_stream( + "age_smoothing", for_initialization=True + ), + "age_smoothing_age_bounds": builder.randomness.get_stream( + "age_smoothing_age_bounds", for_initialization=True + ), + } self.register_simulants = builder.randomness.register_simulants - columns = ['age', 'sex', 'alive', 'location', 'entrance_time', 'exit_time'] + columns = ["age", "sex", "alive", "location", "entrance_time", "exit_time"] self.population_view = builder.population.get_view(columns) - builder.population.initializes_simulants(self.generate_base_population, - creates_columns=columns) + builder.population.initializes_simulants( + self.generate_base_population, creates_columns=columns + ) source_population_structure = load_population_structure(builder) self.population_data = _build_population_data_table(source_population_structure) - builder.event.register_listener('time_step', self.on_time_step, priority=8) + builder.event.register_listener("time_step", self.on_time_step, priority=8) @staticmethod def select_sub_population_data(reference_population_data, year): reference_years = sorted(set(reference_population_data.year_start)) - ref_year_index = np.digitize(year, reference_years).item()-1 - return reference_population_data[reference_population_data.year_start == reference_years[ref_year_index]] + ref_year_index = np.digitize(year, reference_years).item() - 1 + return reference_population_data[ + reference_population_data.year_start == reference_years[ref_year_index] + ] # TODO: Move most of this docstring to an rst file. def generate_base_population(self, pop_data): @@ -91,18 +104,26 @@ def generate_base_population(self, pop_data): pop_data """ - age_params = {'age_start': pop_data.user_data.get('age_start', self.config.age_start), - 'age_end': pop_data.user_data.get('age_end', self.config.age_end)} - - sub_pop_data = self.select_sub_population_data(self.population_data, pop_data.creation_time.year) + age_params = { + "age_start": pop_data.user_data.get("age_start", self.config.age_start), + "age_end": pop_data.user_data.get("age_end", self.config.age_end), + } - self.population_view.update(generate_population(simulant_ids=pop_data.index, - creation_time=pop_data.creation_time, - step_size=pop_data.creation_window, - age_params=age_params, - population_data=sub_pop_data, - randomness_streams=self.randomness, - register_simulants=self.register_simulants)) + sub_pop_data = self.select_sub_population_data( + self.population_data, pop_data.creation_time.year + ) + + self.population_view.update( + generate_population( + simulant_ids=pop_data.index, + creation_time=pop_data.creation_time, + step_size=pop_data.creation_window, + age_params=age_params, + population_data=sub_pop_data, + randomness_streams=self.randomness, + register_simulants=self.register_simulants, + ) + ) def on_time_step(self, event): """Ages simulants each time step. @@ -113,7 +134,7 @@ def on_time_step(self, event): """ population = self.population_view.get(event.index, query="alive == 'alive'") - population['age'] += utilities.to_years(event.step_size) + population["age"] += utilities.to_years(event.step_size) self.population_view.update(population) def __repr__(self): @@ -133,24 +154,31 @@ def setup(self, builder): if builder.configuration.population.exit_age is None: return self.config = builder.configuration.population - self.population_view = builder.population.get_view(['age', 'exit_time', 'tracked']) - builder.event.register_listener('time_step__cleanup', self.on_time_step_cleanup) + self.population_view = builder.population.get_view(["age", "exit_time", "tracked"]) + builder.event.register_listener("time_step__cleanup", self.on_time_step_cleanup) def on_time_step_cleanup(self, event): population = self.population_view.get(event.index) max_age = float(self.config.exit_age) - pop = population[(population['age'] >= max_age) & population['tracked']].copy() + pop = population[(population["age"] >= max_age) & population["tracked"]].copy() if len(pop) > 0: - pop['tracked'] = pd.Series(False, index=pop.index) - pop['exit_time'] = event.time + pop["tracked"] = pd.Series(False, index=pop.index) + pop["exit_time"] = event.time self.population_view.update(pop) def __repr__(self): return "AgeOutSimulants()" -def generate_population(simulant_ids, creation_time, step_size, age_params, - population_data, randomness_streams, register_simulants): +def generate_population( + simulant_ids, + creation_time, + step_size, + age_params, + population_data, + randomness_streams, + register_simulants, +): """Produces a randomly generated set of simulants sampled from the provided `population_data`. Parameters @@ -196,22 +224,39 @@ def generate_population(simulant_ids, creation_time, step_size, age_params, Either 'Male' or 'Female'. The sex of the simulant. """ - simulants = pd.DataFrame({'entrance_time': pd.Series(creation_time, index=simulant_ids), - 'exit_time': pd.Series(pd.NaT, index=simulant_ids), - 'alive': pd.Series('alive', index=simulant_ids)}, - index=simulant_ids) - age_start = float(age_params['age_start']) - age_end = float(age_params['age_end']) + simulants = pd.DataFrame( + { + "entrance_time": pd.Series(creation_time, index=simulant_ids), + "exit_time": pd.Series(pd.NaT, index=simulant_ids), + "alive": pd.Series("alive", index=simulant_ids), + }, + index=simulant_ids, + ) + age_start = float(age_params["age_start"]) + age_end = float(age_params["age_end"]) if age_start == age_end: - return _assign_demography_with_initial_age(simulants, population_data, age_start, - step_size, randomness_streams, register_simulants) + return _assign_demography_with_initial_age( + simulants, + population_data, + age_start, + step_size, + randomness_streams, + register_simulants, + ) else: # age_params['age_start'] is not None and age_params['age_end'] is not None - return _assign_demography_with_age_bounds(simulants, population_data, age_start, - age_end, randomness_streams, register_simulants) - - -def _assign_demography_with_initial_age(simulants, pop_data, initial_age, - step_size, randomness_streams, register_simulants): + return _assign_demography_with_age_bounds( + simulants, + population_data, + age_start, + age_end, + randomness_streams, + register_simulants, + ) + + +def _assign_demography_with_initial_age( + simulants, pop_data, initial_age, step_size, randomness_streams, register_simulants +): """Assigns age, sex, and location information to the provided simulants given a fixed age. Parameters @@ -235,28 +280,40 @@ def _assign_demography_with_initial_age(simulants, pop_data, initial_age, pandas.DataFrame Table with same columns as `simulants` and with the additional columns 'age', 'sex', and 'location'. """ - pop_data = pop_data[(pop_data.age_start <= initial_age) & (pop_data.age_end >= initial_age)] + pop_data = pop_data[ + (pop_data.age_start <= initial_age) & (pop_data.age_end >= initial_age) + ] if pop_data.empty: - raise ValueError('The age {} is not represented by the population data structure'.format(initial_age)) + raise ValueError( + "The age {} is not represented by the population data structure".format( + initial_age + ) + ) - age_fuzz = randomness_streams['age_smoothing'].get_draw(simulants.index) * utilities.to_years(step_size) - simulants['age'] = initial_age + age_fuzz - register_simulants(simulants[['entrance_time', 'age']]) + age_fuzz = randomness_streams["age_smoothing"].get_draw( + simulants.index + ) * utilities.to_years(step_size) + simulants["age"] = initial_age + age_fuzz + register_simulants(simulants[["entrance_time", "age"]]) # Assign a demographically accurate location and sex distribution. - choices = pop_data.set_index(['sex', 'location'])['P(sex, location | age, year)'].reset_index() - decisions = randomness_streams['general_purpose'].choice(simulants.index, - choices=choices.index, - p=choices['P(sex, location | age, year)']) + choices = pop_data.set_index(["sex", "location"])[ + "P(sex, location | age, year)" + ].reset_index() + decisions = randomness_streams["general_purpose"].choice( + simulants.index, choices=choices.index, p=choices["P(sex, location | age, year)"] + ) - simulants['sex'] = choices.loc[decisions, 'sex'].values - simulants['location'] = choices.loc[decisions, 'location'].values + simulants["sex"] = choices.loc[decisions, "sex"].values + simulants["location"] = choices.loc[decisions, "location"].values return simulants -def _assign_demography_with_age_bounds(simulants, pop_data, age_start, age_end, randomness_streams, register_simulants): +def _assign_demography_with_age_bounds( + simulants, pop_data, age_start, age_end, randomness_streams, register_simulants +): """Assigns age, sex, and location information to the provided simulants given a range of ages. Parameters @@ -281,19 +338,26 @@ def _assign_demography_with_age_bounds(simulants, pop_data, age_start, age_end, pop_data = rescale_binned_proportions(pop_data, age_start, age_end) if pop_data.empty: raise ValueError( - 'The age range ({}, {}) is not represented by the population data structure'.format(age_start, age_end)) + "The age range ({}, {}) is not represented by the population data structure".format( + age_start, age_end + ) + ) # Assign a demographically accurate age, location, and sex distribution. sub_pop_data = pop_data[(pop_data.age_start >= age_start) & (pop_data.age_end <= age_end)] - choices = sub_pop_data.set_index(['age', 'sex', 'location'])['P(sex, location, age| year)'].reset_index() - decisions = randomness_streams['bin_selection'].choice(simulants.index, - choices=choices.index, - p=choices['P(sex, location, age| year)']) - simulants['age'] = choices.loc[decisions, 'age'].values - simulants['sex'] = choices.loc[decisions, 'sex'].values - simulants['location'] = choices.loc[decisions, 'location'].values - simulants = smooth_ages(simulants, pop_data, randomness_streams['age_smoothing_age_bounds']) - register_simulants(simulants[['entrance_time', 'age']]) + choices = sub_pop_data.set_index(["age", "sex", "location"])[ + "P(sex, location, age| year)" + ].reset_index() + decisions = randomness_streams["bin_selection"].choice( + simulants.index, choices=choices.index, p=choices["P(sex, location, age| year)"] + ) + simulants["age"] = choices.loc[decisions, "age"].values + simulants["sex"] = choices.loc[decisions, "sex"].values + simulants["location"] = choices.loc[decisions, "location"].values + simulants = smooth_ages( + simulants, pop_data, randomness_streams["age_smoothing_age_bounds"] + ) + register_simulants(simulants[["entrance_time", "age"]]) return simulants diff --git a/src/vivarium_public_health/population/data_transformations.py b/src/vivarium_public_health/population/data_transformations.py index 7e985af24..584c5790a 100644 --- a/src/vivarium_public_health/population/data_transformations.py +++ b/src/vivarium_public_health/population/data_transformations.py @@ -30,25 +30,25 @@ def assign_demographic_proportions(population_data): various population levels. """ - population_data['P(sex, location, age| year)'] = ( - population_data - .groupby('year_start', as_index=False) - .apply(lambda sub_pop: sub_pop.value / sub_pop.value.sum()) - .reset_index(level=0).value + population_data["P(sex, location, age| year)"] = ( + population_data.groupby("year_start", as_index=False) + .apply(lambda sub_pop: sub_pop.value / sub_pop.value.sum()) + .reset_index(level=0) + .value ) - population_data['P(sex, location | age, year)'] = ( - population_data - .groupby(['age', 'year_start'], as_index=False) - .apply(lambda sub_pop: sub_pop.value / sub_pop.value.sum()) - .reset_index(level=0).value + population_data["P(sex, location | age, year)"] = ( + population_data.groupby(["age", "year_start"], as_index=False) + .apply(lambda sub_pop: sub_pop.value / sub_pop.value.sum()) + .reset_index(level=0) + .value ) - population_data['P(age | year, sex, location)'] = ( - population_data - .groupby(['year_start', 'sex', 'location'], as_index=False) - .apply(lambda sub_pop: sub_pop.value / sub_pop.value.sum()) - .reset_index(level=0).value + population_data["P(age | year, sex, location)"] = ( + population_data.groupby(["year_start", "sex", "location"], as_index=False) + .apply(lambda sub_pop: sub_pop.value / sub_pop.value.sum()) + .reset_index(level=0) + .value ) return population_data @@ -81,40 +81,48 @@ def rescale_binned_proportions(pop_data, age_start, age_end): values are rescaled to reflect their smaller representation. """ if age_start > pop_data.age_end.max(): - raise ValueError('Provided population data is insufficient to model the requested age range.') + raise ValueError( + "Provided population data is insufficient to model the requested age range." + ) age_start = max(pop_data.age_start.min(), age_start) age_end = min(pop_data.age_end.max(), age_end) - 1e-8 pop_data = _add_edge_age_groups(pop_data.copy()) - columns_to_scale = ['P(sex, location, age| year)', 'P(age | year, sex, location)', 'value'] - for _, sub_pop in pop_data.groupby(['sex', 'location']): + columns_to_scale = [ + "P(sex, location, age| year)", + "P(age | year, sex, location)", + "value", + ] + for _, sub_pop in pop_data.groupby(["sex", "location"]): min_bin = sub_pop[(sub_pop.age_start <= age_start) & (age_start < sub_pop.age_end)] padding_bin = sub_pop[sub_pop.age_end == float(min_bin.age_start)] - min_scale = ((float(min_bin.age_end) - age_start) - / float(min_bin.age_end - min_bin.age_start)) + min_scale = (float(min_bin.age_end) - age_start) / float( + min_bin.age_end - min_bin.age_start + ) remainder = pop_data.loc[min_bin.index, columns_to_scale].values * (1 - min_scale) pop_data.loc[min_bin.index, columns_to_scale] *= min_scale pop_data.loc[padding_bin.index, columns_to_scale] += remainder - pop_data.loc[min_bin.index, 'age_start'] = age_start - pop_data.loc[padding_bin.index, 'age_end'] = age_start + pop_data.loc[min_bin.index, "age_start"] = age_start + pop_data.loc[padding_bin.index, "age_end"] = age_start max_bin = sub_pop[(sub_pop.age_end > age_end) & (age_end >= sub_pop.age_start)] padding_bin = sub_pop[sub_pop.age_start == float(max_bin.age_end)] - max_scale = ((age_end - float(max_bin.age_start)) - / float(max_bin.age_end - max_bin.age_start)) + max_scale = (age_end - float(max_bin.age_start)) / float( + max_bin.age_end - max_bin.age_start + ) remainder = pop_data.loc[max_bin.index, columns_to_scale] * (1 - max_scale) pop_data.loc[max_bin.index, columns_to_scale] *= max_scale pop_data.loc[padding_bin.index, columns_to_scale] += remainder.values - pop_data.loc[max_bin.index, 'age_end'] = age_end - pop_data.loc[padding_bin.index, 'age_start'] = age_end + pop_data.loc[max_bin.index, "age_end"] = age_end + pop_data.loc[padding_bin.index, "age_start"] = age_end return pop_data @@ -132,46 +140,57 @@ def _add_edge_age_groups(pop_data): ------- pandas.DataFrame """ - index_cols = ['location', 'year_start', 'year_end', 'sex'] - age_cols = ['age', 'age_start', 'age_end'] + index_cols = ["location", "year_start", "year_end", "sex"] + age_cols = ["age", "age_start", "age_end"] other_cols = [c for c in pop_data.columns if c not in index_cols + age_cols] pop_data = pop_data.set_index(index_cols) # For the lower bin, we want constant interpolation off the left side - min_valid_age = pop_data['age_start'].min() + min_valid_age = pop_data["age_start"].min() # This bin width needs to be the same as the lowest bin. - min_pad_age = min_valid_age - (pop_data['age_end'].min() - min_valid_age) + min_pad_age = min_valid_age - (pop_data["age_end"].min() - min_valid_age) min_pad_age_midpoint = (min_valid_age + min_pad_age) * 0.5 - lower_bin = pd.DataFrame({'age_start': min_pad_age, - 'age_end': min_valid_age, - 'age': min_pad_age_midpoint}, index=pop_data.index.unique()) - lower_bin[other_cols] = pop_data.loc[pop_data['age_start'] == min_valid_age, other_cols] + lower_bin = pd.DataFrame( + {"age_start": min_pad_age, "age_end": min_valid_age, "age": min_pad_age_midpoint}, + index=pop_data.index.unique(), + ) + lower_bin[other_cols] = pop_data.loc[pop_data["age_start"] == min_valid_age, other_cols] # For the upper bin, we want our interpolation to go to zero. - max_valid_age = pop_data['age_end'].max() + max_valid_age = pop_data["age_end"].max() # This bin width is not arbitrary. It effects the rate at which our interpolation zeros out. # Since for the 2016 round the maximum age is 125, we're assuming almost no one lives past that age, # so we make this bin 1 week. A more robust technique for this would be better. - max_pad_age = max_valid_age + 7/365 + max_pad_age = max_valid_age + 7 / 365 max_pad_age_midpoint = (max_valid_age + max_pad_age) * 0.5 - upper_bin = pd.DataFrame({'age_start': max_valid_age, - 'age_end': max_pad_age, - 'age': max_pad_age_midpoint}, index=pop_data.index.unique()) + upper_bin = pd.DataFrame( + {"age_start": max_valid_age, "age_end": max_pad_age, "age": max_pad_age_midpoint}, + index=pop_data.index.unique(), + ) # We're doing the multiplication to ensure we get the correct data shape and index. - upper_bin[other_cols] = 0 * pop_data.loc[pop_data['age_end'] == max_valid_age, other_cols] + upper_bin[other_cols] = 0 * pop_data.loc[pop_data["age_end"] == max_valid_age, other_cols] pop_data = pd.concat([lower_bin, pop_data, upper_bin], sort=False).reset_index() - pop_data = pop_data.rename(columns={'level_0': 'location', 'level_1': 'year_start', - 'level_2': 'year_end', 'level_3': 'sex'}) - return pop_data[index_cols + age_cols + other_cols].sort_values( - by=['location', 'year_start', 'year_end', 'age']).reset_index(drop=True) + pop_data = pop_data.rename( + columns={ + "level_0": "location", + "level_1": "year_start", + "level_2": "year_end", + "level_3": "sex", + } + ) + return ( + pop_data[index_cols + age_cols + other_cols] + .sort_values(by=["location", "year_start", "year_end", "age"]) + .reset_index(drop=True) + ) -AgeValues = namedtuple('AgeValues', ['current', 'young', 'old']) -EndpointValues = namedtuple('EndpointValues', ['left', 'right']) +AgeValues = namedtuple("AgeValues", ["current", "young", "old"]) +EndpointValues = namedtuple("EndpointValues", ["left", "right"]) def smooth_ages(simulants, population_data, randomness): @@ -194,19 +213,22 @@ def smooth_ages(simulants, population_data, randomness): Table with same columns as `simulants` with ages smoothed out within the age bins. """ simulants = simulants.copy() - for (sex, location), sub_pop in population_data.groupby(['sex', 'location']): + for (sex, location), sub_pop in population_data.groupby(["sex", "location"]): ages = sorted(sub_pop.age.unique()) - younger = [float(sub_pop.loc[sub_pop.age == ages[0], 'age_start'])] + ages[:-1] - older = ages[1:] + [float(sub_pop.loc[sub_pop.age == ages[-1], 'age_end'])] + younger = [float(sub_pop.loc[sub_pop.age == ages[0], "age_start"])] + ages[:-1] + older = ages[1:] + [float(sub_pop.loc[sub_pop.age == ages[-1], "age_end"])] uniform_all = randomness.get_draw(simulants.index) for age_set in zip(ages, younger, older): age = AgeValues(*age_set) - has_correct_demography = ((simulants.age == age.current) - & (simulants.sex == sex) & (simulants.location == location)) + has_correct_demography = ( + (simulants.age == age.current) + & (simulants.sex == sex) + & (simulants.location == location) + ) affected = simulants[has_correct_demography] if affected.empty: @@ -214,7 +236,9 @@ def smooth_ages(simulants, population_data, randomness): # bin endpoints endpoints, proportions = _get_bins_and_proportions(sub_pop, age) - pdf, slope, area, cdf_inflection_point = _construct_sampling_parameters(age, endpoints, proportions) + pdf, slope, area, cdf_inflection_point = _construct_sampling_parameters( + age, endpoints, proportions + ) # Make a draw from a uniform distribution uniform_rv = uniform_all.loc[affected.index] @@ -222,10 +246,16 @@ def smooth_ages(simulants, population_data, randomness): left_sims = affected[uniform_rv <= cdf_inflection_point] right_sims = affected[uniform_rv > cdf_inflection_point] - simulants.loc[left_sims.index, 'age'] = _compute_ages(uniform_rv[left_sims.index], - endpoints.left, pdf.left, slope.left, area) - simulants.loc[right_sims.index, 'age'] = _compute_ages(uniform_rv[right_sims.index] - cdf_inflection_point, - age.current, proportions.current, slope.right, area) + simulants.loc[left_sims.index, "age"] = _compute_ages( + uniform_rv[left_sims.index], endpoints.left, pdf.left, slope.left, area + ) + simulants.loc[right_sims.index, "age"] = _compute_ages( + uniform_rv[right_sims.index] - cdf_inflection_point, + age.current, + proportions.current, + slope.right, + area, + ) return simulants @@ -251,26 +281,40 @@ def _get_bins_and_proportions(pop_data, age): The `AgeValues` tuple has values (proportion of pop in current bin, proportion of pop in previous bin, proportion of pop in next bin) """ - left = float(pop_data.loc[pop_data.age == age.current, 'age_start']) - right = float(pop_data.loc[pop_data.age == age.current, 'age_end']) + left = float(pop_data.loc[pop_data.age == age.current, "age_start"]) + right = float(pop_data.loc[pop_data.age == age.current, "age_end"]) - if not pop_data.loc[pop_data.age == age.young, 'age_start'].empty: - lower_left = float(pop_data.loc[pop_data.age == age.young, 'age_start']) + if not pop_data.loc[pop_data.age == age.young, "age_start"].empty: + lower_left = float(pop_data.loc[pop_data.age == age.young, "age_start"]) else: lower_left = left - if not pop_data.loc[pop_data.age == age.old, 'age_end'].empty: - upper_right = float(pop_data.loc[pop_data.age == age.old, 'age_end']) + if not pop_data.loc[pop_data.age == age.old, "age_end"].empty: + upper_right = float(pop_data.loc[pop_data.age == age.old, "age_end"]) else: upper_right = right # proportion in this bin and the neighboring bins - proportion_column = 'P(age | year, sex, location)' + proportion_column = "P(age | year, sex, location)" # Here we make the assumption that P(left < age < right | year, sex, location) = p * (right - left) # in order to back out a point estimate for the probability density at the center of the interval. # This not the best assumption, but it'll do. - p_age = float(pop_data.loc[pop_data.age == age.current, proportion_column]/(right - left)) - p_young = float(pop_data.loc[pop_data.age == age.young, proportion_column]/(left - lower_left)) if age.young != left else p_age - p_old = float(pop_data.loc[pop_data.age == age.old, proportion_column]/(upper_right - right)) if age.old != right else 0 + p_age = float( + pop_data.loc[pop_data.age == age.current, proportion_column] / (right - left) + ) + p_young = ( + float( + pop_data.loc[pop_data.age == age.young, proportion_column] / (left - lower_left) + ) + if age.young != left + else p_age + ) + p_old = ( + float( + pop_data.loc[pop_data.age == age.old, proportion_column] / (upper_right - right) + ) + if age.old != right + else 0 + ) return EndpointValues(left, right), AgeValues(p_age, p_young, p_old) @@ -302,12 +346,16 @@ def _construct_sampling_parameters(age, endpoint, proportion): """ # pdf value at bin endpoints - pdf_left = ((proportion.current - proportion.young) / (age.current - age.young) - * (endpoint.left - age.young) + proportion.young) - pdf_right = ((proportion.old - proportion.current) / (age.old - age.current) - * (endpoint.right - age.current) + proportion.current) - area = 0.5 * ((proportion.current + pdf_left) * (age.current - endpoint.left) - + (pdf_right + proportion.current) * (endpoint.right - age.current)) + pdf_left = (proportion.current - proportion.young) / (age.current - age.young) * ( + endpoint.left - age.young + ) + proportion.young + pdf_right = (proportion.old - proportion.current) / (age.old - age.current) * ( + endpoint.right - age.current + ) + proportion.current + area = 0.5 * ( + (proportion.current + pdf_left) * (age.current - endpoint.left) + + (pdf_right + proportion.current) * (endpoint.right - age.current) + ) pdf = EndpointValues(pdf_left, pdf_right) @@ -317,7 +365,9 @@ def _construct_sampling_parameters(age, endpoint, proportion): slope = EndpointValues(m_left, m_right) # The decision bound on the uniform rv. - cdf_inflection_point = 1 / (2 * area) * (proportion.current + pdf.left) * (age.current - endpoint.left) + cdf_inflection_point = ( + 1 / (2 * area) * (proportion.current + pdf.left) * (age.current - endpoint.left) + ) return pdf, slope, area, cdf_inflection_point @@ -350,25 +400,31 @@ def _compute_ages(uniform_rv, start, height, slope, normalization): if abs(slope) < np.finfo(np.float32).eps: return start + normalization / height * uniform_rv else: - return start + height / slope * (np.sqrt(1 + 2 * normalization * slope / height ** 2 * uniform_rv) - 1) + return start + height / slope * ( + np.sqrt(1 + 2 * normalization * slope / height**2 * uniform_rv) - 1 + ) def get_cause_deleted_mortality_rate(all_cause_mortality_rate, list_of_csmrs): - index_cols = ['age_start', 'age_end', 'sex', 'year_start', 'year_end'] + index_cols = ["age_start", "age_end", "sex", "year_start", "year_end"] all_cause_mortality_rate = all_cause_mortality_rate.set_index(index_cols).copy() for csmr in list_of_csmrs: if csmr is None: continue - all_cause_mortality_rate = all_cause_mortality_rate.subtract(csmr.set_index(index_cols)).dropna() + all_cause_mortality_rate = all_cause_mortality_rate.subtract( + csmr.set_index(index_cols) + ).dropna() - return all_cause_mortality_rate.reset_index().rename(columns={'value': 'death_due_to_other_causes'}) + return all_cause_mortality_rate.reset_index().rename( + columns={"value": "death_due_to_other_causes"} + ) def load_population_structure(builder): data = builder.data.load("population.structure") # create an age column which is the midpoint of the age group - data['age'] = data.apply(lambda row: (row['age_start'] + row['age_end']) / 2, axis=1) - data['location'] = builder.data.load('population.location') + data["age"] = data.apply(lambda row: (row["age_start"] + row["age_end"]) / 2, axis=1) + data["location"] = builder.data.load("population.location") return data @@ -380,13 +436,19 @@ def get_live_births_per_year(builder): population_data = rescale_final_age_bin(builder, population_data) initial_population_size = builder.configuration.population.population_size - population_data = population_data.groupby(['year_start'])['value'].sum() - birth_data = (birth_data[birth_data.parameter == 'mean_value'] - .drop('parameter', 'columns') - .groupby(['year_start'])['value'].sum()) + population_data = population_data.groupby(["year_start"])["value"].sum() + birth_data = ( + birth_data[birth_data.parameter == "mean_value"] + .drop("parameter", "columns") + .groupby(["year_start"])["value"] + .sum() + ) start_year = builder.configuration.time.start.year - if builder.configuration.interpolation.extrapolate and start_year > birth_data.index.max(): + if ( + builder.configuration.interpolation.extrapolate + and start_year > birth_data.index.max() + ): start_year = birth_data.index.max() if not builder.configuration.fertility.time_dependent_live_births: @@ -398,41 +460,54 @@ def get_live_births_per_year(builder): live_birth_rate = initial_population_size / population_data * birth_data if isinstance(live_birth_rate, (int, float)): - live_birth_rate = pd.Series(live_birth_rate, index=pd.RangeIndex(builder.configuration.time.start.year, - builder.configuration.time.end.year + 1, - name='year')) + live_birth_rate = pd.Series( + live_birth_rate, + index=pd.RangeIndex( + builder.configuration.time.start.year, + builder.configuration.time.end.year + 1, + name="year", + ), + ) else: - live_birth_rate = (live_birth_rate - .reset_index() - .rename(columns={'year_start': 'year'}) - .set_index('year') - .value) + live_birth_rate = ( + live_birth_rate.reset_index() + .rename(columns={"year_start": "year"}) + .set_index("year") + .value + ) exceeds_data = builder.configuration.time.end.year > live_birth_rate.index.max() if exceeds_data: - new_index = pd.RangeIndex(live_birth_rate.index.min(), builder.configuration.time.end.year + 1) - live_birth_rate = live_birth_rate.reindex(new_index, - fill_value=live_birth_rate.at[live_birth_rate.index.max()]) + new_index = pd.RangeIndex( + live_birth_rate.index.min(), builder.configuration.time.end.year + 1 + ) + live_birth_rate = live_birth_rate.reindex( + new_index, fill_value=live_birth_rate.at[live_birth_rate.index.max()] + ) return live_birth_rate def rescale_final_age_bin(builder, population_data): - exit_age = builder.configuration.population.to_dict().get('exit_age', None) + exit_age = builder.configuration.population.to_dict().get("exit_age", None) if exit_age: - population_data = population_data.loc[population_data['age_start'] < exit_age].copy() - cut_bin_idx = (exit_age <= population_data['age_end']) - cut_age_start = population_data.loc[cut_bin_idx, 'age_start'] - cut_age_end = population_data.loc[cut_bin_idx, 'age_end'] - population_data.loc[cut_bin_idx, 'value'] *= ((exit_age - cut_age_start) / (cut_age_end - cut_age_start)) - population_data.loc[cut_bin_idx, 'age_end'] = exit_age + population_data = population_data.loc[population_data["age_start"] < exit_age].copy() + cut_bin_idx = exit_age <= population_data["age_end"] + cut_age_start = population_data.loc[cut_bin_idx, "age_start"] + cut_age_end = population_data.loc[cut_bin_idx, "age_end"] + population_data.loc[cut_bin_idx, "value"] *= (exit_age - cut_age_start) / ( + cut_age_end - cut_age_start + ) + population_data.loc[cut_bin_idx, "age_end"] = exit_age return population_data def validate_crude_birth_rate_data(builder, data_year_max): - exit_age = builder.configuration.population.to_dict().get('exit_age', None) + exit_age = builder.configuration.population.to_dict().get("exit_age", None) if exit_age and builder.configuration.population.age_end != exit_age: - raise ValueError('If you specify an exit age, the initial population age end must be the same ' - 'for the crude birth rate calculation to work.') + raise ValueError( + "If you specify an exit age, the initial population age end must be the same " + "for the crude birth rate calculation to work." + ) exceeds_data = builder.configuration.time.end.year > data_year_max if exceeds_data and not builder.configuration.interpolation.extrapolate: - raise ValueError('Trying to extrapolate beyond the end of available birth data.') + raise ValueError("Trying to extrapolate beyond the end of available birth data.") diff --git a/src/vivarium_public_health/population/mortality.py b/src/vivarium_public_health/population/mortality.py index 028a0a736..598fc430c 100644 --- a/src/vivarium_public_health/population/mortality.py +++ b/src/vivarium_public_health/population/mortality.py @@ -10,7 +10,6 @@ from typing import Callable import pandas as pd - from vivarium.framework.engine import Builder from vivarium.framework.event import Event from vivarium.framework.lookup import LookupTable @@ -22,13 +21,12 @@ class Mortality: - def __init__(self): - self._randomness_stream_name = 'mortality_handler' - self.cause_specific_mortality_rate_pipeline_name = 'cause_specific_mortality_rate' - self.mortality_rate_pipeline_name = 'mortality_rate' - self.cause_of_death_column_name = 'cause_of_death' - self.years_of_life_lost_column_name = 'years_of_life_lost' + self._randomness_stream_name = "mortality_handler" + self.cause_specific_mortality_rate_pipeline_name = "cause_specific_mortality_rate" + self.mortality_rate_pipeline_name = "mortality_rate" + self.cause_of_death_column_name = "cause_of_death" + self.years_of_life_lost_column_name = "years_of_life_lost" def __repr__(self) -> str: return f"Mortality()" @@ -39,7 +37,7 @@ def __repr__(self) -> str: @property def name(self) -> str: - return 'mortality' + return "mortality" ################# # Setup methods # @@ -68,37 +66,54 @@ def _get_clock(self, builder: Builder) -> Callable[[], Time]: # noinspection PyMethodMayBeStatic def _get_all_cause_mortality_rate(self, builder: Builder) -> LookupTable: acmr_data = builder.data.load("cause.all_causes.cause_specific_mortality_rate") - return builder.lookup.build_table(acmr_data, key_columns=['sex'], parameter_columns=['age', 'year']) + return builder.lookup.build_table( + acmr_data, key_columns=["sex"], parameter_columns=["age", "year"] + ) def _get_cause_specific_mortality_rate(self, builder: Builder) -> Pipeline: return builder.value.register_value_producer( - self.cause_specific_mortality_rate_pipeline_name, source=builder.lookup.build_table(0) + self.cause_specific_mortality_rate_pipeline_name, + source=builder.lookup.build_table(0), ) def _get_mortality_rate(self, builder: Builder) -> Pipeline: return builder.value.register_rate_producer( self.mortality_rate_pipeline_name, source=self._calculate_mortality_rate, - requires_columns=['age', 'sex'] + requires_columns=["age", "sex"], ) # noinspection PyMethodMayBeStatic def _get_life_expectancy(self, builder: Builder) -> LookupTable: - life_expectancy_data = builder.data.load("population.theoretical_minimum_risk_life_expectancy") - return builder.lookup.build_table(life_expectancy_data, parameter_columns=['age']) + life_expectancy_data = builder.data.load( + "population.theoretical_minimum_risk_life_expectancy" + ) + return builder.lookup.build_table(life_expectancy_data, parameter_columns=["age"]) def _get_population_view(self, builder: Builder) -> PopulationView: - return builder.population.get_view([self.cause_of_death_column_name, self.years_of_life_lost_column_name, - 'alive', 'exit_time', 'age', 'sex', 'location']) + return builder.population.get_view( + [ + self.cause_of_death_column_name, + self.years_of_life_lost_column_name, + "alive", + "exit_time", + "age", + "sex", + "location", + ] + ) def _register_simulant_initializer(self, builder: Builder) -> None: builder.population.initializes_simulants( self.on_initialize_simulants, - creates_columns=[self.cause_of_death_column_name, self.years_of_life_lost_column_name] + creates_columns=[ + self.cause_of_death_column_name, + self.years_of_life_lost_column_name, + ], ) def _register_on_timestep_listener(self, builder: Builder) -> None: - builder.event.register_listener('time_step', self.on_time_step, priority=0) + builder.event.register_listener("time_step", self.on_time_step, priority=0) ######################## # Event-driven methods # @@ -106,23 +121,30 @@ def _register_on_timestep_listener(self, builder: Builder) -> None: def on_initialize_simulants(self, pop_data: SimulantData) -> None: pop_update = pd.DataFrame( - {self.cause_of_death_column_name: 'not_dead', self.years_of_life_lost_column_name: 0.}, - index=pop_data.index + { + self.cause_of_death_column_name: "not_dead", + self.years_of_life_lost_column_name: 0.0, + }, + index=pop_data.index, ) self.population_view.update(pop_update) def on_time_step(self, event: Event) -> None: pop = self.population_view.get(event.index, query="alive =='alive'") prob_df = rate_to_probability(pd.DataFrame(self.mortality_rate(pop.index))) - prob_df['no_death'] = 1-prob_df.sum(axis=1) - prob_df['cause_of_death'] = self.random.choice(prob_df.index, prob_df.columns, prob_df) + prob_df["no_death"] = 1 - prob_df.sum(axis=1) + prob_df["cause_of_death"] = self.random.choice( + prob_df.index, prob_df.columns, prob_df + ) dead_pop = prob_df.query('cause_of_death != "no_death"').copy() if not dead_pop.empty: - dead_pop['alive'] = pd.Series('dead', index=dead_pop.index) - dead_pop['exit_time'] = event.time - dead_pop['years_of_life_lost'] = self.life_expectancy(dead_pop.index) - self.population_view.update(dead_pop[['alive', 'exit_time', 'cause_of_death', 'years_of_life_lost']]) + dead_pop["alive"] = pd.Series("dead", index=dead_pop.index) + dead_pop["exit_time"] = event.time + dead_pop["years_of_life_lost"] = self.life_expectancy(dead_pop.index) + self.population_view.update( + dead_pop[["alive", "exit_time", "cause_of_death", "years_of_life_lost"]] + ) ################################## # Pipeline sources and modifiers # @@ -132,4 +154,4 @@ def _calculate_mortality_rate(self, index: pd.Index) -> pd.DataFrame: acmr = self.all_cause_mortality_rate(index) csmr = self.cause_specific_mortality_rate(index, skip_post_processor=True) cause_deleted_mortality_rate = acmr - csmr - return pd.DataFrame({'other_causes': cause_deleted_mortality_rate}) + return pd.DataFrame({"other_causes": cause_deleted_mortality_rate}) diff --git a/src/vivarium_public_health/risks/__init__.py b/src/vivarium_public_health/risks/__init__.py index 5a2a364e1..b50de709c 100644 --- a/src/vivarium_public_health/risks/__init__.py +++ b/src/vivarium_public_health/risks/__init__.py @@ -1,4 +1,4 @@ +from .base_risk import Risk from .distributions import get_distribution from .effect import RiskEffect -from .base_risk import Risk from .implementations.low_birth_weight_and_short_gestation import LBWSGRisk diff --git a/src/vivarium_public_health/risks/base_risk.py b/src/vivarium_public_health/risks/base_risk.py index c4e03f686..d8336e0ea 100644 --- a/src/vivarium_public_health/risks/base_risk.py +++ b/src/vivarium_public_health/risks/base_risk.py @@ -10,15 +10,16 @@ from typing import Dict, List import pandas as pd - from vivarium.framework.engine import Builder from vivarium.framework.population import PopulationView, SimulantData from vivarium.framework.randomness import RandomnessStream from vivarium.framework.values import Pipeline -from vivarium_public_health.utilities import EntityString +from vivarium_public_health.risks.data_transformations import ( + get_exposure_post_processor, +) from vivarium_public_health.risks.distributions import SimulationDistribution -from vivarium_public_health.risks.data_transformations import get_exposure_post_processor +from vivarium_public_health.utilities import EntityString class Risk: @@ -81,7 +82,7 @@ class Risk: configuration_defaults = { "risk": { - "exposure": 'data', + "exposure": "data", "rebinned_exposed": [], "category_thresholds": [], } @@ -99,10 +100,10 @@ def __init__(self, risk: str): self.exposure_distribution = self._get_exposure_distribution() self._sub_components = [self.exposure_distribution] - self._randomness_stream_name = f'initial_{self.risk.name}_propensity' - self.propensity_column_name = f'{self.risk.name}_propensity' - self.propensity_pipeline_name = f'{self.risk.name}.propensity' - self.exposure_pipeline_name = f'{self.risk.name}.exposure' + self._randomness_stream_name = f"initial_{self.risk.name}_propensity" + self.propensity_column_name = f"{self.risk.name}_propensity" + self.propensity_pipeline_name = f"{self.risk.name}.propensity" + self.exposure_pipeline_name = f"{self.risk.name}.exposure" def __repr__(self) -> str: return f"Risk({self.risk})" @@ -123,7 +124,7 @@ def _get_exposure_distribution(self) -> SimulationDistribution: @property def name(self) -> str: - return f'risk.{self.risk}' + return f"risk.{self.risk}" @property def sub_components(self) -> List: @@ -149,21 +150,20 @@ def _get_propensity_pipeline(self, builder: Builder) -> Pipeline: return builder.value.register_value_producer( self.propensity_pipeline_name, source=lambda index: ( - self.population_view - .subview([self.propensity_column_name]) + self.population_view.subview([self.propensity_column_name]) .get(index) .squeeze(axis=1) ), - requires_columns=[self.propensity_column_name] + requires_columns=[self.propensity_column_name], ) def _get_exposure_pipeline(self, builder: Builder) -> Pipeline: return builder.value.register_value_producer( self.exposure_pipeline_name, source=self._get_current_exposure, - requires_columns=['age', 'sex'], + requires_columns=["age", "sex"], requires_values=[self.propensity_pipeline_name], - preferred_post_processor=get_exposure_post_processor(builder, self.risk) + preferred_post_processor=get_exposure_post_processor(builder, self.risk), ) def _get_population_view(self, builder: Builder) -> PopulationView: @@ -173,7 +173,7 @@ def _register_simulant_initializer(self, builder: Builder) -> None: builder.population.initializes_simulants( self.on_initialize_simulants, creates_columns=[self.propensity_column_name], - requires_streams=[self._randomness_stream_name] + requires_streams=[self._randomness_stream_name], ) ######################## @@ -181,8 +181,11 @@ def _register_simulant_initializer(self, builder: Builder) -> None: ######################## def on_initialize_simulants(self, pop_data: SimulantData) -> None: - self.population_view.update(pd.Series(self.randomness.get_draw(pop_data.index), - name=self.propensity_column_name)) + self.population_view.update( + pd.Series( + self.randomness.get_draw(pop_data.index), name=self.propensity_column_name + ) + ) ################################## # Pipeline sources and modifiers # diff --git a/src/vivarium_public_health/risks/data_transformations.py b/src/vivarium_public_health/risks/data_transformations.py index a3fe18029..dd6eedc9a 100644 --- a/src/vivarium_public_health/risks/data_transformations.py +++ b/src/vivarium_public_health/risks/data_transformations.py @@ -14,16 +14,16 @@ from vivarium_public_health.utilities import EntityString, TargetString - ############# # Utilities # ############# + def pivot_categorical(data: pd.DataFrame) -> pd.DataFrame: """Pivots data that is long on categories to be wide.""" - key_cols = ['sex', 'age_start', 'age_end', 'year_start', 'year_end'] + key_cols = ["sex", "age_start", "age_end", "year_start", "year_end"] key_cols = [k for k in key_cols if k in data.columns] - data = data.pivot_table(index=key_cols, columns='parameter', values='value').reset_index() + data = data.pivot_table(index=key_cols, columns="parameter", values="value").reset_index() data.columns.name = None return data @@ -32,6 +32,7 @@ def pivot_categorical(data: pd.DataFrame) -> pd.DataFrame: # Exposure data handlers # ########################## + def get_distribution_data(builder, risk: EntityString): validate_distribution_data_source(builder, risk) data = load_distribution_data(builder, risk) @@ -39,15 +40,17 @@ def get_distribution_data(builder, risk: EntityString): def get_exposure_post_processor(builder, risk: EntityString): - thresholds = builder.configuration[risk.name]['category_thresholds'] + thresholds = builder.configuration[risk.name]["category_thresholds"] if thresholds: thresholds = [-np.inf] + thresholds + [np.inf] - categories = [f'cat{i}' for i in range(1, len(thresholds))] + categories = [f"cat{i}" for i in range(1, len(thresholds))] def post_processor(exposure, _): - return pd.Series(pd.cut(exposure, thresholds, labels=categories), - index=exposure.index).astype(str) + return pd.Series( + pd.cut(exposure, thresholds, labels=categories), index=exposure.index + ).astype(str) + else: post_processor = None @@ -57,20 +60,22 @@ def post_processor(exposure, _): def load_distribution_data(builder, risk: EntityString): exposure_data = get_exposure_data(builder, risk) - data = {'distribution_type': get_distribution_type(builder, risk), - 'exposure': exposure_data, - 'exposure_standard_deviation': get_exposure_standard_deviation_data(builder, risk), - 'weights': get_exposure_distribution_weights(builder, risk)} + data = { + "distribution_type": get_distribution_type(builder, risk), + "exposure": exposure_data, + "exposure_standard_deviation": get_exposure_standard_deviation_data(builder, risk), + "weights": get_exposure_distribution_weights(builder, risk), + } return data def get_distribution_type(builder, risk: EntityString): risk_config = builder.configuration[risk.name] - if risk_config['exposure'] == 'data' and not risk_config['rebinned_exposed']: - distribution_type = builder.data.load(f'{risk}.distribution') + if risk_config["exposure"] == "data" and not risk_config["rebinned_exposed"]: + distribution_type = builder.data.load(f"{risk}.distribution") else: - distribution_type = 'dichotomous' + distribution_type = "dichotomous" return distribution_type @@ -79,7 +84,11 @@ def get_exposure_data(builder, risk: EntityString): exposure_data = load_exposure_data(builder, risk) exposure_data = rebin_exposure_data(builder, risk, exposure_data) - if get_distribution_type(builder, risk) in ['dichotomous', 'ordered_polytomous', 'unordered_polytomous']: + if get_distribution_type(builder, risk) in [ + "dichotomous", + "ordered_polytomous", + "unordered_polytomous", + ]: exposure_data = pivot_categorical(exposure_data) return exposure_data @@ -87,23 +96,23 @@ def get_exposure_data(builder, risk: EntityString): def load_exposure_data(builder, risk: EntityString): risk_config = builder.configuration[risk.name] - exposure_source = risk_config['exposure'] + exposure_source = risk_config["exposure"] - if exposure_source == 'data': - exposure_data = builder.data.load(f'{risk}.exposure') + if exposure_source == "data": + exposure_data = builder.data.load(f"{risk}.exposure") else: if isinstance(exposure_source, str): # Build from covariate - cat1 = builder.data.load(f'{exposure_source}.estimate') + cat1 = builder.data.load(f"{exposure_source}.estimate") # TODO: Generate a draw. - cat1 = cat1[cat1['parameter'] == 'mean_value'] - cat1['parameter'] = 'cat1' + cat1 = cat1[cat1["parameter"] == "mean_value"] + cat1["parameter"] = "cat1" else: # We have a numerical value - cat1 = builder.data.load('population.demographic_dimensions') - cat1['parameter'] = 'cat1' - cat1['value'] = float(exposure_source) + cat1 = builder.data.load("population.demographic_dimensions") + cat1["parameter"] = "cat1" + cat1["value"] = float(exposure_source) cat2 = cat1.copy() - cat2['parameter'] = 'cat2' - cat2['value'] = 1 - cat2['value'] + cat2["parameter"] = "cat2" + cat2["value"] = 1 - cat2["value"] exposure_data = pd.concat([cat1, cat2], ignore_index=True) return exposure_data @@ -111,8 +120,8 @@ def load_exposure_data(builder, risk: EntityString): def get_exposure_standard_deviation_data(builder, risk: EntityString): distribution_type = get_distribution_type(builder, risk) - if distribution_type in ['normal', 'lognormal', 'ensemble']: - exposure_sd = builder.data.load(f'{risk}.exposure_standard_deviation') + if distribution_type in ["normal", "lognormal", "ensemble"]: + exposure_sd = builder.data.load(f"{risk}.exposure_standard_deviation") else: exposure_sd = None return exposure_sd @@ -120,13 +129,13 @@ def get_exposure_standard_deviation_data(builder, risk: EntityString): def get_exposure_distribution_weights(builder, risk: EntityString): distribution_type = get_distribution_type(builder, risk) - if distribution_type == 'ensemble': - weights = builder.data.load(f'{risk}.exposure_distribution_weights') + if distribution_type == "ensemble": + weights = builder.data.load(f"{risk}.exposure_distribution_weights") weights = pivot_categorical(weights) - if 'glnorm' in weights.columns: - if np.any(weights['glnorm']): - raise NotImplementedError('glnorm distribution is not supported') - weights = weights.drop(columns='glnorm') + if "glnorm" in weights.columns: + if np.any(weights["glnorm"]): + raise NotImplementedError("glnorm distribution is not supported") + weights = weights.drop(columns="glnorm") else: weights = None return weights @@ -134,7 +143,7 @@ def get_exposure_distribution_weights(builder, risk: EntityString): def rebin_exposure_data(builder, risk: EntityString, exposure_data: pd.DataFrame): validate_rebin_source(builder, risk, exposure_data) - rebin_exposed_categories = set(builder.configuration[risk.name]['rebinned_exposed']) + rebin_exposed_categories = set(builder.configuration[risk.name]["rebinned_exposed"]) if rebin_exposed_categories: exposure_data = _rebin_exposure_data(exposure_data, rebin_exposed_categories) @@ -142,47 +151,71 @@ def rebin_exposure_data(builder, risk: EntityString, exposure_data: pd.DataFrame return exposure_data -def _rebin_exposure_data(exposure_data: pd.DataFrame, rebin_exposed_categories: set) -> pd.DataFrame: - exposure_data["parameter"] = (exposure_data["parameter"] - .map(lambda p: 'cat1' if p in rebin_exposed_categories else 'cat2')) - return exposure_data.groupby(list(exposure_data.columns.difference(['value']))).sum().reset_index() +def _rebin_exposure_data( + exposure_data: pd.DataFrame, rebin_exposed_categories: set +) -> pd.DataFrame: + exposure_data["parameter"] = exposure_data["parameter"].map( + lambda p: "cat1" if p in rebin_exposed_categories else "cat2" + ) + return ( + exposure_data.groupby(list(exposure_data.columns.difference(["value"]))) + .sum() + .reset_index() + ) ############################### # Relative risk data handlers # ############################### + def get_relative_risk_data(builder, risk: EntityString, target: TargetString): source_type = validate_relative_risk_data_source(builder, risk, target) relative_risk_data = load_relative_risk_data(builder, risk, target, source_type) relative_risk_data = rebin_relative_risk_data(builder, risk, relative_risk_data) - if get_distribution_type(builder, risk) in ['dichotomous', 'ordered_polytomous', 'unordered_polytomous']: + if get_distribution_type(builder, risk) in [ + "dichotomous", + "ordered_polytomous", + "unordered_polytomous", + ]: relative_risk_data = pivot_categorical(relative_risk_data) else: - relative_risk_data = relative_risk_data.drop(['parameter'], 'columns') + relative_risk_data = relative_risk_data.drop(["parameter"], "columns") return relative_risk_data -def load_relative_risk_data(builder, risk: EntityString, target: TargetString, source_type: str): - relative_risk_source = builder.configuration[f'effect_of_{risk.name}_on_{target.name}'][target.measure] +def load_relative_risk_data( + builder, risk: EntityString, target: TargetString, source_type: str +): + relative_risk_source = builder.configuration[f"effect_of_{risk.name}_on_{target.name}"][ + target.measure + ] - if source_type == 'data': - relative_risk_data = builder.data.load(f'{risk}.relative_risk') - correct_target = ((relative_risk_data['affected_entity'] == target.name) - & (relative_risk_data['affected_measure'] == target.measure)) - relative_risk_data = (relative_risk_data[correct_target] - .drop(['affected_entity', 'affected_measure'], 'columns')) + if source_type == "data": + relative_risk_data = builder.data.load(f"{risk}.relative_risk") + correct_target = (relative_risk_data["affected_entity"] == target.name) & ( + relative_risk_data["affected_measure"] == target.measure + ) + relative_risk_data = relative_risk_data[correct_target].drop( + ["affected_entity", "affected_measure"], "columns" + ) - elif source_type == 'relative risk value': - relative_risk_data = _make_relative_risk_data(builder, float(relative_risk_source['relative_risk'])) + elif source_type == "relative risk value": + relative_risk_data = _make_relative_risk_data( + builder, float(relative_risk_source["relative_risk"]) + ) else: # distribution - parameters = {k: v for k, v in relative_risk_source.to_dict().items() if v is not None} + parameters = { + k: v for k, v in relative_risk_source.to_dict().items() if v is not None + } random_state = np.random.RandomState( - builder.randomness.get_seed(f'effect_of_{risk.name}_on_{target.name}.{target.measure}') + builder.randomness.get_seed( + f"effect_of_{risk.name}_on_{target.name}.{target.measure}" + ) ) cat1_value = generate_relative_risk_from_distribution(random_state, parameters) relative_risk_data = _make_relative_risk_data(builder, cat1_value) @@ -190,27 +223,32 @@ def load_relative_risk_data(builder, risk: EntityString, target: TargetString, s return relative_risk_data -def generate_relative_risk_from_distribution(random_state: np.random.RandomState, - parameters: dict) -> Union[float, pd.Series, np.ndarray]: +def generate_relative_risk_from_distribution( + random_state: np.random.RandomState, parameters: dict +) -> Union[float, pd.Series, np.ndarray]: first = pd.Series(list(parameters.values())[0]) length = len(first) index = first.index for v in parameters.values(): if length != len(pd.Series(v)) or not index.equals(pd.Series(v).index): - raise ValueError('If specifying vectorized parameters, all parameters ' - 'must be the same length and have the same index.') - - if 'mean' in parameters: # normal distribution - rr_value = random_state.normal(parameters['mean'], parameters['se']) - elif 'log_mean' in parameters: # log distribution - log_value = parameters['log_mean'] + parameters['log_se']*random_state.randn() - if parameters['tau_squared']: - log_value += random_state.normal(0, parameters['tau_squared']) + raise ValueError( + "If specifying vectorized parameters, all parameters " + "must be the same length and have the same index." + ) + + if "mean" in parameters: # normal distribution + rr_value = random_state.normal(parameters["mean"], parameters["se"]) + elif "log_mean" in parameters: # log distribution + log_value = parameters["log_mean"] + parameters["log_se"] * random_state.randn() + if parameters["tau_squared"]: + log_value += random_state.normal(0, parameters["tau_squared"]) rr_value = np.exp(log_value) else: - raise NotImplementedError(f'Only normal distributions (supplying mean and se) and log distributions ' - f'(supplying log_mean, log_se, and tau_squared) are currently supported.') + raise NotImplementedError( + f"Only normal distributions (supplying mean and se) and log distributions " + f"(supplying log_mean, log_se, and tau_squared) are currently supported." + ) rr_value = np.maximum(1, rr_value) @@ -218,51 +256,63 @@ def generate_relative_risk_from_distribution(random_state: np.random.RandomState def _make_relative_risk_data(builder, cat1_value: float) -> pd.DataFrame: - cat1 = builder.data.load('population.demographic_dimensions') - cat1['parameter'] = 'cat1' - cat1['value'] = cat1_value + cat1 = builder.data.load("population.demographic_dimensions") + cat1["parameter"] = "cat1" + cat1["value"] = cat1_value cat2 = cat1.copy() - cat2['parameter'] = 'cat2' - cat2['value'] = 1 + cat2["parameter"] = "cat2" + cat2["value"] = 1 return pd.concat([cat1, cat2], ignore_index=True) -def rebin_relative_risk_data(builder, risk: EntityString, relative_risk_data: pd.DataFrame) -> pd.DataFrame: - """ When the polytomous risk is rebinned, matching relative risk needs to be rebinned. - After rebinning, rr for both exposed and unexposed categories should be the weighted sum of relative risk - of the component categories where weights are relative proportions of exposure of those categories. - For example, if cat1, cat2, cat3 are exposed categories and cat4 is unexposed with exposure [0.1,0.2,0.3,0.4], - for the matching rr = [rr1, rr2, rr3, 1], rebinned rr for the rebinned cat1 should be: - (0.1 *rr1 + 0.2 * rr2 + 0.3* rr3) / (0.1+0.2+0.3) +def rebin_relative_risk_data( + builder, risk: EntityString, relative_risk_data: pd.DataFrame +) -> pd.DataFrame: + """When the polytomous risk is rebinned, matching relative risk needs to be rebinned. + After rebinning, rr for both exposed and unexposed categories should be the weighted sum of relative risk + of the component categories where weights are relative proportions of exposure of those categories. + For example, if cat1, cat2, cat3 are exposed categories and cat4 is unexposed with exposure [0.1,0.2,0.3,0.4], + for the matching rr = [rr1, rr2, rr3, 1], rebinned rr for the rebinned cat1 should be: + (0.1 *rr1 + 0.2 * rr2 + 0.3* rr3) / (0.1+0.2+0.3) """ - rebin_exposed_categories = set(builder.configuration[risk.name]['rebinned_exposed']) + rebin_exposed_categories = set(builder.configuration[risk.name]["rebinned_exposed"]) validate_rebin_source(builder, risk, relative_risk_data) if rebin_exposed_categories: exposure_data = load_exposure_data(builder, risk) - relative_risk_data = _rebin_relative_risk_data(relative_risk_data, exposure_data, rebin_exposed_categories) + relative_risk_data = _rebin_relative_risk_data( + relative_risk_data, exposure_data, rebin_exposed_categories + ) return relative_risk_data -def _rebin_relative_risk_data(relative_risk_data: pd.DataFrame, exposure_data: pd.DataFrame, - rebin_exposed_categories: set) -> pd.DataFrame: - cols = list(exposure_data.columns.difference(['value'])) +def _rebin_relative_risk_data( + relative_risk_data: pd.DataFrame, + exposure_data: pd.DataFrame, + rebin_exposed_categories: set, +) -> pd.DataFrame: + cols = list(exposure_data.columns.difference(["value"])) relative_risk_data = relative_risk_data.merge(exposure_data, on=cols) - relative_risk_data['value_x'] = relative_risk_data.value_x.multiply(relative_risk_data.value_y) - relative_risk_data.parameter = (relative_risk_data["parameter"] - .map(lambda p: 'cat1' if p in rebin_exposed_categories else 'cat2')) + relative_risk_data["value_x"] = relative_risk_data.value_x.multiply( + relative_risk_data.value_y + ) + relative_risk_data.parameter = relative_risk_data["parameter"].map( + lambda p: "cat1" if p in rebin_exposed_categories else "cat2" + ) relative_risk_data = relative_risk_data.groupby(cols).sum().reset_index() - relative_risk_data['value'] = relative_risk_data.value_x.divide(relative_risk_data.value_y).fillna(0) - return relative_risk_data.drop(['value_x', 'value_y'], 'columns') + relative_risk_data["value"] = relative_risk_data.value_x.divide( + relative_risk_data.value_y + ).fillna(0) + return relative_risk_data.drop(["value_x", "value_y"], "columns") def get_exposure_effect(builder, risk: EntityString): distribution_type = get_distribution_type(builder, risk) - risk_exposure = builder.value.get_value(f'{risk.name}.exposure') + risk_exposure = builder.value.get_value(f"{risk.name}.exposure") - if distribution_type in ['normal', 'lognormal', 'ensemble']: + if distribution_type in ["normal", "lognormal", "ensemble"]: tmred = builder.data.load(f"{risk}.tmred") tmrel = 0.5 * (tmred["min"] + tmred["max"]) scale = builder.data.load(f"{risk}.relative_risk_scalar") @@ -271,19 +321,21 @@ def exposure_effect(rates, rr): exposure = risk_exposure(rr.index) relative_risk = np.maximum(rr.values ** ((exposure - tmrel) / scale), 1) return rates * relative_risk + else: + def exposure_effect(rates, rr: pd.DataFrame) -> pd.Series: - index_columns = ['index', risk.name] + index_columns = ["index", risk.name] exposure = risk_exposure(rr.index).reset_index() exposure.columns = index_columns exposure = exposure.set_index(index_columns) relative_risk = rr.stack().reset_index() - relative_risk.columns = index_columns + ['value'] + relative_risk.columns = index_columns + ["value"] relative_risk = relative_risk.set_index(index_columns) - effect = relative_risk.loc[exposure.index, 'value'].droplevel(risk.name) + effect = relative_risk.loc[exposure.index, "value"].droplevel(risk.name) affected_rates = rates * effect return affected_rates @@ -294,22 +346,27 @@ def exposure_effect(rates, rr: pd.DataFrame) -> pd.Series: # Population attributable fraction data handlers # ################################################## -def get_population_attributable_fraction_data(builder, risk: EntityString, target: TargetString): - exposure_source = builder.configuration[f'{risk.name}']['exposure'] + +def get_population_attributable_fraction_data( + builder, risk: EntityString, target: TargetString +): + exposure_source = builder.configuration[f"{risk.name}"]["exposure"] rr_source_type = validate_relative_risk_data_source(builder, risk, target) - if exposure_source == 'data' and rr_source_type == 'data' and risk.type == 'risk_factor': - paf_data = builder.data.load(f'{risk}.population_attributable_fraction') - correct_target = ((paf_data['affected_entity'] == target.name) - & (paf_data['affected_measure'] == target.measure)) - paf_data = (paf_data[correct_target] - .drop(['affected_entity', 'affected_measure'], 'columns')) + if exposure_source == "data" and rr_source_type == "data" and risk.type == "risk_factor": + paf_data = builder.data.load(f"{risk}.population_attributable_fraction") + correct_target = (paf_data["affected_entity"] == target.name) & ( + paf_data["affected_measure"] == target.measure + ) + paf_data = paf_data[correct_target].drop( + ["affected_entity", "affected_measure"], "columns" + ) else: - key_cols = ['sex', 'age_start', 'age_end', 'year_start', 'year_end'] + key_cols = ["sex", "age_start", "age_end", "year_start", "year_end"] exposure_data = get_exposure_data(builder, risk).set_index(key_cols) relative_risk_data = get_relative_risk_data(builder, risk, target).set_index(key_cols) mean_rr = (exposure_data * relative_risk_data).sum(axis=1) - paf_data = ((mean_rr - 1)/mean_rr).reset_index().rename(columns={0: 'value'}) + paf_data = ((mean_rr - 1) / mean_rr).reset_index().rename(columns={0: "value"}) return paf_data @@ -317,63 +374,85 @@ def get_population_attributable_fraction_data(builder, risk: EntityString, targe # Validators # ############## + def validate_distribution_data_source(builder, risk: EntityString): """Checks that the exposure distribution specification is valid.""" - exposure_type = builder.configuration[risk.name]['exposure'] - rebin = builder.configuration[risk.name]['rebinned_exposed'] - category_thresholds = builder.configuration[risk.name]['category_thresholds'] + exposure_type = builder.configuration[risk.name]["exposure"] + rebin = builder.configuration[risk.name]["rebinned_exposed"] + category_thresholds = builder.configuration[risk.name]["category_thresholds"] - if risk.type == 'alternative_risk_factor': - if exposure_type != 'data' or rebin: - raise ValueError('Parameterized risk components are not available for alternative risks.') + if risk.type == "alternative_risk_factor": + if exposure_type != "data" or rebin: + raise ValueError( + "Parameterized risk components are not available for alternative risks." + ) if not category_thresholds: - raise ValueError('Must specify category thresholds to use alternative risks.') + raise ValueError("Must specify category thresholds to use alternative risks.") - elif risk.type in ['risk_factor', 'coverage_gap']: + elif risk.type in ["risk_factor", "coverage_gap"]: if isinstance(exposure_type, (int, float)) and not 0 <= exposure_type <= 1: raise ValueError(f"Exposure should be in the range [0, 1]") - elif isinstance(exposure_type, str) and exposure_type.split('.')[0] not in ['covariate', 'data']: - raise ValueError(f"Exposure must be specified as 'data', an integer or float value, " - f"or as a string in the format covariate.covariate_name") + elif isinstance(exposure_type, str) and exposure_type.split(".")[0] not in [ + "covariate", + "data", + ]: + raise ValueError( + f"Exposure must be specified as 'data', an integer or float value, " + f"or as a string in the format covariate.covariate_name" + ) else: pass # All good else: - raise ValueError(f'Unknown risk type {risk.type} for risk {risk.name}') + raise ValueError(f"Unknown risk type {risk.type} for risk {risk.name}") def validate_relative_risk_data_source(builder, risk: EntityString, target: TargetString): - source_key = f'effect_of_{risk.name}_on_{target.name}' + source_key = f"effect_of_{risk.name}_on_{target.name}" relative_risk_source = builder.configuration[source_key][target.measure] - provided_keys = set(k for k, v in relative_risk_source.to_dict().items() if isinstance(v, (int, float))) + provided_keys = set( + k for k, v in relative_risk_source.to_dict().items() if isinstance(v, (int, float)) + ) - source_map = {'data': set(), - 'relative risk value': {'relative_risk'}, - 'normal distribution': {'mean', 'se'}, - 'log distribution': {'log_mean', 'log_se', 'tau_squared'}} + source_map = { + "data": set(), + "relative risk value": {"relative_risk"}, + "normal distribution": {"mean", "se"}, + "log distribution": {"log_mean", "log_se", "tau_squared"}, + } if provided_keys not in source_map.values(): - raise ValueError(f'The acceptable parameter options for specifying relative risk are: ' - f'{source_map.values()}. You provided {provided_keys} for {source_key}.') + raise ValueError( + f"The acceptable parameter options for specifying relative risk are: " + f"{source_map.values()}. You provided {provided_keys} for {source_key}." + ) source_type = [k for k, v in source_map.items() if provided_keys == v][0] - if source_type == 'relative risk value': - if not 1 <= relative_risk_source['relative_risk'] <= 100: - raise ValueError(f"If specifying a single value for relative risk, it should be in the " - f"range [1, 100]. You provided {relative_risk_source['relative_risk']} for {source_key}.") - elif source_type == 'normal distribution': - if relative_risk_source['mean'] <= 0 or relative_risk_source['se'] <= 0: - raise ValueError(f"To specify parameters for a normal distribution for a risk effect, you must provide" - f"both mean and se above 0. This is not the case for {source_key}.") - elif source_type == 'log distribution': - if relative_risk_source['log_mean'] <= 0 or relative_risk_source['log_se'] <= 0: - raise ValueError(f"To specify parameters for a log distribution for a risk effect, you must provide" - f"both log_mean and log_se above 0. This is not the case for {source_key}.") - if relative_risk_source['tau_squared'] < 0: - raise ValueError(f"To specify parameters for a log distribution for a risk effect, you must provide" - f"tau_squared >= 0. This is not the case for {source_key}.") + if source_type == "relative risk value": + if not 1 <= relative_risk_source["relative_risk"] <= 100: + raise ValueError( + f"If specifying a single value for relative risk, it should be in the " + f"range [1, 100]. You provided {relative_risk_source['relative_risk']} for {source_key}." + ) + elif source_type == "normal distribution": + if relative_risk_source["mean"] <= 0 or relative_risk_source["se"] <= 0: + raise ValueError( + f"To specify parameters for a normal distribution for a risk effect, you must provide" + f"both mean and se above 0. This is not the case for {source_key}." + ) + elif source_type == "log distribution": + if relative_risk_source["log_mean"] <= 0 or relative_risk_source["log_se"] <= 0: + raise ValueError( + f"To specify parameters for a log distribution for a risk effect, you must provide" + f"both log_mean and log_se above 0. This is not the case for {source_key}." + ) + if relative_risk_source["tau_squared"] < 0: + raise ValueError( + f"To specify parameters for a log distribution for a risk effect, you must provide" + f"tau_squared >= 0. This is not the case for {source_key}." + ) else: pass @@ -381,22 +460,32 @@ def validate_relative_risk_data_source(builder, risk: EntityString, target: Targ def validate_rebin_source(builder, risk: EntityString, data: pd.DataFrame): - rebin_exposed_categories = set(builder.configuration[risk.name]['rebinned_exposed']) + rebin_exposed_categories = set(builder.configuration[risk.name]["rebinned_exposed"]) - if rebin_exposed_categories and builder.configuration[risk.name]['category_thresholds']: - raise ValueError(f'Rebinning and category thresholds are mutually exclusive. ' - f'You provided both for {risk.name}.') + if rebin_exposed_categories and builder.configuration[risk.name]["category_thresholds"]: + raise ValueError( + f"Rebinning and category thresholds are mutually exclusive. " + f"You provided both for {risk.name}." + ) - if rebin_exposed_categories and 'polytomous' not in builder.data.load(f'{risk}.distribution'): - raise ValueError(f'Rebinning is only supported for polytomous risks. You provided rebinning exposed categories' - f'for {risk.name}, which is of type {builder.data.load(f"{risk}.distribution")}.') + if rebin_exposed_categories and "polytomous" not in builder.data.load( + f"{risk}.distribution" + ): + raise ValueError( + f"Rebinning is only supported for polytomous risks. You provided rebinning exposed categories" + f'for {risk.name}, which is of type {builder.data.load(f"{risk}.distribution")}.' + ) invalid_cats = rebin_exposed_categories.difference(set(data.parameter)) if invalid_cats: - raise ValueError(f'The following provided categories for the rebinned exposed category of {risk.name} ' - f'are not found in the exposure data: {invalid_cats}.') + raise ValueError( + f"The following provided categories for the rebinned exposed category of {risk.name} " + f"are not found in the exposure data: {invalid_cats}." + ) if rebin_exposed_categories == set(data.parameter): - raise ValueError(f'The provided categories for the rebinned exposed category of {risk.name} comprise all ' - f'categories for the exposure data. At least one category must be left out of the provided ' - f'categories to be rebinned into the unexposed category.') + raise ValueError( + f"The provided categories for the rebinned exposed category of {risk.name} comprise all " + f"categories for the exposure data. At least one category must be left out of the provided " + f"categories to be rebinned into the unexposed category." + ) diff --git a/src/vivarium_public_health/risks/distributions.py b/src/vivarium_public_health/risks/distributions.py index e2f6b2362..caec38899 100644 --- a/src/vivarium_public_health/risks/distributions.py +++ b/src/vivarium_public_health/risks/distributions.py @@ -9,11 +9,10 @@ """ import numpy as np import pandas as pd - -from risk_distributions import EnsembleDistribution, Normal, LogNormal - +from risk_distributions import EnsembleDistribution, LogNormal, Normal from vivarium.framework.engine import Builder from vivarium.framework.values import list_combiner, union_post_processor + from vivarium_public_health.risks.data_transformations import get_distribution_data @@ -32,7 +31,7 @@ def __init__(self, risk): @property def name(self): - return f'{self.risk}.exposure_distribution' + return f"{self.risk}.exposure_distribution" def setup(self, builder): distribution_data = get_distribution_data(builder, self.risk) @@ -43,52 +42,64 @@ def ppf(self, q): return self.implementation.ppf(q) def __repr__(self): - return f'ExposureDistribution({self.risk})' + return f"ExposureDistribution({self.risk})" class EnsembleSimulation: - def __init__(self, risk, weights, mean, sd): self.risk = risk self._weights, self._parameters = self._get_parameters(weights, mean, sd) @property def name(self): - return f'ensemble_simulation.{self.risk}' + return f"ensemble_simulation.{self.risk}" def setup(self, builder): - self.weights = builder.lookup.build_table(self._weights, key_columns=['sex'], - parameter_columns=['age', 'year']) - self.parameters = {k: builder.lookup.build_table(v, key_columns=['sex'], parameter_columns=['age', 'year']) - for k, v in self._parameters.items()} - - self._propensity = f'ensemble_propensity_{self.risk}' + self.weights = builder.lookup.build_table( + self._weights, key_columns=["sex"], parameter_columns=["age", "year"] + ) + self.parameters = { + k: builder.lookup.build_table( + v, key_columns=["sex"], parameter_columns=["age", "year"] + ) + for k, v in self._parameters.items() + } + + self._propensity = f"ensemble_propensity_{self.risk}" self.randomness = builder.randomness.get_stream(self._propensity) self.population_view = builder.population.get_view([self._propensity]) - builder.population.initializes_simulants(self.on_initialize_simulants, - creates_columns=[self._propensity], - requires_streams=[self._propensity]) + builder.population.initializes_simulants( + self.on_initialize_simulants, + creates_columns=[self._propensity], + requires_streams=[self._propensity], + ) def on_initialize_simulants(self, pop_data): - ensemble_propensity = self.randomness.get_draw(pop_data.index).rename(self._propensity) + ensemble_propensity = self.randomness.get_draw(pop_data.index).rename( + self._propensity + ) self.population_view.update(ensemble_propensity) def _get_parameters(self, weights, mean, sd): - index_cols = ['sex', 'age_start', 'age_end', 'year_start', 'year_end'] + index_cols = ["sex", "age_start", "age_end", "year_start", "year_end"] weights = weights.set_index(index_cols) - mean = mean.set_index(index_cols)['value'] - sd = sd.set_index(index_cols)['value'] + mean = mean.set_index(index_cols)["value"] + sd = sd.set_index(index_cols)["value"] weights, parameters = EnsembleDistribution.get_parameters(weights, mean=mean, sd=sd) - return weights.reset_index(), {name: p.reset_index() for name, p in parameters.items()} + return weights.reset_index(), { + name: p.reset_index() for name, p in parameters.items() + } def ppf(self, q): if not q.empty: q = clip(q) weights = self.weights(q.index) - parameters = {name: parameter(q.index) for name, parameter in self.parameters.items()} - ensemble_propensity = self.population_view.get(q.index).iloc[:,0] + parameters = { + name: parameter(q.index) for name, parameter in self.parameters.items() + } + ensemble_propensity = self.population_view.get(q.index).iloc[:, 0] x = EnsembleDistribution(weights, parameters).ppf(q, ensemble_propensity) x[x.isnull()] = 0 else: @@ -96,7 +107,7 @@ def ppf(self, q): return x def __repr__(self): - return f'EnsembleSimulation(risk={self.risk})' + return f"EnsembleSimulation(risk={self.risk})" class ContinuousDistribution: @@ -107,16 +118,17 @@ def __init__(self, risk, mean, sd, distribution=None): @property def name(self): - return f'simulation_distribution.{self.risk}' + return f"simulation_distribution.{self.risk}" def setup(self, builder): - self.parameters = builder.lookup.build_table(self._parameters, key_columns=['sex'], - parameter_columns=['age', 'year']) + self.parameters = builder.lookup.build_table( + self._parameters, key_columns=["sex"], parameter_columns=["age", "year"] + ) def _get_parameters(self, mean, sd): - index = ['sex', 'age_start', 'age_end', 'year_start', 'year_end'] - mean = mean.set_index(index)['value'] - sd = sd.set_index(index)['value'] + index = ["sex", "age_start", "age_end", "year_start", "year_end"] + mean = mean.set_index(index)["value"] + sd = sd.set_index(index)["value"] return self.distribution.get_parameters(mean=mean, sd=sd).reset_index() def ppf(self, q): @@ -136,29 +148,38 @@ class PolytomousDistribution: def __init__(self, risk: str, exposure_data: pd.DataFrame): self.risk = risk self.exposure_data = exposure_data - self.categories = sorted([column for column in self.exposure_data if 'cat' in column], - key=lambda column: int(column[3:])) + self.categories = sorted( + [column for column in self.exposure_data if "cat" in column], + key=lambda column: int(column[3:]), + ) @property def name(self): - return f'polytomous_distribution.{self.risk}' + return f"polytomous_distribution.{self.risk}" # noinspection PyAttributeOutsideInit def setup(self, builder: Builder): - self.exposure = builder.value.register_value_producer(f'{self.risk}.exposure_parameters', - source=builder.lookup.build_table(self.exposure_data, - key_columns=['sex'], - parameter_columns= - ['age', 'year'])) + self.exposure = builder.value.register_value_producer( + f"{self.risk}.exposure_parameters", + source=builder.lookup.build_table( + self.exposure_data, key_columns=["sex"], parameter_columns=["age", "year"] + ), + ) def ppf(self, x: pd.Series) -> pd.Series: exposure = self.exposure(x.index) sorted_exposures = exposure[self.categories] if not np.allclose(1, np.sum(sorted_exposures, axis=1)): - raise MissingDataError('All exposure data returned as 0.') - exposure_sum = sorted_exposures.cumsum(axis='columns') - category_index = pd.concat([exposure_sum[c] < x for c in exposure_sum.columns], axis=1).sum(axis=1) - return pd.Series(np.array(self.categories)[category_index], name=self.risk + '.exposure', index=x.index) + raise MissingDataError("All exposure data returned as 0.") + exposure_sum = sorted_exposures.cumsum(axis="columns") + category_index = pd.concat( + [exposure_sum[c] < x for c in exposure_sum.columns], axis=1 + ).sum(axis=1) + return pd.Series( + np.array(self.categories)[category_index], + name=self.risk + ".exposure", + index=x.index, + ) def __repr__(self): return f"PolytomousDistribution(risk={self.risk})" @@ -167,50 +188,65 @@ def __repr__(self): class DichotomousDistribution: def __init__(self, risk: str, exposure_data: pd.DataFrame): self.risk = risk - self.exposure_data = exposure_data.drop('cat2', axis=1) + self.exposure_data = exposure_data.drop("cat2", axis=1) @property def name(self): - return f'dichotomous_distribution.{self.risk}' + return f"dichotomous_distribution.{self.risk}" # noinspection PyAttributeOutsideInit def setup(self, builder: Builder): - self._base_exposure = builder.lookup.build_table(self.exposure_data, key_columns=['sex'], - parameter_columns=['age', 'year']) - self.exposure_proportion = builder.value.register_value_producer(f'{self.risk}.exposure_parameters', - source=self.exposure) + self._base_exposure = builder.lookup.build_table( + self.exposure_data, key_columns=["sex"], parameter_columns=["age", "year"] + ) + self.exposure_proportion = builder.value.register_value_producer( + f"{self.risk}.exposure_parameters", source=self.exposure + ) base_paf = builder.lookup.build_table(0) - self.joint_paf = builder.value.register_value_producer(f'{self.risk}.exposure_parameters.paf', - source=lambda index: [base_paf(index)], - preferred_combiner=list_combiner, - preferred_post_processor=union_post_processor) + self.joint_paf = builder.value.register_value_producer( + f"{self.risk}.exposure_parameters.paf", + source=lambda index: [base_paf(index)], + preferred_combiner=list_combiner, + preferred_post_processor=union_post_processor, + ) def exposure(self, index: pd.Index) -> pd.Series: base_exposure = self._base_exposure(index).values joint_paf = self.joint_paf(index).values - return pd.Series(base_exposure * (1-joint_paf), index=index, name='values') + return pd.Series(base_exposure * (1 - joint_paf), index=index, name="values") def ppf(self, x: pd.Series) -> pd.Series: exposed = x < self.exposure_proportion(x.index) - return pd.Series(exposed.replace({True: 'cat1', False: 'cat2'}), name=self.risk + '.exposure', index=x.index) + return pd.Series( + exposed.replace({True: "cat1", False: "cat2"}), + name=self.risk + ".exposure", + index=x.index, + ) def __repr__(self): return f"DichotomousDistribution(risk={self.risk})" def get_distribution(risk, distribution_type, exposure, exposure_standard_deviation, weights): - if distribution_type == 'dichotomous': + if distribution_type == "dichotomous": distribution = DichotomousDistribution(risk, exposure) - elif 'polytomous' in distribution_type: + elif "polytomous" in distribution_type: distribution = PolytomousDistribution(risk, exposure) - elif distribution_type == 'normal': - distribution = ContinuousDistribution(risk, mean=exposure, sd=exposure_standard_deviation, - distribution=Normal) - elif distribution_type == 'lognormal': - distribution = ContinuousDistribution(risk, mean=exposure, sd=exposure_standard_deviation, - distribution=LogNormal) - elif distribution_type == 'ensemble': - distribution = EnsembleSimulation(risk, weights, mean=exposure, sd=exposure_standard_deviation,) + elif distribution_type == "normal": + distribution = ContinuousDistribution( + risk, mean=exposure, sd=exposure_standard_deviation, distribution=Normal + ) + elif distribution_type == "lognormal": + distribution = ContinuousDistribution( + risk, mean=exposure, sd=exposure_standard_deviation, distribution=LogNormal + ) + elif distribution_type == "ensemble": + distribution = EnsembleSimulation( + risk, + weights, + mean=exposure, + sd=exposure_standard_deviation, + ) else: raise NotImplementedError(f"Unhandled distribution type {distribution_type}") return distribution diff --git a/src/vivarium_public_health/risks/effect.py b/src/vivarium_public_health/risks/effect.py index a07a950f2..6e69bc085 100644 --- a/src/vivarium_public_health/risks/effect.py +++ b/src/vivarium_public_health/risks/effect.py @@ -11,13 +11,14 @@ from typing import Callable, Dict import pandas as pd - from vivarium.framework.engine import Builder from vivarium.framework.lookup import LookupTable -from vivarium_public_health.risks.data_transformations import (get_relative_risk_data, - get_population_attributable_fraction_data, - get_exposure_effect) +from vivarium_public_health.risks.data_transformations import ( + get_exposure_effect, + get_population_attributable_fraction_data, + get_relative_risk_data, +) from vivarium_public_health.utilities import EntityString, TargetString @@ -38,14 +39,14 @@ class RiskEffect: """ configuration_defaults = { - 'effect_of_risk_on_target': { - 'measure': { - 'relative_risk': None, - 'mean': None, - 'se': None, - 'log_mean': None, - 'log_se': None, - 'tau_squared': None + "effect_of_risk_on_target": { + "measure": { + "relative_risk": None, + "mean": None, + "se": None, + "log_mean": None, + "log_se": None, + "tau_squared": None, } } } @@ -67,8 +68,8 @@ def __init__(self, risk: str, target: str): self.target = TargetString(target) self.configuration_defaults = self._get_configuration_defaults() - self.target_pipeline_name = f'{self.target.name}.{self.target.measure}' - self.target_paf_pipeline_name = f'{self.target_pipeline_name}.paf' + self.target_pipeline_name = f"{self.target.name}.{self.target.measure}" + self.target_paf_pipeline_name = f"{self.target_pipeline_name}.paf" def __repr__(self): return f"RiskEffect(risk={self.risk}, target={self.target})" @@ -79,8 +80,10 @@ def __repr__(self): def _get_configuration_defaults(self) -> Dict[str, Dict]: return { - f'effect_of_{self.risk.name}_on_{self.target.name}': { - self.target.measure: RiskEffect.configuration_defaults['effect_of_risk_on_target']['measure'] + f"effect_of_{self.risk.name}_on_{self.target.name}": { + self.target.measure: RiskEffect.configuration_defaults[ + "effect_of_risk_on_target" + ]["measure"] } } @@ -90,7 +93,7 @@ def _get_configuration_defaults(self) -> Dict[str, Dict]: @property def name(self) -> str: - return f'risk_effect.{self.risk}.{self.target}' + return f"risk_effect.{self.risk}.{self.target}" ################# # Setup methods # @@ -99,7 +102,9 @@ def name(self) -> str: # noinspection PyAttributeOutsideInit def setup(self, builder: Builder) -> None: self.relative_risk = self._get_relative_risk_source(builder) - self.population_attributable_fraction = self._get_population_attributable_fraction_source(builder) + self.population_attributable_fraction = ( + self._get_population_attributable_fraction_source(builder) + ) self.target_modifier = self._get_target_modifier(builder) self._register_target_modifier(builder) @@ -107,17 +112,19 @@ def setup(self, builder: Builder) -> None: def _get_relative_risk_source(self, builder: Builder) -> LookupTable: relative_risk_data = get_relative_risk_data(builder, self.risk, self.target) - return builder.lookup.build_table(relative_risk_data, - key_columns=['sex'], - parameter_columns=['age', 'year']) + return builder.lookup.build_table( + relative_risk_data, key_columns=["sex"], parameter_columns=["age", "year"] + ) def _get_population_attributable_fraction_source(self, builder: Builder) -> LookupTable: paf_data = get_population_attributable_fraction_data(builder, self.risk, self.target) - return builder.lookup.build_table(paf_data, - key_columns=['sex'], - parameter_columns=['age', 'year']) + return builder.lookup.build_table( + paf_data, key_columns=["sex"], parameter_columns=["age", "year"] + ) - def _get_target_modifier(self, builder: Builder) -> Callable[[pd.Index, pd.Series], pd.Series]: + def _get_target_modifier( + self, builder: Builder + ) -> Callable[[pd.Index, pd.Series], pd.Series]: exposure_effect = get_exposure_effect(builder, self.risk) def adjust_target(index: pd.Index, target: pd.Series) -> pd.Series: @@ -129,13 +136,13 @@ def _register_target_modifier(self, builder: Builder) -> None: builder.value.register_value_modifier( self.target_pipeline_name, modifier=self.target_modifier, - requires_values=[f'{self.risk.name}.exposure'], - requires_columns=['age', 'sex'] + requires_values=[f"{self.risk.name}.exposure"], + requires_columns=["age", "sex"], ) def _register_paf_modifier(self, builder: Builder) -> None: builder.value.register_value_modifier( self.target_paf_pipeline_name, modifier=self.population_attributable_fraction, - requires_columns=['age', 'sex'] + requires_columns=["age", "sex"], ) diff --git a/src/vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py b/src/vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py index 041c42043..750dda353 100644 --- a/src/vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py +++ b/src/vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py @@ -9,30 +9,22 @@ Note that because the input data is so large, it relies on a custom relative risk data loader that expects data saved in keys by draw. """ -from typing import Tuple from pathlib import Path +from typing import Tuple import pandas as pd +from vivarium.framework.randomness import RandomnessStream -from vivarium_public_health.risks import data_transformations as data_transformations +from vivarium_public_health.risks import RiskEffect, data_transformations from vivarium_public_health.utilities import EntityString, TargetString -from vivarium_public_health.risks.data_transformations import validate_relative_risk_data_source -from vivarium_public_health.risks.data_transformations import rebin_relative_risk_data -from vivarium_public_health.risks.data_transformations import get_distribution_type -from vivarium_public_health.risks.data_transformations import pivot_categorical -from vivarium_public_health.risks import RiskEffect -from vivarium.framework.randomness import RandomnessStream -MISSING_CATEGORY = 'cat212' +MISSING_CATEGORY = "cat212" class LBWSGRisk: configuration_defaults = { - 'low_birth_weight_and_short_gestation': { - 'exposure': 'data', - 'rebinned_exposed': [] - } + "low_birth_weight_and_short_gestation": {"exposure": "data", "rebinned_exposed": []} } @property @@ -40,64 +32,80 @@ def name(self): return "low_birth_weight_short_gestation_risk" def __init__(self): - self.risk = EntityString('risk_factor.low_birth_weight_and_short_gestation') + self.risk = EntityString("risk_factor.low_birth_weight_and_short_gestation") def setup(self, builder): self.exposure_distribution = LBWSGDistribution(builder) - self.birthweight_gestation_time_view = builder.population.get_view(['birth_weight', 'gestation_time']) + self.birthweight_gestation_time_view = builder.population.get_view( + ["birth_weight", "gestation_time"] + ) self._raw_exposure = builder.value.register_value_producer( - f'{self.risk.name}.raw_exposure', + f"{self.risk.name}.raw_exposure", source=lambda index: self.birthweight_gestation_time_view.get(index), - requires_columns=['birth_weight', 'gestation_time'] + requires_columns=["birth_weight", "gestation_time"], ) self.exposure = builder.value.register_value_producer( - f'{self.risk.name}.exposure', + f"{self.risk.name}.exposure", source=self._raw_exposure, preferred_post_processor=self.exposure_distribution.convert_to_categorical, - requires_values=f'{self.risk.name}.raw_exposure' + requires_values=f"{self.risk.name}.raw_exposure", ) - builder.population.initializes_simulants(self.on_initialize_simulants, - creates_columns=['birth_weight', 'gestation_time']) + builder.population.initializes_simulants( + self.on_initialize_simulants, creates_columns=["birth_weight", "gestation_time"] + ) def on_initialize_simulants(self, pop_data): - exposure = self.exposure_distribution.get_birth_weight_and_gestational_age(pop_data.index) - self.birthweight_gestation_time_view.update(pd.DataFrame({ - 'birth_weight': exposure['birth_weight'], - 'gestation_time': exposure['gestation_time'] - }, index=pop_data.index)) + exposure = self.exposure_distribution.get_birth_weight_and_gestational_age( + pop_data.index + ) + self.birthweight_gestation_time_view.update( + pd.DataFrame( + { + "birth_weight": exposure["birth_weight"], + "gestation_time": exposure["gestation_time"], + }, + index=pop_data.index, + ) + ) class LBWSGDistribution: - def __init__(self, builder): - self.risk = EntityString('risk_factor.low_birth_weight_and_short_gestation') - self.randomness = builder.randomness.get_stream(f'{self.risk.name}.exposure') + self.risk = EntityString("risk_factor.low_birth_weight_and_short_gestation") + self.randomness = builder.randomness.get_stream(f"{self.risk.name}.exposure") self.categories_by_interval = get_categories_by_interval(builder, self.risk) - self.intervals_by_category = self.categories_by_interval.reset_index().set_index('cat') + self.intervals_by_category = self.categories_by_interval.reset_index().set_index( + "cat" + ) self.max_gt_by_bw, self.max_bw_by_gt = self._get_boundary_mappings() - self.exposure_parameters = builder.lookup.build_table(get_exposure_data(builder, self.risk), - key_columns=['sex'], - parameter_columns=['age', 'year']) + self.exposure_parameters = builder.lookup.build_table( + get_exposure_data(builder, self.risk), + key_columns=["sex"], + parameter_columns=["age", "year"], + ) def get_birth_weight_and_gestational_age(self, index): - category_draw = self.randomness.get_draw(index, additional_key='category') + category_draw = self.randomness.get_draw(index, additional_key="category") exposure = self.exposure_parameters(index)[self.categories_by_interval.values] - exposure_sum = exposure.cumsum(axis='columns') - category_index = (exposure_sum.T < category_draw).T.sum('columns') - categorical_exposure = pd.Series(self.categories_by_interval.values[category_index], - index=index, name='cat') + exposure_sum = exposure.cumsum(axis="columns") + category_index = (exposure_sum.T < category_draw).T.sum("columns") + categorical_exposure = pd.Series( + self.categories_by_interval.values[category_index], index=index, name="cat" + ) return self._convert_to_continuous(categorical_exposure) def convert_to_categorical(self, exposure, _): exposure = self._convert_boundary_cases(exposure) - categorical_exposure = self.categories_by_interval.iloc[self._get_categorical_index(exposure)] + categorical_exposure = self.categories_by_interval.iloc[ + self._get_categorical_index(exposure) + ] categorical_exposure.index = exposure.index return categorical_exposure @@ -105,59 +113,84 @@ def _convert_boundary_cases(self, exposure): eps = 1e-4 outside_bounds = self._get_categorical_index(exposure) == -1 shift_down = outside_bounds & ( - (exposure.birth_weight < 1000) - | ((1000 < exposure.birth_weight) & (exposure.birth_weight < 4500) & (40 < exposure.gestation_time)) + (exposure.birth_weight < 1000) + | ( + (1000 < exposure.birth_weight) + & (exposure.birth_weight < 4500) + & (40 < exposure.gestation_time) + ) ) shift_left = outside_bounds & ( - (1000 < exposure.birth_weight) & (exposure.gestation_time < 34) - | (4500 < exposure.birth_weight) & (exposure.gestation_time < 42) + (1000 < exposure.birth_weight) & (exposure.gestation_time < 34) + | (4500 < exposure.birth_weight) & (exposure.gestation_time < 42) ) tmrel = outside_bounds & ( - (4500 < exposure.birth_weight) & (42 < exposure.gestation_time) + (4500 < exposure.birth_weight) & (42 < exposure.gestation_time) ) - exposure.loc[shift_down, 'gestation_time'] = (self.max_gt_by_bw - .loc[exposure.loc[shift_down, 'birth_weight']] - .values) - eps - exposure.loc[shift_left, 'birth_weight'] = (self.max_bw_by_gt - .loc[exposure.loc[shift_left, 'gestation_time']] - .values) - eps - exposure.loc[tmrel, 'gestation_time'] = 42 - eps - exposure.loc[tmrel, 'birth_weight'] = 4500 - eps + exposure.loc[shift_down, "gestation_time"] = ( + self.max_gt_by_bw.loc[exposure.loc[shift_down, "birth_weight"]].values + ) - eps + exposure.loc[shift_left, "birth_weight"] = ( + self.max_bw_by_gt.loc[exposure.loc[shift_left, "gestation_time"]].values + ) - eps + exposure.loc[tmrel, "gestation_time"] = 42 - eps + exposure.loc[tmrel, "birth_weight"] = 4500 - eps return exposure def _get_categorical_index(self, exposure): - exposure_bw_gt_index = exposure.set_index(['gestation_time', 'birth_weight']).index - return self.categories_by_interval.index.get_indexer(exposure_bw_gt_index, method=None) + exposure_bw_gt_index = exposure.set_index(["gestation_time", "birth_weight"]).index + return self.categories_by_interval.index.get_indexer( + exposure_bw_gt_index, method=None + ) def _convert_to_continuous(self, categorical_exposure): - draws = {'birth_weight': self.randomness.get_draw(categorical_exposure.index, additional_key='birth_weight'), - 'gestation_time': self.randomness.get_draw(categorical_exposure.index, additional_key='gestation_time')} + draws = { + "birth_weight": self.randomness.get_draw( + categorical_exposure.index, additional_key="birth_weight" + ), + "gestation_time": self.randomness.get_draw( + categorical_exposure.index, additional_key="gestation_time" + ), + } def single_values_from_category(row): - idx = row['index'] - bw_draw = draws['birth_weight'][idx] - gt_draw = draws['gestation_time'][idx] + idx = row["index"] + bw_draw = draws["birth_weight"][idx] + gt_draw = draws["gestation_time"][idx] - intervals = self.intervals_by_category.loc[row['cat']] + intervals = self.intervals_by_category.loc[row["cat"]] - birth_weight = (intervals.birth_weight.left - + bw_draw * (intervals.birth_weight.right - intervals.birth_weight.left)) - gestational_age = (intervals.gestation_time.left - + gt_draw * (intervals.gestation_time.right - intervals.gestation_time.left)) + birth_weight = intervals.birth_weight.left + bw_draw * ( + intervals.birth_weight.right - intervals.birth_weight.left + ) + gestational_age = intervals.gestation_time.left + gt_draw * ( + intervals.gestation_time.right - intervals.gestation_time.left + ) return birth_weight, gestational_age values = categorical_exposure.reset_index().apply(single_values_from_category, axis=1) - return pd.DataFrame(list(values), index=categorical_exposure.index, - columns=['birth_weight', 'gestation_time']) + return pd.DataFrame( + list(values), + index=categorical_exposure.index, + columns=["birth_weight", "gestation_time"], + ) def _get_boundary_mappings(self): cats = self.categories_by_interval.reset_index() - max_gt_by_bw = pd.Series({bw_interval: pd.Index(group.gestation_time).right.max() - for bw_interval, group in cats.groupby('birth_weight')}) - max_bw_by_gt = pd.Series({gt_interval: pd.Index(group.birth_weight).right.max() - for gt_interval, group in cats.groupby('gestation_time')}) + max_gt_by_bw = pd.Series( + { + bw_interval: pd.Index(group.gestation_time).right.max() + for bw_interval, group in cats.groupby("birth_weight") + } + ) + max_bw_by_gt = pd.Series( + { + gt_interval: pd.Index(group.birth_weight).right.max() + for gt_interval, group in cats.groupby("gestation_time") + } + ) return max_gt_by_bw, max_bw_by_gt @@ -168,14 +201,17 @@ def get_exposure_data(builder, risk): def get_categories_by_interval(builder, risk): - category_dict = builder.data.load(f'{risk}.categories') - category_dict[MISSING_CATEGORY] = 'Birth prevalence - [37, 38) wks, [1000, 1500) g' - cats = (pd.DataFrame.from_dict(category_dict, orient='index') - .reset_index() - .rename(columns={'index': 'cat', 0: 'name'})) - idx = pd.MultiIndex.from_tuples(cats.name.apply(get_intervals_from_name), - names=['gestation_time', 'birth_weight']) - cats = cats['cat'] + category_dict = builder.data.load(f"{risk}.categories") + category_dict[MISSING_CATEGORY] = "Birth prevalence - [37, 38) wks, [1000, 1500) g" + cats = ( + pd.DataFrame.from_dict(category_dict, orient="index") + .reset_index() + .rename(columns={"index": "cat", 0: "name"}) + ) + idx = pd.MultiIndex.from_tuples( + cats.name.apply(get_intervals_from_name), names=["gestation_time", "birth_weight"] + ) + cats = cats["cat"] cats.index = idx return cats @@ -186,24 +222,28 @@ def get_intervals_from_name(name: str) -> Tuple[pd.Interval, pd.Interval]: The first interval corresponds to gestational age in weeks, the second to birth weight in grams. """ - numbers_only = (name.replace('Birth prevalence - [', '') - .replace(',', '') - .replace(') wks [', ' ') - .replace(') g', '')) + numbers_only = ( + name.replace("Birth prevalence - [", "") + .replace(",", "") + .replace(") wks [", " ") + .replace(") g", "") + ) numbers_only = [int(n) for n in numbers_only.split()] - return (pd.Interval(numbers_only[0], numbers_only[1], closed='left'), - pd.Interval(numbers_only[2], numbers_only[3], closed='left')) + return ( + pd.Interval(numbers_only[0], numbers_only[1], closed="left"), + pd.Interval(numbers_only[2], numbers_only[3], closed="left"), + ) class LBWSGRiskEffect: """A component to model the impact of the low birth weight and short gestation - risk factor on the target rate of some affected entity. + risk factor on the target rate of some affected entity. """ configuration_defaults = { - 'effect_of_risk_on_target': { - 'measure': { - 'relative_risk': None, + "effect_of_risk_on_target": { + "measure": { + "relative_risk": None, } } } @@ -221,75 +261,102 @@ def __init__(self, target: str): supplied in the form "entity_type.entity_name.measure" where entity_type should be singular (e.g., cause instead of causes). """ - self.risk = EntityString('risk_factor.low_birth_weight_and_short_gestation') + self.risk = EntityString("risk_factor.low_birth_weight_and_short_gestation") self.target = TargetString(target) self.configuration_defaults = { - f'effect_of_{self.risk.name}_on_{self.target.name}': { - self.target.measure: RiskEffect.configuration_defaults['effect_of_risk_on_target']['measure'] + f"effect_of_{self.risk.name}_on_{self.target.name}": { + self.target.measure: RiskEffect.configuration_defaults[ + "effect_of_risk_on_target" + ]["measure"] } } def setup(self, builder): - self.randomness = builder.randomness.get_stream(f'effect_of_{self.risk.name}_on_{self.target.name}') - self.relative_risk = builder.lookup.build_table(self.get_relative_risk_data(builder), - key_columns=['sex'], - parameter_columns=['age', 'year']) + self.randomness = builder.randomness.get_stream( + f"effect_of_{self.risk.name}_on_{self.target.name}" + ) + self.relative_risk = builder.lookup.build_table( + self.get_relative_risk_data(builder), + key_columns=["sex"], + parameter_columns=["age", "year"], + ) self.population_attributable_fraction = builder.lookup.build_table( - data_transformations.get_population_attributable_fraction_data(builder, self.risk, self.target, self.randomness), - key_columns=['sex'], - parameter_columns=['age', 'year'] + data_transformations.get_population_attributable_fraction_data( + builder, self.risk, self.target, self.randomness + ), + key_columns=["sex"], + parameter_columns=["age", "year"], ) self.exposure_effect = data_transformations.get_exposure_effect(builder, self.risk) - builder.value.register_value_modifier(f'{self.target.name}.{self.target.measure}', - modifier=self.adjust_target) - builder.value.register_value_modifier(f'{self.target.name}.{self.target.measure}.paf', - modifier=self.population_attributable_fraction) + builder.value.register_value_modifier( + f"{self.target.name}.{self.target.measure}", modifier=self.adjust_target + ) + builder.value.register_value_modifier( + f"{self.target.name}.{self.target.measure}.paf", + modifier=self.population_attributable_fraction, + ) def adjust_target(self, index, target): return self.exposure_effect(target, self.relative_risk(index)) def get_relative_risk_data(self, builder): - rr_data = get_relative_risk_data_by_draw(builder, self.risk, self.target, self.randomness) - rr_data[MISSING_CATEGORY] = (rr_data['cat106'] + rr_data['cat116']) / 2 + rr_data = get_relative_risk_data_by_draw( + builder, self.risk, self.target, self.randomness + ) + rr_data[MISSING_CATEGORY] = (rr_data["cat106"] + rr_data["cat116"]) / 2 return rr_data # Pulled from vivarium_public_health.risks.data_transformations -def get_relative_risk_data_by_draw(builder, risk: EntityString, target: TargetString, randomness: RandomnessStream): - source_type = validate_relative_risk_data_source(builder, risk, target) +def get_relative_risk_data_by_draw( + builder, risk: EntityString, target: TargetString, randomness: RandomnessStream +): + source_type = data_transformations.validate_relative_risk_data_source( + builder, risk, target + ) relative_risk_data = load_relative_risk_data_by_draw(builder, risk, target, source_type) - relative_risk_data = rebin_relative_risk_data(builder, risk, relative_risk_data) + relative_risk_data = data_transformations.rebin_relative_risk_data( + builder, risk, relative_risk_data + ) - if get_distribution_type(builder, risk) in ['dichotomous', 'ordered_polytomous', 'unordered_polytomous']: - relative_risk_data = pivot_categorical(relative_risk_data) + distribution_type = data_transformations.get_distribution_type(builder, risk) + if distribution_type in ["dichotomous", "ordered_polytomous", "unordered_polytomous"]: + relative_risk_data = data_transformations.pivot_categorical(relative_risk_data) else: - relative_risk_data = relative_risk_data.drop(['parameter'], 'columns') + relative_risk_data = relative_risk_data.drop(["parameter"], "columns") return relative_risk_data -def load_relative_risk_data_by_draw(builder, risk: EntityString, target: TargetString, source_type: str): +def load_relative_risk_data_by_draw( + builder, risk: EntityString, target: TargetString, source_type: str +): artifact_path = Path(builder.data._manager.artifact.path).resolve() relative_risk_data = None - if source_type == 'data': - relative_risk_data = read_data_by_draw(str(artifact_path), f'{risk}.relative_risk', - builder.configuration.input_data.input_draw_number) - correct_target = ((relative_risk_data['affected_entity'] == target.name) - & (relative_risk_data['affected_measure'] == target.measure)) - relative_risk_data = (relative_risk_data[correct_target] - .drop(['affected_entity', 'affected_measure'], 'columns')) + if source_type == "data": + relative_risk_data = read_data_by_draw( + str(artifact_path), + f"{risk}.relative_risk", + builder.configuration.input_data.input_draw_number, + ) + correct_target = (relative_risk_data["affected_entity"] == target.name) & ( + relative_risk_data["affected_measure"] == target.measure + ) + relative_risk_data = relative_risk_data[correct_target].drop( + ["affected_entity", "affected_measure"], "columns" + ) return relative_risk_data def read_data_by_draw(path, key, draw): key = key.replace(".", "/") - with pd.HDFStore(path, mode='r') as store: - index = store.get(f'{key}/index') - draw = store.get(f'{key}/draw_{draw}') + with pd.HDFStore(path, mode="r") as store: + index = store.get(f"{key}/index") + draw = store.get(f"{key}/draw_{draw}") draw.rename("value", inplace=True) return pd.concat([index, draw], axis=1) diff --git a/src/vivarium_public_health/testing/mock_artifact.py b/src/vivarium_public_health/testing/mock_artifact.py index c4d851d7a..7cda5bb3b 100644 --- a/src/vivarium_public_health/testing/mock_artifact.py +++ b/src/vivarium_public_health/testing/mock_artifact.py @@ -8,75 +8,90 @@ """ import pandas as pd - -from vivarium.testing_utilities import build_table - from vivarium.framework.artifact import ArtifactManager +from vivarium.testing_utilities import build_table -from .utils import make_uniform_pop_data +from vivarium_public_health.testing.utils import make_uniform_pop_data MOCKERS = { - 'cause': { - 'prevalence': 0, - 'cause_specific_mortality_rate': 0, - 'excess_mortality_rate': 0, - 'remission_rate': 0, - 'incidence_rate': 0.001, - 'disability_weight': pd.DataFrame({'value': [0]}), - 'restrictions': lambda *args, **kwargs: {'yld_only': False} - }, - 'risk_factor': { - 'distribution': lambda *args, **kwargs: 'ensemble', - 'exposure': 120, - 'exposure_standard_deviation': 15, - 'relative_risk': build_table([1.5, "continuous", "test_cause", "incidence_rate"], 1990, 2017, - ("age", "sex", "year", "value", "parameter", "cause", "affected_measure")), - 'population_attributable_fraction': build_table([1, "test_cause_1", "incidence_rate"], 1990, 2017, - ("age", "sex", "year", "value", "cause", "affected_measure")), - 'tmred': lambda *args, **kwargs: { - "distribution": "uniform", - "min": 80, - "max": 100, - "inverted": False, - }, - 'exposure_parameters': lambda *args, **kwargs: { - 'scale': 1, - 'max_rr': 10, - 'max_val': 200, - 'min_val': 0, - }, - 'ensemble_weights': lambda *args, **kwargs: pd.DataFrame({'norm': 1}, index=[0]) + "cause": { + "prevalence": 0, + "cause_specific_mortality_rate": 0, + "excess_mortality_rate": 0, + "remission_rate": 0, + "incidence_rate": 0.001, + "disability_weight": pd.DataFrame({"value": [0]}), + "restrictions": lambda *args, **kwargs: {"yld_only": False}, + }, + "risk_factor": { + "distribution": lambda *args, **kwargs: "ensemble", + "exposure": 120, + "exposure_standard_deviation": 15, + "relative_risk": build_table( + [1.5, "continuous", "test_cause", "incidence_rate"], + 1990, + 2017, + ("age", "sex", "year", "value", "parameter", "cause", "affected_measure"), + ), + "population_attributable_fraction": build_table( + [1, "test_cause_1", "incidence_rate"], + 1990, + 2017, + ("age", "sex", "year", "value", "cause", "affected_measure"), + ), + "tmred": lambda *args, **kwargs: { + "distribution": "uniform", + "min": 80, + "max": 100, + "inverted": False, }, - 'sequela': { - 'prevalence': 0, - 'cause_specific_mortality_rate': 0, - 'excess_mortality_rate': 0, - 'remission_rate': 0, - 'incidence_rate': 0.001, - 'disability_weight': pd.DataFrame({'value': [0]}), - }, - 'etiology': { - 'population_attributable_fraction': build_table([1, "incidence_rate"], 1990, 2017, - ("age", "sex", "year", "value", "affected_measure")), - }, - 'healthcare_entity': { - 'cost': build_table([0, 'outpatient_visits'], 1990, 2017, - ("age", "sex", "year", "value", "healthcare_entity")), - 'utilization_rate': 0, - }, - # FIXME: this is a hack to get the MockArtifact to use the correct value - 'population.location': 'Kenya', - 'population': { - 'structure': make_uniform_pop_data(), - 'theoretical_minimum_risk_life_expectancy': (build_table(98.0, 1990, 1990) - .query('sex=="Female"') - .filter(['age_start', 'age_end', 'value'])) + "exposure_parameters": lambda *args, **kwargs: { + "scale": 1, + "max_rr": 10, + "max_val": 200, + "min_val": 0, }, + "ensemble_weights": lambda *args, **kwargs: pd.DataFrame({"norm": 1}, index=[0]), + }, + "sequela": { + "prevalence": 0, + "cause_specific_mortality_rate": 0, + "excess_mortality_rate": 0, + "remission_rate": 0, + "incidence_rate": 0.001, + "disability_weight": pd.DataFrame({"value": [0]}), + }, + "etiology": { + "population_attributable_fraction": build_table( + [1, "incidence_rate"], + 1990, + 2017, + ("age", "sex", "year", "value", "affected_measure"), + ), + }, + "healthcare_entity": { + "cost": build_table( + [0, "outpatient_visits"], + 1990, + 2017, + ("age", "sex", "year", "value", "healthcare_entity"), + ), + "utilization_rate": 0, + }, + # FIXME: this is a hack to get the MockArtifact to use the correct value + "population.location": "Kenya", + "population": { + "structure": make_uniform_pop_data(), + "theoretical_minimum_risk_life_expectancy": ( + build_table(98.0, 1990, 1990) + .query('sex=="Female"') + .filter(["age_start", "age_end", "value"]) + ), + }, } -class MockArtifact(): - +class MockArtifact: def __init__(self): self.mocks = MOCKERS.copy() @@ -84,7 +99,7 @@ def load(self, entity_key): if entity_key in self.mocks: return self.mocks[entity_key] - entity_type, *_, entity_measure = entity_key.split('.') + entity_type, *_, entity_measure = entity_key.split(".") assert entity_type in self.mocks assert entity_measure in self.mocks[entity_type] value = self.mocks[entity_type][entity_measure] @@ -101,13 +116,12 @@ def write(self, entity_key, data): class MockArtifactManager(ArtifactManager): - def __init__(self): self.artifact = self._load_artifact(None) @property def name(self): - return 'mock_artifact_manager' + return "mock_artifact_manager" def setup(self, builder): pass diff --git a/src/vivarium_public_health/testing/utils.py b/src/vivarium_public_health/testing/utils.py index 0299148ab..b7c750cf4 100644 --- a/src/vivarium_public_health/testing/utils.py +++ b/src/vivarium_public_health/testing/utils.py @@ -14,7 +14,7 @@ def make_uniform_pop_data(age_bin_midpoint=False): age_bins = [(n, n + 5) for n in range(0, 100, 5)] - sexes = ('Male', 'Female') + sexes = ("Male", "Female") years = zip(range(1990, 2018), range(1991, 2019)) locations = (1, 2) @@ -22,13 +22,17 @@ def make_uniform_pop_data(age_bin_midpoint=False): mins, maxes = zip(*age_bins) year_starts, year_ends = zip(*years) - pop = pd.DataFrame({'age_start': mins, - 'age_end': maxes, - 'sex': sexes, - 'year_start': year_starts, - 'year_end': year_ends, - 'location': locations, - 'value': [100] * len(mins)}) + pop = pd.DataFrame( + { + "age_start": mins, + "age_end": maxes, + "sex": sexes, + "year_start": year_starts, + "year_end": year_ends, + "location": locations, + "value": [100] * len(mins), + } + ) if age_bin_midpoint: # used for population tests - pop['age'] = pop.apply(lambda row: (row['age_start'] + row['age_end']) / 2, axis=1) + pop["age"] = pop.apply(lambda row: (row["age_start"] + row["age_end"]) / 2, axis=1) return pop diff --git a/src/vivarium_public_health/treatment/magic_wand.py b/src/vivarium_public_health/treatment/magic_wand.py index 3de42bcea..c4de428a4 100644 --- a/src/vivarium_public_health/treatment/magic_wand.py +++ b/src/vivarium_public_health/treatment/magic_wand.py @@ -13,33 +13,39 @@ class AbsoluteShift: configuration_defaults = { - 'intervention': { - 'target_value': 'baseline', - 'age_start': 0, - 'age_end': 125, + "intervention": { + "target_value": "baseline", + "age_start": 0, + "age_end": 125, } } def __init__(self, target): self.target = TargetString(target) self.configuration_defaults = { - f'intervention_on_{self.target.name}': AbsoluteShift.configuration_defaults['intervention'] + f"intervention_on_{self.target.name}": AbsoluteShift.configuration_defaults[ + "intervention" + ] } @property def name(self): - return f'absolute_shift_wand.{self.target}' + return f"absolute_shift_wand.{self.target}" def setup(self, builder): - self.config = builder.configuration[f'intervention_on_{self.target.name}'] - builder.value.register_value_modifier(f'{self.target.name}.{self.target.measure}', - modifier=self.intervention_effect, - requires_columns=['age']) - self.population_view = builder.population.get_view(['age']) + self.config = builder.configuration[f"intervention_on_{self.target.name}"] + builder.value.register_value_modifier( + f"{self.target.name}.{self.target.measure}", + modifier=self.intervention_effect, + requires_columns=["age"], + ) + self.population_view = builder.population.get_view(["age"]) def intervention_effect(self, index, value): - if self.config['target_value'] != 'baseline': + if self.config["target_value"] != "baseline": pop = self.population_view.get(index) - affected_group = pop[pop.age.between(self.config['age_start'], self.config['age_end'])] - value.loc[affected_group.index] = float(self.config['target_value']) + affected_group = pop[ + pop.age.between(self.config["age_start"], self.config["age_end"]) + ] + value.loc[affected_group.index] = float(self.config["target_value"]) return value diff --git a/src/vivarium_public_health/treatment/scale_up.py b/src/vivarium_public_health/treatment/scale_up.py index 59e336e17..577bb17dc 100644 --- a/src/vivarium_public_health/treatment/scale_up.py +++ b/src/vivarium_public_health/treatment/scale_up.py @@ -10,12 +10,12 @@ from typing import Callable, Dict, List, Tuple import pandas as pd - from vivarium.framework.engine import Builder from vivarium.framework.lookup import LookupTable from vivarium.framework.population import PopulationView from vivarium.framework.time import Time, get_time_stamp from vivarium.framework.values import Pipeline + from vivarium_public_health.utilities import EntityString @@ -53,12 +53,12 @@ class LinearScaleUp: configuration_defaults = { "treatment": { "start": { - "date": 'start', - "value": 'data', + "date": "start", + "value": "data", }, "end": { - "date": 'end', - "value": 'data', + "date": "end", + "value": "data", }, } } @@ -86,7 +86,7 @@ def _get_configuration_defaults(self) -> Dict[str, Dict]: @property def name(self) -> str: - return f'{self.treatment.name}_intervention' + return f"{self.treatment.name}_intervention" @property def configuration_key(self) -> str: @@ -101,8 +101,13 @@ def setup(self, builder: Builder) -> None: """Perform this component's setup.""" self.is_intervention_scenario = self._get_is_intervention_scenario(builder) self.clock = self._get_clock(builder) - self.scale_up_start_date, self.scale_up_end_date = self._get_scale_up_date_endpoints(builder) - self.scale_up_start_value, self.scale_up_end_value = self._get_scale_up_value_endpoints(builder) + self.scale_up_start_date, self.scale_up_end_date = self._get_scale_up_date_endpoints( + builder + ) + ( + self.scale_up_start_value, + self.scale_up_end_value, + ) = self._get_scale_up_value_endpoints(builder) required_columns = self._get_required_columns() self.pipelines = self._get_required_pipelines(builder) @@ -114,7 +119,7 @@ def setup(self, builder: Builder) -> None: # noinspection PyMethodMayBeStatic def _get_is_intervention_scenario(self, builder: Builder) -> bool: - return builder.configuration.intervention.scenario != 'baseline' + return builder.configuration.intervention.scenario != "baseline" # noinspection PyMethodMayBeStatic def _get_clock(self, builder: Builder) -> Callable[[], Time]: @@ -125,25 +130,27 @@ def _get_scale_up_date_endpoints(self, builder: Builder) -> Tuple[datetime, date scale_up_config = builder.configuration[self.configuration_key] def get_endpoint(endpoint_type: str) -> datetime: - if scale_up_config[endpoint_type]['date'] == endpoint_type: + if scale_up_config[endpoint_type]["date"] == endpoint_type: endpoint = get_time_stamp(builder.configuration.time[endpoint_type]) else: - endpoint = get_time_stamp(scale_up_config[endpoint_type]['date']) + endpoint = get_time_stamp(scale_up_config[endpoint_type]["date"]) return endpoint - return get_endpoint('start'), get_endpoint('end') + return get_endpoint("start"), get_endpoint("end") - def _get_scale_up_value_endpoints(self, builder: Builder) -> Tuple[LookupTable, LookupTable]: + def _get_scale_up_value_endpoints( + self, builder: Builder + ) -> Tuple[LookupTable, LookupTable]: scale_up_config = builder.configuration[self.configuration_key] def get_endpoint_value(endpoint_type: str) -> LookupTable: - if scale_up_config[endpoint_type]['value'] == 'data': + if scale_up_config[endpoint_type]["value"] == "data": endpoint = self._get_endpoint_value_from_data(builder, endpoint_type) else: - endpoint = builder.lookup.build_table(scale_up_config[endpoint_type]['value']) + endpoint = builder.lookup.build_table(scale_up_config[endpoint_type]["value"]) return endpoint - return get_endpoint_value('start'), get_endpoint_value('end') + return get_endpoint_value("start"), get_endpoint_value("end") # noinspection PyMethodMayBeStatic def _get_required_columns(self) -> List[str]: @@ -155,12 +162,14 @@ def _get_required_pipelines(self, builder: Builder) -> Dict[str, Pipeline]: def _register_intervention_modifiers(self, builder: Builder): builder.value.register_value_modifier( - f'{self.treatment}.exposure_parameters', + f"{self.treatment}.exposure_parameters", modifier=self._coverage_effect, ) # noinspection PyMethodMayBeStatic - def _get_population_view(self, builder: Builder, required_columns: List[str]) -> PopulationView: + def _get_population_view( + self, builder: Builder, required_columns: List[str] + ) -> PopulationView: return builder.population.get_view(required_columns) ################################## @@ -171,28 +180,39 @@ def _coverage_effect(self, idx: pd.Index, target: pd.Series) -> pd.Series: if not self.is_intervention_scenario or self.clock() < self.scale_up_start_date: scale_up_progress = 0.0 elif self.scale_up_start_date <= self.clock() < self.scale_up_end_date: - scale_up_progress = ((self.clock() - self.scale_up_start_date) - / (self.scale_up_end_date - self.scale_up_start_date)) + scale_up_progress = (self.clock() - self.scale_up_start_date) / ( + self.scale_up_end_date - self.scale_up_start_date + ) else: scale_up_progress = 1.0 - target = self._apply_scale_up(idx, target, scale_up_progress) if scale_up_progress else target + target = ( + self._apply_scale_up(idx, target, scale_up_progress) + if scale_up_progress + else target + ) return target ################## # Helper methods # ################## - def _get_endpoint_value_from_data(self, builder: Builder, endpoint_type: str) -> LookupTable: - if endpoint_type == 'start': - endpoint_data = builder.data.load(f'{self.treatment}.exposure') - elif endpoint_type == 'end': - endpoint_data = builder.data.load(f'alternate_{self.treatment}.exposure') + def _get_endpoint_value_from_data( + self, builder: Builder, endpoint_type: str + ) -> LookupTable: + if endpoint_type == "start": + endpoint_data = builder.data.load(f"{self.treatment}.exposure") + elif endpoint_type == "end": + endpoint_data = builder.data.load(f"alternate_{self.treatment}.exposure") else: - raise ValueError(f'Invalid endpoint type {endpoint_type}. Allowed types are "start" and "end".') + raise ValueError( + f'Invalid endpoint type {endpoint_type}. Allowed types are "start" and "end".' + ) return builder.lookup.build_table(endpoint_data) - def _apply_scale_up(self, idx: pd.Index, target: pd.Series, scale_up_progress: float) -> pd.Series: + def _apply_scale_up( + self, idx: pd.Index, target: pd.Series, scale_up_progress: float + ) -> pd.Series: start_value = self.scale_up_start_value(idx) end_value = self.scale_up_end_value(idx) value_increase = scale_up_progress * (end_value - start_value) diff --git a/src/vivarium_public_health/treatment/therapeutic_inertia.py b/src/vivarium_public_health/treatment/therapeutic_inertia.py index 1b62db19b..e6e05ebe3 100644 --- a/src/vivarium_public_health/treatment/therapeutic_inertia.py +++ b/src/vivarium_public_health/treatment/therapeutic_inertia.py @@ -17,23 +17,25 @@ class TherapeuticInertia: This is the probability of treatment during a healthcare visit.""" configuration_defaults = { - 'therapeutic_inertia': { - 'triangle_min': 0.65, - 'triangle_max': 0.9, - 'triangle_mode': 0.875 + "therapeutic_inertia": { + "triangle_min": 0.65, + "triangle_max": 0.9, + "triangle_mode": 0.875, } } @property def name(self): - return 'therapeutic_inertia' + return "therapeutic_inertia" def setup(self, builder): self.therapeutic_inertia_parameters = builder.configuration.therapeutic_inertia self._therapeutic_inertia = self.initialize_therapeutic_inertia(builder) ti_source = lambda index: pd.Series(self._therapeutic_inertia, index=index) - self.therapeutic_inertia = builder.value.register_value_producer('therapeutic_inertia', source=ti_source) + self.therapeutic_inertia = builder.value.register_value_producer( + "therapeutic_inertia", source=ti_source + ) def initialize_therapeutic_inertia(self, builder): triangle_min = self.therapeutic_inertia_parameters.triangle_min @@ -49,14 +51,18 @@ def initialize_therapeutic_inertia(self, builder): c = (triangle_mode - loc) / scale seed = builder.randomness.get_seed(self.name) - therapeutic_inertia = scipy.stats.triang(c, loc=loc, scale=scale).rvs(random_state=seed) + therapeutic_inertia = scipy.stats.triang(c, loc=loc, scale=scale).rvs( + random_state=seed + ) return therapeutic_inertia def __str__(self): - return (f'TherapeuticInertia(triangle_min={self.therapeutic_inertia_parameters.triangle_min}, ' - f'triangle_max={self.therapeutic_inertia_parameters.triangle_max}, ' - f'triangle_mode={self.therapeutic_inertia_parameters.triangle_mode})') + return ( + f"TherapeuticInertia(triangle_min={self.therapeutic_inertia_parameters.triangle_min}, " + f"triangle_max={self.therapeutic_inertia_parameters.triangle_max}, " + f"triangle_mode={self.therapeutic_inertia_parameters.triangle_mode})" + ) def __repr__(self): - return 'TherapeuticInertia()' + return "TherapeuticInertia()" diff --git a/src/vivarium_public_health/utilities.py b/src/vivarium_public_health/utilities.py index d757d1f6e..a720ee898 100644 --- a/src/vivarium_public_health/utilities.py +++ b/src/vivarium_public_health/utilities.py @@ -28,9 +28,11 @@ def name(self): return self._name def split_entity(self): - split = self.split('.') + split = self.split(".") if len(split) != 2: - raise ValueError(f'You must specify the entity as "entity_type.entity". You specified {self}.') + raise ValueError( + f'You must specify the entity as "entity_type.entity". You specified {self}.' + ) return split[0], split[1] @@ -54,11 +56,12 @@ def measure(self): return self._measure def split_target(self): - split = self.split('.') + split = self.split(".") if len(split) != 3: raise ValueError( f'You must specify the target as "affected_entity_type.affected_entity_name.affected_measure".' - f'You specified {self}.') + f"You specified {self}." + ) return split[0], split[1], split[2] diff --git a/tests/conftest.py b/tests/conftest.py index cf1b47d24..cba7dbe5d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,34 +1,35 @@ from pathlib import Path import pytest -from vivarium.framework.configuration import build_simulation_configuration from vivarium.config_tree import ConfigTree +from vivarium.framework.configuration import build_simulation_configuration @pytest.fixture() def base_config(): config = build_simulation_configuration() - config.update({ - 'time': { - 'start': {'year': 1990}, - 'end': {'year': 2010}, - 'step_size': 30.5 + config.update( + { + "time": {"start": {"year": 1990}, "end": {"year": 2010}, "step_size": 30.5}, + "randomness": {"key_columns": ["entrance_time", "age"]}, }, - 'randomness': {'key_columns': ['entrance_time', 'age']}, - }, source=str(Path(__file__).resolve()), layer='model_override') + source=str(Path(__file__).resolve()), + layer="model_override", + ) return config -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def base_plugins(): - config = {'required': { - 'data': { - 'controller': 'vivarium_public_health.testing.mock_artifact.MockArtifactManager', - 'builder_interface': 'vivarium.framework.artifact.ArtifactInterface' - } - } + config = { + "required": { + "data": { + "controller": "vivarium_public_health.testing.mock_artifact.MockArtifactManager", + "builder_interface": "vivarium.framework.artifact.ArtifactInterface", + } + } } return ConfigTree(config) diff --git a/tests/disease/test_disease.py b/tests/disease/test_disease.py index 3103a18a2..b58fdf35d 100644 --- a/tests/disease/test_disease.py +++ b/tests/disease/test_disease.py @@ -1,35 +1,41 @@ import numpy as np import pandas as pd import pytest - from vivarium import InteractiveContext from vivarium.framework.utilities import from_yearly from vivarium.framework.values import rescale_post_processor -from vivarium.testing_utilities import build_table, TestPopulation, metadata - -from vivarium_public_health.disease import (BaseDiseaseState, DiseaseState, - RateTransition, DiseaseModel) +from vivarium.testing_utilities import TestPopulation, build_table, metadata + +from vivarium_public_health.disease import ( + BaseDiseaseState, + DiseaseModel, + DiseaseState, + RateTransition, +) from vivarium_public_health.disease.state import SusceptibleState from vivarium_public_health.population import Mortality @pytest.fixture def disease(): - return 'test' + return "test" @pytest.fixture def assign_cause_mock(mocker): - return mocker.patch('vivarium_public_health.disease.model.DiseaseModel.assign_initial_status_to_simulants') + return mocker.patch( + "vivarium_public_health.disease.model.DiseaseModel.assign_initial_status_to_simulants" + ) @pytest.fixture def base_data(): def _set_prevalence(p): base_function = dict() - base_function['dwell_time'] = lambda _, __: pd.Timedelta(days=1) - base_function['prevalence'] = lambda _, __: p + base_function["dwell_time"] = lambda _, __: pd.Timedelta(days=1) + base_function["prevalence"] = lambda _, __: p return base_function + return _set_prevalence @@ -49,42 +55,47 @@ def test_dwell_time(assign_cause_mock, base_config, base_plugins, disease, base_ time_step = 10 assign_cause_mock.side_effect = lambda population, *args: pd.DataFrame( - {'condition_state': 'healthy'}, index=population.index) + {"condition_state": "healthy"}, index=population.index + ) - base_config.update({ - 'time': {'step_size': time_step}, - 'population': {'population_size': 10} - }, **metadata(__file__)) + base_config.update( + {"time": {"step_size": time_step}, "population": {"population_size": 10}}, + **metadata(__file__), + ) - healthy_state = BaseDiseaseState('healthy') + healthy_state = BaseDiseaseState("healthy") data_function = base_data(0) - data_function['dwell_time'] = lambda _, __: pd.Timedelta(days=28) - data_function['disability_weight'] = lambda _, __: 0.0 - event_state = DiseaseState('event', get_data_functions=data_function) - done_state = BaseDiseaseState('sick') + data_function["dwell_time"] = lambda _, __: pd.Timedelta(days=28) + data_function["disability_weight"] = lambda _, __: 0.0 + event_state = DiseaseState("event", get_data_functions=data_function) + done_state = BaseDiseaseState("sick") healthy_state.add_transition(event_state) event_state.add_transition(done_state) - model = DiseaseModel(disease, initial_state=healthy_state, states=[healthy_state, event_state, done_state]) + model = DiseaseModel( + disease, initial_state=healthy_state, states=[healthy_state, event_state, done_state] + ) - simulation = InteractiveContext(components=[TestPopulation(), model], - configuration=base_config, - plugin_configuration=base_plugins) + simulation = InteractiveContext( + components=[TestPopulation(), model], + configuration=base_config, + plugin_configuration=base_plugins, + ) # Move everyone into the event state simulation.step() event_time = simulation._clock.time - assert np.all(simulation.get_population()[disease] == 'event') + assert np.all(simulation.get_population()[disease] == "event") simulation.step() simulation.step() # Not enough time has passed for people to move out of the event state, so they should all still be there - assert np.all(simulation.get_population()[disease] == 'event') + assert np.all(simulation.get_population()[disease] == "event") simulation.step() # Now enough time has passed so people should transition away - assert np.all(simulation.get_population()[disease] == 'sick') + assert np.all(simulation.get_population()[disease] == "sick") assert np.all(simulation.get_population().event_event_time == pd.to_datetime(event_time)) assert np.all(simulation.get_population().event_event_count == 1) @@ -95,53 +106,63 @@ def test_dwell_time_with_mortality(base_config, base_plugins, disease): time_step = 10 pop_size = 100 - base_config.update({ - 'time': {'step_size': time_step}, - 'population': {'population_size': pop_size} - }, **metadata(__file__)) - healthy_state = BaseDiseaseState('healthy') + base_config.update( + {"time": {"step_size": time_step}, "population": {"population_size": pop_size}}, + **metadata(__file__), + ) + healthy_state = BaseDiseaseState("healthy") mort_get_data_funcs = { - 'dwell_time': lambda _, __: pd.Timedelta(days=14), - 'excess_mortality_rate': lambda _, __: build_table(0.7, year_start-1, year_end), - 'disability_weight': lambda _, __: 0.0 + "dwell_time": lambda _, __: pd.Timedelta(days=14), + "excess_mortality_rate": lambda _, __: build_table(0.7, year_start - 1, year_end), + "disability_weight": lambda _, __: 0.0, } - mortality_state = DiseaseState('event', get_data_functions=mort_get_data_funcs) - done_state = BaseDiseaseState('sick') + mortality_state = DiseaseState("event", get_data_functions=mort_get_data_funcs) + done_state = BaseDiseaseState("sick") healthy_state.add_transition(mortality_state) mortality_state.add_transition(done_state) - model = DiseaseModel(disease, initial_state=healthy_state, states=[healthy_state, mortality_state, done_state]) + model = DiseaseModel( + disease, + initial_state=healthy_state, + states=[healthy_state, mortality_state, done_state], + ) mortality = Mortality() - simulation = InteractiveContext(components=[TestPopulation(), model, mortality], - configuration=base_config, - plugin_configuration=base_plugins) + simulation = InteractiveContext( + components=[TestPopulation(), model, mortality], + configuration=base_config, + plugin_configuration=base_plugins, + ) # Move everyone into the event state simulation.step() - assert np.all(simulation.get_population()[disease] == 'event') + assert np.all(simulation.get_population()[disease] == "event") simulation.step() # Not enough time has passed for people to move out of the event state, so they should all still be there - assert np.all(simulation.get_population()[disease] == 'event') + assert np.all(simulation.get_population()[disease] == "event") simulation.step() # Make sure some people have died and remained in event state - assert (simulation.get_population()['alive'] == 'alive').sum() < pop_size + assert (simulation.get_population()["alive"] == "alive").sum() < pop_size - assert ((simulation.get_population()['alive'] == 'dead').sum() == - (simulation.get_population()[disease] == 'event').sum()) + assert (simulation.get_population()["alive"] == "dead").sum() == ( + simulation.get_population()[disease] == "event" + ).sum() # enough time has passed so living people should transition away to sick - assert ((simulation.get_population()['alive'] == 'alive').sum() == - (simulation.get_population()[disease] == 'sick').sum()) + assert (simulation.get_population()["alive"] == "alive").sum() == ( + simulation.get_population()[disease] == "sick" + ).sum() -@pytest.mark.parametrize('test_prevalence_level', [0, 0.35, 1]) -def test_prevalence_single_state_with_migration(base_config, base_plugins, disease, base_data, test_prevalence_level): +@pytest.mark.parametrize("test_prevalence_level", [0, 0.35, 1]) +def test_prevalence_single_state_with_migration( + base_config, base_plugins, disease, base_data, test_prevalence_level +): """ Test the prevalence for the single state over newly migrated population. Start with the initial population, check the prevalence for initial assignment. @@ -152,64 +173,95 @@ def test_prevalence_single_state_with_migration(base_config, base_plugins, disea year_start = base_config.time.start.year year_end = base_config.time.end.year - healthy = BaseDiseaseState('healthy') + healthy = BaseDiseaseState("healthy") data_funcs = base_data(test_prevalence_level) - data_funcs.update({'disability_weight': lambda _, __: 0.0}) - sick = DiseaseState('sick', get_data_functions=data_funcs) + data_funcs.update({"disability_weight": lambda _, __: 0.0}) + sick = DiseaseState("sick", get_data_functions=data_funcs) model = DiseaseModel(disease, initial_state=healthy, states=[healthy, sick]) - base_config.update({'population': {'population_size': 50000}}, **metadata(__file__)) - simulation = InteractiveContext(components=[TestPopulation(), model], - configuration=base_config, - plugin_configuration=base_plugins) + base_config.update({"population": {"population_size": 50000}}, **metadata(__file__)) + simulation = InteractiveContext( + components=[TestPopulation(), model], + configuration=base_config, + plugin_configuration=base_plugins, + ) error_message = "initial status of simulants should be matched to the prevalence data." - assert np.isclose(get_test_prevalence(simulation, 'sick'), test_prevalence_level, 0.01), error_message + assert np.isclose( + get_test_prevalence(simulation, "sick"), test_prevalence_level, 0.01 + ), error_message simulation._clock.step_forward() - assert np.isclose(get_test_prevalence(simulation, 'sick'), test_prevalence_level, .01), error_message - simulation.simulant_creator(50000, population_configuration={'age_start': 0, 'age_end': 5, - 'sim_state': 'time_step'}) - assert np.isclose(get_test_prevalence(simulation, 'sick'), test_prevalence_level, .01), error_message + assert np.isclose( + get_test_prevalence(simulation, "sick"), test_prevalence_level, 0.01 + ), error_message + simulation.simulant_creator( + 50000, + population_configuration={"age_start": 0, "age_end": 5, "sim_state": "time_step"}, + ) + assert np.isclose( + get_test_prevalence(simulation, "sick"), test_prevalence_level, 0.01 + ), error_message simulation._clock.step_forward() - simulation.simulant_creator(50000, population_configuration={'age_start': 0, 'age_end': 5, - 'sim_state': 'time_step'}) - assert np.isclose(get_test_prevalence(simulation, 'sick'), test_prevalence_level, .01), error_message - - -@pytest.mark.parametrize('test_prevalence_level', - [[0.15, 0.05, 0.35], [0, 0.15, 0.5], [0.2, 0.3, 0.5], [0, 0, 1], [0, 0, 0]]) -def test_prevalence_multiple_sequelae(base_config, base_plugins, disease, base_data, test_prevalence_level): + simulation.simulant_creator( + 50000, + population_configuration={"age_start": 0, "age_end": 5, "sim_state": "time_step"}, + ) + assert np.isclose( + get_test_prevalence(simulation, "sick"), test_prevalence_level, 0.01 + ), error_message + + +@pytest.mark.parametrize( + "test_prevalence_level", + [[0.15, 0.05, 0.35], [0, 0.15, 0.5], [0.2, 0.3, 0.5], [0, 0, 1], [0, 0, 0]], +) +def test_prevalence_multiple_sequelae( + base_config, base_plugins, disease, base_data, test_prevalence_level +): year_start = base_config.time.start.year year_end = base_config.time.end.year - healthy = BaseDiseaseState('healthy') + healthy = BaseDiseaseState("healthy") sequela = dict() for i, p in enumerate(test_prevalence_level): data_funcs = base_data(p) - data_funcs.update({'disability_weight': lambda _, __: 0.0}) - sequela[i] = DiseaseState('sequela'+str(i), get_data_functions=data_funcs) + data_funcs.update({"disability_weight": lambda _, __: 0.0}) + sequela[i] = DiseaseState("sequela" + str(i), get_data_functions=data_funcs) - model = DiseaseModel(disease, initial_state=healthy, states=[healthy, sequela[0], sequela[1], sequela[2]]) - base_config.update({'population': {'population_size': 100000}}, **metadata(__file__)) - simulation = InteractiveContext(components=[TestPopulation(), model], - configuration=base_config, - plugin_configuration=base_plugins) - error_message = "initial sequela status of simulants should be matched to the prevalence data." - assert np.allclose([get_test_prevalence(simulation, 'sequela0'), - get_test_prevalence(simulation, 'sequela1'), - get_test_prevalence(simulation, 'sequela2')], test_prevalence_level, .02), error_message + model = DiseaseModel( + disease, initial_state=healthy, states=[healthy, sequela[0], sequela[1], sequela[2]] + ) + base_config.update({"population": {"population_size": 100000}}, **metadata(__file__)) + simulation = InteractiveContext( + components=[TestPopulation(), model], + configuration=base_config, + plugin_configuration=base_plugins, + ) + error_message = ( + "initial sequela status of simulants should be matched to the prevalence data." + ) + assert np.allclose( + [ + get_test_prevalence(simulation, "sequela0"), + get_test_prevalence(simulation, "sequela1"), + get_test_prevalence(simulation, "sequela2"), + ], + test_prevalence_level, + 0.02, + ), error_message def test_prevalence_single_simulant(): # pandas has a bug on the case of single element with non-zero index; this test is to catch that case test_index = [20] - initial_state = 'healthy' - simulants_df = pd.DataFrame({'sex': 'Female', 'age': 3, 'sex_id': 2.0}, index=test_index) - state_names = ['sick', 'healthy'] + initial_state = "healthy" + simulants_df = pd.DataFrame({"sex": "Female", "age": 3, "sex_id": 2.0}, index=test_index) + state_names = ["sick", "healthy"] weights = np.array([[1, 1]]) - simulants = DiseaseModel.assign_initial_status_to_simulants(simulants_df, state_names, weights, - pd.Series(0.5, index=test_index)) - expected = simulants_df[['age', 'sex']] - expected['condition_state'] = 'sick' + simulants = DiseaseModel.assign_initial_status_to_simulants( + simulants_df, state_names, weights, pd.Series(0.5, index=test_index) + ) + expected = simulants_df[["age", "sex"]] + expected["condition_state"] = "sick" assert expected.equals(simulants) @@ -219,30 +271,35 @@ def test_mortality_rate(base_config, base_plugins, disease): time_step = pd.Timedelta(days=base_config.time.step_size) - healthy = BaseDiseaseState('healthy') + healthy = BaseDiseaseState("healthy") mort_get_data_funcs = { - 'dwell_time': lambda _, __: pd.Timedelta(days=0), - 'disability_weight': lambda _, __: 0.0, - 'prevalence': lambda _, __: build_table(0.000001, year_start-1, year_end, - ['age', 'year', 'sex', 'value']), - 'excess_mortality_rate': lambda _, __: build_table(0.7, year_start-1, year_end), + "dwell_time": lambda _, __: pd.Timedelta(days=0), + "disability_weight": lambda _, __: 0.0, + "prevalence": lambda _, __: build_table( + 0.000001, year_start - 1, year_end, ["age", "year", "sex", "value"] + ), + "excess_mortality_rate": lambda _, __: build_table(0.7, year_start - 1, year_end), } - mortality_state = DiseaseState('sick', get_data_functions=mort_get_data_funcs) + mortality_state = DiseaseState("sick", get_data_functions=mort_get_data_funcs) healthy.add_transition(mortality_state) model = DiseaseModel(disease, initial_state=healthy, states=[healthy, mortality_state]) - simulation = InteractiveContext(components=[TestPopulation(), model, Mortality()], - configuration=base_config, - plugin_configuration=base_plugins) + simulation = InteractiveContext( + components=[TestPopulation(), model, Mortality()], + configuration=base_config, + plugin_configuration=base_plugins, + ) - mortality_rate = simulation._values.get_value('mortality_rate') + mortality_rate = simulation._values.get_value("mortality_rate") simulation.step() # Folks instantly transition to sick so now our mortality rate should be much higher - assert np.allclose(from_yearly(0.7, time_step), mortality_rate(simulation.get_population().index)['sick']) + assert np.allclose( + from_yearly(0.7, time_step), mortality_rate(simulation.get_population().index)["sick"] + ) def test_incidence(base_config, base_plugins, disease): @@ -250,32 +307,37 @@ def test_incidence(base_config, base_plugins, disease): year_end = base_config.time.end.year time_step = pd.Timedelta(days=base_config.time.step_size) - healthy = BaseDiseaseState('healthy') - sick = BaseDiseaseState('sick') + healthy = BaseDiseaseState("healthy") + sick = BaseDiseaseState("sick") key = f"sequela.acute_myocardial_infarction_first_2_days.incidence_rate" transition = RateTransition( - input_state=healthy, output_state=sick, - get_data_functions={ - 'incidence_rate': lambda _, builder: builder.data.load(key) - }) + input_state=healthy, + output_state=sick, + get_data_functions={"incidence_rate": lambda _, builder: builder.data.load(key)}, + ) healthy.transition_set.append(transition) model = DiseaseModel(disease, initial_state=healthy, states=[healthy, sick]) - simulation = InteractiveContext(components=[TestPopulation(), model], - configuration=base_config, - plugin_configuration=base_plugins, - setup=False) + simulation = InteractiveContext( + components=[TestPopulation(), model], + configuration=base_config, + plugin_configuration=base_plugins, + setup=False, + ) simulation._data.write(key, 0.7) simulation.setup() - incidence_rate = simulation._values.get_value('sick.incidence_rate') + incidence_rate = simulation._values.get_value("sick.incidence_rate") simulation.step() - assert np.allclose(from_yearly(0.7, time_step), - incidence_rate(simulation.get_population().index), atol=0.00001) + assert np.allclose( + from_yearly(0.7, time_step), + incidence_rate(simulation.get_population().index), + atol=0.00001, + ) def test_risk_deletion(base_config, base_plugins, disease): @@ -286,14 +348,13 @@ def test_risk_deletion(base_config, base_plugins, disease): base_rate = 0.7 paf = 0.1 - healthy = BaseDiseaseState('healthy') - sick = BaseDiseaseState('sick') + healthy = BaseDiseaseState("healthy") + sick = BaseDiseaseState("sick") key = "sequela.acute_myocardial_infarction_first_2_days.incidence_rate" transition = RateTransition( - input_state=healthy, output_state=sick, - get_data_functions={ - 'incidence_rate': lambda _, builder: builder.data.load(key) - } + input_state=healthy, + output_state=sick, + get_data_functions={"incidence_rate": lambda _, builder: builder.data.load(key)}, ) healthy.transition_set.append(transition) @@ -302,32 +363,38 @@ def test_risk_deletion(base_config, base_plugins, disease): class PafModifier: @property def name(self): - return 'paf_modifier' + return "paf_modifier" def setup(self, builder): builder.value.register_value_modifier( - 'sick.incidence_rate.paf', + "sick.incidence_rate.paf", modifier=simulation._tables.build_table( build_table(paf, year_start, year_end), - key_columns=('sex',), - parameter_columns=['age', 'year'], - value_columns=None - ) + key_columns=("sex",), + parameter_columns=["age", "year"], + value_columns=None, + ), ) - simulation = InteractiveContext(components=[TestPopulation(), model, PafModifier()], - configuration=base_config, - plugin_configuration=base_plugins, setup=False) + simulation = InteractiveContext( + components=[TestPopulation(), model, PafModifier()], + configuration=base_config, + plugin_configuration=base_plugins, + setup=False, + ) simulation._data.write(key, base_rate) simulation.setup() - incidence_rate = simulation._values.get_value('sick.incidence_rate') + incidence_rate = simulation._values.get_value("sick.incidence_rate") simulation.step() expected_rate = base_rate * (1 - paf) - assert np.allclose(from_yearly(expected_rate, time_step), - incidence_rate(simulation.get_population().index), atol=0.00001) + assert np.allclose( + from_yearly(expected_rate, time_step), + incidence_rate(simulation.get_population().index), + atol=0.00001, + ) def test__assign_event_time_for_prevalent_cases(): @@ -339,71 +406,103 @@ def test__assign_event_time_for_prevalent_cases(): # 10* 0.4 = 4 ; 4 days before the current time expected = pd.Series(pd.Timestamp(2017, 1, 6, 12), index=pop_data.index) - assert expected.equals(DiseaseState._assign_event_time_for_prevalent_cases(pop_data, current_time, random_func, - dwell_time_func)) + assert expected.equals( + DiseaseState._assign_event_time_for_prevalent_cases( + pop_data, current_time, random_func, dwell_time_func + ) + ) def test_prevalence_birth_prevalence_initial_assignment(base_config, base_plugins, disease): - healthy = BaseDiseaseState('healthy') + healthy = BaseDiseaseState("healthy") - data_funcs = {'prevalence': lambda _, __: 1, - 'birth_prevalence': lambda _, __: 0.5, - 'disability_weight': lambda _, __: 0} - with_condition = DiseaseState('with_condition', get_data_functions=data_funcs) + data_funcs = { + "prevalence": lambda _, __: 1, + "birth_prevalence": lambda _, __: 0.5, + "disability_weight": lambda _, __: 0, + } + with_condition = DiseaseState("with_condition", get_data_functions=data_funcs) model = DiseaseModel(disease, initial_state=healthy, states=[healthy, with_condition]) - base_config.update({'population': {'population_size': 1000, 'age_start': 0, 'age_end': 5}}, **metadata(__file__)) - simulation = InteractiveContext(components=[TestPopulation(), model], - configuration=base_config, - plugin_configuration=base_plugins) + base_config.update( + {"population": {"population_size": 1000, "age_start": 0, "age_end": 5}}, + **metadata(__file__), + ) + simulation = InteractiveContext( + components=[TestPopulation(), model], + configuration=base_config, + plugin_configuration=base_plugins, + ) # prevalence should be used for assigning initial status at sim start assert np.isclose(get_test_prevalence(simulation, "with_condition"), 1) # birth prevalence should be used for assigning initial status to newly-borns on time steps simulation._clock.step_forward() - simulation.simulant_creator(1000, population_configuration={'age_start': 0, 'age_end': 0, 'sim_state': 'time_step'}) + simulation.simulant_creator( + 1000, + population_configuration={"age_start": 0, "age_end": 0, "sim_state": "time_step"}, + ) assert np.isclose(get_test_prevalence(simulation, "with_condition"), 0.75, 0.01) # and prevalence should be used for ages not start = end = 0 simulation._clock.step_forward() - simulation.simulant_creator(1000, population_configuration={'age_start': 0, 'age_end': 5, 'sim_state': 'time_step'}) + simulation.simulant_creator( + 1000, + population_configuration={"age_start": 0, "age_end": 5, "sim_state": "time_step"}, + ) assert np.isclose(get_test_prevalence(simulation, "with_condition"), 0.83, 0.01) def test_no_birth_prevalence_initial_assignment(base_config, base_plugins, disease): - healthy = BaseDiseaseState('healthy') + healthy = BaseDiseaseState("healthy") - data_funcs = {'prevalence': lambda _, __: 1, - 'disability_weight': lambda _, __: 0} - with_condition = DiseaseState('with_condition', get_data_functions=data_funcs) + data_funcs = {"prevalence": lambda _, __: 1, "disability_weight": lambda _, __: 0} + with_condition = DiseaseState("with_condition", get_data_functions=data_funcs) model = DiseaseModel(disease, initial_state=healthy, states=[healthy, with_condition]) - base_config.update({'population': {'population_size': 1000, 'age_start': 0, 'age_end': 5}}, **metadata(__file__)) - simulation = InteractiveContext(components=[TestPopulation(), model], - configuration=base_config, - plugin_configuration=base_plugins) + base_config.update( + {"population": {"population_size": 1000, "age_start": 0, "age_end": 5}}, + **metadata(__file__), + ) + simulation = InteractiveContext( + components=[TestPopulation(), model], + configuration=base_config, + plugin_configuration=base_plugins, + ) # prevalence should be used for assigning initial status at sim start assert np.isclose(get_test_prevalence(simulation, "with_condition"), 1) # with no birth prevalence provided, it should default to 0 for ages start = end = 0 simulation._clock.step_forward() - simulation.simulant_creator(1000, population_configuration={'age_start': 0, 'age_end': 0, 'sim_state': 'time_step'}) + simulation.simulant_creator( + 1000, + population_configuration={"age_start": 0, "age_end": 0, "sim_state": "time_step"}, + ) assert np.isclose(get_test_prevalence(simulation, "with_condition"), 0.5, 0.01) # and default to prevalence for ages not start = end = 0 simulation._clock.step_forward() - simulation.simulant_creator(1000, population_configuration={'age_start': 0, 'age_end': 5, 'sim_state': 'time_step'}) + simulation.simulant_creator( + 1000, + population_configuration={"age_start": 0, "age_end": 5, "sim_state": "time_step"}, + ) assert np.isclose(get_test_prevalence(simulation, "with_condition"), 0.67, 0.01) def test_state_transition_names(disease): - with_condition = DiseaseState('diarrheal_diseases') - healthy = SusceptibleState('diarrheal_diseases') + with_condition = DiseaseState("diarrheal_diseases") + healthy = SusceptibleState("diarrheal_diseases") healthy.add_transition(with_condition) with_condition.add_transition(healthy) model = DiseaseModel(disease, initial_state=healthy, states=[healthy, with_condition]) - assert set(model.state_names) == set(['diarrheal_diseases', 'susceptible_to_diarrheal_diseases']) - assert set(model.transition_names) == set(['diarrheal_diseases_TO_susceptible_to_diarrheal_diseases', - 'susceptible_to_diarrheal_diseases_TO_diarrheal_diseases']) + assert set(model.state_names) == set( + ["diarrheal_diseases", "susceptible_to_diarrheal_diseases"] + ) + assert set(model.transition_names) == set( + [ + "diarrheal_diseases_TO_susceptible_to_diarrheal_diseases", + "susceptible_to_diarrheal_diseases_TO_diarrheal_diseases", + ] + ) diff --git a/tests/disease/test_special_disease.py b/tests/disease/test_special_disease.py index 9bcf851fe..e563e870b 100644 --- a/tests/disease/test_special_disease.py +++ b/tests/disease/test_special_disease.py @@ -1,29 +1,36 @@ +from operator import gt, lt + import numpy as np import pandas as pd -from operator import gt, lt import pytest + from vivarium_public_health.disease import RiskAttributableDisease @pytest.fixture def disease_mock(mocker): def disease_with_distribution(distribution): - test_disease = RiskAttributableDisease('cause.test_cause', 'risk_factor.test_risk') + test_disease = RiskAttributableDisease("cause.test_cause", "risk_factor.test_risk") test_disease.distribution = distribution test_disease.population_view = mocker.Mock() test_disease.excess_mortality_rate = mocker.Mock() return test_disease + return disease_with_distribution -test_data = [('ordered_polytomous', ['cat1', 'cat2', 'cat3', 'cat4'], ['cat1']), - ('ordered_polytomous', ['cat1', 'cat2', 'cat3', 'cat4'], ['cat1', 'cat2']), - ('ordered_polytomous', ['cat1', 'cat2', 'cat3', 'cat4'], ['cat1', 'cat2', 'cat3']), - ('dichotomous', ['cat1', 'cat2'], ['cat1'])] +test_data = [ + ("ordered_polytomous", ["cat1", "cat2", "cat3", "cat4"], ["cat1"]), + ("ordered_polytomous", ["cat1", "cat2", "cat3", "cat4"], ["cat1", "cat2"]), + ("ordered_polytomous", ["cat1", "cat2", "cat3", "cat4"], ["cat1", "cat2", "cat3"]), + ("dichotomous", ["cat1", "cat2"], ["cat1"]), +] -@pytest.mark.parametrize('distribution, categories, threshold', test_data) -def test_filter_by_exposure_categorical(disease_mock, mocker, distribution, categories, threshold): +@pytest.mark.parametrize("distribution, categories, threshold", test_data) +def test_filter_by_exposure_categorical( + disease_mock, mocker, distribution, categories, threshold +): disease = disease_mock(distribution) test_index = range(500) per_cat = len(test_index) // len(categories) @@ -36,33 +43,55 @@ def test_filter_by_exposure_categorical(disease_mock, mocker, distribution, cate assert np.all(expected(test_index) == filter_func(test_index)) -test_data = [('ensemble', '>=7'), ('ensemble', '<=7.5'), ('lognormal', '=2.5'), ('normal', '4'), ('normal', '+4'), - ('lognormal', '>=')] +test_data = [ + ("ensemble", ">=7"), + ("ensemble", "<=7.5"), + ("lognormal", "=2.5"), + ("normal", "4"), + ("normal", "+4"), + ("lognormal", ">="), +] -@pytest.mark.parametrize('distribution, threshold', test_data) -def test_filter_by_exposure_continuous_incorrect_operator(disease_mock, distribution, threshold): +@pytest.mark.parametrize("distribution, threshold", test_data) +def test_filter_by_exposure_continuous_incorrect_operator( + disease_mock, distribution, threshold +): disease = disease_mock(distribution) disease.threshold = threshold - with pytest.raises(ValueError, match='incorrect threshold'): + with pytest.raises(ValueError, match="incorrect threshold"): disease.get_exposure_filter(distribution, lambda index: index, threshold) -test_data = [('ensemble', '>7'), ('ensemble', '<5'), ('lognormal', '<3.5'), ('normal', '>5.5')] +test_data = [ + ("ensemble", ">7"), + ("ensemble", "<5"), + ("lognormal", "<3.5"), + ("normal", ">5.5"), +] -@pytest.mark.parametrize('distribution, threshold', test_data) +@pytest.mark.parametrize("distribution, threshold", test_data) def test_filter_by_exposure_continuous(disease_mock, distribution, threshold): disease = disease_mock(distribution) disease.threshold = threshold - op = {'>', '<'}.intersection(list(threshold)).pop() + op = {">", "<"}.intersection(list(threshold)).pop() threshold_val = float(threshold.split(op)[-1]) threshold_op = gt if op == ">" else lt test_index = range(500) - current_exposure = lambda index: pd.Series([threshold_val - 0.2, threshold_val - 0.1, threshold_val, - threshold_val + 0.1, threshold_val + 0.2] *100, index=test_index) + current_exposure = lambda index: pd.Series( + [ + threshold_val - 0.2, + threshold_val - 0.1, + threshold_val, + threshold_val + 0.1, + threshold_val + 0.2, + ] + * 100, + index=test_index, + ) filter_func = disease.get_exposure_filter(distribution, current_exposure, threshold) expected = lambda index: threshold_op(current_exposure(index), threshold_val) @@ -71,39 +100,50 @@ def test_filter_by_exposure_continuous(disease_mock, distribution, threshold): def test_mortality_rate_pandas_dataframe(disease_mock): - disease = disease_mock('enesmble') + disease = disease_mock("enesmble") num_sims = 500 test_index = range(num_sims) - current_disease_status = [disease.cause.name] * int(0.2 * num_sims) + \ - [f'susceptible_to_{disease.cause.name}'] * int(num_sims * 0.8) - disease.population_view.get.side_effect = lambda index: pd.DataFrame({disease.cause.name: current_disease_status, - 'alive': 'alive'}, index=index) - expected_mortality_values = pd.Series(current_disease_status, name=disease.cause.name, - index=test_index).map({disease.cause.name: 0.05, - f'susceptible_to_{disease.cause.name}': 0}) + current_disease_status = [disease.cause.name] * int(0.2 * num_sims) + [ + f"susceptible_to_{disease.cause.name}" + ] * int(num_sims * 0.8) + disease.population_view.get.side_effect = lambda index: pd.DataFrame( + {disease.cause.name: current_disease_status, "alive": "alive"}, index=index + ) + expected_mortality_values = pd.Series( + current_disease_status, name=disease.cause.name, index=test_index + ).map({disease.cause.name: 0.05, f"susceptible_to_{disease.cause.name}": 0}) disease.excess_mortality_rate.return_value = expected_mortality_values - rates_df = pd.DataFrame({'other_causes': 0, 'another_test_cause': 0.001}, index=test_index) - expected = pd.DataFrame({'other_causes': 0, 'another_test_cause': 0.001, - disease.cause.name: expected_mortality_values}, index=test_index) + rates_df = pd.DataFrame( + {"other_causes": 0, "another_test_cause": 0.001}, index=test_index + ) + expected = pd.DataFrame( + { + "other_causes": 0, + "another_test_cause": 0.001, + disease.cause.name: expected_mortality_values, + }, + index=test_index, + ) assert np.all(expected == disease.adjust_mortality_rate(test_index, rates_df)) test_data = [ - ('disease_no_recovery', False), - ('disease_with_recovery', True), + ("disease_no_recovery", False), + ("disease_with_recovery", True), ] -@pytest.mark.parametrize('disease, recoverable', test_data) + +@pytest.mark.parametrize("disease, recoverable", test_data) def test_state_transition_names(disease, recoverable): - model = RiskAttributableDisease(f'cause.{disease}', f'risk_factor.{disease}') + model = RiskAttributableDisease(f"cause.{disease}", f"risk_factor.{disease}") model.recoverable = recoverable model.adjust_state_and_transitions() - states = [disease, f'susceptible_to_{disease}'] + states = [disease, f"susceptible_to_{disease}"] transitions = [ - f'susceptible_to_{disease}_TO_{disease}', + f"susceptible_to_{disease}_TO_{disease}", ] if recoverable: - transitions.append(f'{disease}_TO_susceptible_to_{disease}') + transitions.append(f"{disease}_TO_susceptible_to_{disease}") assert set(model.state_names) == set(states) assert set(model.transition_names) == set(transitions) diff --git a/tests/metrics/test_utilities.py b/tests/metrics/test_utilities.py index 520cbf3fa..6490bd204 100644 --- a/tests/metrics/test_utilities.py +++ b/tests/metrics/test_utilities.py @@ -1,17 +1,30 @@ -from itertools import product, combinations +from itertools import combinations, product import numpy as np import pandas as pd import pytest - from vivarium.testing_utilities import metadata -from vivarium_public_health.metrics.utilities import (QueryString, OutputTemplate, to_years, get_output_template, - get_susceptible_person_time, get_disease_event_counts, - get_age_sex_filter_and_iterables, - get_time_iterable, get_lived_in_span, get_person_time_in_span, - get_deaths, get_years_of_life_lost, - get_years_lived_with_disability, get_age_bins, - _MIN_YEAR, _MAX_YEAR, _MIN_AGE, _MAX_AGE) + +from vivarium_public_health.metrics.utilities import ( + _MAX_AGE, + _MAX_YEAR, + _MIN_AGE, + _MIN_YEAR, + OutputTemplate, + QueryString, + get_age_bins, + get_age_sex_filter_and_iterables, + get_deaths, + get_disease_event_counts, + get_lived_in_span, + get_output_template, + get_person_time_in_span, + get_susceptible_person_time, + get_time_iterable, + get_years_lived_with_disability, + get_years_of_life_lost, + to_years, +) @pytest.fixture(params=((0, 100, 5, 1000), (20, 100, 5, 1000))) @@ -21,39 +34,47 @@ def ages_and_bins(request): age_groups = request.param[2] num_ages = request.param[3] - ages = np.linspace(age_min, age_max - age_groups/num_ages, num_ages) + ages = np.linspace(age_min, age_max - age_groups / num_ages, num_ages) bin_ages, step = np.linspace(age_min, age_max, age_groups, endpoint=False, retstep=True) - age_bins = pd.DataFrame({'age_start': bin_ages, - 'age_end': bin_ages + step, - 'age_group_name': [str(name) for name in range(len(bin_ages))]}) + age_bins = pd.DataFrame( + { + "age_start": bin_ages, + "age_end": bin_ages + step, + "age_group_name": [str(name) for name in range(len(bin_ages))], + } + ) return ages, age_bins @pytest.fixture def sexes(): - return ['Male', 'Female'] + return ["Male", "Female"] @pytest.fixture(params=list(product((True, False), repeat=3))) def observer_config(request): - c = {'by_age': request.param[0], - 'by_sex': request.param[1], - 'by_year': request.param[2]} + c = {"by_age": request.param[0], "by_sex": request.param[1], "by_year": request.param[2]} return c @pytest.fixture() def builder(mocker): builder = mocker.MagicMock() - df = pd.DataFrame({'age_start': [0, 1, 4], - 'age_group_name': ['youngest', 'younger', 'young'], - 'age_end': [1, 4, 6]}) + df = pd.DataFrame( + { + "age_start": [0, 1, 4], + "age_group_name": ["youngest", "younger", "young"], + "age_end": [1, 4, 6], + } + ) builder.data.load.return_value = df return builder -@pytest.mark.parametrize('reference, test', product([QueryString(''), QueryString('abc')], [QueryString(''), ''])) +@pytest.mark.parametrize( + "reference, test", product([QueryString(""), QueryString("abc")], [QueryString(""), ""]) +) def test_query_string_empty(reference, test): result = str(reference) assert reference + test == result @@ -75,24 +96,24 @@ def test_query_string_empty(reference, test): assert isinstance(test, QueryString) -@pytest.mark.parametrize('a, b', product([QueryString('a')], [QueryString('b'), 'b'])) +@pytest.mark.parametrize("a, b", product([QueryString("a")], [QueryString("b"), "b"])) def test_query_string(a, b): - assert a + b == 'a and b' - assert a + b == QueryString('a and b') + assert a + b == "a and b" + assert a + b == QueryString("a and b") assert isinstance(a + b, QueryString) - assert b + a == 'b and a' - assert b + a == QueryString('b and a') + assert b + a == "b and a" + assert b + a == QueryString("b and a") assert isinstance(b + a, QueryString) a += b - assert a == 'a and b' - assert a == QueryString('a and b') + assert a == "a and b" + assert a == QueryString("a and b") assert isinstance(a, QueryString) b += a - assert b == 'b and a and b' - assert b == QueryString('b and a and b') + assert b == "b and a and b" + assert b == QueryString("b and a and b") assert isinstance(b, QueryString) @@ -100,151 +121,168 @@ def test_get_output_template(observer_config): template = get_output_template(**observer_config) assert isinstance(template, OutputTemplate) - assert '${measure}' in template.template - - if observer_config['by_year']: - assert '_in_${year}' in template.template - if observer_config['by_sex']: - assert '_among_${sex}' in template.template - if observer_config['by_age']: - assert '_in_age_group_${age_group}' in template.template - - -@pytest.mark.parametrize('measure, sex, age, year', - product(['test', 'Test'], ['female', 'Female'], - [1.0, 1, 'Early Neonatal'], [2011, '2011'])) + assert "${measure}" in template.template + + if observer_config["by_year"]: + assert "_in_${year}" in template.template + if observer_config["by_sex"]: + assert "_among_${sex}" in template.template + if observer_config["by_age"]: + assert "_in_age_group_${age_group}" in template.template + + +@pytest.mark.parametrize( + "measure, sex, age, year", + product( + ["test", "Test"], ["female", "Female"], [1.0, 1, "Early Neonatal"], [2011, "2011"] + ), +) def test_output_template(observer_config, measure, sex, age, year): template = get_output_template(**observer_config) out1 = template.substitute(measure=measure, sex=sex, age_group=age, year=year) - out2 = template.substitute(measure=measure).substitute(sex=sex).substitute(age_group=age).substitute(year=year) + out2 = ( + template.substitute(measure=measure) + .substitute(sex=sex) + .substitute(age_group=age) + .substitute(year=year) + ) assert out1 == out2 def test_output_template_exact(): template = get_output_template(by_age=True, by_sex=True, by_year=True) - out = template.substitute(measure='Test', sex='Female', age_group=1.0, year=2011) - expected = 'test_in_2011_among_female_in_age_group_1.0' + out = template.substitute(measure="Test", sex="Female", age_group=1.0, year=2011) + expected = "test_in_2011_among_female_in_age_group_1.0" assert out == expected - out = template.substitute(measure='Test', sex='Female', age_group='Early Neonatal', year=2011) - expected = 'test_in_2011_among_female_in_age_group_early_neonatal' + out = template.substitute( + measure="Test", sex="Female", age_group="Early Neonatal", year=2011 + ) + expected = "test_in_2011_among_female_in_age_group_early_neonatal" assert out == expected def test_get_age_sex_filter_and_iterables(ages_and_bins, observer_config): _, age_bins = ages_and_bins - age_sex_filter, (ages, sexes) = get_age_sex_filter_and_iterables(observer_config, age_bins) + age_sex_filter, (ages, sexes) = get_age_sex_filter_and_iterables( + observer_config, age_bins + ) assert isinstance(age_sex_filter, QueryString) - if observer_config['by_age'] and observer_config['by_sex']: + if observer_config["by_age"] and observer_config["by_sex"]: assert age_sex_filter == '{age_start} <= age and age < {age_end} and sex == "{sex}"' - for (g1, s1), (g2, s2) in zip(ages, age_bins.set_index('age_group_name').iterrows()): + for (g1, s1), (g2, s2) in zip(ages, age_bins.set_index("age_group_name").iterrows()): assert g1 == g2 assert s1.equals(s2) - assert sexes == ['Male', 'Female'] + assert sexes == ["Male", "Female"] - elif observer_config['by_age']: - assert age_sex_filter == '{age_start} <= age and age < {age_end}' + elif observer_config["by_age"]: + assert age_sex_filter == "{age_start} <= age and age < {age_end}" - for (g1, s1), (g2, s2) in zip(ages, age_bins.set_index('age_group_name').iterrows()): + for (g1, s1), (g2, s2) in zip(ages, age_bins.set_index("age_group_name").iterrows()): assert g1 == g2 assert s1.equals(s2) - assert sexes == ['Both'] - elif observer_config['by_sex']: + assert sexes == ["Both"] + elif observer_config["by_sex"]: assert age_sex_filter == 'sex == "{sex}"' assert len(ages) == 1 group, data = ages[0] - assert group == 'all_ages' - assert data['age_start'] == _MIN_AGE - assert data['age_end'] == _MAX_AGE + assert group == "all_ages" + assert data["age_start"] == _MIN_AGE + assert data["age_end"] == _MAX_AGE - assert sexes == ['Male', 'Female'] + assert sexes == ["Male", "Female"] else: - assert age_sex_filter == '' + assert age_sex_filter == "" assert len(ages) == 1 group, data = ages[0] - assert group == 'all_ages' - assert data['age_start'] == _MIN_AGE - assert data['age_end'] == _MAX_AGE + assert group == "all_ages" + assert data["age_start"] == _MIN_AGE + assert data["age_end"] == _MAX_AGE - assert sexes == ['Both'] + assert sexes == ["Both"] def test_get_age_sex_filter_and_iterables_with_span(ages_and_bins, observer_config): _, age_bins = ages_and_bins - age_sex_filter, (ages, sexes) = get_age_sex_filter_and_iterables(observer_config, age_bins, in_span=True) + age_sex_filter, (ages, sexes) = get_age_sex_filter_and_iterables( + observer_config, age_bins, in_span=True + ) assert isinstance(age_sex_filter, QueryString) - if observer_config['by_age'] and observer_config['by_sex']: + if observer_config["by_age"] and observer_config["by_sex"]: expected = '{age_start} < age_at_span_end and age_at_span_start < {age_end} and sex == "{sex}"' assert age_sex_filter == expected - for (g1, s1), (g2, s2) in zip(ages, age_bins.set_index('age_group_name').iterrows()): + for (g1, s1), (g2, s2) in zip(ages, age_bins.set_index("age_group_name").iterrows()): assert g1 == g2 assert s1.equals(s2) - assert sexes == ['Male', 'Female'] + assert sexes == ["Male", "Female"] - elif observer_config['by_age']: - assert age_sex_filter == '{age_start} < age_at_span_end and age_at_span_start < {age_end}' + elif observer_config["by_age"]: + assert ( + age_sex_filter + == "{age_start} < age_at_span_end and age_at_span_start < {age_end}" + ) - for (g1, s1), (g2, s2) in zip(ages, age_bins.set_index('age_group_name').iterrows()): + for (g1, s1), (g2, s2) in zip(ages, age_bins.set_index("age_group_name").iterrows()): assert g1 == g2 assert s1.equals(s2) - assert sexes == ['Both'] - elif observer_config['by_sex']: + assert sexes == ["Both"] + elif observer_config["by_sex"]: assert age_sex_filter == 'sex == "{sex}"' assert len(ages) == 1 group, data = ages[0] - assert group == 'all_ages' - assert data['age_start'] == _MIN_AGE - assert data['age_end'] == _MAX_AGE + assert group == "all_ages" + assert data["age_start"] == _MIN_AGE + assert data["age_end"] == _MAX_AGE - assert sexes == ['Male', 'Female'] + assert sexes == ["Male", "Female"] else: - assert age_sex_filter == '' + assert age_sex_filter == "" assert len(ages) == 1 group, data = ages[0] - assert group == 'all_ages' - assert data['age_start'] == _MIN_AGE - assert data['age_end'] == _MAX_AGE + assert group == "all_ages" + assert data["age_start"] == _MIN_AGE + assert data["age_end"] == _MAX_AGE - assert sexes == ['Both'] + assert sexes == ["Both"] -@pytest.mark.parametrize('year_start, year_end', [(2011, 2017), (2011, 2011)]) +@pytest.mark.parametrize("year_start, year_end", [(2011, 2017), (2011, 2011)]) def test_get_time_iterable_no_year(year_start, year_end): - config = {'by_year': False} - sim_start = pd.Timestamp(f'7-2-{year_start}') - sim_end = pd.Timestamp(f'3-15-{year_end}') + config = {"by_year": False} + sim_start = pd.Timestamp(f"7-2-{year_start}") + sim_end = pd.Timestamp(f"3-15-{year_end}") time_spans = get_time_iterable(config, sim_start, sim_end) assert len(time_spans) == 1 name, (start, end) = time_spans[0] - assert name == 'all_years' - assert start == pd.Timestamp(f'1-1-{_MIN_YEAR}') - assert end == pd.Timestamp(f'1-1-{_MAX_YEAR}') + assert name == "all_years" + assert start == pd.Timestamp(f"1-1-{_MIN_YEAR}") + assert end == pd.Timestamp(f"1-1-{_MAX_YEAR}") -@pytest.mark.parametrize('year_start, year_end', [(2011, 2017), (2011, 2011)]) +@pytest.mark.parametrize("year_start, year_end", [(2011, 2017), (2011, 2011)]) def test_get_time_iterable_with_year(year_start, year_end): - config = {'by_year': True} - sim_start = pd.Timestamp(f'7-2-{year_start}') - sim_end = pd.Timestamp(f'3-15-{year_end}') + config = {"by_year": True} + sim_start = pd.Timestamp(f"7-2-{year_start}") + sim_end = pd.Timestamp(f"3-15-{year_end}") time_spans = get_time_iterable(config, sim_start, sim_end) @@ -253,49 +291,55 @@ def test_get_time_iterable_with_year(year_start, year_end): for year, time_span in zip(years, time_spans): name, (start, end) = time_span assert name == year - assert start == pd.Timestamp(f'1-1-{year}') - assert end == pd.Timestamp(f'1-1-{year+1}') + assert start == pd.Timestamp(f"1-1-{year}") + assert end == pd.Timestamp(f"1-1-{year+1}") def test_get_susceptible_person_time(ages_and_bins, sexes, observer_config): ages, age_bins = ages_and_bins - disease = 'test_disease' - states = [f'susceptible_to_{disease}', disease] - pop = pd.DataFrame(list(product(ages, sexes, states)), columns=['age', 'sex', disease]) - pop['alive'] = 'alive' + disease = "test_disease" + states = [f"susceptible_to_{disease}", disease] + pop = pd.DataFrame(list(product(ages, sexes, states)), columns=["age", "sex", disease]) + pop["alive"] = "alive" # Shuffle the rows pop = pop.sample(frac=1).reset_index(drop=True) year = 2017 step_size = pd.Timedelta(days=7) - person_time = get_susceptible_person_time(pop, observer_config, disease, year, step_size, age_bins) + person_time = get_susceptible_person_time( + pop, observer_config, disease, year, step_size, age_bins + ) values = set(person_time.values()) assert len(values) == 1 - expected_value = to_years(step_size)*len(pop)/2 - if observer_config['by_sex']: + expected_value = to_years(step_size) * len(pop) / 2 + if observer_config["by_sex"]: expected_value /= 2 - if observer_config['by_age']: + if observer_config["by_age"]: expected_value /= len(age_bins) assert np.isclose(values.pop(), expected_value) # Doubling pop should double person time pop = pd.concat([pop, pop], axis=0, ignore_index=True) - person_time = get_susceptible_person_time(pop, observer_config, disease, year, step_size, age_bins) + person_time = get_susceptible_person_time( + pop, observer_config, disease, year, step_size, age_bins + ) values = set(person_time.values()) assert len(values) == 1 - assert np.isclose(values.pop(), 2*expected_value) + assert np.isclose(values.pop(), 2 * expected_value) def test_get_disease_event_counts(ages_and_bins, sexes, observer_config): ages, age_bins = ages_and_bins - disease = 'test_disease' - event_time = pd.Timestamp('1-1-2017') + disease = "test_disease" + event_time = pd.Timestamp("1-1-2017") states = [event_time, pd.NaT] - pop = pd.DataFrame(list(product(ages, sexes, states)), columns=['age', 'sex', f'{disease}_event_time']) + pop = pd.DataFrame( + list(product(ages, sexes, states)), columns=["age", "sex", f"{disease}_event_time"] + ) # Shuffle the rows pop = pop.sample(frac=1).reset_index(drop=True) @@ -304,9 +348,9 @@ def test_get_disease_event_counts(ages_and_bins, sexes, observer_config): values = set(counts.values()) assert len(values) == 1 expected_value = len(pop) / len(states) - if observer_config['by_sex']: + if observer_config["by_sex"]: expected_value /= 2 - if observer_config['by_age']: + if observer_config["by_age"]: expected_value /= len(age_bins) assert np.isclose(values.pop(), expected_value) @@ -322,25 +366,25 @@ def test_get_disease_event_counts(ages_and_bins, sexes, observer_config): def test_get_lived_in_span(): dt = pd.Timedelta(days=5) - reference_t = pd.Timestamp('1-10-2010') + reference_t = pd.Timestamp("1-10-2010") - early_1 = reference_t - 2*dt + early_1 = reference_t - 2 * dt early_2 = reference_t - dt t_start = reference_t mid_1 = reference_t + dt - mid_2 = reference_t + 2*dt + mid_2 = reference_t + 2 * dt - t_end = reference_t + 3*dt + t_end = reference_t + 3 * dt - late_1 = reference_t + 4*dt - late_2 = reference_t + 5*dt + late_1 = reference_t + 4 * dt + late_2 = reference_t + 5 * dt # 28 combinations, six of which are entirely out of the time span times = [early_1, early_2, t_start, mid_1, mid_2, t_end, late_1, late_2] starts, ends = zip(*combinations(times, 2)) - pop = pd.DataFrame({'age': to_years(10*dt), 'entrance_time': starts, 'exit_time': ends}) + pop = pd.DataFrame({"age": to_years(10 * dt), "entrance_time": starts, "exit_time": ends}) lived_in_span = get_lived_in_span(pop, t_start, t_end) # Indices here are from the combinatorics math. They represent @@ -353,35 +397,49 @@ def test_get_lived_in_span(): assert {0, 1, 7, 25, 26, 27}.intersection(lived_in_span.index) == set() exit_before_span_end = lived_in_span.exit_time <= t_end - assert np.all(lived_in_span.loc[exit_before_span_end, 'age_at_span_end'] - == lived_in_span.loc[exit_before_span_end, 'age']) + assert np.all( + lived_in_span.loc[exit_before_span_end, "age_at_span_end"] + == lived_in_span.loc[exit_before_span_end, "age"] + ) exit_after_span_end = ~exit_before_span_end age_at_end = lived_in_span.age - to_years(lived_in_span.exit_time - t_end) - assert np.all(lived_in_span.loc[exit_after_span_end, 'age_at_span_end'] - == age_at_end.loc[exit_after_span_end]) + assert np.all( + lived_in_span.loc[exit_after_span_end, "age_at_span_end"] + == age_at_end.loc[exit_after_span_end] + ) enter_after_span_start = lived_in_span.entrance_time >= t_start - age_at_start = lived_in_span.age - to_years(lived_in_span.exit_time - lived_in_span.entrance_time) - assert np.all(lived_in_span.loc[enter_after_span_start, 'age_at_span_start'] - == age_at_start.loc[enter_after_span_start]) + age_at_start = lived_in_span.age - to_years( + lived_in_span.exit_time - lived_in_span.entrance_time + ) + assert np.all( + lived_in_span.loc[enter_after_span_start, "age_at_span_start"] + == age_at_start.loc[enter_after_span_start] + ) enter_before_span_start = ~enter_after_span_start age_at_start = lived_in_span.age - to_years(lived_in_span.exit_time - t_start) - assert np.all(lived_in_span.loc[enter_before_span_start, 'age_at_span_start'] - == age_at_start.loc[enter_before_span_start]) + assert np.all( + lived_in_span.loc[enter_before_span_start, "age_at_span_start"] + == age_at_start.loc[enter_before_span_start] + ) def test_get_lived_in_span_no_one_in_span(): dt = pd.Timedelta(days=365.25) - t_start = pd.Timestamp('1-1-2010') + t_start = pd.Timestamp("1-1-2010") t_end = t_start + dt - pop = pd.DataFrame({'entrance_time': t_start - 2*dt, 'exit_time': t_start - dt, 'age': range(100)}) + pop = pd.DataFrame( + {"entrance_time": t_start - 2 * dt, "exit_time": t_start - dt, "age": range(100)} + ) lived_in_span = get_lived_in_span(pop, t_start, t_end) assert lived_in_span.empty - pop = pd.DataFrame({'entrance_time': t_end + dt, 'exit_time': t_end + 2*dt, 'age': range(100)}) + pop = pd.DataFrame( + {"entrance_time": t_end + dt, "exit_time": t_end + 2 * dt, "age": range(100)} + ) lived_in_span = get_lived_in_span(pop, t_start, t_end) assert lived_in_span.empty @@ -392,71 +450,102 @@ def test_get_person_time_in_span(ages_and_bins, observer_config): end = int(age_bins.age_end.max()) n_ages = len(list(range(start, end))) n_bins = len(age_bins) - segments_per_age = [(i + 1)*(n_ages - i) for i in range(n_ages)] + segments_per_age = [(i + 1) * (n_ages - i) for i in range(n_ages)] ages_per_bin = n_ages // n_bins - age_bins['expected_time'] = [sum(segments_per_age[ages_per_bin*i:ages_per_bin*(i+1)]) for i in range(n_bins)] + age_bins["expected_time"] = [ + sum(segments_per_age[ages_per_bin * i : ages_per_bin * (i + 1)]) + for i in range(n_bins) + ] age_starts, age_ends = zip(*combinations(range(start, end + 1), 2)) - women = pd.DataFrame({'age_at_span_start': age_starts, 'age_at_span_end': age_ends, 'sex': 'Female'}) + women = pd.DataFrame( + {"age_at_span_start": age_starts, "age_at_span_end": age_ends, "sex": "Female"} + ) men = women.copy() - men.loc[:, 'sex'] = 'Male' + men.loc[:, "sex"] = "Male" - lived_in_span = pd.concat([women, men], ignore_index=True).sample(frac=1).reset_index(drop=True) + lived_in_span = ( + pd.concat([women, men], ignore_index=True).sample(frac=1).reset_index(drop=True) + ) base_filter = QueryString("") - span_key = get_output_template(**observer_config).substitute(measure='person_time', year=2019) + span_key = get_output_template(**observer_config).substitute( + measure="person_time", year=2019 + ) - pt = get_person_time_in_span(lived_in_span, base_filter, span_key, observer_config, age_bins) + pt = get_person_time_in_span( + lived_in_span, base_filter, span_key, observer_config, age_bins + ) - if observer_config['by_age']: + if observer_config["by_age"]: for group, age_bin in age_bins.iterrows(): - group_pt = sum(set([v for k, v in pt.items() if f'in_age_group_{group}' in k])) - if observer_config['by_sex']: + group_pt = sum(set([v for k, v in pt.items() if f"in_age_group_{group}" in k])) + if observer_config["by_sex"]: assert group_pt == age_bin.expected_time else: assert group_pt == 2 * age_bin.expected_time else: group_pt = sum(set(pt.values())) - if observer_config['by_sex']: + if observer_config["by_sex"]: assert group_pt == age_bins.expected_time.sum() else: assert group_pt == 2 * age_bins.expected_time.sum() def test_get_deaths(ages_and_bins, sexes, observer_config): - alive = ['dead', 'alive'] + alive = ["dead", "alive"] ages, age_bins = ages_and_bins - exit_times = [pd.Timestamp('1-1-2012'), pd.Timestamp('1-1-2013')] - causes = ['cause_a', 'cause_b'] + exit_times = [pd.Timestamp("1-1-2012"), pd.Timestamp("1-1-2013")] + causes = ["cause_a", "cause_b"] - pop = pd.DataFrame(list(product(alive, ages, sexes, exit_times, causes)), - columns=['alive', 'age', 'sex', 'exit_time', 'cause_of_death']) + pop = pd.DataFrame( + list(product(alive, ages, sexes, exit_times, causes)), + columns=["alive", "age", "sex", "exit_time", "cause_of_death"], + ) # Shuffle the rows pop = pop.sample(frac=1).reset_index(drop=True) - deaths = get_deaths(pop, observer_config, pd.Timestamp('1-1-2010'), pd.Timestamp('1-1-2015'), age_bins, causes) + deaths = get_deaths( + pop, + observer_config, + pd.Timestamp("1-1-2010"), + pd.Timestamp("1-1-2015"), + age_bins, + causes, + ) values = set(deaths.values()) expected_value = len(pop) / (len(causes) * len(alive)) - if observer_config['by_year']: - assert len(values) == 2 # Uniform across bins with deaths, 0 in year bins without deaths + if observer_config["by_year"]: + assert ( + len(values) == 2 + ) # Uniform across bins with deaths, 0 in year bins without deaths expected_value /= 2 else: assert len(values) == 1 value = max(values) - if observer_config['by_sex']: + if observer_config["by_sex"]: expected_value /= 2 - if observer_config['by_age']: + if observer_config["by_age"]: expected_value /= len(age_bins) assert np.isclose(value, expected_value) # Doubling pop should double counts pop = pd.concat([pop, pop], axis=0, ignore_index=True) - deaths = get_deaths(pop, observer_config, pd.Timestamp('1-1-2010'), pd.Timestamp('1-1-2015'), age_bins, causes) + deaths = get_deaths( + pop, + observer_config, + pd.Timestamp("1-1-2010"), + pd.Timestamp("1-1-2015"), + age_bins, + causes, + ) values = set(deaths.values()) - if observer_config['by_year']: - assert len(values) == 2 # Uniform across bins with deaths, 0 in year bins without deaths + if observer_config["by_year"]: + assert ( + len(values) == 2 + ) # Uniform across bins with deaths, 0 in year bins without deaths else: assert len(values) == 1 value = max(values) @@ -464,45 +553,65 @@ def test_get_deaths(ages_and_bins, sexes, observer_config): def test_get_years_of_life_lost(ages_and_bins, sexes, observer_config): - alive = ['dead', 'alive'] + alive = ["dead", "alive"] ages, age_bins = ages_and_bins - exit_times = [pd.Timestamp('1-1-2012'), pd.Timestamp('1-1-2013')] - causes = ['cause_a', 'cause_b'] + exit_times = [pd.Timestamp("1-1-2012"), pd.Timestamp("1-1-2013")] + causes = ["cause_a", "cause_b"] - pop = pd.DataFrame(list(product(alive, ages, sexes, exit_times, causes)), - columns=['alive', 'age', 'sex', 'exit_time', 'cause_of_death']) + pop = pd.DataFrame( + list(product(alive, ages, sexes, exit_times, causes)), + columns=["alive", "age", "sex", "exit_time", "cause_of_death"], + ) # Shuffle the rows pop = pop.sample(frac=1).reset_index(drop=True) def life_expectancy(index): return pd.Series(1, index=index) - ylls = get_years_of_life_lost(pop, observer_config, pd.Timestamp('1-1-2010'), pd.Timestamp('1-1-2015'), - age_bins, life_expectancy, causes) + ylls = get_years_of_life_lost( + pop, + observer_config, + pd.Timestamp("1-1-2010"), + pd.Timestamp("1-1-2015"), + age_bins, + life_expectancy, + causes, + ) values = set(ylls.values()) expected_value = len(pop) / (len(causes) * len(alive)) - if observer_config['by_year']: - assert len(values) == 2 # Uniform across bins with deaths, 0 in year bins without deaths + if observer_config["by_year"]: + assert ( + len(values) == 2 + ) # Uniform across bins with deaths, 0 in year bins without deaths expected_value /= 2 else: assert len(values) == 1 value = max(values) - if observer_config['by_sex']: + if observer_config["by_sex"]: expected_value /= 2 - if observer_config['by_age']: + if observer_config["by_age"]: expected_value /= len(age_bins) assert np.isclose(value, expected_value) # Doubling pop should double counts pop = pd.concat([pop, pop], axis=0, ignore_index=True) - ylls = get_years_of_life_lost(pop, observer_config, pd.Timestamp('1-1-2010'), pd.Timestamp('1-1-2015'), - age_bins, life_expectancy, causes) + ylls = get_years_of_life_lost( + pop, + observer_config, + pd.Timestamp("1-1-2010"), + pd.Timestamp("1-1-2015"), + age_bins, + life_expectancy, + causes, + ) values = set(ylls.values()) - if observer_config['by_year']: - assert len(values) == 2 # Uniform across bins with deaths, 0 in year bins without deaths + if observer_config["by_year"]: + assert ( + len(values) == 2 + ) # Uniform across bins with deaths, 0 in year bins without deaths else: assert len(values) == 1 value = max(values) @@ -510,16 +619,18 @@ def life_expectancy(index): def test_get_years_lived_with_disability(ages_and_bins, sexes, observer_config): - alive = ['dead', 'alive'] + alive = ["dead", "alive"] ages, age_bins = ages_and_bins - causes = ['cause_a', 'cause_b'] - cause_a = ['susceptible_to_cause_a', 'cause_a'] - cause_b = ['susceptible_to_cause_b', 'cause_b'] + causes = ["cause_a", "cause_b"] + cause_a = ["susceptible_to_cause_a", "cause_a"] + cause_b = ["susceptible_to_cause_b", "cause_b"] year = 2010 step_size = pd.Timedelta(days=7) - pop = pd.DataFrame(list(product(alive, ages, sexes, cause_a, cause_b)), - columns=['alive', 'age', 'sex'] + causes) + pop = pd.DataFrame( + list(product(alive, ages, sexes, cause_a, cause_b)), + columns=["alive", "age", "sex"] + causes, + ) # Shuffle the rows pop = pop.sample(frac=1).reset_index(drop=True) @@ -527,47 +638,54 @@ def disability_weight(cause): def inner(index): sub_pop = pop.loc[index] return pd.Series(1, index=index) * (sub_pop[cause] == cause) + return inner disability_weights = {cause: disability_weight(cause) for cause in causes} - ylds = get_years_lived_with_disability(pop, observer_config, year, step_size, age_bins, disability_weights, causes) + ylds = get_years_lived_with_disability( + pop, observer_config, year, step_size, age_bins, disability_weights, causes + ) values = set(ylds.values()) assert len(values) == 1 states_per_cause = len(cause_a) expected_value = len(pop) / (len(alive) * states_per_cause) * to_years(step_size) - if observer_config['by_sex']: + if observer_config["by_sex"]: expected_value /= 2 - if observer_config['by_age']: + if observer_config["by_age"]: expected_value /= len(age_bins) assert np.isclose(values.pop(), expected_value) # Doubling pop should double person time pop = pd.concat([pop, pop], axis=0, ignore_index=True) - ylds = get_years_lived_with_disability(pop, observer_config, year, step_size, age_bins, disability_weights, causes) + ylds = get_years_lived_with_disability( + pop, observer_config, year, step_size, age_bins, disability_weights, causes + ) values = set(ylds.values()) assert len(values) == 1 assert np.isclose(values.pop(), 2 * expected_value) -@pytest.mark.parametrize('age_start, exit_age, result_age_end_values, result_age_start_values', - [(2, 5, {4, 5}, {2, 4}), - (0, None, {1, 4, 6}, {0, 1, 4}), - (1, 4, {4}, {1}), - (1, 3, {3}, {1}), - (0.8, 6, {1, 4, 6}, {0.8, 1, 4})]) -def test_get_age_bins(builder, base_config, age_start, exit_age, result_age_end_values, result_age_start_values): - base_config.update({ - 'population': { - 'age_start': age_start, - 'exit_age': exit_age - } - }, **metadata(__file__)) +@pytest.mark.parametrize( + "age_start, exit_age, result_age_end_values, result_age_start_values", + [ + (2, 5, {4, 5}, {2, 4}), + (0, None, {1, 4, 6}, {0, 1, 4}), + (1, 4, {4}, {1}), + (1, 3, {3}, {1}), + (0.8, 6, {1, 4, 6}, {0.8, 1, 4}), + ], +) +def test_get_age_bins( + builder, base_config, age_start, exit_age, result_age_end_values, result_age_start_values +): + base_config.update( + {"population": {"age_start": age_start, "exit_age": exit_age}}, **metadata(__file__) + ) builder.configuration = base_config df = get_age_bins(builder) assert set(df.age_end) == result_age_end_values assert set(df.age_start) == result_age_start_values - diff --git a/tests/population/test_add_new_birth_cohort.py b/tests/population/test_add_new_birth_cohort.py index 2ca2fc826..d08e3d3ff 100644 --- a/tests/population/test_add_new_birth_cohort.py +++ b/tests/population/test_add_new_birth_cohort.py @@ -4,32 +4,47 @@ import pandas as pd import pytest from vivarium import InteractiveContext -from vivarium.testing_utilities import TestPopulation, metadata, build_table +from vivarium.testing_utilities import TestPopulation, build_table, metadata from vivarium_public_health import utilities -from vivarium_public_health.population import FertilityDeterministic, FertilityCrudeBirthRate, FertilityAgeSpecificRates +from vivarium_public_health.population import ( + FertilityAgeSpecificRates, + FertilityCrudeBirthRate, + FertilityDeterministic, +) @pytest.fixture() def config(base_config): - base_config.update({ - 'population': { - 'population_size': 10000, - 'age_start': 0, - 'age_end': 125, + base_config.update( + { + "population": { + "population_size": 10000, + "age_start": 0, + "age_end": 125, + }, + "time": { + "step_size": 10, + }, }, - 'time': { - 'step_size': 10, - } - }, source=str(Path(__file__).resolve()), layer='override') + source=str(Path(__file__).resolve()), + layer="override", + ) return base_config def crude_birth_rate_data(live_births=500): - return (build_table(['mean_value', live_births], 1990, 2017, ('age', 'year', 'sex', 'parameter', 'value')) - .query('age_start == 25 and sex != "Both"') - .drop(['age_start', 'age_end'], 'columns')) + return ( + build_table( + ["mean_value", live_births], + 1990, + 2017, + ("age", "year", "sex", "parameter", "value"), + ) + .query('age_start == 25 and sex != "Both"') + .drop(["age_start", "age_end"], "columns") + ) def test_FertilityDeterministic(config): @@ -38,11 +53,10 @@ def test_FertilityDeterministic(config): step_size = config.time.step_size num_days = 100 - config.update({ - 'fertility': { - 'number_of_new_simulants_each_year': annual_new_simulants - } - }, **metadata(__file__)) + config.update( + {"fertility": {"number_of_new_simulants_each_year": annual_new_simulants}}, + **metadata(__file__) + ) components = [TestPopulation(), FertilityDeterministic()] simulation = InteractiveContext(components=components, configuration=config) @@ -50,44 +64,51 @@ def test_FertilityDeterministic(config): pop = simulation.get_population() assert num_steps == num_days // step_size - assert np.all(pop.alive == 'alive') - assert int(num_days * annual_new_simulants / utilities.DAYS_PER_YEAR) == len(pop.age) - pop_size + assert np.all(pop.alive == "alive") + assert ( + int(num_days * annual_new_simulants / utilities.DAYS_PER_YEAR) + == len(pop.age) - pop_size + ) def test_FertilityCrudeBirthRate(config, base_plugins): pop_size = config.population.population_size num_days = 100 components = [TestPopulation(), FertilityCrudeBirthRate()] - simulation = InteractiveContext(components=components, - configuration=config, - plugin_configuration=base_plugins, - setup=False) + simulation = InteractiveContext( + components=components, + configuration=config, + plugin_configuration=base_plugins, + setup=False, + ) simulation._data.write("covariate.live_births_by_sex.estimate", crude_birth_rate_data()) simulation.setup() simulation.run_for(duration=pd.Timedelta(days=num_days)) pop = simulation.get_population() - assert np.all(pop.alive == 'alive') + assert np.all(pop.alive == "alive") assert len(pop.age) > pop_size def test_FertilityCrudeBirthRate_extrapolate_fail(config, base_plugins): - config.update({ - 'interpolation': { - 'extrapolate': False - }, - 'time': { - 'start': {'year': 2016}, - 'end': {'year': 2025}, - }, - }) + config.update( + { + "interpolation": {"extrapolate": False}, + "time": { + "start": {"year": 2016}, + "end": {"year": 2025}, + }, + } + ) components = [TestPopulation(), FertilityCrudeBirthRate()] - simulation = InteractiveContext(components=components, - configuration=config, - plugin_configuration=base_plugins, - setup=False) + simulation = InteractiveContext( + components=components, + configuration=config, + plugin_configuration=base_plugins, + setup=False, + ) simulation._data.write("covariate.live_births_by_sex.estimate", crude_birth_rate_data()) with pytest.raises(ValueError): @@ -95,31 +116,35 @@ def test_FertilityCrudeBirthRate_extrapolate_fail(config, base_plugins): def test_FertilityCrudeBirthRate_extrapolate(base_config, base_plugins): - base_config.update({ - 'population': { - 'population_size': 10000, - 'age_start': 0, - 'age_end': 125, - }, - 'interpolation': { - 'extrapolate': True - }, - 'time': { - 'start': {'year': 2016}, - 'end': {'year': 2026}, - 'step_size': 365, - }, - }) + base_config.update( + { + "population": { + "population_size": 10000, + "age_start": 0, + "age_end": 125, + }, + "interpolation": {"extrapolate": True}, + "time": { + "start": {"year": 2016}, + "end": {"year": 2026}, + "step_size": 365, + }, + } + ) pop_size = base_config.population.population_size true_pop_size = 8000 # What's available in the mock artifact live_births_by_sex = 500 components = [TestPopulation(), FertilityCrudeBirthRate()] - simulation = simulation = InteractiveContext(components=components, - configuration=base_config, - plugin_configuration=base_plugins, - setup=False) - simulation._data.write("covariate.live_births_by_sex.estimate", crude_birth_rate_data(live_births_by_sex)) + simulation = simulation = InteractiveContext( + components=components, + configuration=base_config, + plugin_configuration=base_plugins, + setup=False, + ) + simulation._data.write( + "covariate.live_births_by_sex.estimate", crude_birth_rate_data(live_births_by_sex) + ) simulation.setup() birth_rate = [] @@ -127,9 +152,9 @@ def test_FertilityCrudeBirthRate_extrapolate(base_config, base_plugins): pop_start = len(simulation.get_population()) simulation.step() pop_end = len(simulation.get_population()) - birth_rate.append((pop_end - pop_start)/pop_size) + birth_rate.append((pop_end - pop_start) / pop_size) - given_birth_rate = 2*live_births_by_sex / true_pop_size + given_birth_rate = 2 * live_births_by_sex / true_pop_size np.testing.assert_allclose(birth_rate, given_birth_rate, atol=0.01) @@ -137,41 +162,51 @@ def test_fertility_module(base_config, base_plugins): start_population_size = 1000 num_days = 1000 time_step = 10 # Days - base_config.update({ - 'population': { - 'population_size': start_population_size, - 'age_start': 0, - 'age_end': 125}, - 'time': {'step_size': time_step} - }, layer='override') + base_config.update( + { + "population": { + "population_size": start_population_size, + "age_start": 0, + "age_end": 125, + }, + "time": {"step_size": time_step}, + }, + layer="override", + ) components = [TestPopulation(), FertilityAgeSpecificRates()] - simulation = simulation = InteractiveContext(components=components, - configuration=base_config, - plugin_configuration=base_plugins, - setup=False) - - asfr_data = build_table(0.05, 1990, 2017).rename(columns={'value': 'mean_value'}) + simulation = simulation = InteractiveContext( + components=components, + configuration=base_config, + plugin_configuration=base_plugins, + setup=False, + ) + + asfr_data = build_table(0.05, 1990, 2017).rename(columns={"value": "mean_value"}) simulation._data.write("covariate.age_specific_fertility_rate.estimate", asfr_data) simulation.setup() time_start = simulation._clock.time - assert 'last_birth_time' in simulation.get_population().columns,\ - 'expect Fertility module to update state table.' - assert 'parent_id' in simulation.get_population().columns, \ - 'expect Fertility module to update state table.' + assert ( + "last_birth_time" in simulation.get_population().columns + ), "expect Fertility module to update state table." + assert ( + "parent_id" in simulation.get_population().columns + ), "expect Fertility module to update state table." simulation.run_for(duration=pd.Timedelta(days=num_days)) pop = simulation.get_population() # No death in this model. - assert np.all(pop.alive == 'alive'), 'expect all simulants to be alive' + assert np.all(pop.alive == "alive"), "expect all simulants to be alive" # TODO: Write a more rigorous test. - assert len(pop.age) > start_population_size, 'expect new simulants' + assert len(pop.age) > start_population_size, "expect new simulants" for i in range(start_population_size, len(pop)): - assert pop.loc[pop.iloc[i].parent_id].last_birth_time >= time_start, 'expect all children to have mothers who' \ - ' gave birth after the simulation starts.' + assert pop.loc[pop.iloc[i].parent_id].last_birth_time >= time_start, ( + "expect all children to have mothers who" + " gave birth after the simulation starts." + ) diff --git a/tests/population/test_base_population.py b/tests/population/test_base_population.py index 81635585c..dfa409f33 100644 --- a/tests/population/test_base_population.py +++ b/tests/population/test_base_population.py @@ -1,5 +1,5 @@ -from pathlib import Path import math +from pathlib import Path import numpy as np import pandas as pd @@ -7,61 +7,80 @@ from vivarium import InteractiveContext from vivarium.testing_utilities import get_randomness -from vivarium_public_health import utilities import vivarium_public_health.population.base_population as bp import vivarium_public_health.population.data_transformations as dt +from vivarium_public_health import utilities from vivarium_public_health.testing.utils import make_uniform_pop_data @pytest.fixture def config(base_config): - base_config.update({ - 'population': { - 'age_start': 0, - 'age_end': 110, + base_config.update( + { + "population": { + "age_start": 0, + "age_end": 110, + }, }, - }, source=str(Path(__file__).resolve()), layer='model_override') + source=str(Path(__file__).resolve()), + layer="model_override", + ) return base_config @pytest.fixture def generate_population_mock(mocker): - return mocker.patch('vivarium_public_health.population.base_population.generate_population') + return mocker.patch( + "vivarium_public_health.population.base_population.generate_population" + ) @pytest.fixture def age_bounds_mock(mocker): - return mocker.patch('vivarium_public_health.population.base_population._assign_demography_with_age_bounds') + return mocker.patch( + "vivarium_public_health.population.base_population._assign_demography_with_age_bounds" + ) @pytest.fixture def initial_age_mock(mocker): - return mocker.patch('vivarium_public_health.population.base_population._assign_demography_with_initial_age') + return mocker.patch( + "vivarium_public_health.population.base_population._assign_demography_with_initial_age" + ) def make_base_simulants(): simulant_ids = range(100000) creation_time = pd.Timestamp(1990, 7, 2) - return pd.DataFrame({'entrance_time': pd.Series(pd.Timestamp(creation_time), index=simulant_ids), - 'exit_time': pd.Series(pd.NaT, index=simulant_ids), - 'alive': pd.Series('alive', index=simulant_ids)}, - index=simulant_ids) + return pd.DataFrame( + { + "entrance_time": pd.Series(pd.Timestamp(creation_time), index=simulant_ids), + "exit_time": pd.Series(pd.NaT, index=simulant_ids), + "alive": pd.Series("alive", index=simulant_ids), + }, + index=simulant_ids, + ) def make_full_simulants(): base_simulants = make_base_simulants() - base_simulants['location'] = pd.Series(1, index=base_simulants.index) - base_simulants['sex'] = pd.Series('Male', index=base_simulants.index).astype( - pd.api.types.CategoricalDtype(categories=['Male', 'Female'], ordered=False)) - base_simulants['age'] = np.random.uniform(0, 100, len(base_simulants)) - base_simulants['tracked'] = pd.Series(True, index=base_simulants.index) + base_simulants["location"] = pd.Series(1, index=base_simulants.index) + base_simulants["sex"] = pd.Series("Male", index=base_simulants.index).astype( + pd.api.types.CategoricalDtype(categories=["Male", "Female"], ordered=False) + ) + base_simulants["age"] = np.random.uniform(0, 100, len(base_simulants)) + base_simulants["tracked"] = pd.Series(True, index=base_simulants.index) return base_simulants def test_select_sub_population_data(): - data = pd.DataFrame({'year_start': [1990, 1995, 2000, 2005], - 'year_end': [1995, 2000, 2005, 2010], - 'population': [100, 110, 120, 130]}) + data = pd.DataFrame( + { + "year_start": [1990, 1995, 2000, 2005], + "year_end": [1995, 2000, 2005, 2010], + "population": [100, 110, 120, 130], + } + ) sub_pop = bp.BasePopulation.select_sub_population_data(data, 1997) @@ -74,34 +93,41 @@ def test_BasePopulation(config, base_plugins, generate_population_mock): sims = make_full_simulants() start_population_size = len(sims) - generate_population_mock.return_value = sims.drop(columns=['tracked']) + generate_population_mock.return_value = sims.drop(columns=["tracked"]) base_pop = bp.BasePopulation() components = [base_pop] - config.update({'population': {'population_size': start_population_size}, - 'time': {'step_size': time_step}}, layer='override') - simulation = InteractiveContext(components=components, - configuration=config, - plugin_configuration=base_plugins) + config.update( + { + "population": {"population_size": start_population_size}, + "time": {"step_size": time_step}, + }, + layer="override", + ) + simulation = InteractiveContext( + components=components, configuration=config, plugin_configuration=base_plugins + ) time_start = simulation._clock.time - pop_structure = simulation._data.load('population.structure') + pop_structure = simulation._data.load("population.structure") uniform_pop = dt.assign_demographic_proportions(pop_structure) assert base_pop.population_data.equals(uniform_pop) - age_params = {'age_start': config.population.age_start, - 'age_end': config.population.age_end} + age_params = { + "age_start": config.population.age_start, + "age_end": config.population.age_end, + } sub_pop = bp.BasePopulation.select_sub_population_data(uniform_pop, time_start.year) generate_population_mock.assert_called_once() # Get a dictionary of the arguments used in the call mock_args = generate_population_mock.call_args[1] - assert mock_args['creation_time'] == time_start - simulation._clock.step_size - assert mock_args['age_params'] == age_params - assert mock_args['population_data'].equals(sub_pop) - assert mock_args['randomness_streams'] == base_pop.randomness + assert mock_args["creation_time"] == time_start - simulation._clock.step_size + assert mock_args["age_params"] == age_params + assert mock_args["population_data"].equals(sub_pop) + assert mock_args["randomness_streams"] == base_pop.randomness pop = simulation.get_population() for column in pop: assert pop[column].equals(sims[column]) @@ -111,55 +137,67 @@ def test_BasePopulation(config, base_plugins, generate_population_mock): simulation.run_for(duration=pd.Timedelta(days=num_days)) pop = simulation.get_population() - assert np.allclose(pop.age, final_ages, atol=0.5 / utilities.DAYS_PER_YEAR) # Within a half of a day. + assert np.allclose( + pop.age, final_ages, atol=0.5 / utilities.DAYS_PER_YEAR + ) # Within a half of a day. def test_age_out_simulants(config, base_plugins): start_population_size = 10000 num_days = 600 time_step = 100 # Days - config.update({'population': { - 'population_size': start_population_size, - 'age_start': 4, - 'age_end': 4, - 'exit_age': 5, - }, - 'time': {'step_size': time_step} - }, layer='override') + config.update( + { + "population": { + "population_size": start_population_size, + "age_start": 4, + "age_end": 4, + "exit_age": 5, + }, + "time": {"step_size": time_step}, + }, + layer="override", + ) components = [bp.BasePopulation()] - simulation = InteractiveContext(components=components, - configuration=config, - plugin_configuration=base_plugins) + simulation = InteractiveContext( + components=components, configuration=config, plugin_configuration=base_plugins + ) time_start = simulation._clock.time assert len(simulation.get_population()) == len(simulation.get_population().age.unique()) simulation.run_for(duration=pd.Timedelta(days=num_days)) pop = simulation.get_population() assert len(pop) == len(pop[~pop.tracked]) - exit_after_300_days = pop.exit_time >= time_start + pd.Timedelta(300, unit='D') - exit_before_400_days = pop.exit_time <= time_start + pd.Timedelta(400, unit='D') + exit_after_300_days = pop.exit_time >= time_start + pd.Timedelta(300, unit="D") + exit_before_400_days = pop.exit_time <= time_start + pd.Timedelta(400, unit="D") assert len(pop) == len(pop[exit_after_300_days & exit_before_400_days]) def test_generate_population_age_bounds(age_bounds_mock, initial_age_mock): creation_time = pd.Timestamp(1990, 7, 2) step_size = pd.Timedelta(days=1) - age_params = {'age_start': 0, - 'age_end': 120} + age_params = {"age_start": 0, "age_end": 120} pop_data = dt.assign_demographic_proportions(make_uniform_pop_data(age_bin_midpoint=True)) - r = {k: get_randomness() for k in ['general_purpose', 'bin_selection', 'age_smoothing']} + r = {k: get_randomness() for k in ["general_purpose", "bin_selection", "age_smoothing"]} sims = make_base_simulants() simulant_ids = sims.index - bp.generate_population(simulant_ids, creation_time, step_size, - age_params, pop_data, r, lambda *args, **kwargs: None) + bp.generate_population( + simulant_ids, + creation_time, + step_size, + age_params, + pop_data, + r, + lambda *args, **kwargs: None, + ) age_bounds_mock.assert_called_once() mock_args = age_bounds_mock.call_args[0] assert mock_args[0].equals(sims) assert mock_args[1].equals(pop_data) - assert mock_args[2] == float(age_params['age_start']) - assert mock_args[3] == float(age_params['age_end']) + assert mock_args[2] == float(age_params["age_start"]) + assert mock_args[3] == float(age_params["age_end"]) assert mock_args[4] == r initial_age_mock.assert_not_called() @@ -167,22 +205,28 @@ def test_generate_population_age_bounds(age_bounds_mock, initial_age_mock): def test_generate_population_initial_age(age_bounds_mock, initial_age_mock): creation_time = pd.Timestamp(1990, 7, 2) step_size = pd.Timedelta(days=1) - age_params = {'age_start': 0, - 'age_end': 0} + age_params = {"age_start": 0, "age_end": 0} pop_data = dt.assign_demographic_proportions(make_uniform_pop_data(age_bin_midpoint=True)) - r = {k: get_randomness() for k in ['general_purpose', 'bin_selection', 'age_smoothing']} + r = {k: get_randomness() for k in ["general_purpose", "bin_selection", "age_smoothing"]} sims = make_base_simulants() simulant_ids = sims.index - bp.generate_population(simulant_ids, creation_time, step_size, - age_params, pop_data, r, lambda *args, **kwargs: None) + bp.generate_population( + simulant_ids, + creation_time, + step_size, + age_params, + pop_data, + r, + lambda *args, **kwargs: None, + ) initial_age_mock.assert_called_once() mock_args = initial_age_mock.call_args[0] assert mock_args[0].equals(sims) assert mock_args[1].equals(pop_data) - assert mock_args[2] == float(age_params['age_start']) + assert mock_args[2] == float(age_params["age_start"]) assert mock_args[3] == step_size assert mock_args[4] == r age_bounds_mock.assert_not_called() @@ -193,19 +237,25 @@ def test__assign_demography_with_initial_age(config): pop_data = pop_data[pop_data.year_start == 1990] simulants = make_base_simulants() initial_age = 20 - r = {k: get_randomness() for k in ['general_purpose', 'bin_selection', 'age_smoothing']} + r = {k: get_randomness() for k in ["general_purpose", "bin_selection", "age_smoothing"]} step_size = pd.Timedelta(days=config.time.step_size) - simulants = bp._assign_demography_with_initial_age(simulants, pop_data, initial_age, - step_size, r, lambda *args, **kwargs: None) + simulants = bp._assign_demography_with_initial_age( + simulants, pop_data, initial_age, step_size, r, lambda *args, **kwargs: None + ) assert len(simulants) == len(simulants.age.unique()) assert simulants.age.min() > initial_age assert simulants.age.max() < initial_age + utilities.to_years(step_size) - assert math.isclose(len(simulants[simulants.sex == 'Male']) / len(simulants), 0.5, abs_tol=0.01) + assert math.isclose( + len(simulants[simulants.sex == "Male"]) / len(simulants), 0.5, abs_tol=0.01 + ) for location in simulants.location.unique(): - assert math.isclose(len(simulants[simulants.location == location]) / len(simulants), - 1 / len(simulants.location.unique()), abs_tol=0.01) + assert math.isclose( + len(simulants[simulants.location == location]) / len(simulants), + 1 / len(simulants.location.unique()), + abs_tol=0.01, + ) def test__assign_demography_with_initial_age_zero(config): @@ -213,19 +263,25 @@ def test__assign_demography_with_initial_age_zero(config): pop_data = pop_data[pop_data.year_start == 1990] simulants = make_base_simulants() initial_age = 0 - r = {k: get_randomness() for k in ['general_purpose', 'bin_selection', 'age_smoothing']} + r = {k: get_randomness() for k in ["general_purpose", "bin_selection", "age_smoothing"]} step_size = utilities.to_time_delta(config.time.step_size) - simulants = bp._assign_demography_with_initial_age(simulants, pop_data, initial_age, - step_size, r, lambda *args, **kwargs: None) + simulants = bp._assign_demography_with_initial_age( + simulants, pop_data, initial_age, step_size, r, lambda *args, **kwargs: None + ) assert len(simulants) == len(simulants.age.unique()) assert simulants.age.min() > initial_age assert simulants.age.max() < initial_age + utilities.to_years(step_size) - assert math.isclose(len(simulants[simulants.sex == 'Male']) / len(simulants), 0.5, abs_tol=0.01) + assert math.isclose( + len(simulants[simulants.sex == "Male"]) / len(simulants), 0.5, abs_tol=0.01 + ) for location in simulants.location.unique(): - assert math.isclose(len(simulants[simulants.location == location]) / len(simulants), - 1 / len(simulants.location.unique()), abs_tol=0.01) + assert math.isclose( + len(simulants[simulants.location == location]) / len(simulants), + 1 / len(simulants.location.unique()), + abs_tol=0.01, + ) def test__assign_demography_with_initial_age_error(): @@ -233,12 +289,13 @@ def test__assign_demography_with_initial_age_error(): pop_data = pop_data[pop_data.year_start == 1990] simulants = make_base_simulants() initial_age = 200 - r = {k: get_randomness() for k in ['general_purpose', 'bin_selection', 'age_smoothing']} + r = {k: get_randomness() for k in ["general_purpose", "bin_selection", "age_smoothing"]} step_size = pd.Timedelta(days=1) with pytest.raises(ValueError): - bp._assign_demography_with_initial_age(simulants, pop_data, initial_age, - step_size, r, lambda *args, **kwargs: None) + bp._assign_demography_with_initial_age( + simulants, pop_data, initial_age, step_size, r, lambda *args, **kwargs: None + ) def test__assign_demography_with_age_bounds(): @@ -246,16 +303,30 @@ def test__assign_demography_with_age_bounds(): pop_data = pop_data[pop_data.year_start == 1990] simulants = make_base_simulants() age_start, age_end = 0, 180 - r = {k: get_randomness(k) for k in ['general_purpose', 'bin_selection', 'age_smoothing', 'age_smoothing_age_bounds']} - - simulants = bp._assign_demography_with_age_bounds(simulants, pop_data, age_start, - age_end, r, lambda *args, **kwargs: None) - - assert math.isclose(len(simulants[simulants.sex == 'Male']) / len(simulants), 0.5, abs_tol=0.01) + r = { + k: get_randomness(k) + for k in [ + "general_purpose", + "bin_selection", + "age_smoothing", + "age_smoothing_age_bounds", + ] + } + + simulants = bp._assign_demography_with_age_bounds( + simulants, pop_data, age_start, age_end, r, lambda *args, **kwargs: None + ) + + assert math.isclose( + len(simulants[simulants.sex == "Male"]) / len(simulants), 0.5, abs_tol=0.01 + ) for location in simulants.location.unique(): - assert math.isclose(len(simulants[simulants.location == location]) / len(simulants), - 1 / len(simulants.location.unique()), abs_tol=0.01) + assert math.isclose( + len(simulants[simulants.location == location]) / len(simulants), + 1 / len(simulants.location.unique()), + abs_tol=0.01, + ) ages = np.sort(simulants.age.values) age_deltas = ages[1:] - ages[:-1] @@ -263,15 +334,18 @@ def test__assign_demography_with_age_bounds(): num_bins = len(pop_data.age.unique()) n = len(simulants) assert math.isclose(age_deltas.mean(), age_bin_width * num_bins / n, rel_tol=1e-3) - assert age_deltas.max() < 100 * age_bin_width * num_bins / n # Make sure there are no big age gaps. + assert ( + age_deltas.max() < 100 * age_bin_width * num_bins / n + ) # Make sure there are no big age gaps. def test__assign_demography_with_age_bounds_error(): pop_data = dt.assign_demographic_proportions(make_uniform_pop_data(age_bin_midpoint=True)) simulants = make_base_simulants() age_start, age_end = 110, 120 - r = {k: get_randomness() for k in ['general_purpose', 'bin_selection', 'age_smoothing']} + r = {k: get_randomness() for k in ["general_purpose", "bin_selection", "age_smoothing"]} with pytest.raises(ValueError): - bp._assign_demography_with_age_bounds(simulants, pop_data, age_start, - age_end, r, lambda *args, **kwargs: None) + bp._assign_demography_with_age_bounds( + simulants, pop_data, age_start, age_end, r, lambda *args, **kwargs: None + ) diff --git a/tests/population/test_data_transformations.py b/tests/population/test_data_transformations.py index 44233280e..8da7bbbc5 100644 --- a/tests/population/test_data_transformations.py +++ b/tests/population/test_data_transformations.py @@ -2,22 +2,32 @@ import numpy as np import pandas as pd +from vivarium.testing_utilities import build_table, get_randomness -from vivarium.testing_utilities import get_randomness, build_table -from vivarium_public_health.testing.utils import make_uniform_pop_data import vivarium_public_health.population.data_transformations as dt +from vivarium_public_health.testing.utils import make_uniform_pop_data def test_assign_demographic_proportions(): pop_data = dt.assign_demographic_proportions(make_uniform_pop_data(age_bin_midpoint=True)) - assert np.allclose(pop_data['P(sex, location, age| year)'], len(pop_data.year_start.unique()) / len(pop_data)) assert np.allclose( - pop_data['P(sex, location | age, year)'], (len(pop_data.year_start.unique()) - * len(pop_data.age.unique()) / len(pop_data))) + pop_data["P(sex, location, age| year)"], + len(pop_data.year_start.unique()) / len(pop_data), + ) + assert np.allclose( + pop_data["P(sex, location | age, year)"], + (len(pop_data.year_start.unique()) * len(pop_data.age.unique()) / len(pop_data)), + ) assert np.allclose( - pop_data['P(age | year, sex, location)'], (len(pop_data.year_start.unique()) * len(pop_data.sex.unique()) - * len(pop_data.location.unique()) / len(pop_data))) + pop_data["P(age | year, sex, location)"], + ( + len(pop_data.year_start.unique()) + * len(pop_data.sex.unique()) + * len(pop_data.location.unique()) + / len(pop_data) + ), + ) def test_rescale_binned_proportions_full_range(): @@ -27,7 +37,10 @@ def test_rescale_binned_proportions_full_range(): pop_data_scaled = dt.rescale_binned_proportions(pop_data, age_start=0, age_end=100) pop_data_scaled = pop_data_scaled[pop_data_scaled.age.isin(pop_data.age.unique())] - assert np.allclose(pop_data['P(sex, location, age| year)'], pop_data_scaled['P(sex, location, age| year)']) + assert np.allclose( + pop_data["P(sex, location, age| year)"], + pop_data_scaled["P(sex, location, age| year)"], + ) def test_rescale_binned_proportions_clipped_ends(): @@ -36,11 +49,15 @@ def test_rescale_binned_proportions_clipped_ends(): scale = len(pop_data.location.unique()) * len(pop_data.sex.unique()) pop_data_scaled = dt.rescale_binned_proportions(pop_data, age_start=2, age_end=7) - base_p = 1/len(pop_data) - p_scaled = [base_p*7/5, base_p*3/5, base_p*2/5, base_p*8/5] + [base_p]*(len(pop_data_scaled)//scale - 5) + [0] + base_p = 1 / len(pop_data) + p_scaled = ( + [base_p * 7 / 5, base_p * 3 / 5, base_p * 2 / 5, base_p * 8 / 5] + + [base_p] * (len(pop_data_scaled) // scale - 5) + + [0] + ) - for group, sub_population in pop_data_scaled.groupby(['sex', 'location']): - assert np.allclose(sub_population['P(sex, location, age| year)'], p_scaled) + for group, sub_population in pop_data_scaled.groupby(["sex", "location"]): + assert np.allclose(sub_population["P(sex, location, age| year)"], p_scaled) def test_rescale_binned_proportions_age_bin_edges(): @@ -51,27 +68,37 @@ def test_rescale_binned_proportions_age_bin_edges(): pop_data_scaled = dt.rescale_binned_proportions(pop_data, age_start=5, age_end=10) assert len(pop_data_scaled.age.unique()) == len(pop_data.age.unique()) + 2 assert 7.5 in pop_data_scaled.age.unique() - correct_data = ([1/len(pop_data)]*(len(pop_data_scaled)//2 - 2) + [0, 0])*2 - assert np.allclose(pop_data_scaled['P(sex, location, age| year)'], correct_data) + correct_data = ([1 / len(pop_data)] * (len(pop_data_scaled) // 2 - 2) + [0, 0]) * 2 + assert np.allclose(pop_data_scaled["P(sex, location, age| year)"], correct_data) def test_smooth_ages(): pop_data = dt.assign_demographic_proportions(make_uniform_pop_data(age_bin_midpoint=True)) pop_data = pop_data[pop_data.year_start == 1990] - simulants = pd.DataFrame({'age': [22.5]*10000 + [52.5]*10000, - 'sex': ['Male', 'Female']*10000, - 'location': [1, 2]*10000}) + simulants = pd.DataFrame( + { + "age": [22.5] * 10000 + [52.5] * 10000, + "sex": ["Male", "Female"] * 10000, + "location": [1, 2] * 10000, + } + ) randomness = get_randomness() smoothed_simulants = dt.smooth_ages(simulants, pop_data, randomness) - assert math.isclose(len(smoothed_simulants.age.unique()), len(smoothed_simulants.index), abs_tol=1) + assert math.isclose( + len(smoothed_simulants.age.unique()), len(smoothed_simulants.index), abs_tol=1 + ) # Tolerance is 3*std_dev of the sample mean - assert math.isclose(smoothed_simulants.age.mean(), 37.5, abs_tol=3*math.sqrt(13.149778198**2/2000)) + assert math.isclose( + smoothed_simulants.age.mean(), 37.5, abs_tol=3 * math.sqrt(13.149778198**2 / 2000) + ) def test__get_bins_and_proportions_with_youngest_bin(): pop_data = dt.assign_demographic_proportions(make_uniform_pop_data(age_bin_midpoint=True)) - pop_data = pop_data[(pop_data.year_start == 1990) & (pop_data.location == 1) & (pop_data.sex == 'Male')] + pop_data = pop_data[ + (pop_data.year_start == 1990) & (pop_data.location == 1) & (pop_data.sex == "Male") + ] age = dt.AgeValues(current=2.5, young=0, old=7.5) endpoints, proportions = dt._get_bins_and_proportions(pop_data, age) assert endpoints.left == 0 @@ -84,7 +111,9 @@ def test__get_bins_and_proportions_with_youngest_bin(): def test__get_bins_and_proportions_with_oldest_bin(): pop_data = dt.assign_demographic_proportions(make_uniform_pop_data(age_bin_midpoint=True)) - pop_data = pop_data[(pop_data.year_start == 1990) & (pop_data.location == 1) & (pop_data.sex == 'Male')] + pop_data = pop_data[ + (pop_data.year_start == 1990) & (pop_data.location == 1) & (pop_data.sex == "Male") + ] age = dt.AgeValues(current=97.5, young=92.5, old=100) endpoints, proportions = dt._get_bins_and_proportions(pop_data, age) assert endpoints.left == 95 @@ -97,7 +126,9 @@ def test__get_bins_and_proportions_with_oldest_bin(): def test__get_bins_and_proportions_with_middle_bin(): pop_data = dt.assign_demographic_proportions(make_uniform_pop_data(age_bin_midpoint=True)) - pop_data = pop_data[(pop_data.year_start == 1990) & (pop_data.location == 1) & (pop_data.sex == 'Male')] + pop_data = pop_data[ + (pop_data.year_start == 1990) & (pop_data.location == 1) & (pop_data.sex == "Male") + ] age = dt.AgeValues(current=22.5, young=17.5, old=27.5) endpoints, proportions = dt._get_bins_and_proportions(pop_data, age) assert endpoints.left == 20 @@ -113,19 +144,35 @@ def test__construct_sampling_parameters(): endpoint = dt.EndpointValues(left=34, right=77) proportion = dt.AgeValues(current=0.1, young=0.5, old=0.3) - pdf, slope, area, cdf_inflection_point = dt._construct_sampling_parameters(age, endpoint, proportion) - - assert pdf.left == ((proportion.current - proportion.young)/(age.current - age.young) - * (endpoint.left - age.young) + proportion.young) - assert pdf.right == ((proportion.old - proportion.current) / (age.old - age.current) - * (endpoint.right - age.current) + proportion.current) - assert area == 0.5 * ((proportion.current + pdf.left)*(age.current - endpoint.left) - + (pdf.right + proportion.current)*(endpoint.right - age.current)) + pdf, slope, area, cdf_inflection_point = dt._construct_sampling_parameters( + age, endpoint, proportion + ) + + assert pdf.left == ( + (proportion.current - proportion.young) + / (age.current - age.young) + * (endpoint.left - age.young) + + proportion.young + ) + assert pdf.right == ( + (proportion.old - proportion.current) + / (age.old - age.current) + * (endpoint.right - age.current) + + proportion.current + ) + assert area == 0.5 * ( + (proportion.current + pdf.left) * (age.current - endpoint.left) + + (pdf.right + proportion.current) * (endpoint.right - age.current) + ) assert slope.left == (proportion.current - pdf.left) / (age.current - endpoint.left) assert slope.right == (pdf.right - proportion.current) / (endpoint.right - age.current) - assert cdf_inflection_point == 1 / (2 * area) * (proportion.current + pdf.left) * (age.current - endpoint.left) + assert cdf_inflection_point == 1 / (2 * area) * (proportion.current + pdf.left) * ( + age.current - endpoint.left + ) def test__compute_ages(): - assert dt._compute_ages(1, 10, 12, 0, 33) == 10 + 33/12*1 - assert dt._compute_ages(1, 10, 12, 5, 33) == 10 + 12/5*(np.sqrt(1+2*33*5/12**2*1) - 1) + assert dt._compute_ages(1, 10, 12, 0, 33) == 10 + 33 / 12 * 1 + assert dt._compute_ages(1, 10, 12, 5, 33) == 10 + 12 / 5 * ( + np.sqrt(1 + 2 * 33 * 5 / 12**2 * 1) - 1 + ) diff --git a/tests/risks/conftest.py b/tests/risks/conftest.py index 51e8f5c58..c56923825 100644 --- a/tests/risks/conftest.py +++ b/tests/risks/conftest.py @@ -1,31 +1,43 @@ -import pytest from typing import List -import pandas as pd +import pandas as pd +import pytest from vivarium.testing_utilities import build_table + from vivarium_public_health.risks.base_risk import Risk -def make_test_data_table(values: List, parameter='cat') -> pd.DataFrame: +def make_test_data_table(values: List, parameter="cat") -> pd.DataFrame: year_start = 1990 # same as the base config year_end = 2010 if len(values) == 1: - df = build_table(values[0], year_start, year_end, ('age', 'year', 'sex', 'value')) + df = build_table(values[0], year_start, year_end, ("age", "year", "sex", "value")) else: - cats = [f'{parameter}{i+1}' for i in range(len(values))] if parameter == 'cat' else parameter + cats = ( + [f"{parameter}{i+1}" for i in range(len(values))] + if parameter == "cat" + else parameter + ) df = [] for cat, value in zip(cats, values): - df.append(build_table([cat, value], year_start, year_end, ('age','year', 'sex', 'parameter', 'value'))) + df.append( + build_table( + [cat, value], + year_start, + year_end, + ("age", "year", "sex", "parameter", "value"), + ) + ) df = pd.concat(df) return df -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def continuous_risk(): year_start = 1990 year_end = 2010 - risk = 'test_risk' + risk = "test_risk" risk_data = dict() exposure_mean = make_test_data_table([130]) exposure_sd = make_test_data_table([15]) @@ -34,24 +46,35 @@ def continuous_risk(): paf_data = [] for cause in affected_causes: rr_data.append( - build_table([1.01, cause], year_start, year_end, ['age', 'sex', 'year', 'value', 'cause'], - ).melt(id_vars=('age_start', 'age_end', 'year_start', - 'year_end', 'sex', 'cause'), var_name='parameter', value_name='value') + build_table( + [1.01, cause], + year_start, + year_end, + ["age", "sex", "year", "value", "cause"], + ).melt( + id_vars=("age_start", "age_end", "year_start", "year_end", "sex", "cause"), + var_name="parameter", + value_name="value", + ) + ) + paf_data.append( + build_table( + [1, cause], year_start, year_end, ["age", "sex", "year", "value", "cause"] + ) ) - paf_data.append(build_table([1, cause], year_start, year_end, ['age', 'sex', 'year', 'value', 'cause'])) rr_data = pd.concat(rr_data) paf_data = pd.concat(paf_data) - paf_data['affected_measure'] = 'incidence_rate' - rr_data['affected_measure'] = 'incidence_rate' - risk_data['exposure'] = exposure_mean - risk_data['exposure_standard_deviation'] = exposure_sd - risk_data['relative_risk'] = rr_data - risk_data['population_attributable_fraction'] = paf_data - risk_data['affected_causes'] = affected_causes - risk_data['affected_risk_factors'] = [] + paf_data["affected_measure"] = "incidence_rate" + rr_data["affected_measure"] = "incidence_rate" + risk_data["exposure"] = exposure_mean + risk_data["exposure_standard_deviation"] = exposure_sd + risk_data["relative_risk"] = rr_data + risk_data["population_attributable_fraction"] = paf_data + risk_data["affected_causes"] = affected_causes + risk_data["affected_risk_factors"] = [] tmred = { - "distribution": 'uniform', + "distribution": "uniform", "min": 110.0, "max": 115.0, "inverted": False, @@ -63,24 +86,27 @@ def continuous_risk(): "min_val": 50.0, } tmrel = 0.5 * (tmred["max"] + tmred["min"]) - risk_data['tmred'] = tmred - risk_data['tmrel'] = tmrel - risk_data['exposure_parameters'] = exposure_parameters - risk_data['distribution'] = 'normal' + risk_data["tmred"] = tmred + risk_data["tmrel"] = tmrel + risk_data["exposure_parameters"] = exposure_parameters + risk_data["distribution"] = "normal" return Risk(f"risk_factor.{risk}"), risk_data -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def dichotomous_risk(): year_start = 1990 year_end = 2010 - risk = 'test_risk' + risk = "test_risk" risk_data = dict() exposure_data = build_table( - 0.5, year_start, year_end, ['age', 'year', 'sex', 'cat1', 'cat2'] - ).melt(id_vars=('age_start', 'age_end', - 'year_start', 'year_end', 'sex'), var_name='parameter', value_name='value') + 0.5, year_start, year_end, ["age", "year", "sex", "cat1", "cat2"] + ).melt( + id_vars=("age_start", "age_end", "year_start", "year_end", "sex"), + var_name="parameter", + value_name="value", + ) affected_causes = ["test_cause_1", "test_cause_2"] rr_data = [] @@ -88,36 +114,49 @@ def dichotomous_risk(): for cause in affected_causes: rr_data.append( build_table( - [1.01, 1, cause], year_start, year_end, ['age', 'year', 'sex', 'cat1', 'cat2', 'cause'] - ).melt(id_vars=('age_start', 'age_end', 'year_start', - 'year_end', 'sex', 'cause'), var_name='parameter', value_name='value') + [1.01, 1, cause], + year_start, + year_end, + ["age", "year", "sex", "cat1", "cat2", "cause"], + ).melt( + id_vars=("age_start", "age_end", "year_start", "year_end", "sex", "cause"), + var_name="parameter", + value_name="value", + ) + ) + paf_data.append( + build_table( + [1, cause], year_start, year_end, ["age", "sex", "year", "value", "cause"] + ) ) - paf_data.append(build_table([1, cause], year_start, year_end, ['age', 'sex', 'year', 'value', 'cause'])) rr_data = pd.concat(rr_data) paf_data = pd.concat(paf_data) - paf_data['affected_measure'] = 'incidence_rate' - rr_data['affected_measure'] = 'incidence_rate' - risk_data['exposure'] = exposure_data - risk_data['relative_risk'] = rr_data - risk_data['population_attributable_fraction'] = paf_data - risk_data['affected_causes'] = affected_causes - risk_data['affected_risk_factors'] = [] + paf_data["affected_measure"] = "incidence_rate" + rr_data["affected_measure"] = "incidence_rate" + risk_data["exposure"] = exposure_data + risk_data["relative_risk"] = rr_data + risk_data["population_attributable_fraction"] = paf_data + risk_data["affected_causes"] = affected_causes + risk_data["affected_risk_factors"] = [] incidence_rate = build_table(0.01, year_start, year_end) - risk_data['incidence_rate'] = incidence_rate - risk_data['distribution'] = 'dichotomous' - return Risk(f'risk_factor.{risk}'), risk_data + risk_data["incidence_rate"] = incidence_rate + risk_data["distribution"] = "dichotomous" + return Risk(f"risk_factor.{risk}"), risk_data -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def polytomous_risk(): year_start = 1990 year_end = 2010 - risk = 'test_risk' + risk = "test_risk" risk_data = dict() exposure_data = build_table( - 0.25, year_start, year_end, ['age', 'year', 'sex', 'cat1', 'cat2', 'cat3', 'cat4'] - ).melt(id_vars=('age_start', 'age_end', - 'year_start', 'year_end', 'sex'), var_name='parameter', value_name='value') + 0.25, year_start, year_end, ["age", "year", "sex", "cat1", "cat2", "cat3", "cat4"] + ).melt( + id_vars=("age_start", "age_end", "year_start", "year_end", "sex"), + var_name="parameter", + value_name="value", + ) affected_causes = ["test_cause_1", "test_cause_2"] rr_data = [] @@ -125,62 +164,88 @@ def polytomous_risk(): for cause in affected_causes: rr_data.append( build_table( - [1.03, 1.02, 1.01, 1, cause], year_start, year_end, ['age', 'year', 'sex', 'cat1', 'cat2', 'cat3', 'cat4', 'cause'] - ).melt(id_vars=('age_start', 'age_end', 'year_start', - 'year_end', 'sex', 'cause'), var_name='parameter', value_name='value') + [1.03, 1.02, 1.01, 1, cause], + year_start, + year_end, + ["age", "year", "sex", "cat1", "cat2", "cat3", "cat4", "cause"], + ).melt( + id_vars=("age_start", "age_end", "year_start", "year_end", "sex", "cause"), + var_name="parameter", + value_name="value", + ) + ) + paf_data.append( + build_table( + [1, cause], year_start, year_end, ["age", "sex", "year", "value", "cause"] + ) ) - paf_data.append(build_table([1, cause], year_start, year_end, ['age', 'sex', 'year', 'value', 'cause'])) rr_data = pd.concat(rr_data) paf_data = pd.concat(paf_data) - paf_data['affected_measure'] = 'incidence_rate' - rr_data['affected_measure'] = 'incidence_rate' - risk_data['exposure'] = exposure_data - risk_data['relative_risk'] = rr_data - risk_data['population_attributable_fraction'] = paf_data - risk_data['affected_causes'] = affected_causes - risk_data['affected_risk_factors'] = [] + paf_data["affected_measure"] = "incidence_rate" + rr_data["affected_measure"] = "incidence_rate" + risk_data["exposure"] = exposure_data + risk_data["relative_risk"] = rr_data + risk_data["population_attributable_fraction"] = paf_data + risk_data["affected_causes"] = affected_causes + risk_data["affected_risk_factors"] = [] incidence_rate = build_table(0.01, year_start, year_end) - risk_data['incidence_rate'] = incidence_rate - risk_data['distribution'] = 'polytomous' - return Risk(f'risk_factor.{risk}'), risk_data + risk_data["incidence_rate"] = incidence_rate + risk_data["distribution"] = "polytomous" + return Risk(f"risk_factor.{risk}"), risk_data -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def coverage_gap(): year_start = 1990 year_end = 2010 - cg = 'test_coverage_gap' + cg = "test_coverage_gap" cg_data = dict() cg_exposed = 0.6 cg_exposure_data = build_table( - [cg_exposed, 1 - cg_exposed], year_start, year_end, ['age', 'year', 'sex', 'cat1', 'cat2'] - ).melt(id_vars=('age_start', 'age_end', 'year_start', - 'year_end', 'sex',), var_name='parameter', value_name='value') - - + [cg_exposed, 1 - cg_exposed], + year_start, + year_end, + ["age", "year", "sex", "cat1", "cat2"], + ).melt( + id_vars=( + "age_start", + "age_end", + "year_start", + "year_end", + "sex", + ), + var_name="parameter", + value_name="value", + ) rr = 2 rr_data = build_table( - [rr, 1], year_start, year_end, ['age', 'year', 'sex', 'cat1', 'cat2'] - ).melt(id_vars=('age_start', 'age_end', - 'year_start', 'year_end', 'sex'), var_name='parameter', value_name='value') + [rr, 1], year_start, year_end, ["age", "year", "sex", "cat1", "cat2"] + ).melt( + id_vars=("age_start", "age_end", "year_start", "year_end", "sex"), + var_name="parameter", + value_name="value", + ) # paf is (sum(exposure(category)*rr(category) -1 )/ (sum(exposure(category)* rr(category) paf = (rr * cg_exposed + (1 - cg_exposed) - 1) / (rr * cg_exposed + (1 - cg_exposed)) paf_data = build_table( - paf, year_start, year_end, ['age', 'year', 'sex', 'population_attributable_fraction'] - ).melt(id_vars=('age_start', 'age_end', 'year_start', - 'year_end', 'sex'), var_name='population_attributable_fraction', value_name='value') - - paf_data['risk_factor'] = 'test_risk' - paf_data['affected_measure'] = 'exposure_parameters' - rr_data['affected_measure'] = 'exposure_parameters' - cg_data['exposure'] = cg_exposure_data - rr_data['risk_factor'] = 'test_risk' - cg_data['relative_risk'] = rr_data - cg_data['population_attributable_fraction'] = paf_data - cg_data['affected_causes'] = [] - cg_data['affected_risk_factors'] = ['test_risk'] - cg_data['distribution'] = 'dichotomous' - return Risk(f'coverage_gap.{cg}') , cg_data + paf, year_start, year_end, ["age", "year", "sex", "population_attributable_fraction"] + ).melt( + id_vars=("age_start", "age_end", "year_start", "year_end", "sex"), + var_name="population_attributable_fraction", + value_name="value", + ) + + paf_data["risk_factor"] = "test_risk" + paf_data["affected_measure"] = "exposure_parameters" + rr_data["affected_measure"] = "exposure_parameters" + cg_data["exposure"] = cg_exposure_data + rr_data["risk_factor"] = "test_risk" + cg_data["relative_risk"] = rr_data + cg_data["population_attributable_fraction"] = paf_data + cg_data["affected_causes"] = [] + cg_data["affected_risk_factors"] = ["test_risk"] + cg_data["distribution"] = "dichotomous" + return Risk(f"coverage_gap.{cg}"), cg_data diff --git a/tests/risks/test_data_transformations.py b/tests/risks/test_data_transformations.py index b2307e394..eb765a08b 100644 --- a/tests/risks/test_data_transformations.py +++ b/tests/risks/test_data_transformations.py @@ -1,44 +1,69 @@ -import pytest import pandas as pd +import pytest -from vivarium_public_health.risks.data_transformations import _rebin_exposure_data, _rebin_relative_risk_data +from vivarium_public_health.risks.data_transformations import ( + _rebin_exposure_data, + _rebin_relative_risk_data, +) -@pytest.mark.parametrize('rebin_categories, rebinned_values', [({'cat1', 'cat2'}, (0.7, 0.3)), - ({'cat1'}, (0.5, 0.5)), - ({'cat2'}, (0.2, 0.8)), - ({'cat2', 'cat3'}, (0.5, 0.5)), - ({'cat1', 'cat3'}, (0.8, 0.2))]) +@pytest.mark.parametrize( + "rebin_categories, rebinned_values", + [ + ({"cat1", "cat2"}, (0.7, 0.3)), + ({"cat1"}, (0.5, 0.5)), + ({"cat2"}, (0.2, 0.8)), + ({"cat2", "cat3"}, (0.5, 0.5)), + ({"cat1", "cat3"}, (0.8, 0.2)), + ], +) def test__rebin_exposure_data(rebin_categories, rebinned_values): - df = pd.DataFrame({'year': [1990, 1990, 1995, 1995]*3, - 'age': [10, 40, 10, 40]*3, - 'parameter': ['cat1']*4 + ['cat2']*4 + ['cat3']*4, - 'value': [0.5]*4 + [0.2]*4 + [0.3]*4}) + df = pd.DataFrame( + { + "year": [1990, 1990, 1995, 1995] * 3, + "age": [10, 40, 10, 40] * 3, + "parameter": ["cat1"] * 4 + ["cat2"] * 4 + ["cat3"] * 4, + "value": [0.5] * 4 + [0.2] * 4 + [0.3] * 4, + } + ) rebinned_df = _rebin_exposure_data(df, rebin_categories) assert rebinned_df.shape == (8, 4) - assert (rebinned_df[rebinned_df.parameter == 'cat1'].value == rebinned_values[0]).all() - assert (rebinned_df[rebinned_df.parameter == 'cat2'].value == rebinned_values[1]).all() + assert (rebinned_df[rebinned_df.parameter == "cat1"].value == rebinned_values[0]).all() + assert (rebinned_df[rebinned_df.parameter == "cat2"].value == rebinned_values[1]).all() -@pytest.mark.parametrize('rebin_categories, rebinned_values', [({'cat1', 'cat2'}, (10, 1)), - ({'cat1'}, (0, 7.3)), - ({'cat2'}, (10, 1)), - ({'cat2', 'cat3'}, (7.3, 0)), - ({'cat1', 'cat3'}, (1, 10))]) +@pytest.mark.parametrize( + "rebin_categories, rebinned_values", + [ + ({"cat1", "cat2"}, (10, 1)), + ({"cat1"}, (0, 7.3)), + ({"cat2"}, (10, 1)), + ({"cat2", "cat3"}, (7.3, 0)), + ({"cat1", "cat3"}, (1, 10)), + ], +) def test__rebin_relative_risk(rebin_categories, rebinned_values): - exp = pd.DataFrame({'year': [1990, 1990, 1995, 1995]*3, - 'age': [10, 40, 10, 40]*3, - 'parameter': ['cat1']*4 + ['cat2']*4 + ['cat3']*4, - 'value': [0.0]*4 + [0.7]*4 + [0.3]*4}) + exp = pd.DataFrame( + { + "year": [1990, 1990, 1995, 1995] * 3, + "age": [10, 40, 10, 40] * 3, + "parameter": ["cat1"] * 4 + ["cat2"] * 4 + ["cat3"] * 4, + "value": [0.0] * 4 + [0.7] * 4 + [0.3] * 4, + } + ) - rr = pd.DataFrame({'year': [1990, 1990, 1995, 1995]*3, - 'age': [10, 40, 10, 40]*3, - 'parameter': ['cat1']*4 + ['cat2']*4 + ['cat3']*4, - 'value': [5]*4 + [10]*4 + [1]*4}) + rr = pd.DataFrame( + { + "year": [1990, 1990, 1995, 1995] * 3, + "age": [10, 40, 10, 40] * 3, + "parameter": ["cat1"] * 4 + ["cat2"] * 4 + ["cat3"] * 4, + "value": [5] * 4 + [10] * 4 + [1] * 4, + } + ) rebinned_df = _rebin_relative_risk_data(rr, exp, rebin_categories) assert rebinned_df.shape == (8, 4) - assert (rebinned_df[rebinned_df.parameter == 'cat1'].value == rebinned_values[0]).all() - assert (rebinned_df[rebinned_df.parameter == 'cat2'].value == rebinned_values[1]).all() + assert (rebinned_df[rebinned_df.parameter == "cat1"].value == rebinned_values[0]).all() + assert (rebinned_df[rebinned_df.parameter == "cat2"].value == rebinned_values[1]).all() diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 49ffb3796..e2857433d 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -1,18 +1,20 @@ -from hypothesis import given import hypothesis.strategies as st import pytest +from hypothesis import given from vivarium_public_health.utilities import EntityString, TargetString @st.composite def component_string(draw, min_components=0, max_components=None): - alphabet = st.characters(blacklist_characters=['.']) - string_parts = draw(st.lists(st.text(alphabet=alphabet), min_size=min_components, max_size=max_components)) - return '.'.join(string_parts) + alphabet = st.characters(blacklist_characters=["."]) + string_parts = draw( + st.lists(st.text(alphabet=alphabet), min_size=min_components, max_size=max_components) + ) + return ".".join(string_parts) -@given(component_string().filter(lambda x: len(x.split('.')) != 2)) +@given(component_string().filter(lambda x: len(x.split(".")) != 2)) def test_EntityString_fail(s): with pytest.raises(ValueError): EntityString(s) @@ -20,13 +22,13 @@ def test_EntityString_fail(s): @given(component_string(2, 2)) def test_EntityString_pass(s): - entity_type, entity_name = s.split('.') + entity_type, entity_name = s.split(".") r = EntityString(s) assert r.type == entity_type assert r.name == entity_name -@given(component_string().filter(lambda x: len(x.split('.')) != 3)) +@given(component_string().filter(lambda x: len(x.split(".")) != 3)) def test_TargetString_fail(s): with pytest.raises(ValueError): TargetString(s) @@ -34,10 +36,8 @@ def test_TargetString_fail(s): @given(component_string(3, 3)) def test_TargetString_pass(s): - target_type, target_name, target_measure = s.split('.') + target_type, target_name, target_measure = s.split(".") t = TargetString(s) assert t.type == target_type assert t.name == target_name assert t.measure == target_measure - -