Skip to content

Commit 55aa1e6

Browse files
feat: shift all discrete systems by 1 to fix correctness issues
1 parent 3824ce7 commit 55aa1e6

File tree

10 files changed

+147
-83
lines changed

10 files changed

+147
-83
lines changed

src/structural_transformation/utils.jl

+1
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ end
433433

434434
function simplify_shifts(var)
435435
ModelingToolkit.hasshift(var) || return var
436+
var isa Equation && return simplify_shifts(var.lhs) ~ simplify_shifts(var.rhs)
436437
if isdoubleshift(var)
437438
op1 = operation(var)
438439
vv1 = arguments(var)[1]

src/systems/alias_elimination.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ function observed2graph(eqs, unknowns)
453453
lhs_j === nothing &&
454454
throw(ArgumentError("The lhs $(eq.lhs) of $eq, doesn't appear in unknowns."))
455455
assigns[i] = lhs_j
456-
vs = vars(eq.rhs)
456+
vs = vars(eq.rhs; op = Symbolics.Operator)
457457
for v in vs
458458
j = get(v2j, v, nothing)
459459
j !== nothing && add_edge!(graph, i, j)

src/systems/clock_inference.jl

+54-23
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,10 @@ function split_system(ci::ClockInference{S}) where {S}
133133
tss = similar(cid_to_eq, S)
134134
for (id, ieqs) in enumerate(cid_to_eq)
135135
ts_i = system_subset(ts, ieqs)
136-
@set! ts_i.structure.only_discrete = id != continuous_id
136+
if id != continuous_id
137+
ts_i = shift_discrete_system(ts_i)
138+
@set! ts_i.structure.only_discrete = true
139+
end
137140
tss[id] = ts_i
138141
end
139142
return tss, inputs, continuous_id, id_to_clock
@@ -148,7 +151,7 @@ function generate_discrete_affect(
148151
end
149152
use_index_cache = has_index_cache(osys) && get_index_cache(osys) !== nothing
150153
out = Sym{Any}(:out)
151-
appended_parameters = parameters(syss[continuous_id])
154+
appended_parameters = full_parameters(syss[continuous_id])
152155
offset = length(appended_parameters)
153156
param_to_idx = if use_index_cache
154157
Dict{Any, ParameterIndex}(p => parameter_index(osys, p)
@@ -157,6 +160,7 @@ function generate_discrete_affect(
157160
Dict{Any, Int}(reverse(en) for en in enumerate(appended_parameters))
158161
end
159162
affect_funs = []
163+
init_funs = []
160164
svs = []
161165
clocks = TimeDomain[]
162166
for (i, (sys, input)) in enumerate(zip(syss, inputs))
@@ -183,47 +187,38 @@ function generate_discrete_affect(
183187
if _v in fullvars
184188
push!(needed_disc_to_cont_obs, _v)
185189
push!(disc_to_cont_idxs, param_to_idx[v])
190+
continue
186191
end
187192

188-
# In the above case, `_v` was in `observed(sys)`
189-
# It may also be in `unknowns(sys)`, in which case it
190-
# will be shifted back by one step
191-
if istree(v) && (op = operation(v)) isa Shift
192-
_v = arguments(_v)[1]
193-
_v = Shift(op.t, op.steps - 1)(_v)
194-
else
195-
_v = Shift(get_iv(sys), -1)(_v)
196-
end
193+
# If the held quantity is calculated through observed
194+
# it will be shifted forward by 1
195+
_v = Shift(get_iv(sys), 1)(_v)
197196
if _v in fullvars
198197
push!(needed_disc_to_cont_obs, _v)
199198
push!(disc_to_cont_idxs, param_to_idx[v])
199+
continue
200200
end
201201
end
202-
append!(appended_parameters, input, unknowns(sys))
202+
append!(appended_parameters, input)
203203
cont_to_disc_obs = build_explicit_observed_function(
204204
use_index_cache ? osys : syss[continuous_id],
205205
needed_cont_to_disc_obs,
206206
throw = false,
207207
expression = true,
208208
output_type = SVector)
209-
@set! sys.ps = appended_parameters
210209
disc_to_cont_obs = build_explicit_observed_function(sys, needed_disc_to_cont_obs,
211210
throw = false,
212211
expression = true,
213212
output_type = SVector,
214213
op = Shift,
215-
ps = reorder_parameters(osys, full_parameters(sys)))
214+
ps = reorder_parameters(osys, appended_parameters))
216215
ni = length(input)
217216
ns = length(unknowns(sys))
218217
disc = Func(
219218
[
220219
out,
221220
DestructuredArgs(unknowns(osys)),
222-
if use_index_cache
223-
DestructuredArgs.(reorder_parameters(osys, full_parameters(osys)))
224-
else
225-
(DestructuredArgs(appended_parameters),)
226-
end...,
221+
DestructuredArgs.(reorder_parameters(osys, full_parameters(osys)))...,
227222
get_iv(sys)
228223
],
229224
[],
@@ -248,6 +243,36 @@ function generate_discrete_affect(
248243
end
249244
end
250245
empty_disc = isempty(disc_range)
246+
disc_init = if use_index_cache
247+
:(function (u, p, t)
248+
c2d_obs = $cont_to_disc_obs
249+
d2c_obs = $disc_to_cont_obs
250+
result = c2d_obs(u, p..., t)
251+
for (val, i) in zip(result, $cont_to_disc_idxs)
252+
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
253+
end
254+
255+
disc_state = Tuple($(parameter_values)(p, i) for i in $disc_range)
256+
result = d2c_obs(disc_state, p..., t)
257+
for (val, i) in zip(result, $disc_to_cont_idxs)
258+
# prevent multiple updates to dependents
259+
_set_parameter_unchecked!(p, val, i; update_dependent = false)
260+
end
261+
discretes, repack, _ = $(SciMLStructures.canonicalize)(
262+
$(SciMLStructures.Discrete()), p)
263+
repack(discretes) # to force recalculation of dependents
264+
end)
265+
else
266+
:(function (u, p, t)
267+
c2d_obs = $cont_to_disc_obs
268+
d2c_obs = $disc_to_cont_obs
269+
c2d_view = view(p, $cont_to_disc_idxs)
270+
d2c_view = view(p, $disc_to_cont_idxs)
271+
disc_unknowns = view(p, $disc_range)
272+
copyto!(c2d_view, c2d_obs(u, p, t))
273+
copyto!(d2c_view, d2c_obs(disc_unknowns, p, t))
274+
end)
275+
end
251276

252277
# @show disc_to_cont_idxs
253278
# @show cont_to_disc_idxs
@@ -270,9 +295,6 @@ function generate_discrete_affect(
270295
# TODO: find a way to do this without allocating
271296
disc = $disc
272297

273-
push!(saved_values.t, t)
274-
push!(saved_values.saveval, $save_vec)
275-
276298
# Write continuous into to discrete: handles `Sample`
277299
# Write discrete into to continuous
278300
# Update discrete unknowns
@@ -322,6 +344,10 @@ function generate_discrete_affect(
322344
:(copyto!(d2c_view, d2c_obs(disc_unknowns, p, t)))
323345
end
324346
)
347+
348+
push!(saved_values.t, t)
349+
push!(saved_values.saveval, $save_vec)
350+
325351
# @show "after d2c", p
326352
$(
327353
if use_index_cache
@@ -335,15 +361,20 @@ function generate_discrete_affect(
335361
end)
336362
sv = SavedValues(Float64, Vector{Float64})
337363
push!(affect_funs, affect!)
364+
push!(init_funs, disc_init)
338365
push!(svs, sv)
339366
end
340367
if eval_expression
341368
affects = map(affect_funs) do a
342369
drop_expr(@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a))))
343370
end
371+
inits = map(init_funs) do a
372+
drop_expr(@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a))))
373+
end
344374
else
345375
affects = map(a -> toexpr(LiteralExpr(a)), affect_funs)
376+
inits = map(a -> toexpr(LiteralExpr(a)), init_funs)
346377
end
347378
defaults = Dict{Any, Any}(v => 0.0 for v in Iterators.flatten(inputs))
348-
return affects, clocks, svs, appended_parameters, defaults
379+
return affects, inits, clocks, svs, appended_parameters, defaults
349380
end

src/systems/diffeqs/abstractodesystem.jl

+20-30
Original file line numberDiff line numberDiff line change
@@ -1038,12 +1038,13 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
10381038
t = tspan !== nothing ? tspan[1] : tspan,
10391039
check_length, warn_initialize_determined, kwargs...)
10401040
cbs = process_events(sys; callback, kwargs...)
1041-
affects = []
1041+
inits = []
10421042
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
1043-
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...)
1043+
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...)
10441044
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
10451045
if clock isa Clock
1046-
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
1046+
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt;
1047+
final_affect = true)
10471048
elseif clock isa SolverStepClock
10481049
affect = DiscreteSaveAffect(affect, sv)
10491050
DiscreteCallback(Returns(true), affect,
@@ -1061,12 +1062,6 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
10611062
else
10621063
cbs = CallbackSet(cbs, discrete_cbs...)
10631064
end
1064-
# initialize by running affects
1065-
dummy_saveval = (; t = [], saveval = [])
1066-
for affect! in affects
1067-
affect!(
1068-
(; u = u0, p = p, t = tspan !== nothing ? tspan[1] : tspan), dummy_saveval)
1069-
end
10701065
else
10711066
svs = nothing
10721067
end
@@ -1080,7 +1075,14 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
10801075
if svs !== nothing
10811076
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
10821077
end
1083-
ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
1078+
1079+
prob = ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
1080+
if !isempty(inits)
1081+
for init in inits
1082+
init(prob.u0, prob.p, tspan[1])
1083+
end
1084+
end
1085+
prob
10841086
end
10851087
get_callback(prob::ODEProblem) = prob.kwargs[:callback]
10861088

@@ -1149,12 +1151,12 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
11491151
h(p::MTKParameters, t) = h_oop(p..., t)
11501152
u0 = h(p, tspan[1])
11511153
cbs = process_events(sys; callback, kwargs...)
1152-
inits = []
11531154
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
1154-
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...)
1155+
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...)
11551156
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
11561157
if clock isa Clock
1157-
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
1158+
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt;
1159+
final_affect = true, initial_affect = true)
11581160
else
11591161
error("$clock is not a supported clock type.")
11601162
end
@@ -1180,13 +1182,7 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
11801182
if svs !== nothing
11811183
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
11821184
end
1183-
prob = DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
1184-
if !isempty(inits)
1185-
for init in inits
1186-
init(prob.p, tspan[1])
1187-
end
1188-
end
1189-
prob
1185+
DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
11901186
end
11911187

