Skip to content

Commit

Permalink
move abstract types into a single script and add a AbstractModel type
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs committed Feb 28, 2024
1 parent 8f0c5df commit c1f0781
Show file tree
Hide file tree
Showing 8 changed files with 492 additions and 0 deletions.
1 change: 1 addition & 0 deletions EpiAware/src/EpiAware.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ export EpiData, Renewal, ExpGrowthRate, DirectInfections, AbstractEpiModel,
# Exported Turing model constructors
export make_epi_aware

include("abstract-types.jl")
include("epi-models.jl")
include("utilities.jl")
include("latent-models.jl")
Expand Down
7 changes: 7 additions & 0 deletions EpiAware/src/abstract-types.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
abstract type AbstractModel end

abstract type AbstractEpiModel <: AbstractModel end

abstract type AbstractLatentModel <: AbstractModel end

abstract type AbstractObservationModel <: AbstractModel end
116 changes: 116 additions & 0 deletions EpiAware/src/epi-models.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
struct EpiData{T <: Real, F <: Function}
gen_int::Vector{T}
len_gen_int::Integer
transformation::F

#Inner constructors for EpiData object
function EpiData(gen_int,
transformation::Function)
@assert all(gen_int .>= 0) "Generation interval must be non-negative"
@assert sum(gen_int)1 "Generation interval must sum to 1"

new{eltype(gen_int), typeof(transformation)}(gen_int,
length(gen_int),
transformation)
end

function EpiData(gen_distribution::ContinuousDistribution;
D_gen,
Δd = 1.0,
transformation::Function = exp)
gen_int = create_discrete_pmf(gen_distribution, Δd = Δd, D = D_gen) |>
p -> p[2:end] ./ sum(p[2:end])

return EpiData(gen_int, transformation)
end
end

struct DirectInfections{S <: Sampleable} <: AbstractEpiModel
data::EpiData
initialisation_prior::S
end

struct ExpGrowthRate{S <: Sampleable} <: AbstractEpiModel
data::EpiData
initialisation_prior::S
end

struct Renewal{S <: Sampleable} <: AbstractEpiModel
data::EpiData
initialisation_prior::S
end

"""
function (epi_model::Renewal)(recent_incidence, Rt)
Compute new incidence based on recent incidence and Rt.
This is a callable function on `Renewal` structs, that encodes new incidence prediction
given recent incidence and Rt according to basic renewal process.
```math
I_t = R_t \\sum_{i=1}^{n-1} I_{t-i} g_i
```
where `I_t` is the new incidence, `R_t` is the reproduction number, `I_{t-i}` is the recent incidence
and `g_i` is the generation interval.
# Arguments
- `recent_incidence`: Array of recent incidence values.
- `Rt`: Reproduction number.
# Returns
- Tuple containing the updated incidence array and the new incidence value.
"""
function (epi_model::Renewal)(recent_incidence, Rt)
new_incidence = Rt * dot(recent_incidence, epi_model.data.gen_int)
return ([new_incidence; recent_incidence[1:(epi_model.data.len_gen_int - 1)]],
new_incidence)
end

function generate_latent_infs(epi_model::AbstractEpiModel, latent_model)
@info "No concrete implementation for `generate_latent_infs` is defined."
return nothing
end

@model function generate_latent_infs(epi_model::DirectInfections, _It)
init_incidence ~ epi_model.initialisation_prior
return epi_model.data.transformation.(init_incidence .+ _It)
end

@model function generate_latent_infs(epi_model::ExpGrowthRate, rt)
init_incidence ~ epi_model.initialisation_prior
return exp.(init_incidence .+ cumsum(rt))
end

"""
generate_latent_infs(epi_model::Renewal, _Rt)
`Turing` model constructor for latent infections using the `Renewal` object `epi_model` and time-varying unconstrained reproduction number `_Rt`.
`generate_latent_infs` creates a `Turing` model for sampling latent infections with given unconstrainted
reproduction number `_Rt` but random initial incidence scale. The initial incidence pre-time one is given as
a scale on top of an exponential growing process with exponential growth rate given by `R_to_r`applied to the
first value of `Rt`.
# Arguments
- `epi_model::Renewal`: The epidemiological model.
- `_Rt`: Time-varying unconstrained (e.g. log-) reproduction number.
# Returns
- `I_t`: Array of latent infections over time.
"""
@model function generate_latent_infs(epi_model::Renewal, _Rt)
init_incidence ~ epi_model.initialisation_prior
I₀ = epi_model.data.transformation(init_incidence)
Rt = epi_model.data.transformation.(_Rt)

r_approx = R_to_r(Rt[1], epi_model)
init = I₀ * [exp(-r_approx * t) for t in 0:(epi_model.data.len_gen_int - 1)]

