Skip to content

Commit

Permalink
Allow Computing Standard Errors Across Runs in `compute_summary_stati…
Browse files Browse the repository at this point in the history
…stics` (#1558)

* implement use_standard_error in compute_summary_statistics

* add control (and warning) in relation to the arguments

* Update src/tlo/analysis/utils.py

* better control of arguments

* better control of arguments

* get number of runs per draw correctly

* add test
  • Loading branch information
tbhallett authored Jan 8, 2025
1 parent b24161b commit b94f99e
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 6 deletions.
34 changes: 28 additions & 6 deletions src/tlo/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats as st
import squarify

from tlo import Date, Simulation, logging, util
Expand Down Expand Up @@ -360,8 +361,9 @@ def generate_series(dataframe: pd.DataFrame) -> pd.Series:

def compute_summary_statistics(
results: pd.DataFrame,
central_measure: Literal["mean", "median"] = "median",
central_measure: Union[Literal["mean", "median"], None] = None,
width_of_range: float = 0.95,
use_standard_error: bool = False,
only_central: bool = False,
collapse_columns: bool = False,
) -> pd.DataFrame:
Expand All @@ -371,13 +373,23 @@ def compute_summary_statistics(
measure of the median and a 95% interval range.
:param results: The dataframe of results to compute summary statistics of.
:param central_measure: The name of the central measure to use - either 'mean' or 'median'.
:param central_measure: The name of the central measure to use - either 'mean' or 'median' (defaults to 'median')
:param width_of_range: The width of the range to compute the statistics (e.g. 0.95 for the 95% interval).
:param use_standard_error: Whether the range should represent the standard error; otherwise it is just a
description of the variation of runs. If selected, then the central measure is always the mean.
:param collapse_columns: Whether to simplify the columnar index if there is only one run (cannot be done otherwise).
:param only_central: Whether to only report the central value (dropping the range).
:return: A dataframe with computed summary statistics.
"""

if use_standard_error:
if not central_measure == 'mean':
warnings.warn("When using 'standard-error' the central measure in the summary statistics is always the mean.")
central_measure = 'mean'
elif central_measure is None:
# If no argument is provided for 'central_measure' (and not using standard-error), default to using 'median'
central_measure = 'median'

stats = dict()
grouped_results = results.groupby(axis=1, by='draw', sort=False)

Expand All @@ -388,9 +400,19 @@ def compute_summary_statistics(
else:
raise ValueError(f"Unknown stat: {central_measure}")

lower_quantile = (1. - width_of_range) / 2.
stats["lower"] = grouped_results.quantile(lower_quantile)
stats["upper"] = grouped_results.quantile(1 - lower_quantile)
if not use_standard_error:
lower_quantile = (1. - width_of_range) / 2.
stats["lower"] = grouped_results.quantile(lower_quantile)
stats["upper"] = grouped_results.quantile(1 - lower_quantile)
else:
# Use standard error concept whereby we're using the intervals to express a 95% CI on the value of the mean.
# This will make width of uncertainty become narrower with more runs.
std_deviation = grouped_results.std()
num_runs_per_draw = grouped_results.size().T
std_error = std_deviation.div(np.sqrt(num_runs_per_draw))
z_value = st.norm.ppf(1 - (1. - width_of_range) / 2.)
stats["lower"] = stats['central'] - z_value * std_error
stats["upper"] = stats['central'] + z_value * std_error

summary = pd.concat(stats, axis=1)
summary.columns = summary.columns.swaplevel(1, 0)
Expand Down
46 changes: 46 additions & 0 deletions tests/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,52 @@ def test_compute_summary_statistics():
summarize(results_one_draw, collapse_columns=True)
)


def test_compute_summary_statistics_use_standard_error():
"""Check computation of standard error statistics."""

results_multiple_draws = pd.DataFrame(
columns=pd.MultiIndex.from_tuples(
[
("DrawA", "DrawA_Run1"),
("DrawA", "DrawA_Run2"),
("DrawB", "DrawB_Run1"),
("DrawB", "DrawB_Run2"),
("DrawC", "DrawC_Run1"),
("DrawC", "DrawC_Run2"),
],
names=("draw", "run"),
),
index=["TimePoint0", "TimePoint1", "TimePoint2", "TimePoint3"],
data=np.array([[0, 21, 1000, 2430, 111, 30], # <-- randomly chosen numbers
[9, 22, 10440, 1960, 2222, 40],
[4, 23, 10200, 1989, 3333, 50],
[555, 24, 1000, 2022, 4444, 60]
]),
)

# Compute summary using standard error
summary = compute_summary_statistics(results_multiple_draws, use_standard_error=True)

# Compute expectation for what the standard should be for Draw A
mean = results_multiple_draws['DrawA'].mean(axis=1)
se = results_multiple_draws['DrawA'].std(axis=1) / np.sqrt(2)
expectation_for_draw_a = pd.DataFrame(
columns=pd.Index(["lower", "central", "upper"], name="stat"),
index=["TimePoint0", "TimePoint1", "TimePoint2", "TimePoint3"],
data=np.array(
[
mean - 1.96 * se,
mean,
mean + 1.96 * se,
]
).T,
)

# Check actual computation matches expectation
pd.testing.assert_frame_equal(expectation_for_draw_a, summary['DrawA'], rtol=1e-3)


def test_control_loggers_from_same_module_independently(seed, tmpdir):
"""Check that detailed/summary loggers in the same module can configured independently."""

Expand Down

0 comments on commit b94f99e

Please sign in to comment.