Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ExtendedJumpArray remake fixes II #448

Merged
merged 2 commits into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,27 +121,31 @@ function DiffEqBase.remake(jprob::JumpProblem; kwargs...)
else
_kwargs = kwargs
end
dprob = DiffEqBase.remake(jprob.prob; _kwargs...)
newprob = DiffEqBase.remake(jprob.prob; _kwargs...)
else
dprob = DiffEqBase.remake(jprob.prob; kwargs...)
newprob = DiffEqBase.remake(jprob.prob; kwargs...)
end

# if the parameters were changed we must remake the MassActionJump too
if (:p ∈ keys(kwargs)) && using_params(jprob.massaction_jump)
update_parameters!(jprob.massaction_jump, dprob.p; kwargs...)
update_parameters!(jprob.massaction_jump, newprob.p; kwargs...)
end
else
any(k -> k in keys(kwargs), (:u0, :p, :tspan)) &&
error("If remaking a JumpProblem you can not pass both prob and any of u0, p, or tspan.")
dprob = kwargs[:prob]
newprob = kwargs[:prob]

# when passing a new wrapped problem directly we require u0 has the correct type
(typeof(newprob.u0) == typeof(jprob.prob.u0)) ||
error("The new u0 within the passed prob does not have the same type as the existing u0. Please pass a u0 of type $(typeof(jprob.prob.u0)).")

# we can't know if p was changed, so we must remake the MassActionJump
if using_params(jprob.massaction_jump)
update_parameters!(jprob.massaction_jump, dprob.p; kwargs...)
update_parameters!(jprob.massaction_jump, newprob.p; kwargs...)
end
end

T(dprob, jprob.aggregator, jprob.discrete_jump_aggregation, jprob.jump_callback,
T(newprob, jprob.aggregator, jprob.discrete_jump_aggregation, jprob.jump_callback,
jprob.variable_jumps, jprob.regular_jump, jprob.massaction_jump, jprob.rng,
jprob.kwargs)
end
Expand Down
24 changes: 24 additions & 0 deletions test/remake_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,27 @@ let
@test all(==(0.0), sol[1, :])
@test_throws ErrorException jprob4=remake(jprob, u0 = 1)
end

# tests when changing u0 via a passed in prob
let
f(du, u, p, t) = (du .= 0; nothing)
prob = ODEProblem(f, [0.0], (0.0, 1.0))
rrate(u, p, t) = u[1]
aaffect!(integrator) = (integrator.u[1] += 1; nothing)
vrj = VariableRateJump(rrate, aaffect!)
jprob = JumpProblem(prob, vrj; rng)
sol = solve(jprob, Tsit5())
@test all(==(0.0), sol[1, :])
u0 = [4.0]
prob2 = remake(jprob.prob; u0)
@test_throws ErrorException jprob2=remake(jprob; prob = prob2)
u0eja = JumpProcesses.remake_extended_u0(jprob.prob, u0, rng)
prob3 = remake(jprob.prob; u0 = u0eja)
jprob3 = remake(jprob; prob = prob3)
@test jprob3.prob.u0 isa ExtendedJumpArray
@test jprob3.prob.u0 === u0eja
sol = solve(jprob3, Tsit5())
u = sol[1, :]
@test length(u) > 2
@test all(>(u0[1]), u[3:end])
end
Loading