Skip to content

Commit

Permalink
dispatch approach + fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelBrand1 committed Dec 20, 2024
1 parent 1069aaf commit ac290b1
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 44 deletions.
92 changes: 57 additions & 35 deletions pipeline/src/constructors/remake_latent_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ function remake_latent_model(
#Baseline choices
prior_dict = make_model_priors(pipeline)
igp = inference_config["igp"]
latent_model_name = inference_config["latent_namemodels"].first
target_std, target_autocorr = latent_model_name == "ar" ?
default_latent_model = inference_config["latent_namemodels"].second
target_std, target_autocorr = default_latent_model isa AR ?
_make_target_std_and_autocorr(igp; stationary = true) :
_make_target_std_and_autocorr(igp; stationary = false)

return _implement_latent_process(
target_std, target_autocorr, latent_model_name, pipeline)
target_std, target_autocorr, default_latent_model, pipeline)
end

"""
Expand All @@ -49,14 +49,34 @@ For Direct Infections process `log(I_t)` in a single time step a fluctuation of
The autocorrelation is expected to be 0.1.
"""
function _make_target_std_and_autocorr(igp; stationary::Bool)
if igp == Renewal
return stationary ? (0.75, 0.1) : (0.025, 0.1)
elseif igp == ExpGrowthRate
return stationary ? (0.1, 0.1) : (0.005, 0.1)
elseif igp == DirectInfections
return stationary ? (1.0, 0.5) : (0.025, 0.1)
end
function _make_target_std_and_autocorr(::Type{Renewal}; stationary::Bool)
return stationary ? (0.75, 0.1) : (0.025, 0.1)
end

function _make_target_std_and_autocorr(::Type{ExpGrowthRate}; stationary::Bool)
return stationary ? (0.1, 0.1) : (0.005, 0.1)
end

function _make_target_std_and_autocorr(::Type{DirectInfections}; stationary::Bool)
return stationary ? (1.0, 0.5) : (0.025, 0.1)
end

function _make_new_prior_dict(target_std, target_autocorr,
pipeline::AbstractRtwithoutRenewalPipeline; beta_eff_sample_size)
#Get default priors
prior_dict = make_model_priors(pipeline)
#Adjust priors based on target autocorrelation and standard deviation
damp_prior = Beta(target_autocorr * beta_eff_sample_size,
(1 - target_autocorr) * beta_eff_sample_size)
corr_corrected_noise_prior = HalfNormal(target_std * sqrt(1 - target_autocorr^2))
noise_prior = HalfNormal(target_std)
init_prior = prior_dict["transformed_process_init_prior"]
return Dict(
"transformed_process_init_prior" => init_prior,
"corr_corrected_noise_prior" => corr_corrected_noise_prior,
"noise_prior" => noise_prior,
"damp_param_prior" => damp_prior
)
end

"""
Expand All @@ -73,31 +93,33 @@ However, for priors this should get the expected order of magnitude.
"""
function _implement_latent_process(
target_std, target_autocorr, latent_model_name, pipeline; beta_eff_sample_size = 10)
target_std, target_autocorr, default_latent_model, pipeline; beta_eff_sample_size = 10)
prior_dict = make_model_priors(pipeline)
damp_prior = Beta(target_autocorr * beta_eff_sample_size,
(1 - target_autocorr) * beta_eff_sample_size)
corr_corrected_noise_std = HalfNormal(target_std * sqrt(1 - target_autocorr^2))
noise_std = HalfNormal(target_std)
init_prior = prior_dict["transformed_process_init_prior"]
if latent_model_name == "diff_ar"
_ar = AR(damp_priors = [damp_prior],
std_prior = corr_corrected_noise_std,
init_priors = [init_prior])
diff_ar = DiffLatentModel(;
model = _ar, init_priors = [init_prior])
return diff_ar
elseif latent_model_name == "ar"
ar = AR(damp_priors = [damp_prior],
std_prior = corr_corrected_noise_std,
init_priors = [init_prior])
return ar
elseif latent_model_name == "rw"
rw = RandomWalk(
std_prior = noise_std,
init_prior = init_prior)
return rw
end
new_priors = _make_new_prior_dict(
target_std, target_autocorr, pipeline; beta_eff_sample_size)

