Skip to content

Commit

Permalink
Issue 559: Diagnostic analysis over all inference runs (#560)
Browse files Browse the repository at this point in the history
* Create make_mcmc_diagnostic_dataframe.jl

* reorg scripts and add more success/fail analysis

* Add function to get run info to avoid DRY

* Add function to do diagnostics

* export new func

* update SI

* Issue 561: Soft min transformation (#562)

Also removed unnecessary call to `fetch`

* base values on pipeline types

* breakdown mcmc convergence test function

Adds more stats and a unit test
  • Loading branch information
SamuelBrand1 authored Dec 19, 2024
1 parent 9b467f2 commit 7dde0e9
Show file tree
Hide file tree
Showing 14 changed files with 265 additions and 68 deletions.
1 change: 1 addition & 0 deletions manuscript/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964"
DrWatson = "634d3b9d-ee7a-5ddf-bec9-22491ea816e1"
Expand Down
43 changes: 24 additions & 19 deletions manuscript/index.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ index_location = @__DIR__()
Pkg.activate(index_location)
Pkg.resolve()
Pkg.instantiate()
Pkg.add(["CairoMakie", "JLD2", "DataFramesMeta", "DrWatson"])
Pkg.add(["CairoMakie", "JLD2", "DataFramesMeta", "DrWatson", "CSV"])
using DataFramesMeta, JLD2
using DataFramesMeta, JLD2, CSV
```

Expand Down Expand Up @@ -66,23 +66,28 @@ We noted that for a substantial number of the model configurations there were mo
priorpred_dir = joinpath(@__DIR__(),"..", "pipeline/data/priorpredictive/")
priorpred_datafiles = readdir(priorpred_dir) |>
fns -> filter(fn -> contains(fn, ".jld2"), fns) #filter for .jld2 files
priorpred_outcomes_df = mapreduce(vcat, priorpred_datafiles) do fn
D = load(joinpath(priorpred_dir, fn))
igp = D["inference_config"]["igp"]
latent_model = D["inference_config"]["latent_model"]
gi_mean = D["inference_config"]["gi_mean"]
T1, T2 = split(D["inference_config"]["tspan"], "_")
runsuccess = D["priorpredictive"] .== "Pass"
df = DataFrame(
infection_gen_proc = igp,
latent_model = latent_model,
gi_mean = gi_mean,
T1 = T1,
T2 = T2,
T_diff = parse(Int, T2) - parse(Int, T1),
runsuccess = runsuccess,
)
priorpred_outcomes_df = DataFrame()
if !isfile(joinpath(index_location, "pass_fail_rdn1.csv"))
priorpred_outcomes_df = mapreduce(vcat, priorpred_datafiles) do fn
D = load(joinpath(priorpred_dir, fn))
igp = D["inference_config"]["igp"]
latent_model = D["inference_config"]["latent_model"]
gi_mean = D["inference_config"]["gi_mean"]
T1, T2 = split(D["inference_config"]["tspan"], "_")
runsuccess = D["priorpredictive"] .== "Pass"
df = DataFrame(
infection_gen_proc = igp,
latent_model = latent_model,
gi_mean = gi_mean,
T1 = T1,
T2 = T2,
T_diff = parse(Int, T2) - parse(Int, T1),
runsuccess = runsuccess,
)
end
CSV.write(joinpath(index_location, "pass_fail_rdn1.csv"), priorpred_outcomes_df)
else
priorpred_outcomes_df = CSV.File(joinpath(index_location, "pass_fail_rdn1.csv")) |> DataFrame
end
```

Expand Down
20 changes: 20 additions & 0 deletions pipeline/scripts/create_mcmc_diagonostic_script.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
## Analysis of the prediction dataframes for mcmc diagnostics
diagnostic_df = mapreduce(vcat, scenarios) do scenario
mapreduce(vcat, true_gi_means) do true_gi_mean
target_str = "truth_gi_mean_" * string(true_gi_mean) * "_"
files = readdir(datadir("epiaware_observables/" * scenario)) |>
strs -> filter(s -> occursin("jld2", s), strs) |>
strs -> filter(s -> occursin(target_str, s), strs)

mapreduce(vcat, files) do filename
output = load(joinpath(datadir("epiaware_observables"), scenario, filename))
try
make_mcmc_diagnostic_dataframe(output, true_gi_mean, scenario)
catch e
end
end
end
end

## Save the mcmc diagnostics
CSV.write("manuscript/inference_diagnostics_rnd2.csv", diagnostic_df)
26 changes: 24 additions & 2 deletions pipeline/scripts/create_postprocessing_dataframes.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,24 @@
using EpiAwarePipeline, EpiAware, JLD2, DrWatson, DataFramesMeta, CSV
using EpiAwarePipeline, EpiAware, JLD2, DrWatson, DataFramesMeta, CSV, MCMCChains

pipelinetypes = [
MeasuresOutbreakPipeline,
SmoothOutbreakPipeline,
SmoothEndemicPipeline,
RoughEndemicPipeline
]
## Define scenarios
scenarios = ["measures_outbreak", "smooth_outbreak", "smooth_endemic", "rough_endemic"]

scenarios = pipelinetypes .|> pipetype -> pipetype().prefix

## Define true GI means
# Errors if not the same for all pipeline types
true_gi_means = map(pipelinetypes) do pipetype
make_gi_params(pipetype())["gi_means"]
end |>
ensemble_gi_means -> all([gi_means == ensemble_gi_means[1]
for gi_means in ensemble_gi_means]) ?
ensemble_gi_means[1] :
error("GI means are not the same")

if !isfile(plotsdir("plotting_data/predictions.csv"))
@info "Prediction dataframe does not exist, generating now"
Expand All @@ -12,3 +29,8 @@ if !isfile(plotsdir("plotting_data/truthdata.csv"))
@info "Truth dataframe does not exist, generating now"
include("create_truth_dataframe.jl")
end

if !isfile("manuscript/inference_pass_fail_rnd2.csv")
@info "Diagnostic dataframe does not exist, generating now"
include("create_mcmc_diagonostic_script.jl")
end
26 changes: 11 additions & 15 deletions pipeline/scripts/create_prediction_dataframe.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
## Define true GI means
true_gi_means = [2.0, 10.0, 20.0]

## Load the prediction dataframes or record fails
## Structure to record success/failure
success_configs = Dict[]
failed_configs = Dict[]

## Analysis of the prediction dataframes
dfs = mapreduce(vcat, scenarios) do scenario
mapreduce(vcat, true_gi_means) do true_gi_mean
target_str = "truth_gi_mean_" * string(true_gi_mean) * "_"
Expand All @@ -14,39 +13,36 @@ dfs = mapreduce(vcat, scenarios) do scenario
mapreduce(vcat, files) do filename
output = load(joinpath(datadir("epiaware_observables"), scenario, filename))
try
push!(success_configs,
merge(output["inference_config"], Dict("runsuccess" => true)))
make_prediction_dataframe_from_output(output, true_gi_mean, scenario)
catch e
@warn "Error in $filename"
push!(failed_configs, output["inference_config"])
push!(failed_configs,
merge(output["inference_config"], Dict("runsuccess" => false)))
return DataFrame()
end
end
end
end

## Gather the failed data
failed_df = mapreduce(vcat, failed_configs) do D
## Gather the pass/failed data
pass_fail_df = mapreduce(vcat, [success_configs; failed_configs]) do D
igp = D["igp"] |> str -> split(str, ".")[end]
latent_model = D["latent_model"]
gi_mean = D["gi_mean"]
T1, T2 = split(D["tspan"], "_")
runsuccess = D["priorpredictive"] .== "Pass"
df = DataFrame(
infection_gen_proc = igp,
latent_model = latent_model,
gi_mean = gi_mean,
T1 = T1,
T2 = T2,
T_diff = parse(Int, T2) - parse(Int, T1),
runsuccess = runsuccess
runsuccess = D["runsuccess"]
)
end

##
grped_failed_df = failed_df |>
df -> @groupby(df, :infection_gen_proc, :latent_model) |>
gd -> @combine(gd, :n_fail=sum(1 .- :runsuccess))

## Save the prediction and failed dataframes
CSV.write(plotsdir("plotting_data/predictions.csv"), dfs)
CSV.write(plotsdir("plotting_data/failed_preds.csv"), failed_df)
CSV.write("manuscript/inference_pass_fail_rnd2.csv", pass_fail_df)
2 changes: 1 addition & 1 deletion pipeline/src/EpiAwarePipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export score_parameters, simple_crps, summarise_crps

# Exported functions: Analysis functions for constructing dataframes
export make_prediction_dataframe_from_output, make_truthdata_dataframe,
make_scoring_dataframe_from_output
make_scoring_dataframe_from_output, make_mcmc_diagnostic_dataframe

# Exported functions: Make main plots
export figureone, figuretwo
Expand Down
2 changes: 2 additions & 0 deletions pipeline/src/analysis/analysis.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
include("config_mappings.jl")
include("make_truthdata_dataframe.jl")
include("make_prediction_dataframe_from_output.jl")
include("make_scoring_dataframe_from_output.jl")
include("make_mcmc_diagnostic_dataframe.jl")
33 changes: 33 additions & 0 deletions pipeline/src/analysis/config_mappings.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Extracts and returns relevant information from the given inference configuration dictionary.
# Returns
- `NamedTuple`: A named tuple containing the following fields:
- `igp_model::String`: The IGP model name extracted from the configuration.
- `latent_model::String`: The latent model name from the configuration.
- `used_gi_mean::Float64`: The mean generation interval (GI) used in the configuration.
- `used_gi_std::Float64`: The standard deviation of the generation interval (GI) used in the configuration.
- `start_time::Int`: The start time parsed from the configuration's time span.
- `reference_time::Int`: The reference time parsed from the configuration's time span.
- `used_gi_means::Vector{Float64}`: A vector of GI means, either a single value if the IGP model is "Renewal" or a list of values generated by `make_gi_params` otherwise.
"""
function _get_info_from_config(inference_config)
#Get the scenario, IGP model, latent model and true mean GI
igp_model = inference_config["igp"] |> igp_name -> split(igp_name, ".")[end]
latent_model = inference_config["latent_model"]
used_gi_mean = inference_config["gi_mean"]
used_gi_std = inference_config["gi_std"]
(start_time, reference_time) = inference_config["tspan"] |>
tspan -> split(tspan, "_") |>
tspan -> (
parse(Int, tspan[1]), parse(Int, tspan[2]))

#Get the quantiles for the targets across the gi mean scenarios
#if Renewal model, then we use the underlying epi model
#otherwise we use the epi datas to loop over different gi mean implications
used_gi_means = igp_model == "Renewal" ?
[used_gi_mean] :
make_gi_params(EpiAwareExamplePipeline())["gi_means"]
return (; igp_model, latent_model, used_gi_mean, used_gi_std,
start_time, reference_time, used_gi_means)
end
76 changes: 76 additions & 0 deletions pipeline/src/analysis/make_mcmc_diagnostic_dataframe.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
Collects the statistics of a vector `x` that are relevant for MCMC diagnostics.
"""
function _get_stats(x, threshold; pass_above = true)
if pass_above
return (; x_mean = mean(x), prop_pass = mean(x .>= threshold),
x_min = minimum(x), x_max = maximum(x))
else
return (; x_mean = mean(x), prop_pass = mean(x .<= threshold),
x_min = minimum(x), x_max = maximum(x))
end
end

"""
Collects the convergence statistics over the parameters that are not cluster factor.
"""
function _collect_stats(chn_nt, not_cluster_factor; bulk_ess_threshold,
tail_ess_threshold, rhat_diff_threshold)
ess_bulk = chn_nt.ess_bulk[not_cluster_factor] |> x -> _get_stats(x, bulk_ess_threshold)
ess_tail = chn_nt.ess_tail[not_cluster_factor] |> x -> _get_stats(x, tail_ess_threshold)
rhat_diff = abs.(chn_nt.rhat[not_cluster_factor] .- 1) |>
x -> _get_stats(x, rhat_diff_threshold; pass_above = false)
return (; ess_bulk, ess_tail, rhat_diff)
end

"""
Generate a DataFrame containing MCMC diagnostic metrics. The metrics are the proportion of
parameters that pass the bulk effective sample size (ESS) threshold, the proportion of
parameters that pass the tail ESS threshold, the proportion of parameters that pass the R-hat
absolute difference from 1 threshold, whether the model has a cluster factor parameter, and the tail ESS
of the cluster factor parameter.
# Arguments
- `output::Dict`: A dictionary containing the inference results.
- `bulk_ess_threshold::Int`: The threshold for bulk effective sample size (ESS). Default is 500.
- `tail_ess_threshold::Int`: The threshold for tail effective sample size (ESS). Default is 100.
- `rhat_diff_threshold::Float64`: The threshold for the difference of R-hat from 1. Default is 0.02.
"""
function make_mcmc_diagnostic_dataframe(
output, true_mean_gi, scenario; bulk_ess_threshold = 500,
tail_ess_threshold = 100, rhat_diff_threshold = 0.02)
#Get the scenario, IGP model, latent model and true mean GI
inference_config = output["inference_config"]
info = _get_info_from_config(inference_config)
#Get the convergence diagnostics
chn_nt = output["inference_results"].samples |> summarize |> summary -> summary.nt
cluster_factor_idxs = chn_nt.parameters .== Symbol("obs.cluster_factor")
has_cluster_factor = any(cluster_factor_idxs)
not_cluster_factor = .~cluster_factor_idxs
cluster_factor_tail = chn_nt.ess_tail[cluster_factor_idxs][1]
#Collect the statistics
stats_for_targets = _collect_stats(chn_nt, not_cluster_factor; bulk_ess_threshold,
tail_ess_threshold, rhat_diff_threshold)

#Create the dataframe
df = mapreduce(vcat, info.used_gi_means) do used_gi_mean
DataFrame(
Scenario = scenario,
igp_model = info.igp_model,
latent_model = info.latent_model,
True_GI_Mean = true_mean_gi,
used_gi_mean = used_gi_mean,
reference_time = info.reference_time,
has_cluster_factor = has_cluster_factor,
cluster_factor_tail = has_cluster_factor ? cluster_factor_tail : missing)
end
#Add stats columns
for key in keys(stats_for_targets)
stats = getfield(stats_for_targets, key)
df[!, string(key) * "_" * "mean"] .= stats.x_mean
df[!, string(key) * "_" * "prop_pass"] .= stats.prop_pass
df[!, string(key) * "_" * "min"] .= stats.x_min
df[!, string(key) * "_" * "max"] .= stats.x_max
end
return df
end
36 changes: 11 additions & 25 deletions pipeline/src/analysis/make_prediction_dataframe_from_output.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,42 +20,28 @@ function make_prediction_dataframe_from_output(
inference_config = output["inference_config"]
forecasts = output["forecast_results"]
#Get the scenario, IGP model, latent model and true mean GI
igp_model = inference_config["igp"] |> igp_name -> split(igp_name, ".")[end]
latent_model = inference_config["latent_model"]
used_gi_mean = inference_config["gi_mean"]
used_gi_std = inference_config["gi_std"]
(start_time, reference_time) = inference_config["tspan"] |>
tspan -> split(tspan, "_") |>
tspan -> (
parse(Int, tspan[1]), parse(Int, tspan[2]))

#Get the quantiles for the targets across the gi mean scenarios
#if Renewal model, then we use the underlying epi model
#otherwise we use the epi datas to loop over different gi mean implications
used_gi_means = igp_model == "Renewal" ?
[used_gi_mean] :
make_gi_params(EpiAwareExamplePipeline())["gi_means"]

used_epidatas = map(used_gi_means) do
_make_epidata(ḡ, used_gi_std; transformation = transformation)
info = _get_info_from_config(inference_config)
#Get the epi datas
used_epidatas = map(info.used_gi_means) do
_make_epidata(ḡ, info.used_gi_std; transformation = transformation)
end

#Generate the quantiles for the targets
preds = map(used_epidatas) do epi_data
generate_quantiles_for_targets(forecasts, epi_data, qs)
end

#Create the dataframe columnwise
df = mapreduce(vcat, preds, used_gi_means) do pred, used_gi_mean
df = mapreduce(vcat, preds, info.used_gi_means) do pred, used_gi_mean
mapreduce(vcat, keys(pred)) do target
target_mat = pred[target]
target_times = collect(1:size(target_mat, 1)) .+ (start_time - 1)
target_times = collect(1:size(target_mat, 1)) .+ (info.start_time - 1)
_df = DataFrame(target_times = target_times)
_df[!, "Scenario"] .= scenario
_df[!, "IGP_Model"] .= igp_model
_df[!, "Latent_Model"] .= latent_model
_df[!, "igp_model"] .= info.igp_model
_df[!, "latent_model"] .= info.latent_model
_df[!, "True_GI_Mean"] .= true_mean_gi
_df[!, "Used_GI_Mean"] .= used_gi_mean
_df[!, "Reference_Time"] .= reference_time
_df[!, "used_gi_mean"] .= used_gi_mean
_df[!, "reference_time"] .= info.reference_time
_df[!, "Target"] .= string(target)
# quantile predictions
for (j, q) in enumerate(qs)
Expand Down
22 changes: 21 additions & 1 deletion pipeline/src/constructors/make_observation_model.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Constructs an observation model for the given pipeline. This is the defualt method.
Constructs an observation model for the given pipeline. This is the default method.
# Arguments
- `pipeline::AbstractEpiAwarePipeline`: The pipeline for which the observation model is constructed.
Expand All @@ -18,3 +18,23 @@ function make_observation_model(pipeline::AbstractEpiAwarePipeline)
obs = LatentDelay(dayofweek_logit_ascert, delay_distribution)
return obs
end

const negC = -1e15
"""
Soft minimum function for a smooth transition from `x -> x` to a maximum value of 1e15.
"""
_softmin(x) = -logaddexp(negC, -x)

function make_observation_model(pipeline::AbstractRtwithoutRenewalPipeline)
default_params = make_default_params(pipeline)
#Model for ascertainment based on day of the week
dayofweek_logit_ascert = ascertainment_dayofweek(
NegativeBinomialError(cluster_factor_prior = HalfNormal(default_params["cluster_factor"]));
transform = (x, y) -> _softmin.(x .* y))

#Default continuous-time model for latent delay in observations
delay_distribution = make_delay_distribution(pipeline)
#Model for latent delay in observations
obs = LatentDelay(dayofweek_logit_ascert, delay_distribution)
return obs
end
Loading

0 comments on commit 7dde0e9

Please sign in to comment.