Skip to content

Commit

Permalink
Merge pull request #430 from isaacsas/fix_saveat_plotting
Browse files Browse the repository at this point in the history
Fix saveat plotting
  • Loading branch information
isaacsas authored Jul 26, 2024
2 parents 7be7a72 + 9902ecc commit a5fb32a
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ PoissonRandom = "0.4"
RandomNumbers = "1.5"
RecursiveArrayTools = "3.12"
Reexport = "1.0"
SciMLBase = "2.30.1"
SciMLBase = "2.46"
StaticArrays = "1.9"
SymbolicIndexingInterface = "0.3.13"
UnPack = "1.0.2"
Expand Down
2 changes: 1 addition & 1 deletion src/JumpProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using RandomNumbers, LinearAlgebra, Markdown, DocStringExtensions
using DataStructures, PoissonRandom, Random, ArrayInterface
using FunctionWrappers, UnPack
using Graphs
using SciMLBase: SciMLBase
using SciMLBase: SciMLBase, isdenseplot
using Base.FastMath: add_fast

import DiffEqBase: DiscreteCallback, init, solve, solve!, plot_indices, initialize!
Expand Down
10 changes: 8 additions & 2 deletions src/SSA_stepper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,11 @@ function DiffEqBase.__init(jump_prob::JumpProblem,
u = typeof(prob.u0)[]
end

sol = DiffEqBase.build_solution(prob, alg, t, u, dense = false,
save_everystep = any(cb.save_positions)
sol = DiffEqBase.build_solution(prob, alg, t, u, dense = save_everystep,
calculate_error = false,
stats = DiffEqBase.Stats(0),
interp = DiffEqBase.ConstantInterpolation(t, u))
save_everystep = any(cb.save_positions)

if saveat isa Number
_saveat = prob.tspan[1]:saveat:prob.tspan[2]
Expand Down Expand Up @@ -331,3 +331,9 @@ function DiffEqBase.terminate!(integrator::SSAIntegrator, retcode = ReturnCode.T
end

export SSAStepper

function SciMLBase.isdenseplot(sol::ODESolution{
T, N, uType, uType2, DType, tType, rateType, discType, P,
SSAStepper}) where {T, N, uType, uType2, DType, tType, rateType, discType, P}
sol.dense
end
74 changes: 54 additions & 20 deletions test/save_positions.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,62 @@
using JumpProcesses, OrdinaryDiffEq, Test
using JumpProcesses, OrdinaryDiffEq, Test, SciMLBase
using StableRNGs
rng = StableRNG(12345)

# test that we only save when a jump occurs
for alg in (Coevolve(),)
u0 = [0]
tspan = (0.0, 30.0)
let
for alg in (Coevolve(),)
u0 = [0]
tspan = (0.0, 30.0)

dprob = DiscreteProblem(u0, tspan)
# set the rate to 0, so that no jump ever occurs; but urate is positive so
# Coevolve will consider many candidates before the end of the simmulation.
# None of these points should be saved.
jump = VariableRateJump((u, p, t) -> 0, (integrator) -> integrator.u[1] += 1;
urate = (u, p, t) -> 1.0, rateinterval = (u, p, t) -> 5.0)
jumpproblem = JumpProblem(dprob, alg, jump; dep_graph = [[1]],
save_positions = (false, true), rng)
sol = solve(jumpproblem, SSAStepper())
@test sol.t == [0.0, 30.0]

oprob = ODEProblem((du, u, p, t) -> 0, u0, tspan)
jump = VariableRateJump((u, p, t) -> 0, (integrator) -> integrator.u[1] += 1;
urate = (u, p, t) -> 1.0, rateinterval = (u, p, t) -> 5.0)
jumpproblem = JumpProblem(oprob, alg, jump; dep_graph = [[1]],
save_positions = (false, true), rng)
sol = solve(jumpproblem, Tsit5(); save_everystep = false)
@test sol.t == [0.0, 30.0]
end
end

# test isdenseplot gives correct values for SSAStepper and non-SSAStepper models
let
rate(u, p, t) = max(u[1], 0.0)
affect!(integ) = (integ.u[1] -= 1; nothing)
crj = ConstantRateJump(rate, affect!)
u0 = [10.0]
tspan = (0.0, 10.0)
dprob = DiscreteProblem(u0, tspan)
# set the rate to 0, so that no jump ever occurs; but urate is positive so
# Coevolve will consider many candidates before the end of the simmulation.
# None of these points should be saved.
jump = VariableRateJump((u, p, t) -> 0, (integrator) -> integrator.u[1] += 1;
urate = (u, p, t) -> 1.0, rateinterval = (u, p, t) -> 5.0)
jumpproblem = JumpProblem(dprob, alg, jump; dep_graph = [[1]],
save_positions = (false, true))
sol = solve(jumpproblem, SSAStepper())
@test sol.t == [0.0, 30.0]
sps = ((true, true), (true, false), (false, true), (false, false))

# for pure jump problems dense = save_everystep
vals = (true, true, true, false)
for (sp, val) in zip(sps, vals)
jprob = JumpProblem(dprob, Direct(), crj; save_positions = sp, rng)
sol = solve(jprob, SSAStepper())
@test SciMLBase.isdenseplot(sol) == val
end

# for mixed problems sol.dense currently ignores save_positions
oprob = ODEProblem((du, u, p, t) -> du[1] = 0.1, u0, tspan)
for sp in sps
jprob = JumpProblem(oprob, Direct(), crj; save_positions = sp, rng)
sol = solve(jprob, Tsit5())
@test sol.dense == true
@test SciMLBase.isdenseplot(sol) == true

oprob = ODEProblem((du, u, p, t) -> 0, u0, tspan)
jump = VariableRateJump((u, p, t) -> 0, (integrator) -> integrator.u[1] += 1;
urate = (u, p, t) -> 1.0, rateinterval = (u, p, t) -> 5.0)
jumpproblem = JumpProblem(oprob, alg, jump; dep_graph = [[1]],
save_positions = (false, true))
sol = solve(jumpproblem, Tsit5(); save_everystep = false)
@test sol.t == [0.0, 30.0]
sol = solve(jprob, Tsit5(); dense = false)
@test sol.dense == false
@test SciMLBase.isdenseplot(sol) == false
end
end

0 comments on commit a5fb32a

Please sign in to comment.