Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A fix for occursin #1139

Merged
merged 7 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ import MacroTools: splitdef, combinedef, postwalk, striplines
include("wrapper-types.jl")

include("num.jl")

include("rewrite-helpers.jl")
include("complex.jl")

Expand Down
58 changes: 26 additions & 32 deletions src/rewrite-helpers.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,28 @@
"""
replace(expr::Symbolic, rules...)
Walk the expression and replace subexpressions according to `rules`. `rules`
replacenode(expr::Symbolic, rules...)
Walk the expression and replacenode subexpressions according to `rules`. `rules`
could be rules constructed with `@rule`, a function, or a pair where the
left hand side is matched with equality (using `isequal`) and is replaced by the right hand side.
left hand side is matched with equality (using `isequal`) and is replacenoded by the right hand side.

Rules will be applied left-to-right simultaneously,
so only one pattern will be applied to any subexpression,
and the patterns will only be applied to the input text,
not the replacements.
not the replacenodements.

Set `fixpoint = true` to repeatedly apply rules until no
change to the expression remains to be made.
"""
function Base.replace(expr::Num, r::Pair, rules::Pair...)
_replace(unwrap(expr), r, rules...)
function replacenode(expr::Num, r::Pair, rules::Pair...; fixpoint = false)
_replacenode(unwrap(expr), r, rules...)
end

# Fix ambiguity
function Base.replace(expr::Num, rules...)
_replace(unwrap(expr), rules...)
end

function Base.replace(expr::Symbolic, rules...)
_replace(unwrap(expr), rules...)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought replace was fine though right? The occursin was the one which we couldn't define the right methods for? Either way, I'm fine with this change, but it would be good to add a deprection method for replace, I believe David Sanders uses it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, my recollection from our discussion was that the function we are writing here is different from the normal replace and we should rename it to replacenode (it was a while ago, so do not remember exactly). Personally, I have no strong opinion.

In an unrelated issue (and more serious), the current replace is also broken (I think after the latest term interface update): #1153

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that applies to occursin, not so much to replace.

Look for the next release that just got merged for the term interface fixes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should I update this PR to use replace instead of replacenode?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixpoint keyword arg is not in these top level methods, it should be.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll update that as well?

replacenode(expr::Num, rules...; fixpoint = false) = _replacenode(unwrap(expr), rules...; fixpoint)
replacenode(expr::Symbolic, rules...; fixpoint = false) = _replacenode(unwrap(expr), rules...; fixpoint)
replacenode(expr::Symbolic, r::Pair, rules::Pair...; fixpoint = false) = _replacenode(expr, r, rules...; fixpoint)
replacenode(expr::Number, rules...; fixpoint = false) = expr
replacenode(expr::Number, r::Pair, rules::Pair...; fixpoint = false) = expr

function Base.replace(expr::Symbolic, r::Pair, rules::Pair...)
_replace(expr, r, rules...)
end

function _replace(expr::Symbolic, rules...; fixpoint=false)
function _replacenode(expr::Symbolic, rules...; fixpoint = false)
rs = map(r -> r isa Pair ? (x -> isequal(x, r[1]) ? r[2] : nothing) : r, rules)
R = Prewalk(Chain(rs))
if fixpoint
Expand All @@ -40,32 +33,32 @@ function _replace(expr::Symbolic, rules...; fixpoint=false)
end

"""
occursin(c, x)
hasnode(c, x)
Returns true if any part of `x` fufills the condition given in c. c can be a function or an expression.
If it is a function, returns true if x is true for any part of x. If c is an expression, returns
true if x contains c.

Examples:
```julia
@syms x y
Symbolics.occursin(x, log(x) + x + 1) # returns `true`.
Symbolics.occursin(x, log(y) + y + 1) # returns `false`.
hasnode(x, log(x) + x + 1) # returns `true`.
hasnode(x, log(y) + y + 1) # returns `false`.
```

```julia
@variables t X(t)
D = Differential(t)
Symbolics.occursin(Symbolics.is_derivative, X + D(X) + D(X^2)) # returns `true`.
hasnode(Symbolics.is_derivative, X + D(X) + D(X^2)) # returns `true`.
```
"""
function Base.occursin(r::Function, y::Union{Num, Symbolic})
_occursin(r, y)
function hasnode(r::Function, y::Union{Num, Symbolic})
_hasnode(r, y)
end
hasnode(r::Num, y::Union{Num, Symbolic}) = occursin(unwrap(r), unwrap(y))
hasnode(r::Symbolic, y::Union{Num, Symbolic}) = occursin(unwrap(r), unwrap(y))
hasnode(r::Union{Num, Symbolic, Function}, y::Number) = false

Base.occursin(r::Num, y::Num) = occursin(unwrap(r), unwrap(y))
Base.occursin(r::Num, y::Symbolic) = occursin(unwrap(r), unwrap(y))

function _occursin(r, y)
function _hasnode(r, y)
y = unwrap(y)
if r isa Function
if r(y)
Expand All @@ -75,7 +68,7 @@ function _occursin(r, y)

