Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pathfinder initialisation #110

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions EpiAware/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Expand All @@ -20,14 +21,14 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
[compat]
DataFramesMeta = "0.14"
Distributions = "0.25"
DocStringExtensions = "0.9"
LinearAlgebra = "1.9"
LogExpFunctions = "0.3"
Optim = "1.9"
Parameters = "0.12"
Pathfinder = "0.8"
QuadGK = "2.9"
Random = "1.9"
ReverseDiff = "1.15"
SparseArrays = "1.10"
Turing = "0.30"
julia = "1.10"
julia = "1.9"
35 changes: 35 additions & 0 deletions EpiAware/src/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,38 @@
latent_model,
process_aux = merge(latent_model_aux, generated_y_t_aux))
end

@model function make_epi_aware(y_t,
time_steps,
::Val{:safe};
epi_model::AbstractEpiModel,
latent_model_model::AbstractLatentModel,
observation_model::AbstractObservationModel,
pos_shift = 1e-6)
try
#Latent process
@submodel latent_model, latent_model_aux = generate_latent(
latent_model_model,
time_steps)

#Transform into infections
@submodel I_t = generate_latent_infs(epi_model, latent_model)

#Predictive distribution of ascerted cases
@submodel generated_y_t, generated_y_t_aux = generate_observations(
observation_model,
y_t,
I_t;
pos_shift = pos_shift)

#Generate quantities
return (;
generated_y_t,
I_t,
latent_model,
process_aux = merge(latent_model_aux, generated_y_t_aux))
catch
Turing.@addlogprob! -Inf
return
end
end
2 changes: 2 additions & 0 deletions EpiAware/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Expand Down
98 changes: 76 additions & 22 deletions EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ This is a toy model for demonstrating current functionality of EpiAware package.

### Latent Process

The latent process is a random walk defined by a Turing model `random_walk` of specified length `n`.
The latent process is a random walk defined by a Turing model `random_walk` of specified
length `n`.

_Unfixed parameters_:
- `σ²_RW`: The variance of the random walk process. Current defauly prior is
Expand All @@ -25,19 +26,22 @@ X(0) &\sim \mathcal{N}(0, 1) \\

### Log-Infections Model

The log-infections model is defined by a Turing model `log_infections` that takes the observed data `y_t` (or `missing` value),
an `EpiModel` object `epi_model`, and a `latent_model` model. In this case the latent process is a random walk model.
The log-infections model is defined by a Turing model `log_infections` that takes the
observed data `y_t` (or `missing` value), an `EpiModel` object `epi_model`, and a
`latent_model` model. In this case the latent process is a random walk model.

It also accepts optional arguments for the `process_priors`, `transform_function`, `pos_shift`, `neg_bin_cluster_factor`, and `neg_bin_cluster_factor_prior`.
It also accepts optional arguments for the `process_priors`, `transform_function`,
`pos_shift`, `neg_bin_cluster_factor`, and `neg_bin_cluster_factor_prior`.

```math
\log I_t = \exp(X(t)).
```

### Observation model

The observation model is a negative binomial distribution with mean `μ` and cluster factor `r`. Delays are implemented
as the action of a sparse kernel on the infections $I(t)$. The delay kernel is contained in an `EpiModel` struct.
The observation model is a negative binomial distribution with mean `μ` and cluster factor
`r`. Delays are implemented as the action of a sparse kernel on the infections $I(t)$.
The delay kernel is contained in an `EpiModel` struct.

```math
\begin{align}
Expand Down Expand Up @@ -68,7 +72,8 @@ using Random
using DynamicPPL
using Statistics
using DataFramesMeta
using CSV # For outputting the MCMC chain
using CSV
using Pathfinder

Random.seed!(0)

Expand Down Expand Up @@ -132,50 +137,99 @@ plot(gen.I_t,
scatter!(random_epidemic.y_t, lab = "generated cases")

#=
## Inference
## Model with observed data

We treat the generated data as observed data and attempt to infer underlying infections.
=#

truth_data = random_epidemic.y_t

model = make_epi_aware(truth_data, time_horizon, ; epi_model = epi_model,
model = make_epi_aware(truth_data, time_horizon; epi_model = epi_model,
latent_model_model = rwp, observation_model = obs_model,
pos_shift = 1e-6)

#=
### Pathfinder inference

We can use pathfinder to get draws from the model. We can later use these draws to
initialize the MCMC chain. We can also compare a single run of pathfinder with
=#

safe_model = make_epi_aware(truth_data, time_horizon, Val(:safe);
epi_model = epi_model,
latent_model_model = rwp,
observation_model = obs_model,
pos_shift = 1e-6)

mpf_result = multipathfinder(safe_model, 1000; nruns = 10)

mpf_chn = mpf_result.draws_transformed

@time chn = sample(model,
NUTS(; adtype = AutoReverseDiff(true)),
MCMCThreads(),
250,
4;
drop_warmup = true)
drop_warmup = true,
init_params = collect.(eachrow(mpf_chn.value[1:4, :, 1]))
)

#=
## Postior predictive checking

We check the posterior predictive checking by plotting the predicted cases against the observed cases.
=#

predicted_y_t = mapreduce(hcat, generated_quantities(log_infs_model, chn)) do gen
gen.generated_y_t
predicted_y_t, mpf_predicted_y_t = map((chn, mpf_chn)) do _chn
mapreduce(hcat, generated_quantities(log_infs_model, _chn)) do gen
gen.generated_y_t
end
end

plot(predicted_y_t, c = :grey, alpha = 0.05, lab = "")
scatter!(truth_data,
lab = "Observed cases",
xlabel = "Time",
ylabel = "Cases",
title = "Posterior Predictive Checking",
ylims = (-0.5, maximum(truth_data) * 2.5))
data_pred_plts = map(("NUTS", "multi-pf"),
(predicted_y_t, mpf_predicted_y_t)) do title_str, pred_y_t
plt = plot(pred_y_t, c = :grey, alpha = 0.05, lab = "")
scatter!(plt, truth_data,
lab = "Observed cases",
xlabel = "Time",
ylabel = "Cases",
title = "Posterior Predictive Checking: " * title_str,
ylims = (-0.5, maximum(truth_data) * 2.5))
return plt
end

plot(data_pred_plts...,
layout = (2, 1),
size = (500, 700))

#=
## Underlying inferred infections
=#

predicted_I_t = mapreduce(hcat, generated_quantities(log_infs_model, chn)) do gen
gen.I_t
predicted_I_t, mpf_predicted_I_t = map((chn, mpf_chn)) do _chn
mapreduce(hcat, generated_quantities(log_infs_model, _chn)) do gen
gen.I_t
end
end

plts_infs = map(("NUTS", "multi-pf"),
(predicted_I_t, mpf_predicted_I_t)) do title_str, pred_I_t
plt = plot(pred_I_t, c = :grey, alpha = 0.05, lab = "")
scatter!(plt, gen.I_t,
lab = "Actual infections",
xlabel = "Time",
ylabel = "Infections",
title = "Posterior Predictive Checking: " * title_str,
ylims = (-0.5, maximum(gen.I_t) * 1.5))
return plt
end

plot(predicted_I_t, c = :grey, alpha = 0.05, lab = "")
plot(plts_infs...,
layout = (2, 1),
size = (500, 700))
plot(pf_predicted_I_t, c = :grey, alpha = 0.05, lab = "")
plot(predicted_I_t, c = :blue, alpha = 0.05, lab = "")

scatter!(gen.I_t,
lab = "Actual infections",
xlabel = "Time",
Expand Down
Loading