I_t, _ = scan(epi_model, init, Rt)
return I_t
end
28 changes: 28 additions & 0 deletions EpiAware/src/latent-models.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
struct RandomWalk{D <: Sampleable, S <: Sampleable} <: AbstractLatentModel
init_prior::D
var_prior::S
end

function default_rw_priors()
return (:var_RW_prior => truncated(Normal(0.0, 0.05), 0.0, Inf),
:init_rw_value_prior => Normal()) |> Dict
end

function generate_latent(latent_model::AbstractLatentModel, n)
@info "No concrete implementation for generate_latent is defined."
return nothing
end

@model function generate_latent(latent_model::RandomWalk, n)
ϵ_t ~ MvNormal(ones(n))
σ²_RW ~ latent_model.var_prior
rw_init ~ latent_model.init_prior
σ_RW = sqrt(σ²_RW)
rw = Vector{eltype(ϵ_t)}(undef, n)

rw[1] = rw_init + σ_RW * ϵ_t[1]
for t in 2:n
rw[t] = rw[t - 1] + σ_RW * ϵ_t[t]
end
return rw, (; σ_RW, rw_init)
end
55 changes: 55 additions & 0 deletions EpiAware/src/observation-models.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
struct DelayObservations{T <: AbstractFloat, S <: Sampleable} <: AbstractObservationModel
delay_kernel::SparseMatrixCSC{T, Integer}
neg_bin_cluster_factor_prior::S

function DelayObservations(delay_int,
time_horizon,
neg_bin_cluster_factor_prior)
@assert all(delay_int .>= 0) "Delay interval must be non-negative"
@assert sum(delay_int)1 "Delay interval must sum to 1"

K = generate_observation_kernel(delay_int, time_horizon)

new{eltype(K), typeof(neg_bin_cluster_factor_prior)}(K,
neg_bin_cluster_factor_prior)
end

function DelayObservations(;
delay_distribution::ContinuousDistribution,
time_horizon::Integer,
neg_bin_cluster_factor_prior::Sampleable,
D_delay,
Δd = 1.0)
delay_int = create_discrete_pmf(delay_distribution; Δd = Δd, D = D_delay)
return DelayObservations(delay_int, time_horizon, neg_bin_cluster_factor_prior)
end
end

function default_delay_obs_priors()
return (:neg_bin_cluster_factor_prior => Gamma(3, 0.05 / 3),) |> Dict
end

function generate_observations(observation_model::AbstractObservationModel,
y_t,
I_t;
pos_shift)
@info "No concrete implementation for generate_observations is defined."
return nothing
end

@model function generate_observations(observation_model::DelayObservations,
y_t,
I_t;
pos_shift)
#Parameters
neg_bin_cluster_factor ~ observation_model.neg_bin_cluster_factor_prior

#Predictive distribution
case_pred_dists = (observation_model.delay_kernel * I_t) .+ pos_shift .|>
μ -> mean_cc_neg_bin(μ, neg_bin_cluster_factor)

#Likelihood
y_t ~ arraydist(case_pred_dists)

return y_t, (; neg_bin_cluster_factor,)
end
164 changes: 164 additions & 0 deletions EpiAware/test/test_epi-models.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@

@testitem "EpiData constructor" begin
gen_int = [0.2, 0.3, 0.5]
transformation = exp

data = EpiData(gen_int, transformation)

@test length(data.gen_int) == 3
@test data.len_gen_int == 3
@test sum(data.gen_int) 1
@test data.transformation(0.0) == 1.0
end

@testitem "EpiData constructor with distributions" begin
using Distributions

gen_distribution = Uniform(0.0, 10.0)
cluster_coeff = 0.8
time_horizon = 10
D_gen = 10.0
Δd = 1.0

data = EpiData(gen_distribution;
D_gen = 10.0)

@test data.len_gen_int == Int64(D_gen / Δd) - 1

@test sum(data.gen_int) 1
end

@testitem "Renewal function: internal generate infs" begin
using LinearAlgebra, Distributions
gen_int = [0.2, 0.3, 0.5]
delay_int = [0.1, 0.4, 0.5]
cluster_coeff = 0.8
time_horizon = 10
transformation = exp

data = EpiData(gen_int, transformation)
epi_model = Renewal(data, Normal())

function generate_infs(recent_incidence, Rt)
new_incidence = Rt * dot(recent_incidence, epi_model.data.gen_int)
[new_incidence; recent_incidence[1:(epi_model.data.len_gen_int - 1)]], new_incidence
end

recent_incidence = [10, 20, 30]
Rt = 1.5

