From 3bb4ee098796da46d0ff671831acb1e243c74232 Mon Sep 17 00:00:00 2001 From: Torkel Date: Tue, 14 May 2024 15:01:58 -0400 Subject: [PATCH] up --- src/Symbolics.jl | 1 + src/rewrite-helpers.jl | 51 +++++++++++++++----------------- test/rewrite_helpers.jl | 64 ++++++++++++++++++++--------------------- 3 files changed, 56 insertions(+), 60 deletions(-) diff --git a/src/Symbolics.jl b/src/Symbolics.jl index b48673c71..8fec4628d 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -52,6 +52,7 @@ import MacroTools: splitdef, combinedef, postwalk, striplines include("wrapper-types.jl") include("num.jl") + include("rewrite-helpers.jl") include("complex.jl") diff --git a/src/rewrite-helpers.jl b/src/rewrite-helpers.jl index 4ac62e11e..a9ad3fe24 100644 --- a/src/rewrite-helpers.jl +++ b/src/rewrite-helpers.jl @@ -1,5 +1,5 @@ """ - replace(expr::Symbolic, rules...) +replacenode(expr::Symbolic, rules...) Walk the expression and replace subexpressions according to `rules`. `rules` could be rules constructed with `@rule`, a function, or a pair where the left hand side is matched with equality (using `isequal`) and is replaced by the right hand side. @@ -12,24 +12,24 @@ not the replacements. Set `fixpoint = true` to repeatedly apply rules until no change to the expression remains to be made. """ -function Base.replace(expr::Num, r::Pair, rules::Pair...) - _replace(unwrap(expr), r, rules...) +function replacenode(expr::Num, r::Pair, rules::Pair...) + _replacenode(unwrap(expr), r, rules...) end # Fix ambiguity -function Base.replace(expr::Num, rules...) - _replace(unwrap(expr), rules...) +function replacenode(expr::Num, rules...) + _replacenode(unwrap(expr), rules...) end -function Base.replace(expr::Symbolic, rules...) - _replace(unwrap(expr), rules...) +function replacenode(expr::Symbolic, rules...) + _replacenode(unwrap(expr), rules...) end -function Base.replace(expr::Symbolic, r::Pair, rules::Pair...) - _replace(expr, r, rules...) +function replacenode(expr::Symbolic, r::Pair, rules::Pair...) + _replacenode(expr, r, rules...) end -function _replace(expr::Symbolic, rules...; fixpoint=false) +function _replacenode(expr::Symbolic, rules...; fixpoint=false) rs = map(r -> r isa Pair ? (x -> isequal(x, r[1]) ? r[2] : nothing) : r, rules) R = Prewalk(Chain(rs)) if fixpoint @@ -40,7 +40,7 @@ function _replace(expr::Symbolic, rules...; fixpoint=false) end """ - occursin(c, x) + hasnode(c, x) Returns true if any part of `x` fufills the condition given in c. c can be a function or an expression. If it is a function, returns true if x is true for any part of x. If c is an expression, returns true if x contains c. @@ -48,29 +48,23 @@ true if x contains c. Examples: ```julia @syms x y -Symbolics.occursin(x, log(x) + x + 1) # returns `true`. -Symbolics.occursin(x, log(y) + y + 1) # returns `false`. +hasnode(x, log(x) + x + 1) # returns `true`. +hasnode(x, log(y) + y + 1) # returns `false`. ``` ```julia @variables t X(t) D = Differential(t) -Symbolics.occursin(Symbolics.is_derivative, X + D(X) + D(X^2)) # returns `true`. +hasnode(Symbolics.is_derivative, X + D(X) + D(X^2)) # returns `true`. ``` """ -function Base.occursin(r::Function, y::Num) - Symbolics._occursin(r, y) +function hasnode(r::Function, y::Union{Num, Symbolic}) + _hasnode(r, y) end -# Initially both these were created using `y::Union{Num, Symbolic}`. However, this produced -# ambiguity error due to something in SymbolicsBase. Hence the dual declarations here. -function Base.occursin(r::Function, y::Symbolics.Symbolic) - Symbolics._occursin(r, y) -end - -Base.occursin(r::Num, y::Num) = occursin(unwrap(r), unwrap(y)) -Base.occursin(r::Num, y::Symbolic) = occursin(unwrap(r), unwrap(y)) +hasnode(r::Num, y::Union{Num, Symbolic}) = occursin(unwrap(r), unwrap(y)) +hasnode(r::Symbolic, y::Union{Num, Symbolic}) = occursin(unwrap(r), unwrap(y)) -function _occursin(r, y) +function _hasnode(r, y) y = unwrap(y) if r isa Function if r(y) @@ -80,7 +74,7 @@ function _occursin(r, y) if istree(y) return r(operation(y)) || - any(y->_occursin(r, y), arguments(y)) + any(y->_hasnode(r, y), arguments(y)) else return false end @@ -134,6 +128,7 @@ function filterchildren!(r::Any, y, acc) end module RewriteHelpers -import Symbolics: filterchildren, unwrap -export replace, occursin, filterchildren, unwrap +import Symbolics: replacenode, hasnode, filterchildren, unwrap +export replacenode, hasnode, filterchildren, unwrap + end diff --git a/test/rewrite_helpers.jl b/test/rewrite_helpers.jl index d74abe8e5..1c96fb4fa 100644 --- a/test/rewrite_helpers.jl +++ b/test/rewrite_helpers.jl @@ -9,15 +9,15 @@ using Test D = Differential(t) my_f(x, y) = x^3 + 2y -# Check replace function. +# Check replacenode function. let - @test isequal(replace(X + X + X, X =>1), 3) - @test isequal(replace(X + X + X, Y => 1), 3X) - @test isequal(replace(X + X + X, X => Y), 3Y) - @test isequal(replace(X + Y^2 - Z, Y^2 => Z), X) + @test isequal(replacenode(X + X + X, X =>1), 3) + @test isequal(replacenode(X + X + X, Y => 1), 3X) + @test isequal(replacenode(X + X + X, X => Y), 3Y) + @test isequal(replacenode(X + Y^2 - Z, Y^2 => Z), X) end -# Test occursin function. +# Test hasnode function. let ex1 = 2X^a - log(b + my_f(Y,Y)) - 3 ex2 = X^(Y^(Z-a)) +log(log(log(b))) @@ -26,36 +26,36 @@ let ex5 = a + 5b^2 # Test for variables. - @test occursin(X, ex1) - @test occursin(X, ex2) - @test occursin(X, ex3) - @test !occursin(X, ex4) - @test occursin(Y, ex1) - @test occursin(Y, ex2) - @test occursin(Y, ex3) - @test occursin(Y, ex4) - @test !occursin(Z, ex1) - @test occursin(Z, ex2) - @test !occursin(Z, ex3) - @test occursin(Z, ex4) + @test hasnode(X, ex1) + @test hasnode(X, ex2) + @test hasnode(X, ex3) + @test !hasnode(X, ex4) + @test hasnode(Y, ex1) + @test hasnode(Y, ex2) + @test hasnode(Y, ex3) + @test hasnode(Y, ex4) + @test !hasnode(Z, ex1) + @test hasnode(Z, ex2) + @test !hasnode(Z, ex3) + @test hasnode(Z, ex4) # Test for variables. - @test_broken occursin(a, ex1) - @test_broken occursin(a, ex2) - @test_broken occursin(a, ex3) - @test_broken occursin(a, ex4) - @test occursin(a, ex5) - @test_broken occursin(b, ex1) - @test_broken occursin(b, ex2) - @test !occursin(b, ex3) - @test !occursin(b, ex4) - @test occursin(b, ex5) + @test hasnode(a, ex1) + @test hasnode(a, ex2) + @test hasnode(a, ex3) + @test hasnode(a, ex4) + @test hasnode(a, ex5) + @test hasnode(b, ex1) + @test hasnode(b, ex2) + @test !hasnode(b, ex3) + @test !hasnode(b, ex4) + @test hasnode(b, ex5) # Test for function. - @test !occursin(is_derivative, ex1) - @test !occursin(is_derivative, ex2) - @test !occursin(is_derivative, ex3) - @test occursin(is_derivative, ex4) + @test !hasnode(is_derivative, ex1) + @test !hasnode(is_derivative, ex2) + @test !hasnode(is_derivative, ex3) + @test hasnode(is_derivative, ex4) end # Check filterchildren function.