Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
TorkelE committed May 14, 2024
1 parent d023e54 commit 3bb4ee0
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 60 deletions.
1 change: 1 addition & 0 deletions src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import MacroTools: splitdef, combinedef, postwalk, striplines
include("wrapper-types.jl")

include("num.jl")

include("rewrite-helpers.jl")
include("complex.jl")

Expand Down
51 changes: 23 additions & 28 deletions src/rewrite-helpers.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
replace(expr::Symbolic, rules...)
replacenode(expr::Symbolic, rules...)
Walk the expression and replace subexpressions according to `rules`. `rules`
could be rules constructed with `@rule`, a function, or a pair where the
left hand side is matched with equality (using `isequal`) and is replaced by the right hand side.
Expand All @@ -12,24 +12,24 @@ not the replacements.
Set `fixpoint = true` to repeatedly apply rules until no
change to the expression remains to be made.
"""
function Base.replace(expr::Num, r::Pair, rules::Pair...)
_replace(unwrap(expr), r, rules...)
function replacenode(expr::Num, r::Pair, rules::Pair...)
_replacenode(unwrap(expr), r, rules...)
end

# Fix ambiguity
function Base.replace(expr::Num, rules...)
_replace(unwrap(expr), rules...)
function replacenode(expr::Num, rules...)
_replacenode(unwrap(expr), rules...)
end

function Base.replace(expr::Symbolic, rules...)
_replace(unwrap(expr), rules...)
function replacenode(expr::Symbolic, rules...)
_replacenode(unwrap(expr), rules...)
end

function Base.replace(expr::Symbolic, r::Pair, rules::Pair...)
_replace(expr, r, rules...)
function replacenode(expr::Symbolic, r::Pair, rules::Pair...)
_replacenode(expr, r, rules...)
end

function _replace(expr::Symbolic, rules...; fixpoint=false)
function _replacenode(expr::Symbolic, rules...; fixpoint=false)
rs = map(r -> r isa Pair ? (x -> isequal(x, r[1]) ? r[2] : nothing) : r, rules)
R = Prewalk(Chain(rs))
if fixpoint
Expand All @@ -40,37 +40,31 @@ function _replace(expr::Symbolic, rules...; fixpoint=false)
end

"""
occursin(c, x)
hasnode(c, x)
Returns true if any part of `x` fufills the condition given in c. c can be a function or an expression.
If it is a function, returns true if x is true for any part of x. If c is an expression, returns
true if x contains c.
Examples:
```julia
@syms x y
Symbolics.occursin(x, log(x) + x + 1) # returns `true`.
Symbolics.occursin(x, log(y) + y + 1) # returns `false`.
hasnode(x, log(x) + x + 1) # returns `true`.
hasnode(x, log(y) + y + 1) # returns `false`.
```
```julia
@variables t X(t)
D = Differential(t)
Symbolics.occursin(Symbolics.is_derivative, X + D(X) + D(X^2)) # returns `true`.
hasnode(Symbolics.is_derivative, X + D(X) + D(X^2)) # returns `true`.
```
"""
function Base.occursin(r::Function, y::Num)
Symbolics._occursin(r, y)
function hasnode(r::Function, y::Union{Num, Symbolic})
_hasnode(r, y)
end
# Initially both these were created using `y::Union{Num, Symbolic}`. However, this produced
# ambiguity error due to something in SymbolicsBase. Hence the dual declarations here.
function Base.occursin(r::Function, y::Symbolics.Symbolic)
Symbolics._occursin(r, y)
end

Base.occursin(r::Num, y::Num) = occursin(unwrap(r), unwrap(y))
Base.occursin(r::Num, y::Symbolic) = occursin(unwrap(r), unwrap(y))
hasnode(r::Num, y::Union{Num, Symbolic}) = occursin(unwrap(r), unwrap(y))
hasnode(r::Symbolic, y::Union{Num, Symbolic}) = occursin(unwrap(r), unwrap(y))

function _occursin(r, y)
function _hasnode(r, y)
y = unwrap(y)
if r isa Function
if r(y)
Expand All @@ -80,7 +74,7 @@ function _occursin(r, y)

if istree(y)
return r(operation(y)) ||
any(y->_occursin(r, y), arguments(y))
any(y->_hasnode(r, y), arguments(y))
else
return false
end
Expand Down Expand Up @@ -134,6 +128,7 @@ function filterchildren!(r::Any, y, acc)
end

module RewriteHelpers
import Symbolics: filterchildren, unwrap
export replace, occursin, filterchildren, unwrap
import Symbolics: replacenode, hasnode, filterchildren, unwrap
export replacenode, hasnode, filterchildren, unwrap

end
64 changes: 32 additions & 32 deletions test/rewrite_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ using Test
D = Differential(t)
my_f(x, y) = x^3 + 2y

# Check replace function.
# Check replacenode function.
let
@test isequal(replace(X + X + X, X =>1), 3)
@test isequal(replace(X + X + X, Y => 1), 3X)
@test isequal(replace(X + X + X, X => Y), 3Y)
@test isequal(replace(X + Y^2 - Z, Y^2 => Z), X)
@test isequal(replacenode(X + X + X, X =>1), 3)
@test isequal(replacenode(X + X + X, Y => 1), 3X)
@test isequal(replacenode(X + X + X, X => Y), 3Y)
@test isequal(replacenode(X + Y^2 - Z, Y^2 => Z), X)
end

# Test occursin function.
# Test hasnode function.
let
ex1 = 2X^a - log(b + my_f(Y,Y)) - 3
ex2 = X^(Y^(Z-a)) +log(log(log(b)))
Expand All @@ -26,36 +26,36 @@ let
ex5 = a + 5b^2

# Test for variables.
@test occursin(X, ex1)
@test occursin(X, ex2)
@test occursin(X, ex3)
@test !occursin(X, ex4)
@test occursin(Y, ex1)
@test occursin(Y, ex2)
@test occursin(Y, ex3)
@test occursin(Y, ex4)
@test !occursin(Z, ex1)
@test occursin(Z, ex2)
@test !occursin(Z, ex3)
@test occursin(Z, ex4)
@test hasnode(X, ex1)
@test hasnode(X, ex2)
@test hasnode(X, ex3)
@test !hasnode(X, ex4)
@test hasnode(Y, ex1)
@test hasnode(Y, ex2)
@test hasnode(Y, ex3)
@test hasnode(Y, ex4)
@test !hasnode(Z, ex1)
@test hasnode(Z, ex2)
@test !hasnode(Z, ex3)
@test hasnode(Z, ex4)

# Test for variables.
@test_broken occursin(a, ex1)
@test_broken occursin(a, ex2)
@test_broken occursin(a, ex3)
@test_broken occursin(a, ex4)
@test occursin(a, ex5)
@test_broken occursin(b, ex1)
@test_broken occursin(b, ex2)
@test !occursin(b, ex3)
@test !occursin(b, ex4)
@test occursin(b, ex5)
@test hasnode(a, ex1)
@test hasnode(a, ex2)
@test hasnode(a, ex3)
@test hasnode(a, ex4)
@test hasnode(a, ex5)
@test hasnode(b, ex1)
@test hasnode(b, ex2)
@test !hasnode(b, ex3)
@test !hasnode(b, ex4)
@test hasnode(b, ex5)

# Test for function.
@test !occursin(is_derivative, ex1)
@test !occursin(is_derivative, ex2)
@test !occursin(is_derivative, ex3)
@test occursin(is_derivative, ex4)
@test !hasnode(is_derivative, ex1)
@test !hasnode(is_derivative, ex2)
@test !hasnode(is_derivative, ex3)
@test hasnode(is_derivative, ex4)
end

# Check filterchildren function.
Expand Down

0 comments on commit 3bb4ee0

Please sign in to comment.