From 727a827ea557fb1e9a4aef91749be64500f80895 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Thu, 19 Dec 2024 12:46:02 +0000 Subject: [PATCH] breakdown mcmc convergence test function Adds more stats and a unit test --- .../make_mcmc_diagnostic_dataframe.jl | 44 +++++++++++++++---- .../make_mcmc_diagnostic_dataframe.jl | 37 ++++++++++++++++ pipeline/test/analysis/test_analysis.jl | 1 + 3 files changed, 74 insertions(+), 8 deletions(-) create mode 100644 pipeline/test/analysis/make_mcmc_diagnostic_dataframe.jl diff --git a/pipeline/src/analysis/make_mcmc_diagnostic_dataframe.jl b/pipeline/src/analysis/make_mcmc_diagnostic_dataframe.jl index 80ef7d27a..b51a5c706 100644 --- a/pipeline/src/analysis/make_mcmc_diagnostic_dataframe.jl +++ b/pipeline/src/analysis/make_mcmc_diagnostic_dataframe.jl @@ -1,3 +1,28 @@ +""" +Collects the statistics of a vector `x` that are relevant for MCMC diagnostics. +""" +function _get_stats(x, threshold; pass_above = true) + if pass_above + return (; x_mean = mean(x), prop_pass = mean(x .>= threshold), + x_min = minimum(x), x_max = maximum(x)) + else + return (; x_mean = mean(x), prop_pass = mean(x .<= threshold), + x_min = minimum(x), x_max = maximum(x)) + end +end + +""" +Collects the convergence statistics over the parameters that are not cluster factor. +""" +function _collect_stats(chn_nt, not_cluster_factor; bulk_ess_threshold, + tail_ess_threshold, rhat_diff_threshold) + ess_bulk = chn_nt.ess_bulk[not_cluster_factor] |> x -> _get_stats(x, bulk_ess_threshold) + ess_tail = chn_nt.ess_tail[not_cluster_factor] |> x -> _get_stats(x, tail_ess_threshold) + rhat_diff = abs.(chn_nt.rhat[not_cluster_factor] .- 1) |> + x -> _get_stats(x, rhat_diff_threshold; pass_above = false) + return (; ess_bulk, ess_tail, rhat_diff) +end + """ Generate a DataFrame containing MCMC diagnostic metrics. The metrics are the proportion of parameters that pass the bulk effective sample size (ESS) threshold, the proportion of @@ -23,6 +48,9 @@ function make_mcmc_diagnostic_dataframe( has_cluster_factor = any(cluster_factor_idxs) not_cluster_factor = .~cluster_factor_idxs cluster_factor_tail = chn_nt.ess_tail[cluster_factor_idxs][1] + #Collect the statistics + stats_for_targets = _collect_stats(chn_nt, not_cluster_factor; bulk_ess_threshold, + tail_ess_threshold, rhat_diff_threshold) #Create the dataframe df = mapreduce(vcat, info.used_gi_means) do used_gi_mean @@ -33,16 +61,16 @@ function make_mcmc_diagnostic_dataframe( True_GI_Mean = true_mean_gi, used_gi_mean = used_gi_mean, reference_time = info.reference_time, - bulk_ess_threshold = (chn_nt.ess_bulk[not_cluster_factor] .> - bulk_ess_threshold) |> - mean, - tail_ess_threshold = (chn_nt.ess_tail[not_cluster_factor] .> - tail_ess_threshold) |> - mean, - rhat_diff_threshold = (abs.(chn_nt.rhat[not_cluster_factor] .- 1) .< - rhat_diff_threshold) |> mean, has_cluster_factor = has_cluster_factor, cluster_factor_tail = has_cluster_factor ? cluster_factor_tail : missing) end + #Add stats columns + for key in keys(stats_for_targets) + stats = getfield(stats_for_targets, key) + df[!, string(key) * "_" * "mean"] .= stats.x_mean + df[!, string(key) * "_" * "prop_pass"] .= stats.prop_pass + df[!, string(key) * "_" * "min"] .= stats.x_min + df[!, string(key) * "_" * "max"] .= stats.x_max + end return df end diff --git a/pipeline/test/analysis/make_mcmc_diagnostic_dataframe.jl b/pipeline/test/analysis/make_mcmc_diagnostic_dataframe.jl new file mode 100644 index 000000000..b2473d918 --- /dev/null +++ b/pipeline/test/analysis/make_mcmc_diagnostic_dataframe.jl @@ -0,0 +1,37 @@ +@testset "test MCMC convergence analysis on toy obs model" begin + using JLD2, DataFramesMeta, Turing, EpiAware + # Reuse the local config + _output = load(joinpath(@__DIR__(), "test_data.jld2")) + inference_config = _output["inference_config"] + # Create a simple test model to test mcmc diagnostics via prior sampling + obs = make_observation_model(SmoothEndemicPipeline()) + @model function test_model() + x ~ filldist(Normal(0, 1), 20) + @submodel prefix="obs" y_t=generate_observations(obs, missing, exp.(x)) + end + n = 1000 + samples = sample(test_model(), Prior(), n) + + # Create a simple output to test the function + output = Dict( + "inference_config" => inference_config, + "inference_results" => (; samples,) + ) + + true_mean_gi = 10.0 + scenario = "rough_endemic" + df = make_mcmc_diagnostic_dataframe( + output, true_mean_gi, "rough_endemic") + # Check pass throughs + @test typeof(df) == DataFrame + @test size(df, 1) == 3 # Number of rows should match the length of used_gi_means + @test df[1, :Scenario] == scenario + @test df[1, :latent_model] == inference_config["latent_model"] + @test df[1, :True_GI_Mean] == true_mean_gi + # Prior sampling should be uncorrelated and meet all the convergence criteria + @test all(df[:, :ess_bulk_prop_pass] .== 1.0) + @test all(df[:, :ess_tail_prop_pass] .== 1.0) + @test all(df[:, :rhat_diff_prop_pass] .== 1.0) + @test all(df[:, :has_cluster_factor] .== true) + @test all(df[1, :cluster_factor_tail] .> n / 2) +end diff --git a/pipeline/test/analysis/test_analysis.jl b/pipeline/test/analysis/test_analysis.jl index 1e2608aea..6a45d2214 100644 --- a/pipeline/test/analysis/test_analysis.jl +++ b/pipeline/test/analysis/test_analysis.jl @@ -1,2 +1,3 @@ include("make_prediction_dataframe_from_output.jl") include("make_truthdata_dataframe.jl") +include("make_mcmc_diagnostic_dataframe.jl")