diff --git a/src/spatial/bracketing.jl b/src/spatial/bracketing.jl index 290efc56..3c810f95 100644 --- a/src/spatial/bracketing.jl +++ b/src/spatial/bracketing.jl @@ -35,8 +35,8 @@ end nothing end -function is_outside_brackets(u_low_high::LowHigh{M}, u::M, species, site) where {M} - return u[species, site] < u_low_high.low[species, site] || u[species, site] > u_low_high.high[species, site] +function is_inside_brackets(u_low_high::LowHigh{M}, u::M, species, site) where {M} + return u_low_high.low[species, site] < u[species, site] < u_low_high.high[species, site] end ### convenience functions for LowHigh ### diff --git a/src/spatial/hop_rates.jl b/src/spatial/hop_rates.jl index 2b26283c..9e5f430a 100644 --- a/src/spatial/hop_rates.jl +++ b/src/spatial/hop_rates.jl @@ -72,14 +72,7 @@ function update_hop_rates!(hop_rates::AbstractHopRates, species_vec, u, site, sp end end -function recompute_site_hop_rate(hop_rates::HP, u, site, spatial_system) where {HP <: AbstractHopRates} - rate = zero(eltype(hop_rates.rates)) - num_species = size(hop_rates.rates, 1) - for species in 1:num_species - rate += evalhoprate(hop_rates, u, species, site, spatial_system) - end - return rate -end +hop_rate(hop_rates, species, site) = @inbounds hop_rates.rates[species, site] """ total_site_hop_rate(hop_rates::AbstractHopRates, site) diff --git a/src/spatial/reaction_rates.jl b/src/spatial/reaction_rates.jl index 54f38324..a9c78c9d 100644 --- a/src/spatial/reaction_rates.jl +++ b/src/spatial/reaction_rates.jl @@ -39,6 +39,9 @@ function reset!(rx_rates::RxRates) nothing end +rx_rate(rx_rates, rx, site) = rx_rates.rates[rx, site] +evalrxrate(rx_rates, u, rx, site) = eval_massaction_rate(u, rx, rx_rates.ma_jumps, site) + """ total_site_rx_rate(rx_rates::RxRates, site) @@ -78,20 +81,6 @@ function sample_rx_at_site(rx_rates::RxRates, site, rng) rand(rng) * total_site_rx_rate(rx_rates, site)) end -""" - recompute_site_rx_rate(rx_rates::RxRates, u, site) - -compute the total reaction rate at site at the current state u -""" -function recompute_site_rx_rate(rx_rates::RxRates, u, site) - rate = zero(eltype(rx_rates.rates)) - ma_jumps = rx_rates.ma_jumps - for rx in 1:num_rxs(rx_rates) - rate += eval_massaction_rate(u, rx, ma_jumps, site) - end - return rate -end - # helper functions function set_rx_rate_at_site!(rx_rates::RxRates, site, rx, rate) @inbounds old_rate = rx_rates.rates[rx, site] diff --git a/src/spatial/rssacrdirect.jl b/src/spatial/rssacrdirect.jl index 8413570c..80596624 100644 --- a/src/spatial/rssacrdirect.jl +++ b/src/spatial/rssacrdirect.jl @@ -15,7 +15,7 @@ mutable struct RSSACRDirectJumpAggregation{T, BD, M, RNG, J, RX, HOP, DEPGR, u_low_high::LowHigh{M} # species bracketing rx_rates::LowHigh{RX} hop_rates::LowHigh{HOP} - site_rates::LowHigh{Vector{T}} + site_rates::LowHigh{Vector{T}} # TODO(vilin97): we never use site_rates.low save_positions::Tuple{Bool, Bool} rng::RNG dep_gr::DEPGR #dep graph is same for each site @@ -69,8 +69,24 @@ function RSSACRDirectJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, bd::BD, # construct an empty initial priority table -- we'll reset this in init rt = PriorityTable(ratetogroup, zeros(T, 1), minrate, 2 * minrate) - RSSACRDirectJumpAggregation{T, BD, M, RNG, J, RX, HOP, typeof(dg), typeof(vtoj_map), typeof(jtov_map), SS, typeof(rt), Nothing, Nothing, Nothing}( - nj, nj, njt, et, bd, u_low_high, rx_rates, hop_rates, site_rates, sps, rng, dg, vtoj_map, jtov_map, spatial_system, num_specs, rt, nothing, nothing) + RSSACRDirectJumpAggregation{ + T, + BD, + M, + RNG, + J, + RX, + HOP, + typeof(dg), + typeof(vtoj_map), + typeof(jtov_map), + SS, + typeof(rt), + Nothing, + Nothing, + Nothing, + }(nj, nj, njt, et, bd, u_low_high, rx_rates, hop_rates, site_rates, sps, rng, dg, + vtoj_map, jtov_map, spatial_system, num_specs, rt, nothing, nothing) end ############################# Required Functions ############################## @@ -118,14 +134,16 @@ end function generate_jumps!(p::RSSACRDirectJumpAggregation, integrator, u, params, t) @unpack rng, rt, site_rates, rx_rates, hop_rates, spatial_system = p time_delta = zero(t) - site = zero(eltype(u)) while true site = sample(rt, site_rates.high, rng) + jump = sample_jump_direct(rx_rates.high, hop_rates.high, site, spatial_system, rng) time_delta += randexp(rng) - accept_jump(rx_rates, hop_rates, site_rates, u, site, spatial_system, rng) && break + if accept_jump(p, u, jump) + p.next_jump_time = t + time_delta / groupsum(rt) + p.next_jump = jump + break + end end - p.next_jump_time = t + time_delta / groupsum(rt) - p.next_jump = sample_jump_direct(rx_rates.high, hop_rates.high, site, spatial_system, rng) nothing end @@ -139,14 +157,37 @@ end ######################## SSA specific helper routines ######################## # Return true if site is accepted. -function accept_jump(rx_rates, hop_rates, site_rates, u, site, spatial_system, rng) - acceptance_threshold = rand(rng) * site_rates.high[site] - if acceptance_threshold < site_rates.low[site] +function accept_jump(p, u, jump) + if is_hop(p, jump) + return accept_hop(p, u, jump) + else + return accept_rx(p, u, jump) + end +end + +function accept_hop(p, u, jump) + @unpack hop_rates, spatial_system, rng = p + species, site = jump.jidx, jump.src + acceptance_threshold = rand(rng) * hop_rate(hop_rates.high, species, site) + if hop_rate(hop_rates.low, species, site) > acceptance_threshold return true else - site_rate = recompute_site_hop_rate(hop_rates.low, u, site, spatial_system) + - recompute_site_rx_rate(rx_rates.low, u, site) - return acceptance_threshold < site_rate + # compute the real rate. Could have used hop_rates.high as well. + real_rate = evalhoprate(hop_rates.low, u, species, site, spatial_system) + return real_rate > acceptance_threshold + end +end + +function accept_rx(p, u, jump) + @unpack rx_rates, rng = p + rx, site = reaction_id_from_jump(p, jump), jump.src + acceptance_threshold = rand(rng) * rx_rate(rx_rates.high, rx, site) + if rx_rate(rx_rates.low, rx, site) > acceptance_threshold + return true + else + # compute the real rate. Could have used rx_rates.high as well. + real_rate = evalrxrate(rx_rates.low, u, rx, site) + return real_rate > acceptance_threshold end end @@ -186,25 +227,20 @@ end recalculate jump rates for jumps that depend on the just executed jump (p.prev_jump) """ -function update_dependent_rates!(p::RSSACRDirectJumpAggregation, - integrator, - t) - @unpack rx_rates, hop_rates, site_rates, u_low_high, bracket_data, vartojumps_map, jumptovars_map, spatial_system = p - - u = integrator.u - site_rates = p.site_rates +function update_dependent_rates!(p::RSSACRDirectJumpAggregation, integrator, t) jump = p.prev_jump - if is_hop(p, jump) - species_to_update = jump.jidx - sites_to_update = (jump.src, jump.dst) + update_brackets!(p, integrator, jump.jidx, (jump.src, jump.dst)) else - species_to_update = jumptovars_map[reaction_id_from_jump(p, jump)] - sites_to_update = jump.src + update_brackets!(p, integrator, p.jumptovars_map[reaction_id_from_jump(p, jump)], jump.src) end +end +function update_brackets!(p, integrator, species_to_update, sites_to_update) + @unpack rx_rates, hop_rates, site_rates, u_low_high, bracket_data, vartojumps_map, spatial_system = p + u = integrator.u for site in sites_to_update, species in species_to_update - if is_outside_brackets(u_low_high, u, species, site) + if !is_inside_brackets(u_low_high, u, species, site) update_u_brackets!(u_low_high, bracket_data, u, species, site) update_rx_rates!(rx_rates, vartojumps_map[species], @@ -218,6 +254,7 @@ function update_dependent_rates!(p::RSSACRDirectJumpAggregation, update!(p.rt, site, oldrate, site_rates.high[site]) end end + nothing end """ diff --git a/test/spatial/ABC.jl b/test/spatial/ABC.jl index 48a22c3e..6dd9c7ee 100644 --- a/test/spatial/ABC.jl +++ b/test/spatial/ABC.jl @@ -2,12 +2,12 @@ using JumpProcesses, DiffEqBase # using BenchmarkTools using Test, Graphs -Nsims = 1000 +Nsims = 100 reltol = 0.05 non_spatial_mean = [65.7395, 65.7395, 434.2605] #mean of 10,000 simulations dim = 1 -linear_size = 1 +linear_size = 5 dims = Tuple(repeat([linear_size], dim)) num_nodes = prod(dims) starting_site = trunc(Int, (linear_size^dim + 1) / 2)