Skip to content

Commit

Permalink
Write RSSACR-Direct but it's incorrect. I will turn it into CR-RSSA.
Browse files Browse the repository at this point in the history
  • Loading branch information
Vilin97 committed Sep 10, 2023
1 parent 3f1fd74 commit 7861bc7
Show file tree
Hide file tree
Showing 10 changed files with 364 additions and 66 deletions.
3 changes: 2 additions & 1 deletion src/JumpProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ include("spatial/bracketing.jl")

include("spatial/nsm.jl")
include("spatial/directcrdirect.jl")
include("spatial/rssacrdirect.jl")

include("aggregators/aggregated_api.jl")

Expand Down Expand Up @@ -101,6 +102,6 @@ export ExtendedJumpArray
export CartesianGrid, CartesianGridRej
export SpatialMassActionJump
export outdegree, num_sites, neighbors
export NSM, DirectCRDirect
export NSM, DirectCRDirect, RSSACRDirect

end # module
3 changes: 3 additions & 0 deletions src/aggregators/aggregators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ algorithm with optimal binning, Journal of Chemical Physics 143, 074108
"""
struct DirectCRDirect <: AbstractAggregatorAlgorithm end

struct RSSACRDirect <: AbstractAggregatorAlgorithm end

const JUMP_AGGREGATORS = (Direct(), DirectFW(), DirectCR(), SortingDirect(), RSSA(), FRM(),
FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve())

Expand Down Expand Up @@ -187,3 +189,4 @@ supports_variablerates(aggregator::Coevolve) = true
is_spatial(aggregator::AbstractAggregatorAlgorithm) = false
is_spatial(aggregator::NSM) = true
is_spatial(aggregator::DirectCRDirect) = true
is_spatial(aggregator::RSSACRDirect) = true
43 changes: 35 additions & 8 deletions src/spatial/bracketing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@ struct LowHigh{T}
low::T
high::T

LowHigh(low::T, high::T) where {T} = new{T}(deepcopy(low), deepcopy(high))
LowHigh(pair::Tuple{T,T}) where {T} = new{T}(pair[1], pair[2])
LowHigh(low_and_high::T) where {T} = new{T}(low_and_high, deepcopy(low_and_high))
function LowHigh(low::T, high::T; do_copy = true) where {T}
if do_copy
return new{T}(deepcopy(low), deepcopy(high))
else
return new{T}(low, high)
end
end
LowHigh(pair::Tuple{T,T}; kwargs...) where {T} = LowHigh(pair[1], pair[2]; kwargs...)
LowHigh(low_and_high::T; kwargs...) where {T} = LowHigh(low_and_high, low_and_high; kwargs...)
end

function Base.show(io::IO, ::MIME"text/plain", low_high::LowHigh)
Expand All @@ -16,25 +22,39 @@ function Base.show(io::IO, ::MIME"text/plain", low_high::LowHigh)
end

@inline function update_u_brackets!(u_low_high::LowHigh, bracket_data, u::AbstractMatrix)
@inbounds for (i, uval) in enumerate(u)
u_low_high[i] = LowHigh(get_spec_brackets(bracket_data, i, uval))
num_species, num_sites = size(u)
update_u_brackets!(u_low_high, bracket_data, u, 1:num_species, 1:num_sites)
end

@inline function update_u_brackets!(u_low_high::LowHigh, bracket_data, u::AbstractMatrix, species_vec, sites)
@inbounds for site in sites
for species in species_vec
u_low_high[species, site] = LowHigh(get_spec_brackets(bracket_data, species, u[species, site]))
end
end
nothing
end

function is_outside_brackets(u_low_high::LowHigh{M}, u::M, species, site) where {M}
return u[species, site] < u_low_high.low[species, site] || u[species, site] > u_low_high.high[species, site]
end

### convenience functions for LowHigh ###
function setindex!(low_high::LowHigh, val::LowHigh, i)
low_high.low[i] = val.low
low_high.high[i] = val.high
function setindex!(low_high::LowHigh, val::LowHigh, i...)
low_high.low[i...] = val.low
low_high.high[i...] = val.high
val
end

get_majumps(rx_rates::LowHigh{R}) where {R <: RxRates} = get_majumps(rx_rates.low)

function total_site_rate(rx_rates::LowHigh, hop_rates::LowHigh, site)
return LowHigh(
total_site_rate(rx_rates.low, hop_rates.low, site),
total_site_rate(rx_rates.high, hop_rates.high, site))
end

# Compatible with constant rate jumps, because u_low_high.low and u_low_high.high are used in rate().
function update_rx_rates!(rx_rates::LowHigh, rxs, u_low_high, integrator, site)
update_rx_rates!(rx_rates.low, rxs, u_low_high.low, integrator, site)
update_rx_rates!(rx_rates.high, rxs, u_low_high.high, integrator, site)
Expand All @@ -44,3 +64,10 @@ function update_hop_rates!(hop_rates::LowHigh, species, u_low_high, site, spatia
update_hop_rates!(hop_rates.low, species, u_low_high.low, site, spatial_system)
update_hop_rates!(hop_rates.high, species, u_low_high.high, site, spatial_system)
end

function reset!(low_high::LowHigh)
reset!(low_high.low)
reset!(low_high.high)
end

reset!(array::AbstractArray) = fill!(array, zero(eltype(array)))
5 changes: 2 additions & 3 deletions src/spatial/directcrdirect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
const MINJUMPRATE = 2.0^exponent(1e-12)

#NOTE state vector u is a matrix. u[i,j] is species i, site j
#NOTE hopping_constants is a matrix. hopping_constants[i,j] is species i, site j
mutable struct DirectCRDirectJumpAggregation{T, S, F1, F2, RNG, J, RX, HOP, DEPGR,
VJMAP, JVMAP, SS, U <: PriorityTable,
W <: Function} <:
Expand Down Expand Up @@ -107,12 +106,12 @@ end
function initialize!(p::DirectCRDirectJumpAggregation, integrator, u, params, t)
p.end_time = integrator.sol.prob.tspan[2]
fill_rates_and_get_times!(p, integrator, t)
generate_jumps!(p, integrator, params, u, t)
generate_jumps!(p, integrator, u, params, t)
nothing
end

# calculate the next jump / jump time
function generate_jumps!(p::DirectCRDirectJumpAggregation, integrator, params, u, t)
function generate_jumps!(p::DirectCRDirectJumpAggregation, integrator, u, params, t)
p.next_jump_time = t + randexp(p.rng) / p.rt.gsum
p.next_jump_time >= p.end_time && return nothing
site = sample(p.rt, p.site_rates, p.rng)
Expand Down
37 changes: 18 additions & 19 deletions src/spatial/hop_rates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,29 +57,28 @@ function HopRates(p::Pair{SpecHop, SiteHop},
end

"""
update_hop_rates!(hop_rates::AbstractHopRates, species::AbstractArray, u, site, spatial_system)
update_hop_rates!(hop_rates::HopRatesGraphDsi, species_vec, u, site, spatial_system)
update rates of all specs in species at site
update rates of all species in species_vec at site
"""
function update_hop_rates!(hop_rates::AbstractHopRates, species::AbstractArray, u, site,
spatial_system)
@inbounds for spec in species
update_hop_rate!(hop_rates, spec, u, site, spatial_system)
function update_hop_rates!(hop_rates::AbstractHopRates, species_vec, u, site, spatial_system)
@inbounds for species in species_vec
rates = hop_rates.rates
old_rate = rates[species, site]
rates[species, site] = evalhoprate(hop_rates, u, species, site,
spatial_system)
hop_rates.sum_rates[site] += rates[species, site] - old_rate
old_rate
end
end

"""
update_hop_rate!(hop_rates::HopRatesGraphDsi, species, u, site, spatial_system)
update rates of single species at site
"""
function update_hop_rate!(hop_rates::AbstractHopRates, species, u, site, spatial_system)
rates = hop_rates.rates
@inbounds old_rate = rates[species, site]
@inbounds rates[species, site] = evalhoprate(hop_rates, u, species, site,
spatial_system)
@inbounds hop_rates.sum_rates[site] += rates[species, site] - old_rate
old_rate
function recompute_site_hop_rate(hop_rates::HP, u, site, spatial_system) where {HP <: AbstractHopRates}
rate = zero(eltype(hop_rates.rates))
num_species = size(hop_rates.rates, 1)
for species in 1:num_species
rate += evalhoprate(hop_rates, u, species, site, spatial_system)
end
return rate
end

"""
Expand Down Expand Up @@ -197,7 +196,7 @@ end
return hopping rate of species at site
"""
function evalhoprate(hop_rates::HopRatesGraphDsi, u, species, site, spatial_system)
@inbounds u[species, site] * hop_rates.hopping_constants[species, site] *
u[species, site] * hop_rates.hopping_constants[species, site] *
outdegree(spatial_system, site)
end

Expand Down
4 changes: 2 additions & 2 deletions src/spatial/nsm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ end
function initialize!(p::NSMJumpAggregation, integrator, u, params, t)
p.end_time = integrator.sol.prob.tspan[2]
fill_rates_and_get_times!(p, integrator, t)
generate_jumps!(p, integrator, params, u, t)
generate_jumps!(p, integrator, u, params, t)
nothing
end

# calculate the next jump / jump time
function generate_jumps!(p::NSMJumpAggregation, integrator, params, u, t)
function generate_jumps!(p::NSMJumpAggregation, integrator, u, params, t)
p.next_jump_time, site = top_with_handle(p.pq)
p.next_jump_time >= p.end_time && return nothing
p.next_jump = sample_jump_direct(p, site)
Expand Down
15 changes: 15 additions & 0 deletions src/spatial/reaction_rates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ function RxRates(num_sites::Int, ma_jumps::M) where {M}
end

num_rxs(rx_rates::RxRates) = get_num_majumps(rx_rates.ma_jumps)
get_majumps(rx_rates::RxRates) = rx_rates.ma_jumps

"""
reset!(rx_rates::RxRates)
Expand Down Expand Up @@ -77,6 +78,20 @@ function sample_rx_at_site(rx_rates::RxRates, site, rng)
rand(rng) * total_site_rx_rate(rx_rates, site))
end

"""
recompute_site_rx_rate(rx_rates::RxRates, u, site)
compute the total reaction rate at site at the current state u
"""
function recompute_site_rx_rate(rx_rates::RxRates, u, site)
rate = zero(eltype(rx_rates.rates))
ma_jumps = rx_rates.ma_jumps
for rx in 1:num_rxs(rx_rates)
rate += eval_massaction_rate(u, rx, ma_jumps, site)
end
return rate
end

# helper functions
function set_rx_rate_at_site!(rx_rates::RxRates, site, rx, rate)
@inbounds old_rate = rx_rates.rates[rx, site]
Expand Down
Loading

0 comments on commit 7861bc7

Please sign in to comment.