From ac290b17e24e7178ab459f29879920533b61371f Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Fri, 20 Dec 2024 13:58:55 +0000 Subject: [PATCH] dispatch approach + fix unit tests --- .../src/constructors/remake_latent_model.jl | 92 ++++++++++++------- .../test/constructors/remake_latent_model.jl | 20 ++-- 2 files changed, 68 insertions(+), 44 deletions(-) diff --git a/pipeline/src/constructors/remake_latent_model.jl b/pipeline/src/constructors/remake_latent_model.jl index acc68c3bf..53d36d422 100644 --- a/pipeline/src/constructors/remake_latent_model.jl +++ b/pipeline/src/constructors/remake_latent_model.jl @@ -18,13 +18,13 @@ function remake_latent_model( #Baseline choices prior_dict = make_model_priors(pipeline) igp = inference_config["igp"] - latent_model_name = inference_config["latent_namemodels"].first - target_std, target_autocorr = latent_model_name == "ar" ? + default_latent_model = inference_config["latent_namemodels"].second + target_std, target_autocorr = default_latent_model isa AR ? _make_target_std_and_autocorr(igp; stationary = true) : _make_target_std_and_autocorr(igp; stationary = false) return _implement_latent_process( - target_std, target_autocorr, latent_model_name, pipeline) + target_std, target_autocorr, default_latent_model, pipeline) end """ @@ -49,14 +49,34 @@ For Direct Infections process `log(I_t)` in a single time step a fluctuation of The autocorrelation is expected to be 0.1. """ -function _make_target_std_and_autocorr(igp; stationary::Bool) - if igp == Renewal - return stationary ? (0.75, 0.1) : (0.025, 0.1) - elseif igp == ExpGrowthRate - return stationary ? (0.1, 0.1) : (0.005, 0.1) - elseif igp == DirectInfections - return stationary ? (1.0, 0.5) : (0.025, 0.1) - end +function _make_target_std_and_autocorr(::Type{Renewal}; stationary::Bool) + return stationary ? (0.75, 0.1) : (0.025, 0.1) +end + +function _make_target_std_and_autocorr(::Type{ExpGrowthRate}; stationary::Bool) + return stationary ? (0.1, 0.1) : (0.005, 0.1) +end + +function _make_target_std_and_autocorr(::Type{DirectInfections}; stationary::Bool) + return stationary ? (1.0, 0.5) : (0.025, 0.1) +end + +function _make_new_prior_dict(target_std, target_autocorr, + pipeline::AbstractRtwithoutRenewalPipeline; beta_eff_sample_size) + #Get default priors + prior_dict = make_model_priors(pipeline) + #Adjust priors based on target autocorrelation and standard deviation + damp_prior = Beta(target_autocorr * beta_eff_sample_size, + (1 - target_autocorr) * beta_eff_sample_size) + corr_corrected_noise_prior = HalfNormal(target_std * sqrt(1 - target_autocorr^2)) + noise_prior = HalfNormal(target_std) + init_prior = prior_dict["transformed_process_init_prior"] + return Dict( + "transformed_process_init_prior" => init_prior, + "corr_corrected_noise_prior" => corr_corrected_noise_prior, + "noise_prior" => noise_prior, + "damp_param_prior" => damp_prior + ) end """ @@ -73,31 +93,33 @@ However, for priors this should get the expected order of magnitude. """ function _implement_latent_process( - target_std, target_autocorr, latent_model_name, pipeline; beta_eff_sample_size = 10) + target_std, target_autocorr, default_latent_model, pipeline; beta_eff_sample_size = 10) prior_dict = make_model_priors(pipeline) - damp_prior = Beta(target_autocorr * beta_eff_sample_size, - (1 - target_autocorr) * beta_eff_sample_size) - corr_corrected_noise_std = HalfNormal(target_std * sqrt(1 - target_autocorr^2)) - noise_std = HalfNormal(target_std) - init_prior = prior_dict["transformed_process_init_prior"] - if latent_model_name == "diff_ar" - _ar = AR(damp_priors = [damp_prior], - std_prior = corr_corrected_noise_std, - init_priors = [init_prior]) - diff_ar = DiffLatentModel(; - model = _ar, init_priors = [init_prior]) - return diff_ar - elseif latent_model_name == "ar" - ar = AR(damp_priors = [damp_prior], - std_prior = corr_corrected_noise_std, - init_priors = [init_prior]) - return ar - elseif latent_model_name == "rw" - rw = RandomWalk( - std_prior = noise_std, - init_prior = init_prior) - return rw - end + new_priors = _make_new_prior_dict( + target_std, target_autocorr, pipeline; beta_eff_sample_size) + + return _make_latent(default_latent_model, new_priors) +end + +function _make_latent(::AR, new_priors) + damp_prior = new_priors["damp_param_prior"] + corr_corrected_noise_std = new_priors["corr_corrected_noise_prior"] + init_prior = new_priors["transformed_process_init_prior"] + return AR(damp_priors = [damp_prior], + std_prior = corr_corrected_noise_std, + init_priors = [init_prior]) +end + +function _make_latent(::DiffLatentModel, new_priors) + init_prior = new_priors["transformed_process_init_prior"] + ar = _make_latent(AR(), new_priors) + return DiffLatentModel(; model = ar, init_priors = [init_prior]) +end + +function _make_latent(::RandomWalk, new_priors) + noise_std = new_priors["noise_prior"] + init_prior = new_priors["transformed_process_init_prior"] + return RandomWalk(std_prior = noise_std, init_prior = init_prior) end """ diff --git a/pipeline/test/constructors/remake_latent_model.jl b/pipeline/test/constructors/remake_latent_model.jl index 305eb9543..9f6d9f5f3 100644 --- a/pipeline/test/constructors/remake_latent_model.jl +++ b/pipeline/test/constructors/remake_latent_model.jl @@ -7,49 +7,51 @@ ) end pipeline = MockPipeline() - + ar = AR() + diff_ar = DiffLatentModel(model = ar) + rw = RandomWalk() @testset "diff_ar model" begin inference_config = Dict( - "igp" => ExpGrowthRate, "latent_namemodels" => ("diff_ar" => "diff_ar")) + "igp" => ExpGrowthRate, "latent_namemodels" => Pair("diff_ar", diff_ar)) model = remake_latent_model(inference_config, pipeline) @test model isa DiffLatentModel @test model.model isa AR inference_config = Dict( - "igp" => DirectInfections, "latent_namemodels" => ("diff_ar" => "diff_ar")) + "igp" => DirectInfections, "latent_namemodels" => Pair("diff_ar", diff_ar)) model = remake_latent_model(inference_config, pipeline) @test model isa DiffLatentModel @test model.model isa AR end @testset "ar model" begin - inference_config = Dict("igp" => Renewal, "latent_namemodels" => Pair("ar", "ar")) + inference_config = Dict("igp" => Renewal, "latent_namemodels" => Pair("ar", ar)) model = remake_latent_model(inference_config, pipeline) @test model isa AR inference_config = Dict( - "igp" => ExpGrowthRate, "latent_namemodels" => Pair("ar", "ar")) + "igp" => ExpGrowthRate, "latent_namemodels" => Pair("ar", ar)) model = remake_latent_model(inference_config, pipeline) @test model isa AR inference_config = Dict( - "igp" => DirectInfections, "latent_namemodels" => Pair("ar", "ar")) + "igp" => DirectInfections, "latent_namemodels" => Pair("ar", ar)) model = remake_latent_model(inference_config, pipeline) @test model isa AR end @testset "rw model" begin - inference_config = Dict("igp" => Renewal, "latent_namemodels" => Pair("rw", "rw")) + inference_config = Dict("igp" => Renewal, "latent_namemodels" => Pair("rw", rw)) model = remake_latent_model(inference_config, pipeline) @test model isa RandomWalk inference_config = Dict( - "igp" => ExpGrowthRate, "latent_namemodels" => Pair("rw", "rw")) + "igp" => ExpGrowthRate, "latent_namemodels" => Pair("rw", rw)) model = remake_latent_model(inference_config, pipeline) @test model isa RandomWalk inference_config = Dict( - "igp" => DirectInfections, "latent_namemodels" => Pair("rw", "rw")) + "igp" => DirectInfections, "latent_namemodels" => Pair("rw", rw)) model = remake_latent_model(inference_config, pipeline) @test model isa RandomWalk end