diff --git a/src/vivarium_sodium_reduction/components/interventions.py b/src/vivarium_sodium_reduction/components/interventions.py
index 28a4256..771b2ed 100644
--- a/src/vivarium_sodium_reduction/components/interventions.py
+++ b/src/vivarium_sodium_reduction/components/interventions.py
@@ -1,17 +1,19 @@
+from typing import Any, Dict
+
import pandas as pd
-from typing import Dict, Any
from vivarium import Component
from vivarium.framework.engine import Builder
from vivarium.framework.population import SimulantData
+
class RelativeShiftIntervention(Component):
"""Applies a relative shift to a target value."""
CONFIGURATION_DEFAULTS = {
- 'intervention': {
- 'shift_factor': 0.1,
- 'age_start': 0,
- 'age_end': 125,
+ "intervention": {
+ "shift_factor": 0.1,
+ "age_start": 0,
+ "age_end": 125,
}
}
@@ -25,9 +27,7 @@ def name(self) -> str:
@property
def configuration_defaults(self) -> Dict[str, Dict[str, Any]]:
- return {
- f'{self.name}': self.CONFIGURATION_DEFAULTS['intervention']
- }
+ return {f"{self.name}": self.CONFIGURATION_DEFAULTS["intervention"]}
def setup(self, builder: Builder) -> None:
self.config = builder.configuration[self.name]
@@ -35,12 +35,10 @@ def setup(self, builder: Builder) -> None:
self.age_start = self.config.age_start
self.age_end = self.config.age_end
- self.population_view = builder.population.get_view(['age'])
+ self.population_view = builder.population.get_view(["age"])
builder.value.register_value_modifier(
- self.target,
- modifier=self.adjust_exposure,
- requires_columns=['age']
+ self.target, modifier=self.adjust_exposure, requires_columns=["age"]
)
def adjust_exposure(self, index: pd.Index, exposure: pd.Series) -> pd.Series:
@@ -49,4 +47,4 @@ def adjust_exposure(self, index: pd.Index, exposure: pd.Series) -> pd.Series:
(self.age_start <= pop.age) & (pop.age < self.age_end)
].index
exposure.loc[applicable_index] *= self.shift_factor
- return exposure
\ No newline at end of file
+ return exposure
diff --git a/src/vivarium_sodium_reduction/components/risks.py b/src/vivarium_sodium_reduction/components/risks.py
index 36f85d3..74e2462 100644
--- a/src/vivarium_sodium_reduction/components/risks.py
+++ b/src/vivarium_sodium_reduction/components/risks.py
@@ -11,9 +11,12 @@
from vivarium.framework.randomness import get_hash
from vivarium.framework.values import Pipeline
from vivarium_public_health.risks import Risk
-from vivarium_public_health.risks.data_transformations import get_exposure_post_processor
+from vivarium_public_health.risks.data_transformations import (
+ get_exposure_post_processor,
+)
from vivarium_public_health.utilities import EntityString
+
class DropValueRisk(Risk):
def __init__(self, risk: str):
super().__init__(risk)
@@ -57,11 +60,13 @@ def post_processor(exposure, _):
return post_processor
+
class CorrelatedRisk(DropValueRisk):
"""A risk that can be correlated with another risk.
-
+
TODO: document strategy used in this component in more detail,
Abie had an AI adapt it from https://github.com/ihmeuw/vivarium_nih_us_cvd"""
+
@property
def columns_created(self) -> List[str]:
return []
@@ -84,8 +89,10 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None:
def on_time_step_prepare(self, event: Event) -> None:
pass
+
class RiskCorrelation(Component):
"""A component that generates a specified correlation between two risk exposures."""
+
@property
def columns_created(self) -> List[str]:
return self.propensity_column_names + self.exposure_column_names
@@ -104,8 +111,12 @@ def __init__(self, risk1: str, risk2: str, correlation: str):
correlation_matrix = np.array([[1, float(correlation)], [float(correlation), 1]])
self.correlated_risks = [EntityString(risk) for risk in correlated_risks]
self.correlation_matrix = correlation_matrix
- self.propensity_column_names = [f"{risk.name}_propensity" for risk in self.correlated_risks]
- self.exposure_column_names = [f"{risk.name}_exposure" for risk in self.correlated_risks]
+ self.propensity_column_names = [
+ f"{risk.name}_propensity" for risk in self.correlated_risks
+ ]
+ self.exposure_column_names = [
+ f"{risk.name}_exposure" for risk in self.correlated_risks
+ ]
self.ensemble_propensities = [
f"ensemble_propensity_" + risk
for risk in self.correlated_risks
@@ -113,8 +124,14 @@ def __init__(self, risk1: str, risk2: str, correlation: str):
]
def setup(self, builder: Builder) -> None:
- self.distributions = {risk: builder.components.get_component(risk).exposure_distribution for risk in self.correlated_risks}
- self.exposures = {risk: builder.value.get_value(f"{risk.name}.exposure") for risk in self.correlated_risks}
+ self.distributions = {
+ risk: builder.components.get_component(risk).exposure_distribution
+ for risk in self.correlated_risks
+ }
+ self.exposures = {
+ risk: builder.value.get_value(f"{risk.name}.exposure")
+ for risk in self.correlated_risks
+ }
self.input_draw = builder.configuration.input_data.input_draw_number
self.random_seed = builder.configuration.randomness.random_seed
@@ -124,21 +141,21 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None:
np.random.seed(get_hash(f"{self.input_draw}_{self.random_seed}"))
probit_propensity = np.random.multivariate_normal(
- mean=[0] * len(self.correlated_risks),
- cov=self.correlation_matrix,
- size=len(pop)
+ mean=[0] * len(self.correlated_risks), cov=self.correlation_matrix, size=len(pop)
)
correlated_propensities = scipy.stats.norm().cdf(probit_propensity)
propensities[self.propensity_column_names] = correlated_propensities
def get_exposure_from_propensity(propensity_col: pd.Series) -> pd.Series:
- risk = propensity_col.name.replace('_propensity','')
- exposure_values = self.distributions['risk_factor.' + risk].ppf(propensity_col)
+ risk = propensity_col.name.replace("_propensity", "")
+ exposure_values = self.distributions["risk_factor." + risk].ppf(propensity_col)
return pd.Series(exposure_values)
exposures = propensities.apply(get_exposure_from_propensity)
- exposures.columns = [col.replace('_propensity','_exposure') for col in propensities.columns]
-
+ exposures.columns = [
+ col.replace("_propensity", "_exposure") for col in propensities.columns
+ ]
+
self.population_view.update(pd.concat([propensities, exposures], axis=1))
def on_time_step_prepare(self, event: Event) -> None:
@@ -147,33 +164,37 @@ def on_time_step_prepare(self, event: Event) -> None:
exposure_col = pd.Series(exposure_values, name=f"{risk.name}_exposure")
self.population_view.update(exposure_col)
+
class SodiumSBPEffect(Component):
@property
def name(self):
return "sodium_sbp_effect"
def setup(self, builder: Builder):
- self.sodium_exposure = builder.value.get_value('diet_high_in_sodium.exposure')
- self.sodium_exposure_raw = builder.value.get_value('diet_high_in_sodium.raw_exposure')
-
+ self.sodium_exposure = builder.value.get_value("diet_high_in_sodium.exposure")
+ self.sodium_exposure_raw = builder.value.get_value("diet_high_in_sodium.raw_exposure")
+
builder.value.register_value_modifier(
- 'high_systolic_blood_pressure.drop_value',
+ "high_systolic_blood_pressure.drop_value",
modifier=self.sodium_effect_on_sbp,
- requires_columns=['age', 'sex'],
- requires_values=['diet_high_in_sodium.exposure', 'diet_high_in_sodium.raw_exposure']
+ requires_columns=["age", "sex"],
+ requires_values=[
+ "diet_high_in_sodium.exposure",
+ "diet_high_in_sodium.raw_exposure",
+ ],
)
def sodium_effect_on_sbp(self, index, sbp_drop_value):
sodium_exposure = self.sodium_exposure(index)
sodium_exposure_raw = self.sodium_exposure_raw(index)
-
+
sodium_threshold = 2.0 # g/day
mmHg_per_g_sodium = 10 # mmHg increase per 1g sodium above threshold
-
+
sbp_increase = pd.Series(0, index=index)
sodium_drop = sodium_exposure_raw - sodium_exposure
# TODO: use threshold
sbp_drop_due_to_sodium_drop = sodium_drop * mmHg_per_g_sodium
-
+
return sbp_drop_value + sbp_drop_due_to_sodium_drop
diff --git a/src/vivarium_sodium_reduction/constants/data_keys.py b/src/vivarium_sodium_reduction/constants/data_keys.py
index 156318d..9d5921e 100644
--- a/src/vivarium_sodium_reduction/constants/data_keys.py
+++ b/src/vivarium_sodium_reduction/constants/data_keys.py
@@ -36,8 +36,12 @@ class __SomeDisease(NamedTuple):
# Keys that will be loaded into the artifact. must have a colon type declaration
SOME_DISEASE_PREVALENCE: TargetString = TargetString("cause.some_disease.prevalence")
- SOME_DISEASE_INCIDENCE_RATE: TargetString = TargetString("cause.some_disease.incidence_rate")
- SOME_DISEASE_REMISSION_RATE: TargetString = TargetString("cause.some_disease.remission_rate")
+ SOME_DISEASE_INCIDENCE_RATE: TargetString = TargetString(
+ "cause.some_disease.incidence_rate"
+ )
+ SOME_DISEASE_REMISSION_RATE: TargetString = TargetString(
+ "cause.some_disease.remission_rate"
+ )
DISABILITY_WEIGHT: TargetString = TargetString("cause.some_disease.disability_weight")
EMR: TargetString = TargetString("cause.some_disease.excess_mortality_rate")
CSMR: TargetString = TargetString("cause.some_disease.cause_specific_mortality_rate")
diff --git a/src/vivarium_sodium_reduction/constants/data_values.py b/src/vivarium_sodium_reduction/constants/data_values.py
index b2b3458..06cf48a 100644
--- a/src/vivarium_sodium_reduction/constants/data_values.py
+++ b/src/vivarium_sodium_reduction/constants/data_values.py
@@ -29,5 +29,6 @@
SCALE_UP_START_DT = datetime(2021, 1, 1)
SCALE_UP_END_DT = datetime(2030, 1, 1)
SCREENING_SCALE_UP_GOAL_COVERAGE = 0.50
-SCREENING_SCALE_UP_DIFFERENCE = SCREENING_SCALE_UP_GOAL_COVERAGE - PROBABILITY_ATTENDING_SCREENING_START_MEAN
-
+SCREENING_SCALE_UP_DIFFERENCE = (
+ SCREENING_SCALE_UP_GOAL_COVERAGE - PROBABILITY_ATTENDING_SCREENING_START_MEAN
+)
diff --git a/src/vivarium_sodium_reduction/constants/models.py b/src/vivarium_sodium_reduction/constants/models.py
index d135bf5..3543fa5 100644
--- a/src/vivarium_sodium_reduction/constants/models.py
+++ b/src/vivarium_sodium_reduction/constants/models.py
@@ -35,4 +35,6 @@ def __new__(cls, value):
STATES = tuple(state for model in STATE_MACHINE_MAP.values() for state in model["states"])
-TRANSITIONS = tuple(state for model in STATE_MACHINE_MAP.values() for state in model["transitions"])
+TRANSITIONS = tuple(
+ state for model in STATE_MACHINE_MAP.values() for state in model["transitions"]
+)
diff --git a/src/vivarium_sodium_reduction/constants/paths.py b/src/vivarium_sodium_reduction/constants/paths.py
index 851e727..4af4711 100644
--- a/src/vivarium_sodium_reduction/constants/paths.py
+++ b/src/vivarium_sodium_reduction/constants/paths.py
@@ -5,5 +5,6 @@
BASE_DIR = Path(vivarium_sodium_reduction.__file__).resolve().parent
-ARTIFACT_ROOT = Path(f"/mnt/team/simulation_science/pub/models/{metadata.PROJECT_NAME}/artifacts/")
-
+ARTIFACT_ROOT = Path(
+ f"/mnt/team/simulation_science/pub/models/{metadata.PROJECT_NAME}/artifacts/"
+)
diff --git a/src/vivarium_sodium_reduction/constants/results.py b/src/vivarium_sodium_reduction/constants/results.py
index 8e4065d..e21809e 100644
--- a/src/vivarium_sodium_reduction/constants/results.py
+++ b/src/vivarium_sodium_reduction/constants/results.py
@@ -27,10 +27,10 @@
THROWAWAY_COLUMNS = [f"{state}_event_count" for state in models.STATES]
DEATH_COLUMN_TEMPLATE = "MEASURE_death_due_to_{CAUSE_OF_DEATH}_AGE_GROUP_{AGE_GROUP}_CURRENT_YEAR_{YEAR}_SEX_{SEX}"
-YLLS_COLUMN_TEMPLATE = "MEASURE_ylls_due_to_{CAUSE_OF_DEATH}_AGE_GROUP_{AGE_GROUP}_CURRENT_YEAR_{YEAR}_SEX_{SEX}"
-YLDS_COLUMN_TEMPLATE = (
- "MEASURE_ylds_due_to_{CAUSE_OF_DISABILITY}_AGE_GROUP_{AGE_GROUP}_CURRENT_YEAR_{YEAR}_SEX_{SEX}"
+YLLS_COLUMN_TEMPLATE = (
+ "MEASURE_ylls_due_to_{CAUSE_OF_DEATH}_AGE_GROUP_{AGE_GROUP}_CURRENT_YEAR_{YEAR}_SEX_{SEX}"
)
+YLDS_COLUMN_TEMPLATE = "MEASURE_ylds_due_to_{CAUSE_OF_DISABILITY}_AGE_GROUP_{AGE_GROUP}_CURRENT_YEAR_{YEAR}_SEX_{SEX}"
STATE_PERSON_TIME_COLUMN_TEMPLATE = (
"MEASURE_{STATE}_person_time_AGE_GROUP_{AGE_GROUP}_CURRENT_YEAR_{YEAR}_SEX_{SEX}"
)
@@ -46,8 +46,7 @@
"transition_count": TRANSITION_COUNT_COLUMN_TEMPLATE,
}
-NON_COUNT_TEMPLATES = [
-]
+NON_COUNT_TEMPLATES = []
SEXES = ("male", "female")
# TODO - add literals for years in the model
@@ -87,11 +86,18 @@ def RESULT_COLUMNS(kind="all"):
columns = list(STANDARD_COLUMNS.values()) + columns
else:
template = COLUMN_TEMPLATES[kind]
- filtered_field_map = {field: values
- for field, values in TEMPLATE_FIELD_MAP.items() if f"{{{field}}}" in template}
- fields, value_groups = filtered_field_map.keys(), itertools.product(*filtered_field_map.values())
+ filtered_field_map = {
+ field: values
+ for field, values in TEMPLATE_FIELD_MAP.items()
+ if f"{{{field}}}" in template
+ }
+ fields, value_groups = filtered_field_map.keys(), itertools.product(
+ *filtered_field_map.values()
+ )
for value_group in value_groups:
- columns.append(template.format(**{field: value for field, value in zip(fields, value_group)}))
+ columns.append(
+ template.format(**{field: value for field, value in zip(fields, value_group)})
+ )
return columns
@@ -101,13 +107,21 @@ def RESULTS_MAP(kind):
raise ValueError(f"Unknown result column type {kind}")
columns = []
template = COLUMN_TEMPLATES[kind]
- filtered_field_map = {field: values
- for field, values in TEMPLATE_FIELD_MAP.items() if f"{{{field}}}" in template}
- fields, value_groups = list(filtered_field_map.keys()), list(itertools.product(*filtered_field_map.values()))
+ filtered_field_map = {
+ field: values
+ for field, values in TEMPLATE_FIELD_MAP.items()
+ if f"{{{field}}}" in template
+ }
+ fields, value_groups = list(filtered_field_map.keys()), list(
+ itertools.product(*filtered_field_map.values())
+ )
for value_group in value_groups:
- columns.append(template.format(**{field: value for field, value in zip(fields, value_group)}))
+ columns.append(
+ template.format(**{field: value for field, value in zip(fields, value_group)})
+ )
df = pd.DataFrame(value_groups, columns=map(lambda x: x.lower(), fields))
df["key"] = columns
- df["measure"] = kind # per researcher feedback, this column is useful, even when it"s identical for all rows
+ df["measure"] = (
+ kind # per researcher feedback, this column is useful, even when it"s identical for all rows
+ )
return df.set_index("key").sort_index()
-
diff --git a/src/vivarium_sodium_reduction/constants/scenarios.py b/src/vivarium_sodium_reduction/constants/scenarios.py
index 23d78ec..c72a4e9 100644
--- a/src/vivarium_sodium_reduction/constants/scenarios.py
+++ b/src/vivarium_sodium_reduction/constants/scenarios.py
@@ -8,11 +8,11 @@
class InterventionScenario:
def __init__(
- self,
- name: str,
- # todo add additional interventions
- # has_treatment_one: bool = False,
- # has_treatment_two: bool = False,
+ self,
+ name: str,
+ # todo add additional interventions
+ # has_treatment_one: bool = False,
+ # has_treatment_two: bool = False,
):
self.name = name
# self.has_treatment_one = has_treatment_one
diff --git a/src/vivarium_sodium_reduction/data/builder.py b/src/vivarium_sodium_reduction/data/builder.py
index fb695fe..ab95bc1 100644
--- a/src/vivarium_sodium_reduction/data/builder.py
+++ b/src/vivarium_sodium_reduction/data/builder.py
@@ -9,6 +9,7 @@
Logging in this module should be done at the ``debug`` level.
"""
+
from pathlib import Path
from typing import Optional
diff --git a/src/vivarium_sodium_reduction/data/loader.py b/src/vivarium_sodium_reduction/data/loader.py
index e712023..4bec7e4 100644
--- a/src/vivarium_sodium_reduction/data/loader.py
+++ b/src/vivarium_sodium_reduction/data/loader.py
@@ -12,6 +12,7 @@
No logging is done here. Logging is done in vivarium inputs itself and forwarded.
"""
+
from typing import List, Optional, Union
import pandas as pd
diff --git a/src/vivarium_sodium_reduction/results_processing/process_results.py b/src/vivarium_sodium_reduction/results_processing/process_results.py
index a4a0461..def61d8 100644
--- a/src/vivarium_sodium_reduction/results_processing/process_results.py
+++ b/src/vivarium_sodium_reduction/results_processing/process_results.py
@@ -7,34 +7,33 @@
from vivarium_sodium_reduction.constants import results, scenarios
-SCENARIO_COLUMN = 'scenario'
-GROUPBY_COLUMNS = [
- results.INPUT_DRAW_COLUMN,
- SCENARIO_COLUMN
-]
+SCENARIO_COLUMN = "scenario"
+GROUPBY_COLUMNS = [results.INPUT_DRAW_COLUMN, SCENARIO_COLUMN]
OUTPUT_COLUMN_SORT_ORDER = [
- 'age_group',
- 'sex',
- 'year',
- 'risk',
- 'cause',
- 'measure',
- 'input_draw'
+ "age_group",
+ "sex",
+ "year",
+ "risk",
+ "cause",
+ "measure",
+ "input_draw",
]
RENAME_COLUMNS = {
- 'age_group': 'age',
- 'cause_of_death': 'cause',
+ "age_group": "age",
+ "cause_of_death": "cause",
}
def make_measure_data(data):
measure_data = MeasureData(
population=get_population_data(data),
- ylls=get_by_cause_measure_data(data, 'ylls'),
- ylds=get_by_cause_measure_data(data, 'ylds'),
- deaths=get_by_cause_measure_data(data, 'deaths'),
- state_person_time=get_state_person_time_measure_data(data, 'disease_state_person_time'),
- transition_count=get_transition_count_measure_data(data, 'disease_transition_count'),
+ ylls=get_by_cause_measure_data(data, "ylls"),
+ ylds=get_by_cause_measure_data(data, "ylds"),
+ deaths=get_by_cause_measure_data(data, "deaths"),
+ state_person_time=get_state_person_time_measure_data(
+ data, "disease_state_person_time"
+ ),
+ transition_count=get_transition_count_measure_data(data, "disease_transition_count"),
)
return measure_data
@@ -49,15 +48,14 @@ class MeasureData(NamedTuple):
def dump(self, output_dir: Path):
for key, df in self._asdict().items():
- df.to_csv(output_dir / f'{key}.csv')
+ df.to_csv(output_dir / f"{key}.csv")
def read_data(path: Path, single_run: bool) -> (pd.DataFrame, List[str]):
data = pd.read_hdf(path)
# noinspection PyUnresolvedReferences
data = (
- data
- .drop(columns=data.columns.intersection(results.THROWAWAY_COLUMNS))
+ data.drop(columns=data.columns.intersection(results.THROWAWAY_COLUMNS))
.reset_index(drop=True)
.rename(
columns={
@@ -74,17 +72,19 @@ def read_data(path: Path, single_run: bool) -> (pd.DataFrame, List[str]):
keyspace = {
results.INPUT_DRAW_COLUMN: [0],
results.RANDOM_SEED_COLUMN: [0],
- results.OUTPUT_SCENARIO_COLUMN: [scenarios.INTERVENTION_SCENARIOS.BASELINE.name]
+ results.OUTPUT_SCENARIO_COLUMN: [scenarios.INTERVENTION_SCENARIOS.BASELINE.name],
}
else:
data[results.INPUT_DRAW_COLUMN] = data[results.INPUT_DRAW_COLUMN].astype(int)
data[results.RANDOM_SEED_COLUMN] = data[results.RANDOM_SEED_COLUMN].astype(int)
- with (path.parent / 'keyspace.yaml').open() as f:
+ with (path.parent / "keyspace.yaml").open() as f:
keyspace = yaml.full_load(f)
return data, keyspace
-def filter_out_incomplete(data: pd.DataFrame, keyspace: Dict[str, Union[str, int]]) -> pd.DataFrame:
+def filter_out_incomplete(
+ data: pd.DataFrame, keyspace: Dict[str, Union[str, int]]
+) -> pd.DataFrame:
output = []
for draw in keyspace[results.INPUT_DRAW_COLUMN]:
# For each draw, gather all random seeds completed for all scenarios.
@@ -108,33 +108,35 @@ def aggregate_over_seed(data: pd.DataFrame) -> pd.DataFrame:
# non_count_data = data[non_count_columns + GROUPBY_COLUMNS].groupby(GROUPBY_COLUMNS).mean()
count_data = data[count_columns + GROUPBY_COLUMNS].groupby(GROUPBY_COLUMNS).sum()
- return pd.concat([
- count_data,
- # non_count_data
- ], axis=1).reset_index()
+ return pd.concat(
+ [
+ count_data,
+ # non_count_data
+ ],
+ axis=1,
+ ).reset_index()
def pivot_data(data: pd.DataFrame) -> pd.DataFrame:
return (
- data
- .set_index(GROUPBY_COLUMNS)
+ data.set_index(GROUPBY_COLUMNS)
.stack()
.reset_index()
- .rename(columns={f'level_{len(GROUPBY_COLUMNS)}': 'key', 0: 'value'})
+ .rename(columns={f"level_{len(GROUPBY_COLUMNS)}": "key", 0: "value"})
)
def sort_data(data: pd.DataFrame) -> pd.DataFrame:
sort_order = [c for c in OUTPUT_COLUMN_SORT_ORDER if c in data.columns]
- other_cols = [c for c in data.columns if c not in sort_order and c != 'value']
- data = data[sort_order + other_cols + ['value']].sort_values(sort_order)
+ other_cols = [c for c in data.columns if c not in sort_order and c != "value"]
+ data = data[sort_order + other_cols + ["value"]].sort_values(sort_order)
return data.reset_index(drop=True)
def apply_results_map(data: pd.DataFrame, kind: str) -> pd.DataFrame:
logger.info(f"Mapping {kind} data to stratifications.")
map_df = results.RESULTS_MAP(kind)
- data = data.set_index('key')
+ data = data.set_index("key")
data = data.join(map_df).reset_index(drop=True)
data = data.rename(columns=RENAME_COLUMNS)
logger.info(f"Mapping {kind} complete.")
@@ -142,10 +144,14 @@ def apply_results_map(data: pd.DataFrame, kind: str) -> pd.DataFrame:
def get_population_data(data: pd.DataFrame) -> pd.DataFrame:
- total_pop = pivot_data(data[[results.TOTAL_POPULATION_COLUMN]
- + results.RESULT_COLUMNS('population')
- + GROUPBY_COLUMNS])
- total_pop = total_pop.rename(columns={'key': 'measure'})
+ total_pop = pivot_data(
+ data[
+ [results.TOTAL_POPULATION_COLUMN]
+ + results.RESULT_COLUMNS("population")
+ + GROUPBY_COLUMNS
+ ]
+ )
+ total_pop = total_pop.rename(columns={"key": "measure"})
return sort_data(total_pop)
@@ -167,6 +173,6 @@ def get_state_person_time_measure_data(data: pd.DataFrame, measure: str) -> pd.D
def get_transition_count_measure_data(data: pd.DataFrame, measure: str) -> pd.DataFrame:
# Oops, edge case.
- data = data.drop(columns=[c for c in data.columns if 'event_count' in c and '2041' in c])
+ data = data.drop(columns=[c for c in data.columns if "event_count" in c and "2041" in c])
data = get_measure_data(data, measure)
return sort_data(data)
diff --git a/src/vivarium_sodium_reduction/tools/app_logging.py b/src/vivarium_sodium_reduction/tools/app_logging.py
index f6ea409..b85b1f5 100644
--- a/src/vivarium_sodium_reduction/tools/app_logging.py
+++ b/src/vivarium_sodium_reduction/tools/app_logging.py
@@ -4,7 +4,9 @@
from loguru import logger
-def add_logging_sink(sink: TextIO, verbose: int, colorize: bool = False, serialize: bool = False):
+def add_logging_sink(
+ sink: TextIO, verbose: int, colorize: bool = False, serialize: bool = False
+):
"""Adds a logging sink to the global process logger.
Parameters
@@ -20,14 +22,26 @@ def add_logging_sink(sink: TextIO, verbose: int, colorize: bool = False, seriali
to the logging sink.
"""
- message_format = ('{time:YYYY-MM-DD HH:mm:ss.SSS} | {elapsed} | '
- '{function}:{line} - {message}')
+ message_format = (
+ "{time:YYYY-MM-DD HH:mm:ss.SSS} | {elapsed} | "
+ "{function}:{line} - {message}"
+ )
if verbose == 0:
- logger.add(sink, colorize=colorize, level="WARNING", format=message_format, serialize=serialize)
+ logger.add(
+ sink,
+ colorize=colorize,
+ level="WARNING",
+ format=message_format,
+ serialize=serialize,
+ )
elif verbose == 1:
- logger.add(sink, colorize=colorize, level="INFO", format=message_format, serialize=serialize)
+ logger.add(
+ sink, colorize=colorize, level="INFO", format=message_format, serialize=serialize
+ )
elif verbose >= 2:
- logger.add(sink, colorize=colorize, level="DEBUG", format=message_format, serialize=serialize)
+ logger.add(
+ sink, colorize=colorize, level="DEBUG", format=message_format, serialize=serialize
+ )
def configure_logging_to_terminal(verbose: int):
@@ -44,15 +58,17 @@ def configure_logging_to_terminal(verbose: int):
def decode_status(drmaa, job_status):
- decoder_map = {drmaa.JobState.UNDETERMINED: 'undetermined',
- drmaa.JobState.QUEUED_ACTIVE: 'queued_active',
- drmaa.JobState.SYSTEM_ON_HOLD: 'system_hold',
- drmaa.JobState.USER_ON_HOLD: 'user_hold',
- drmaa.JobState.USER_SYSTEM_ON_HOLD: 'user_system_hold',
- drmaa.JobState.RUNNING: 'running',
- drmaa.JobState.SYSTEM_SUSPENDED: 'system_suspended',
- drmaa.JobState.USER_SUSPENDED: 'user_suspended',
- drmaa.JobState.DONE: 'finished',
- drmaa.JobState.FAILED: 'failed'}
+ decoder_map = {
+ drmaa.JobState.UNDETERMINED: "undetermined",
+ drmaa.JobState.QUEUED_ACTIVE: "queued_active",
+ drmaa.JobState.SYSTEM_ON_HOLD: "system_hold",
+ drmaa.JobState.USER_ON_HOLD: "user_hold",
+ drmaa.JobState.USER_SYSTEM_ON_HOLD: "user_system_hold",
+ drmaa.JobState.RUNNING: "running",
+ drmaa.JobState.SYSTEM_SUSPENDED: "system_suspended",
+ drmaa.JobState.USER_SUSPENDED: "user_suspended",
+ drmaa.JobState.DONE: "finished",
+ drmaa.JobState.FAILED: "failed",
+ }
return decoder_map[job_status]
diff --git a/src/vivarium_sodium_reduction/tools/cli.py b/src/vivarium_sodium_reduction/tools/cli.py
index af0ddd4..eee36f0 100644
--- a/src/vivarium_sodium_reduction/tools/cli.py
+++ b/src/vivarium_sodium_reduction/tools/cli.py
@@ -5,9 +5,11 @@
from vivarium.framework.utilities import handle_exceptions
from vivarium_sodium_reduction.constants import metadata, paths
-from vivarium_sodium_reduction.tools import (build_artifacts,
- build_results,
- configure_logging_to_terminal)
+from vivarium_sodium_reduction.tools import (
+ build_artifacts,
+ build_results,
+ configure_logging_to_terminal,
+)
@click.command()
diff --git a/src/vivarium_sodium_reduction/tools/make_artifacts.py b/src/vivarium_sodium_reduction/tools/make_artifacts.py
index 3ca65c2..0ab1a75 100644
--- a/src/vivarium_sodium_reduction/tools/make_artifacts.py
+++ b/src/vivarium_sodium_reduction/tools/make_artifacts.py
@@ -6,6 +6,7 @@
Use your best judgement.
"""
+
import shutil
import sys
import time
diff --git a/src/vivarium_sodium_reduction/tools/make_results.py b/src/vivarium_sodium_reduction/tools/make_results.py
index 92f6f18..9af37b2 100644
--- a/src/vivarium_sodium_reduction/tools/make_results.py
+++ b/src/vivarium_sodium_reduction/tools/make_results.py
@@ -8,21 +8,23 @@
def build_results(output_file: str, single_run: bool) -> None:
output_file = Path(output_file)
- measure_dir = output_file.parent / 'count_data'
+ measure_dir = output_file.parent / "count_data"
if measure_dir.exists():
shutil.rmtree(measure_dir)
measure_dir.mkdir(exist_ok=True, mode=0o775)
- logger.info(f'Reading in output data from {str(output_file)}.')
+ logger.info(f"Reading in output data from {str(output_file)}.")
data, keyspace = process_results.read_data(output_file, single_run)
- logger.info(f'Filtering incomplete data from outputs.')
+ logger.info(f"Filtering incomplete data from outputs.")
rows = len(data)
data = process_results.filter_out_incomplete(data, keyspace)
new_rows = len(data)
- logger.info(f'Filtered {rows - new_rows} from data due to incomplete information. {new_rows} remaining.')
+ logger.info(
+ f"Filtered {rows - new_rows} from data due to incomplete information. {new_rows} remaining."
+ )
data = process_results.aggregate_over_seed(data)
- logger.info(f'Computing raw count and proportion data.')
+ logger.info(f"Computing raw count and proportion data.")
measure_data = process_results.make_measure_data(data)
- logger.info(f'Writing raw count and proportion data to {str(measure_dir)}')
+ logger.info(f"Writing raw count and proportion data to {str(measure_dir)}")
measure_data.dump(measure_dir)
- logger.info('**DONE**')
+ logger.info("**DONE**")
diff --git a/src/vivarium_sodium_reduction/utilities.py b/src/vivarium_sodium_reduction/utilities.py
index ff2d1bf..d45b959 100644
--- a/src/vivarium_sodium_reduction/utilities.py
+++ b/src/vivarium_sodium_reduction/utilities.py
@@ -50,14 +50,16 @@ def delete_if_exists(*paths: Union[Path, List[Path]], confirm=False):
# Assumes all paths have the same root dir
root = existing_paths[0].parent
names = [p.name for p in existing_paths]
- click.confirm(f"Existing files {names} found in directory {root}. Do you want to delete and replace?",
- abort=True)
+ click.confirm(
+ f"Existing files {names} found in directory {root}. Do you want to delete and replace?",
+ abort=True,
+ )
for p in existing_paths:
- logger.info(f'Deleting artifact at {str(p)}.')
+ logger.info(f"Deleting artifact at {str(p)}.")
p.unlink()
-def read_data_by_draw(artifact_path: str, key : str, draw: int) -> pd.DataFrame:
+def read_data_by_draw(artifact_path: str, key: str, draw: int) -> pd.DataFrame:
"""Reads data from the artifact on a per-draw basis. This
is necessary for Low Birthweight Short Gestation (LBWSG) data.
@@ -72,32 +74,34 @@ def read_data_by_draw(artifact_path: str, key : str, draw: int) -> pd.DataFrame:
"""
key = key.replace(".", "/")
- with pd.HDFStore(artifact_path, mode='r') as store:
- index = store.get(f'{key}/index')
- draw = store.get(f'{key}/draw_{draw}')
+ with pd.HDFStore(artifact_path, mode="r") as store:
+ index = store.get(f"{key}/index")
+ draw = store.get(f"{key}/draw_{draw}")
draw = draw.rename("value")
data = pd.concat([index, draw], axis=1)
- data = data.drop(columns='location')
+ data = data.drop(columns="location")
data = pivot_categorical(data)
- data[project_globals.LBWSG_MISSING_CATEGORY.CAT] = project_globals.LBWSG_MISSING_CATEGORY.EXPOSURE
+ data[project_globals.LBWSG_MISSING_CATEGORY.CAT] = (
+ project_globals.LBWSG_MISSING_CATEGORY.EXPOSURE
+ )
return data
def get_norm(
- mean: float,
- sd: float = None,
- ninety_five_pct_confidence_interval: Tuple[float, float] = None
+ mean: float,
+ sd: float = None,
+ ninety_five_pct_confidence_interval: Tuple[float, float] = None,
) -> stats.norm:
sd = _get_standard_deviation(mean, sd, ninety_five_pct_confidence_interval)
return stats.norm(loc=mean, scale=sd)
def get_truncnorm(
- mean: float,
- sd: float = None,
- ninety_five_pct_confidence_interval: Tuple[float, float] = None,
- lower_clip: float = 0.0,
- upper_clip: float = 1.0
+ mean: float,
+ sd: float = None,
+ ninety_five_pct_confidence_interval: Tuple[float, float] = None,
+ lower_clip: float = 0.0,
+ upper_clip: float = 1.0,
) -> stats.norm:
sd = _get_standard_deviation(mean, sd, ninety_five_pct_confidence_interval)
a = (lower_clip - mean) / sd if sd else mean - 1e-03
@@ -106,12 +110,16 @@ def get_truncnorm(
def _get_standard_deviation(
- mean: float, sd: float, ninety_five_pct_confidence_interval: Tuple[float, float]
+ mean: float, sd: float, ninety_five_pct_confidence_interval: Tuple[float, float]
) -> float:
if sd is None and ninety_five_pct_confidence_interval is None:
- raise ValueError("Must provide either a standard deviation or a 95% confidence interval.")
+ raise ValueError(
+ "Must provide either a standard deviation or a 95% confidence interval."
+ )
if sd is not None and ninety_five_pct_confidence_interval is not None:
- raise ValueError("Cannot provide both a standard deviation and a 95% confidence interval.")
+ raise ValueError(
+ "Cannot provide both a standard deviation and a 95% confidence interval."
+ )
if ninety_five_pct_confidence_interval is not None:
lower = ninety_five_pct_confidence_interval[0]
upper = ninety_five_pct_confidence_interval[1]
@@ -126,8 +134,9 @@ def _get_standard_deviation(
return sd
-def get_lognorm_from_quantiles(median: float, lower: float, upper: float,
- quantiles: Tuple[float, float] = (0.025, 0.975)) -> stats.lognorm:
+def get_lognorm_from_quantiles(
+ median: float, lower: float, upper: float, quantiles: Tuple[float, float] = (0.025, 0.975)
+) -> stats.lognorm:
"""Returns a frozen lognormal distribution with the specified median, such that
(lower, upper) are approximately equal to the quantiles with ranks
(quantile_ranks[0], quantile_ranks[1]).
@@ -149,17 +158,21 @@ def get_lognorm_from_quantiles(median: float, lower: float, upper: float,
norm_quantiles = np.log([lower, upper])
# standard deviation of Y = log(X) computed from the above quantiles for Y
# and the corresponding standard normal quantiles
- sigma = (norm_quantiles[1] - norm_quantiles[0]) / (stdnorm_quantiles[1] - stdnorm_quantiles[0])
+ sigma = (norm_quantiles[1] - norm_quantiles[0]) / (
+ stdnorm_quantiles[1] - stdnorm_quantiles[0]
+ )
# Frozen lognormal distribution for X = exp(Y)
# (s=sigma is the shape parameter; the scale parameter is exp(mu), which equals the median)
return stats.lognorm(s=sigma, scale=median)
-def get_random_variable_draws(number: int, seeded_distribution: SeededDistribution) -> np.array:
+def get_random_variable_draws(
+ number: int, seeded_distribution: SeededDistribution
+) -> np.array:
return np.array([get_random_variable(x, seeded_distribution) for x in range(number)])
def get_random_variable(draw: int, seeded_distribution: SeededDistribution) -> float:
seed, distribution = seeded_distribution
- np.random.seed(get_hash(f'{seed}_draw_{draw}'))
+ np.random.seed(get_hash(f"{seed}_draw_{draw}"))
return distribution.rvs()
diff --git a/update_readme.py b/update_readme.py
index d1d72a8..526e23d 100644
--- a/update_readme.py
+++ b/update_readme.py
@@ -1,6 +1,7 @@
""" This script updates the README.rst file with the latest information about
the project. It is intended to be run from the github "update README" workflow.
"""
+
import json
import re