-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
move abstract types into a single script and add a AbstractModel type
- Loading branch information
Showing
8 changed files
with
492 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
abstract type AbstractModel end | ||
|
||
abstract type AbstractEpiModel <: AbstractModel end | ||
|
||
abstract type AbstractLatentModel <: AbstractModel end | ||
|
||
abstract type AbstractObservationModel <: AbstractModel end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
struct EpiData{T <: Real, F <: Function} | ||
gen_int::Vector{T} | ||
len_gen_int::Integer | ||
transformation::F | ||
|
||
#Inner constructors for EpiData object | ||
function EpiData(gen_int, | ||
transformation::Function) | ||
@assert all(gen_int .>= 0) "Generation interval must be non-negative" | ||
@assert sum(gen_int)≈1 "Generation interval must sum to 1" | ||
|
||
new{eltype(gen_int), typeof(transformation)}(gen_int, | ||
length(gen_int), | ||
transformation) | ||
end | ||
|
||
function EpiData(gen_distribution::ContinuousDistribution; | ||
D_gen, | ||
Δd = 1.0, | ||
transformation::Function = exp) | ||
gen_int = create_discrete_pmf(gen_distribution, Δd = Δd, D = D_gen) |> | ||
p -> p[2:end] ./ sum(p[2:end]) | ||
|
||
return EpiData(gen_int, transformation) | ||
end | ||
end | ||
|
||
struct DirectInfections{S <: Sampleable} <: AbstractEpiModel | ||
data::EpiData | ||
initialisation_prior::S | ||
end | ||
|
||
struct ExpGrowthRate{S <: Sampleable} <: AbstractEpiModel | ||
data::EpiData | ||
initialisation_prior::S | ||
end | ||
|
||
struct Renewal{S <: Sampleable} <: AbstractEpiModel | ||
data::EpiData | ||
initialisation_prior::S | ||
end | ||
|
||
""" | ||
function (epi_model::Renewal)(recent_incidence, Rt) | ||
Compute new incidence based on recent incidence and Rt. | ||
This is a callable function on `Renewal` structs, that encodes new incidence prediction | ||
given recent incidence and Rt according to basic renewal process. | ||
```math | ||
I_t = R_t \\sum_{i=1}^{n-1} I_{t-i} g_i | ||
``` | ||
where `I_t` is the new incidence, `R_t` is the reproduction number, `I_{t-i}` is the recent incidence | ||
and `g_i` is the generation interval. | ||
# Arguments | ||
- `recent_incidence`: Array of recent incidence values. | ||
- `Rt`: Reproduction number. | ||
# Returns | ||
- Tuple containing the updated incidence array and the new incidence value. | ||
""" | ||
function (epi_model::Renewal)(recent_incidence, Rt) | ||
new_incidence = Rt * dot(recent_incidence, epi_model.data.gen_int) | ||
return ([new_incidence; recent_incidence[1:(epi_model.data.len_gen_int - 1)]], | ||
new_incidence) | ||
end | ||
|
||
function generate_latent_infs(epi_model::AbstractEpiModel, latent_model) | ||
@info "No concrete implementation for `generate_latent_infs` is defined." | ||
return nothing | ||
end | ||
|
||
@model function generate_latent_infs(epi_model::DirectInfections, _It) | ||
init_incidence ~ epi_model.initialisation_prior | ||
return epi_model.data.transformation.(init_incidence .+ _It) | ||
end | ||
|
||
@model function generate_latent_infs(epi_model::ExpGrowthRate, rt) | ||
init_incidence ~ epi_model.initialisation_prior | ||
return exp.(init_incidence .+ cumsum(rt)) | ||
end | ||
|
||
""" | ||
generate_latent_infs(epi_model::Renewal, _Rt) | ||
`Turing` model constructor for latent infections using the `Renewal` object `epi_model` and time-varying unconstrained reproduction number `_Rt`. | ||
`generate_latent_infs` creates a `Turing` model for sampling latent infections with given unconstrainted | ||
reproduction number `_Rt` but random initial incidence scale. The initial incidence pre-time one is given as | ||
a scale on top of an exponential growing process with exponential growth rate given by `R_to_r`applied to the | ||
first value of `Rt`. | ||
# Arguments | ||
- `epi_model::Renewal`: The epidemiological model. | ||
- `_Rt`: Time-varying unconstrained (e.g. log-) reproduction number. | ||
# Returns | ||
- `I_t`: Array of latent infections over time. | ||
""" | ||
@model function generate_latent_infs(epi_model::Renewal, _Rt) | ||
init_incidence ~ epi_model.initialisation_prior | ||
I₀ = epi_model.data.transformation(init_incidence) | ||
Rt = epi_model.data.transformation.(_Rt) | ||
|
||
r_approx = R_to_r(Rt[1], epi_model) | ||
init = I₀ * [exp(-r_approx * t) for t in 0:(epi_model.data.len_gen_int - 1)] | ||
|
||
I_t, _ = scan(epi_model, init, Rt) | ||
return I_t | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
struct RandomWalk{D <: Sampleable, S <: Sampleable} <: AbstractLatentModel | ||
init_prior::D | ||
var_prior::S | ||
end | ||
|
||
function default_rw_priors() | ||
return (:var_RW_prior => truncated(Normal(0.0, 0.05), 0.0, Inf), | ||
:init_rw_value_prior => Normal()) |> Dict | ||
end | ||
|
||
function generate_latent(latent_model::AbstractLatentModel, n) | ||
@info "No concrete implementation for generate_latent is defined." | ||
return nothing | ||
end | ||
|
||
@model function generate_latent(latent_model::RandomWalk, n) | ||
ϵ_t ~ MvNormal(ones(n)) | ||
σ²_RW ~ latent_model.var_prior | ||
rw_init ~ latent_model.init_prior | ||
σ_RW = sqrt(σ²_RW) | ||
rw = Vector{eltype(ϵ_t)}(undef, n) | ||
|
||
rw[1] = rw_init + σ_RW * ϵ_t[1] | ||
for t in 2:n | ||
rw[t] = rw[t - 1] + σ_RW * ϵ_t[t] | ||
end | ||
return rw, (; σ_RW, rw_init) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
struct DelayObservations{T <: AbstractFloat, S <: Sampleable} <: AbstractObservationModel | ||
delay_kernel::SparseMatrixCSC{T, Integer} | ||
neg_bin_cluster_factor_prior::S | ||
|
||
function DelayObservations(delay_int, | ||
time_horizon, | ||
neg_bin_cluster_factor_prior) | ||
@assert all(delay_int .>= 0) "Delay interval must be non-negative" | ||
@assert sum(delay_int)≈1 "Delay interval must sum to 1" | ||
|
||
K = generate_observation_kernel(delay_int, time_horizon) | ||
|
||
new{eltype(K), typeof(neg_bin_cluster_factor_prior)}(K, | ||
neg_bin_cluster_factor_prior) | ||
end | ||
|
||
function DelayObservations(; | ||
delay_distribution::ContinuousDistribution, | ||
time_horizon::Integer, | ||
neg_bin_cluster_factor_prior::Sampleable, | ||
D_delay, | ||
Δd = 1.0) | ||
delay_int = create_discrete_pmf(delay_distribution; Δd = Δd, D = D_delay) | ||
return DelayObservations(delay_int, time_horizon, neg_bin_cluster_factor_prior) | ||
end | ||
end | ||
|
||
function default_delay_obs_priors() | ||
return (:neg_bin_cluster_factor_prior => Gamma(3, 0.05 / 3),) |> Dict | ||
end | ||
|
||
function generate_observations(observation_model::AbstractObservationModel, | ||
y_t, | ||
I_t; | ||
pos_shift) | ||
@info "No concrete implementation for generate_observations is defined." | ||
return nothing | ||
end | ||
|
||
@model function generate_observations(observation_model::DelayObservations, | ||
y_t, | ||
I_t; | ||
pos_shift) | ||
#Parameters | ||
neg_bin_cluster_factor ~ observation_model.neg_bin_cluster_factor_prior | ||
|
||
#Predictive distribution | ||
case_pred_dists = (observation_model.delay_kernel * I_t) .+ pos_shift .|> | ||
μ -> mean_cc_neg_bin(μ, neg_bin_cluster_factor) | ||
|
||
#Likelihood | ||
y_t ~ arraydist(case_pred_dists) | ||
|
||
return y_t, (; neg_bin_cluster_factor,) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
|
||
@testitem "EpiData constructor" begin | ||
gen_int = [0.2, 0.3, 0.5] | ||
transformation = exp | ||
|
||
data = EpiData(gen_int, transformation) | ||
|
||
@test length(data.gen_int) == 3 | ||
@test data.len_gen_int == 3 | ||
@test sum(data.gen_int) ≈ 1 | ||
@test data.transformation(0.0) == 1.0 | ||
end | ||
|
||
@testitem "EpiData constructor with distributions" begin | ||
using Distributions | ||
|
||
gen_distribution = Uniform(0.0, 10.0) | ||
cluster_coeff = 0.8 | ||
time_horizon = 10 | ||
D_gen = 10.0 | ||
Δd = 1.0 | ||
|
||
data = EpiData(gen_distribution; | ||
D_gen = 10.0) | ||
|
||
@test data.len_gen_int == Int64(D_gen / Δd) - 1 | ||
|
||
@test sum(data.gen_int) ≈ 1 | ||
end | ||
|
||
@testitem "Renewal function: internal generate infs" begin | ||
using LinearAlgebra, Distributions | ||
gen_int = [0.2, 0.3, 0.5] | ||
delay_int = [0.1, 0.4, 0.5] | ||
cluster_coeff = 0.8 | ||
time_horizon = 10 | ||
transformation = exp | ||
|
||
data = EpiData(gen_int, transformation) | ||
epi_model = Renewal(data, Normal()) | ||
|
||
function generate_infs(recent_incidence, Rt) | ||
new_incidence = Rt * dot(recent_incidence, epi_model.data.gen_int) | ||
[new_incidence; recent_incidence[1:(epi_model.data.len_gen_int - 1)]], new_incidence | ||
end | ||
|
||
recent_incidence = [10, 20, 30] | ||
Rt = 1.5 | ||
|
||
expected_new_incidence = Rt * dot(recent_incidence, [0.2, 0.3, 0.5]) | ||
expected_output = [expected_new_incidence; recent_incidence[1:2]], | ||
expected_new_incidence | ||
|
||
@test generate_infs(recent_incidence, Rt) == expected_output | ||
end | ||
|
||
@testitem "generate_latent_infs dispatched on ExpGrowthRate" begin | ||
using Distributions, Turing, HypothesisTests, DynamicPPL | ||
gen_int = [0.2, 0.3, 0.5] | ||
transformation = exp | ||
|
||
data = EpiData(gen_int, transformation) | ||
log_init_incidence_prior = Normal() | ||
rt_model = ExpGrowthRate(data, log_init_incidence_prior) | ||
|
||
#Example incidence data | ||
recent_incidence = [10.0, 20.0, 30.0] | ||
log_init = log(5.0) | ||
rt = [log(recent_incidence[1]) - log_init; diff(log.(recent_incidence))] | ||
|
||
#Check log_init is sampled from the correct distribution | ||
sample_init_inc = sample(EpiAware.generate_latent_infs(rt_model, rt), Prior(), 1000) |> | ||
chn -> chn[:init_incidence] |> | ||
Array |> | ||
vec | ||
|
||
ks_test_pval = ExactOneSampleKSTest(sample_init_inc, log_init_incidence_prior) |> pvalue | ||
@test ks_test_pval > 1e-6 #Very unlikely to fail if the model is correctly implemented | ||
|
||
#Check that the generated incidence is correct given correct initialisation | ||
mdl_incidence = generated_quantities(EpiAware.generate_latent_infs(rt_model, rt), | ||
(init_incidence = log_init,)) | ||
@test mdl_incidence ≈ recent_incidence | ||
end | ||
|
||
@testitem "generate_latent_infs dispatched on DirectInfections" begin | ||
using Distributions, Turing, HypothesisTests, DynamicPPL | ||
gen_int = [0.2, 0.3, 0.5] | ||
transformation = exp | ||
|
||
data = EpiData(gen_int, transformation) | ||
log_init_incidence_prior = Normal() | ||
|
||
direct_inf_model = DirectInfections(data, log_init_incidence_prior) | ||
|
||
log_init_scale = log(1.0) | ||
log_incidence = [10, 20, 30] .|> log | ||
expected_incidence = exp.(log_init_scale .+ log_incidence) | ||
|
||
#Check log_init is sampled from the correct distribution | ||
sample_init_inc = sample( | ||
EpiAware.generate_latent_infs(direct_inf_model, log_incidence), | ||
Prior(), 1000) |> | ||
chn -> chn[:init_incidence] |> | ||
Array |> | ||
vec | ||
|
||
ks_test_pval = ExactOneSampleKSTest(sample_init_inc, log_init_incidence_prior) |> pvalue | ||
@test ks_test_pval > 1e-6 #Very unlikely to fail if the model is correctly implemented | ||
|
||
#Check that the generated incidence is correct given correct initialisation | ||
mdl_incidence = generated_quantities( | ||
EpiAware.generate_latent_infs(direct_inf_model, | ||
log_incidence), | ||
(init_incidence = log_init_scale,)) | ||
|
||
@test mdl_incidence ≈ expected_incidence | ||
end | ||
@testitem "generate_latent_infs function: default" begin | ||
latent_model = [0.1, 0.2, 0.3] | ||
init_incidence = 10.0 | ||
|
||
struct TestEpiModel <: EpiAware.AbstractEpiModel | ||
end | ||
|
||
@test isnothing(EpiAware.generate_latent_infs(TestEpiModel(), latent_model)) | ||
end | ||
@testitem "generate_latent_infs dispatched on Renewal" begin | ||
using Distributions, Turing, HypothesisTests, DynamicPPL, LinearAlgebra | ||
gen_int = [0.2, 0.3, 0.5] | ||
transformation = exp | ||
|
||
data = EpiData(gen_int, transformation) | ||
log_init_incidence_prior = Normal() | ||
|
||
renewal_model = Renewal(data, log_init_incidence_prior) | ||
|
||
#Actual Rt | ||
Rt = [1.0, 1.2, 1.5, 1.5, 1.5] | ||
log_Rt = log.(Rt) | ||
initial_incidence = [1.0, 1.0, 1.0]#aligns with initial exp growth rate of 0. | ||
|
||
#Check log_init is sampled from the correct distribution | ||
@time sample_init_inc = sample(EpiAware.generate_latent_infs(renewal_model, log_Rt), | ||
Prior(), 1000) |> | ||
chn -> chn[:init_incidence] |> | ||
Array |> | ||
vec | ||
|
||
ks_test_pval = ExactOneSampleKSTest(sample_init_inc, log_init_incidence_prior) |> pvalue | ||
@test ks_test_pval > 1e-6 #Very unlikely to fail if the model is correctly implemented | ||
|
||
#Check that the generated incidence is correct given correct initialisation | ||
#Check first three days "by hand" | ||
mdl_incidence = generated_quantities( | ||
EpiAware.generate_latent_infs(renewal_model, | ||
log_Rt), (init_incidence = 0.0,)) | ||
|
||
day1_incidence = dot(initial_incidence, gen_int) * Rt[1] | ||
day2_incidence = dot(initial_incidence, gen_int) * Rt[2] | ||
day3_incidence = dot([day2_incidence, 1.0, 1.0], gen_int) * Rt[3] | ||
|
||
@test mdl_incidence[1:3] ≈ [day1_incidence, day2_incidence, day3_incidence] | ||
end |
Oops, something went wrong.