diff --git a/docs/pages.jl b/docs/pages.jl index 32edc1cf4b..c3c4adfda6 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -8,6 +8,7 @@ pages = [ "tutorials/modelingtoolkitize.md", "tutorials/programmatically_generating.md", "tutorials/stochastic_diffeq.md", + "tutorials/discrete_system.md", "tutorials/parameter_identifiability.md", "tutorials/bifurcation_diagram_computation.md", "tutorials/SampledData.md", diff --git a/docs/src/tutorials/discrete_system.md b/docs/src/tutorials/discrete_system.md new file mode 100644 index 0000000000..666125e20e --- /dev/null +++ b/docs/src/tutorials/discrete_system.md @@ -0,0 +1,51 @@ +# (Experimental) Modeling Discrete Systems + +In this example, we will use the new [`DiscreteSystem`](@ref) API +to create an SIR model. + +```@example discrete +using ModelingToolkit +using ModelingToolkit: t_nounits as t +using OrdinaryDiffEq: solve, FunctionMap + +@inline function rate_to_proportion(r, t) + 1 - exp(-r * t) +end +@parameters c δt β γ +@constants h = 1 +@variables S(t) I(t) R(t) +k = ShiftIndex(t) +infection = rate_to_proportion( + β * c * I(k - 1) / (S(k - 1) * h + I(k - 1) + R(k - 1)), δt * h) * S(k - 1) +recovery = rate_to_proportion(γ * h, δt) * I(k - 1) + +# Equations +eqs = [S(k) ~ S(k - 1) - infection * h, + I(k) ~ I(k - 1) + infection - recovery, + R(k) ~ R(k - 1) + recovery] +@mtkbuild sys = DiscreteSystem(eqs, t) + +u0 = [S(k - 1) => 990.0, I(k - 1) => 10.0, R(k - 1) => 0.0] +p = [β => 0.05, c => 10.0, γ => 0.25, δt => 0.1] +tspan = (0.0, 100.0) +prob = DiscreteProblem(sys, u0, tspan, p) +sol = solve(prob, FunctionMap()) +``` + +All shifts must be non-positive, i.e., discrete-time variables may only be indexed at index +`k, k-1, k-2, ...`. If default values are provided, they are treated as the value of the +variable at the previous timestep. For example, consider the following system to generate +the Fibonacci series: + +```@example discrete +@variables x(t) = 1.0 +@mtkbuild sys = DiscreteSystem([x ~ x(k - 1) + x(k - 2)], t) +``` + +The "default value" here should be interpreted as the value of `x` at all past timesteps. +For example, here `x(k-1)` and `x(k-2)` will be `1.0`, and the inital value of `x(k)` will +thus be `2.0`. During problem construction, the _past_ value of a variable should be +provided. For example, providing `[x => 1.0]` while constructing this problem will error. +Provide `[x(k-1) => 1.0]` instead. Note that values provided during problem construction +_do not_ apply to the entire history. Hence, if `[x(k-1) => 2.0]` is provided, the value of +`x(k-2)` will still be `1.0`. diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index c0645f52b1..979021e2f0 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -147,6 +147,8 @@ include("systems/diffeqs/first_order_transform.jl") include("systems/diffeqs/modelingtoolkitize.jl") include("systems/diffeqs/basic_transformations.jl") +include("systems/discrete_system/discrete_system.jl") + include("systems/jumps/jumpsystem.jl") include("systems/optimization/constraints_system.jl") @@ -209,6 +211,7 @@ export ODESystem, export DAEFunctionExpr, DAEProblemExpr export SDESystem, SDEFunction, SDEFunctionExpr, SDEProblemExpr export SystemStructure +export DiscreteSystem, DiscreteProblem, DiscreteFunction, DiscreteFunctionExpr export JumpSystem export ODEProblem, SDEProblem export NonlinearFunction, NonlinearFunctionExpr diff --git a/src/clock.jl b/src/clock.jl index da88f02c39..7ca1707724 100644 --- a/src/clock.jl +++ b/src/clock.jl @@ -146,3 +146,8 @@ Base.hash(c::SolverStepClock, seed::UInt) = seed ⊻ 0x953d7b9a18874b91 function Base.:(==)(c1::SolverStepClock, c2::SolverStepClock) ((c1.t === nothing || c2.t === nothing) || isequal(c1.t, c2.t)) end + +struct IntegerSequence <: AbstractClock + t::Union{Nothing, Symbolic} + IntegerSequence(t::Union{Num, Symbolic}) = new(value(t)) +end diff --git a/src/discretedomain.jl b/src/discretedomain.jl index c8f2bbcd84..68e8e17b03 100644 --- a/src/discretedomain.jl +++ b/src/discretedomain.jl @@ -38,7 +38,9 @@ function (D::Shift)(x::Num, allow_zero = false) vt = value(x) if istree(vt) op = operation(vt) - if op isa Shift + if op isa Sample + error("Cannot shift a `Sample`. Create a variable to represent the sampled value and shift that instead") + elseif op isa Shift if D.t === nothing || isequal(D.t, op.t) arg = arguments(vt)[1] newsteps = D.steps + op.steps @@ -168,6 +170,7 @@ struct ShiftIndex steps::Int ShiftIndex(clock::TimeDomain = Inferred(), steps::Int = 0) = new(clock, steps) ShiftIndex(t::Num, dt::Real, steps::Int = 0) = new(Clock(t, dt), steps) + ShiftIndex(t::Num, steps::Int = 0) = new(IntegerSequence(t), steps) end function (xn::Num)(k::ShiftIndex) diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index 23a5258104..9445986a72 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -382,8 +382,8 @@ function tearing_reassemble(state::TearingState, var_eq_matching; dx = fullvars[dv] # add `x_t` order, lv = var_order(dv) - x_t = lower_varname(fullvars[lv], iv, order) - push!(fullvars, x_t) + x_t = lower_varname_withshift(fullvars[lv], iv, order) + push!(fullvars, simplify_shifts(x_t)) v_t = length(fullvars) v_t_idx = add_vertex!(var_to_diff) add_vertex!(graph, DST) @@ -437,11 +437,12 @@ function tearing_reassemble(state::TearingState, var_eq_matching; # We cannot solve the differential variable like D(x) if isdervar(iv) order, lv = var_order(iv) - dx = D(lower_varname(fullvars[lv], idep, order - 1)) - eq = dx ~ ModelingToolkit.fixpoint_sub( + dx = D(simplify_shifts(lower_varname_withshift( + fullvars[lv], idep, order - 1))) + eq = dx ~ simplify_shifts(ModelingToolkit.fixpoint_sub( Symbolics.solve_for(neweqs[ieq], fullvars[iv]), - total_sub) + total_sub; operator = ModelingToolkit.Shift)) for e in 𝑑neighbors(graph, iv) e == ieq && continue for v in 𝑠neighbors(graph, e) @@ -450,7 +451,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching; rem_edge!(graph, e, iv) end push!(diff_eqs, eq) - total_sub[eq.lhs] = eq.rhs + total_sub[simplify_shifts(eq.lhs)] = eq.rhs push!(diffeq_idxs, ieq) push!(diff_vars, diff_to_var[iv]) continue @@ -469,7 +470,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching; neweq = var ~ ModelingToolkit.fixpoint_sub( simplify ? Symbolics.simplify(rhs) : rhs, - total_sub) + total_sub; operator = ModelingToolkit.Shift) push!(subeqs, neweq) push!(solved_equations, ieq) push!(solved_variables, iv) diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl index d13618bdce..3ae8fb224f 100644 --- a/src/structural_transformation/utils.jl +++ b/src/structural_transformation/utils.jl @@ -412,3 +412,40 @@ function numerical_nlsolve(f, u0, p) # TODO: robust initial guess, better debugging info, and residual check sol.u end + +### +### Misc +### + +function lower_varname_withshift(var, iv, order) + order == 0 && return var + if ModelingToolkit.isoperator(var, ModelingToolkit.Shift) + op = operation(var) + return Shift(op.t, order)(var) + end + return lower_varname(var, iv, order) +end + +function isdoubleshift(var) + return ModelingToolkit.isoperator(var, ModelingToolkit.Shift) && + ModelingToolkit.isoperator(arguments(var)[1], ModelingToolkit.Shift) +end + +function simplify_shifts(var) + ModelingToolkit.hasshift(var) || return var + var isa Equation && return simplify_shifts(var.lhs) ~ simplify_shifts(var.rhs) + if isdoubleshift(var) + op1 = operation(var) + vv1 = arguments(var)[1] + op2 = operation(vv1) + vv2 = arguments(vv1)[1] + s1 = op1.steps + s2 = op2.steps + t1 = op1.t + t2 = op2.t + return simplify_shifts(ModelingToolkit.Shift(t1 === nothing ? t2 : t1, s1 + s2)(vv2)) + else + return similarterm(var, operation(var), simplify_shifts.(arguments(var)), + Symbolics.symtype(var); metadata = unwrap(var).metadata) + end +end diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index 2e553151f8..cded8edbfe 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -453,7 +453,7 @@ function observed2graph(eqs, unknowns) lhs_j === nothing && throw(ArgumentError("The lhs $(eq.lhs) of $eq, doesn't appear in unknowns.")) assigns[i] = lhs_j - vs = vars(eq.rhs) + vs = vars(eq.rhs; op = Symbolics.Operator) for v in vs j = get(v2j, v, nothing) j !== nothing && add_edge!(graph, i, j) @@ -463,11 +463,11 @@ function observed2graph(eqs, unknowns) return graph, assigns end -function fixpoint_sub(x, dict) - y = fast_substitute(x, dict) +function fixpoint_sub(x, dict; operator = Nothing) + y = fast_substitute(x, dict; operator) while !isequal(x, y) y = x - x = fast_substitute(y, dict) + x = fast_substitute(y, dict; operator) end return x diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index dab56cf916..c4c18d5bdb 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -133,7 +133,10 @@ function split_system(ci::ClockInference{S}) where {S} tss = similar(cid_to_eq, S) for (id, ieqs) in enumerate(cid_to_eq) ts_i = system_subset(ts, ieqs) - @set! ts_i.structure.only_discrete = id != continuous_id + if id != continuous_id + ts_i = shift_discrete_system(ts_i) + @set! ts_i.structure.only_discrete = true + end tss[id] = ts_i end return tss, inputs, continuous_id, id_to_clock @@ -148,7 +151,7 @@ function generate_discrete_affect( end use_index_cache = has_index_cache(osys) && get_index_cache(osys) !== nothing out = Sym{Any}(:out) - appended_parameters = parameters(syss[continuous_id]) + appended_parameters = full_parameters(syss[continuous_id]) offset = length(appended_parameters) param_to_idx = if use_index_cache Dict{Any, ParameterIndex}(p => parameter_index(osys, p) @@ -180,40 +183,46 @@ function generate_discrete_affect( disc_to_cont_idxs = Int[] end for v in inputs[continuous_id] - vv = arguments(v)[1] - if vv in fullvars - push!(needed_disc_to_cont_obs, vv) + _v = arguments(v)[1] + if _v in fullvars + push!(needed_disc_to_cont_obs, _v) + push!(disc_to_cont_idxs, param_to_idx[v]) + continue + end + + # If the held quantity is calculated through observed + # it will be shifted forward by 1 + _v = Shift(get_iv(sys), 1)(_v) + if _v in fullvars + push!(needed_disc_to_cont_obs, _v) push!(disc_to_cont_idxs, param_to_idx[v]) + continue end end - append!(appended_parameters, input, unknowns(sys)) + append!(appended_parameters, input) cont_to_disc_obs = build_explicit_observed_function( use_index_cache ? osys : syss[continuous_id], needed_cont_to_disc_obs, throw = false, expression = true, output_type = SVector) - @set! sys.ps = appended_parameters disc_to_cont_obs = build_explicit_observed_function(sys, needed_disc_to_cont_obs, throw = false, expression = true, output_type = SVector, - ps = reorder_parameters(osys, full_parameters(sys))) + op = Shift, + ps = reorder_parameters(osys, appended_parameters)) ni = length(input) ns = length(unknowns(sys)) disc = Func( [ out, DestructuredArgs(unknowns(osys)), - if use_index_cache - DestructuredArgs.(reorder_parameters(osys, full_parameters(osys))) - else - (DestructuredArgs(appended_parameters),) - end..., + DestructuredArgs.(reorder_parameters(osys, full_parameters(osys)))..., get_iv(sys) ], [], - let_block) + let_block) |> toexpr if use_index_cache cont_to_disc_idxs = [parameter_index(osys, sym) for sym in input] disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)] @@ -235,8 +244,14 @@ function generate_discrete_affect( end empty_disc = isempty(disc_range) disc_init = if use_index_cache - :(function (p, t) + :(function (u, p, t) + c2d_obs = $cont_to_disc_obs d2c_obs = $disc_to_cont_obs + result = c2d_obs(u, p..., t) + for (val, i) in zip(result, $cont_to_disc_idxs) + $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) + end + disc_state = Tuple($(parameter_values)(p, i) for i in $disc_range) result = d2c_obs(disc_state, p..., t) for (val, i) in zip(result, $disc_to_cont_idxs) @@ -248,11 +263,14 @@ function generate_discrete_affect( repack(discretes) # to force recalculation of dependents end) else - :(function (p, t) + :(function (u, p, t) + c2d_obs = $cont_to_disc_obs d2c_obs = $disc_to_cont_obs + c2d_view = view(p, $cont_to_disc_idxs) d2c_view = view(p, $disc_to_cont_idxs) - disc_state = view(p, $disc_range) - copyto!(d2c_view, d2c_obs(disc_state, p, t)) + disc_unknowns = view(p, $disc_range) + copyto!(c2d_view, c2d_obs(u, p, t)) + copyto!(d2c_view, d2c_obs(disc_unknowns, p, t)) end) end @@ -277,9 +295,6 @@ function generate_discrete_affect( # TODO: find a way to do this without allocating disc = $disc - push!(saved_values.t, t) - push!(saved_values.saveval, $save_vec) - # Write continuous into to discrete: handles `Sample` # Write discrete into to continuous # Update discrete unknowns @@ -329,6 +344,10 @@ function generate_discrete_affect( :(copyto!(d2c_view, d2c_obs(disc_unknowns, p, t))) end ) + + push!(saved_values.t, t) + push!(saved_values.saveval, $save_vec) + # @show "after d2c", p $( if use_index_cache diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 9886c0a2d0..2ae14962bb 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -812,7 +812,8 @@ function get_u0_p(sys, u0, p, defs end -function get_u0(sys, u0map, parammap = nothing; symbolic_u0 = false) +function get_u0( + sys, u0map, parammap = nothing; symbolic_u0 = false, toterm = default_toterm) dvs = unknowns(sys) ps = parameters(sys) defs = defaults(sys) @@ -821,9 +822,10 @@ function get_u0(sys, u0map, parammap = nothing; symbolic_u0 = false) end defs = mergedefaults(defs, u0map, dvs) if symbolic_u0 - u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false) + u0 = varmap_to_vars( + u0map, dvs; defaults = defs, tofloat = false, use_union = false, toterm) else - u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true) + u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, toterm) end return u0, defs end @@ -862,6 +864,41 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; # since they will be checked in the initialization problem's construction # TODO: make check for if a DAE cheaper than calculating the mass matrix a second time! ci = infer_clocks!(ClockInference(TearingState(sys))) + + if eltype(parammap) <: Pair + parammap = Dict(unwrap(k) => v for (k, v) in todict(parammap)) + elseif parammap isa AbstractArray + if isempty(parammap) + parammap = SciMLBase.NullParameters() + else + parammap = Dict(unwrap.(parameters(sys)) .=> parammap) + end + end + clockedparammap = Dict() + defs = ModelingToolkit.get_defaults(sys) + for v in ps + v = unwrap(v) + is_discrete_domain(v) || continue + op = operation(v) + if !isa(op, Symbolics.Operator) && parammap != SciMLBase.NullParameters() && + haskey(parammap, v) + error("Initial conditions for discrete variables must be for the past state of the unknown. Instead of providing the condition for $v, provide the condition for $(Shift(iv, -1)(v)).") + end + shiftedv = StructuralTransformations.simplify_shifts(Shift(iv, -1)(v)) + if parammap != SciMLBase.NullParameters() && + (val = get(parammap, shiftedv, nothing)) !== nothing + clockedparammap[v] = val + elseif op isa Shift + root = arguments(v)[1] + haskey(defs, root) || error("Initial condition for $v not provided.") + clockedparammap[v] = defs[root] + end + end + parammap = if parammap == SciMLBase.NullParameters() + clockedparammap + else + merge(parammap, clockedparammap) + end # TODO: make it work with clocks # ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first if sys isa ODESystem && (implicit_dae || !isempty(missingvars)) && @@ -1042,7 +1079,8 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...) discrete_cbs = map(affects, clocks, svs) do affect, clock, sv if clock isa Clock - PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt) + PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt; + final_affect = true, initial_affect = true) elseif clock isa SolverStepClock affect = DiscreteSaveAffect(affect, sv) DiscreteCallback(Returns(true), affect, @@ -1073,10 +1111,11 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = if svs !== nothing kwargs1 = merge(kwargs1, (disc_saved_values = svs,)) end + prob = ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...) if !isempty(inits) for init in inits - init(prob.p, tspan[1]) + # init(prob.u0, prob.p, tspan[1]) end end prob @@ -1148,12 +1187,12 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [], h(p::MTKParameters, t) = h_oop(p..., t) u0 = h(p, tspan[1]) cbs = process_events(sys; callback, kwargs...) - inits = [] if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing - affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...) + affects, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...) discrete_cbs = map(affects, clocks, svs) do affect, clock, sv if clock isa Clock - PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt) + PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt; + final_affect = true, initial_affect = true) else error("$clock is not a supported clock type.") end @@ -1179,13 +1218,7 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [], if svs !== nothing kwargs1 = merge(kwargs1, (disc_saved_values = svs,)) end - prob = DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...) - if !isempty(inits) - for init in inits - init(prob.p, tspan[1]) - end - end - prob + DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...) end function DiffEqBase.SDDEProblem(sys::AbstractODESystem, args...; kwargs...) @@ -1210,12 +1243,12 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [], h(p, t) = h_oop(p, t) u0 = h(p, tspan[1]) cbs = process_events(sys; callback, kwargs...) - inits = [] if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing - affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...) + affects, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...) discrete_cbs = map(affects, clocks, svs) do affect, clock, sv if clock isa Clock - PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt) + PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt; + final_affect = true, initial_affect = true) else error("$clock is not a supported clock type.") end @@ -1252,15 +1285,9 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [], else noise_rate_prototype = zeros(eltype(u0), size(noiseeqs)) end - prob = SDDEProblem{iip}(f, f.g, u0, h, tspan, p; + SDDEProblem{iip}(f, f.g, u0, h, tspan, p; noise_rate_prototype = noise_rate_prototype, kwargs1..., kwargs...) - if !isempty(inits) - for init in inits - init(prob.p, tspan[1]) - end - end - prob end """ diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 526d83d48b..6557b2a4a9 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -378,6 +378,7 @@ function build_explicit_observed_function(sys, ts; checkbounds = true, drop_expr = drop_expr, ps = full_parameters(sys), + op = Operator, throw = true) if (isscalar = !(ts isa AbstractVector)) ts = [ts] @@ -385,7 +386,7 @@ function build_explicit_observed_function(sys, ts; ts = unwrap.(Symbolics.scalarize(ts)) vars = Set() - foreach(Base.Fix1(vars!, vars), ts) + foreach(v -> vars!(vars, v; op), ts) ivs = independent_variables(sys) dep_vars = scalarize(setdiff(vars, ivs)) @@ -452,13 +453,16 @@ function build_explicit_observed_function(sys, ts; if inputs !== nothing ps = setdiff(ps, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list end - if has_index_cache(sys) && get_index_cache(sys) !== nothing - ps = DestructuredArgs.(reorder_parameters(get_index_cache(sys), ps)) - elseif ps isa Tuple + if ps isa Tuple ps = DestructuredArgs.(ps, inbounds = !checkbounds) + elseif has_index_cache(sys) && get_index_cache(sys) !== nothing + ps = DestructuredArgs.(reorder_parameters(get_index_cache(sys), ps)) else ps = (DestructuredArgs(ps, inbounds = !checkbounds),) end + if isempty(ps) + ps = (DestructuredArgs([]),) + end dvs = DestructuredArgs(unknowns(sys), inbounds = !checkbounds) if inputs === nothing args = [dvs, ps..., ivs...] diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl new file mode 100644 index 0000000000..530d2d03bc --- /dev/null +++ b/src/systems/discrete_system/discrete_system.jl @@ -0,0 +1,373 @@ +""" +$(TYPEDEF) +A system of difference equations. +# Fields +$(FIELDS) +# Example +``` +using ModelingToolkit +using ModelingToolkit: t_nounits as t +@parameters σ=28.0 ρ=10.0 β=8/3 δt=0.1 +@variables x(t)=1.0 y(t)=0.0 z(t)=0.0 +k = ShiftIndex(t) +eqs = [x(k+1) ~ σ*(y-x), + y(k+1) ~ x*(ρ-z)-y, + z(k+1) ~ x*y - β*z] +@named de = DiscreteSystem(eqs,t,[x,y,z],[σ,ρ,β]; tspan = (0, 1000.0)) # or +@named de = DiscreteSystem(eqs) +``` +""" +struct DiscreteSystem <: AbstractTimeDependentSystem + """ + A tag for the system. If two systems have the same tag, then they are + structurally identical. + """ + tag::UInt + """The differential equations defining the discrete system.""" + eqs::Vector{Equation} + """Independent variable.""" + iv::BasicSymbolic{Real} + """Dependent (state) variables. Must not contain the independent variable.""" + unknowns::Vector + """Parameter variables. Must not contain the independent variable.""" + ps::Vector + """Time span.""" + tspan::Union{NTuple{2, Any}, Nothing} + """Array variables.""" + var_to_name::Any + """Observed states.""" + observed::Vector{Equation} + """ + The name of the system + """ + name::Symbol + """ + The internal systems. These are required to have unique names. + """ + systems::Vector{DiscreteSystem} + """ + The default values to use when initial conditions and/or + parameters are not supplied in `DiscreteProblem`. + """ + defaults::Dict + """ + Inject assignment statements before the evaluation of the RHS function. + """ + preface::Any + """ + Type of the system. + """ + connector_type::Any + """ + A mapping from dependent parameters to expressions describing how they are calculated from + other parameters. + """ + parameter_dependencies::Union{Nothing, Dict} + """ + Metadata for the system, to be used by downstream packages. + """ + metadata::Any + """ + Metadata for MTK GUI. + """ + gui_metadata::Union{Nothing, GUIMetadata} + """ + Cache for intermediate tearing state. + """ + tearing_state::Any + """ + Substitutions generated by tearing. + """ + substitutions::Any + """ + If a model `sys` is complete, then `sys.x` no longer performs namespacing. + """ + complete::Bool + """ + Cached data for fast symbolic indexing. + """ + index_cache::Union{Nothing, IndexCache} + """ + The hierarchical parent system before simplification. + """ + parent::Any + + function DiscreteSystem(tag, discreteEqs, iv, dvs, ps, tspan, var_to_name, + observed, + name, + systems, defaults, preface, connector_type, parameter_dependencies = nothing, + metadata = nothing, gui_metadata = nothing, + tearing_state = nothing, substitutions = nothing, + complete = false, index_cache = nothing, parent = nothing; + checks::Union{Bool, Int} = true) + if checks == true || (checks & CheckComponents) > 0 + check_variables(dvs, iv) + check_parameters(ps, iv) + end + if checks == true || (checks & CheckUnits) > 0 + u = __get_unit_type(dvs, ps, iv) + check_units(u, discreteEqs) + end + new(tag, discreteEqs, iv, dvs, ps, tspan, var_to_name, observed, name, + systems, + defaults, + preface, connector_type, parameter_dependencies, metadata, gui_metadata, + tearing_state, substitutions, complete, index_cache, parent) + end +end + +""" + $(TYPEDSIGNATURES) +Constructs a DiscreteSystem. +""" +function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps; + observed = Num[], + systems = DiscreteSystem[], + tspan = nothing, + name = nothing, + default_u0 = Dict(), + default_p = Dict(), + defaults = _merge(Dict(default_u0), Dict(default_p)), + preface = nothing, + connector_type = nothing, + parameter_dependencies = nothing, + metadata = nothing, + gui_metadata = nothing, + kwargs...) + name === nothing && + throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) + iv′ = value(iv) + dvs′ = value.(dvs) + ps′ = value.(ps) + if any(hasderiv, eqs) || any(hashold, eqs) || any(hassample, eqs) || any(hasdiff, eqs) + error("Equations in a `DiscreteSystem` can only have `Shift` operators.") + end + if !(isempty(default_u0) && isempty(default_p)) + Base.depwarn( + "`default_u0` and `default_p` are deprecated. Use `defaults` instead.", + :DiscreteSystem, force = true) + end + defaults = todict(defaults) + defaults = Dict(value(k) => value(v) for (k, v) in pairs(defaults)) + + var_to_name = Dict() + process_variables!(var_to_name, defaults, dvs′) + process_variables!(var_to_name, defaults, ps′) + isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed)) + + sysnames = nameof.(systems) + if length(unique(sysnames)) != length(sysnames) + throw(ArgumentError("System names must be unique.")) + end + DiscreteSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), + eqs, iv′, dvs′, ps′, tspan, var_to_name, observed, name, systems, + defaults, preface, connector_type, parameter_dependencies, metadata, gui_metadata, kwargs...) +end + +function DiscreteSystem(eqs, iv; kwargs...) + eqs = collect(eqs) + diffvars = OrderedSet() + allunknowns = OrderedSet() + ps = OrderedSet() + iv = value(iv) + for eq in eqs + collect_vars!(allunknowns, ps, eq.lhs, iv; op = Shift) + collect_vars!(allunknowns, ps, eq.rhs, iv; op = Shift) + if istree(eq.lhs) && operation(eq.lhs) isa Shift + isequal(iv, operation(eq.lhs).t) || + throw(ArgumentError("A DiscreteSystem can only have one independent variable.")) + eq.lhs in diffvars && + throw(ArgumentError("The shift variable $(eq.lhs) is not unique in the system of equations.")) + push!(diffvars, eq.lhs) + end + end + new_ps = OrderedSet() + for p in ps + if istree(p) && operation(p) === getindex + par = arguments(p)[begin] + if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() && + all(par[i] in ps for i in eachindex(par)) + push!(new_ps, par) + else + push!(new_ps, p) + end + else + push!(new_ps, p) + end + end + return DiscreteSystem(eqs, iv, + collect(allunknowns), collect(new_ps); kwargs...) +end + +function generate_function( + sys::DiscreteSystem, dvs = unknowns(sys), ps = full_parameters(sys); kwargs...) + generate_custom_function(sys, [eq.rhs for eq in equations(sys)], dvs, ps; kwargs...) +end + +function process_DiscreteProblem(constructor, sys::DiscreteSystem, u0map, parammap; + linenumbers = true, parallel = SerialForm(), + eval_expression = true, + use_union = false, + tofloat = !use_union, + kwargs...) + iv = get_iv(sys) + eqs = equations(sys) + dvs = unknowns(sys) + ps = parameters(sys) + + trueu0map = Dict() + for (k, v) in u0map + k = unwrap(k) + if !((op = operation(k)) isa Shift) + error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).") + end + trueu0map[Shift(iv, op.steps + 1)(arguments(k)[1])] = v + end + defs = ModelingToolkit.get_defaults(sys) + for var in dvs + if (op = operation(var)) isa Shift && !haskey(trueu0map, var) + root = arguments(var)[1] + haskey(defs, root) || error("Initial condition for $var not provided.") + trueu0map[var] = defs[root] + end + end + @show trueu0map u0map + if has_index_cache(sys) && get_index_cache(sys) !== nothing + u0, defs = get_u0(sys, trueu0map, parammap) + p = MTKParameters(sys, parammap, trueu0map) + else + u0, p, defs = get_u0_p(sys, trueu0map, parammap; tofloat, use_union) + end + + check_eqs_u0(eqs, dvs, u0; kwargs...) + + f = constructor(sys, dvs, ps, u0; + linenumbers = linenumbers, parallel = parallel, + syms = Symbol.(dvs), paramsyms = Symbol.(ps), + eval_expression = eval_expression, kwargs...) + return f, u0, p +end + +""" + $(TYPEDSIGNATURES) +Generates an DiscreteProblem from an DiscreteSystem. +""" +function SciMLBase.DiscreteProblem( + sys::DiscreteSystem, u0map = [], tspan = get_tspan(sys), + parammap = SciMLBase.NullParameters(); + eval_module = @__MODULE__, + eval_expression = true, + use_union = false, + kwargs... +) + if !iscomplete(sys) + error("A completed `DiscreteSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`") + end + dvs = unknowns(sys) + ps = parameters(sys) + eqs = equations(sys) + iv = get_iv(sys) + + f, u0, p = process_DiscreteProblem( + DiscreteFunction, sys, u0map, parammap; eval_expression, eval_module) + u0 = f(u0, p, tspan[1]) + DiscreteProblem(f, u0, tspan, p; kwargs...) +end + +function SciMLBase.DiscreteFunction(sys::DiscreteSystem, args...; kwargs...) + DiscreteFunction{true}(sys, args...; kwargs...) +end + +function SciMLBase.DiscreteFunction{true}(sys::DiscreteSystem, args...; kwargs...) + DiscreteFunction{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...) +end + +function SciMLBase.DiscreteFunction{false}(sys::DiscreteSystem, args...; kwargs...) + DiscreteFunction{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...) +end +function SciMLBase.DiscreteFunction{iip, specialize}( + sys::DiscreteSystem, + dvs = unknowns(sys), + ps = full_parameters(sys), + u0 = nothing; + version = nothing, + p = nothing, + t = nothing, + eval_expression = true, + eval_module = @__MODULE__, + analytic = nothing, + kwargs...) where {iip, specialize} + if !iscomplete(sys) + error("A completed `DiscreteSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`") + end + f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression}, + expression_module = eval_module, kwargs...) + f_oop, f_iip = eval_expression ? + (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) : + f_gen + f(u, p, t) = f_oop(u, p, t) + f(du, u, p, t) = f_iip(du, u, p, t) + + if specialize === SciMLBase.FunctionWrapperSpecialize && iip + if u0 === nothing || p === nothing || t === nothing + error("u0, p, and t must be specified for FunctionWrapperSpecialize on DiscreteFunction.") + end + f = SciMLBase.wrapfun_iip(f, (u0, u0, p, t)) + end + + observedfun = let sys = sys, dict = Dict() + function generate_observed(obsvar, u, p, t) + obs = get!(dict, value(obsvar)) do + build_explicit_observed_function(sys, obsvar) + end + p isa MTKParameters ? obs(u, p..., t) : obs(u, p, t) + end + end + + DiscreteFunction{iip, specialize}(f; + sys = sys, + observed = observedfun, + analytic = analytic) +end + +""" +```julia +DiscreteFunctionExpr{iip}(sys::DiscreteSystem, dvs = states(sys), + ps = parameters(sys); + version = nothing, + kwargs...) where {iip} +``` + +Create a Julia expression for an `DiscreteFunction` from the [`DiscreteSystem`](@ref). +The arguments `dvs` and `ps` are used to set the order of the dependent +variable and parameter vectors, respectively. +""" +struct DiscreteFunctionExpr{iip} end +struct DiscreteFunctionClosure{O, I} <: Function + f_oop::O + f_iip::I +end +(f::DiscreteFunctionClosure)(u, p, t) = f.f_oop(u, p, t) +(f::DiscreteFunctionClosure)(du, u, p, t) = f.f_iip(du, u, p, t) + +function DiscreteFunctionExpr{iip}(sys::DiscreteSystem, dvs = unknowns(sys), + ps = parameters(sys), u0 = nothing; + version = nothing, p = nothing, + linenumbers = false, + simplify = false, + kwargs...) where {iip} + f_oop, f_iip = generate_function(sys, dvs, ps; expression = Val{true}, kwargs...) + + fsym = gensym(:f) + _f = :($fsym = $DiscreteFunctionClosure($f_oop, $f_iip)) + + ex = quote + $_f + DiscreteFunction{$iip}($fsym) + end + !linenumbers ? Base.remove_linenums!(ex) : ex +end + +function DiscreteFunctionExpr(sys::DiscreteSystem, args...; kwargs...) + DiscreteFunctionExpr{true}(sys, args...; kwargs...) +end diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 7d3969bd02..54285a431f 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -327,24 +327,6 @@ function TearingState(sys; quick_cancel = false, check = true) dvar = var idx = varidx - if ModelingToolkit.isoperator(dvar, ModelingToolkit.Shift) - if !(idx in dervaridxs) - push!(dervaridxs, idx) - end - op = operation(dvar) - tt = op.t - steps = op.steps - v = arguments(dvar)[1] - for s in (steps - 1):-1:1 - sf = Shift(tt, s) - dvar = sf(v) - idx = addvar!(dvar) - if !(idx in dervaridxs) - push!(dervaridxs, idx) - end - end - idx = addvar!(v) - end if istree(var) && operation(var) isa Symbolics.Operator && !isdifferential(var) && (it = input_timedomain(var)) !== nothing @@ -364,14 +346,51 @@ function TearingState(sys; quick_cancel = false, check = true) eqs[i] = eqs[i].lhs ~ rhs end end - + lowest_shift = Dict() + for var in fullvars + if ModelingToolkit.isoperator(var, ModelingToolkit.Shift) + steps = operation(var).steps + if steps > 0 + error("Only non-positive shifts allowed. Found $var with a shift of $steps") + end + v = arguments(var)[1] + lowest_shift[v] = min(get(lowest_shift, v, 0), steps) + end + end + for var in fullvars + if ModelingToolkit.isoperator(var, ModelingToolkit.Shift) + op = operation(var) + steps = op.steps + v = arguments(var)[1] + lshift = lowest_shift[v] + tt = op.t + elseif haskey(lowest_shift, var) + lshift = lowest_shift[var] + steps = 0 + tt = iv + v = var + else + continue + end + if lshift < steps + push!(dervaridxs, var2idx[var]) + end + for s in (steps - 1):-1:(lshift + 1) + sf = Shift(tt, s) + dvar = sf(v) + idx = addvar!(dvar) + if !(idx in dervaridxs) + push!(dervaridxs, idx) + end + end + end # sort `fullvars` such that the mass matrix is as diagonal as possible. dervaridxs = collect(dervaridxs) sorted_fullvars = OrderedSet(fullvars[dervaridxs]) var_to_old_var = Dict(zip(fullvars, fullvars)) for dervaridx in dervaridxs dervar = fullvars[dervaridx] - diffvar = var_to_old_var[lower_order_var(dervar)] + diffvar = var_to_old_var[lower_order_var(dervar, iv)] if !(diffvar in sorted_fullvars) push!(sorted_fullvars, diffvar) end @@ -393,7 +412,7 @@ function TearingState(sys; quick_cancel = false, check = true) var_to_diff = DiffGraph(nvars, true) for dervaridx in dervaridxs dervar = fullvars[dervaridx] - diffvar = lower_order_var(dervar) + diffvar = lower_order_var(dervar, iv) diffvaridx = var2idx[diffvar] push!(diffvars, diffvar) var_to_diff[diffvaridx] = dervaridx @@ -409,28 +428,58 @@ function TearingState(sys; quick_cancel = false, check = true) eq_to_diff = DiffGraph(nsrcs(graph)) - return TearingState(sys, fullvars, + ts = TearingState(sys, fullvars, SystemStructure(complete(var_to_diff), complete(eq_to_diff), - complete(graph), nothing, var_types, false), + complete(graph), nothing, var_types, sys isa DiscreteSystem), Any[]) + if sys isa DiscreteSystem + ts = shift_discrete_system(ts) + end + return ts end -function lower_order_var(dervar) +function lower_order_var(dervar, t) if isdifferential(dervar) diffvar = arguments(dervar)[1] - else # shift + elseif ModelingToolkit.isoperator(dervar, ModelingToolkit.Shift) s = operation(dervar) step = s.steps - 1 vv = arguments(dervar)[1] - if step >= 1 + if step != 0 diffvar = Shift(s.t, step)(vv) else diffvar = vv end + else + return Shift(t, -1)(dervar) end diffvar end +function shift_discrete_system(ts::TearingState) + @unpack fullvars, sys = ts + discvars = OrderedSet() + eqs = equations(sys) + for eq in eqs + vars!(discvars, eq; op = Union{Sample, Hold}) + end + iv = get_iv(sys) + discmap = Dict(k => StructuralTransformations.simplify_shifts(Shift(iv, 1)(k)) + for k in discvars + if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold})) + for i in eachindex(fullvars) + fullvars[i] = StructuralTransformations.simplify_shifts(fast_substitute( + fullvars[i], discmap; operator = Union{Sample, Hold})) + end + for i in eachindex(eqs) + eqs[i] = StructuralTransformations.simplify_shifts(fast_substitute( + eqs[i], discmap; operator = Union{Sample, Hold})) + end + @set! ts.sys.eqs = eqs + @set! ts.fullvars = fullvars + return ts +end + using .BipartiteGraphs: Label, BipartiteAdjacencyList struct SystemStructurePrintMatrix <: AbstractMatrix{Union{Label, BipartiteAdjacencyList}} diff --git a/src/utils.jl b/src/utils.jl index 5fa79530aa..c8ea9ecaf3 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -243,11 +243,14 @@ function collect_var_to_name!(vars, xs) x = unwrap(x) if hasmetadata(x, Symbolics.GetindexParent) xarr = getmetadata(x, Symbolics.GetindexParent) + hasname(xarr) || continue vars[Symbolics.getname(xarr)] = xarr else if istree(x) && operation(x) === getindex x = arguments(x)[1] end + x = unwrap(x) + hasname(x) || continue vars[Symbolics.getname(unwrap(x))] = x end end @@ -434,11 +437,11 @@ function find_derivatives!(vars, expr, f) return vars end -function collect_vars!(unknowns, parameters, expr, iv) +function collect_vars!(unknowns, parameters, expr, iv; op = Differential) if issym(expr) collect_var!(unknowns, parameters, expr, iv) else - for var in vars(expr) + for var in vars(expr; op) if istree(var) && operation(var) isa Differential var, _ = var_from_nested_derivative(var) end @@ -799,35 +802,76 @@ end # Symbolics needs to call unwrap on the substitution rules, but most of the time # we don't want to do that in MTK. const Eq = Union{Equation, Inequality} -function fast_substitute(eq::Eq, subs) +function fast_substitute(eq::Eq, subs; operator = Nothing) if eq isa Inequality - Inequality(fast_substitute(eq.lhs, subs), fast_substitute(eq.rhs, subs), + Inequality(fast_substitute(eq.lhs, subs; operator), + fast_substitute(eq.rhs, subs; operator), eq.relational_op) else - Equation(fast_substitute(eq.lhs, subs), fast_substitute(eq.rhs, subs)) + Equation(fast_substitute(eq.lhs, subs; operator), + fast_substitute(eq.rhs, subs; operator)) end end -function fast_substitute(eq::T, subs::Pair) where {T <: Eq} - T(fast_substitute(eq.lhs, subs), fast_substitute(eq.rhs, subs)) +function fast_substitute(eq::T, subs::Pair; operator = Nothing) where {T <: Eq} + T(fast_substitute(eq.lhs, subs; operator), fast_substitute(eq.rhs, subs; operator)) end -fast_substitute(eqs::AbstractArray, subs) = fast_substitute.(eqs, (subs,)) -fast_substitute(a, b) = substitute(a, b) -function fast_substitute(expr, pair::Pair) +function fast_substitute(eqs::AbstractArray, subs; operator = Nothing) + fast_substitute.(eqs, (subs,); operator) +end +function fast_substitute(eqs::AbstractArray, subs::Pair; operator = Nothing) + fast_substitute.(eqs, (subs,); operator) +end +for (exprType, subsType) in Iterators.product((Num, Symbolics.Arr), (Any, Pair)) + @eval function fast_substitute(expr::$exprType, subs::$subsType; operator = Nothing) + fast_substitute(value(expr), subs; operator) + end +end +function fast_substitute(expr, subs; operator = Nothing) + if (_val = get(subs, expr, nothing)) !== nothing + return _val + end + istree(expr) || return expr + op = fast_substitute(operation(expr), subs; operator) + args = SymbolicUtils.unsorted_arguments(expr) + if !(op isa operator) + canfold = Ref(!(op isa Symbolic)) + args = let canfold = canfold + map(args) do x + x′ = fast_substitute(x, subs; operator) + canfold[] = canfold[] && !(x′ isa Symbolic) + x′ + end + end + canfold[] && return op(args...) + end + similarterm(expr, + op, + args, + symtype(expr); + metadata = metadata(expr)) +end +function fast_substitute(expr, pair::Pair; operator = Nothing) a, b = pair isequal(expr, a) && return b - + if a isa AbstractArray + for (ai, bi) in zip(a, b) + expr = fast_substitute(expr, ai => bi; operator) + end + end istree(expr) || return expr - op = fast_substitute(operation(expr), pair) - canfold = Ref(!(op isa Symbolic)) - args = let canfold = canfold - map(SymbolicUtils.unsorted_arguments(expr)) do x - x′ = fast_substitute(x, pair) - canfold[] = canfold[] && !(x′ isa Symbolic) - x′ + op = fast_substitute(operation(expr), pair; operator) + args = SymbolicUtils.unsorted_arguments(expr) + if !(op isa operator) + canfold = Ref(!(op isa Symbolic)) + args = let canfold = canfold + map(args) do x + x′ = fast_substitute(x, pair; operator) + canfold[] = canfold[] && !(x′ isa Symbolic) + x′ + end end + canfold[] && return op(args...) end - canfold[] && return op(args...) - similarterm(expr, op, args, diff --git a/src/variables.jl b/src/variables.jl index c6b2f67d58..3fbec6bfd1 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -104,6 +104,9 @@ state_priority(x) = convert(Float64, getmetadata(x, VariableStatePriority, 0.0)) function default_toterm(x) if istree(x) && (op = operation(x)) isa Operator if !(op isa Differential) + if op isa Shift && op.steps < 0 + return x + end x = normalize_to_differential(op)(arguments(x)...) end Symbolics.diff2term(x) @@ -192,7 +195,8 @@ function _varmap_to_vars(varmap::Dict, varlist; defaults = Dict(), check = false values = Dict() for var in varlist var = unwrap(var) - val = unwrap(fixpoint_sub(fixpoint_sub(var, varmap), defaults)) + val = unwrap(fixpoint_sub(fixpoint_sub(var, varmap; operator = Symbolics.Operator), + defaults; operator = Symbolics.Operator)) if symbolic_type(val) === NotSymbolic() values[var] = val end diff --git a/test/clock.jl b/test/clock.jl index bb2b52e3ea..b7964909d0 100644 --- a/test/clock.jl +++ b/test/clock.jl @@ -69,7 +69,10 @@ sss, = ModelingToolkit._structural_simplify!(deepcopy(tss[1]), (inputs[1], ())) @test equations(sss) == [D(x) ~ u - x] sss, = ModelingToolkit._structural_simplify!(deepcopy(tss[2]), (inputs[2], ())) @test isempty(equations(sss)) -@test observed(sss) == [yd ~ Sample(t, dt)(y); r ~ 1.0; ud ~ kp * (r - yd)] +d = Clock(t, dt) +k = ShiftIndex(d) +@test observed(sss) == [yd(k + 1) ~ Sample(t, dt)(y); r(k + 1) ~ 1.0; + ud(k + 1) ~ kp * (r(k + 1) - yd(k + 1))] d = Clock(t, dt) # Note that TearingState reorders the equations @@ -89,40 +92,38 @@ d = Clock(t, dt) @info "Testing shift normalization" dt = 0.1 -@variables x(t) y(t) u(t) yd(t) ud(t) r(t) z(t) +@variables x(t) y(t) u(t) yd(t) ud(t) @parameters kp d = Clock(t, dt) k = ShiftIndex(d) eqs = [yd ~ Sample(t, dt)(y) - ud ~ kp * (r - yd) + z(k) - r ~ 1.0 + ud ~ kp * yd + ud(k - 2) # plant (time continuous part) u ~ Hold(ud) D(x) ~ -x + u - y ~ x - z(k + 2) ~ z(k) + yd - #= - z(k + 2) ~ z(k) + yd - => - z′(k + 1) ~ z(k) + yd - z(k + 1) ~ z′(k) - =# - ] + y ~ x] @named sys = ODESystem(eqs, t) ss = structural_simplify(sys); Tf = 1.0 -prob = ODEProblem(ss, [x => 0.0, y => 0.0], (0.0, Tf), - [kp => 1.0; z => 3.0; z(k + 1) => 2.0]) -@test sort(vcat(prob.p...)) == [0, 1.0, 2.0, 3.0, 4.0] # yd, kp, z(k+1), z(k), ud +prob = ODEProblem(ss, [x => 0.1], (0.0, Tf), + [kp => 1.0; ud(k - 1) => 2.1; ud(k - 2) => 2.0]) +# create integrator so callback is evaluated at t=0 and we can test correct param values +int = init(prob, Tsit5(); kwargshandle = KeywordArgSilent) +@test sort(vcat(int.p...)) == [0.1, 1.0, 2.1, 2.1, 2.1] # yd, kp, ud(k-1), ud, Hold(ud) +prob = ODEProblem(ss, [x => 0.1], (0.0, Tf), + [kp => 1.0; ud(k - 1) => 2.1; ud(k - 2) => 2.0]) # recreate problem to empty saved values sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent) ss_nosplit = structural_simplify(sys; split = false) -prob_nosplit = ODEProblem(ss_nosplit, [x => 0.0, y => 0.0], (0.0, Tf), - [kp => 1.0; z => 3.0; z(k + 1) => 2.0]) -@test sort(prob_nosplit.p) == [0, 1.0, 2.0, 3.0, 4.0] # yd, kp, z(k+1), z(k), ud +prob_nosplit = ODEProblem(ss_nosplit, [x => 0.1], (0.0, Tf), + [kp => 1.0; ud(k - 1) => 2.1; ud(k - 2) => 2.0]) +int = init(prob_nosplit, Tsit5(); kwargshandle = KeywordArgSilent) +@test sort(int.p) == [0.1, 1.0, 2.1, 2.1, 2.1] # yd, kp, ud(k-1), ud, Hold(ud) +prob_nosplit = ODEProblem(ss_nosplit, [x => 0.1], (0.0, Tf), + [kp => 1.0; ud(k - 1) => 2.1; ud(k - 2) => 2.0]) # recreate problem to empty saved values sol_nosplit = solve(prob_nosplit, Tsit5(), kwargshandle = KeywordArgSilent) # For all inputs in parameters, just initialize them to 0.0, and then set them # in the callback. @@ -134,30 +135,24 @@ function foo!(du, u, p, t) du[1] = -x + ud end function affect!(integrator, saved_values) - z_t, z = integrator.p[3], integrator.p[4] yd = integrator.u[1] kp = integrator.p[1] - r = 1.0 + ud = integrator.p[2] + udd = integrator.p[3] - push!(saved_values.t, integrator.t) - push!(saved_values.saveval, [z_t, z]) - - # Update the discrete state - z_t, z = z + yd, z_t - # @show z_t, z - integrator.p[3] = z_t - integrator.p[4] = z + integrator.p[2] = kp * yd + udd + integrator.p[3] = ud - ud = kp * (r - yd) + z - integrator.p[2] = ud + push!(saved_values.t, integrator.t) + push!(saved_values.saveval, [integrator.p[2], integrator.p[3]]) nothing end saved_values = SavedValues(Float64, Vector{Float64}) -cb = PeriodicCallback(Base.Fix2(affect!, saved_values), 0.1) -# kp ud z_t z -prob = ODEProblem(foo!, [0.0], (0.0, Tf), [1.0, 4.0, 2.0, 3.0], callback = cb) -# ud initializes to kp * (r - yd) + z = 1 * (1 - 0) + 3 = 4 +cb = PeriodicCallback( + Base.Fix2(affect!, saved_values), 0.1; final_affect = true, initial_affect = true) +# kp ud +prob = ODEProblem(foo!, [0.1], (0.0, Tf), [1.0, 2.1, 2.0], callback = cb) sol2 = solve(prob, Tsit5()) @test sol.u == sol2.u @test sol_nosplit.u == sol2.u @@ -217,7 +212,7 @@ end function filt(; name) @variables x(t)=0 u(t)=0 y(t)=0 a = 1 / exp(dt) - eqs = [x(k + 1) ~ a * x + (1 - a) * u(k) + eqs = [x ~ a * x(k - 1) + (1 - a) * u(k - 1) y ~ x] ODESystem(eqs, t, name = name) end @@ -318,8 +313,8 @@ if VERSION >= v"1.7" integrator.p[3] = ud2 nothing end - cb1 = PeriodicCallback(affect1!, dt) - cb2 = PeriodicCallback(affect2!, dt2) + cb1 = PeriodicCallback(affect1!, dt; final_affect = true, initial_affect = true) + cb2 = PeriodicCallback(affect2!, dt2; final_affect = true, initial_affect = true) cb = CallbackSet(cb1, cb2) # kp ud1 ud2 prob = ODEProblem(foo!, [0.0], (0.0, 1.0), [1.0, 1.0, 1.0], callback = cb) @@ -423,7 +418,7 @@ ci, varmap = infer_clocks(expand_connections(_model)) @test varmap[_model.feedback.output.u] == d @test varmap[_model.feedback.input2.u] == d -@test_skip ssys = structural_simplify(model) +ssys = structural_simplify(model) Tf = 0.2 timevec = 0:(d.dt):Tf @@ -445,20 +440,20 @@ y = res.y[:] # ref = Constant(k = 0.5) # ; model.controller.x(k-1) => 0.0 +prob = ODEProblem(ssys, + [model.plant.x => 0.0; model.controller.kp => 2.0; model.controller.ki => 2.0], + (0.0, Tf)) +int = init(prob, Tsit5(); kwargshandle = KeywordArgSilent) +@test int.ps[Hold(ssys.holder.input.u)] == 2 # constant output * kp issue https://github.com/SciML/ModelingToolkit.jl/issues/2356 +@test int.ps[ssys.controller.x] == 1 # c2d +@test int.ps[Sample(d)(ssys.sampler.input.u)] == 0 # disc state +sol = solve(prob, + Tsit5(), + kwargshandle = KeywordArgSilent, + abstol = 1e-8, + reltol = 1e-8) @test_skip begin - prob = ODEProblem(ssys, - [model.plant.x => 0.0; model.controller.kp => 2.0; model.controller.ki => 2.0], - (0.0, Tf)) - - @test prob.p[9] == 1 # constant output * kp issue https://github.com/SciML/ModelingToolkit.jl/issues/2356 - @test prob.p[10] == 0 # c2d - @test prob.p[11] == 0 # disc state - sol = solve(prob, - Tsit5(), - kwargshandle = KeywordArgSilent, - abstol = 1e-8, - reltol = 1e-8) - plot([y sol(timevec, idxs = model.plant.output.u)], m = :o, lab = ["CS" "MTK"]) + # plot([y sol(timevec, idxs = model.plant.output.u)], m = :o, lab = ["CS" "MTK"]) ## @@ -487,9 +482,11 @@ k = ShiftIndex(c) @variables begin count(t) = 0 u(t) = 0 + ud(t) = 0 end @equations begin - count(k + 1) ~ Sample(c)(u) + ud ~ Sample(c)(u) + count ~ ud(k - 1) end end @@ -517,3 +514,20 @@ prob = ODEProblem(model, [], (0.0, 10.0)) sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent) @test sol.prob.kwargs[:disc_saved_values][1].t == sol.t[1:2:end] # Test that the discrete-tiem system executed at every step of the continuous solver. The solver saves each time step twice, one state value before discrete affect and one after. +@test_nowarn ModelingToolkit.build_explicit_observed_function( + model, model.counter.ud)(sol.u[1], prob.p..., sol.t[1]) + +@variables x(t)=1.0 y(t)=1.0 +eqs = [D(y) ~ Hold(x) + x ~ x(k - 1) + x(k - 2)] +@mtkbuild sys = ODESystem(eqs, t) +prob = ODEProblem(sys, [], (0.0, 10.0)) +int = init(prob, Tsit5(); kwargshandle = KeywordArgSilent) +@test int.ps[x] == 2.0 +@test int.ps[x(k - 1)] == 1.0 + +@test_throws ErrorException ODEProblem(sys, [], (0.0, 10.0), [x => 2.0]) +prob = ODEProblem(sys, [], (0.0, 10.0), [x(k - 1) => 2.0]) +int = init(prob, Tsit5(); kwargshandle = KeywordArgSilent) +@test int.ps[x] == 3.0 +@test int.ps[x(k - 1)] == 2.0 diff --git a/test/discrete_system.jl b/test/discrete_system.jl new file mode 100644 index 0000000000..3c2238bca3 --- /dev/null +++ b/test/discrete_system.jl @@ -0,0 +1,236 @@ +# Example: Compartmental models in epidemiology +#= +- https://github.com/epirecipes/sir-julia/blob/master/markdown/function_map/function_map.md +- https://en.wikipedia.org/wiki/Compartmental_models_in_epidemiology#Deterministic_versus_stochastic_epidemic_models +=# +using ModelingToolkit, Test +using ModelingToolkit: t_nounits as t +using ModelingToolkit: get_metadata, MTKParameters + +# Make sure positive shifts error +@variables x(t) +k = ShiftIndex(t) +@test_throws ErrorException @mtkbuild sys = DiscreteSystem([x(k + 1) ~ x + x(k - 1)], t) + +@inline function rate_to_proportion(r, t) + 1 - exp(-r * t) +end; + +# Independent and dependent variables and parameters +@parameters c nsteps δt β γ +@constants h = 1 +@variables S(t) I(t) R(t) +infection = rate_to_proportion( + β * c * I(k - 1) / (S(k - 1) * h + I(k - 1) + R(k - 1)), δt * h) * S(k - 1) +recovery = rate_to_proportion(γ * h, δt) * I(k - 1) + +# Equations +eqs = [S ~ S(k - 1) - infection * h, + I ~ I(k - 1) + infection - recovery, + R ~ R(k - 1) + recovery] + +# System +@named sys = DiscreteSystem(eqs, t, [S, I, R], [c, nsteps, δt, β, γ]) +syss = structural_simplify(sys) +@test syss == syss + +for df in [ + DiscreteFunction(syss), + eval(DiscreteFunctionExpr(syss)) +] + + # iip + du = zeros(3) + u = collect(1:3) + p = MTKParameters(syss, parameters(syss) .=> collect(1:5)) + df.f(du, u, p, 0) + @test du ≈ [0.01831563888873422, 0.9816849729159067, 4.999999388195359] + + # oop + @test df.f(u, p, 0) ≈ [0.01831563888873422, 0.9816849729159067, 4.999999388195359] +end + +# Problem +u0 = [S(k - 1) => 990.0, I(k - 1) => 10.0, R(k - 1) => 0.0] +p = [β => 0.05, c => 10.0, γ => 0.25, δt => 0.1, nsteps => 400] +tspan = (0.0, ModelingToolkit.value(substitute(nsteps, p))) # value function (from Symbolics) is used to convert a Num to Float64 +prob_map = DiscreteProblem(syss, u0, tspan, p) +@test prob_map.f.sys === syss + +# Solution +using OrdinaryDiffEq +sol_map = solve(prob_map, FunctionMap()); +@test sol_map[S] isa Vector + +# Using defaults constructor +@parameters c=10.0 nsteps=400 δt=0.1 β=0.05 γ=0.25 +@variables S(t)=990.0 I(t)=10.0 R(t)=0.0 R2(t) + +infection2 = rate_to_proportion(β * c * I(k - 1) / (S(k - 1) + I(k - 1) + R(k - 1)), δt) * + S(k - 1) +recovery2 = rate_to_proportion(γ, δt) * I(k - 1) + +eqs2 = [S ~ S(k - 1) - infection2, + I ~ I(k - 1) + infection2 - recovery2, + R ~ R(k - 1) + recovery2, + R2 ~ R] + +@mtkbuild sys = DiscreteSystem( + eqs2, t, [S, I, R, R2], [c, nsteps, δt, β, γ]; controls = [β, γ], tspan) +@test ModelingToolkit.defaults(sys) != Dict() + +prob_map2 = DiscreteProblem(sys) +sol_map2 = solve(prob_map2, FunctionMap()); + +@test sol_map.u ≈ sol_map2.u +@test sol_map.prob.p == sol_map2.prob.p +@test_throws Any sol_map2[R2] +@test sol_map2[R2(k + 1)][begin:(end - 1)] == sol_map2[R][(begin + 1):end] +# Direct Implementation + +function sir_map!(u_diff, u, p, t) + (S, I, R) = u + (β, c, γ, δt) = p + N = S + I + R + infection = rate_to_proportion(β * c * I / N, δt) * S + recovery = rate_to_proportion(γ, δt) * I + @inbounds begin + u_diff[1] = S - infection + u_diff[2] = I + infection - recovery + u_diff[3] = R + recovery + end + nothing +end; +u0 = prob_map2.u0; +p = [0.05, 10.0, 0.25, 0.1]; +prob_map = DiscreteProblem(sir_map!, u0, tspan, p); +sol_map2 = solve(prob_map, FunctionMap()); + +@test Array(sol_map) ≈ Array(sol_map2) + +# Delayed difference equation +# @variables x(..) y(..) z(t) +# D1 = Difference(t; dt = 1.5) +# D2 = Difference(t; dt = 2) + +# @test ModelingToolkit.is_delay_var(Symbolics.value(t), Symbolics.value(x(t - 2))) +# @test ModelingToolkit.is_delay_var(Symbolics.value(t), Symbolics.value(y(t - 1))) +# @test !ModelingToolkit.is_delay_var(Symbolics.value(t), Symbolics.value(z)) +# @test_throws ErrorException ModelingToolkit.get_delay_val(Symbolics.value(t), +# Symbolics.arguments(Symbolics.value(x(t + +# 2)))[1]) +# @test_throws ErrorException z(t) + +# # Equations +# eqs = [ +# D1(x(t)) ~ 0.4x(t) + 0.3x(t - 1.5) + 0.1x(t - 3), +# D2(y(t)) ~ 0.3y(t) + 0.7y(t - 2) + 0.1z * h, +# ] + +# # System +# @named sys = DiscreteSystem(eqs, t, [x(t), x(t - 1.5), x(t - 3), y(t), y(t - 2), z], []) + +# eqs2, max_delay = ModelingToolkit.linearize_eqs(sys; return_max_delay = true) + +# @test max_delay[Symbolics.operation(Symbolics.value(x(t)))] ≈ 3 +# @test max_delay[Symbolics.operation(Symbolics.value(y(t)))] ≈ 2 + +# linearized_eqs = [eqs +# x(t - 3.0) ~ x(t - 1.5) +# x(t - 1.5) ~ x(t) +# y(t - 2.0) ~ y(t)] +# @test all(eqs2 .== linearized_eqs) + +# observed variable handling +@variables x(t) RHS(t) +@parameters τ +@named fol = DiscreteSystem( + [x ~ (1 - x(k - 1)) / τ], t, [x, RHS], [τ]; observed = [RHS ~ (1 - x) / τ * h]) +@test isequal(RHS, @nonamespace fol.RHS) +RHS2 = RHS +@unpack RHS = fol +@test isequal(RHS, RHS2) + +# @testset "Preface tests" begin +# using OrdinaryDiffEq +# using Symbolics +# using DiffEqBase: isinplace +# using ModelingToolkit +# using SymbolicUtils.Code +# using SymbolicUtils: Sym + +# c = [0] +# f = function f(c, d::Vector{Float64}, u::Vector{Float64}, p, t::Float64, dt::Float64) +# c .= [c[1] + 1] +# d .= randn(length(u)) +# nothing +# end + +# dummy_identity(x, _) = x +# @register_symbolic dummy_identity(x, y) + +# u0 = ones(5) +# p0 = Float64[] +# syms = [Symbol(:a, i) for i in 1:5] +# syms_p = Symbol[] +# dt = 0.1 +# @assert isinplace(f, 6) +# wf = let c = c, buffer = similar(u0), u = similar(u0), p = similar(p0), dt = dt +# t -> (f(c, buffer, u, p, t, dt); buffer) +# end + +# num = hash(f) ⊻ length(u0) ⊻ length(p0) +# buffername = Symbol(:fmi_buffer_, num) + +# Δ = DiscreteUpdate(t; dt = dt) +# us = map(s -> (@variables $s(t))[1], syms) +# ps = map(s -> (@variables $s(t))[1], syms_p) +# buffer, = @variables $buffername[1:length(u0)] +# dummy_var = Sym{Any}(:_) # this is safe because _ cannot be a rvalue in Julia + +# ss = Iterators.flatten((us, ps)) +# vv = Iterators.flatten((u0, p0)) +# defs = Dict{Any, Any}(s => v for (s, v) in zip(ss, vv)) + +# preface = [Assignment(dummy_var, SetArray(true, term(getfield, wf, Meta.quot(:u)), us)) +# Assignment(dummy_var, SetArray(true, term(getfield, wf, Meta.quot(:p)), ps)) +# Assignment(buffer, term(wf, t))] +# eqs = map(1:length(us)) do i +# Δ(us[i]) ~ dummy_identity(buffer[i], us[i]) +# end + +# @mtkbuild sys = DiscreteSystem(eqs, t, us, ps; defaults = defs, preface = preface) +# prob = DiscreteProblem(sys, [], (0.0, 1.0)) +# sol = solve(prob, FunctionMap(); dt = dt) +# @test c[1] + 1 == length(sol) +# end + +@variables x(t) y(t) +testdict = Dict([:test => 1]) +@named sys = DiscreteSystem([x(k + 1) ~ 1.0], t, [x], []; metadata = testdict) +@test get_metadata(sys) == testdict + +@variables x(t) y(t) u(t) +eqs = [u ~ 1 + x ~ x(k - 1) + u + y ~ x + u] +@mtkbuild de = DiscreteSystem(eqs, t) +prob = DiscreteProblem(de, [x(k - 1) => 0.0], (0, 10)) +sol = solve(prob, FunctionMap()) + +@test reduce(vcat, sol.u) == 1:11 + +# test that default values apply to the entire history +@variables x(t) = 1.0 +@mtkbuild de = DiscreteSystem([x ~ x(k - 1) + x(k - 2)], t) +prob = DiscreteProblem(de, [], (0, 10)) +@test prob[x] == 2.0 +@test prob[x(k - 1)] == 1.0 + +# must provide initial conditions for history +@test_throws ErrorException DiscreteProblem(de, [x => 2.0], (0, 10)) + +# initial values only affect _that timestep_, not the entire history +prob = DiscreteProblem(de, [x(k - 1) => 2.0], (0, 10)) +@test prob[x] == 3.0 +@test prob[x(k - 1)] == 2.0 diff --git a/test/parameter_dependencies.jl b/test/parameter_dependencies.jl index 78538d23c7..d55043383a 100644 --- a/test/parameter_dependencies.jl +++ b/test/parameter_dependencies.jl @@ -61,23 +61,23 @@ end u ~ Hold(ud) D(x) ~ -x + u y ~ x - z(k + 2) ~ z(k) + yd] + z(k) ~ z(k - 2) + yd(k - 2)] @mtkbuild sys = ODESystem(eqs, t; parameter_dependencies = [kq => 2kp]) Tf = 1.0 prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf), - [kp => 1.0; z => 3.0; z(k + 1) => 2.0]) + [kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0]) @test_nowarn solve(prob, Tsit5(); kwargshandle = KeywordArgSilent) @mtkbuild sys = ODESystem(eqs, t; parameter_dependencies = [kq => 2kp], discrete_events = [[0.5] => [kp ~ 2.0]]) prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf), - [kp => 1.0; z => 3.0; z(k + 1) => 2.0]) + [kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0]) @test prob.ps[kp] == 1.0 @test prob.ps[kq] == 2.0 @test_nowarn solve(prob, Tsit5(), kwargshandle = KeywordArgSilent) prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf), - [kp => 1.0; z => 3.0; z(k + 1) => 2.0]) + [kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0]) integ = init(prob, Tsit5(), kwargshandle = KeywordArgSilent) @test integ.ps[kp] == 1.0 @test integ.ps[kq] == 2.0 diff --git a/test/runtests.jl b/test/runtests.jl index c0b44a30de..80c3aa4309 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -68,6 +68,7 @@ end @safetestset "Parameter Dependency Test" include("parameter_dependencies.jl") @safetestset "Generate Custom Function Test" include("generate_custom_function.jl") @safetestset "Initial Values Test" include("initial_values.jl") + @safetestset "Discrete System" include("discrete_system.jl") end end