Skip to content

feat: initial implementation of new DiscreteSystem #2507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pages = [
"tutorials/modelingtoolkitize.md",
"tutorials/programmatically_generating.md",
"tutorials/stochastic_diffeq.md",
"tutorials/discrete_system.md",
"tutorials/parameter_identifiability.md",
"tutorials/bifurcation_diagram_computation.md",
"tutorials/SampledData.md",
Expand Down
51 changes: 51 additions & 0 deletions docs/src/tutorials/discrete_system.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# (Experimental) Modeling Discrete Systems
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it still experimental?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kind of, unless observed variables are working as intended now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, since shifting the entire system by one step is basically a workaround and eventually we'll need proper compiler support for this. I've discussed with @baggepinnen and that will end up changing things such as the state realization and observed equations in the simplified system.


In this example, we will use the new [`DiscreteSystem`](@ref) API
to create an SIR model.

```@example discrete
using ModelingToolkit
using ModelingToolkit: t_nounits as t
using OrdinaryDiffEq: solve, FunctionMap

@inline function rate_to_proportion(r, t)
1 - exp(-r * t)
end
@parameters c δt β γ
@constants h = 1
@variables S(t) I(t) R(t)
k = ShiftIndex(t)
infection = rate_to_proportion(
β * c * I(k - 1) / (S(k - 1) * h + I(k - 1) + R(k - 1)), δt * h) * S(k - 1)
recovery = rate_to_proportion(γ * h, δt) * I(k - 1)

# Equations
eqs = [S(k) ~ S(k - 1) - infection * h,
I(k) ~ I(k - 1) + infection - recovery,
R(k) ~ R(k - 1) + recovery]
@mtkbuild sys = DiscreteSystem(eqs, t)

u0 = [S(k - 1) => 990.0, I(k - 1) => 10.0, R(k - 1) => 0.0]
p = [β => 0.05, c => 10.0, γ => 0.25, δt => 0.1]
tspan = (0.0, 100.0)
prob = DiscreteProblem(sys, u0, tspan, p)
sol = solve(prob, FunctionMap())
```

All shifts must be non-positive, i.e., discrete-time variables may only be indexed at index
`k, k-1, k-2, ...`. If default values are provided, they are treated as the value of the
variable at the previous timestep. For example, consider the following system to generate
the Fibonacci series:

```@example discrete
@variables x(t) = 1.0
@mtkbuild sys = DiscreteSystem([x ~ x(k - 1) + x(k - 2)], t)
```

The "default value" here should be interpreted as the value of `x` at all past timesteps.
For example, here `x(k-1)` and `x(k-2)` will be `1.0`, and the inital value of `x(k)` will
thus be `2.0`. During problem construction, the _past_ value of a variable should be
provided. For example, providing `[x => 1.0]` while constructing this problem will error.
Provide `[x(k-1) => 1.0]` instead. Note that values provided during problem construction
_do not_ apply to the entire history. Hence, if `[x(k-1) => 2.0]` is provided, the value of
`x(k-2)` will still be `1.0`.
3 changes: 3 additions & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ include("systems/diffeqs/first_order_transform.jl")
include("systems/diffeqs/modelingtoolkitize.jl")
include("systems/diffeqs/basic_transformations.jl")

include("systems/discrete_system/discrete_system.jl")

include("systems/jumps/jumpsystem.jl")

include("systems/optimization/constraints_system.jl")
Expand Down Expand Up @@ -209,6 +211,7 @@ export ODESystem,
export DAEFunctionExpr, DAEProblemExpr
export SDESystem, SDEFunction, SDEFunctionExpr, SDEProblemExpr
export SystemStructure
export DiscreteSystem, DiscreteProblem, DiscreteFunction, DiscreteFunctionExpr
export JumpSystem
export ODEProblem, SDEProblem
export NonlinearFunction, NonlinearFunctionExpr
Expand Down
5 changes: 5 additions & 0 deletions src/clock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,8 @@ Base.hash(c::SolverStepClock, seed::UInt) = seed ⊻ 0x953d7b9a18874b91
function Base.:(==)(c1::SolverStepClock, c2::SolverStepClock)
((c1.t === nothing || c2.t === nothing) || isequal(c1.t, c2.t))
end

struct IntegerSequence <: AbstractClock
t::Union{Nothing, Symbolic}
IntegerSequence(t::Union{Num, Symbolic}) = new(value(t))
end
5 changes: 4 additions & 1 deletion src/discretedomain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ function (D::Shift)(x::Num, allow_zero = false)
vt = value(x)
if istree(vt)
op = operation(vt)
if op isa Shift
if op isa Sample
error("Cannot shift a `Sample`. Create a variable to represent the sampled value and shift that instead")
elseif op isa Shift
if D.t === nothing || isequal(D.t, op.t)
arg = arguments(vt)[1]
newsteps = D.steps + op.steps
Expand Down Expand Up @@ -168,6 +170,7 @@ struct ShiftIndex
steps::Int
ShiftIndex(clock::TimeDomain = Inferred(), steps::Int = 0) = new(clock, steps)
ShiftIndex(t::Num, dt::Real, steps::Int = 0) = new(Clock(t, dt), steps)
ShiftIndex(t::Num, steps::Int = 0) = new(IntegerSequence(t), steps)
end

