Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 465: Add an infection generating model for ODE problems #510

Merged
merged 54 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
196aa20
catch docstring now failing because output changed
SamuelBrand1 Oct 17, 2024
6acdb44
add OrdinaryDiffEq deps
SamuelBrand1 Oct 17, 2024
df2bed6
Add InfectionODEProcess
SamuelBrand1 Oct 17, 2024
01e11e7
add unit tests
SamuelBrand1 Oct 17, 2024
3d9d7eb
add benchmark
SamuelBrand1 Oct 17, 2024
08ea0cf
reformat
SamuelBrand1 Oct 17, 2024
d88279e
add line
SamuelBrand1 Oct 17, 2024
98dea81
Update InfectionODEProcess.jl
SamuelBrand1 Oct 17, 2024
c5f47ff
reformat
SamuelBrand1 Oct 18, 2024
5185bf5
rm prefix doctest
SamuelBrand1 Oct 22, 2024
be9b714
rename to ODEProcess
SamuelBrand1 Oct 22, 2024
062bbad
move docs to usual format
SamuelBrand1 Oct 22, 2024
7c4b430
Bring struct defn in line with our style
SamuelBrand1 Oct 23, 2024
e17b935
Merge branch 'main' into sciml-infection-model
SamuelBrand1 Oct 24, 2024
65948d9
Merge branch 'main' into sciml-infection-model
seabbs Oct 24, 2024
8f3fc4b
Merge branch 'sciml-infection-model' of https://github.com/CDCgov/Rt-…
SamuelBrand1 Oct 25, 2024
b35dc2f
move _expand_dist to Utils
SamuelBrand1 Oct 25, 2024
520cb67
Keep _expand_dist as an internal function
SamuelBrand1 Nov 1, 2024
835f3d2
Abstract types for Param types
SamuelBrand1 Nov 1, 2024
4f9f576
initial SIRParam commit
SamuelBrand1 Nov 4, 2024
91b8351
SEIR Params initial commit
SamuelBrand1 Nov 4, 2024
1514562
add kwarg passer to ODEProcess
SamuelBrand1 Nov 4, 2024
dcc47ca
Update EpiInfModels.jl
SamuelBrand1 Nov 4, 2024
dcb2fad
add jac_prototype
SamuelBrand1 Nov 4, 2024
58ac869
Param constructor unit tests
SamuelBrand1 Nov 4, 2024
2e368d9
rm old unit test
SamuelBrand1 Nov 4, 2024
0f63eb2
Make jac prototypes const valued
SamuelBrand1 Nov 4, 2024
1406ecc
test that jac_prototype is stable over jac calls
SamuelBrand1 Nov 4, 2024
86e46b8
make solver_options more flexible
SamuelBrand1 Nov 4, 2024
92b88b9
update unit test for ODEProcess
SamuelBrand1 Nov 4, 2024
2d1e1c4
switch to doc raw
SamuelBrand1 Nov 4, 2024
3f24571
Don't pass jac_prototype
SamuelBrand1 Nov 4, 2024
d14f2ff
docstrings for ODEProcess
SamuelBrand1 Nov 4, 2024
560e272
update generate_latent_infs doc strings.
SamuelBrand1 Nov 4, 2024
2f8b12e
remove undefined export
SamuelBrand1 Nov 4, 2024
62d5476
refactor doc string example
SamuelBrand1 Nov 5, 2024
bc701a6
refactor other docstring examples
SamuelBrand1 Nov 5, 2024
d079fb6
Rm AbstractParam type + ODEParam is latent model
SamuelBrand1 Nov 11, 2024
dc85052
fix aqua
SamuelBrand1 Nov 11, 2024
1dd1f09
fix other undefined export
SamuelBrand1 Nov 11, 2024
4b90bbd
Merge branch 'main' into sciml-infection-model
seabbs Nov 11, 2024
d5373fe
remove broken benchmark
SamuelBrand1 Nov 11, 2024
b39af84
catch rogue dep call
SamuelBrand1 Nov 11, 2024
34a56ea
Merge branch 'main' into sciml-infection-model
seabbs Nov 14, 2024
b17b6a9
revert _expand_dist
SamuelBrand1 Nov 14, 2024
bc52a4a
rm include
SamuelBrand1 Nov 14, 2024
feeb80a
refactor Z-t -> n
SamuelBrand1 Nov 14, 2024
a1b0472
pass through Z_t
SamuelBrand1 Nov 14, 2024
f68a0cf
start going to jldoctest
SamuelBrand1 Nov 14, 2024
149fb0d
docs example code to jldoctest
SamuelBrand1 Nov 15, 2024
9e6df0a
change to jldoctest
SamuelBrand1 Nov 15, 2024
f7a8a70
add a resolve step to benchmark
SamuelBrand1 Nov 15, 2024
f91d452
up benchmark manifest
SamuelBrand1 Nov 15, 2024
85a01e1
Update benchmark.yaml
SamuelBrand1 Nov 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions EpiAware/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -36,6 +37,7 @@ FillArrays = "1.11"
LinearAlgebra = ">= 1.9"
LogExpFunctions = "0.3"
MCMCChains = "6.0"
OrdinaryDiffEq = "6.89.0"
Pathfinder = "0.8, 0.9"
QuadGK = "2.9"
Random = ">= 1.9"
Expand Down
2 changes: 1 addition & 1 deletion EpiAware/src/EpiAwareUtils/prefix_submodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ submodel = prefix_submodel(FixedIntercept(0.1), generate_latent, string(1), 2)