return _make_latent(default_latent_model, new_priors)
end

function _make_latent(::AR, new_priors)
damp_prior = new_priors["damp_param_prior"]
corr_corrected_noise_std = new_priors["corr_corrected_noise_prior"]
init_prior = new_priors["transformed_process_init_prior"]
return AR(damp_priors = [damp_prior],
std_prior = corr_corrected_noise_std,
init_priors = [init_prior])
end

function _make_latent(::DiffLatentModel, new_priors)
init_prior = new_priors["transformed_process_init_prior"]
ar = _make_latent(AR(), new_priors)
return DiffLatentModel(; model = ar, init_priors = [init_prior])
end

function _make_latent(::RandomWalk, new_priors)
noise_std = new_priors["noise_prior"]
init_prior = new_priors["transformed_process_init_prior"]
return RandomWalk(std_prior = noise_std, init_prior = init_prior)
end

"""
Expand Down
20 changes: 11 additions & 9 deletions pipeline/test/constructors/remake_latent_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,49 +7,51 @@
)
end
pipeline = MockPipeline()

ar = AR()
diff_ar = DiffLatentModel(model = ar)
rw = RandomWalk()
@testset "diff_ar model" begin
inference_config = Dict(
"igp" => ExpGrowthRate, "latent_namemodels" => ("diff_ar" => "diff_ar"))
"igp" => ExpGrowthRate, "latent_namemodels" => Pair("diff_ar", diff_ar))
model = remake_latent_model(inference_config, pipeline)
@test model isa DiffLatentModel
@test model.model isa AR

inference_config = Dict(
"igp" => DirectInfections, "latent_namemodels" => ("diff_ar" => "diff_ar"))
"igp" => DirectInfections, "latent_namemodels" => Pair("diff_ar", diff_ar))
model = remake_latent_model(inference_config, pipeline)
@test model isa DiffLatentModel
@test model.model isa AR
end

@testset "ar model" begin
inference_config = Dict("igp" => Renewal, "latent_namemodels" => Pair("ar", "ar"))
inference_config = Dict("igp" => Renewal, "latent_namemodels" => Pair("ar", ar))
model = remake_latent_model(inference_config, pipeline)
@test model isa AR

inference_config = Dict(
"igp" => ExpGrowthRate, "latent_namemodels" => Pair("ar", "ar"))
"igp" => ExpGrowthRate, "latent_namemodels" => Pair("ar", ar))
model = remake_latent_model(inference_config, pipeline)
@test model isa AR

inference_config = Dict(
"igp" => DirectInfections, "latent_namemodels" => Pair("ar", "ar"))
"igp" => DirectInfections, "latent_namemodels" => Pair("ar", ar))
model = remake_latent_model(inference_config, pipeline)
@test model isa AR
end

@testset "rw model" begin
inference_config = Dict("igp" => Renewal, "latent_namemodels" => Pair("rw", "rw"))
inference_config = Dict("igp" => Renewal, "latent_namemodels" => Pair("rw", rw))
model = remake_latent_model(inference_config, pipeline)
@test model isa RandomWalk

inference_config = Dict(
"igp" => ExpGrowthRate, "latent_namemodels" => Pair("rw", "rw"))
"igp" => ExpGrowthRate, "latent_namemodels" => Pair("rw", rw))
model = remake_latent_model(inference_config, pipeline)
@test model isa RandomWalk

inference_config = Dict(
"igp" => DirectInfections, "latent_namemodels" => Pair("rw", "rw"))
"igp" => DirectInfections, "latent_namemodels" => Pair("rw", rw))
model = remake_latent_model(inference_config, pipeline)
@test model isa RandomWalk
end
Expand Down

0 comments on commit ac290b1

Please sign in to comment.