diff --git a/src/problem.jl b/src/problem.jl index 12624526..6fccf951 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -121,27 +121,31 @@ function DiffEqBase.remake(jprob::JumpProblem; kwargs...) else _kwargs = kwargs end - dprob = DiffEqBase.remake(jprob.prob; _kwargs...) + newprob = DiffEqBase.remake(jprob.prob; _kwargs...) else - dprob = DiffEqBase.remake(jprob.prob; kwargs...) + newprob = DiffEqBase.remake(jprob.prob; kwargs...) end # if the parameters were changed we must remake the MassActionJump too if (:p ∈ keys(kwargs)) && using_params(jprob.massaction_jump) - update_parameters!(jprob.massaction_jump, dprob.p; kwargs...) + update_parameters!(jprob.massaction_jump, newprob.p; kwargs...) end else any(k -> k in keys(kwargs), (:u0, :p, :tspan)) && error("If remaking a JumpProblem you can not pass both prob and any of u0, p, or tspan.") - dprob = kwargs[:prob] + newprob = kwargs[:prob] + + # when passing a new wrapped problem directly we require u0 has the correct type + (typeof(newprob.u0) == typeof(jprob.prob.u0)) || + error("The new u0 within the passed prob does not have the same type as the existing u0. Please pass a u0 of type $(typeof(jprob.prob.u0)).") # we can't know if p was changed, so we must remake the MassActionJump if using_params(jprob.massaction_jump) - update_parameters!(jprob.massaction_jump, dprob.p; kwargs...) + update_parameters!(jprob.massaction_jump, newprob.p; kwargs...) end end - T(dprob, jprob.aggregator, jprob.discrete_jump_aggregation, jprob.jump_callback, + T(newprob, jprob.aggregator, jprob.discrete_jump_aggregation, jprob.jump_callback, jprob.variable_jumps, jprob.regular_jump, jprob.massaction_jump, jprob.rng, jprob.kwargs) end diff --git a/test/remake_test.jl b/test/remake_test.jl index 13a45bef..676e615d 100644 --- a/test/remake_test.jl +++ b/test/remake_test.jl @@ -93,3 +93,27 @@ let @test all(==(0.0), sol[1, :]) @test_throws ErrorException jprob4=remake(jprob, u0 = 1) end + +# tests when changing u0 via a passed in prob +let + f(du, u, p, t) = (du .= 0; nothing) + prob = ODEProblem(f, [0.0], (0.0, 1.0)) + rrate(u, p, t) = u[1] + aaffect!(integrator) = (integrator.u[1] += 1; nothing) + vrj = VariableRateJump(rrate, aaffect!) + jprob = JumpProblem(prob, vrj; rng) + sol = solve(jprob, Tsit5()) + @test all(==(0.0), sol[1, :]) + u0 = [4.0] + prob2 = remake(jprob.prob; u0) + @test_throws ErrorException jprob2=remake(jprob; prob = prob2) + u0eja = JumpProcesses.remake_extended_u0(jprob.prob, u0, rng) + prob3 = remake(jprob.prob; u0 = u0eja) + jprob3 = remake(jprob; prob = prob3) + @test jprob3.prob.u0 isa ExtendedJumpArray + @test jprob3.prob.u0 === u0eja + sol = solve(jprob3, Tsit5()) + u = sol[1, :] + @test length(u) > 2 + @test all(>(u0[1]), u[3:end]) +end