if iscall(y)
return r(operation(y)) ||
any(y->_occursin(r, y), arguments(y))
any(y->_hasnode(r, y), arguments(y))
else
return false
end
Expand Down Expand Up @@ -129,6 +122,7 @@ function filterchildren!(r::Any, y, acc)
end

module RewriteHelpers
import Symbolics: filterchildren, unwrap
export replace, occursin, filterchildren, unwrap
import Symbolics: replacenode, hasnode, filterchildren, unwrap
export replacenode, hasnode, filterchildren, unwrap

end
85 changes: 50 additions & 35 deletions test/rewrite_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,26 @@ using Test
D = Differential(t)
my_f(x, y) = x^3 + 2y

# Check replace function.
# Check `replacenode` function.
let
@test isequal(replace(X + X + X, X =>1), 3)
@test isequal(replace(X + X + X, Y => 1), 3X)
@test isequal(replace(X + X + X, X => Y), 3Y)
@test isequal(replace(X + Y^2 - Z, Y^2 => Z), X)
# Simple replacements.
@test isequal(replacenode(X + X + X, X =>1), 3)
@test isequal(replacenode(X + X + X, Y => 1), 3X)
@test isequal(replacenode(X + X + my_f(X, Z), X => Y), Y^3 + 2Y + 2Z)
@test isequal(replacenode(X + Y^2 - Z, Y^2 => Z), X)

# When the rule is a function.
rep_func(expr) = Symbolics.is_derivative(expr) ? b : expr
@test isequal(replacenode(D(X + Y) - log(a*Z), rep_func), b - log(a*Z))
@test isequal(replacenode(D(Z^2) + my_f(D(X), D(Y)) + Z, rep_func), b^3 + 3b + Z)
@test isequal(replacenode(X + sin(Y + a) + a, rep_func), X + sin(Y + a) + a)

# On non-symbolic inputs.
@test isequal(replacenode(1, X =>2.0), 1)
@test isequal(replacenode(1, rep_func), 1)
end

# Test occursin function.
# Test `hasnode` function.
let
ex1 = 2X^a - log(b + my_f(Y,Y)) - 3
ex2 = X^(Y^(Z-a)) +log(log(log(b)))
Expand All @@ -26,39 +37,44 @@ let
ex5 = a + 5b^2

# Test for variables.
@test occursin(X, ex1)
@test occursin(X, ex2)
@test occursin(X, ex3)
@test !occursin(X, ex4)
@test occursin(Y, ex1)
@test occursin(Y, ex2)
@test occursin(Y, ex3)
@test occursin(Y, ex4)
@test !occursin(Z, ex1)
@test occursin(Z, ex2)
@test !occursin(Z, ex3)
@test occursin(Z, ex4)
@test hasnode(X, ex1)
@test hasnode(X, ex2)
@test hasnode(X, ex3)
@test !hasnode(X, ex4)
@test hasnode(Y, ex1)
@test hasnode(Y, ex2)
@test hasnode(Y, ex3)
@test hasnode(Y, ex4)
@test !hasnode(Z, ex1)
@test hasnode(Z, ex2)
@test !hasnode(Z, ex3)
@test hasnode(Z, ex4)

# Test for variables.
@test_broken occursin(a, ex1)
@test_broken occursin(a, ex2)
@test_broken occursin(a, ex3)
@test_broken occursin(a, ex4)
@test occursin(a, ex5)
@test_broken occursin(b, ex1)
@test_broken occursin(b, ex2)
@test !occursin(b, ex3)
@test !occursin(b, ex4)
@test occursin(b, ex5)
@test hasnode(a, ex1)
@test hasnode(a, ex2)
@test hasnode(a, ex3)
@test hasnode(a, ex4)
@test hasnode(a, ex5)
@test hasnode(b, ex1)
@test hasnode(b, ex2)
@test !hasnode(b, ex3)
@test !hasnode(b, ex4)
@test hasnode(b, ex5)

# Test for function.
@test !occursin(is_derivative, ex1)
@test !occursin(is_derivative, ex2)
@test !occursin(is_derivative, ex3)
@test occursin(is_derivative, ex4)
@test !hasnode(is_derivative, ex1)
@test !hasnode(is_derivative, ex2)
@test !hasnode(is_derivative, ex3)
@test hasnode(is_derivative, ex4)

# On non symbolic inputs:
@test !hasnode(X, 1)
@test !hasnode(a, 1)
@test !hasnode(is_derivative, 1)
end

# Check filterchildren function.
# Check `filterchildren` function.
let
ex1 = 2X^a - log(b + my_f(Y,Y)) - 3
ex2 = X^(Y^(Z-a)) +log(log(log(b)))
Expand All @@ -80,8 +96,7 @@ let
@test isequal(filterchildren(Z, ex3), [])
@test isequal(filterchildren(Z, ex4), [Z])

# Test for variables.

# Test for syms.
@test isequal(filterchildren(a, ex1), [a])
@test isequal(filterchildren(a, ex2), [a])
@test isequal(filterchildren(a, ex3), [a])
Expand Down
Loading