Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ficticious jump algorithm for time-dependent variable rate jumps. #252

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/JumpProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ include("aggregators/prioritytable.jl")
include("aggregators/directcr.jl")
include("aggregators/rssacr.jl")
include("aggregators/rdirect.jl")
include("aggregators/extrande.jl")
include("aggregators/coevolve.jl")

# spatial:
Expand Down Expand Up @@ -84,6 +85,7 @@ export Direct, DirectFW, SortingDirect, DirectCR
export BracketData, RSSA
export FRM, FRMFW, NRM
export RSSACR, RDirect
export Extrande
export Coevolve

export get_num_majumps, needs_depgraph, needs_vartojumps_map
Expand Down
13 changes: 12 additions & 1 deletion src/aggregators/aggregators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,18 @@ algorithm with optimal binning, Journal of Chemical Physics 143, 074108
"""
struct DirectCRDirect <: AbstractAggregatorAlgorithm end

"""
The Extrande method for simulating variable rate jumps with user-defined bounds
on jumps rates and validity intervals via rejection.

Stochastic Simulation of Biomolecular Networks in Dynamic Environments, Voliotis
M, Thomas P, Grima R, Bowsher CG, PLOS Computational Biology 12(6): e1004923.
(2016); doi.org/10.1371/journal.pcbi.1004923
"""
struct Extrande <: AbstractAggregatorAlgorithm end

const JUMP_AGGREGATORS = (Direct(), DirectFW(), DirectCR(), SortingDirect(), RSSA(), FRM(),
FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve())
FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve(), Extrande())

# For JumpProblem construction without an aggregator
struct NullAggregator <: AbstractAggregatorAlgorithm end
Expand All @@ -181,6 +191,7 @@ needs_vartojumps_map(aggregator::RSSACR) = true
# true if aggregator supports variable rates
supports_variablerates(aggregator::AbstractAggregatorAlgorithm) = false
supports_variablerates(aggregator::Coevolve) = true
supports_variablerates(aggregator::Extrande) = true

is_spatial(aggregator::AbstractAggregatorAlgorithm) = false
is_spatial(aggregator::NSM) = true
Expand Down
122 changes: 122 additions & 0 deletions src/aggregators/extrande.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Define the aggregator.
struct Extrande <: AbstractAggregatorAlgorithm end

"""
Extrande sampling method for jumps with defined rate bounds.
"""

nullaffect!(integrator) = nothing
const NullAffectJump = ConstantRateJump((u, p, t) -> 0.0, nullaffect!)

mutable struct ExtrandeJumpAggregation{T, S, F1, F2, F3, F4, RNG} <:
AbstractSSAJumpAggregator
next_jump::Int
prev_jump::Int
next_jump_time::T
end_time::T
cur_rates::Vector{T}
sum_rate::T
ma_jumps::S
rate_bnds::F3
wds::F4
rates::F1
affects!::F2
save_positions::Tuple{Bool, Bool}
rng::RNG
end

function ExtrandeJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S,
rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG;
rate_bounds::F3, windows::F4,
kwargs...) where {T, S, F1, F2, F3, F4, RNG}
ExtrandeJumpAggregation{T, S, F1, F2, F3, F4, RNG}(nj, nj, njt, et, crs, sr, maj,
rate_bounds, windows, rs, affs!, sps,
rng)
end

############################# Required Functions ##############################
function aggregate(aggregator::Extrande, u, p, t, end_time, constant_jumps,
ma_jumps, save_positions, rng; variable_jumps = (), kwargs...)
rates, affects! = get_jump_info_fwrappers(u, p, t,
(constant_jumps..., variable_jumps...,
NullAffectJump))
rbnds, wnds = get_va_jump_bound_info_fwrapper(u, p, t,
(constant_jumps..., variable_jumps...,
NullAffectJump))
build_jump_aggregation(ExtrandeJumpAggregation, u, p, t, end_time, ma_jumps,
rates, affects!, save_positions, rng; u = u, rate_bounds = rbnds,
windows = wnds, kwargs...)
end

# set up a new simulation and calculate the first jump / jump time
function initialize!(p::ExtrandeJumpAggregation, integrator, u, params, t)
p.end_time = integrator.sol.prob.tspan[2]
generate_jumps!(p, integrator, u, params, t)
end

# execute one jump, changing the system state
@inline function execute_jumps!(p::ExtrandeJumpAggregation, integrator, u, params, t)
# execute jump
u = update_state!(p, integrator, u)
nothing
end

@fastmath function next_extrande_jump(p::ExtrandeJumpAggregation, u, params, t)
ttnj = typemax(typeof(t))
Wmin = typemax(typeof(t))
Bmax = zero(t)

prev_rate = zero(t)
new_rate = zero(t)
cur_rates = p.cur_rates

# Mass action rates
majumps = p.ma_jumps
idx = get_num_majumps(majumps)

@inbounds for i in 1:idx
new_rate = evalrxrate(u, i, majumps)
cur_rates[i] = add_fast(new_rate, prev_rate)
prev_rate = cur_rates[i]
Bmax += prev_rate
end

# Calculate the total rate bound and the largest common validity window.
if !isempty(p.rate_bnds)
@inbounds for i in 1:length(p.wds)
Wmin = min(Wmin, p.wds[i](u, params, t))
Bmax += p.rate_bnds[i](u, params, t)
end
end

# Rejection sampling.
nextrx = length(cur_rates)
prop_ttnj = randexp(p.rng) / Bmax
if prop_ttnj < Wmin
if !isempty(p.rates)
idx += 1
fill_cur_rates(u, params, prop_ttnj + t, p.cur_rates, idx, p.rates...)
@inbounds for i in idx:length(cur_rates)
cur_rates[i] = add_fast(cur_rates[i], prev_rate)
prev_rate = cur_rates[i]
end
end
UBmax = rand(p.rng) * Bmax
ttnj = prop_ttnj
if p.cur_rates[end] ≥ UBmax
nextrx = searchsortedfirst(p.cur_rates, UBmax)
end
else
ttnj = Wmin
end

return nextrx, ttnj
end

function generate_jumps!(p::ExtrandeJumpAggregation, integrator, u, params, t)
nextexj, ttnexj = next_extrande_jump(p, u, params, t)
p.next_jump = nextexj
p.next_jump_time = t + ttnexj

nothing
end
24 changes: 24 additions & 0 deletions src/jumps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -702,3 +702,27 @@ function get_jump_info_fwrappers(u, p, t, constant_jumps)

rates, affects!
end

##### helpers for splitting variable rate jumps with rate bounds and without #####

function rate_window_function(jump)
# Assumes that if no window is given the rate bound is valid for all times.
return !(jump.rateinterval isa Nothing) ? jump.rateinterval : (u, p, t) -> Inf
end

function get_va_jump_bound_info_fwrapper(u, p, t, jumps)
RateWrapper = FunctionWrappers.FunctionWrapper{typeof(t),
Tuple{typeof(u), typeof(p), typeof(t)}}

if (jumps !== nothing) && !isempty(jumps)
rates = [j isa VariableRateJump ? RateWrapper(j.urate) : RateWrapper(j.rate)
for j in jumps]
wnds = [j isa VariableRateJump ? RateWrapper(rate_window_function(j)) :
RateWrapper((u, p, t) -> Inf) for j in jumps]
else
rates = Vector{RateWrapper}()
wnds = Vector{RateWrapper}()
end

rates, wnds
end
2 changes: 1 addition & 1 deletion test/degenerate_rx_cases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ doprint = false
doplot = false

methods = (RDirect(), RSSACR(), Direct(), DirectFW(), FRM(), FRMFW(), SortingDirect(),
NRM(), RSSA(), DirectCR(), Coevolve())
NRM(), RSSA(), DirectCR(), Coevolve(), Extrande())

# one reaction case, mass action jump, vector of data
rate = [2.0]
Expand Down
74 changes: 74 additions & 0 deletions test/extrande.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
using DiffEqBase, JumpProcesses, OrdinaryDiffEq, Test
using StableRNGs
using Statistics
rng = StableRNG(48572)

f = function (du, u, p, t)
du[1] = 0.0
end

rate = (u, p, t) -> t < 5.0 ? 1.0 : 0.0
rbound = (u, p, t) -> 1.0
rinterval = (u, p, t) -> Inf
affect! = (integrator) -> (integrator.u[1] = integrator.u[1] + 1)
jump = VariableRateJump(rate, affect!; urate = rbound, rateinterval = rinterval)

prob = ODEProblem(f, [0.0], (0.0, 10.0))
jump_prob = JumpProblem(prob, Extrande(), jump; rng = rng)

# Test that process doesn't jump when rate switches to 0.
sol = solve(jump_prob, Tsit5())
@test sol(5.0)[1] == sol[end][1]

# Birth-death process with time-varying birth rates.
Nsims = 1000000
u0 = [10.0]

function runsimulations(jump_prob, testts)
Psamp = zeros(Int, length(testts), Nsims)
for i in 1:Nsims
sol_ = solve(jump_prob, Tsit5())
Psamp[:, i] = getindex.(sol_(testts).u, 1)
end
mean(Psamp, dims = 2)
end

# Variable rate birth jumps.
rateb = (u, p, t) -> (0.1 * sin(t) + 0.2)
ratebbound = (u, p, t) -> 0.3
ratebwindow = (u, p, t) -> Inf
affectb! = (integrator) -> (integrator.u[1] = integrator.u[1] + 1)
jumpb = VariableRateJump(rateb, affectb!; urate = ratebbound, rateinterval = ratebwindow)

# Constant rate death jumps.
rated = (u, p, t) -> u[1] * 0.08
affectd! = (integrator) -> (integrator.u[1] = integrator.u[1] - 1)
jumpd = ConstantRateJump(rated, affectd!)

# Problem definition.
bd_prob = ODEProblem(f, u0, (0.0, 2pi))
jump_bd_prob = JumpProblem(bd_prob, Extrande(), jumpb, jumpd)

test_times = range(1.0, stop = 2pi, length = 3)
means = runsimulations(jump_bd_prob, test_times)

# ODE for the mean.
fu = function (du, u, p, t)
du[1] = (0.1 * sin(t) + 0.2) - (u[1] * 0.08)
end

ode_prob = ODEProblem(fu, u0, (0.0, 2 * pi))
ode_sol = solve(ode_prob, Tsit5())

# Test extrande against the ODE mean.
@test prod(isapprox.(means, getindex.(ode_sol(test_times).u, 1), rtol = 1e-3))

# Make sure interfaces correctly with Mass Action Jumps.
reactant_stoich = [[1 => 1]]
net_stoich = [[1 => -1]]
majd = MassActionJump(reactant_stoich, net_stoich; param_idxs = [1])
bmajd_prob = ODEProblem(f, u0, (0.0, 2pi), [0.08])
jump_bmajd_prob = JumpProblem(bmajd_prob, Extrande(), jumpb, majd)

means_mass_action = runsimulations(jump_bmajd_prob, test_times)
@test prod(isapprox.(means_mass_action, getindex.(ode_sol(test_times).u, 1), rtol = 1e-3))
4 changes: 2 additions & 2 deletions test/hawkes_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ h = [Float64[]]

Eλ, Varλ = expected_stats_hawkes_problem(p, tspan)

algs = (Direct(), Coevolve(), Coevolve())
algs = (Direct(), Coevolve(), Coevolve(), Extrande())
uselrate = zeros(Bool, length(algs))
uselrate[3] = true
Nsims = 250
Expand All @@ -122,7 +122,7 @@ for (i, alg) in enumerate(algs)
reset_history!(h)
sols[n] = solve(jump_prob, stepper)
end
if typeof(alg) <: Coevolve
if typeof(alg) <: Union{Coevolve, Extrande}
λs = permutedims(mapreduce((sol) -> empirical_rate(sol), hcat, sols))
else
cols = length(sols[1].u[1].u)
Expand Down
2 changes: 1 addition & 1 deletion test/linearreaction_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ tf = 0.1
baserate = 0.1
A0 = 100
exactmean = (t, ratevec) -> A0 * exp(-sum(ratevec) * t)
SSAalgs = [RSSACR(), Direct(), RSSA()]
SSAalgs = [RSSACR(), Direct(), RSSA(), Extrande()]

spec_to_dep_jumps = [collect(1:Nrxs), []]
jump_to_dep_specs = [[1, 2] for i in 1:Nrxs]
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ using JumpProcesses, DiffEqBase, SafeTestsets
@time @safetestset "Spatial A + B <--> C" begin include("spatial/ABC.jl") end
@time @safetestset "Spatially Varying Reaction Rates" begin include("spatial/spatial_majump.jl") end
@time @safetestset "Pure diffusion" begin include("spatial/diffusion.jl") end
@time @safetestset "Ficticious Jump " begin include("extrande.jl") end
end