From a6ecd501c404ae2264837a8342656deec6ecadc5 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Mon, 19 Feb 2024 10:29:09 +0000 Subject: [PATCH] Add observation process as a Turing model --- EpiAware/src/EpiAware.jl | 4 ++- EpiAware/src/models.jl | 25 ++++++-------- EpiAware/src/observation-processes.jl | 24 +++++++++++++ .../toy_model_log_infs_RW.jl | 5 ++- EpiAware/test/test_models.jl | 34 ++++++------------- EpiAware/test/test_observation-processes.jl | 31 +++++++++++++++++ 6 files changed, 81 insertions(+), 42 deletions(-) create mode 100644 EpiAware/src/observation-processes.jl create mode 100644 EpiAware/test/test_observation-processes.jl diff --git a/EpiAware/src/EpiAware.jl b/EpiAware/src/EpiAware.jl index c8f754194..35f6a7ebb 100644 --- a/EpiAware/src/EpiAware.jl +++ b/EpiAware/src/EpiAware.jl @@ -36,6 +36,7 @@ export scan, growth_rate_to_reproductive_ratio, generate_observation_kernel, default_rw_priors, + default_delay_obs_priors, neg_MGF, dneg_MGF_dr, fast_R_to_r_approx @@ -44,11 +45,12 @@ export scan, export EpiData, Renewal, ExpGrowthRate, DirectInfections, AbstractEpiModel # Exported Turing model constructors -export make_epi_inference_model, random_walk +export make_epi_inference_model, random_walk, delay_observations include("epimodel.jl") include("utilities.jl") include("models.jl") include("latent-processes.jl") +include("observation-processes.jl") end diff --git a/EpiAware/src/models.jl b/EpiAware/src/models.jl index b127362c3..3f06cd8cb 100644 --- a/EpiAware/src/models.jl +++ b/EpiAware/src/models.jl @@ -1,15 +1,11 @@ @model function make_epi_inference_model( y_t, epimodel::AbstractEpiModel, - latent_process; + latent_process, + observation_process; latent_process_priors, pos_shift = 1e-6, - neg_bin_cluster_factor = missing, - neg_bin_cluster_factor_prior = Gamma(3, 0.05 / 3), ) - #Prior - neg_bin_cluster_factor ~ neg_bin_cluster_factor_prior - #Latent process time_steps = epimodel.data.time_horizon @submodel latent_process, latent_process_aux = @@ -18,14 +14,15 @@ #Transform into infections I_t = epimodel(latent_process, latent_process_aux) - #Predictive distribution - case_pred_dists = - (epimodel.data.delay_kernel * I_t) .+ pos_shift .|> - μ -> mean_cc_neg_bin(μ, neg_bin_cluster_factor) - - #Likelihood - y_t ~ arraydist(case_pred_dists) + #Predictive distribution of ascerted cases + @submodel generated_y_t, generated_y_t_aux = observation_process( + y_t, + I_t, + epimodel::AbstractEpiModel; + observation_process_priors = latent_process_priors, + pos_shift = pos_shift, + ) #Generate quantities - return (; I_t, latent_process, latent_process_aux) + return (; generated_y_t, I_t, latent_process, latent_process_aux, generated_y_t_aux) end diff --git a/EpiAware/src/observation-processes.jl b/EpiAware/src/observation-processes.jl new file mode 100644 index 000000000..90d904449 --- /dev/null +++ b/EpiAware/src/observation-processes.jl @@ -0,0 +1,24 @@ +function default_delay_obs_priors() + return (neg_bin_cluster_factor_prior = Gamma(3, 0.05 / 3),) +end + +@model function delay_observations( + y_t, + I_t, + epimodel::AbstractEpiModel; + observation_process_priors = default_delay_obs_priors(), + pos_shift = 1e-6, +) + #Parameters + neg_bin_cluster_factor ~ observation_process_priors.neg_bin_cluster_factor_prior + + #Predictive distribution + case_pred_dists = + (epimodel.data.delay_kernel * I_t) .+ pos_shift .|> + μ -> mean_cc_neg_bin(μ, neg_bin_cluster_factor) + + #Likelihood + y_t ~ arraydist(case_pred_dists) + + return y_t, (; neg_bin_cluster_factor,) +end diff --git a/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl b/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl index 2a3dfebd1..f097e0f14 100644 --- a/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl +++ b/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl @@ -104,10 +104,9 @@ log_infs_model = make_epi_inference_model( missing, toy_log_infs, random_walk, - latent_process_priors = default_rw_priors(), + delay_observations; + latent_process_priors = merge(default_rw_priors(), default_delay_obs_priors()), pos_shift = 1e-6, - neg_bin_cluster_factor = 0.5, - neg_bin_cluster_factor_prior = Gamma(3, 0.05 / 3), ) diff --git a/EpiAware/test/test_models.jl b/EpiAware/test/test_models.jl index 38d7ca0be..311201051 100644 --- a/EpiAware/test/test_models.jl +++ b/EpiAware/test/test_models.jl @@ -4,12 +4,9 @@ # Define test inputs y_t = missing # Data will be generated from the model data = EpiData([0.2, 0.3, 0.5], [0.1, 0.4, 0.5], 0.8, 10, exp) - latent_process_priors = default_rw_priors() - transform_function = exp - n_generate_ahead = 0 + latent_process_priors = merge(default_rw_priors(), default_delay_obs_priors()) pos_shift = 1e-6 - neg_bin_cluster_factor = 0.5 - neg_bin_cluster_factor_prior = Gamma(3, 0.05 / 3) + epimodel = DirectInfections(data) @@ -17,11 +14,10 @@ test_mdl = make_epi_inference_model( y_t, epimodel, - random_walk; + random_walk, + delay_observations; latent_process_priors, pos_shift, - neg_bin_cluster_factor, - neg_bin_cluster_factor_prior, ) # Define expected outputs for a conditional model @@ -43,12 +39,8 @@ end # Define test inputs y_t = missing # Data will be generated from the model data = EpiData([0.2, 0.3, 0.5], [0.1, 0.4, 0.5], 0.8, 10, exp) - latent_process_priors = default_rw_priors() - transform_function = exp - n_generate_ahead = 0 + latent_process_priors = merge(default_rw_priors(), default_delay_obs_priors()) pos_shift = 1e-6 - neg_bin_cluster_factor = 0.5 - neg_bin_cluster_factor_prior = Gamma(3, 0.05 / 3) epimodel = ExpGrowthRate(data) @@ -56,11 +48,10 @@ end test_mdl = make_epi_inference_model( y_t, epimodel, - random_walk; + random_walk, + delay_observations; latent_process_priors, pos_shift, - neg_bin_cluster_factor, - neg_bin_cluster_factor_prior, ) # Define expected outputs for a conditional model @@ -82,12 +73,8 @@ end # Define test inputs y_t = missing # Data will be generated from the model data = EpiData([0.2, 0.3, 0.5], [0.1, 0.4, 0.5], 0.8, 10, exp) - latent_process_priors = default_rw_priors() - transform_function = exp - n_generate_ahead = 0 + latent_process_priors = merge(default_rw_priors(), default_delay_obs_priors()) pos_shift = 1e-6 - neg_bin_cluster_factor = 0.5 - neg_bin_cluster_factor_prior = Gamma(3, 0.05 / 3) epimodel = Renewal(data) @@ -95,11 +82,10 @@ end test_mdl = make_epi_inference_model( y_t, epimodel, - random_walk; + random_walk, + delay_observations; latent_process_priors, pos_shift, - neg_bin_cluster_factor, - neg_bin_cluster_factor_prior, ) # Define expected outputs for a conditional model diff --git a/EpiAware/test/test_observation-processes.jl b/EpiAware/test/test_observation-processes.jl new file mode 100644 index 000000000..238bfd5fa --- /dev/null +++ b/EpiAware/test/test_observation-processes.jl @@ -0,0 +1,31 @@ +@testitem "Testing delay obs against theoretical properties" begin + using DynamicPPL, Turing + # Set up test data with fixed infection + I_t = [10.0, 20.0, 30.0] + + # Replace with your own implementation of AbstractEpiModel + # Delay kernel is just event observed on same day + data = EpiData([0.2, 0.3, 0.5], [1.0], 0.8, 3, exp) + epimodel = DirectInfections(data) + # Set up priors + observation_process_priors = default_delay_obs_priors() + + # Call the function + mdl = delay_observations( + missing, + I_t, + epimodel; + observation_process_priors = observation_process_priors, + ) + fix_mdl = fix(mdl, neg_bin_cluster_factor = 0.00001) # Effectively Poisson sampling + + n_samples = 1000 + mean_first_obs = + sample(fix_mdl, Prior(), n_samples) |> + chn -> generated_quantities(fix_mdl, chn) .|> (gen -> gen[1][1]) |> mean + + theoretical_std_of_empiral_mean = sqrt(I_t[1]) / sqrt(n_samples) + @test mean(mean_first_obs) - I_t[1] < 5 * theoretical_std_of_empiral_mean && + mean(mean_first_obs) - I_t[1] > -5 * theoretical_std_of_empiral_mean + +end