Skip to content

Commit e9cc50e

Browse files
fix: implement proper initialization convention for discrete variables
1 parent 5a19c6a commit e9cc50e

File tree

7 files changed

+109
-19
lines changed

7 files changed

+109
-19
lines changed

docs/src/tutorials/discrete_system.md

+7-2
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,10 @@ the Fibonacci series:
4242
@mtkbuild sys = DiscreteSystem([x ~ x(k - 1) + x(k - 2)], t)
4343
```
4444

45-
Note that the default value is treated as the initial value of `x(k-1)`. The value for
46-
`x(k-2)` must be provided during problem construction.
45+
The "default value" here should be interpreted as the value of `x` at all past timesteps.
46+
For example, here `x(k-1)` and `x(k-2)` will be `1.0`, and the inital value of `x(k)` will
47+
thus be `2.0`. During problem construction, the _past_ value of a variable should be
48+
provided. For example, providing `[x => 1.0]` while constructing this problem will error.
49+
Provide `[x(k-1) => 1.0]` instead. Note that values provided during problem construction
50+
_do not_ apply to the entire history. Hence, if `[x(k-1) => 2.0]` is provided, the value of
51+
`x(k-2)` will still be `1.0`.

src/systems/diffeqs/abstractodesystem.jl

+35
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,41 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
864864
# since they will be checked in the initialization problem's construction
865865
# TODO: make check for if a DAE cheaper than calculating the mass matrix a second time!
866866
ci = infer_clocks!(ClockInference(TearingState(sys)))
867+
868+
if eltype(parammap) <: Pair
869+
parammap = Dict(unwrap(k) => v for (k, v) in todict(parammap))
870+
elseif parammap isa AbstractArray
871+
if isempty(parammap)
872+
parammap = SciMLBase.NullParameters()
873+
else
874+
parammap = Dict(unwrap.(parameters(sys)) .=> parammap)
875+
end
876+
end
877+
clockedparammap = Dict()
878+
defs = ModelingToolkit.get_defaults(sys)
879+
for v in ps
880+
v = unwrap(v)
881+
is_discrete_domain(v) || continue
882+
op = operation(v)
883+
if !isa(op, Symbolics.Operator) && parammap != SciMLBase.NullParameters() &&
884+
haskey(parammap, v)
885+
error("Initial conditions for discrete variables must be for the past state of the unknown. Instead of providing the condition for $v, provide the condition for $(Shift(iv, -1)(v)).")
886+
end
887+
shiftedv = StructuralTransformations.simplify_shifts(Shift(iv, -1)(v))
888+
if parammap != SciMLBase.NullParameters() &&
889+
(val = get(parammap, shiftedv, nothing)) !== nothing
890+
clockedparammap[v] = val
891+
elseif op isa Shift
892+
root = arguments(v)[1]
893+
haskey(defs, root) || error("Initial condition for $v not provided.")
894+
clockedparammap[v] = defs[root]
895+
end
896+
end
897+
parammap = if parammap == SciMLBase.NullParameters()
898+
clockedparammap
899+
else
900+
merge(parammap, clockedparammap)
901+
end
867902
# TODO: make it work with clocks
868903
# ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first
869904
if sys isa ODESystem && (implicit_dae || !isempty(missingvars)) &&

src/systems/discrete_system/discrete_system.jl

+23-6
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,8 @@ function DiscreteSystem(eqs, iv; kwargs...)
171171
ps = OrderedSet()
172172
iv = value(iv)
173173
for eq in eqs
174-
collect_vars!(allunknowns, ps, eq.lhs, iv)
175-
collect_vars!(allunknowns, ps, eq.rhs, iv)
174+
collect_vars!(allunknowns, ps, eq.lhs, iv; op = Shift)
175+
collect_vars!(allunknowns, ps, eq.rhs, iv; op = Shift)
176176
if istree(eq.lhs) && operation(eq.lhs) isa Shift
177177
isequal(iv, operation(eq.lhs).t) ||
178178
throw(ArgumentError("A DiscreteSystem can only have one independent variable."))
@@ -205,21 +205,38 @@ function generate_function(
205205
end
206206

207207
function process_DiscreteProblem(constructor, sys::DiscreteSystem, u0map, parammap;
208-
version = nothing,
209208
linenumbers = true, parallel = SerialForm(),
210209
eval_expression = true,
211210
use_union = false,
212211
tofloat = !use_union,
213212
kwargs...)
213+
iv = get_iv(sys)
214214
eqs = equations(sys)
215215
dvs = unknowns(sys)
216216
ps = parameters(sys)
217217

218+
trueu0map = Dict()
219+
for (k, v) in u0map
220+
k = unwrap(k)
221+
if !((op = operation(k)) isa Shift)
222+
error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).")
223+
end
224+
trueu0map[Shift(iv, op.steps + 1)(arguments(k)[1])] = v
225+
end
226+
defs = ModelingToolkit.get_defaults(sys)
227+
for var in dvs
228+
if (op = operation(var)) isa Shift && !haskey(trueu0map, var)
229+
root = arguments(var)[1]
230+
haskey(defs, root) || error("Initial condition for $var not provided.")
231+
trueu0map[var] = defs[root]
232+
end
233+
end
234+
@show trueu0map u0map
218235
if has_index_cache(sys) && get_index_cache(sys) !== nothing
219-
u0, defs = get_u0(sys, u0map, parammap)
220-
p = MTKParameters(sys, parammap)
236+
u0, defs = get_u0(sys, trueu0map, parammap)
237+
p = MTKParameters(sys, parammap, trueu0map)
221238
else
222-
u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union)
239+
u0, p, defs = get_u0_p(sys, trueu0map, parammap; tofloat, use_union)
223240
end
224241

225242
check_eqs_u0(eqs, dvs, u0; kwargs...)

src/utils.jl

+5-2
Original file line numberDiff line numberDiff line change
@@ -243,11 +243,14 @@ function collect_var_to_name!(vars, xs)
243243
x = unwrap(x)
244244
if hasmetadata(x, Symbolics.GetindexParent)
245245
xarr = getmetadata(x, Symbolics.GetindexParent)
246+
hasname(xarr) || continue
246247
vars[Symbolics.getname(xarr)] = xarr
247248
else
248249
if istree(x) && operation(x) === getindex
249250
x = arguments(x)[1]
250251
end
252+
x = unwrap(x)
253+
hasname(x) || continue
251254
vars[Symbolics.getname(unwrap(x))] = x
252255
end
253256
end
@@ -434,11 +437,11 @@ function find_derivatives!(vars, expr, f)
434437
return vars
435438
end
436439

437-
function collect_vars!(unknowns, parameters, expr, iv)
440+
function collect_vars!(unknowns, parameters, expr, iv; op = Differential)
438441
if issym(expr)
439442
collect_var!(unknowns, parameters, expr, iv)
440443
else
441-
for var in vars(expr)
444+
for var in vars(expr; op)
442445
if istree(var) && operation(var) isa Differential
443446
var, _ = var_from_nested_derivative(var)
444447
end

test/clock.jl

+19-4
Original file line numberDiff line numberDiff line change
@@ -109,21 +109,21 @@ ss = structural_simplify(sys);
109109

110110
Tf = 1.0
111111
prob = ODEProblem(ss, [x => 0.1], (0.0, Tf),
112-
[kp => 1.0; ud => 2.1; ud(k - 1) => 2.0])
112+
[kp => 1.0; ud(k - 1) => 2.1; ud(k - 2) => 2.0])
113113
# create integrator so callback is evaluated at t=0 and we can test correct param values
114114
int = init(prob, Tsit5(); kwargshandle = KeywordArgSilent)
115115
@test sort(vcat(int.p...)) == [0.1, 1.0, 2.1, 2.1, 2.1] # yd, kp, ud(k-1), ud, Hold(ud)
116116
prob = ODEProblem(ss, [x => 0.1], (0.0, Tf),
117-
[kp => 1.0; ud => 2.1; ud(k - 1) => 2.0]) # recreate problem to empty saved values
117+
[kp => 1.0; ud(k - 1) => 2.1; ud(k - 2) => 2.0]) # recreate problem to empty saved values
118118
sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
119119

120120
ss_nosplit = structural_simplify(sys; split = false)
121121
prob_nosplit = ODEProblem(ss_nosplit, [x => 0.1], (0.0, Tf),
122-
[kp => 1.0; ud => 2.1; ud(k - 1) => 2.0])
122+
[kp => 1.0; ud(k - 1) => 2.1; ud(k - 2) => 2.0])
123123
int = init(prob_nosplit, Tsit5(); kwargshandle = KeywordArgSilent)
124124
@test sort(int.p) == [0.1, 1.0, 2.1, 2.1, 2.1] # yd, kp, ud(k-1), ud, Hold(ud)
125125
prob_nosplit = ODEProblem(ss_nosplit, [x => 0.1], (0.0, Tf),
126-
[kp => 1.0; ud => 2.1; ud(k - 1) => 2.0]) # recreate problem to empty saved values
126+
[kp => 1.0; ud(k - 1) => 2.1; ud(k - 2) => 2.0]) # recreate problem to empty saved values
127127
sol_nosplit = solve(prob_nosplit, Tsit5(), kwargshandle = KeywordArgSilent)
128128
# For all inputs in parameters, just initialize them to 0.0, and then set them
129129
# in the callback.
@@ -516,3 +516,18 @@ sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
516516
@test sol.prob.kwargs[:disc_saved_values][1].t == sol.t[1:2:end] # Test that the discrete-tiem system executed at every step of the continuous solver. The solver saves each time step twice, one state value before discrete affect and one after.
517517
@test_nowarn ModelingToolkit.build_explicit_observed_function(
518518
model, model.counter.ud)(sol.u[1], prob.p..., sol.t[1])
519+
520+
@variables x(t)=1.0 y(t)=1.0
521+
eqs = [D(y) ~ Hold(x)
522+
x ~ x(k - 1) + x(k - 2)]
523+
@mtkbuild sys = ODESystem(eqs, t)
524+
prob = ODEProblem(sys, [], (0.0, 10.0))
525+
int = init(prob, Tsit5(); kwargshandle = KeywordArgSilent)
526+
@test int.ps[x] == 2.0
527+
@test int.ps[x(k - 1)] == 1.0
528+
529+
@test_throws ErrorException ODEProblem(sys, [], (0.0, 10.0), [x => 2.0])
530+
prob = ODEProblem(sys, [], (0.0, 10.0), [x(k - 1) => 2.0])
531+
int = init(prob, Tsit5(); kwargshandle = KeywordArgSilent)
532+
@test int.ps[x] == 3.0
533+
@test int.ps[x(k - 1)] == 2.0

test/discrete_system.jl

+17-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ for df in [
5151
end
5252

5353
# Problem
54-
u0 = [S => 990.0, I => 10.0, R => 0.0]
54+
u0 = [S(k - 1) => 990.0, I(k - 1) => 10.0, R(k - 1) => 0.0]
5555
p ==> 0.05, c => 10.0, γ => 0.25, δt => 0.1, nsteps => 400]
5656
tspan = (0.0, ModelingToolkit.value(substitute(nsteps, p))) # value function (from Symbolics) is used to convert a Num to Float64
5757
prob_map = DiscreteProblem(syss, u0, tspan, p)
@@ -215,7 +215,22 @@ eqs = [u ~ 1
215215
x ~ x(k - 1) + u
216216
y ~ x + u]
217217
@mtkbuild de = DiscreteSystem(eqs, t)
218-
prob = DiscreteProblem(de, [x => 0.0], (0, 10))
218+
prob = DiscreteProblem(de, [x(k - 1) => 0.0], (0, 10))
219219
sol = solve(prob, FunctionMap())
220220

221221
@test reduce(vcat, sol.u) == 1:11
222+
223+
# test that default values apply to the entire history
224+
@variables x(t) = 1.0
225+
@mtkbuild de = DiscreteSystem([x ~ x(k - 1) + x(k - 2)], t)
226+
prob = DiscreteProblem(de, [], (0, 10))
227+
@test prob[x] == 2.0
228+
@test prob[x(k - 1)] == 1.0
229+
230+
# must provide initial conditions for history
231+
@test_throws ErrorException DiscreteProblem(de, [x => 2.0], (0, 10))
232+
233+
# initial values only affect _that timestep_, not the entire history
234+
prob = DiscreteProblem(de, [x(k - 1) => 2.0], (0, 10))
235+
@test prob[x] == 3.0
236+
@test prob[x(k - 1)] == 2.0

test/parameter_dependencies.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,18 @@ end
6666

6767
Tf = 1.0
6868
prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf),
69-
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0])
69+
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0])
7070
@test_nowarn solve(prob, Tsit5(); kwargshandle = KeywordArgSilent)
7171

7272
@mtkbuild sys = ODESystem(eqs, t; parameter_dependencies = [kq => 2kp],
7373
discrete_events = [[0.5] => [kp ~ 2.0]])
7474
prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf),
75-
[kp => 1.0; z => 3.0; z(k + 1) => 2.0])
75+
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0])
7676
@test prob.ps[kp] == 1.0
7777
@test prob.ps[kq] == 2.0
7878
@test_nowarn solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
7979
prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf),
80-
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0])
80+
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0])
8181
integ = init(prob, Tsit5(), kwargshandle = KeywordArgSilent)
8282
@test integ.ps[kp] == 1.0
8383
@test integ.ps[kq] == 2.0

0 commit comments

Comments
 (0)