Skip to content

Commit eeccbeb

Browse files
refactor: initial conditions represent values at t-1
1 parent 67991b7 commit eeccbeb

File tree

4 files changed

+19
-13
lines changed

4 files changed

+19
-13
lines changed

src/systems/diffeqs/abstractodesystem.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -1045,7 +1045,7 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
10451045
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
10461046
if clock isa Clock
10471047
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt;
1048-
final_affect = true)
1048+
final_affect = true, initial_affect = true)
10491049
elseif clock isa SolverStepClock
10501050
affect = DiscreteSaveAffect(affect, sv)
10511051
DiscreteCallback(Returns(true), affect,
@@ -1080,7 +1080,7 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
10801080
prob = ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
10811081
if !isempty(inits)
10821082
for init in inits
1083-
init(prob.u0, prob.p, tspan[1])
1083+
# init(prob.u0, prob.p, tspan[1])
10841084
end
10851085
end
10861086
prob

src/systems/discrete_system/discrete_system.jl

+1
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ function SciMLBase.DiscreteProblem(
253253

254254
f, u0, p = process_DiscreteProblem(
255255
DiscreteFunction, sys, u0map, parammap; eval_expression, eval_module)
256+
u0 = f(u0, p, tspan[1])
256257
DiscreteProblem(f, u0, tspan, p; kwargs...)
257258
end
258259

test/clock.jl

+14-9
Original file line numberDiff line numberDiff line change
@@ -112,14 +112,18 @@ prob = ODEProblem(ss, [x => 0.1], (0.0, Tf),
112112
[kp => 1.0; ud => 2.1; ud(k - 1) => 2.0])
113113
# create integrator so callback is evaluated at t=0 and we can test correct param values
114114
int = init(prob, Tsit5(); kwargshandle = KeywordArgSilent)
115-
@test sort(vcat(int.p...)) == [0.1, 1.0, 2.0, 2.1, 2.1] # yd, kp, ud(k-1), ud, Hold(ud)
115+
@test sort(vcat(int.p...)) == [0.1, 1.0, 2.1, 2.1, 2.1] # yd, kp, ud(k-1), ud, Hold(ud)
116+
prob = ODEProblem(ss, [x => 0.1], (0.0, Tf),
117+
[kp => 1.0; ud => 2.1; ud(k - 1) => 2.0]) # recreate problem to empty saved values
116118
sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
117119

118120
ss_nosplit = structural_simplify(sys; split = false)
119121
prob_nosplit = ODEProblem(ss_nosplit, [x => 0.1], (0.0, Tf),
120122
[kp => 1.0; ud => 2.1; ud(k - 1) => 2.0])
121123
int = init(prob_nosplit, Tsit5(); kwargshandle = KeywordArgSilent)
122-
@test sort(int.p) == [0.1, 1.0, 2.0, 2.1, 2.1] # yd, kp, ud(k-1), ud, Hold(ud)
124+
@test sort(int.p) == [0.1, 1.0, 2.1, 2.1, 2.1] # yd, kp, ud(k-1), ud, Hold(ud)
125+
prob_nosplit = ODEProblem(ss_nosplit, [x => 0.1], (0.0, Tf),
126+
[kp => 1.0; ud => 2.1; ud(k - 1) => 2.0]) # recreate problem to empty saved values
123127
sol_nosplit = solve(prob_nosplit, Tsit5(), kwargshandle = KeywordArgSilent)
124128
# For all inputs in parameters, just initialize them to 0.0, and then set them
125129
# in the callback.
@@ -145,7 +149,8 @@ function affect!(integrator, saved_values)
145149
nothing
146150
end
147151
saved_values = SavedValues(Float64, Vector{Float64})
148-
cb = PeriodicCallback(Base.Fix2(affect!, saved_values), 0.1; final_affect = true)
152+
cb = PeriodicCallback(
153+
Base.Fix2(affect!, saved_values), 0.1; final_affect = true, initial_affect = true)
149154
# kp ud
150155
prob = ODEProblem(foo!, [0.1], (0.0, Tf), [1.0, 2.1, 2.0], callback = cb)
151156
sol2 = solve(prob, Tsit5())
@@ -308,8 +313,8 @@ if VERSION >= v"1.7"
308313
integrator.p[3] = ud2
309314
nothing
310315
end
311-
cb1 = PeriodicCallback(affect1!, dt; final_affect = true)
312-
cb2 = PeriodicCallback(affect2!, dt2; final_affect = true)
316+
cb1 = PeriodicCallback(affect1!, dt; final_affect = true, initial_affect = true)
317+
cb2 = PeriodicCallback(affect2!, dt2; final_affect = true, initial_affect = true)
313318
cb = CallbackSet(cb1, cb2)
314319
# kp ud1 ud2
315320
prob = ODEProblem(foo!, [0.0], (0.0, 1.0), [1.0, 1.0, 1.0], callback = cb)
@@ -438,10 +443,10 @@ y = res.y[:]
438443
prob = ODEProblem(ssys,
439444
[model.plant.x => 0.0; model.controller.kp => 2.0; model.controller.ki => 2.0],
440445
(0.0, Tf))
441-
442-
@test prob.ps[Hold(ssys.holder.input.u)] == 1 # constant output * kp issue https://github.com/SciML/ModelingToolkit.jl/issues/2356
443-
@test prob.ps[ssys.controller.x] == 0 # c2d
444-
@test prob.ps[Sample(d)(ssys.sampler.input.u)] == 0 # disc state
446+
int = init(prob, Tsit5(); kwargshandle = KeywordArgSilent)
447+
@test int.ps[Hold(ssys.holder.input.u)] == 2 # constant output * kp issue https://github.com/SciML/ModelingToolkit.jl/issues/2356
448+
@test int.ps[ssys.controller.x] == 1 # c2d
449+
@test int.ps[Sample(d)(ssys.sampler.input.u)] == 0 # disc state
445450
sol = solve(prob,
446451
Tsit5(),
447452
kwargshandle = KeywordArgSilent,

test/discrete_system.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ function sir_map!(u_diff, u, p, t)
9999
end
100100
nothing
101101
end;
102-
u0 = [990.0, 10.0, 0.0];
102+
u0 = prob_map2.u0;
103103
p = [0.05, 10.0, 0.25, 0.1];
104104
prob_map = DiscreteProblem(sir_map!, u0, tspan, p);
105105
sol_map2 = solve(prob_map, FunctionMap());
@@ -216,4 +216,4 @@ eqs = [u ~ 1
216216
prob = DiscreteProblem(de, [x => 0.0], (0, 10))
217217
sol = solve(prob, FunctionMap())
218218

219-
@test reduce(vcat, sol.u) == 0:10
219+
@test reduce(vcat, sol.u) == 1:11

0 commit comments

Comments
 (0)