function (xn::Num)(k::ShiftIndex)
Expand Down
15 changes: 8 additions & 7 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,8 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
dx = fullvars[dv]
# add `x_t`
order, lv = var_order(dv)
x_t = lower_varname(fullvars[lv], iv, order)
push!(fullvars, x_t)
x_t = lower_varname_withshift(fullvars[lv], iv, order)
push!(fullvars, simplify_shifts(x_t))
v_t = length(fullvars)
v_t_idx = add_vertex!(var_to_diff)
add_vertex!(graph, DST)
Expand Down Expand Up @@ -437,11 +437,12 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
# We cannot solve the differential variable like D(x)
if isdervar(iv)
order, lv = var_order(iv)
dx = D(lower_varname(fullvars[lv], idep, order - 1))
eq = dx ~ ModelingToolkit.fixpoint_sub(
dx = D(simplify_shifts(lower_varname_withshift(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the lower_varname_withshift inside a isdervar branch?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that's the structural info.

fullvars[lv], idep, order - 1)))
eq = dx ~ simplify_shifts(ModelingToolkit.fixpoint_sub(
Symbolics.solve_for(neweqs[ieq],
fullvars[iv]),
total_sub)
total_sub; operator = ModelingToolkit.Shift))
for e in 𝑑neighbors(graph, iv)
e == ieq && continue
for v in 𝑠neighbors(graph, e)
Expand All @@ -450,7 +451,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
rem_edge!(graph, e, iv)
end
push!(diff_eqs, eq)
total_sub[eq.lhs] = eq.rhs
total_sub[simplify_shifts(eq.lhs)] = eq.rhs
push!(diffeq_idxs, ieq)
push!(diff_vars, diff_to_var[iv])
continue
Expand All @@ -469,7 +470,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
neweq = var ~ ModelingToolkit.fixpoint_sub(
simplify ?
Symbolics.simplify(rhs) : rhs,
total_sub)
total_sub; operator = ModelingToolkit.Shift)
push!(subeqs, neweq)
push!(solved_equations, ieq)
push!(solved_variables, iv)
Expand Down
37 changes: 37 additions & 0 deletions src/structural_transformation/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,40 @@ function numerical_nlsolve(f, u0, p)
# TODO: robust initial guess, better debugging info, and residual check
sol.u
end

###
### Misc
###

function lower_varname_withshift(var, iv, order)
order == 0 && return var
if ModelingToolkit.isoperator(var, ModelingToolkit.Shift)
op = operation(var)
return Shift(op.t, order)(var)
end
return lower_varname(var, iv, order)
end

function isdoubleshift(var)
return ModelingToolkit.isoperator(var, ModelingToolkit.Shift) &&
ModelingToolkit.isoperator(arguments(var)[1], ModelingToolkit.Shift)
end

function simplify_shifts(var)
ModelingToolkit.hasshift(var) || return var
var isa Equation && return simplify_shifts(var.lhs) ~ simplify_shifts(var.rhs)
if isdoubleshift(var)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about triple shifts? Should we recurse on that, too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing that out

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

op1 = operation(var)
vv1 = arguments(var)[1]
op2 = operation(vv1)
vv2 = arguments(vv1)[1]
s1 = op1.steps
s2 = op2.steps
t1 = op1.t
t2 = op2.t
return simplify_shifts(ModelingToolkit.Shift(t1 === nothing ? t2 : t1, s1 + s2)(vv2))
else
return similarterm(var, operation(var), simplify_shifts.(arguments(var)),
Symbolics.symtype(var); metadata = unwrap(var).metadata)
end
end
8 changes: 4 additions & 4 deletions src/systems/alias_elimination.jl
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ function observed2graph(eqs, unknowns)
lhs_j === nothing &&
throw(ArgumentError("The lhs $(eq.lhs) of $eq, doesn't appear in unknowns."))
assigns[i] = lhs_j
vs = vars(eq.rhs)
vs = vars(eq.rhs; op = Symbolics.Operator)
for v in vs
j = get(v2j, v, nothing)
j !== nothing && add_edge!(graph, i, j)
Expand All @@ -463,11 +463,11 @@ function observed2graph(eqs, unknowns)
return graph, assigns
end

function fixpoint_sub(x, dict)
y = fast_substitute(x, dict)
function fixpoint_sub(x, dict; operator = Nothing)
y = fast_substitute(x, dict; operator)
while !isequal(x, y)
y = x
x = fast_substitute(y, dict)
x = fast_substitute(y, dict; operator)
end