We can now draw a sample from the submodel.

```julia
```@example
rand(submodel)
```
"
Expand Down
10 changes: 8 additions & 2 deletions EpiAware/src/EpiInfModels/EpiInfModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@ module EpiInfModels
using ..EpiAwareBase
using ..EpiAwareUtils

using Turing, Distributions, DocStringExtensions, LinearAlgebra, LogExpFunctions
using LogExpFunctions: xexpy

using Turing, Distributions, DocStringExtensions, LinearAlgebra, OrdinaryDiffEq

#Export parameter helpers
export EpiData

#Export models
export EpiData, DirectInfections, ExpGrowthRate, Renewal
export DirectInfections, ExpGrowthRate, Renewal, ODEProcess

#Export functions
export R_to_r, r_to_R, expected_Rt
Expand All @@ -20,6 +25,7 @@ include("DirectInfections.jl")
include("ExpGrowthRate.jl")
include("RenewalSteps.jl")
include("Renewal.jl")
include("ODEProcess.jl")
include("utils.jl")

end
306 changes: 306 additions & 0 deletions EpiAware/src/EpiInfModels/ODEProcess.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
@doc raw"""
A structure representing an infection process modeled by an Ordinary Differential Equation (ODE).
At a high level, an `ODEProcess` struct object combines:

- An `AbstractTuringParamModel` which defines the ODE model in terms of `OrdinaryDiffEq` types,
the parameters of the ODE model and a method to generate the parameters.
- A technique for solving and interpreting the ODE model using the `SciML` ecosystem. This includes
the solver used in the ODE solution, keyword arguments to send to the solver and a function
to map the `ODESolution` solution object to latent infections.

# Constructors
- `ODEProcess(prob::ODEProblem; ts, solver, sol2infs)`: Create an `ODEProcess`
object with the ODE problem `prob`, time points `ts`, solver `solver`, and function `sol2infs`.

# Predefined ODE models
Two basic ODE models are provided in the `EpiAware` package: `SIRParams` and `SEIRParams`.
In both cases these are defined in terms of the proportions of the population in each compartment
of the SIR and SEIR models respectively.

## SIR model

```math
\begin{aligned}
\frac{dS}{dt} &= -\beta SI \\
\frac{dI}{dt} &= \beta SI - \gamma I \\
\frac{dR}{dt} &= \gamma I
\end{aligned}
```
Where `S` is the proportion of the population that is susceptible, `I` is the proportion of the
population that is infected and `R` is the proportion of the population that is recovered. The
parameters are the infectiousness `β` and the recovery rate `γ`.

```jldoctest sirexample; output = false
using EpiAware, OrdinaryDiffEq, Distributions

# Create an instance of SIRParams
sirparams = SIRParams(
tspan = (0.0, 100.0),
infectiousness = LogNormal(log(0.3), 0.05),
recovery_rate = LogNormal(log(0.1), 0.05),
initial_prop_infected = Beta(1, 99)
)
nothing

# output

```

## SEIR model

```math
\begin{aligned}
\frac{dS}{dt} &= -\beta SI \\
\frac{dE}{dt} &= \beta SI - \alpha E \\
\frac{dI}{dt} &= \alpha E - \gamma I \\
\frac{dR}{dt} &= \gamma I
\end{aligned}
```
Where `S` is the proportion of the population that is susceptible, `E` is the proportion of the
population that is exposed, `I` is the proportion of the population that is infected and `R` is
the proportion of the population that is recovered. The parameters are the infectiousness `β`,
the incubation rate `α` and the recovery rate `γ`.

```jldoctest; output = false
using EpiAware, OrdinaryDiffEq, Distributions, Random
Random.seed!(1234)

