Skip to content

Commit

Permalink
Fix the main part of the SSA code. Time to clean up.
Browse files Browse the repository at this point in the history
  • Loading branch information
Vilin97 committed Sep 11, 2023
1 parent 7861bc7 commit e36f466
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 52 deletions.
4 changes: 2 additions & 2 deletions src/spatial/bracketing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ end
nothing
end

function is_outside_brackets(u_low_high::LowHigh{M}, u::M, species, site) where {M}
return u[species, site] < u_low_high.low[species, site] || u[species, site] > u_low_high.high[species, site]
function is_inside_brackets(u_low_high::LowHigh{M}, u::M, species, site) where {M}
return u_low_high.low[species, site] < u[species, site] < u_low_high.high[species, site]
end

### convenience functions for LowHigh ###
Expand Down
9 changes: 1 addition & 8 deletions src/spatial/hop_rates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,7 @@ function update_hop_rates!(hop_rates::AbstractHopRates, species_vec, u, site, sp
end
end

function recompute_site_hop_rate(hop_rates::HP, u, site, spatial_system) where {HP <: AbstractHopRates}
rate = zero(eltype(hop_rates.rates))
num_species = size(hop_rates.rates, 1)
for species in 1:num_species
rate += evalhoprate(hop_rates, u, species, site, spatial_system)
end
return rate
end
hop_rate(hop_rates, species, site) = @inbounds hop_rates.rates[species, site]

"""
total_site_hop_rate(hop_rates::AbstractHopRates, site)
Expand Down
17 changes: 3 additions & 14 deletions src/spatial/reaction_rates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ function reset!(rx_rates::RxRates)
nothing
end

rx_rate(rx_rates, rx, site) = rx_rates.rates[rx, site]
evalrxrate(rx_rates, u, rx, site) = eval_massaction_rate(u, rx, rx_rates.ma_jumps, site)

"""
total_site_rx_rate(rx_rates::RxRates, site)
Expand Down Expand Up @@ -78,20 +81,6 @@ function sample_rx_at_site(rx_rates::RxRates, site, rng)
rand(rng) * total_site_rx_rate(rx_rates, site))
end

"""
recompute_site_rx_rate(rx_rates::RxRates, u, site)
compute the total reaction rate at site at the current state u
"""
function recompute_site_rx_rate(rx_rates::RxRates, u, site)
rate = zero(eltype(rx_rates.rates))
ma_jumps = rx_rates.ma_jumps
for rx in 1:num_rxs(rx_rates)
rate += eval_massaction_rate(u, rx, ma_jumps, site)
end
return rate
end

