From 3f74888adf0519e54a31ad41999535ba1188b9e4 Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Sun, 1 Sep 2024 21:21:25 -0400 Subject: [PATCH] fix broadcast issue --- src/solve.jl | 7 ++++--- test/variable_rate.jl | 8 +++++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index c64046c7..6ba2fe23 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -58,8 +58,8 @@ function resetted_jump_problem(_jump_prob, seed) if !isempty(jump_prob.variable_jumps) @assert jump_prob.prob.u0 isa ExtendedJumpArray - @. jump_prob.prob.u0.jump_u = -randexp(_jump_prob.rng, - eltype(_jump_prob.prob.tspan)) + ttype = eltype(_jump_prob.prob.tspan) + @. jump_prob.prob.u0.jump_u = -randexp(_jump_prob.rng, ttype) end jump_prob end @@ -71,6 +71,7 @@ function reset_jump_problem!(jump_prob, seed) if !isempty(jump_prob.variable_jumps) @assert jump_prob.prob.u0 isa ExtendedJumpArray - @. jump_prob.prob.u0.jump_u = -randexp(jump_prob.rng, eltype(jump_prob.prob.tspan)) + ttype = eltype(jump_prob.prob.tspan) + @. jump_prob.prob.u0.jump_u = -randexp(jump_prob.rng, ttype) end end diff --git a/test/variable_rate.jl b/test/variable_rate.jl index 64b3873f..d7bf2d9c 100644 --- a/test/variable_rate.jl +++ b/test/variable_rate.jl @@ -261,9 +261,10 @@ let ode_prob = ODEProblem(ode_fxn, u0, tspan, p) sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; rng) + @test allunique(sjm_prob.prob.u0.jump_u) u0old = copy(sjm_prob.prob.u0.jump_u) for i in 1:Nsims - sol = solve(sjm_prob, Tsit5(); saveat = tspan[2]) + sol = solve(sjm_prob, Tsit5(); saveat = tspan[2]) @test allunique(sjm_prob.prob.u0.jump_u) @test all(u0old != sjm_prob.prob.u0.jump_u) u0old .= sjm_prob.prob.u0.jump_u @@ -272,7 +273,8 @@ end # accuracy test based on # https://github.com/SciML/JumpProcesses.jl/issues/320 -# note that testing that precisely is not trivial +# note that even with the seeded StableRNG this test is not +# deterministic for some reason. let rng = StableRNG(12345) b = 2.0 @@ -304,7 +306,7 @@ let d_jump = VariableRateJump(d_rate, death!) ode_prob = ODEProblem(ode_fxn, u0, tspan, p) - sjm_prob = JumpProblem(ode_prob, b_jump, d_jump) + sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; rng) dt = .1 tsave = range(tspan[1], tspan[2]; step = dt) umean = zeros(length(tsave))