# Create an instance of SIRParams
seirparams = SEIRParams(
tspan = (0.0, 100.0),
infectiousness = LogNormal(log(0.3), 0.05),
incubation_rate = LogNormal(log(0.1), 0.05),
recovery_rate = LogNormal(log(0.1), 0.05),
initial_prop_infected = Beta(1, 99)
)
nothing

# output

```

# Usage example with `ODEProcess` and predefined SIR model

In this example we define an `ODEProcess` object using the predefined `SIRParams` model from
above. We then generate latent infections using the `generate_latent_infs` function, and refit
the model using a `Turing` model.

We assume that the latent infections are observed with a Poisson likelihood around their
ODE model prediction. The population size is `N = 1000`, which we put into the `sol2infs` function,
which maps the ODE solution to the number of infections. Recall that the `EpiAware` default SIR
implementation assumes the model is in density/proportion form. Also, note that since the `sol2infs`
function is a link function that maps the ODE solution to the expected number of infections we also
apply the `LogExpFunctions.softplus` function to ensure that the expected number of infections is non-negative.
Note that the `softplus` function is a smooth approximation to the ReLU function `x -> max(0, x)`.
The utility of this approach is that small negative output from the ODE solver (e.g. ~ -1e-10) will be
mapped to small positive values, without needing to use strict positivity constraints in the model.

First, we define the `ODEProcess` object which combines the SIR model with the `sol2infs` link
function and the solver options.

```jldoctest sirexample; output = false
using Turing, LogExpFunctions
N = 1000.0

sir_process = ODEProcess(
params = sirparams,
sol2infs = sol -> softplus.(N .* sol[2, :]),
solver_options = Dict(:verbose => false, :saveat => 1.0)
)
nothing

# output

```

Second, we define a `PoissionError` observation model for linking the the number of infections.

```jldoctest sirexample; output = false
pois_obs = PoissonError()
nothing

# output

```

Next, we create a `Turing` model for the full generative process: this solves the ODE model for
the latent infections and then samples the observed infections from a Poisson distribution with this
as the average.

NB: The `nothing` argument is a dummy latent process, e.g. a log-Rt time series, that is not
used in the SIR model, but might be used in other models.

```jldoctest sirexample; output = false
@model function fit_ode_model(data)
@submodel I_t = generate_latent_infs(sir_process, nothing)
@submodel y_t = generate_observations(pois_obs, data, I_t)

return y_t
end
nothing

# output

```

We can generate some test data from the model by passing `missing` as the argument to the model.
This tells `Turing` that there is no data to condition on, so it will sample from the prior parameters
and then generate infections. In this case, we do it in a way where we cache the sampled parameters
as `θ` for later use.

```jldoctest sirexample; output = false
# Sampled parameters
gen_mdl = fit_ode_model(missing)
θ = rand(gen_mdl)
test_data = (gen_mdl | θ)()
nothing

# output

```

Now, we can refit the model but this time we condition on the test data. We suppress the
output of the sampling process to keep the output clean, but you can remove the `@suppress` macro.

```jldoctest sirexample; output = false
using Suppressor
inference_mdl = fit_ode_model(test_data)
chn = Suppressor.@suppress sample(inference_mdl, NUTS(), 2_000)
summarize(chn)
nothing

# output

```

We can compare the summarized chain to the sampled parameters in `θ` to see that the model is
fitting the data well and recovering a credible interval containing the true parameters.

# Custom ODE models

To define a custom ODE model, you need to define:

- Some `CustomModel <: AbstractTuringLatentModel` struct
that contains the ODE problem as a field called `prob`, as well as sufficient fields to
define or sample the parameters of the ODE model.
- A method for `EpiAwareBase.generate_latent(params::CustomModel, Z_t)` that generates the
initial condition and parameters of the ODE model, potentially conditional on a sample from a latent process `Z_t`.
This method must return a `Tuple` `(u0, p)` where `u0` is the initial condition and `p` is the parameters.

Here is an example of a simple custom ODE model for _specified_ exponential growth:

```jldoctest customexample; output = false
using EpiAware, Turing, OrdinaryDiffEq
# Define a simple exponential growth model for testing
function expgrowth(du, u, p, t)
du[1] = p[1] * u[1]
end

r = log(2) / 7 # Growth rate corresponding to 7 day doubling time

# Define the ODE problem using SciML
prob = ODEProblem(expgrowth, [1.0], (0.0, 10.0), [r])

# Define the custom parameters struct
struct CustomModel <: AbstractTuringLatentModel
prob::ODEProblem
r::Float64
u0::Float64
end
custom_ode = CustomModel(prob, r, 1.0)

