Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: initial implementation of new DiscreteSystem #2507

Merged
merged 13 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pages = [
"tutorials/modelingtoolkitize.md",
"tutorials/programmatically_generating.md",
"tutorials/stochastic_diffeq.md",
"tutorials/discrete_system.md",
"tutorials/parameter_identifiability.md",
"tutorials/bifurcation_diagram_computation.md",
"tutorials/SampledData.md",
Expand Down
51 changes: 51 additions & 0 deletions docs/src/tutorials/discrete_system.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# (Experimental) Modeling Discrete Systems
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it still experimental?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kind of, unless observed variables are working as intended now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, since shifting the entire system by one step is basically a workaround and eventually we'll need proper compiler support for this. I've discussed with @baggepinnen and that will end up changing things such as the state realization and observed equations in the simplified system.


In this example, we will use the new [`DiscreteSystem`](@ref) API
to create an SIR model.

```@example discrete
using ModelingToolkit
using ModelingToolkit: t_nounits as t
using OrdinaryDiffEq: solve, FunctionMap

@inline function rate_to_proportion(r, t)
1 - exp(-r * t)
end
@parameters c δt β γ
@constants h = 1
@variables S(t) I(t) R(t)
k = ShiftIndex(t)
infection = rate_to_proportion(
β * c * I(k - 1) / (S(k - 1) * h + I(k - 1) + R(k - 1)), δt * h) * S(k - 1)
recovery = rate_to_proportion(γ * h, δt) * I(k - 1)

# Equations
eqs = [S(k) ~ S(k - 1) - infection * h,
I(k) ~ I(k - 1) + infection - recovery,
R(k) ~ R(k - 1) + recovery]
@mtkbuild sys = DiscreteSystem(eqs, t)

u0 = [S(k - 1) => 990.0, I(k - 1) => 10.0, R(k - 1) => 0.0]
p = [β => 0.05, c => 10.0, γ => 0.25, δt => 0.1]
tspan = (0.0, 100.0)
prob = DiscreteProblem(sys, u0, tspan, p)
sol = solve(prob, FunctionMap())
```

All shifts must be non-positive, i.e., discrete-time variables may only be indexed at index
`k, k-1, k-2, ...`. If default values are provided, they are treated as the value of the
variable at the previous timestep. For example, consider the following system to generate
the Fibonacci series:

```@example discrete
@variables x(t) = 1.0
@mtkbuild sys = DiscreteSystem([x ~ x(k - 1) + x(k - 2)], t)
```

The "default value" here should be interpreted as the value of `x` at all past timesteps.
For example, here `x(k-1)` and `x(k-2)` will be `1.0`, and the inital value of `x(k)` will

Check warning on line 46 in docs/src/tutorials/discrete_system.md

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"inital" should be "initial".
thus be `2.0`. During problem construction, the _past_ value of a variable should be
provided. For example, providing `[x => 1.0]` while constructing this problem will error.
Provide `[x(k-1) => 1.0]` instead. Note that values provided during problem construction
_do not_ apply to the entire history. Hence, if `[x(k-1) => 2.0]` is provided, the value of
`x(k-2)` will still be `1.0`.
3 changes: 3 additions & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ include("systems/diffeqs/first_order_transform.jl")
include("systems/diffeqs/modelingtoolkitize.jl")
include("systems/diffeqs/basic_transformations.jl")

include("systems/discrete_system/discrete_system.jl")

include("systems/jumps/jumpsystem.jl")

include("systems/optimization/constraints_system.jl")
Expand Down Expand Up @@ -209,6 +211,7 @@ export ODESystem,
export DAEFunctionExpr, DAEProblemExpr
export SDESystem, SDEFunction, SDEFunctionExpr, SDEProblemExpr
export SystemStructure
export DiscreteSystem, DiscreteProblem, DiscreteFunction, DiscreteFunctionExpr
export JumpSystem
export ODEProblem, SDEProblem
export NonlinearFunction, NonlinearFunctionExpr
Expand Down
5 changes: 5 additions & 0 deletions src/clock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
SolverStepClock()
SolverStepClock(t)

A clock that ticks at each solver step (sometimes referred to as "continuous sample time"). This clock **does generally not have equidistant tick intervals**, instead, the tick interval depends on the adaptive step-size slection of the continuous solver, as well as any continuous event handling. If adaptivity of the solver is turned off and there are no continuous events, the tick interval will be given by the fixed solver time step `dt`.

Check warning on line 133 in src/clock.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"slection" should be "selection".

Due to possibly non-equidistant tick intervals, this clock should typically not be used with discrete-time systems that assume a fixed sample time, such as PID controllers and digital filters.
"""
Expand All @@ -146,3 +146,8 @@
function Base.:(==)(c1::SolverStepClock, c2::SolverStepClock)
((c1.t === nothing || c2.t === nothing) || isequal(c1.t, c2.t))
end

struct IntegerSequence <: AbstractClock
t::Union{Nothing, Symbolic}
IntegerSequence(t::Union{Num, Symbolic}) = new(value(t))
end
5 changes: 4 additions & 1 deletion src/discretedomain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
vt = value(x)
if istree(vt)
op = operation(vt)
if op isa Shift
if op isa Sample
error("Cannot shift a `Sample`. Create a variable to represent the sampled value and shift that instead")

Check warning on line 42 in src/discretedomain.jl

View check run for this annotation

Codecov / codecov/patch

src/discretedomain.jl#L42

Added line #L42 was not covered by tests
elseif op isa Shift
if D.t === nothing || isequal(D.t, op.t)
arg = arguments(vt)[1]
newsteps = D.steps + op.steps
Expand Down Expand Up @@ -168,6 +170,7 @@
steps::Int
ShiftIndex(clock::TimeDomain = Inferred(), steps::Int = 0) = new(clock, steps)
ShiftIndex(t::Num, dt::Real, steps::Int = 0) = new(Clock(t, dt), steps)
ShiftIndex(t::Num, steps::Int = 0) = new(IntegerSequence(t), steps)
end

function (xn::Num)(k::ShiftIndex)
Expand Down
15 changes: 8 additions & 7 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,8 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
dx = fullvars[dv]
# add `x_t`
order, lv = var_order(dv)
x_t = lower_varname(fullvars[lv], iv, order)
push!(fullvars, x_t)
x_t = lower_varname_withshift(fullvars[lv], iv, order)
push!(fullvars, simplify_shifts(x_t))
v_t = length(fullvars)
v_t_idx = add_vertex!(var_to_diff)
add_vertex!(graph, DST)
Expand Down Expand Up @@ -437,11 +437,12 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
# We cannot solve the differential variable like D(x)
if isdervar(iv)
order, lv = var_order(iv)
dx = D(lower_varname(fullvars[lv], idep, order - 1))
eq = dx ~ ModelingToolkit.fixpoint_sub(
dx = D(simplify_shifts(lower_varname_withshift(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the lower_varname_withshift inside a isdervar branch?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that's the structural info.

fullvars[lv], idep, order - 1)))
eq = dx ~ simplify_shifts(ModelingToolkit.fixpoint_sub(
Symbolics.solve_for(neweqs[ieq],
fullvars[iv]),
total_sub)
total_sub; operator = ModelingToolkit.Shift))
for e in 𝑑neighbors(graph, iv)
e == ieq && continue
for v in 𝑠neighbors(graph, e)
Expand All @@ -450,7 +451,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
rem_edge!(graph, e, iv)
end
push!(diff_eqs, eq)
total_sub[eq.lhs] = eq.rhs
total_sub[simplify_shifts(eq.lhs)] = eq.rhs
push!(diffeq_idxs, ieq)
push!(diff_vars, diff_to_var[iv])
continue
Expand All @@ -469,7 +470,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
neweq = var ~ ModelingToolkit.fixpoint_sub(
simplify ?
Symbolics.simplify(rhs) : rhs,
total_sub)
total_sub; operator = ModelingToolkit.Shift)
push!(subeqs, neweq)
push!(solved_equations, ieq)
push!(solved_variables, iv)
Expand Down
37 changes: 37 additions & 0 deletions src/structural_transformation/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,40 @@ function numerical_nlsolve(f, u0, p)
# TODO: robust initial guess, better debugging info, and residual check
sol.u
end

###
### Misc
###

function lower_varname_withshift(var, iv, order)
order == 0 && return var
if ModelingToolkit.isoperator(var, ModelingToolkit.Shift)
op = operation(var)
return Shift(op.t, order)(var)
end
return lower_varname(var, iv, order)
end

function isdoubleshift(var)
return ModelingToolkit.isoperator(var, ModelingToolkit.Shift) &&
ModelingToolkit.isoperator(arguments(var)[1], ModelingToolkit.Shift)
end

function simplify_shifts(var)
ModelingToolkit.hasshift(var) || return var
var isa Equation && return simplify_shifts(var.lhs) ~ simplify_shifts(var.rhs)
if isdoubleshift(var)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about triple shifts? Should we recurse on that, too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing that out

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

op1 = operation(var)
vv1 = arguments(var)[1]
op2 = operation(vv1)
vv2 = arguments(vv1)[1]
s1 = op1.steps
s2 = op2.steps
t1 = op1.t
t2 = op2.t
return simplify_shifts(ModelingToolkit.Shift(t1 === nothing ? t2 : t1, s1 + s2)(vv2))
else
return similarterm(var, operation(var), simplify_shifts.(arguments(var)),
Symbolics.symtype(var); metadata = unwrap(var).metadata)
end
end
8 changes: 4 additions & 4 deletions src/systems/alias_elimination.jl
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ function observed2graph(eqs, unknowns)
lhs_j === nothing &&
throw(ArgumentError("The lhs $(eq.lhs) of $eq, doesn't appear in unknowns."))
assigns[i] = lhs_j
vs = vars(eq.rhs)
vs = vars(eq.rhs; op = Symbolics.Operator)
for v in vs
j = get(v2j, v, nothing)
j !== nothing && add_edge!(graph, i, j)
Expand All @@ -463,11 +463,11 @@ function observed2graph(eqs, unknowns)
return graph, assigns
end

function fixpoint_sub(x, dict)
y = fast_substitute(x, dict)
function fixpoint_sub(x, dict; operator = Nothing)
y = fast_substitute(x, dict; operator)
while !isequal(x, y)
y = x
x = fast_substitute(y, dict)
x = fast_substitute(y, dict; operator)
end

return x
Expand Down
61 changes: 40 additions & 21 deletions src/systems/clock_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,10 @@
tss = similar(cid_to_eq, S)
for (id, ieqs) in enumerate(cid_to_eq)
ts_i = system_subset(ts, ieqs)
@set! ts_i.structure.only_discrete = id != continuous_id
if id != continuous_id
ts_i = shift_discrete_system(ts_i)
@set! ts_i.structure.only_discrete = true
end
tss[id] = ts_i
end
return tss, inputs, continuous_id, id_to_clock
Expand All @@ -148,7 +151,7 @@
end
use_index_cache = has_index_cache(osys) && get_index_cache(osys) !== nothing
out = Sym{Any}(:out)
appended_parameters = parameters(syss[continuous_id])
appended_parameters = full_parameters(syss[continuous_id])
offset = length(appended_parameters)
param_to_idx = if use_index_cache
Dict{Any, ParameterIndex}(p => parameter_index(osys, p)
Expand Down Expand Up @@ -180,40 +183,46 @@
disc_to_cont_idxs = Int[]
end
for v in inputs[continuous_id]
vv = arguments(v)[1]
if vv in fullvars
push!(needed_disc_to_cont_obs, vv)
_v = arguments(v)[1]
if _v in fullvars
push!(needed_disc_to_cont_obs, _v)
push!(disc_to_cont_idxs, param_to_idx[v])
continue
end

# If the held quantity is calculated through observed
# it will be shifted forward by 1
_v = Shift(get_iv(sys), 1)(_v)
if _v in fullvars
push!(needed_disc_to_cont_obs, _v)
push!(disc_to_cont_idxs, param_to_idx[v])
continue
end
end
append!(appended_parameters, input, unknowns(sys))
append!(appended_parameters, input)
cont_to_disc_obs = build_explicit_observed_function(
use_index_cache ? osys : syss[continuous_id],
needed_cont_to_disc_obs,
throw = false,
expression = true,
output_type = SVector)
@set! sys.ps = appended_parameters
disc_to_cont_obs = build_explicit_observed_function(sys, needed_disc_to_cont_obs,
throw = false,
expression = true,
output_type = SVector,
ps = reorder_parameters(osys, full_parameters(sys)))
op = Shift,
ps = reorder_parameters(osys, appended_parameters))
ni = length(input)
ns = length(unknowns(sys))
disc = Func(
[
out,
DestructuredArgs(unknowns(osys)),
if use_index_cache
DestructuredArgs.(reorder_parameters(osys, full_parameters(osys)))
else
(DestructuredArgs(appended_parameters),)
end...,
DestructuredArgs.(reorder_parameters(osys, full_parameters(osys)))...,
get_iv(sys)
],
[],
let_block)
let_block) |> toexpr
if use_index_cache
cont_to_disc_idxs = [parameter_index(osys, sym) for sym in input]
disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)]
Expand All @@ -235,8 +244,14 @@
end
empty_disc = isempty(disc_range)
disc_init = if use_index_cache
:(function (p, t)
:(function (u, p, t)
c2d_obs = $cont_to_disc_obs

Check warning on line 248 in src/systems/clock_inference.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/clock_inference.jl#L248

Added line #L248 was not covered by tests
d2c_obs = $disc_to_cont_obs
result = c2d_obs(u, p..., t)
for (val, i) in zip(result, $cont_to_disc_idxs)
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
end

Check warning on line 253 in src/systems/clock_inference.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/clock_inference.jl#L250-L253

Added lines #L250 - L253 were not covered by tests

disc_state = Tuple($(parameter_values)(p, i) for i in $disc_range)
result = d2c_obs(disc_state, p..., t)
for (val, i) in zip(result, $disc_to_cont_idxs)
Expand All @@ -248,11 +263,14 @@
repack(discretes) # to force recalculation of dependents
end)
else
:(function (p, t)
:(function (u, p, t)
c2d_obs = $cont_to_disc_obs

Check warning on line 267 in src/systems/clock_inference.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/clock_inference.jl#L267

Added line #L267 was not covered by tests
d2c_obs = $disc_to_cont_obs
c2d_view = view(p, $cont_to_disc_idxs)

Check warning on line 269 in src/systems/clock_inference.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/clock_inference.jl#L269

Added line #L269 was not covered by tests
d2c_view = view(p, $disc_to_cont_idxs)
disc_state = view(p, $disc_range)
copyto!(d2c_view, d2c_obs(disc_state, p, t))
disc_unknowns = view(p, $disc_range)
copyto!(c2d_view, c2d_obs(u, p, t))
copyto!(d2c_view, d2c_obs(disc_unknowns, p, t))

Check warning on line 273 in src/systems/clock_inference.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/clock_inference.jl#L271-L273

Added lines #L271 - L273 were not covered by tests
end)
end

Expand All @@ -277,9 +295,6 @@
# TODO: find a way to do this without allocating
disc = $disc

push!(saved_values.t, t)
push!(saved_values.saveval, $save_vec)

# Write continuous into to discrete: handles `Sample`
# Write discrete into to continuous
# Update discrete unknowns
Expand Down Expand Up @@ -329,6 +344,10 @@
:(copyto!(d2c_view, d2c_obs(disc_unknowns, p, t)))
end
)

push!(saved_values.t, t)
push!(saved_values.saveval, $save_vec)

# @show "after d2c", p
$(
if use_index_cache
Expand Down
Loading
Loading