diff --git a/Project.toml b/Project.toml index 0381235df..1ba267860 100644 --- a/Project.toml +++ b/Project.toml @@ -15,6 +15,7 @@ DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07" IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Metatheory = "e9d8d322-4543-424a-9be4-0cc815abe26c" MultivariatePolynomials = "102ac46a-7ee4-5c85-9060-abc95bfdeaa3" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" @@ -43,7 +44,7 @@ Setfield = "0.7, 0.8, 1" SpecialFunctions = "0.10, 1.0, 2" StaticArrays = "0.12, 1.0" SymbolicIndexingInterface = "0.3" -TermInterface = "0.4" +TermInterface = "2.0" TimerOutputs = "0.5" Unityper = "0.1.2" julia = "1.3" diff --git a/docs/src/manual/interface.md b/docs/src/manual/interface.md index 899aff658..98518fb52 100644 --- a/docs/src/manual/interface.md +++ b/docs/src/manual/interface.md @@ -13,6 +13,20 @@ You can read the documentation of [TermInterface.jl](https://github.com/JuliaSym ## SymbolicUtils.jl only methods -`promote_symtype(f, arg_symtypes...)` +### `symtype(x)` + +Returns the +[numeric type](https://docs.julialang.org/en/v1/base/numbers/#Standard-Numeric-Types) +of `x`. By default this is just `typeof(x)`. +Define this for your symbolic types if you want [`SymbolicUtils.simplify`](@ref) to apply rules +specific to numbers (such as commutativity of multiplication). Or such +rules that may be implemented in the future. + +### `issym(x)` + +Returns `true` if `x` is a `Sym`. If `true`, `nameof` must be defined +on `x` and must return a `Symbol`. + +### `promote_symtype(f, arg_symtypes...)` Returns the appropriate output type of applying `f` on arguments of type `arg_symtypes`. diff --git a/docs/src/manual/representation.md b/docs/src/manual/representation.md index 997d33f3a..fea21bf1b 100644 --- a/docs/src/manual/representation.md +++ b/docs/src/manual/representation.md @@ -4,7 +4,7 @@ Performance of symbolic simplification depends on the datastructures used to rep The most basic term representation simply holds a function call and stores the function and the arguments it is called with. This is done by the `Term` type in SymbolicUtils. Functions that aren't commutative or associative, such as `sin` or `hypot` are stored as `Term`s. Commutative and associative operations like `+`, `*`, and their supporting operations like `-`, `/` and `^`, when used on terms of type `<:Number`, stand to gain from the use of more efficient datastrucutres. -All term representations must support `operation` and `arguments` functions. And they must define `istree` to return `true` when called with an instance of the type. Generic term-manipulation programs such as the rule-based rewriter make use of this interface to inspect expressions. In this way, the interface wins back the generality lost by having a zoo of term representations instead of one. (see [interface](/interface/) section for more on this.) +All term representations must support `operation` and `arguments` functions. And they must define `iscall` and `isexpr` to return `true` when called with an instance of the type. Generic term-manipulation programs such as the rule-based rewriter make use of this interface to inspect expressions. In this way, the interface wins back the generality lost by having a zoo of term representations instead of one. (see [interface](/interface/) section for more on this.) ### Preliminary representation of arithmetic diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index 619c37118..ef9b81f33 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -16,8 +16,8 @@ using SymbolicIndexingInterface import Base: +, -, *, /, //, \, ^, ImmutableDict using ConstructionBase using TermInterface -import TermInterface: iscall, isexpr, issym, symtype, head, children, - operation, arguments, metadata, maketerm +import TermInterface: iscall, isexpr, head, children, + operation, arguments, metadata, maketerm, sorted_arguments const istree = iscall Base.@deprecate_binding istree iscall diff --git a/src/code.jl b/src/code.jl index 6432bd1f5..c9a51ad19 100644 --- a/src/code.jl +++ b/src/code.jl @@ -9,7 +9,7 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr, import ..SymbolicUtils import ..SymbolicUtils.Rewriters import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym, - symtype, similarterm, sorted_arguments, metadata, isterm, term, maketerm + symtype, sorted_arguments, metadata, isterm, term, maketerm ##== state management ==## @@ -115,7 +115,7 @@ function function_to_expr(op, O, st) (get(st.rewrites, :nanmath, false) && op in NaNMathFuns) || return nothing name = nameof(op) fun = GlobalRef(NaNMath, name) - args = map(Base.Fix2(toexpr, st), arguments(O)) + args = map(Base.Fix2(toexpr, st), sorted_arguments(O)) expr = Expr(:call, fun) append!(expr.args, args) return expr @@ -138,7 +138,7 @@ function function_to_expr(op::Union{typeof(*),typeof(+)}, O, st) end function function_to_expr(::typeof(^), O, st) - args = arguments(O) + args = sorted_arguments(O) if length(args) == 2 && args[2] isa Real && args[2] < 0 ex = args[1] if args[2] == -1 @@ -151,7 +151,7 @@ function function_to_expr(::typeof(^), O, st) end function function_to_expr(::typeof(SymbolicUtils.ifelse), O, st) - args = arguments(O) + args = sorted_arguments(O) :($(toexpr(args[1], st)) ? $(toexpr(args[2], st)) : $(toexpr(args[3], st))) end @@ -183,7 +183,7 @@ function toexpr(O, st) return expr′ else !iscall(O) && return O - args = arguments(O) + args = sorted_arguments(O) return Expr(:call, toexpr(op, st), map(x->toexpr(x, st), args)...) end end @@ -693,8 +693,8 @@ end function _cse!(mem, expr) iscall(expr) || return expr op = _cse!(mem, operation(expr)) - args = map(Base.Fix1(_cse!, mem), arguments(expr)) - t = similarterm(expr, op, args) + args = map(Base.Fix1(_cse!, mem), sorted_arguments(expr)) + t = maketerm(typeof(expr), op, args, nothing) v, dict = mem update! = let v=v, t=t @@ -716,7 +716,7 @@ end function _cse(exprs::AbstractArray) letblock = cse(Term{Any}(tuple, vec(exprs))) - letblock.pairs, reshape(arguments(letblock.body), size(exprs)) + letblock.pairs, reshape(sorted_arguments(letblock.body), size(exprs)) end function cse(x::MakeArray) @@ -763,9 +763,7 @@ function cse_block!(assignments, counter, names, name, state, x) if isterm(x) return term(operation(x), args...) else - return maketerm(typeof(x), operation(x), - args, symtype(x), - metadata(x)) + return maketerm(typeof(x), operation(x), args, metadata(x)) end else return x diff --git a/src/interface.jl b/src/interface.jl deleted file mode 100644 index bea1d47ae..000000000 --- a/src/interface.jl +++ /dev/null @@ -1,84 +0,0 @@ -""" - iscall(x) - -Returns `true` if `x` is a term. If true, `operation`, `arguments` -must also be defined for `x` appropriately. -""" -iscall(x) = false - -""" - symtype(x) - -Returns the symbolic type of `x`. By default this is just `typeof(x)`. -Define this for your symbolic types if you want `SymbolicUtils.simplify` to apply rules -specific to numbers (such as commutativity of multiplication). Or such -rules that may be implemented in the future. -""" -function symtype(x) - typeof(x) -end - -""" - issym(x) - -Returns `true` if `x` is a symbol. If true, `nameof` must be defined -on `x` and must return a Symbol. -""" -issym(x) = false - -""" - operation(x) - -If `x` is a term as defined by `iscall(x)`, `operation(x)` returns the -head of the term if `x` represents a function call, for example, the head -is the function being called. -""" -function operation end - -""" - sorted_arguments(x) - -Get the arguments of `x`, must be defined if `iscall(x)` is `true`. -""" -function sorted_arguments end - -""" - sorted_arguments(x::T) - -If x is a term satisfying `iscall(x)` and your term type `T` provides -an optimized implementation for storing the arguments, this function can -be used to retrieve the arguments when the order of arguments does not matter -but the speed of the operation does. -""" -function arguments end -arity(x) = length(arguments(x)) - -""" - metadata(x) - -Return the metadata attached to `x`. -""" -metadata(x) = nothing - -""" - metadata(x, md) - -Returns a new term which has the structure of `x` but also has -the metadata `md` attached to it. -""" -function metadata(x, data) - error("Setting metadata on $x is not possible") -end - -""" - similarterm(x, head, args, symtype=nothing; metadata=nothing, exprhead=:call) - -Returns a term that is in the same closure of types as `typeof(x)`, -with `head` as the head and `args` as the arguments, `type` as the symtype -and `metadata` as the metadata. By default this will execute `head(args...)`. -`x` parameter can also be a `Type`. The `exprhead` keyword argument is useful -when manipulating `Expr`s. - -`similarterm` is deprecated see help for `maketerm` instead. -""" -function similarterm end diff --git a/src/matchers.jl b/src/matchers.jl index 7f4dea537..91a6c1990 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -85,7 +85,7 @@ function matcher(segment::Segment) end function term_matcher(term) - matchers = (matcher(operation(term)), map(matcher, arguments(term))...,) + matchers = (matcher(operation(term)), map(matcher, sorted_arguments(term))...,) function term_matcher(success, data, bindings) !islist(data) && return nothing diff --git a/src/ordering.jl b/src/ordering.jl index 332f11cf8..0d1a4ce79 100644 --- a/src/ordering.jl +++ b/src/ordering.jl @@ -84,7 +84,7 @@ function <ₑ(a::BasicSymbolic, b::BasicSymbolic) bw = monomial_lt(db, da) if fw === bw && !isequal(a, b) if _arglen(a) == _arglen(b) - return (operation(a), arguments(a)...,) <ₑ (operation(b), arguments(b)...,) + return (operation(a), sorted_arguments(a)...,) <ₑ (operation(b), sorted_arguments(b)...,) else return _arglen(a) < _arglen(b) end diff --git a/src/polyform.jl b/src/polyform.jl index ab8bddfae..9450da554 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -103,7 +103,7 @@ function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse) end op = operation(x) - args = arguments(x) + args = sorted_arguments(x) local_polyize(y) = polyize(y, pvar2sym, sym2term, vtype, pow, Fs, recurse) @@ -121,7 +121,6 @@ function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse) maketerm(typeof(x), op, map(a->PolyForm(a, pvar2sym, sym2term, vtype; Fs, recurse), args), - symtype(x), metadata(x)) else x @@ -176,18 +175,18 @@ isexpr(x::PolyForm) = true iscall(x::Type{<:PolyForm}) = true iscall(x::PolyForm) = true -function maketerm(::Type{<:PolyForm}, f, args, symtype, metadata) - basicsymbolic(t, f, args, symtype, metadata) +function maketerm(t::Type{<:PolyForm}, f, args, metadata) + # TODO: this looks uncovered. + basicsymbolic(f, args, nothing, metadata) end -function maketerm(::Type{<:PolyForm}, f::Union{typeof(*), typeof(+), typeof(^)}, - args, symtype, metadata) +function maketerm(::Type{<:PolyForm}, f::Union{typeof(*), typeof(+), typeof(^)}, args, metadata) f(args...) end head(::PolyForm) = PolyForm operation(x::PolyForm) = MP.nterms(x.p) == 1 ? (*) : (+) -function arguments(x::PolyForm{T}) where {T} +function TermInterface.arguments(x::PolyForm{T}) where {T} function is_var(v) MP.nterms(v) == 1 && @@ -231,10 +230,7 @@ function arguments(x::PolyForm{T}) where {T} PolyForm{T}(t, x.pvar2sym, x.sym2term, nothing)) for t in ts] end end - -sorted_arguments(x::PolyForm) = arguments(x) - -children(x::PolyForm) = [operation(x); arguments(x)] +children(x::PolyForm) = arguments(x) Base.show(io::IO, x::PolyForm) = show_term(io, x) @@ -255,7 +251,7 @@ function unpolyize(x) # we need a special makterm here because the default one used in Postwalk will call # promote_symtype to get the new type, but we just want to forward that in case # promote_symtype is not defined for some of the expressions here. - Postwalk(identity, maketerm=(T,f,args,sT,m) -> maketerm(T, f, args, symtype(x), m))(x) + Postwalk(identity, maketerm=(T,f,args,m) -> maketerm(T, f, args, m))(x) end function toterm(x::PolyForm) @@ -307,7 +303,8 @@ function add_divs(x, y) end end -function frac_maketerm(T, f, args, stype, metadata) +function frac_maketerm(T, f, args, metadata) + # TODO add stype to T? if f in (*, /, \, +, -) f(args...) elseif f == (^) @@ -317,7 +314,7 @@ function frac_maketerm(T, f, args, stype, metadata) args[1]^args[2] end else - maketerm(T, f, args, stype, metadata) + maketerm(T, f, args, metadata) end end @@ -394,7 +391,7 @@ function has_div(x) end flatten_pows(xs) = map(xs) do x - ispow(x) ? Iterators.repeated(arguments(x)...) : (x,) + ispow(x) ? Iterators.repeated(sorted_arguments(x)...) : (x,) end |> Iterators.flatten |> a->collect(Any,a) coefftype(x::PolyForm) = coefftype(x.p) diff --git a/src/rewriters.jl b/src/rewriters.jl index fe5d2bb04..f9c9b603a 100644 --- a/src/rewriters.jl +++ b/src/rewriters.jl @@ -167,11 +167,7 @@ end struct Walk{ord, C, F, threaded} rw::C thread_cutoff::Int - maketerm::F # XXX: for the 2.0 deprecation cycle, we actually store a function - # that behaves like `similarterm` here, we use `compatmaker` to wrap - # maketerm-like input to do this, with a warning if similarterm provided - # we need this workaround to deprecate because similarterm takes value - # but maketerm only knows the type. + maketerm::F end function instrument(x::Walk{ord, C,F,threaded}, f) where {ord,C,F,threaded} @@ -183,25 +179,13 @@ end using .Threads -function compatmaker(similarterm, maketerm) - # XXX: delete this and only use maketerm in a future release. - if similarterm isa Nothing - function (x, f, args, type=_promote_symtype(f, args); metadata) - maketerm(typeof(x), f, args, type, metadata) - end - else - Base.depwarn("Prewalk and Postwalk now take maketerm instead of similarterm keyword argument. similarterm(x, f, args, type; metadata) is now maketerm(typeof(x), f, args, type, metadata)", :similarterm) - similarterm - end -end -function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm, similarterm=nothing) - maker = compatmaker(similarterm, maketerm) - Walk{:post, typeof(rw), typeof(maker), threaded}(rw, thread_cutoff, maker) + +function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm) + Walk{:post, typeof(rw), typeof(maketerm), threaded}(rw, thread_cutoff, maketerm) end -function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm, similarterm=nothing) - maker = compatmaker(similarterm, maketerm) - Walk{:pre, typeof(rw), typeof(maker), threaded}(rw, thread_cutoff, maker) +function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm) + Walk{:pre, typeof(rw), typeof(maketerm), threaded}(rw, thread_cutoff, maketerm) end struct PassThrough{C} @@ -220,8 +204,8 @@ function (p::Walk{ord, C, F, false})(x) where {ord, C, F} end if iscall(x) - x = p.maketerm(x, operation(x), map(PassThrough(p), - arguments(x)), metadata=metadata(x)) + x = p.maketerm(typeof(x), operation(x), map(PassThrough(p), + arguments(x)), metadata(x)) end return ord === :post ? p.rw(x) : x @@ -237,15 +221,15 @@ function (p::Walk{ord, C, F, true})(x) where {ord, C, F} x = p.rw(x) end if iscall(x) - _args = map(arguments(x)) do arg + _args = map(sorted_arguments(x)) do arg if node_count(arg) > p.thread_cutoff Threads.@spawn p(arg) else p(arg) end end - args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x)) - t = p.maketerm(x, operation(x), args, metadata=metadata(x)) + args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, sorted_arguments(x)) + t = p.maketerm(typeof(x), operation(x), args, metadata(x)) end return ord === :post ? p.rw(t) : t else diff --git a/src/rule.jl b/src/rule.jl index 13fe86c79..bf84b9cdc 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -121,7 +121,7 @@ getdepth(r::Rule) = r.depth function rule_depth(rule, d=0, maxdepth=0) if iscall(rule) - maxdepth = reduce(max, (rule_depth(r, d+1, maxdepth) for r in arguments(rule)), init=1) + maxdepth = reduce(max, (rule_depth(r, d+1, maxdepth) for r in sorted_arguments(rule)), init=1) elseif rule isa Slot || rule isa Segment maxdepth = max(d, maxdepth) end @@ -298,24 +298,27 @@ whether the predicate holds or not. _In the consequent pattern_: Use `(@ctx)` to access the context object on the right hand side of an expression. """ -macro rule(expr) - @assert expr.head == :call && expr.args[1] == :(=>) - lhs = expr.args[2] - rhs = rewrite_rhs(expr.args[3]) - keys = Symbol[] - lhs_term = makepattern(lhs, keys) - unique!(keys) - quote - $(__source__) - lhs_pattern = $(lhs_term) - Rule($(QuoteNode(expr)), - lhs_pattern, - matcher(lhs_pattern), - __MATCHES__ -> $(makeconsequent(rhs)), - rule_depth($lhs_term)) - end -end - +# macro rule(expr) +# @assert expr.head == :call && expr.args[1] == :(=>) +# lhs = expr.args[2] +# rhs = rewrite_rhs(expr.args[3]) +# keys = Symbol[] +# lhs_term = makepattern(lhs, keys) +# unique!(keys) +# quote +# $(__source__) +# lhs_pattern = $(lhs_term) +# Rule($(QuoteNode(expr)), +# lhs_pattern, +# matcher(lhs_pattern), +# __MATCHES__ -> $(makeconsequent(rhs)), +# rule_depth($lhs_term)) +# end +# end + +using Metatheory +using Metatheory: @rule +using TermInterface: isexpr """ @capture ex pattern @@ -394,7 +397,7 @@ function (acr::ACRule)(term) else f = operation(term) # Assume that the matcher was formed by closing over a term - if f != operation(r.lhs) # Maybe offer a fallback if m.term errors. + if f != operation(r.left) # Maybe offer a fallback if m.term errors. return nothing end @@ -408,7 +411,7 @@ function (acr::ACRule)(term) if result !== nothing # Assumption: inds are unique length(args) == length(inds) && return result - return maketerm(typeof(term), f, [result, (args[i] for i in eachindex(args) if i ∉ inds)...], symtype(term), metadata(term)) + return maketerm(typeof(term), f, [result, (args[i] for i in eachindex(args) if i ∉ inds)...], metadata(term)) end end end diff --git a/src/simplify.jl b/src/simplify.jl index 695e57c5a..68fe78f83 100644 --- a/src/simplify.jl +++ b/src/simplify.jl @@ -45,6 +45,6 @@ end has_operation(x, op) = (iscall(x) && (operation(x) == op || any(a->has_operation(a, op), - arguments(x)))) + arguments(x)))) Base.@deprecate simplify(x, ctx; kwargs...) simplify(x; rewriter=ctx, kwargs...) diff --git a/src/simplify_rules.jl b/src/simplify_rules.jl index a612036cb..26a37a73d 100644 --- a/src/simplify_rules.jl +++ b/src/simplify_rules.jl @@ -1,4 +1,6 @@ using .Rewriters +using Metatheory: @rule + """ is_operation(f) Returns a single argument anonymous function predicate, that returns `true` if and only if @@ -6,10 +8,15 @@ the argument to the predicate satisfies `iscall` and `operation(x) == f` """ is_operation(f) = @nospecialize(x) -> iscall(x) && (operation(x) == f) +const isnotflatplus = isnotflat(+) +const isnotflattimes = isnotflat(*) +const needs_sorting_plus = needs_sorting(+) +const needs_sorting_times = needs_sorting(*) + let CANONICALIZE_PLUS = [ - @rule(~x::isnotflat(+) => flatten_term(+, ~x)) - @rule(~x::needs_sorting(+) => sort_args(+, ~x)) + @rule(~x::isnotflatplus => flatten_term(+, ~x)) + @rule(~x::needs_sorting_plus => sort_args(+, ~x)) @ordered_acrule(~a::is_literal_number + ~b::is_literal_number => ~a + ~b) @acrule(*(~~x) + *(~β, ~~x) => *(1 + ~β, (~~x)...)) @@ -28,8 +35,8 @@ let ] CANONICALIZE_TIMES = [ - @rule(~x::isnotflat(*) => flatten_term(*, ~x)) - @rule(~x::needs_sorting(*) => sort_args(*, ~x)) + @rule(~x::isnotflattimes => flatten_term(*, ~x)) + @rule(~x::needs_sorting_times => sort_args(*, ~x)) @ordered_acrule(~a::is_literal_number * ~b::is_literal_number => ~a * ~b) @rule(*(~~x::hasrepeats) => *(merge_repeats(^, ~~x)...)) diff --git a/src/substitute.jl b/src/substitute.jl index 51c75e3c4..8fc980c69 100644 --- a/src/substitute.jl +++ b/src/substitute.jl @@ -34,7 +34,6 @@ function substitute(expr, dict; fold=true) maketerm(typeof(expr), op, args, - symtype(expr), metadata(expr)) else expr diff --git a/src/types.jl b/src/types.jl index ba14e34f4..eb9fa04d4 100644 --- a/src/types.jl +++ b/src/types.jl @@ -98,8 +98,18 @@ end ### ### Term interface ### -symtype(x::Number) = typeof(x) + +""" + symtype(x) + +Returns the numeric type of `x`. By default this is just `typeof(x)`. +Define this for your symbolic types if you want [`SymbolicUtils.simplify`](@ref) to apply rules +specific to numbers (such as commutativity of multiplication). Or such +rules that may be implemented in the future. +""" +symtype(x) = typeof(x) @inline symtype(::Symbolic{T}) where T = T +@inline symtype(::Type{<:Symbolic{T}}) where T = T # We're returning a function pointer @inline function operation(x::BasicSymbolic) @@ -116,7 +126,7 @@ end @inline head(x::BasicSymbolic) = operation(x) -function sorted_arguments(x::BasicSymbolic) +function TermInterface.sorted_arguments(x::BasicSymbolic) args = arguments(x) @compactified x::BasicSymbolic begin Add => @goto ADD @@ -138,13 +148,11 @@ function sorted_arguments(x::BasicSymbolic) return args end -children(x::BasicSymbolic) = arguments(x) - -sorted_children(x::BasicSymbolic) = sorted_arguments(x) - @deprecate unsorted_arguments(x) arguments(x) -function arguments(x::BasicSymbolic) +TermInterface.children(x::BasicSymbolic) = arguments(x) +TermInterface.sorted_children(x::BasicSymbolic) = sorted_arguments(x) +function TermInterface.arguments(x::BasicSymbolic) @compactified x::BasicSymbolic begin Term => return x.arguments Add => @goto ADDMUL @@ -166,7 +174,7 @@ function arguments(x::BasicSymbolic) if isadd(x) for (k, v) in x.dict push!(args, applicable(*,k,v) ? k*v : - maketerm(k, *, [k, v])) + maketerm(k, *, [k, v], nothing)) end else # MUL for (k, v) in x.dict @@ -196,7 +204,16 @@ isexpr(s::BasicSymbolic) = !issym(s) iscall(s::BasicSymbolic) = isexpr(s) @inline isa_SymType(T::Val{S}, x) where {S} = x isa BasicSymbolic ? Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolic), T, x) : false + +""" + issym(x) + +Returns `true` if `x` is a `Sym`. If true, `nameof` must be defined +on `x` and must return a `Symbol`. +""" +issym(x) = false issym(x::BasicSymbolic) = isa_SymType(Val(:Sym), x) + isterm(x) = isa_SymType(Val(:Term), x) ismul(x) = isa_SymType(Val(:Mul), x) isadd(x) = isa_SymType(Val(:Add), x) @@ -239,8 +256,8 @@ function _isequal(a, b, E) elseif E === POW isequal(a.exp, b.exp) && isequal(a.base, b.base) elseif E === TERM - a1 = arguments(a) - a2 = arguments(b) + a1 = sorted_arguments(a) + a2 = sorted_arguments(b) isequal(operation(a), operation(b)) && _allarequal(a1, a2) else error_on_type() @@ -281,7 +298,7 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt !iszero(h) && return h op = operation(s) oph = op isa Function ? nameof(op) : op - h′ = hashvec(arguments(s), hash(oph, salt)) + h′ = hashvec(sorted_arguments(s), hash(oph, salt)) s.hash[] = h′ return h′ else @@ -411,7 +428,7 @@ end @inline function numerators(x) isdiv(x) && return numerators(x.num) - iscall(x) && operation(x) === (*) ? arguments(x) : Any[x] + iscall(x) && operation(x) === (*) ? sorted_arguments(x) : Any[x] end @inline denominators(x) = isdiv(x) ? numerators(x.den) : Any[1] @@ -530,7 +547,7 @@ function unflatten(t::Symbolic{T}) where{T} if iscall(t) f = operation(t) if f == (+) || f == (*) # TODO check out for other n-ary --> binary ops - a = arguments(t) + a = sorted_arguments(t) return foldl((x,y) -> Term{T}(f, Any[x, y]), a) end end @@ -539,8 +556,22 @@ end unflatten(t) = t -function TermInterface.maketerm(::Type{<:BasicSymbolic}, head, args, type, metadata) - basicsymbolic(head, args, type, metadata) +function TermInterface.maketerm(T::Type{<:BasicSymbolic}, head, args, metadata) + st = symtype(T) + pst = _promote_symtype(head, args) + # Use promoted symtype only if not a subtype of the existing symtype of T. + # This is useful when calling `maketerm(BasicSymbolic{Number}, (==), [true, false])` + # Where the result would have a symtype of Bool. + # Please see discussion in https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/609 + # TODO this should be optimized. + new_st = if pst === Bool + pst + elseif pst === Any || (st === Number && pst <: st) + st + else + pst + end + basicsymbolic(head, args, new_st, metadata) end @@ -639,28 +670,6 @@ function to_symbolic(x) x end -""" - similarterm(x, op, args, symtype=nothing; metadata=nothing) - -""" -function similarterm(x, op, args, symtype=nothing; metadata=nothing) - Base.depwarn("""`similarterm` is deprecated, use `maketerm` instead. - `similarterm(x, op, args, symtype; metadata)` is now - `maketerm(typeof(x), op, args, symtype, metadata)`""", :similarterm) - TermInterface.maketerm(typeof(x), op, args, symtype, metadata) -end - -# Old fallback -function similarterm(T::Type, op, args, symtype=nothing; metadata=nothing) - - Base.depwarn("`similarterm` is deprecated, use `maketerm` instead." * - "See https://github.com/JuliaSymbolics/TermInterface.jl for details.", :similarterm) - op(args...) -end - -export similarterm - - ### ### Pretty printing ### @@ -669,7 +678,7 @@ const show_simplified = Ref(false) isnegative(t::Real) = t < 0 function isnegative(t) if iscall(t) && operation(t) === (*) - coeff = first(arguments(t)) + coeff = first(sorted_arguments(t)) return isnegative(coeff) end return false @@ -701,7 +710,7 @@ end function remove_minus(t) !iscall(t) && return -t @assert operation(t) == (*) - args = arguments(t) + args = sorted_arguments(t) @assert args[1] < 0 Any[-args[1], args[2:end]...] end diff --git a/src/utils.jl b/src/utils.jl index 69b6e8e2d..fb5ceaa36 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -48,12 +48,12 @@ end function fold(t) if iscall(t) - tt = map(fold, arguments(t)) + tt = map(fold, sorted_arguments(t)) if !any(x->x isa Symbolic, tt) # evaluate it return operation(t)(tt...) else - return maketerm(typeof(t), operation(t), tt, symtype(t), metadata(t)) + return maketerm(typeof(t), operation(t), tt, metadata(t)) end else return t @@ -74,12 +74,12 @@ _isinteger(x) = (x isa Number && isinteger(x)) || (x isa Symbolic && symtype(x) _isreal(x) = (x isa Number && isreal(x)) || (x isa Symbolic && symtype(x) <: Real) issortedₑ(args) = issorted(args, lt=<ₑ) -needs_sorting(f) = x -> is_operation(f)(x) && !issortedₑ(arguments(x)) +needs_sorting(f) = x -> is_operation(f)(x) && !issortedₑ(sorted_arguments(x)) # are there nested ⋆ terms? function isnotflat(⋆) function (x) - args = arguments(x) + args = sorted_arguments(x) for t in args if iscall(t) && operation(t) === (⋆) return true @@ -137,29 +137,29 @@ x + 2y ``` """ function flatten_term(⋆, x) - args = arguments(x) + args = sorted_arguments(x) # flatten nested ⋆ flattened_args = [] for t in args if iscall(t) && operation(t) === (⋆) - append!(flattened_args, arguments(t)) + append!(flattened_args, sorted_arguments(t)) else push!(flattened_args, t) end end - maketerm(typeof(x), ⋆, flattened_args, symtype(x), metadata(x)) + maketerm(typeof(x), ⋆, flattened_args, metadata(x)) end function sort_args(f, t) - args = arguments(t) + args = sorted_arguments(t) if length(args) < 2 - return maketerm(typeof(t), f, args, symtype(t), metadata(t)) + return maketerm(typeof(t), f, args, metadata(t)) elseif length(args) == 2 x, y = args - return maketerm(typeof(t), f, x <ₑ y ? [x,y] : [y,x], symtype(t), metadata(t)) + return maketerm(typeof(t), f, x <ₑ y ? [x,y] : [y,x], metadata(t)) end args = args isa Tuple ? [args...] : args - maketerm(typeof(t), f, sort(args, lt=<ₑ), symtype(t), metadata(t)) + maketerm(typeof(t), f, sort(args, lt=<ₑ), metadata(t)) end # Linked List interface @@ -182,12 +182,12 @@ Base.length(l::LL) = length(l.v)-l.i+1 Base.length(t::Term) = length(arguments(t)) + 1 # PIRACY Base.isempty(t::Term) = false @inline car(t::Term) = operation(t) -@inline cdr(t::Term) = arguments(t) +@inline cdr(t::Term) = sorted_arguments(t) @inline car(v) = iscall(v) ? operation(v) : first(v) @inline function cdr(v) if iscall(v) - arguments(v) + sorted_arguments(v) else islist(v) ? LL(v, 2) : error("asked cdr of empty") end @@ -200,7 +200,7 @@ end if n === 0 return ll else - iscall(ll) ? drop_n(arguments(ll), n-1) : drop_n(cdr(ll), n-1) + iscall(ll) ? drop_n(sorted_arguments(ll), n-1) : drop_n(cdr(ll), n-1) end end @inline drop_n(ll::Union{Tuple, AbstractArray}, n) = drop_n(LL(ll, 1), n) @@ -225,7 +225,7 @@ macro matchable(expr) SymbolicUtils.arguments(x::$name) = getfield.((x,), ($(QuoteNode.(fields)...),)) SymbolicUtils.children(x::$name) = [SymbolicUtils.operation(x); SymbolicUtils.children(x)] Base.length(x::$name) = $(length(fields) + 1) - SymbolicUtils.maketerm(x::$name, f, args, type, metadata) = f(args...) + SymbolicUtils.maketerm(x::$name, f, args, metadata) = f(args...) end |> esc end diff --git a/test/basics.jl b/test/basics.jl index 36228324c..58c3cab0d 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -216,23 +216,23 @@ end @testset "maketerm" begin @syms a b c - @test isequal(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], Number, nothing).dict, Dict(a=>1,b=>1,c=>1)) - @test isequal(SymbolicUtils.maketerm(typeof(b^2), ^, [b^2, 1//2], Number, nothing), b) + @test isequal(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], nothing).dict, Dict(a=>1,b=>1,c=>1)) + @test isequal(SymbolicUtils.maketerm(typeof(b^2), ^, [b^2, 1//2], nothing), b) # test that maketerm doesn't hard-code BasicSymbolic subtype # and is consistent with BasicSymbolic arithmetic operations - @test isequal(SymbolicUtils.maketerm(typeof(a / b), *, [a / b, c], Number, nothing), (a / b) * c) - @test isequal(SymbolicUtils.maketerm(typeof(a * b), *, [0, c], Number, nothing), 0) - @test isequal(SymbolicUtils.maketerm(typeof(a^b), ^, [a * b, 3], Number, nothing), (a * b)^3) + @test isequal(SymbolicUtils.maketerm(typeof(a / b), *, [a / b, c], nothing), (a / b) * c) + @test isequal(SymbolicUtils.maketerm(typeof(a * b), *, [0, c], nothing), 0) + @test isequal(SymbolicUtils.maketerm(typeof(a^b), ^, [a * b, 3], nothing), (a * b)^3) # test that maketerm sets metadata correctly metadata = Base.ImmutableDict{DataType, Any}(Ctx1, "meta_1") - s = SymbolicUtils.maketerm(typeof(a^b), ^, [a * b, 3], Number, metadata) + s = SymbolicUtils.maketerm(typeof(a^b), ^, [a * b, 3], metadata) @test hasmetadata(s, Ctx1) @test getmetadata(s, Ctx1) == "meta_1" end -toterm(t) = Term{symtype(t)}(operation(t), arguments(t)) +toterm(t) = Term{symtype(t)}(operation(t), sorted_arguments(t)) @testset "diffs" begin @syms a b c @@ -279,7 +279,7 @@ end T = FnType{Tuple{T,S,Int} where {T,S}, Real} s = Sym{T}(:t) @syms a b c::Int - @test isequal(arguments(s(a, b, c)), [a, b, c]) + @test isequal(sorted_arguments(s(a, b, c)), [a, b, c]) end @testset "div" begin diff --git a/test/egraphs.jl b/test/egraphs.jl new file mode 100644 index 000000000..d139d316c --- /dev/null +++ b/test/egraphs.jl @@ -0,0 +1,196 @@ +using Metatheory +using SymbolicUtils +const SU = SymbolicUtils +using SymbolicUtils: Symbolic, BasicSymbolic, unflatten, toterm, Term +using SymbolicUtils: monadic, diadic +using InteractiveUtils + +EGraphs.preprocess(t::Symbolic) = toterm(unflatten(t)) + +""" +Equational rewrite rules for optimizing expressions +""" +opt_theory = @theory a b c x y z begin + a + (b + c) == (a + b) + c + a * (b * c) == (a * b) * c + x + 0 --> x + a + b == b + a + a - a => 0 # is it ok? + + 0 - x --> -x + + a * b == b * a + a * x + a * y == a*(x+y) + -1 * a --> -a + a + (-1 * b) == a - b + x * 1 --> x + x * 0 --> 0 + x/x --> 1 + # fraction rules + x^-1 == 1/x + 1/x * a == a/x # is this needed? + x / (x / y) --> y + x * (y / z) == (x * y) / z + (a/b) + (c/b) --> (a+c)/b + (a / b) / c == a/(b*c) + + # TODO prohibited rule + x / x --> 1 + + # pow rules + a * a == a^2 + (a^b)^c == a^(b*c) + a^b * a^c == a^(b+c) + a^b / a^c == a^(b-c) + (a*b)^c == a^c * b^c + + # logarithmic rules + # TODO variables are non-zero + log(x::Number) => log(x) + log(x * y) == log(x) + log(y) + log(x / y) == log(x) - log(y) + log(x^y) == y * log(x) + x^(log(y)) == y^(log(x)) + + # trig functions + sin(x)/cos(x) == tan(x) + cos(x)/sin(x) == cot(x) + sin(x)^2 + cos(x)^2 --> 1 + sin(2a) == 2sin(a)cos(a) + + sin(x)*cos(y) - cos(x)*sin(y) --> sin(x - y) + # hyperbolic trigonometric + # are these optimizing at all? dont think so + sinh(x) == (ℯ^x - ℯ^(-x))/2 + csch(x) == 1/sinh(x) + cosh(x) == (ℯ^x + ℯ^(-x))/2 + sech(x) == 1/cosh(x) + sech(x) == 2/(ℯ^x + ℯ^(-x)) + tanh(x) == sinh(x)/cosh(x) + tanh(x) == (ℯ^x - ℯ^(-x))/(ℯ^x + ℯ^(-x)) + coth(x) == 1/tanh(x) + coth(x) == (ℯ^x + ℯ^-x)/(ℯ^x - ℯ^(-x)) + + cosh(x)^2 - sinh(x)^2 --> 1 + tanh(x)^2 + sech(x)^2 --> 1 + coth(x)^2 - csch(x)^2 --> 1 + + asinh(z) == log(z + √(z^2 + 1)) + acosh(z) == log(z + √(z^2 - 1)) + atanh(z) == log((1+z)/(1-z))/2 + acsch(z) == log((1+√(1+z^2)) / z ) + asech(z) == log((1 + √(1-z^2)) / z ) + acoth(z) == log( (z+1)/(z-1) )/2 + + # folding + x::Number * y::Number => x*y + x::Number + y::Number => x+y + x::Number / y::Number => x/y + x::Number - y::Number => x-y +end + + +# See +# * https://latkin.org/blog/2014/11/09/a-simple-benchmark-of-various-math-operations/ +# * https://streamhpc.com/blog/2012-07-16/how-expensive-is-an-operation-on-a-cpu/ +# * https://github.com/triscale-innov/GFlops.jl +# Measure the cost of expressions in terms of number of ASM instructions + + +function make_op_costs() + const op_costs = Dict() + + const types = [(Int64, Integer), (Float64, Real), (ComplexF64, Complex)] + + const io = IOBuffer() + + for f in vcat(monadic, [-]) + z = get!(op_costs, nameof(f), Dict()) + for (t, at) in types + try + InteractiveUtils.code_native(io, f, (t,)) + catch e + z[(t,)] = z[(at,)] = 1 + continue + end + str = String(take!(io)) + z[(t,)] = z[(at,)] = length(split(str, "\n")) + end + end + + for f in vcat(diadic, [+, -, *, /, //, ^]) + z = get!(op_costs, nameof(f), Dict()) + for (t1, at1) in types, (t2, at2) in types + try + InteractiveUtils.code_native(io, f, (t1, t2)) + catch e + z[(t1, t2)] = z[(at1, at2)] = z[(at1, t2)] = z[(t1, at2)] = 1 + continue + end + str = String(take!(io)) + z[(t1, t2)] = z[(at1, at2)] = z[(at1, t2)] = z[(t1, at2)] = length(split(str, "\n")) + end + end + + op_costs +end + +function getopcost(op_costs, f::Function, types::Tuple) + sym = nameof(f) + if haskey(op_costs, sym) && haskey(op_costs[sym], types) + return op_costs[sym][types] + end + + # print("$f $types | ") + io = IOBuffer() + try + InteractiveUtils.code_native(io, f, types) + catch e + op_costs[sym][types] = 1 + return 1 + end + str = String(take!(io)) + c = length(split(str, "\n")) + !haskey(op_costs, sym) && (op_costs[sym] = Dict()) + op_costs[sym][types] = c +end + +getopcost(f, types::Tuple) = get(get(op_costs, f, Dict()), types, 1) + +function costfun(n::VecExpr, op, children_costs::Vector{Float64}) + v_isexpr(n) || return 1 + # types = Tuple(map(x -> getdata(g[x], SymtypeAnalysis, Real), args)) + types = Tuple([Float64 for i in 1:v_arity(n)]) + opc = getopcost(op, types) + opc + sum(children_costs) +end + +denoisescalars(x, atol=1e-11) = Postwalk(Chain([ + # 0 - x --> -x + @acrule *(~x::Real, sin(~y)) => 0 where isapprox(x, 0; atol=atol) + @acrule *(~x::Real, cos(~y)) => 0 where isapprox(x, 0; atol=atol) + @acrule +(~x::Real, ~y) => y where isapprox(x, 0; atol=atol) + @acrule +(~x::Real, ~y) => y where isapprox(x, 0; atol=atol) +]))(x) + +const op_costs = make_op_costs() +function optimize(ex::Symbolic; params=SaturationParams(), atol=1e-13, verbose=false, kws...) + # ex = simplify(denoisescalars(ex, atol)) + # println(ex) + # readline() + + g = EGraph{BasicSymbolic}(ex) + + # display(g.classes);println(); + + report = saturate!(g, opt_theory, params) + verbose && @info report + extr = extract!(g, costfun) + return extr +end + +@syms x y z + +t = Term(+, [Term(*, [z, x]), Term(*, [z, y])]) + +optimize(t) \ No newline at end of file diff --git a/test/rewrite.jl b/test/rewrite.jl index 3bb2621e3..ccc754141 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -1,5 +1,7 @@ @syms a b c +using Metatheory + @testset "Equality" begin @eqtest a == a @eqtest a != b @@ -65,7 +67,7 @@ using SymbolicUtils: @capture ret = @capture (a + b) (+)(~~z) @test ret @test @isdefined z - @test all(z .=== arguments(a + b)) + @test all(z .=== sorted_arguments(a + b)) #a more typical way to use the @capture macro @@ -84,24 +86,24 @@ end ex1 = ex + c @test SymbolicUtils.isterm(ex1) - @test getmetadata(arguments(ex1)[1], MetaData) == :metadata + @test getmetadata(sorted_arguments(ex1)[1], MetaData) == :metadata ex = a ex = setmetadata(ex, MetaData, :metadata) ex1 = ex + b - @test getmetadata(arguments(ex1)[1], MetaData) == :metadata + @test getmetadata(sorted_arguments(ex1)[1], MetaData) == :metadata ex = a * b ex = setmetadata(ex, MetaData, :metadata) ex1 = ex * c @test SymbolicUtils.isterm(ex1) - @test getmetadata(arguments(ex1)[1], MetaData) == :metadata + @test getmetadata(sorted_arguments(ex1)[1], MetaData) == :metadata ex = a ex = setmetadata(ex, MetaData, :metadata) ex1 = ex * b - @test getmetadata(arguments(ex1)[1], MetaData) == :metadata + @test getmetadata(sorted_arguments(ex1)[1], MetaData) == :metadata end \ No newline at end of file