From 7861bc727c5124ee75489b7b3c51158ddec6b63f Mon Sep 17 00:00:00 2001 From: Vasily Ilin Date: Sun, 10 Sep 2023 15:00:14 -0400 Subject: [PATCH] Write RSSACR-Direct but it's incorrect. I will turn it into CR-RSSA. --- src/JumpProcesses.jl | 3 +- src/aggregators/aggregators.jl | 3 + src/spatial/bracketing.jl | 43 +++++-- src/spatial/directcrdirect.jl | 5 +- src/spatial/hop_rates.jl | 37 +++--- src/spatial/nsm.jl | 4 +- src/spatial/reaction_rates.jl | 15 +++ src/spatial/rssacrdirect.jl | 228 +++++++++++++++++++++++++++++++++ src/spatial/utils.jl | 25 ++-- test/spatial/ABC.jl | 67 ++++++---- 10 files changed, 364 insertions(+), 66 deletions(-) create mode 100644 src/spatial/rssacrdirect.jl diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 39dd5465..640f832d 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -65,6 +65,7 @@ include("spatial/bracketing.jl") include("spatial/nsm.jl") include("spatial/directcrdirect.jl") +include("spatial/rssacrdirect.jl") include("aggregators/aggregated_api.jl") @@ -101,6 +102,6 @@ export ExtendedJumpArray export CartesianGrid, CartesianGridRej export SpatialMassActionJump export outdegree, num_sites, neighbors -export NSM, DirectCRDirect +export NSM, DirectCRDirect, RSSACRDirect end # module diff --git a/src/aggregators/aggregators.jl b/src/aggregators/aggregators.jl index c1553d03..46c220a5 100644 --- a/src/aggregators/aggregators.jl +++ b/src/aggregators/aggregators.jl @@ -159,6 +159,8 @@ algorithm with optimal binning, Journal of Chemical Physics 143, 074108 """ struct DirectCRDirect <: AbstractAggregatorAlgorithm end +struct RSSACRDirect <: AbstractAggregatorAlgorithm end + const JUMP_AGGREGATORS = (Direct(), DirectFW(), DirectCR(), SortingDirect(), RSSA(), FRM(), FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve()) @@ -187,3 +189,4 @@ supports_variablerates(aggregator::Coevolve) = true is_spatial(aggregator::AbstractAggregatorAlgorithm) = false is_spatial(aggregator::NSM) = true is_spatial(aggregator::DirectCRDirect) = true +is_spatial(aggregator::RSSACRDirect) = true diff --git a/src/spatial/bracketing.jl b/src/spatial/bracketing.jl index 106add2f..290efc56 100644 --- a/src/spatial/bracketing.jl +++ b/src/spatial/bracketing.jl @@ -5,9 +5,15 @@ struct LowHigh{T} low::T high::T - LowHigh(low::T, high::T) where {T} = new{T}(deepcopy(low), deepcopy(high)) - LowHigh(pair::Tuple{T,T}) where {T} = new{T}(pair[1], pair[2]) - LowHigh(low_and_high::T) where {T} = new{T}(low_and_high, deepcopy(low_and_high)) + function LowHigh(low::T, high::T; do_copy = true) where {T} + if do_copy + return new{T}(deepcopy(low), deepcopy(high)) + else + return new{T}(low, high) + end + end + LowHigh(pair::Tuple{T,T}; kwargs...) where {T} = LowHigh(pair[1], pair[2]; kwargs...) + LowHigh(low_and_high::T; kwargs...) where {T} = LowHigh(low_and_high, low_and_high; kwargs...) end function Base.show(io::IO, ::MIME"text/plain", low_high::LowHigh) @@ -16,25 +22,39 @@ function Base.show(io::IO, ::MIME"text/plain", low_high::LowHigh) end @inline function update_u_brackets!(u_low_high::LowHigh, bracket_data, u::AbstractMatrix) - @inbounds for (i, uval) in enumerate(u) - u_low_high[i] = LowHigh(get_spec_brackets(bracket_data, i, uval)) + num_species, num_sites = size(u) + update_u_brackets!(u_low_high, bracket_data, u, 1:num_species, 1:num_sites) +end + +@inline function update_u_brackets!(u_low_high::LowHigh, bracket_data, u::AbstractMatrix, species_vec, sites) + @inbounds for site in sites + for species in species_vec + u_low_high[species, site] = LowHigh(get_spec_brackets(bracket_data, species, u[species, site])) + end end nothing end +function is_outside_brackets(u_low_high::LowHigh{M}, u::M, species, site) where {M} + return u[species, site] < u_low_high.low[species, site] || u[species, site] > u_low_high.high[species, site] +end + ### convenience functions for LowHigh ### -function setindex!(low_high::LowHigh, val::LowHigh, i) - low_high.low[i] = val.low - low_high.high[i] = val.high +function setindex!(low_high::LowHigh, val::LowHigh, i...) + low_high.low[i...] = val.low + low_high.high[i...] = val.high val end +get_majumps(rx_rates::LowHigh{R}) where {R <: RxRates} = get_majumps(rx_rates.low) + function total_site_rate(rx_rates::LowHigh, hop_rates::LowHigh, site) return LowHigh( total_site_rate(rx_rates.low, hop_rates.low, site), total_site_rate(rx_rates.high, hop_rates.high, site)) end +# Compatible with constant rate jumps, because u_low_high.low and u_low_high.high are used in rate(). function update_rx_rates!(rx_rates::LowHigh, rxs, u_low_high, integrator, site) update_rx_rates!(rx_rates.low, rxs, u_low_high.low, integrator, site) update_rx_rates!(rx_rates.high, rxs, u_low_high.high, integrator, site) @@ -44,3 +64,10 @@ function update_hop_rates!(hop_rates::LowHigh, species, u_low_high, site, spatia update_hop_rates!(hop_rates.low, species, u_low_high.low, site, spatial_system) update_hop_rates!(hop_rates.high, species, u_low_high.high, site, spatial_system) end + +function reset!(low_high::LowHigh) + reset!(low_high.low) + reset!(low_high.high) +end + +reset!(array::AbstractArray) = fill!(array, zero(eltype(array))) \ No newline at end of file diff --git a/src/spatial/directcrdirect.jl b/src/spatial/directcrdirect.jl index bb44b144..611846c3 100644 --- a/src/spatial/directcrdirect.jl +++ b/src/spatial/directcrdirect.jl @@ -4,7 +4,6 @@ const MINJUMPRATE = 2.0^exponent(1e-12) #NOTE state vector u is a matrix. u[i,j] is species i, site j -#NOTE hopping_constants is a matrix. hopping_constants[i,j] is species i, site j mutable struct DirectCRDirectJumpAggregation{T, S, F1, F2, RNG, J, RX, HOP, DEPGR, VJMAP, JVMAP, SS, U <: PriorityTable, W <: Function} <: @@ -107,12 +106,12 @@ end function initialize!(p::DirectCRDirectJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] fill_rates_and_get_times!(p, integrator, t) - generate_jumps!(p, integrator, params, u, t) + generate_jumps!(p, integrator, u, params, t) nothing end # calculate the next jump / jump time -function generate_jumps!(p::DirectCRDirectJumpAggregation, integrator, params, u, t) +function generate_jumps!(p::DirectCRDirectJumpAggregation, integrator, u, params, t) p.next_jump_time = t + randexp(p.rng) / p.rt.gsum p.next_jump_time >= p.end_time && return nothing site = sample(p.rt, p.site_rates, p.rng) diff --git a/src/spatial/hop_rates.jl b/src/spatial/hop_rates.jl index ef7f73a5..2b26283c 100644 --- a/src/spatial/hop_rates.jl +++ b/src/spatial/hop_rates.jl @@ -57,29 +57,28 @@ function HopRates(p::Pair{SpecHop, SiteHop}, end """ - update_hop_rates!(hop_rates::AbstractHopRates, species::AbstractArray, u, site, spatial_system) + update_hop_rates!(hop_rates::HopRatesGraphDsi, species_vec, u, site, spatial_system) -update rates of all specs in species at site + update rates of all species in species_vec at site """ -function update_hop_rates!(hop_rates::AbstractHopRates, species::AbstractArray, u, site, - spatial_system) - @inbounds for spec in species - update_hop_rate!(hop_rates, spec, u, site, spatial_system) +function update_hop_rates!(hop_rates::AbstractHopRates, species_vec, u, site, spatial_system) + @inbounds for species in species_vec + rates = hop_rates.rates + old_rate = rates[species, site] + rates[species, site] = evalhoprate(hop_rates, u, species, site, + spatial_system) + hop_rates.sum_rates[site] += rates[species, site] - old_rate + old_rate end end -""" - update_hop_rate!(hop_rates::HopRatesGraphDsi, species, u, site, spatial_system) - -update rates of single species at site -""" -function update_hop_rate!(hop_rates::AbstractHopRates, species, u, site, spatial_system) - rates = hop_rates.rates - @inbounds old_rate = rates[species, site] - @inbounds rates[species, site] = evalhoprate(hop_rates, u, species, site, - spatial_system) - @inbounds hop_rates.sum_rates[site] += rates[species, site] - old_rate - old_rate +function recompute_site_hop_rate(hop_rates::HP, u, site, spatial_system) where {HP <: AbstractHopRates} + rate = zero(eltype(hop_rates.rates)) + num_species = size(hop_rates.rates, 1) + for species in 1:num_species + rate += evalhoprate(hop_rates, u, species, site, spatial_system) + end + return rate end """ @@ -197,7 +196,7 @@ end return hopping rate of species at site """ function evalhoprate(hop_rates::HopRatesGraphDsi, u, species, site, spatial_system) - @inbounds u[species, site] * hop_rates.hopping_constants[species, site] * + u[species, site] * hop_rates.hopping_constants[species, site] * outdegree(spatial_system, site) end diff --git a/src/spatial/nsm.jl b/src/spatial/nsm.jl index 3cfe7eed..a5341665 100644 --- a/src/spatial/nsm.jl +++ b/src/spatial/nsm.jl @@ -95,12 +95,12 @@ end function initialize!(p::NSMJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] fill_rates_and_get_times!(p, integrator, t) - generate_jumps!(p, integrator, params, u, t) + generate_jumps!(p, integrator, u, params, t) nothing end # calculate the next jump / jump time -function generate_jumps!(p::NSMJumpAggregation, integrator, params, u, t) +function generate_jumps!(p::NSMJumpAggregation, integrator, u, params, t) p.next_jump_time, site = top_with_handle(p.pq) p.next_jump_time >= p.end_time && return nothing p.next_jump = sample_jump_direct(p, site) diff --git a/src/spatial/reaction_rates.jl b/src/spatial/reaction_rates.jl index 2aba9df1..54f38324 100644 --- a/src/spatial/reaction_rates.jl +++ b/src/spatial/reaction_rates.jl @@ -26,6 +26,7 @@ function RxRates(num_sites::Int, ma_jumps::M) where {M} end num_rxs(rx_rates::RxRates) = get_num_majumps(rx_rates.ma_jumps) +get_majumps(rx_rates::RxRates) = rx_rates.ma_jumps """ reset!(rx_rates::RxRates) @@ -77,6 +78,20 @@ function sample_rx_at_site(rx_rates::RxRates, site, rng) rand(rng) * total_site_rx_rate(rx_rates, site)) end +""" + recompute_site_rx_rate(rx_rates::RxRates, u, site) + +compute the total reaction rate at site at the current state u +""" +function recompute_site_rx_rate(rx_rates::RxRates, u, site) + rate = zero(eltype(rx_rates.rates)) + ma_jumps = rx_rates.ma_jumps + for rx in 1:num_rxs(rx_rates) + rate += eval_massaction_rate(u, rx, ma_jumps, site) + end + return rate +end + # helper functions function set_rx_rate_at_site!(rx_rates::RxRates, site, rx, rate) @inbounds old_rate = rx_rates.rates[rx, site] diff --git a/src/spatial/rssacrdirect.jl b/src/spatial/rssacrdirect.jl new file mode 100644 index 00000000..8413570c --- /dev/null +++ b/src/spatial/rssacrdirect.jl @@ -0,0 +1,228 @@ +# site chosen with RSSACR, rx or hop chosen with Direct + +############################ RSSACRDirect ################################### +const MINJUMPRATE = 2.0^exponent(1e-12) + +#NOTE state vector u is a matrix. u[i,j] is species i, site j +mutable struct RSSACRDirectJumpAggregation{T, BD, M, RNG, J, RX, HOP, DEPGR, + VJMAP, JVMAP, SS, U <: PriorityTable, S, F1, F2} <: + AbstractSSAJumpAggregator{T, S, F1, F2, RNG} + next_jump::SpatialJump{J} + prev_jump::SpatialJump{J} + next_jump_time::T + end_time::T + bracket_data::BD + u_low_high::LowHigh{M} # species bracketing + rx_rates::LowHigh{RX} + hop_rates::LowHigh{HOP} + site_rates::LowHigh{Vector{T}} + save_positions::Tuple{Bool, Bool} + rng::RNG + dep_gr::DEPGR #dep graph is same for each site + vartojumps_map::VJMAP #vartojumps_map is same for each site + jumptovars_map::JVMAP #jumptovars_map is same for each site + spatial_system::SS + numspecies::Int #number of species + rt::U + rates::F1 # legacy, not used + affects!::F2 # legacy, not used +end + +function RSSACRDirectJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, bd::BD, + u_low_high::LowHigh{M}, rx_rates::LowHigh{RX}, + hop_rates::LowHigh{HOP}, site_rates::LowHigh{Vector{T}}, + sps::Tuple{Bool, Bool}, rng::RNG, spatial_system::SS; + num_specs, minrate = convert(T, MINJUMPRATE), + vartojumps_map = nothing, jumptovars_map = nothing, + dep_graph = nothing, + kwargs...) where {J, T, BD, RX, HOP, RNG, SS, M} + + # a dependency graph is needed + if dep_graph === nothing + dg = make_dependency_graph(num_specs, rx_rates.low.ma_jumps) + else + dg = dep_graph + # make sure each jump depends on itself + add_self_dependencies!(dg) + end + + # a species-to-reactions graph is needed + if vartojumps_map === nothing + vtoj_map = var_to_jumps_map(num_specs, rx_rates.low.ma_jumps) + else + vtoj_map = vartojumps_map + end + + if jumptovars_map === nothing + jtov_map = jump_to_vars_map(rx_rates.low.ma_jumps) + else + jtov_map = jumptovars_map + end + + # mapping from jump rate to group id + minexponent = exponent(minrate) + + # use the largest power of two that is <= the passed in minrate + minrate = 2.0^minexponent + ratetogroup = rate -> priortogid(rate, minexponent) + + # construct an empty initial priority table -- we'll reset this in init + rt = PriorityTable(ratetogroup, zeros(T, 1), minrate, 2 * minrate) + + RSSACRDirectJumpAggregation{T, BD, M, RNG, J, RX, HOP, typeof(dg), typeof(vtoj_map), typeof(jtov_map), SS, typeof(rt), Nothing, Nothing, Nothing}( + nj, nj, njt, et, bd, u_low_high, rx_rates, hop_rates, site_rates, sps, rng, dg, vtoj_map, jtov_map, spatial_system, num_specs, rt, nothing, nothing) +end + +############################# Required Functions ############################## +# creating the JumpAggregation structure (function wrapper-based constant jumps) +function aggregate(aggregator::RSSACRDirect, starting_state, p, t, end_time, + constant_jumps, ma_jumps, save_positions, rng; hopping_constants, + spatial_system, bracket_data = nothing, kwargs...) + T = typeof(end_time) + num_species = size(starting_state, 1) + majumps = ma_jumps + if majumps === nothing + majumps = MassActionJump(Vector{T}(), + Vector{Vector{Pair{Int, Int}}}(), + Vector{Vector{Pair{Int, Int}}}()) + end + + next_jump = SpatialJump{Int}(typemax(Int), typemax(Int), typemax(Int)) #a placeholder + next_jump_time = typemax(T) + rx_rates = LowHigh(RxRates(num_sites(spatial_system), majumps), + RxRates(num_sites(spatial_system), majumps); + do_copy = false) # do not copy ma_jumps + hop_rates = LowHigh(HopRates(hopping_constants, spatial_system), + HopRates(hopping_constants, spatial_system); + do_copy = false) # do not copy hopping_constants + site_rates = LowHigh(zeros(T, num_sites(spatial_system))) + bd = (bracket_data === nothing) ? BracketData{T, eltype(starting_state)}() : + bracket_data + u_low_high = LowHigh(starting_state) + + RSSACRDirectJumpAggregation(next_jump, next_jump_time, end_time, bd, u_low_high, + rx_rates, hop_rates, + site_rates, save_positions, rng, spatial_system; + num_specs = num_species, kwargs...) +end + +# set up a new simulation and calculate the first jump / jump time +function initialize!(p::RSSACRDirectJumpAggregation, integrator, u, params, t) + p.end_time = integrator.sol.prob.tspan[2] + fill_rates_and_get_times!(p, integrator, t) + generate_jumps!(p, integrator, u, params, t) + nothing +end + +# calculate the next jump / jump time +function generate_jumps!(p::RSSACRDirectJumpAggregation, integrator, u, params, t) + @unpack rng, rt, site_rates, rx_rates, hop_rates, spatial_system = p + time_delta = zero(t) + site = zero(eltype(u)) + while true + site = sample(rt, site_rates.high, rng) + time_delta += randexp(rng) + accept_jump(rx_rates, hop_rates, site_rates, u, site, spatial_system, rng) && break + end + p.next_jump_time = t + time_delta / groupsum(rt) + p.next_jump = sample_jump_direct(rx_rates.high, hop_rates.high, site, spatial_system, rng) + nothing +end + +# execute one jump, changing the system state +function execute_jumps!(p::RSSACRDirectJumpAggregation, integrator, u, params, t, + affects!) + update_state!(p, integrator) + update_dependent_rates!(p, integrator, t) + nothing +end + +######################## SSA specific helper routines ######################## +# Return true if site is accepted. +function accept_jump(rx_rates, hop_rates, site_rates, u, site, spatial_system, rng) + acceptance_threshold = rand(rng) * site_rates.high[site] + if acceptance_threshold < site_rates.low[site] + return true + else + site_rate = recompute_site_hop_rate(hop_rates.low, u, site, spatial_system) + + recompute_site_rx_rate(rx_rates.low, u, site) + return acceptance_threshold < site_rate + end +end + +""" + fill_rates_and_get_times!(aggregation::RSSACRDirectJumpAggregation, u, t) + +reset all stucts, reevaluate all rates, repopulate the priority table +""" +function fill_rates_and_get_times!(aggregation::RSSACRDirectJumpAggregation, integrator, t) + @unpack bracket_data, u_low_high, spatial_system, rx_rates, hop_rates, site_rates, rt = aggregation + u = integrator.u + update_u_brackets!(u_low_high::LowHigh, bracket_data, u::AbstractMatrix) + + reset!(rx_rates) + reset!(hop_rates) + reset!(site_rates) + + rxs = 1:num_rxs(rx_rates.low) + species = 1:(aggregation.numspecies) + + for site in 1:num_sites(spatial_system) + update_rx_rates!(rx_rates, rxs, u_low_high, integrator, site) + update_hop_rates!(hop_rates, species, u_low_high, site, spatial_system) + site_rates[site] = total_site_rate(rx_rates, hop_rates, site) + end + + # setup PriorityTable + reset!(rt) + for (pid, priority) in enumerate(site_rates.high) + insert!(rt, pid, priority) + end + nothing +end + +""" + update_dependent_rates!(p, integrator, t) + +recalculate jump rates for jumps that depend on the just executed jump (p.prev_jump) +""" +function update_dependent_rates!(p::RSSACRDirectJumpAggregation, + integrator, + t) + @unpack rx_rates, hop_rates, site_rates, u_low_high, bracket_data, vartojumps_map, jumptovars_map, spatial_system = p + + u = integrator.u + site_rates = p.site_rates + jump = p.prev_jump + + if is_hop(p, jump) + species_to_update = jump.jidx + sites_to_update = (jump.src, jump.dst) + else + species_to_update = jumptovars_map[reaction_id_from_jump(p, jump)] + sites_to_update = jump.src + end + + for site in sites_to_update, species in species_to_update + if is_outside_brackets(u_low_high, u, species, site) + update_u_brackets!(u_low_high, bracket_data, u, species, site) + update_rx_rates!(rx_rates, + vartojumps_map[species], + u_low_high, + integrator, + site) + update_hop_rates!(hop_rates, species, u_low_high, site, spatial_system) + + oldrate = site_rates.high[site] + site_rates[site] = total_site_rate(p.rx_rates, p.hop_rates, site) + update!(p.rt, site, oldrate, site_rates.high[site]) + end + end +end + +""" + num_constant_rate_jumps(aggregator::RSSACRDirectJumpAggregation) + +number of constant rate jumps +""" +num_constant_rate_jumps(aggregator::RSSACRDirectJumpAggregation) = 0 \ No newline at end of file diff --git a/src/spatial/utils.jl b/src/spatial/utils.jl index ddb9db41..16191002 100644 --- a/src/spatial/utils.jl +++ b/src/spatial/utils.jl @@ -27,18 +27,23 @@ end sample jump at site with direct method """ -function sample_jump_direct(p, site) - if rand(p.rng) * (total_site_rate(p.rx_rates, p.hop_rates, site)) < - total_site_rx_rate(p.rx_rates, site) - rx = sample_rx_at_site(p.rx_rates, site, p.rng) - return SpatialJump(site, rx + p.numspecies, site) +function sample_jump_direct(rx_rates, hop_rates, site, spatial_system, rng) + numspecies = size(hop_rates.rates, 1) + if rand(rng) * (total_site_rate(rx_rates, hop_rates, site)) < + total_site_rx_rate(rx_rates, site) + rx = sample_rx_at_site(rx_rates, site, rng) + return SpatialJump(site, rx + numspecies, site) else - species_to_diffuse, target_site = sample_hop_at_site(p.hop_rates, site, p.rng, - p.spatial_system) + species_to_diffuse, target_site = sample_hop_at_site(hop_rates, site, rng, + spatial_system) return SpatialJump(site, species_to_diffuse, target_site) end end +function sample_jump_direct(p, site) + sample_jump_direct(p.rx_rates, p.hop_rates, site, p.spatial_system, p.rng) +end + function total_site_rate(rx_rates::RxRates, hop_rates::AbstractHopRates, site) total_site_hop_rate(hop_rates, site) + total_site_rx_rate(rx_rates, site) end @@ -52,10 +57,10 @@ end function update_rates_after_hop!(p, integrator, source_site, target_site, species) u = integrator.u update_rx_rates!(p.rx_rates, p.vartojumps_map[species], integrator, source_site) - update_hop_rate!(p.hop_rates, species, u, source_site, p.spatial_system) + update_hop_rates!(p.hop_rates, species, u, source_site, p.spatial_system) update_rx_rates!(p.rx_rates, p.vartojumps_map[species], integrator, target_site) - update_hop_rate!(p.hop_rates, species, u, target_site, p.spatial_system) + update_hop_rates!(p.hop_rates, species, u, target_site, p.spatial_system) end """ @@ -70,7 +75,7 @@ function update_state!(p, integrator) else rx_index = reaction_id_from_jump(p, jump) @inbounds executerx!((@view integrator.u[:, jump.src]), rx_index, - p.rx_rates.ma_jumps) + get_majumps(p.rx_rates)) end # save jump that was just exectued p.prev_jump = jump diff --git a/test/spatial/ABC.jl b/test/spatial/ABC.jl index 558a701a..48a22c3e 100644 --- a/test/spatial/ABC.jl +++ b/test/spatial/ABC.jl @@ -2,12 +2,12 @@ using JumpProcesses, DiffEqBase # using BenchmarkTools using Test, Graphs -Nsims = 100 +Nsims = 1000 reltol = 0.05 non_spatial_mean = [65.7395, 65.7395, 434.2605] #mean of 10,000 simulations dim = 1 -linear_size = 5 +linear_size = 1 dims = Tuple(repeat([linear_size], dim)) num_nodes = prod(dims) starting_site = trunc(Int, (linear_size^dim + 1) / 2) @@ -47,27 +47,27 @@ end # testing grids = [CartesianGridRej(dims), Graphs.grid(dims)] -jump_problems = JumpProblem[JumpProblem(prob, NSM(), majumps, - hopping_constants = hopping_constants, - spatial_system = grid, - save_positions = (false, false)) for grid in grids] -push!(jump_problems, - JumpProblem(prob, DirectCRDirect(), majumps, hopping_constants = hopping_constants, - spatial_system = grids[1], save_positions = (false, false))) -# setup flattenned jump prob -push!(jump_problems, - JumpProblem(prob, NRM(), majumps, hopping_constants = hopping_constants, - spatial_system = grids[1], save_positions = (false, false))) -# test -for spatial_jump_prob in jump_problems - solution = solve(spatial_jump_prob, SSAStepper()) - mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) - mean_end_state = reshape(mean_end_state, num_species, num_nodes) - diff = sum(mean_end_state, dims = 2) - non_spatial_mean - for (i, d) in enumerate(diff) - @test abs(d) < reltol * non_spatial_mean[i] - end -end +# jump_problems = JumpProblem[JumpProblem(prob, NSM(), majumps, +# hopping_constants = hopping_constants, +# spatial_system = grid, +# save_positions = (false, false)) for grid in grids] +# push!(jump_problems, +# JumpProblem(prob, DirectCRDirect(), majumps, hopping_constants = hopping_constants, +# spatial_system = grids[1], save_positions = (false, false))) +# # setup flattenned jump prob +# push!(jump_problems, +# JumpProblem(prob, NRM(), majumps, hopping_constants = hopping_constants, +# spatial_system = grids[1], save_positions = (false, false))) +# # test +# for spatial_jump_prob in jump_problems +# solution = solve(spatial_jump_prob, SSAStepper()) +# mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) +# mean_end_state = reshape(mean_end_state, num_species, num_nodes) +# diff = sum(mean_end_state, dims = 2) - non_spatial_mean +# for (i, d) in enumerate(diff) +# @test abs(d) < reltol * non_spatial_mean[i] +# end +# end #using non-spatial SSAs to get the mean # non_spatial_rates = [0.1,1.0] @@ -77,3 +77,24 @@ end # non_spatial_prob = DiscreteProblem(u0,(0.0,end_time), non_spatial_rates) # jump_prob = JumpProblem(non_spatial_prob, Direct(), majumps) # non_spatial_mean = get_mean_end_state(jump_prob, 10000) + +spatial_jump_prob = JumpProblem(prob, RSSACRDirect(), majumps, hopping_constants = hopping_constants, + spatial_system = grids[1], save_positions = (false, false)) +sol = solve(spatial_jump_prob, SSAStepper()) +mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) +mean_end_state = reshape(mean_end_state, num_species, num_nodes) +diff = sum(mean_end_state, dims = 2) - non_spatial_mean +for (i, d) in enumerate(diff) + @test abs(d) < reltol * non_spatial_mean[i] +end + + +spatial_jump_prob = JumpProblem(prob, NSM(), majumps, hopping_constants = hopping_constants, + spatial_system = grids[1], save_positions = (false, false)) +sol = solve(spatial_jump_prob, SSAStepper()) +mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) +mean_end_state = reshape(mean_end_state, num_species, num_nodes) +diff = sum(mean_end_state, dims = 2) - non_spatial_mean +for (i, d) in enumerate(diff) + @test abs(d) < reltol * non_spatial_mean[i] +end \ No newline at end of file