# helper functions
function set_rx_rate_at_site!(rx_rates::RxRates, site, rx, rate)
@inbounds old_rate = rx_rates.rates[rx, site]
Expand Down
89 changes: 63 additions & 26 deletions src/spatial/rssacrdirect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ mutable struct RSSACRDirectJumpAggregation{T, BD, M, RNG, J, RX, HOP, DEPGR,
u_low_high::LowHigh{M} # species bracketing
rx_rates::LowHigh{RX}
hop_rates::LowHigh{HOP}
site_rates::LowHigh{Vector{T}}
site_rates::LowHigh{Vector{T}} # TODO(vilin97): we never use site_rates.low
save_positions::Tuple{Bool, Bool}
rng::RNG
dep_gr::DEPGR #dep graph is same for each site
Expand Down Expand Up @@ -69,8 +69,24 @@ function RSSACRDirectJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, bd::BD,
# construct an empty initial priority table -- we'll reset this in init
rt = PriorityTable(ratetogroup, zeros(T, 1), minrate, 2 * minrate)

RSSACRDirectJumpAggregation{T, BD, M, RNG, J, RX, HOP, typeof(dg), typeof(vtoj_map), typeof(jtov_map), SS, typeof(rt), Nothing, Nothing, Nothing}(
nj, nj, njt, et, bd, u_low_high, rx_rates, hop_rates, site_rates, sps, rng, dg, vtoj_map, jtov_map, spatial_system, num_specs, rt, nothing, nothing)
RSSACRDirectJumpAggregation{
T,
BD,
M,
RNG,
J,
RX,
HOP,
typeof(dg),
typeof(vtoj_map),
typeof(jtov_map),
SS,
typeof(rt),
Nothing,
Nothing,
Nothing,
}(nj, nj, njt, et, bd, u_low_high, rx_rates, hop_rates, site_rates, sps, rng, dg,
vtoj_map, jtov_map, spatial_system, num_specs, rt, nothing, nothing)
end

############################# Required Functions ##############################
Expand Down Expand Up @@ -118,14 +134,16 @@ end
function generate_jumps!(p::RSSACRDirectJumpAggregation, integrator, u, params, t)
@unpack rng, rt, site_rates, rx_rates, hop_rates, spatial_system = p
time_delta = zero(t)
site = zero(eltype(u))
while true
site = sample(rt, site_rates.high, rng)
jump = sample_jump_direct(rx_rates.high, hop_rates.high, site, spatial_system, rng)
time_delta += randexp(rng)
accept_jump(rx_rates, hop_rates, site_rates, u, site, spatial_system, rng) && break
if accept_jump(p, u, jump)
p.next_jump_time = t + time_delta / groupsum(rt)
p.next_jump = jump
break
end
end
p.next_jump_time = t + time_delta / groupsum(rt)
p.next_jump = sample_jump_direct(rx_rates.high, hop_rates.high, site, spatial_system, rng)
nothing
end

Expand All @@ -139,14 +157,37 @@ end

######################## SSA specific helper routines ########################
# Return true if site is accepted.
function accept_jump(rx_rates, hop_rates, site_rates, u, site, spatial_system, rng)
acceptance_threshold = rand(rng) * site_rates.high[site]
if acceptance_threshold < site_rates.low[site]
function accept_jump(p, u, jump)
if is_hop(p, jump)
return accept_hop(p, u, jump)
else
return accept_rx(p, u, jump)
end
end

function accept_hop(p, u, jump)
@unpack hop_rates, spatial_system, rng = p
species, site = jump.jidx, jump.src
acceptance_threshold = rand(rng) * hop_rate(hop_rates.high, species, site)
if hop_rate(hop_rates.low, species, site) > acceptance_threshold
return true
else
site_rate = recompute_site_hop_rate(hop_rates.low, u, site, spatial_system) +
recompute_site_rx_rate(rx_rates.low, u, site)
return acceptance_threshold < site_rate
# compute the real rate. Could have used hop_rates.high as well.
real_rate = evalhoprate(hop_rates.low, u, species, site, spatial_system)
return real_rate > acceptance_threshold
end
end

function accept_rx(p, u, jump)
@unpack rx_rates, rng = p
rx, site = reaction_id_from_jump(p, jump), jump.src
acceptance_threshold = rand(rng) * rx_rate(rx_rates.high, rx, site)
if rx_rate(rx_rates.low, rx, site) > acceptance_threshold
return true
else
# compute the real rate. Could have used rx_rates.high as well.
real_rate = evalrxrate(rx_rates.low, u, rx, site)
return real_rate > acceptance_threshold
end
end

Expand Down Expand Up @@ -186,25 +227,20 @@ end
recalculate jump rates for jumps that depend on the just executed jump (p.prev_jump)
"""
function update_dependent_rates!(p::RSSACRDirectJumpAggregation,
integrator,
t)
@unpack rx_rates, hop_rates, site_rates, u_low_high, bracket_data, vartojumps_map, jumptovars_map, spatial_system = p

u = integrator.u
site_rates = p.site_rates
function update_dependent_rates!(p::RSSACRDirectJumpAggregation, integrator, t)
jump = p.prev_jump

if is_hop(p, jump)
species_to_update = jump.jidx
sites_to_update = (jump.src, jump.dst)
update_brackets!(p, integrator, jump.jidx, (jump.src, jump.dst))
else
species_to_update = jumptovars_map[reaction_id_from_jump(p, jump)]
sites_to_update = jump.src
update_brackets!(p, integrator, p.jumptovars_map[reaction_id_from_jump(p, jump)], jump.src)
end
end

function update_brackets!(p, integrator, species_to_update, sites_to_update)
@unpack rx_rates, hop_rates, site_rates, u_low_high, bracket_data, vartojumps_map, spatial_system = p
u = integrator.u
for site in sites_to_update, species in species_to_update
if is_outside_brackets(u_low_high, u, species, site)
if !is_inside_brackets(u_low_high, u, species, site)
update_u_brackets!(u_low_high, bracket_data, u, species, site)
update_rx_rates!(rx_rates,
vartojumps_map[species],
Expand All @@ -218,6 +254,7 @@ function update_dependent_rates!(p::RSSACRDirectJumpAggregation,
update!(p.rt, site, oldrate, site_rates.high[site])
end
end
nothing
end

"""
Expand Down
4 changes: 2 additions & 2 deletions test/spatial/ABC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ using JumpProcesses, DiffEqBase
# using BenchmarkTools
using Test, Graphs

Nsims = 1000
Nsims = 100
reltol = 0.05
non_spatial_mean = [65.7395, 65.7395, 434.2605] #mean of 10,000 simulations

dim = 1
linear_size = 1
linear_size = 5
dims = Tuple(repeat([linear_size], dim))
num_nodes = prod(dims)
starting_site = trunc(Int, (linear_size^dim + 1) / 2)
Expand Down

0 comments on commit e36f466

Please sign in to comment.