# Define the custom generate_latent function
@model function EpiAwareBase.generate_latent(params::CustomModel, n)
return ([params.u0], [params.r])
end
nothing

# output

```

This model is not random! But we can still use it to generate latent infections.

```jldoctest customexample; output = false
# Define the ODEProcess
expgrowth_model = ODEProcess(
params = custom_ode,
sol2infs = sol -> sol[1, :]
)
infs = generate_latent_infs(expgrowth_model, nothing)()
nothing

# output

```
"""
@kwdef struct ODEProcess{
P <: AbstractTuringLatentModel, S, F <: Function, D <:
Union{Dict, NamedTuple}} <:
EpiAwareBase.AbstractTuringEpiModel
"The ODE problem and parameters, where `P` is a subtype of `AbstractTuringLatentModel`."
params::P
"The solver used for the ODE problem. Default is `AutoVern7(Rodas5())`, which is an auto
switching solver aimed at medium/low tolerances."
solver::S = AutoVern7(Rodas5())
"A function that maps the solution object of the ODE to infection counts."
sol2infs::F
"The extra solver options for the ODE problem. Can be either a `Dict` or a `NamedTuple`
containing the solver options."
solver_options::D = Dict(:verbose => false, :saveat => 1.0)
end

@doc raw"""
Implement the `generate_latent_infs` function for the `ODEProcess` model.

This function remakes the ODE problem with the provided initial conditions and parameters,
solves it using the specified solver, and then transforms the solution into latent infections
using the `sol2infs` function.

# Example usage with predefined SIR model

In this example we define an `ODEProcess` object using the predefined `SIRParams` model and
generate an expected infection time series using SIR model parameters sampled from their priors.

```jldoctest; output = false
using EpiAware, OrdinaryDiffEq, Distributions, Turing, LogExpFunctions

# Create an instance of SIRParams
sirparams = SIRParams(
tspan = (0.0, 100.0),
infectiousness = LogNormal(log(0.3), 0.05),
recovery_rate = LogNormal(log(0.1), 0.05),
initial_prop_infected = Beta(1, 99)
)

#Population size

N = 1000.0

sir_process = ODEProcess(
params = sirparams,
sol2infs = sol -> softplus.(N .* sol[2, :]),
solver_options = Dict(:verbose => false, :saveat => 1.0)
)

generated_It = generate_latent_infs(sir_process, nothing)()
nothing

# output

```

"""
@model function EpiAwareBase.generate_latent_infs(epi_model::ODEProcess, Z_t)
prob, solver, sol2infs, solver_options = epi_model.params.prob,
epi_model.solver, epi_model.sol2infs, epi_model.solver_options
n = isnothing(Z_t) ? 0 : size(Z_t, 1)

@submodel u0, p = generate_latent(epi_model.params, n)

_prob = remake(prob; u0 = u0, p = p)
sol = solve(_prob, solver; solver_options...)
I_t = sol2infs(sol)

return I_t
end
10 changes: 8 additions & 2 deletions EpiAware/src/EpiLatentModels/EpiLatentModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@ using LogExpFunctions: softmax

using FillArrays: Fill

using Turing, Distributions, DocStringExtensions, LinearAlgebra
using Turing, Distributions, DocStringExtensions, LinearAlgebra, SparseArrays,
OrdinaryDiffEq

#Export models
export FixedIntercept, Intercept, RandomWalk, AR, HierarchicalNormal

#Export ODE definitions
export SIRParams, SEIRParams

# Export tools for manipulating latent models
export CombineLatentModels, ConcatLatentModels, BroadcastLatentModel

Expand All @@ -29,10 +33,13 @@ export broadcast_rule, broadcast_dayofweek, broadcast_weekly, equal_dimensions
export DiffLatentModel, TransformLatentModel, PrefixLatentModel, RecordExpectedLatent

include("docstrings.jl")
include("utils.jl")
include("models/Intercept.jl")
include("models/RandomWalk.jl")
include("models/AR.jl")
include("models/HierarchicalNormal.jl")
include("odemodels/SIRParams.jl")
include("odemodels/SEIRParams.jl")
include("modifiers/DiffLatentModel.jl")
include("modifiers/TransformLatentModel.jl")
include("modifiers/PrefixLatentModel.jl")
Expand All @@ -42,6 +49,5 @@ include("manipulators/ConcatLatentModels.jl")
include("manipulators/broadcast/LatentModel.jl")
include("manipulators/broadcast/rules.jl")
include("manipulators/broadcast/helpers.jl")
include("utils.jl")

end
Loading
Loading