return x
Expand Down
61 changes: 40 additions & 21 deletions src/systems/clock_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,10 @@ function split_system(ci::ClockInference{S}) where {S}
tss = similar(cid_to_eq, S)
for (id, ieqs) in enumerate(cid_to_eq)
ts_i = system_subset(ts, ieqs)
@set! ts_i.structure.only_discrete = id != continuous_id
if id != continuous_id
ts_i = shift_discrete_system(ts_i)
@set! ts_i.structure.only_discrete = true
end
tss[id] = ts_i
end
return tss, inputs, continuous_id, id_to_clock
Expand All @@ -148,7 +151,7 @@ function generate_discrete_affect(
end
use_index_cache = has_index_cache(osys) && get_index_cache(osys) !== nothing
out = Sym{Any}(:out)
appended_parameters = parameters(syss[continuous_id])
appended_parameters = full_parameters(syss[continuous_id])
offset = length(appended_parameters)
param_to_idx = if use_index_cache
Dict{Any, ParameterIndex}(p => parameter_index(osys, p)
Expand Down Expand Up @@ -180,40 +183,46 @@ function generate_discrete_affect(
disc_to_cont_idxs = Int[]
end
for v in inputs[continuous_id]
vv = arguments(v)[1]
if vv in fullvars
push!(needed_disc_to_cont_obs, vv)
_v = arguments(v)[1]
if _v in fullvars
push!(needed_disc_to_cont_obs, _v)
push!(disc_to_cont_idxs, param_to_idx[v])
continue
end

# If the held quantity is calculated through observed
# it will be shifted forward by 1
_v = Shift(get_iv(sys), 1)(_v)
if _v in fullvars
push!(needed_disc_to_cont_obs, _v)
push!(disc_to_cont_idxs, param_to_idx[v])
continue
end
end
append!(appended_parameters, input, unknowns(sys))
append!(appended_parameters, input)
cont_to_disc_obs = build_explicit_observed_function(
use_index_cache ? osys : syss[continuous_id],
needed_cont_to_disc_obs,
throw = false,
expression = true,
output_type = SVector)
@set! sys.ps = appended_parameters
disc_to_cont_obs = build_explicit_observed_function(sys, needed_disc_to_cont_obs,
throw = false,
expression = true,
output_type = SVector,
ps = reorder_parameters(osys, full_parameters(sys)))
op = Shift,
ps = reorder_parameters(osys, appended_parameters))
ni = length(input)
ns = length(unknowns(sys))
disc = Func(
[
out,
DestructuredArgs(unknowns(osys)),
if use_index_cache
DestructuredArgs.(reorder_parameters(osys, full_parameters(osys)))
else
(DestructuredArgs(appended_parameters),)
end...,
DestructuredArgs.(reorder_parameters(osys, full_parameters(osys)))...,
get_iv(sys)
],
[],
let_block)
let_block) |> toexpr
if use_index_cache
cont_to_disc_idxs = [parameter_index(osys, sym) for sym in input]
disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)]
Expand All @@ -235,8 +244,14 @@ function generate_discrete_affect(
end
empty_disc = isempty(disc_range)
disc_init = if use_index_cache
:(function (p, t)
:(function (u, p, t)
c2d_obs = $cont_to_disc_obs
d2c_obs = $disc_to_cont_obs
result = c2d_obs(u, p..., t)
for (val, i) in zip(result, $cont_to_disc_idxs)
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
end

disc_state = Tuple($(parameter_values)(p, i) for i in $disc_range)
result = d2c_obs(disc_state, p..., t)
for (val, i) in zip(result, $disc_to_cont_idxs)
Expand All @@ -248,11 +263,14 @@ function generate_discrete_affect(
repack(discretes) # to force recalculation of dependents
end)
else
:(function (p, t)
:(function (u, p, t)
c2d_obs = $cont_to_disc_obs
d2c_obs = $disc_to_cont_obs
c2d_view = view(p, $cont_to_disc_idxs)
d2c_view = view(p, $disc_to_cont_idxs)
disc_state = view(p, $disc_range)
copyto!(d2c_view, d2c_obs(disc_state, p, t))
disc_unknowns = view(p, $disc_range)
copyto!(c2d_view, c2d_obs(u, p, t))
copyto!(d2c_view, d2c_obs(disc_unknowns, p, t))
end)
end

Expand All @@ -277,9 +295,6 @@ function generate_discrete_affect(
# TODO: find a way to do this without allocating
disc = $disc

push!(saved_values.t, t)
push!(saved_values.saveval, $save_vec)

# Write continuous into to discrete: handles `Sample`
# Write discrete into to continuous
# Update discrete unknowns
Expand Down Expand Up @@ -329,6 +344,10 @@ function generate_discrete_affect(
:(copyto!(d2c_view, d2c_obs(disc_unknowns, p, t)))
end
)

push!(saved_values.t, t)
push!(saved_values.saveval, $save_vec)

# @show "after d2c", p
$(
if use_index_cache
Expand Down
Loading