From 078928d95db8bde78be04c6ca5dba07717bcee88 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 12 Mar 2024 11:15:59 +0530 Subject: [PATCH] feat: move fast_substitute to Symbolics, implement SII.symbolic_evaluate --- Project.toml | 2 +- src/variable.jl | 52 +++++++++++++++++++++++ test/symbolic_indexing_interface_trait.jl | 8 ++++ 3 files changed, 61 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d731f1d98..0a0b45cdd 100644 --- a/Project.toml +++ b/Project.toml @@ -80,7 +80,7 @@ SciMLBase = "2" Setfield = "1" SpecialFunctions = "2" StaticArrays = "1.1" -SymbolicIndexingInterface = "0.3" +SymbolicIndexingInterface = "0.3.11" SymbolicUtils = "1.4" julia = "1.10" diff --git a/src/variable.jl b/src/variable.jl index aba9ba502..b4a3285b9 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -410,6 +410,58 @@ end SymbolicIndexingInterface.getname(x, val=_fail) = _getname(unwrap(x), val) +function SymbolicIndexingInterface.symbolic_evaluate(ex::Union{Num, Arr, Symbolic}, d::Dict) + fixpoint_sub(ex, d) +end + +function fixpoint_sub(x, dict) + y = fast_substitute(x, dict) + while !isequal(x, y) + y = x + x = fast_substitute(y, dict) + end + + return x +end + +const Eq = Union{Equation, Inequality} +# substitute without unwrapping +function fast_substitute(eq::Eq, subs) + if eq isa Inequality + Inequality(fast_substitute(eq.lhs, subs), fast_substitute(eq.rhs, subs), + eq.relational_op) + else + Equation(fast_substitute(eq.lhs, subs), fast_substitute(eq.rhs, subs)) + end +end +function fast_substitute(eq::T, subs::Pair) where {T <: Eq} + T(fast_substitute(eq.lhs, subs), fast_substitute(eq.rhs, subs)) +end +fast_substitute(eqs::AbstractArray, subs) = fast_substitute.(eqs, (subs,)) +fast_substitute(a, b) = substitute(a, b) +function fast_substitute(expr, pair::Pair) + a, b = pair + isequal(expr, a) && return b + + istree(expr) || return expr + op = fast_substitute(operation(expr), pair) + canfold = Ref(!(op isa Symbolic)) + args = let canfold = canfold + map(SymbolicUtils.unsorted_arguments(expr)) do x + x′ = fast_substitute(x, pair) + canfold[] = canfold[] && !(x′ isa Symbolic) + x′ + end + end + canfold[] && return op(args...) + + similarterm(expr, + op, + args, + symtype(expr); + metadata = metadata(expr)) +end + function getparent(x, val=_fail) maybe_parent = getmetadata(x, Symbolics.GetindexParent, nothing) if maybe_parent !== nothing diff --git a/test/symbolic_indexing_interface_trait.jl b/test/symbolic_indexing_interface_trait.jl index 52d1579ae..6c8af1457 100644 --- a/test/symbolic_indexing_interface_trait.jl +++ b/test/symbolic_indexing_interface_trait.jl @@ -10,3 +10,11 @@ using SymbolicIndexingInterface @variables y[1:3] @test symbolic_type(y) == ArraySymbolic() @test all(symbolic_type.(collect(y)) .== (ScalarSymbolic(),)) + +@variables x y z +subs = Dict(x => 0.1, y => 2z) +subs2 = merge(subs, Dict(z => 2x+3)) + +@test symbolic_evaluate(x, subs) == 0.1 +@test isequal(symbolic_evaluate(y, subs), 2z) +@test symbolic_evaluate(y, subs2) == 6.4