Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a proposed modular API to specify a data (i.e. case) generating process #39

Merged
merged 7 commits into from
Feb 15, 2024
11 changes: 8 additions & 3 deletions EpiAware/src/EpiAware.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,18 @@ using Distributions,
Parameters,
QuadGK

# Exported utilities
export scan,
create_discrete_pmf,
growth_rate_to_reproductive_ratio,
generate_observation_kernel,
EpiModel,
log_infections,
random_walk
default_rw_priors

# Exported types
export EpiData, Renewal, ExpGrowthRate, DirectInfections

# Exported Turing model constructors
export make_epi_inference_model, random_walk

include("utilities.jl")
include("epimodel.jl")
Expand Down
91 changes: 39 additions & 52 deletions EpiAware/src/epimodel.jl
Original file line number Diff line number Diff line change
@@ -1,96 +1,83 @@
abstract type AbstractEpiModel end



"""
struct EpiModel{T<:Real} <: AbstractEpiModel

EpiModel represents an epidemiological model with generation intervals, delay intervals, and observation delay kernel.

# Fields
- `gen_int::Vector{T}`: Discrete generation inteval, runs from 1, 2, ... to the end of the vector.
- `delay_int::Vector{T}`: Discrete delay distribution runs from 0, 1, ... to the end of the vector less 1.
- `delay_kernel::SparseMatrixCSC{T,Integer}`: Sparse matrix representing the observation delay kernel.
- `cluster_coeff::T`: Cluster coefficient for negative binomial observations.
- `len_gen_int::Integer`: Length of `gen_int`.
- `len_delay_int::Integer`: Length of `delay_int`.
- `time_horizon::Integer`: Length of the generated data.

# Constructors
- `EpiModel(gen_int, delay_int, cluster_coeff, time_horizon::Integer)`: Constructs an EpiModel object with given generation intervals, delay intervals, cluster coefficient, and time horizon.
- `EpiModel(gen_distribution::ContinuousDistribution, delay_distribution::ContinuousDistribution, cluster_coeff, time_horizon::Integer; Δd = 1.0, D_gen, D_delay)`: Constructs an EpiModel object with generation and delay distributions, cluster coefficient, time horizon, and optional parameters.

"""
struct EpiModel{T<:Real} <: AbstractEpiModel
struct EpiData{T<:Real,F<:Function}
gen_int::Vector{T}
delay_int::Vector{T}
delay_kernel::SparseMatrixCSC{T,Integer}
cluster_coeff::T
len_gen_int::Integer #length(gen_int) just to save recalc
len_delay_int::Integer #length(delay_int) just to save recalc
len_gen_int::Integer
len_delay_int::Integer
time_horizon::Integer
transformation::F

#Inner constructors for EpiModel object
function EpiModel(gen_int, delay_int, cluster_coeff, time_horizon::Integer)
#Inner constructors for EpiData object
function EpiData(
gen_int,
delay_int,
cluster_coeff,
time_horizon::Integer,
transformation::Function,
)
@assert all(gen_int .>= 0) "Generation interval must be non-negative"
@assert all(delay_int .>= 0) "Delay interval must be non-negative"
@assert sum(gen_int) ≈ 1 "Generation interval must sum to 1"
@assert sum(delay_int) ≈ 1 "Delay interval must sum to 1"

K = generate_observation_kernel(delay_int, time_horizon)

