Skip to content

Commit

Permalink
fix broadcast issue
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacsas committed Sep 2, 2024
1 parent 103c487 commit 3f74888
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
7 changes: 4 additions & 3 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ function resetted_jump_problem(_jump_prob, seed)

if !isempty(jump_prob.variable_jumps)
@assert jump_prob.prob.u0 isa ExtendedJumpArray
@. jump_prob.prob.u0.jump_u = -randexp(_jump_prob.rng,
eltype(_jump_prob.prob.tspan))
ttype = eltype(_jump_prob.prob.tspan)
@. jump_prob.prob.u0.jump_u = -randexp(_jump_prob.rng, ttype)
end
jump_prob
end
Expand All @@ -71,6 +71,7 @@ function reset_jump_problem!(jump_prob, seed)

if !isempty(jump_prob.variable_jumps)
@assert jump_prob.prob.u0 isa ExtendedJumpArray
@. jump_prob.prob.u0.jump_u = -randexp(jump_prob.rng, eltype(jump_prob.prob.tspan))
ttype = eltype(jump_prob.prob.tspan)
@. jump_prob.prob.u0.jump_u = -randexp(jump_prob.rng, ttype)
end
end
8 changes: 5 additions & 3 deletions test/variable_rate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,10 @@ let

ode_prob = ODEProblem(ode_fxn, u0, tspan, p)
sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; rng)
@test allunique(sjm_prob.prob.u0.jump_u)
u0old = copy(sjm_prob.prob.u0.jump_u)
for i in 1:Nsims
sol = solve(sjm_prob, Tsit5(); saveat = tspan[2])
sol = solve(sjm_prob, Tsit5(); saveat = tspan[2])
@test allunique(sjm_prob.prob.u0.jump_u)
@test all(u0old != sjm_prob.prob.u0.jump_u)
u0old .= sjm_prob.prob.u0.jump_u
Expand All @@ -272,7 +273,8 @@ end

# accuracy test based on
# https://github.com/SciML/JumpProcesses.jl/issues/320
# note that testing that precisely is not trivial
# note that even with the seeded StableRNG this test is not
# deterministic for some reason.
let
rng = StableRNG(12345)
b = 2.0
Expand Down Expand Up @@ -304,7 +306,7 @@ let
d_jump = VariableRateJump(d_rate, death!)

ode_prob = ODEProblem(ode_fxn, u0, tspan, p)
sjm_prob = JumpProblem(ode_prob, b_jump, d_jump)
sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; rng)
dt = .1
tsave = range(tspan[1], tspan[2]; step = dt)
umean = zeros(length(tsave))
Expand Down

0 comments on commit 3f74888

Please sign in to comment.