Skip to content

Commit

Permalink
Merge pull request #79 from CDCgov/refactor-for-more-multi-dispatch
Browse files Browse the repository at this point in the history
Refactor for more multi dispatch
  • Loading branch information
seabbs authored Feb 27, 2024
2 parents a488736 + 5533e53 commit 8b9edf9
Show file tree
Hide file tree
Showing 18 changed files with 506 additions and 451 deletions.
79 changes: 42 additions & 37 deletions EpiAware/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,63 +9,68 @@
- Solid lines indicate implemented features/analysis.
- Dashed lines indicate planned features/analysis.

## Proposed `EpiAware` model diagram
## Current `EpiAware` model diagram
```mermaid
flowchart LR
A["Underlying dists.
and specify length of sims
---------------------
EpiData"]
A["Underlying GI
Bijector"]
EpiModel["AbstractEpiModel
----------------------
Choice of target
for latent process:
B["Choice of target
for latent process
---------------------
DirectInfections
ExpGrowthRate
Renewal"]
C["Observational Data
InitModel["Priors for
initial scale of incidence"]
DataW[Data wrangling and QC]
ObsData["Observational Data
---------------------
Obs. cases y_t"]
D["Latent processes
LatentProcPriors["Latent process priors"]
LatentProc["AbstractLatentProcess
---------------------
RandomWalkLatentProcess"]
ObsModelPriors["Observation model priors
choice of delayed obs. model"]
ObsModel["AbstractObservationModel
---------------------
random_walk"]
DelayObservations"]
E["Turing model constructor
---------------------
make_epi_inference_model"]
F["Latent Process priors
---------------------
default_rw_priors"]
G[Posterior draws]
H[Posterior checking]
I[Post-processing]
DataW[Data wrangling and QC]
J["Observation models
---------------------
delay_observations"]
K["Observation model priors
---------------------
default_delay_obs_priors"]
ObservationModel["ObservationModel
---------------------
delay_observations_model"]
LatentProcess["LatentProcess
---------------------
random_walk_process"]
A --> EpiModel
B --> EpiModel
A --> EpiData
EpiData --> EpiModel
InitModel --> EpiModel
EpiModel -->E
C-->E
D-->LatentProcess
F-->LatentProcess
J-->ObservationModel
K-->ObservationModel
LatentProcess-->E
ObservationModel-->E
ObsData-->E
DataW-.->ObsData
LatentProcPriors-->LatentProc
LatentProc-->E
ObsModelPriors-->ObsModel
ObsModel-->E
E-->|sample...NUTS...| G
G-.->H
H-.->I
DataW-.->C
```
7 changes: 2 additions & 5 deletions EpiAware/src/EpiAware.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,18 @@ using Distributions,
DataFramesMeta

# Exported utilities
export create_discrete_pmf, default_rw_priors, default_delay_obs_priors,
default_initialisation_prior, spread_draws
export create_discrete_pmf, spread_draws

# Exported types
export EpiData, Renewal, ExpGrowthRate, DirectInfections

# Exported Turing model constructors
export make_epi_inference_model, delay_observations_model, random_walk_process,
initialize_incidence
export make_epi_inference_model

include("epimodel.jl")
include("utilities.jl")
include("latent-processes.jl")
include("observation-processes.jl")
include("initialisation.jl")
include("models.jl")

end
128 changes: 77 additions & 51 deletions EpiAware/src/epimodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,91 +2,117 @@ abstract type AbstractEpiModel end

struct EpiData{T <: Real, F <: Function}
gen_int::Vector{T}
delay_int::Vector{T}
delay_kernel::SparseMatrixCSC{T, Integer}
cluster_coeff::T
len_gen_int::Integer
len_delay_int::Integer
time_horizon::Integer
transformation::F

#Inner constructors for EpiData object
function EpiData(
gen_int,
delay_int,
cluster_coeff,
time_horizon::Integer,
transformation::Function
)
function EpiData(gen_int,
transformation::Function)
@assert all(gen_int .>= 0) "Generation interval must be non-negative"
@assert all(delay_int .>= 0) "Delay interval must be non-negative"
@assert sum(gen_int)1 "Generation interval must sum to 1"
@assert sum(delay_int)1 "Delay interval must sum to 1"

K = generate_observation_kernel(delay_int, time_horizon)

new{eltype(gen_int), typeof(transformation)}(
gen_int,
delay_int,
K,
cluster_coeff,
new{eltype(gen_int), typeof(transformation)}(gen_int,
length(gen_int),
length(delay_int),
time_horizon,
transformation
)
transformation)
end

