diff --git a/src/Symbolics.jl b/src/Symbolics.jl index 82062a688..6338fddb4 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -50,6 +50,7 @@ import MacroTools: splitdef, combinedef, postwalk, striplines include("wrapper-types.jl") include("num.jl") +include("rewrite-helpers.jl") include("complex.jl") """ diff --git a/src/rewrite-helpers.jl b/src/rewrite-helpers.jl new file mode 100644 index 000000000..40229a1e0 --- /dev/null +++ b/src/rewrite-helpers.jl @@ -0,0 +1,85 @@ +""" + replace(expr::Symbolic, rules...) +Walk the expression and replace subexpressions according to `rules`. `rules` +could be rules constructed with `@rule`, a function, or a pair where the +left hand side is matched with equality (using `isequal`) and is replaced by the right hand side. + +Rules will be applied left-to-right simultaneously, +so only one pattern will be applied to any subexpression, +and the patterns will only be applied to the input text, +not the replacements. + +Set `fixpoint = true` to repeatedly apply rules until no +change to the expression remains to be made. +""" +function _replace(expr::Symbolic, rules...; fixpoint=false) + rs = map(r -> r isa Pair ? (x -> isequal(x, r[1]) ? r[2] : nothing) : r, rules) + R = Prewalk(Chain(rs)) + if fixpoint + Fixpoint(R)(expr) + else + R(expr) + end +end +# Fix ambiguity +function Base.replace(expr::Num, r::Pair, rules::Pair...) + _replace(unwrap(expr), r, rules...) +end + +function Base.replace(expr::Num, rules...) + _replace(unwrap(expr), rules...) +end + +function Base.replace(expr::Symbolic, r, rules...) + _replace(expr, r, rules) +end + + +Base.occursin(x::Num, y::Num) = occursin(unwrap(x), unwrap(y)) +@wrapped function Base.occursin(r::Any, y::Real) + y = unwrap(y) + if isequal(r, y) + return true + elseif r isa Function + if r(y) + return true + end + end + + if istree(y) + return r(operation(y)) || + any(y->occursin(r, y), arguments(y)) + else + return false + end +end + +function filterchildren!(r::Any, y::Union{Num, Symbolic}, acc) + y = unwrap(y) + if isequal(r, y) + push!(acc, y) + return acc + elseif r isa Function + if r(y) + push!(acc, y) + return acc + end + end + + if istree(y) + if r(operation(y)) + push!(acc, y) + end + foreach(c->filterchildren!(r, c, acc), + arguments(y)) + return acc + end +end + +filterchildren(r, y) = filterchildren!(r, y, []) + +module RewriteHelpers +import Symbolics: is_derivative, filterchildren, unwrap +export replace, occursin, is_derivative, + filterchildren, unwrap +end