expected_new_incidence = Rt * dot(recent_incidence, [0.2, 0.3, 0.5])
expected_output = [expected_new_incidence; recent_incidence[1:2]],
expected_new_incidence

@test generate_infs(recent_incidence, Rt) == expected_output
end

@testitem "generate_latent_infs dispatched on ExpGrowthRate" begin
using Distributions, Turing, HypothesisTests, DynamicPPL
gen_int = [0.2, 0.3, 0.5]
transformation = exp

data = EpiData(gen_int, transformation)
log_init_incidence_prior = Normal()
rt_model = ExpGrowthRate(data, log_init_incidence_prior)

#Example incidence data
recent_incidence = [10.0, 20.0, 30.0]
log_init = log(5.0)
rt = [log(recent_incidence[1]) - log_init; diff(log.(recent_incidence))]

#Check log_init is sampled from the correct distribution
sample_init_inc = sample(EpiAware.generate_latent_infs(rt_model, rt), Prior(), 1000) |>
chn -> chn[:init_incidence] |>
Array |>
vec

ks_test_pval = ExactOneSampleKSTest(sample_init_inc, log_init_incidence_prior) |> pvalue
@test ks_test_pval > 1e-6 #Very unlikely to fail if the model is correctly implemented

#Check that the generated incidence is correct given correct initialisation
mdl_incidence = generated_quantities(EpiAware.generate_latent_infs(rt_model, rt),
(init_incidence = log_init,))
@test mdl_incidence recent_incidence
end

@testitem "generate_latent_infs dispatched on DirectInfections" begin
using Distributions, Turing, HypothesisTests, DynamicPPL
gen_int = [0.2, 0.3, 0.5]
transformation = exp

data = EpiData(gen_int, transformation)
log_init_incidence_prior = Normal()

direct_inf_model = DirectInfections(data, log_init_incidence_prior)

log_init_scale = log(1.0)
log_incidence = [10, 20, 30] .|> log
expected_incidence = exp.(log_init_scale .+ log_incidence)

#Check log_init is sampled from the correct distribution
sample_init_inc = sample(
EpiAware.generate_latent_infs(direct_inf_model, log_incidence),
Prior(), 1000) |>
chn -> chn[:init_incidence] |>
Array |>
vec

ks_test_pval = ExactOneSampleKSTest(sample_init_inc, log_init_incidence_prior) |> pvalue
@test ks_test_pval > 1e-6 #Very unlikely to fail if the model is correctly implemented

#Check that the generated incidence is correct given correct initialisation
mdl_incidence = generated_quantities(
EpiAware.generate_latent_infs(direct_inf_model,
log_incidence),
(init_incidence = log_init_scale,))

@test mdl_incidence expected_incidence
end
@testitem "generate_latent_infs function: default" begin
latent_model = [0.1, 0.2, 0.3]
init_incidence = 10.0

struct TestEpiModel <: EpiAware.AbstractEpiModel
end

@test isnothing(EpiAware.generate_latent_infs(TestEpiModel(), latent_model))
end
@testitem "generate_latent_infs dispatched on Renewal" begin
using Distributions, Turing, HypothesisTests, DynamicPPL, LinearAlgebra
gen_int = [0.2, 0.3, 0.5]
transformation = exp

data = EpiData(gen_int, transformation)
log_init_incidence_prior = Normal()

renewal_model = Renewal(data, log_init_incidence_prior)

#Actual Rt
Rt = [1.0, 1.2, 1.5, 1.5, 1.5]
log_Rt = log.(Rt)
initial_incidence = [1.0, 1.0, 1.0]#aligns with initial exp growth rate of 0.

#Check log_init is sampled from the correct distribution
@time sample_init_inc = sample(EpiAware.generate_latent_infs(renewal_model, log_Rt),
Prior(), 1000) |>
chn -> chn[:init_incidence] |>
Array |>
vec

ks_test_pval = ExactOneSampleKSTest(sample_init_inc, log_init_incidence_prior) |> pvalue
@test ks_test_pval > 1e-6 #Very unlikely to fail if the model is correctly implemented

#Check that the generated incidence is correct given correct initialisation
#Check first three days "by hand"
mdl_incidence = generated_quantities(
EpiAware.generate_latent_infs(renewal_model,
log_Rt), (init_incidence = 0.0,))

day1_incidence = dot(initial_incidence, gen_int) * Rt[1]
day2_incidence = dot(initial_incidence, gen_int) * Rt[2]
day3_incidence = dot([day2_incidence, 1.0, 1.0], gen_int) * Rt[3]

@test mdl_incidence[1:3] [day1_incidence, day2_incidence, day3_incidence]
end
Loading

0 comments on commit c1f0781

Please sign in to comment.