Skip to content

Commit

Permalink
Issue 405: model specific priors (#565)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelBrand1 authored Jan 6, 2025
1 parent 0d4095e commit 94b4643
Show file tree
Hide file tree
Showing 14 changed files with 411 additions and 207 deletions.
1 change: 1 addition & 0 deletions pipeline/plots/priorpredictive/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Prior predictive plots
3 changes: 2 additions & 1 deletion pipeline/src/EpiAwarePipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ export TruthSimulationConfig, InferenceConfig
export make_gi_params, make_inf_generating_processes, make_model_priors,
make_epiaware_name_latentmodel_pairs, make_Rt, make_truth_data_configs,
make_default_params, make_inference_configs, make_tspan, make_inference_method,
make_delay_distribution, make_delay_distribution, make_observation_model
make_delay_distribution, make_delay_distribution, make_observation_model,
remake_latent_model

# Exported functions: pipeline components
export do_truthdata, do_inference, do_pipeline
Expand Down
1 change: 1 addition & 0 deletions pipeline/src/constructors/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ include("make_tspan.jl")
include("make_default_params.jl")
include("make_delay_distribution.jl")
include("make_observation_model.jl")
include("remake_latent_model.jl")
2 changes: 1 addition & 1 deletion pipeline/src/constructors/make_model_priors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ deviation 1e-1.
"""
function make_model_priors(pipeline::AbstractEpiAwarePipeline)
transformed_process_init_prior = Normal(0.0, 0.25)
transformed_process_init_prior = Normal(0.0, 0.1)
std_prior = HalfNormal(0.025)
damp_param_prior = Beta(1, 9)
log_I0_prior = Normal(log(100.0), 1e-1)
Expand Down
130 changes: 130 additions & 0 deletions pipeline/src/constructors/remake_latent_model.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""
Constructs and returns a latent model based on the provided `inference_config` and `pipeline`.
The purpose of this function is to make adjustments to the latent model based on the
full `inference_config` provided.
The `pipeline` argument is used for dispatch purposes.
The prior decisions are based on the target standard deviation and autocorrelation of the latent process,
which are determined by the infection generating process (igp) and whether the latent process is stationary or non-stationary
via the `_make_target_std_and_autocorr` function.
# Returns
- A latent model object which can be one of `DiffLatentModel`, `AR`, or `RandomWalk` depending on the `latent_model_name` and `igp` specified in `inference_config`.
"""
function remake_latent_model(
inference_config::Dict, pipeline::AbstractRtwithoutRenewalPipeline)
#Baseline choices
prior_dict = make_model_priors(pipeline)
igp = inference_config["igp"]
default_latent_model = inference_config["latent_namemodels"].second
target_std, target_autocorr = default_latent_model isa AR ?
_make_target_std_and_autocorr(igp; stationary = true) :
_make_target_std_and_autocorr(igp; stationary = false)

return _implement_latent_process(
target_std, target_autocorr, default_latent_model, pipeline)
end

"""
This function sets the target standard deviation for an infection generating process (igp)
based on whether the latent process representation of its dynamics are stationary or non-stationary.
## Stationary Processes
- For Renewal process `log(R_t)` in the long run a fluctuation of 0.75 (e.g. ~ 75% of the mean) is not unexpected.
- For Exponential Growth Rate process `r_t` in the long run a fluctuation of 0.2 is not unexpected e.g. going from
`rt = 0.1` (7 day doubling time) to `rt = -0.1` (7 day halving time) is a 0.2 time-to-time fluctuation.
- For Direct Infections process `log(I_t)` in the long run a fluctuation of 2.0 (i.e a couple of orders of magnitude) is not unexpected.
For stationary latent processes Direct Infections and rt processes the autocorrelation is expected to be high at 0.9,
because persistence in residual away from mean is expected. Otherwise, the autocorrelation is expected to be 0.1.
## Non-Stationary Processes
For Renewal process `log(R_t)` in a single time step a fluctuation of 0.025 (e.g. ~ 2.5% of the mean) is not unexpected.
For Exponential Growth Rate process `r_t` in a single time step a fluctuation of 0.005 is not unexpected.
For Direct Infections process `log(I_t)` in a single time step a fluctuation of 0.025 is not unexpected.
The autocorrelation is expected to be 0.1.
"""
function _make_target_std_and_autocorr(::Type{Renewal}; stationary::Bool)
return stationary ? (0.75, 0.1) : (0.025, 0.1)
end

function _make_target_std_and_autocorr(::Type{ExpGrowthRate}; stationary::Bool)
return stationary ? (0.2, 0.9) : (0.005, 0.1)
end

function _make_target_std_and_autocorr(::Type{DirectInfections}; stationary::Bool)
return stationary ? (2.0, 0.9) : (0.25, 0.1)
end

function _make_new_prior_dict(target_std, target_autocorr,
pipeline::AbstractRtwithoutRenewalPipeline; beta_eff_sample_size)
#Get default priors
prior_dict = make_model_priors(pipeline)
#Adjust priors based on target autocorrelation and standard deviation
damp_prior = Beta(target_autocorr * beta_eff_sample_size,
(1 - target_autocorr) * beta_eff_sample_size)
corr_corrected_noise_prior = HalfNormal(target_std * sqrt(1 - target_autocorr^2))
noise_prior = HalfNormal(target_std)
init_prior = prior_dict["transformed_process_init_prior"]
return Dict(
"transformed_process_init_prior" => init_prior,
"corr_corrected_noise_prior" => corr_corrected_noise_prior,
"noise_prior" => noise_prior,
"damp_param_prior" => damp_prior
)
end

"""
Constructs and returns a latent model based on an approximation to the specified target standard deviation and autocorrelation.
NB: The stationary variance of an AR(1) process is given by `σ² = σ²_ε / (1 - ρ²)` where `σ²_ε` is the variance of the noise and `ρ` is the autocorrelation.
The approximation here are based on `E[1/(1 - ρ²)`] ≈ 1 / (1 - E[ρ²])` which only holds for fairly tight distributions of `ρ`.
However, for priors this should get the expected order of magnitude.
# Models
- `"diff_ar"`: Constructs a `DiffLatentModel` with an autoregressive (AR) process.
- `"ar"`: Constructs an autoregressive (AR) process.
- `"rw"`: Constructs a random walk (RW) process.
"""
function _implement_latent_process(
target_std, target_autocorr, default_latent_model, pipeline; beta_eff_sample_size = 10)
prior_dict = make_model_priors(pipeline)
new_priors = _make_new_prior_dict(
target_std, target_autocorr, pipeline; beta_eff_sample_size)

return _make_latent(default_latent_model, new_priors)
end

function _make_latent(::AR, new_priors)
damp_prior = new_priors["damp_param_prior"]
corr_corrected_noise_std = new_priors["corr_corrected_noise_prior"]
init_prior = new_priors["transformed_process_init_prior"]
return AR(damp_priors = [damp_prior],
std_prior = corr_corrected_noise_std,
init_priors = [init_prior])
end

function _make_latent(::DiffLatentModel, new_priors)
init_prior = new_priors["transformed_process_init_prior"]
ar = _make_latent(AR(), new_priors)
return DiffLatentModel(; model = ar, init_priors = [init_prior])
end

function _make_latent(::RandomWalk, new_priors)
noise_std = new_priors["noise_prior"]
init_prior = new_priors["transformed_process_init_prior"]
return RandomWalk(std_prior = noise_std, init_prior = init_prior)
end

"""
Pass through fallback dispatch.
"""
function remake_latent_model(inference_config::Dict, pipeline::AbstractEpiAwarePipeline)
inference_config["latent_namemodels"].second
end
11 changes: 9 additions & 2 deletions pipeline/src/constructors/selector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@ end

"""
Internal method for selecting from a list of items based on the pipeline type.
Example/test mode is to return a randomly selected item from the list.
Example/test mode is to return a randomly selected item from the list. Prior predictive mode
only runs on configurations with the furthest ahead horizon.
"""
function _selector(list, pipeline::AbstractRtwithoutRenewalPipeline)
return pipeline.testmode ? [rand(list)] : list
if pipeline.priorpredictive
maxT = maximum([config["T"] for config in list])
_list = filter(config -> config["T"] == maxT, list)
return pipeline.testmode ? [rand(_list)] : _list
else
return pipeline.testmode ? [rand(list)] : list
end
end
43 changes: 23 additions & 20 deletions pipeline/src/infer/InferenceConfig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Inference configuration struct for specifying the parameters and models used in
"""
struct InferenceConfig{
T, F, IGP, L, O, E, D <: Distribution, X <: Integer,
P <: AbstractRtwithoutRenewalPipeline}
P <: AbstractEpiAwarePipeline}
gi_mean::T
gi_std::T
igp::IGP
Expand Down Expand Up @@ -51,26 +51,29 @@ struct InferenceConfig{
case_data, truth_I_t, truth_I0, tspan, epimethod,
transformation, log_I0_prior, lookahead, latent_model_name, pipeline)
end
end

function InferenceConfig(
inference_config::Dict; case_data, truth_I_t, truth_I0, tspan, epimethod, pipeline)
InferenceConfig(
inference_config["igp"],
inference_config["latent_namemodels"].second,
inference_config["observation_model"];
gi_mean = inference_config["gi_mean"],
gi_std = inference_config["gi_std"],
case_data = case_data,
truth_I_t = truth_I_t,
truth_I0 = truth_I0,
tspan = tspan,
epimethod = epimethod,
log_I0_prior = inference_config["log_I0_prior"],
lookahead = inference_config["lookahead"],
latent_model_name = inference_config["latent_namemodels"].first,
pipeline
)
end
function InferenceConfig(
inference_config::Dict, pipeline::AbstractEpiAwarePipeline;
case_data, truth_I_t, truth_I0, tspan, epimethod)
latent_model = remake_latent_model(inference_config::Dict, pipeline)

InferenceConfig(
inference_config["igp"],
latent_model,
inference_config["observation_model"];
gi_mean = inference_config["gi_mean"],
gi_std = inference_config["gi_std"],
case_data = case_data,
truth_I_t = truth_I_t,
truth_I0 = truth_I0,
tspan = tspan,
epimethod = epimethod,
log_I0_prior = inference_config["log_I0_prior"],
lookahead = inference_config["lookahead"],
latent_model_name = inference_config["latent_namemodels"].first,
pipeline
)
end

"""
Expand Down
4 changes: 2 additions & 2 deletions pipeline/src/infer/generate_inference_results.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ function generate_inference_results(
pipeline; T = inference_config["T"], lookback = inference_config["lookback"])
inference_method = make_inference_method(pipeline)
config = InferenceConfig(
inference_config; case_data = truthdata["y_t"], truth_I_t = truthdata["I_t"],
truth_I0 = truthdata["truth_I0"], tspan, epimethod = inference_method, pipeline = pipeline)
inference_config, pipeline; case_data = truthdata["y_t"], truth_I_t = truthdata["I_t"],
truth_I0 = truthdata["truth_I0"], tspan, epimethod = inference_method)

# produce or load inference results
prfx = _inference_prefix(truthdata, inference_config, pipeline)
Expand Down
Loading

0 comments on commit 94b4643

Please sign in to comment.