function EpiData(
gen_distribution::ContinuousDistribution,
delay_distribution::ContinuousDistribution,
cluster_coeff,
time_horizon::Integer;
function EpiData(gen_distribution::ContinuousDistribution;
D_gen,
D_delay,
Δd = 1.0,
transformation::Function = exp
)
transformation::Function = exp)
gen_int = create_discrete_pmf(gen_distribution, Δd = Δd, D = D_gen) |>
p -> p[2:end] ./ sum(p[2:end])
delay_int = create_discrete_pmf(delay_distribution, Δd = Δd, D = D_delay)

return EpiData(gen_int, delay_int, cluster_coeff, time_horizon, transformation)
return EpiData(gen_int, transformation)
end
end

struct DirectInfections <: AbstractEpiModel
struct DirectInfections{S <: Sampleable} <: AbstractEpiModel
data::EpiData
initialisation_prior::S
end

function (epimodel::DirectInfections)(_It, init)
epimodel.data.transformation.(init .+ _It)
struct ExpGrowthRate{S <: Sampleable} <: AbstractEpiModel
data::EpiData
initialisation_prior::S
end

struct ExpGrowthRate <: AbstractEpiModel
struct Renewal{S <: Sampleable} <: AbstractEpiModel
data::EpiData
initialisation_prior::S
end

function (epimodel::ExpGrowthRate)(rt, init)
init .+ cumsum(rt) .|> exp
"""
function (epimodel::Renewal)(recent_incidence, Rt)
Compute new incidence based on recent incidence and Rt.
This is a callable function on `Renewal` structs, that encodes new incidence prediction
given recent incidence and Rt according to basic renewal process.
```math
I_t = R_t \\sum_{i=1}^{n-1} I_{t-i} g_i
```
where `I_t` is the new incidence, `R_t` is the reproduction number, `I_{t-i}` is the recent incidence
and `g_i` is the generation interval.
# Arguments
- `recent_incidence`: Array of recent incidence values.
- `Rt`: Reproduction number.
# Returns
- Tuple containing the updated incidence array and the new incidence value.
"""
function (epimodel::Renewal)(recent_incidence, Rt)
new_incidence = Rt * dot(recent_incidence, epimodel.data.gen_int)
return ([new_incidence; recent_incidence[1:(epimodel.data.len_gen_int - 1)]],
new_incidence)
end

struct Renewal <: AbstractEpiModel
data::EpiData
function generate_latent_infs(epimodel::AbstractEpiModel, latent_process)
@info "No concrete implementation for `generate_latent_infs` is defined."
return nothing
end

function (epimodel::Renewal)(_Rt, init)
I₀ = epimodel.data.transformation(init)
@model function generate_latent_infs(epimodel::DirectInfections, _It)
init_incidence ~ epimodel.initialisation_prior
return epimodel.data.transformation.(init_incidence .+ _It)
end

@model function generate_latent_infs(epimodel::ExpGrowthRate, rt)
init_incidence ~ epimodel.initialisation_prior
return exp.(init_incidence .+ cumsum(rt))
end

"""
generate_latent_infs(epimodel::Renewal, _Rt)
`Turing` model constructor for latent infections using the `Renewal` object `epimodel` and time-varying unconstrained reproduction number `_Rt`.
`generate_latent_infs` creates a `Turing` model for sampling latent infections with given unconstrainted
reproduction number `_Rt` but random initial incidence scale. The initial incidence pre-time one is given as
a scale on top of an exponential growing process with exponential growth rate given by `R_to_r`applied to the
first value of `Rt`.
# Arguments
- `epimodel::Renewal`: The epidemiological model.
- `_Rt`: Time-varying unconstrained (e.g. log-) reproduction number.
# Returns
- `I_t`: Array of latent infections over time.
"""
@model function generate_latent_infs(epimodel::Renewal, _Rt)
init_incidence ~ epimodel.initialisation_prior
I₀ = epimodel.data.transformation(init_incidence)
Rt = epimodel.data.transformation.(_Rt)

r_approx = R_to_r(Rt[1], epimodel)
init = I₀ * [exp(-r_approx * t) for t in 0:(epimodel.data.len_gen_int - 1)]

function generate_infs(recent_incidence, Rt)
new_incidence = Rt * dot(recent_incidence, epimodel.data.gen_int)
[new_incidence; recent_incidence[1:(epimodel.data.len_gen_int - 1)]], new_incidence
end

I_t, _ = scan(generate_infs, init, Rt)
I_t, _ = scan(epimodel, init, Rt)
return I_t
end
30 changes: 0 additions & 30 deletions EpiAware/src/initialisation.jl

This file was deleted.

Loading

0 comments on commit 8b9edf9

Please sign in to comment.