diff --git a/EpiAware/Project.toml b/EpiAware/Project.toml index cb35f98bc..3779f1e88 100644 --- a/EpiAware/Project.toml +++ b/EpiAware/Project.toml @@ -11,6 +11,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" +Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454" QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -20,14 +21,14 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] DataFramesMeta = "0.14" Distributions = "0.25" -DocStringExtensions = "0.9" LinearAlgebra = "1.9" LogExpFunctions = "0.3" Optim = "1.9" Parameters = "0.12" +Pathfinder = "0.8" QuadGK = "2.9" Random = "1.9" ReverseDiff = "1.15" SparseArrays = "1.10" Turing = "0.30" -julia = "1.10" +julia = "1.9" diff --git a/EpiAware/src/models.jl b/EpiAware/src/models.jl index 0c93c3b21..eb1fbd4f9 100644 --- a/EpiAware/src/models.jl +++ b/EpiAware/src/models.jl @@ -25,3 +25,38 @@ latent_model, process_aux = merge(latent_model_aux, generated_y_t_aux)) end + +@model function make_epi_aware(y_t, + time_steps, + ::Val{:safe}; + epi_model::AbstractEpiModel, + latent_model_model::AbstractLatentModel, + observation_model::AbstractObservationModel, + pos_shift = 1e-6) + try + #Latent process + @submodel latent_model, latent_model_aux = generate_latent( + latent_model_model, + time_steps) + + #Transform into infections + @submodel I_t = generate_latent_infs(epi_model, latent_model) + + #Predictive distribution of ascerted cases + @submodel generated_y_t, generated_y_t_aux = generate_observations( + observation_model, + y_t, + I_t; + pos_shift = pos_shift) + + #Generate quantities + return (; + generated_y_t, + I_t, + latent_model, + process_aux = merge(latent_model_aux, generated_y_t_aux)) + catch + Turing.@addlogprob! -Inf + return + end +end diff --git a/EpiAware/test/Project.toml b/EpiAware/test/Project.toml index 0271e5375..52faadb00 100644 --- a/EpiAware/test/Project.toml +++ b/EpiAware/test/Project.toml @@ -6,6 +6,8 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Optim = "429524aa-4258-5aef-a3af-852621145aeb" +Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" diff --git a/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl b/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl index e686bf41a..93b0ea2fb 100644 --- a/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl +++ b/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl @@ -7,7 +7,8 @@ This is a toy model for demonstrating current functionality of EpiAware package. ### Latent Process -The latent process is a random walk defined by a Turing model `random_walk` of specified length `n`. +The latent process is a random walk defined by a Turing model `random_walk` of specified + length `n`. _Unfixed parameters_: - `σ²_RW`: The variance of the random walk process. Current defauly prior is @@ -25,10 +26,12 @@ X(0) &\sim \mathcal{N}(0, 1) \\ ### Log-Infections Model -The log-infections model is defined by a Turing model `log_infections` that takes the observed data `y_t` (or `missing` value), -an `EpiModel` object `epi_model`, and a `latent_model` model. In this case the latent process is a random walk model. +The log-infections model is defined by a Turing model `log_infections` that takes the + observed data `y_t` (or `missing` value), an `EpiModel` object `epi_model`, and a + `latent_model` model. In this case the latent process is a random walk model. -It also accepts optional arguments for the `process_priors`, `transform_function`, `pos_shift`, `neg_bin_cluster_factor`, and `neg_bin_cluster_factor_prior`. +It also accepts optional arguments for the `process_priors`, `transform_function`, + `pos_shift`, `neg_bin_cluster_factor`, and `neg_bin_cluster_factor_prior`. ```math \log I_t = \exp(X(t)). @@ -36,8 +39,9 @@ It also accepts optional arguments for the `process_priors`, `transform_function ### Observation model -The observation model is a negative binomial distribution with mean `μ` and cluster factor `r`. Delays are implemented -as the action of a sparse kernel on the infections $I(t)$. The delay kernel is contained in an `EpiModel` struct. +The observation model is a negative binomial distribution with mean `μ` and cluster factor + `r`. Delays are implemented as the action of a sparse kernel on the infections $I(t)$. +The delay kernel is contained in an `EpiModel` struct. ```math \begin{align} @@ -68,7 +72,8 @@ using Random using DynamicPPL using Statistics using DataFramesMeta -using CSV # For outputting the MCMC chain +using CSV +using Pathfinder Random.seed!(0) @@ -132,22 +137,42 @@ plot(gen.I_t, scatter!(random_epidemic.y_t, lab = "generated cases") #= -## Inference +## Model with observed data We treat the generated data as observed data and attempt to infer underlying infections. =# truth_data = random_epidemic.y_t -model = make_epi_aware(truth_data, time_horizon, ; epi_model = epi_model, +model = make_epi_aware(truth_data, time_horizon; epi_model = epi_model, latent_model_model = rwp, observation_model = obs_model, pos_shift = 1e-6) + +#= +### Pathfinder inference + +We can use pathfinder to get draws from the model. We can later use these draws to + initialize the MCMC chain. We can also compare a single run of pathfinder with +=# + +safe_model = make_epi_aware(truth_data, time_horizon, Val(:safe); + epi_model = epi_model, + latent_model_model = rwp, + observation_model = obs_model, + pos_shift = 1e-6) + +mpf_result = multipathfinder(safe_model, 1000; nruns = 10) + +mpf_chn = mpf_result.draws_transformed + @time chn = sample(model, NUTS(; adtype = AutoReverseDiff(true)), MCMCThreads(), 250, 4; - drop_warmup = true) + drop_warmup = true, + init_params = collect.(eachrow(mpf_chn.value[1:4, :, 1])) +) #= ## Postior predictive checking @@ -155,27 +180,56 @@ model = make_epi_aware(truth_data, time_horizon, ; epi_model = epi_model, We check the posterior predictive checking by plotting the predicted cases against the observed cases. =# -predicted_y_t = mapreduce(hcat, generated_quantities(log_infs_model, chn)) do gen - gen.generated_y_t +predicted_y_t, mpf_predicted_y_t = map((chn, mpf_chn)) do _chn + mapreduce(hcat, generated_quantities(log_infs_model, _chn)) do gen + gen.generated_y_t + end end -plot(predicted_y_t, c = :grey, alpha = 0.05, lab = "") -scatter!(truth_data, - lab = "Observed cases", - xlabel = "Time", - ylabel = "Cases", - title = "Posterior Predictive Checking", - ylims = (-0.5, maximum(truth_data) * 2.5)) +data_pred_plts = map(("NUTS", "multi-pf"), + (predicted_y_t, mpf_predicted_y_t)) do title_str, pred_y_t + plt = plot(pred_y_t, c = :grey, alpha = 0.05, lab = "") + scatter!(plt, truth_data, + lab = "Observed cases", + xlabel = "Time", + ylabel = "Cases", + title = "Posterior Predictive Checking: " * title_str, + ylims = (-0.5, maximum(truth_data) * 2.5)) + return plt +end + +plot(data_pred_plts..., + layout = (2, 1), + size = (500, 700)) #= ## Underlying inferred infections =# -predicted_I_t = mapreduce(hcat, generated_quantities(log_infs_model, chn)) do gen - gen.I_t +predicted_I_t, mpf_predicted_I_t = map((chn, mpf_chn)) do _chn + mapreduce(hcat, generated_quantities(log_infs_model, _chn)) do gen + gen.I_t + end +end + +plts_infs = map(("NUTS", "multi-pf"), + (predicted_I_t, mpf_predicted_I_t)) do title_str, pred_I_t + plt = plot(pred_I_t, c = :grey, alpha = 0.05, lab = "") + scatter!(plt, gen.I_t, + lab = "Actual infections", + xlabel = "Time", + ylabel = "Infections", + title = "Posterior Predictive Checking: " * title_str, + ylims = (-0.5, maximum(gen.I_t) * 1.5)) + return plt end -plot(predicted_I_t, c = :grey, alpha = 0.05, lab = "") +plot(plts_infs..., + layout = (2, 1), + size = (500, 700)) +plot(pf_predicted_I_t, c = :grey, alpha = 0.05, lab = "") +plot(predicted_I_t, c = :blue, alpha = 0.05, lab = "") + scatter!(gen.I_t, lab = "Actual infections", xlabel = "Time",