11921188
function DiffEqBase.SDDEProblem(sys::AbstractODESystem, args...; kwargs...)
@@ -1211,12 +1207,12 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
12111207
h(p, t) = h_oop(p, t)
12121208
u0 = h(p, tspan[1])
12131209
cbs = process_events(sys; callback, kwargs...)
1214-
inits = []
12151210
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
1216-
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...)
1211+
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...)
12171212
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
12181213
if clock isa Clock
1219-
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
1214+
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt;
1215+
final_affect = true, initial_affect = true)
12201216
else
12211217
error("$clock is not a supported clock type.")
12221218
end
@@ -1253,15 +1249,9 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
12531249
else
12541250
noise_rate_prototype = zeros(eltype(u0), size(noiseeqs))
12551251
end
1256-
prob = SDDEProblem{iip}(f, f.g, u0, h, tspan, p;
1252+
SDDEProblem{iip}(f, f.g, u0, h, tspan, p;
12571253
noise_rate_prototype =
12581254
noise_rate_prototype, kwargs1..., kwargs...)
1259-
if !isempty(inits)
1260-
for init in inits
1261-
init(prob.p, tspan[1])
1262-
end
1263-
end
1264-
prob
12651255
end
12661256

12671257
"""

src/systems/diffeqs/odesystem.jl

+6-3
Original file line numberDiff line numberDiff line change
@@ -453,13 +453,16 @@ function build_explicit_observed_function(sys, ts;
453453
if inputs !== nothing
454454
ps = setdiff(ps, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list
455455
end
456-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
457-
ps = DestructuredArgs.(reorder_parameters(get_index_cache(sys), ps))
458-
elseif ps isa Tuple
456+
if ps isa Tuple
459457
ps = DestructuredArgs.(ps, inbounds = !checkbounds)
458+
elseif has_index_cache(sys) && get_index_cache(sys) !== nothing
459+
ps = DestructuredArgs.(reorder_parameters(get_index_cache(sys), ps))
460460
else
461461
ps = (DestructuredArgs(ps, inbounds = !checkbounds),)
462462
end
463+
if isempty(ps)
464+
ps = (DestructuredArgs([]),)
465+
end
463466
dvs = DestructuredArgs(unknowns(sys), inbounds = !checkbounds)
464467
if inputs === nothing
465468
args = [dvs, ps..., ivs...]

src/systems/discrete_system/discrete_system.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
139139
iv′ = value(iv)
140140
dvs′ = value.(dvs)
141141
ps′ = value.(ps)
142-
if !all(hasshift, eqs)
143-
error("All equations in a `DiscreteSystem` must be difference equations")
142+
if any(hasderiv, eqs) || any(hashold, eqs) || any(hassample, eqs) || any(hasdiff, eqs)
143+
error("Equations in a `DiscreteSystem` can only have `Shift` operators.")
144144
end
145145
if !(isempty(default_u0) && isempty(default_p))
146146
Base.depwarn(

src/systems/systemstructure.jl

+29-7
Original file line numberDiff line numberDiff line change
@@ -369,12 +369,6 @@ function TearingState(sys; quick_cancel = false, check = true)
369369
steps = 0
370370
tt = iv
371371
v = var
372-
if lshift < 0
373-
defs = ModelingToolkit.get_defaults(sys)
374-
if (_val = get(defs, var, nothing)) !== nothing
375-
defs[Shift(tt, -1)(v)] = _val
376-
end
377-
end
378372
else
379373
continue
380374
end
@@ -434,10 +428,14 @@ function TearingState(sys; quick_cancel = false, check = true)
434428

435429
eq_to_diff = DiffGraph(nsrcs(graph))
436430

437-
return TearingState(sys, fullvars,
431+
ts = TearingState(sys, fullvars,
438432
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
439433
complete(graph), nothing, var_types, sys isa DiscreteSystem),
440434
Any[])
435+
if sys isa DiscreteSystem
436+
ts = shift_discrete_system(ts)
437+
end
438+
return ts
441439
end
442440

443441
function lower_order_var(dervar, t)
@@ -458,6 +456,30 @@ function lower_order_var(dervar, t)
458456
diffvar
459457
end
460458

459+
function shift_discrete_system(ts::TearingState)
460+
@unpack fullvars, sys = ts
461+
discvars = OrderedSet()
462+
eqs = equations(sys)
463+
for eq in eqs
464+
vars!(discvars, eq; op = Union{Sample, Hold})
465+
end
466+
iv = get_iv(sys)
467+
discmap = Dict(k => StructuralTransformations.simplify_shifts(Shift(iv, 1)(k))
468+
for k in discvars
469+
if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold}))
470+
for i in eachindex(fullvars)
471+
fullvars[i] = StructuralTransformations.simplify_shifts(fast_substitute(
472+
fullvars[i], discmap; operator = Union{Sample, Hold}))
473+
end
474+
for i in eachindex(eqs)
475+
eqs[i] = StructuralTransformations.simplify_shifts(fast_substitute(
476+
eqs[i], discmap; operator = Union{Sample, Hold}))
477+
end
478+
@set! ts.sys.eqs = eqs
479+
@set! ts.fullvars = fullvars
480+
return ts
481+
end
482+
461483
using .BipartiteGraphs: Label, BipartiteAdjacencyList
462484
struct SystemStructurePrintMatrix <:
463485
AbstractMatrix{Union{Label, BipartiteAdjacencyList}}

src/variables.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,8 @@ function _varmap_to_vars(varmap::Dict, varlist; defaults = Dict(), check = false
195195
values = Dict()
196196
for var in varlist
197197
var = unwrap(var)
198-
val = unwrap(fixpoint_sub(fixpoint_sub(var, varmap), defaults))
198+
val = unwrap(fixpoint_sub(fixpoint_sub(var, varmap; operator = Symbolics.Operator),
199+
defaults; operator = Symbolics.Operator))
199200
if symbolic_type(val) === NotSymbolic()
200201
values[var] = val
201202
end

0 commit comments

Comments
 (0)