diff --git a/src/vivarium_sodium_reduction/components/interventions.py b/src/vivarium_sodium_reduction/components/interventions.py index 28a4256..771b2ed 100644 --- a/src/vivarium_sodium_reduction/components/interventions.py +++ b/src/vivarium_sodium_reduction/components/interventions.py @@ -1,17 +1,19 @@ +from typing import Any, Dict + import pandas as pd -from typing import Dict, Any from vivarium import Component from vivarium.framework.engine import Builder from vivarium.framework.population import SimulantData + class RelativeShiftIntervention(Component): """Applies a relative shift to a target value.""" CONFIGURATION_DEFAULTS = { - 'intervention': { - 'shift_factor': 0.1, - 'age_start': 0, - 'age_end': 125, + "intervention": { + "shift_factor": 0.1, + "age_start": 0, + "age_end": 125, } } @@ -25,9 +27,7 @@ def name(self) -> str: @property def configuration_defaults(self) -> Dict[str, Dict[str, Any]]: - return { - f'{self.name}': self.CONFIGURATION_DEFAULTS['intervention'] - } + return {f"{self.name}": self.CONFIGURATION_DEFAULTS["intervention"]} def setup(self, builder: Builder) -> None: self.config = builder.configuration[self.name] @@ -35,12 +35,10 @@ def setup(self, builder: Builder) -> None: self.age_start = self.config.age_start self.age_end = self.config.age_end - self.population_view = builder.population.get_view(['age']) + self.population_view = builder.population.get_view(["age"]) builder.value.register_value_modifier( - self.target, - modifier=self.adjust_exposure, - requires_columns=['age'] + self.target, modifier=self.adjust_exposure, requires_columns=["age"] ) def adjust_exposure(self, index: pd.Index, exposure: pd.Series) -> pd.Series: @@ -49,4 +47,4 @@ def adjust_exposure(self, index: pd.Index, exposure: pd.Series) -> pd.Series: (self.age_start <= pop.age) & (pop.age < self.age_end) ].index exposure.loc[applicable_index] *= self.shift_factor - return exposure \ No newline at end of file + return exposure diff --git a/src/vivarium_sodium_reduction/components/risks.py b/src/vivarium_sodium_reduction/components/risks.py index 36f85d3..74e2462 100644 --- a/src/vivarium_sodium_reduction/components/risks.py +++ b/src/vivarium_sodium_reduction/components/risks.py @@ -11,9 +11,12 @@ from vivarium.framework.randomness import get_hash from vivarium.framework.values import Pipeline from vivarium_public_health.risks import Risk -from vivarium_public_health.risks.data_transformations import get_exposure_post_processor +from vivarium_public_health.risks.data_transformations import ( + get_exposure_post_processor, +) from vivarium_public_health.utilities import EntityString + class DropValueRisk(Risk): def __init__(self, risk: str): super().__init__(risk) @@ -57,11 +60,13 @@ def post_processor(exposure, _): return post_processor + class CorrelatedRisk(DropValueRisk): """A risk that can be correlated with another risk. - + TODO: document strategy used in this component in more detail, Abie had an AI adapt it from https://github.com/ihmeuw/vivarium_nih_us_cvd""" + @property def columns_created(self) -> List[str]: return [] @@ -84,8 +89,10 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None: def on_time_step_prepare(self, event: Event) -> None: pass + class RiskCorrelation(Component): """A component that generates a specified correlation between two risk exposures.""" + @property def columns_created(self) -> List[str]: return self.propensity_column_names + self.exposure_column_names @@ -104,8 +111,12 @@ def __init__(self, risk1: str, risk2: str, correlation: str): correlation_matrix = np.array([[1, float(correlation)], [float(correlation), 1]]) self.correlated_risks = [EntityString(risk) for risk in correlated_risks] self.correlation_matrix = correlation_matrix - self.propensity_column_names = [f"{risk.name}_propensity" for risk in self.correlated_risks] - self.exposure_column_names = [f"{risk.name}_exposure" for risk in self.correlated_risks] + self.propensity_column_names = [ + f"{risk.name}_propensity" for risk in self.correlated_risks + ] + self.exposure_column_names = [ + f"{risk.name}_exposure" for risk in self.correlated_risks + ] self.ensemble_propensities = [ f"ensemble_propensity_" + risk for risk in self.correlated_risks @@ -113,8 +124,14 @@ def __init__(self, risk1: str, risk2: str, correlation: str): ] def setup(self, builder: Builder) -> None: - self.distributions = {risk: builder.components.get_component(risk).exposure_distribution for risk in self.correlated_risks} - self.exposures = {risk: builder.value.get_value(f"{risk.name}.exposure") for risk in self.correlated_risks} + self.distributions = { + risk: builder.components.get_component(risk).exposure_distribution + for risk in self.correlated_risks + } + self.exposures = { + risk: builder.value.get_value(f"{risk.name}.exposure") + for risk in self.correlated_risks + } self.input_draw = builder.configuration.input_data.input_draw_number self.random_seed = builder.configuration.randomness.random_seed @@ -124,21 +141,21 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None: np.random.seed(get_hash(f"{self.input_draw}_{self.random_seed}")) probit_propensity = np.random.multivariate_normal( - mean=[0] * len(self.correlated_risks), - cov=self.correlation_matrix, - size=len(pop) + mean=[0] * len(self.correlated_risks), cov=self.correlation_matrix, size=len(pop) ) correlated_propensities = scipy.stats.norm().cdf(probit_propensity) propensities[self.propensity_column_names] = correlated_propensities def get_exposure_from_propensity(propensity_col: pd.Series) -> pd.Series: - risk = propensity_col.name.replace('_propensity','') - exposure_values = self.distributions['risk_factor.' + risk].ppf(propensity_col) + risk = propensity_col.name.replace("_propensity", "") + exposure_values = self.distributions["risk_factor." + risk].ppf(propensity_col) return pd.Series(exposure_values) exposures = propensities.apply(get_exposure_from_propensity) - exposures.columns = [col.replace('_propensity','_exposure') for col in propensities.columns] - + exposures.columns = [ + col.replace("_propensity", "_exposure") for col in propensities.columns + ] + self.population_view.update(pd.concat([propensities, exposures], axis=1)) def on_time_step_prepare(self, event: Event) -> None: @@ -147,33 +164,37 @@ def on_time_step_prepare(self, event: Event) -> None: exposure_col = pd.Series(exposure_values, name=f"{risk.name}_exposure") self.population_view.update(exposure_col) + class SodiumSBPEffect(Component): @property def name(self): return "sodium_sbp_effect" def setup(self, builder: Builder): - self.sodium_exposure = builder.value.get_value('diet_high_in_sodium.exposure') - self.sodium_exposure_raw = builder.value.get_value('diet_high_in_sodium.raw_exposure') - + self.sodium_exposure = builder.value.get_value("diet_high_in_sodium.exposure") + self.sodium_exposure_raw = builder.value.get_value("diet_high_in_sodium.raw_exposure") + builder.value.register_value_modifier( - 'high_systolic_blood_pressure.drop_value', + "high_systolic_blood_pressure.drop_value", modifier=self.sodium_effect_on_sbp, - requires_columns=['age', 'sex'], - requires_values=['diet_high_in_sodium.exposure', 'diet_high_in_sodium.raw_exposure'] + requires_columns=["age", "sex"], + requires_values=[ + "diet_high_in_sodium.exposure", + "diet_high_in_sodium.raw_exposure", + ], ) def sodium_effect_on_sbp(self, index, sbp_drop_value): sodium_exposure = self.sodium_exposure(index) sodium_exposure_raw = self.sodium_exposure_raw(index) - + sodium_threshold = 2.0 # g/day mmHg_per_g_sodium = 10 # mmHg increase per 1g sodium above threshold - + sbp_increase = pd.Series(0, index=index) sodium_drop = sodium_exposure_raw - sodium_exposure # TODO: use threshold sbp_drop_due_to_sodium_drop = sodium_drop * mmHg_per_g_sodium - + return sbp_drop_value + sbp_drop_due_to_sodium_drop diff --git a/src/vivarium_sodium_reduction/constants/data_keys.py b/src/vivarium_sodium_reduction/constants/data_keys.py index 156318d..9d5921e 100644 --- a/src/vivarium_sodium_reduction/constants/data_keys.py +++ b/src/vivarium_sodium_reduction/constants/data_keys.py @@ -36,8 +36,12 @@ class __SomeDisease(NamedTuple): # Keys that will be loaded into the artifact. must have a colon type declaration SOME_DISEASE_PREVALENCE: TargetString = TargetString("cause.some_disease.prevalence") - SOME_DISEASE_INCIDENCE_RATE: TargetString = TargetString("cause.some_disease.incidence_rate") - SOME_DISEASE_REMISSION_RATE: TargetString = TargetString("cause.some_disease.remission_rate") + SOME_DISEASE_INCIDENCE_RATE: TargetString = TargetString( + "cause.some_disease.incidence_rate" + ) + SOME_DISEASE_REMISSION_RATE: TargetString = TargetString( + "cause.some_disease.remission_rate" + ) DISABILITY_WEIGHT: TargetString = TargetString("cause.some_disease.disability_weight") EMR: TargetString = TargetString("cause.some_disease.excess_mortality_rate") CSMR: TargetString = TargetString("cause.some_disease.cause_specific_mortality_rate") diff --git a/src/vivarium_sodium_reduction/constants/data_values.py b/src/vivarium_sodium_reduction/constants/data_values.py index b2b3458..06cf48a 100644 --- a/src/vivarium_sodium_reduction/constants/data_values.py +++ b/src/vivarium_sodium_reduction/constants/data_values.py @@ -29,5 +29,6 @@ SCALE_UP_START_DT = datetime(2021, 1, 1) SCALE_UP_END_DT = datetime(2030, 1, 1) SCREENING_SCALE_UP_GOAL_COVERAGE = 0.50 -SCREENING_SCALE_UP_DIFFERENCE = SCREENING_SCALE_UP_GOAL_COVERAGE - PROBABILITY_ATTENDING_SCREENING_START_MEAN - +SCREENING_SCALE_UP_DIFFERENCE = ( + SCREENING_SCALE_UP_GOAL_COVERAGE - PROBABILITY_ATTENDING_SCREENING_START_MEAN +) diff --git a/src/vivarium_sodium_reduction/constants/models.py b/src/vivarium_sodium_reduction/constants/models.py index d135bf5..3543fa5 100644 --- a/src/vivarium_sodium_reduction/constants/models.py +++ b/src/vivarium_sodium_reduction/constants/models.py @@ -35,4 +35,6 @@ def __new__(cls, value): STATES = tuple(state for model in STATE_MACHINE_MAP.values() for state in model["states"]) -TRANSITIONS = tuple(state for model in STATE_MACHINE_MAP.values() for state in model["transitions"]) +TRANSITIONS = tuple( + state for model in STATE_MACHINE_MAP.values() for state in model["transitions"] +) diff --git a/src/vivarium_sodium_reduction/constants/paths.py b/src/vivarium_sodium_reduction/constants/paths.py index 851e727..4af4711 100644 --- a/src/vivarium_sodium_reduction/constants/paths.py +++ b/src/vivarium_sodium_reduction/constants/paths.py @@ -5,5 +5,6 @@ BASE_DIR = Path(vivarium_sodium_reduction.__file__).resolve().parent -ARTIFACT_ROOT = Path(f"/mnt/team/simulation_science/pub/models/{metadata.PROJECT_NAME}/artifacts/") - +ARTIFACT_ROOT = Path( + f"/mnt/team/simulation_science/pub/models/{metadata.PROJECT_NAME}/artifacts/" +) diff --git a/src/vivarium_sodium_reduction/constants/results.py b/src/vivarium_sodium_reduction/constants/results.py index 8e4065d..e21809e 100644 --- a/src/vivarium_sodium_reduction/constants/results.py +++ b/src/vivarium_sodium_reduction/constants/results.py @@ -27,10 +27,10 @@ THROWAWAY_COLUMNS = [f"{state}_event_count" for state in models.STATES] DEATH_COLUMN_TEMPLATE = "MEASURE_death_due_to_{CAUSE_OF_DEATH}_AGE_GROUP_{AGE_GROUP}_CURRENT_YEAR_{YEAR}_SEX_{SEX}" -YLLS_COLUMN_TEMPLATE = "MEASURE_ylls_due_to_{CAUSE_OF_DEATH}_AGE_GROUP_{AGE_GROUP}_CURRENT_YEAR_{YEAR}_SEX_{SEX}" -YLDS_COLUMN_TEMPLATE = ( - "MEASURE_ylds_due_to_{CAUSE_OF_DISABILITY}_AGE_GROUP_{AGE_GROUP}_CURRENT_YEAR_{YEAR}_SEX_{SEX}" +YLLS_COLUMN_TEMPLATE = ( + "MEASURE_ylls_due_to_{CAUSE_OF_DEATH}_AGE_GROUP_{AGE_GROUP}_CURRENT_YEAR_{YEAR}_SEX_{SEX}" ) +YLDS_COLUMN_TEMPLATE = "MEASURE_ylds_due_to_{CAUSE_OF_DISABILITY}_AGE_GROUP_{AGE_GROUP}_CURRENT_YEAR_{YEAR}_SEX_{SEX}" STATE_PERSON_TIME_COLUMN_TEMPLATE = ( "MEASURE_{STATE}_person_time_AGE_GROUP_{AGE_GROUP}_CURRENT_YEAR_{YEAR}_SEX_{SEX}" ) @@ -46,8 +46,7 @@ "transition_count": TRANSITION_COUNT_COLUMN_TEMPLATE, } -NON_COUNT_TEMPLATES = [ -] +NON_COUNT_TEMPLATES = [] SEXES = ("male", "female") # TODO - add literals for years in the model @@ -87,11 +86,18 @@ def RESULT_COLUMNS(kind="all"): columns = list(STANDARD_COLUMNS.values()) + columns else: template = COLUMN_TEMPLATES[kind] - filtered_field_map = {field: values - for field, values in TEMPLATE_FIELD_MAP.items() if f"{{{field}}}" in template} - fields, value_groups = filtered_field_map.keys(), itertools.product(*filtered_field_map.values()) + filtered_field_map = { + field: values + for field, values in TEMPLATE_FIELD_MAP.items() + if f"{{{field}}}" in template + } + fields, value_groups = filtered_field_map.keys(), itertools.product( + *filtered_field_map.values() + ) for value_group in value_groups: - columns.append(template.format(**{field: value for field, value in zip(fields, value_group)})) + columns.append( + template.format(**{field: value for field, value in zip(fields, value_group)}) + ) return columns @@ -101,13 +107,21 @@ def RESULTS_MAP(kind): raise ValueError(f"Unknown result column type {kind}") columns = [] template = COLUMN_TEMPLATES[kind] - filtered_field_map = {field: values - for field, values in TEMPLATE_FIELD_MAP.items() if f"{{{field}}}" in template} - fields, value_groups = list(filtered_field_map.keys()), list(itertools.product(*filtered_field_map.values())) + filtered_field_map = { + field: values + for field, values in TEMPLATE_FIELD_MAP.items() + if f"{{{field}}}" in template + } + fields, value_groups = list(filtered_field_map.keys()), list( + itertools.product(*filtered_field_map.values()) + ) for value_group in value_groups: - columns.append(template.format(**{field: value for field, value in zip(fields, value_group)})) + columns.append( + template.format(**{field: value for field, value in zip(fields, value_group)}) + ) df = pd.DataFrame(value_groups, columns=map(lambda x: x.lower(), fields)) df["key"] = columns - df["measure"] = kind # per researcher feedback, this column is useful, even when it"s identical for all rows + df["measure"] = ( + kind # per researcher feedback, this column is useful, even when it"s identical for all rows + ) return df.set_index("key").sort_index() - diff --git a/src/vivarium_sodium_reduction/constants/scenarios.py b/src/vivarium_sodium_reduction/constants/scenarios.py index 23d78ec..c72a4e9 100644 --- a/src/vivarium_sodium_reduction/constants/scenarios.py +++ b/src/vivarium_sodium_reduction/constants/scenarios.py @@ -8,11 +8,11 @@ class InterventionScenario: def __init__( - self, - name: str, - # todo add additional interventions - # has_treatment_one: bool = False, - # has_treatment_two: bool = False, + self, + name: str, + # todo add additional interventions + # has_treatment_one: bool = False, + # has_treatment_two: bool = False, ): self.name = name # self.has_treatment_one = has_treatment_one diff --git a/src/vivarium_sodium_reduction/data/builder.py b/src/vivarium_sodium_reduction/data/builder.py index fb695fe..ab95bc1 100644 --- a/src/vivarium_sodium_reduction/data/builder.py +++ b/src/vivarium_sodium_reduction/data/builder.py @@ -9,6 +9,7 @@ Logging in this module should be done at the ``debug`` level. """ + from pathlib import Path from typing import Optional diff --git a/src/vivarium_sodium_reduction/data/loader.py b/src/vivarium_sodium_reduction/data/loader.py index e712023..4bec7e4 100644 --- a/src/vivarium_sodium_reduction/data/loader.py +++ b/src/vivarium_sodium_reduction/data/loader.py @@ -12,6 +12,7 @@ No logging is done here. Logging is done in vivarium inputs itself and forwarded. """ + from typing import List, Optional, Union import pandas as pd diff --git a/src/vivarium_sodium_reduction/results_processing/process_results.py b/src/vivarium_sodium_reduction/results_processing/process_results.py index a4a0461..def61d8 100644 --- a/src/vivarium_sodium_reduction/results_processing/process_results.py +++ b/src/vivarium_sodium_reduction/results_processing/process_results.py @@ -7,34 +7,33 @@ from vivarium_sodium_reduction.constants import results, scenarios -SCENARIO_COLUMN = 'scenario' -GROUPBY_COLUMNS = [ - results.INPUT_DRAW_COLUMN, - SCENARIO_COLUMN -] +SCENARIO_COLUMN = "scenario" +GROUPBY_COLUMNS = [results.INPUT_DRAW_COLUMN, SCENARIO_COLUMN] OUTPUT_COLUMN_SORT_ORDER = [ - 'age_group', - 'sex', - 'year', - 'risk', - 'cause', - 'measure', - 'input_draw' + "age_group", + "sex", + "year", + "risk", + "cause", + "measure", + "input_draw", ] RENAME_COLUMNS = { - 'age_group': 'age', - 'cause_of_death': 'cause', + "age_group": "age", + "cause_of_death": "cause", } def make_measure_data(data): measure_data = MeasureData( population=get_population_data(data), - ylls=get_by_cause_measure_data(data, 'ylls'), - ylds=get_by_cause_measure_data(data, 'ylds'), - deaths=get_by_cause_measure_data(data, 'deaths'), - state_person_time=get_state_person_time_measure_data(data, 'disease_state_person_time'), - transition_count=get_transition_count_measure_data(data, 'disease_transition_count'), + ylls=get_by_cause_measure_data(data, "ylls"), + ylds=get_by_cause_measure_data(data, "ylds"), + deaths=get_by_cause_measure_data(data, "deaths"), + state_person_time=get_state_person_time_measure_data( + data, "disease_state_person_time" + ), + transition_count=get_transition_count_measure_data(data, "disease_transition_count"), ) return measure_data @@ -49,15 +48,14 @@ class MeasureData(NamedTuple): def dump(self, output_dir: Path): for key, df in self._asdict().items(): - df.to_csv(output_dir / f'{key}.csv') + df.to_csv(output_dir / f"{key}.csv") def read_data(path: Path, single_run: bool) -> (pd.DataFrame, List[str]): data = pd.read_hdf(path) # noinspection PyUnresolvedReferences data = ( - data - .drop(columns=data.columns.intersection(results.THROWAWAY_COLUMNS)) + data.drop(columns=data.columns.intersection(results.THROWAWAY_COLUMNS)) .reset_index(drop=True) .rename( columns={ @@ -74,17 +72,19 @@ def read_data(path: Path, single_run: bool) -> (pd.DataFrame, List[str]): keyspace = { results.INPUT_DRAW_COLUMN: [0], results.RANDOM_SEED_COLUMN: [0], - results.OUTPUT_SCENARIO_COLUMN: [scenarios.INTERVENTION_SCENARIOS.BASELINE.name] + results.OUTPUT_SCENARIO_COLUMN: [scenarios.INTERVENTION_SCENARIOS.BASELINE.name], } else: data[results.INPUT_DRAW_COLUMN] = data[results.INPUT_DRAW_COLUMN].astype(int) data[results.RANDOM_SEED_COLUMN] = data[results.RANDOM_SEED_COLUMN].astype(int) - with (path.parent / 'keyspace.yaml').open() as f: + with (path.parent / "keyspace.yaml").open() as f: keyspace = yaml.full_load(f) return data, keyspace -def filter_out_incomplete(data: pd.DataFrame, keyspace: Dict[str, Union[str, int]]) -> pd.DataFrame: +def filter_out_incomplete( + data: pd.DataFrame, keyspace: Dict[str, Union[str, int]] +) -> pd.DataFrame: output = [] for draw in keyspace[results.INPUT_DRAW_COLUMN]: # For each draw, gather all random seeds completed for all scenarios. @@ -108,33 +108,35 @@ def aggregate_over_seed(data: pd.DataFrame) -> pd.DataFrame: # non_count_data = data[non_count_columns + GROUPBY_COLUMNS].groupby(GROUPBY_COLUMNS).mean() count_data = data[count_columns + GROUPBY_COLUMNS].groupby(GROUPBY_COLUMNS).sum() - return pd.concat([ - count_data, - # non_count_data - ], axis=1).reset_index() + return pd.concat( + [ + count_data, + # non_count_data + ], + axis=1, + ).reset_index() def pivot_data(data: pd.DataFrame) -> pd.DataFrame: return ( - data - .set_index(GROUPBY_COLUMNS) + data.set_index(GROUPBY_COLUMNS) .stack() .reset_index() - .rename(columns={f'level_{len(GROUPBY_COLUMNS)}': 'key', 0: 'value'}) + .rename(columns={f"level_{len(GROUPBY_COLUMNS)}": "key", 0: "value"}) ) def sort_data(data: pd.DataFrame) -> pd.DataFrame: sort_order = [c for c in OUTPUT_COLUMN_SORT_ORDER if c in data.columns] - other_cols = [c for c in data.columns if c not in sort_order and c != 'value'] - data = data[sort_order + other_cols + ['value']].sort_values(sort_order) + other_cols = [c for c in data.columns if c not in sort_order and c != "value"] + data = data[sort_order + other_cols + ["value"]].sort_values(sort_order) return data.reset_index(drop=True) def apply_results_map(data: pd.DataFrame, kind: str) -> pd.DataFrame: logger.info(f"Mapping {kind} data to stratifications.") map_df = results.RESULTS_MAP(kind) - data = data.set_index('key') + data = data.set_index("key") data = data.join(map_df).reset_index(drop=True) data = data.rename(columns=RENAME_COLUMNS) logger.info(f"Mapping {kind} complete.") @@ -142,10 +144,14 @@ def apply_results_map(data: pd.DataFrame, kind: str) -> pd.DataFrame: def get_population_data(data: pd.DataFrame) -> pd.DataFrame: - total_pop = pivot_data(data[[results.TOTAL_POPULATION_COLUMN] - + results.RESULT_COLUMNS('population') - + GROUPBY_COLUMNS]) - total_pop = total_pop.rename(columns={'key': 'measure'}) + total_pop = pivot_data( + data[ + [results.TOTAL_POPULATION_COLUMN] + + results.RESULT_COLUMNS("population") + + GROUPBY_COLUMNS + ] + ) + total_pop = total_pop.rename(columns={"key": "measure"}) return sort_data(total_pop) @@ -167,6 +173,6 @@ def get_state_person_time_measure_data(data: pd.DataFrame, measure: str) -> pd.D def get_transition_count_measure_data(data: pd.DataFrame, measure: str) -> pd.DataFrame: # Oops, edge case. - data = data.drop(columns=[c for c in data.columns if 'event_count' in c and '2041' in c]) + data = data.drop(columns=[c for c in data.columns if "event_count" in c and "2041" in c]) data = get_measure_data(data, measure) return sort_data(data) diff --git a/src/vivarium_sodium_reduction/tools/app_logging.py b/src/vivarium_sodium_reduction/tools/app_logging.py index f6ea409..b85b1f5 100644 --- a/src/vivarium_sodium_reduction/tools/app_logging.py +++ b/src/vivarium_sodium_reduction/tools/app_logging.py @@ -4,7 +4,9 @@ from loguru import logger -def add_logging_sink(sink: TextIO, verbose: int, colorize: bool = False, serialize: bool = False): +def add_logging_sink( + sink: TextIO, verbose: int, colorize: bool = False, serialize: bool = False +): """Adds a logging sink to the global process logger. Parameters @@ -20,14 +22,26 @@ def add_logging_sink(sink: TextIO, verbose: int, colorize: bool = False, seriali to the logging sink. """ - message_format = ('{time:YYYY-MM-DD HH:mm:ss.SSS} | {elapsed} | ' - '{function}:{line} - {message}') + message_format = ( + "{time:YYYY-MM-DD HH:mm:ss.SSS} | {elapsed} | " + "{function}:{line} - {message}" + ) if verbose == 0: - logger.add(sink, colorize=colorize, level="WARNING", format=message_format, serialize=serialize) + logger.add( + sink, + colorize=colorize, + level="WARNING", + format=message_format, + serialize=serialize, + ) elif verbose == 1: - logger.add(sink, colorize=colorize, level="INFO", format=message_format, serialize=serialize) + logger.add( + sink, colorize=colorize, level="INFO", format=message_format, serialize=serialize + ) elif verbose >= 2: - logger.add(sink, colorize=colorize, level="DEBUG", format=message_format, serialize=serialize) + logger.add( + sink, colorize=colorize, level="DEBUG", format=message_format, serialize=serialize + ) def configure_logging_to_terminal(verbose: int): @@ -44,15 +58,17 @@ def configure_logging_to_terminal(verbose: int): def decode_status(drmaa, job_status): - decoder_map = {drmaa.JobState.UNDETERMINED: 'undetermined', - drmaa.JobState.QUEUED_ACTIVE: 'queued_active', - drmaa.JobState.SYSTEM_ON_HOLD: 'system_hold', - drmaa.JobState.USER_ON_HOLD: 'user_hold', - drmaa.JobState.USER_SYSTEM_ON_HOLD: 'user_system_hold', - drmaa.JobState.RUNNING: 'running', - drmaa.JobState.SYSTEM_SUSPENDED: 'system_suspended', - drmaa.JobState.USER_SUSPENDED: 'user_suspended', - drmaa.JobState.DONE: 'finished', - drmaa.JobState.FAILED: 'failed'} + decoder_map = { + drmaa.JobState.UNDETERMINED: "undetermined", + drmaa.JobState.QUEUED_ACTIVE: "queued_active", + drmaa.JobState.SYSTEM_ON_HOLD: "system_hold", + drmaa.JobState.USER_ON_HOLD: "user_hold", + drmaa.JobState.USER_SYSTEM_ON_HOLD: "user_system_hold", + drmaa.JobState.RUNNING: "running", + drmaa.JobState.SYSTEM_SUSPENDED: "system_suspended", + drmaa.JobState.USER_SUSPENDED: "user_suspended", + drmaa.JobState.DONE: "finished", + drmaa.JobState.FAILED: "failed", + } return decoder_map[job_status] diff --git a/src/vivarium_sodium_reduction/tools/cli.py b/src/vivarium_sodium_reduction/tools/cli.py index af0ddd4..eee36f0 100644 --- a/src/vivarium_sodium_reduction/tools/cli.py +++ b/src/vivarium_sodium_reduction/tools/cli.py @@ -5,9 +5,11 @@ from vivarium.framework.utilities import handle_exceptions from vivarium_sodium_reduction.constants import metadata, paths -from vivarium_sodium_reduction.tools import (build_artifacts, - build_results, - configure_logging_to_terminal) +from vivarium_sodium_reduction.tools import ( + build_artifacts, + build_results, + configure_logging_to_terminal, +) @click.command() diff --git a/src/vivarium_sodium_reduction/tools/make_artifacts.py b/src/vivarium_sodium_reduction/tools/make_artifacts.py index 3ca65c2..0ab1a75 100644 --- a/src/vivarium_sodium_reduction/tools/make_artifacts.py +++ b/src/vivarium_sodium_reduction/tools/make_artifacts.py @@ -6,6 +6,7 @@ Use your best judgement. """ + import shutil import sys import time diff --git a/src/vivarium_sodium_reduction/tools/make_results.py b/src/vivarium_sodium_reduction/tools/make_results.py index 92f6f18..9af37b2 100644 --- a/src/vivarium_sodium_reduction/tools/make_results.py +++ b/src/vivarium_sodium_reduction/tools/make_results.py @@ -8,21 +8,23 @@ def build_results(output_file: str, single_run: bool) -> None: output_file = Path(output_file) - measure_dir = output_file.parent / 'count_data' + measure_dir = output_file.parent / "count_data" if measure_dir.exists(): shutil.rmtree(measure_dir) measure_dir.mkdir(exist_ok=True, mode=0o775) - logger.info(f'Reading in output data from {str(output_file)}.') + logger.info(f"Reading in output data from {str(output_file)}.") data, keyspace = process_results.read_data(output_file, single_run) - logger.info(f'Filtering incomplete data from outputs.') + logger.info(f"Filtering incomplete data from outputs.") rows = len(data) data = process_results.filter_out_incomplete(data, keyspace) new_rows = len(data) - logger.info(f'Filtered {rows - new_rows} from data due to incomplete information. {new_rows} remaining.') + logger.info( + f"Filtered {rows - new_rows} from data due to incomplete information. {new_rows} remaining." + ) data = process_results.aggregate_over_seed(data) - logger.info(f'Computing raw count and proportion data.') + logger.info(f"Computing raw count and proportion data.") measure_data = process_results.make_measure_data(data) - logger.info(f'Writing raw count and proportion data to {str(measure_dir)}') + logger.info(f"Writing raw count and proportion data to {str(measure_dir)}") measure_data.dump(measure_dir) - logger.info('**DONE**') + logger.info("**DONE**") diff --git a/src/vivarium_sodium_reduction/utilities.py b/src/vivarium_sodium_reduction/utilities.py index ff2d1bf..d45b959 100644 --- a/src/vivarium_sodium_reduction/utilities.py +++ b/src/vivarium_sodium_reduction/utilities.py @@ -50,14 +50,16 @@ def delete_if_exists(*paths: Union[Path, List[Path]], confirm=False): # Assumes all paths have the same root dir root = existing_paths[0].parent names = [p.name for p in existing_paths] - click.confirm(f"Existing files {names} found in directory {root}. Do you want to delete and replace?", - abort=True) + click.confirm( + f"Existing files {names} found in directory {root}. Do you want to delete and replace?", + abort=True, + ) for p in existing_paths: - logger.info(f'Deleting artifact at {str(p)}.') + logger.info(f"Deleting artifact at {str(p)}.") p.unlink() -def read_data_by_draw(artifact_path: str, key : str, draw: int) -> pd.DataFrame: +def read_data_by_draw(artifact_path: str, key: str, draw: int) -> pd.DataFrame: """Reads data from the artifact on a per-draw basis. This is necessary for Low Birthweight Short Gestation (LBWSG) data. @@ -72,32 +74,34 @@ def read_data_by_draw(artifact_path: str, key : str, draw: int) -> pd.DataFrame: """ key = key.replace(".", "/") - with pd.HDFStore(artifact_path, mode='r') as store: - index = store.get(f'{key}/index') - draw = store.get(f'{key}/draw_{draw}') + with pd.HDFStore(artifact_path, mode="r") as store: + index = store.get(f"{key}/index") + draw = store.get(f"{key}/draw_{draw}") draw = draw.rename("value") data = pd.concat([index, draw], axis=1) - data = data.drop(columns='location') + data = data.drop(columns="location") data = pivot_categorical(data) - data[project_globals.LBWSG_MISSING_CATEGORY.CAT] = project_globals.LBWSG_MISSING_CATEGORY.EXPOSURE + data[project_globals.LBWSG_MISSING_CATEGORY.CAT] = ( + project_globals.LBWSG_MISSING_CATEGORY.EXPOSURE + ) return data def get_norm( - mean: float, - sd: float = None, - ninety_five_pct_confidence_interval: Tuple[float, float] = None + mean: float, + sd: float = None, + ninety_five_pct_confidence_interval: Tuple[float, float] = None, ) -> stats.norm: sd = _get_standard_deviation(mean, sd, ninety_five_pct_confidence_interval) return stats.norm(loc=mean, scale=sd) def get_truncnorm( - mean: float, - sd: float = None, - ninety_five_pct_confidence_interval: Tuple[float, float] = None, - lower_clip: float = 0.0, - upper_clip: float = 1.0 + mean: float, + sd: float = None, + ninety_five_pct_confidence_interval: Tuple[float, float] = None, + lower_clip: float = 0.0, + upper_clip: float = 1.0, ) -> stats.norm: sd = _get_standard_deviation(mean, sd, ninety_five_pct_confidence_interval) a = (lower_clip - mean) / sd if sd else mean - 1e-03 @@ -106,12 +110,16 @@ def get_truncnorm( def _get_standard_deviation( - mean: float, sd: float, ninety_five_pct_confidence_interval: Tuple[float, float] + mean: float, sd: float, ninety_five_pct_confidence_interval: Tuple[float, float] ) -> float: if sd is None and ninety_five_pct_confidence_interval is None: - raise ValueError("Must provide either a standard deviation or a 95% confidence interval.") + raise ValueError( + "Must provide either a standard deviation or a 95% confidence interval." + ) if sd is not None and ninety_five_pct_confidence_interval is not None: - raise ValueError("Cannot provide both a standard deviation and a 95% confidence interval.") + raise ValueError( + "Cannot provide both a standard deviation and a 95% confidence interval." + ) if ninety_five_pct_confidence_interval is not None: lower = ninety_five_pct_confidence_interval[0] upper = ninety_five_pct_confidence_interval[1] @@ -126,8 +134,9 @@ def _get_standard_deviation( return sd -def get_lognorm_from_quantiles(median: float, lower: float, upper: float, - quantiles: Tuple[float, float] = (0.025, 0.975)) -> stats.lognorm: +def get_lognorm_from_quantiles( + median: float, lower: float, upper: float, quantiles: Tuple[float, float] = (0.025, 0.975) +) -> stats.lognorm: """Returns a frozen lognormal distribution with the specified median, such that (lower, upper) are approximately equal to the quantiles with ranks (quantile_ranks[0], quantile_ranks[1]). @@ -149,17 +158,21 @@ def get_lognorm_from_quantiles(median: float, lower: float, upper: float, norm_quantiles = np.log([lower, upper]) # standard deviation of Y = log(X) computed from the above quantiles for Y # and the corresponding standard normal quantiles - sigma = (norm_quantiles[1] - norm_quantiles[0]) / (stdnorm_quantiles[1] - stdnorm_quantiles[0]) + sigma = (norm_quantiles[1] - norm_quantiles[0]) / ( + stdnorm_quantiles[1] - stdnorm_quantiles[0] + ) # Frozen lognormal distribution for X = exp(Y) # (s=sigma is the shape parameter; the scale parameter is exp(mu), which equals the median) return stats.lognorm(s=sigma, scale=median) -def get_random_variable_draws(number: int, seeded_distribution: SeededDistribution) -> np.array: +def get_random_variable_draws( + number: int, seeded_distribution: SeededDistribution +) -> np.array: return np.array([get_random_variable(x, seeded_distribution) for x in range(number)]) def get_random_variable(draw: int, seeded_distribution: SeededDistribution) -> float: seed, distribution = seeded_distribution - np.random.seed(get_hash(f'{seed}_draw_{draw}')) + np.random.seed(get_hash(f"{seed}_draw_{draw}")) return distribution.rvs() diff --git a/update_readme.py b/update_readme.py index d1d72a8..526e23d 100644 --- a/update_readme.py +++ b/update_readme.py @@ -1,6 +1,7 @@ """ This script updates the README.rst file with the latest information about the project. It is intended to be run from the github "update README" workflow. """ + import json import re