new{eltype(gen_int)}(
new{eltype(gen_int),typeof(transformation)}(
gen_int,
delay_int,
K,
cluster_coeff,
length(gen_int),
length(delay_int),
time_horizon,
transformation,
)
end

function EpiModel(
function EpiData(
gen_distribution::ContinuousDistribution,
delay_distribution::ContinuousDistribution,
cluster_coeff,
time_horizon::Integer;
Δd = 1.0,
D_gen,
D_delay,
Δd = 1.0,
transformation::Function = exp,
)
gen_int =
create_discrete_pmf(gen_distribution, Δd = Δd, D = D_gen) |>
p -> p[2:end] ./ sum(p[2:end])
delay_int = create_discrete_pmf(delay_distribution, Δd = Δd, D = D_delay)

K = generate_observation_kernel(delay_int, time_horizon)

new{eltype(gen_int)}(
gen_int,
delay_int,
K,
cluster_coeff,
length(gen_int),
length(delay_int),
time_horizon,
)
return EpiData(gen_int, delay_int, cluster_coeff, time_horizon, transformation)
end
end

"""
(epi_model::EpiModel)(recent_incidence, Rt)
struct DirectInfections <: AbstractEpiModel
data::EpiData
end

Apply the EpiModel to calculate new incidence based on recent incidence and Rt.
function (epi_model::DirectInfections)(recent_incidence, _I_t)
nothing, epi_model.data.transformation(_I_t)
end

struct ExpGrowthRate <: AbstractEpiModel
data::EpiData
end

# Arguments
- `recent_incidence`: Array of recent incidence values.
- `Rt`: Reproduction number.
function (epi_model::ExpGrowthRate)(recent_incidence, rt)
new_incidence = recent_incidence * exp(rt)
seabbs marked this conversation as resolved.
Show resolved Hide resolved
new_incidence, new_incidence
end

struct Renewal <: AbstractEpiModel
data::EpiData
end

# Returns
- `new_incidence`: Array of new incidence values.
"""
function (epi_model::EpiModel)(recent_incidence, Rt)
new_incidence = Rt * dot(recent_incidence, epi_model.gen_int)
[new_incidence; recent_incidence[1:(epi_model.len_gen_int-1)]], new_incidence
function (epi_model::Renewal)(recent_incidence, Rt)
new_incidence = Rt * dot(recent_incidence, epi_model.data.gen_int)
[new_incidence; recent_incidence[1:(epi_model.data.len_gen_int-1)]], new_incidence
end
35 changes: 14 additions & 21 deletions EpiAware/src/latent-processes.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,25 @@
const STANDARD_RW_PRIORS =
(var_RW_dist = truncated(Normal(0.0, 0.05), 0.0, Inf), init_rw_value_dist = Normal())


"""
random_walk(n, ϵ_t = missing; latent_process_priors = (var_RW_dist = truncated(Normal(0., 0.05), 0., Inf),), ::Type{T} = Float64) where {T <: Real}

Constructs a random walk model.

# Arguments
- `n`: The number of time steps.
- `ϵ_t`: The random noise vector. Defaults to `missing`, in which case it is sampled from the standard multivariate normal distribution.
- `latent_process_priors`: The prior distribution for the latent process parameters. Defaults to `(var_RW_dist = truncated(Normal(0., 0.05), 0., Inf),)`.
function default_rw_priors()
return (
seabbs marked this conversation as resolved.
Show resolved Hide resolved
var_RW_dist = truncated(Normal(0.0, 0.05), 0.0, Inf),
init_rw_value_dist = Normal(),
)
end

# Returns
- `rw`: The random walk process.
- `σ_RW`: The standard deviation of the random walk process.
"""
@model function random_walk(
n,
ϵ_t = missing,
::Type{T} = Float64;
latent_process_priors = STANDARD_RW_PRIORS,
) where {T<:Real}
latent_process_priors = default_rw_priors(),
) where {T<:AbstractFloat}
rw = Vector{T}(undef, n)
ϵ_t ~ MvNormal(ones(n))
σ²_RW ~ latent_process_priors.var_RW_dist
init_rw_value ~ latent_process_priors.init_rw_value_dist
σ_RW = sqrt(σ²_RW)
rw .= init_rw_value .+ cumsum(σ_RW * ϵ_t)
return rw, (; σ_RW, init_rw_value)

rw[1] = init_rw_value + σ_RW * ϵ_t[1]
for t = 2:n
seabbs marked this conversation as resolved.
Show resolved Hide resolved
rw[t] = rw[t-1] + σ_RW * ϵ_t[t]
end
return rw, (; σ_RW, init_rw_value, init = rw[1])
end
43 changes: 6 additions & 37 deletions EpiAware/src/models.jl
Original file line number Diff line number Diff line change
@@ -1,39 +1,8 @@
"""
log_infections(y_t, epimodel::EpiModel, latent_process;
latent_process_priors,
transform_function = exp,
n_generate_ahead = 0,
pos_shift = 1e-6,
neg_bin_cluster_factor = missing,
neg_bin_cluster_factor_prior = Gamma(3, 0.05 / 3))

A Turing model for Log-infections undelying observed epidemiological data.

This function defines a log-infections model for epidemiological data.
It takes the observed data `y_t`, an `EpiModel` object `epimodel`, and a `latent_process`
model. It also accepts optional arguments for the `latent_process_priors`, `transform_function`,
`n_generate_ahead`, `pos_shift`, `neg_bin_cluster_factor`, and `neg_bin_cluster_factor_prior`.

## Arguments
- `y_t`: Observed data.
- `epimodel`: Epidemiological model.
- `latent_process`: Latent process model.
- `latent_process_priors`: Priors for the latent process model.
- `transform_function`: Function to transform the latent process into infections. Default is `exp`.
- `n_generate_ahead`: Number of time steps to generate ahead. Default is `0`.
- `pos_shift`: Positive shift to avoid zero values. Default is `1e-6`.
- `neg_bin_cluster_factor`: Missing value for the negative binomial cluster factor. Default is `missing`.
- `neg_bin_cluster_factor_prior`: Prior distribution for the negative binomial cluster factor. Default is `Gamma(3, 0.05 / 3)`.

## Returns
A named tuple containing the generated quantities `I_t` and `latent_process_parameters`.
"""
@model function log_infections(
@model function make_epi_inference_model(
y_t,
epimodel::EpiModel,
epimodel::AbstractEpiModel,
latent_process;
seabbs marked this conversation as resolved.
Show resolved Hide resolved
latent_process_priors,
transform_function = exp,
pos_shift = 1e-6,
neg_bin_cluster_factor = missing,
neg_bin_cluster_factor_prior = Gamma(3, 0.05 / 3),
Expand All @@ -42,16 +11,16 @@ A named tuple containing the generated quantities `I_t` and `latent_process_para
neg_bin_cluster_factor ~ neg_bin_cluster_factor_prior

#Latent process
time_steps = epimodel.time_horizon
@submodel _I_t, latent_process_parameters =
time_steps = epimodel.data.time_horizon
@submodel latent_process, latent_process_parameters =
latent_process(time_steps; latent_process_priors = latent_process_priors)

#Transform into infections
I_t = transform_function.(_I_t)
I_t, _ = scan(epimodel, latent_process_parameters.init, latent_process)

#Predictive distribution
case_pred_dists =
(epimodel.delay_kernel * I_t) .+ pos_shift .|>
(epimodel.data.delay_kernel * I_t) .+ pos_shift .|>
μ -> mean_cc_neg_bin(μ, neg_bin_cluster_factor)
seabbs marked this conversation as resolved.
Show resolved Hide resolved

#Likelihood
Expand Down
59 changes: 35 additions & 24 deletions EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using Pkg: generate
#=
# Toy model for running analysis:

Expand Down Expand Up @@ -47,18 +46,16 @@ r &\sim \text{Gamma}(3, 0.05/3).
\end{align}
```

## Load dependencies in `TestEnv`
## Load dependencies

This script should be run from the root folder of `EpiAware` and with the active environment.

=#

split(pwd(), "/")[end] != "EpiAware" && begin
cd("./EpiAware")
using Pkg
Pkg.activate(".")

using TestEnv
TestEnv.activate()
end

using TestEnv # Run in Test environment mode
TestEnv.activate()

using EpiAware
using Turing
Expand All @@ -70,16 +67,18 @@ Random.seed!(0)

#=
## Create an `EpiModel` struct
Somewhat randomly chosen parameters for the `EpiModel` struct.

- Medium length generation interval distribution.
- Median 2 day, std 4.3 day delay distribution.
- 100 days of simulations
=#

truth_GI = Gamma(1, 2)
truth_delay = Uniform(0.0, 5.0)
truth_GI = Gamma(2, 5)
truth_delay = LogNormal(2.0, 1.0)
neg_bin_cluster_factor = 0.05
time_horizon = 100

epimodel = EpiModel(
model_data = EpiData(
truth_GI,
truth_delay,
neg_bin_cluster_factor,
Expand All @@ -89,29 +88,41 @@ epimodel = EpiModel(
)

#=
## Define a log-infections model
The log-infections model is defined by a Turing model `log_infections`.
## Define the data generating process

In this case we don't have observed data, so we use `missing` value for `y_t`.
In this case we use the `DirectInfections` model.
=#
toy_log_infs = log_infections(

toy_log_infs = DirectInfections(model_data)

#=
## Generate a `Turing` `Model`
We don't have observed data, so we use `missing` value for `y_t`.
=#

log_infs_model = make_epi_inference_model(
missing,
epimodel,
random_walk;
latent_process_priors = EpiAware.STANDARD_RW_PRIORS,
toy_log_infs,
random_walk,
latent_process_priors = default_rw_priors(),
pos_shift = 1e-6,
neg_bin_cluster_factor = 0.5,
neg_bin_cluster_factor_prior = Gamma(3, 0.05 / 3),
)



#=
## Sample from the model
I define a fixed version of the model with initial infections set to 10 and variance of the random walk process set to 0.1.
We can sample from the model using the `rand` function, and plot the generated infections against generated cases.
=#
cond_toy = fix(toy_log_infs, (init_rw_value = log(10.0), σ²_RW = 0.1))
random_epidemic = rand(cond_toy)

# We can get the generated infections using `generated_quantities` function. Because the observed
# cases are "defined" with a `~` operator they can be accessed directly from the randomly sampled
# process.

cond_toy = fix(log_infs_model, (init_rw_value = log(10.0), σ²_RW = 0.1))
random_epidemic = rand(cond_toy)
gen = generated_quantities(cond_toy, random_epidemic)
plot(
gen.I_t,
Expand All @@ -120,4 +131,4 @@ plot(
ylabel = "Infections",
title = "Generated Infections",
)
scatter!(X.y_t, lab = "generated cases")
scatter!(random_epidemic.y_t, lab = "generated cases")
Loading
Loading