Skip to content

Commit

Permalink
Set up bracketing for spatial solvers.
Browse files Browse the repository at this point in the history
This is in preparation for the spatial SSA RSSACRDirect.
  • Loading branch information
Vilin97 committed Sep 10, 2023
1 parent 52ded9a commit 3f1fd74
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/JumpProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ include("spatial/hop_rates.jl")
include("spatial/reaction_rates.jl")
include("spatial/flatten.jl")
include("spatial/utils.jl")
include("spatial/bracketing.jl")

include("spatial/nsm.jl")
include("spatial/directcrdirect.jl")
Expand Down
46 changes: 46 additions & 0 deletions src/spatial/bracketing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
################## Spatial bracketing ##################

# struct to store brackets like (ulow, uhigh), (rx_rates_low, rx_rates_high), (hop_rates_low, hop_rates_high), (site_rates_low, site_rates_high)
struct LowHigh{T}
low::T
high::T

LowHigh(low::T, high::T) where {T} = new{T}(deepcopy(low), deepcopy(high))
LowHigh(pair::Tuple{T,T}) where {T} = new{T}(pair[1], pair[2])
LowHigh(low_and_high::T) where {T} = new{T}(low_and_high, deepcopy(low_and_high))
end

function Base.show(io::IO, ::MIME"text/plain", low_high::LowHigh)
println(io, "Low: \n $(low_high.low)")
println(io, "High: \n $(low_high.high)")

Check warning on line 15 in src/spatial/bracketing.jl

View check run for this annotation

Codecov / codecov/patch

src/spatial/bracketing.jl#L13-L15

Added lines #L13 - L15 were not covered by tests
end

@inline function update_u_brackets!(u_low_high::LowHigh, bracket_data, u::AbstractMatrix)
@inbounds for (i, uval) in enumerate(u)
u_low_high[i] = LowHigh(get_spec_brackets(bracket_data, i, uval))
end
nothing
end

### convenience functions for LowHigh ###
function setindex!(low_high::LowHigh, val::LowHigh, i)
low_high.low[i] = val.low
low_high.high[i] = val.high
val
end

function total_site_rate(rx_rates::LowHigh, hop_rates::LowHigh, site)
return LowHigh(
total_site_rate(rx_rates.low, hop_rates.low, site),
total_site_rate(rx_rates.high, hop_rates.high, site))
end

function update_rx_rates!(rx_rates::LowHigh, rxs, u_low_high, integrator, site)
update_rx_rates!(rx_rates.low, rxs, u_low_high.low, integrator, site)
update_rx_rates!(rx_rates.high, rxs, u_low_high.high, integrator, site)
end

function update_hop_rates!(hop_rates::LowHigh, species, u_low_high, site, spatial_system)
update_hop_rates!(hop_rates.low, species, u_low_high.low, site, spatial_system)
update_hop_rates!(hop_rates.high, species, u_low_high.high, site, spatial_system)
end
11 changes: 8 additions & 3 deletions src/spatial/reaction_rates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,21 @@ end
update rates of all reactions in rxs at site
"""
function update_rx_rates!(rx_rates::RxRates{F, M}, rxs, integrator,
site) where {F, M <: AbstractMassActionJump}
u = integrator.u
function update_rx_rates!(rx_rates::RxRates{F, M}, rxs, u::AbstractMatrix, integrator,
site) where {F, M}
ma_jumps = rx_rates.ma_jumps
@inbounds for rx in rxs
rate = eval_massaction_rate(u, rx, ma_jumps, site)
set_rx_rate_at_site!(rx_rates, site, rx, rate)
end
end

function update_rx_rates!(rx_rates::RxRates{F, M}, rxs, integrator,
site) where {F, M <: AbstractMassActionJump}
u = integrator.u
update_rx_rates!(rx_rates, rxs, u, integrator, site)
end

"""
sample_rx_at_site(rx_rates::RxRates, site, rng)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ using JumpProcesses, DiffEqBase, SafeTestsets
@time @safetestset "Reaction rates" begin include("spatial/reaction_rates.jl") end
@time @safetestset "Hop rates" begin include("spatial/hop_rates.jl") end
@time @safetestset "Topology" begin include("spatial/topology.jl") end
@time @safetestset "Spatial bracketing Tests" begin include("spatial/bracketing.jl") end
@time @safetestset "Spatial A + B <--> C" begin include("spatial/ABC.jl") end
@time @safetestset "Spatially Varying Reaction Rates" begin include("spatial/spatial_majump.jl") end
@time @safetestset "Pure diffusion" begin include("spatial/diffusion.jl") end
Expand Down
53 changes: 53 additions & 0 deletions test/spatial/bracketing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using JumpProcesses, Test
const JP = JumpProcesses

fluctuation_rate = 0.1
threshold = 25
Δu = 4
bracket_data = BracketData(fluctuation_rate, threshold, Δu)

n = 3 # number of sites

# set up spatial system
spatial_system = CartesianGrid((n,)) # n sites
site_rates = JP.LowHigh(zeros(n), zeros(n))

# set up reaction rates
majump_rates = [0.1] # death at rate 0.1
reactstoch = [[1 => 1]]
netstoch = [[1 => -1]]
majump = MassActionJump(majump_rates, reactstoch,
netstoch)
rx_rates = JP.LowHigh(JP.RxRates(n, majump))

# set up hop rates
hop_constants = [1.0]
hop_rates = JP.LowHigh(JP.HopRates(hop_constants, spatial_system))

# set up species brackets
u = 100*ones(Int, 1, n) # 2 species, n sites
u_low_high = JP.LowHigh(u, u)
JP.update_u_brackets!(u_low_high, bracket_data, u)

# update reaction rates, hop rates and site rates
rxs = [1] # vector of all reactions
species_vec = [1] # vector of all species
integrator = Nothing # only needed for constant rate jumps
for site in 1:num_sites(spatial_system)
JP.update_rx_rates!(rx_rates, rxs, u_low_high, integrator, site)
JP.update_hop_rates!(hop_rates, species_vec, u_low_high, site, spatial_system)
site_rates[site] = JP.total_site_rate(rx_rates, hop_rates, site)
end

# test species brackets
@test u_low_high.low[1, 1] u[1, 1] * (1 - fluctuation_rate) atol = 1
@test u_low_high.high[1, 1] u[1, 1] * (1 + fluctuation_rate) atol = 1

# test site rate brackets
site = 1
rx = 1
species = 1
@test JP.total_site_rx_rate(rx_rates.low, site) == majump_rates[rx] * u_low_high.low[species, site]
@test JP.total_site_rx_rate(rx_rates.high, site) == majump_rates[rx] * u_low_high.high[species, site]
@test JP.total_site_hop_rate(hop_rates.low, site) == hop_constants[site] * u_low_high.low[species, site]
@test JP.total_site_hop_rate(hop_rates.high, site) == hop_constants[site] * u_low_high.high[species, site]
1 change: 1 addition & 0 deletions test/spatial/run_spatial_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using JumpProcesses, DiffEqBase, SafeTestsets
@time @safetestset "Reaction rates" begin include("reaction_rates.jl") end
@time @safetestset "Hop rates" begin include("hop_rates.jl") end
@time @safetestset "Topology" begin include("topology.jl") end
@time @safetestset "Bracketing" begin include("bracketing.jl") end
@time @safetestset "Spatial A + B <--> C" begin include("ABC.jl") end
@time @safetestset "Pure diffusion" begin include("diffusion.jl") end
@time @safetestset "Spatially Varying Reaction Rates" begin include("spatial_majump.jl") end
Expand Down

0 comments on commit 3f1fd74

Please sign in to comment.