From c4c78c8daabd2e25420af51bceaaf26c2f5d1dcb Mon Sep 17 00:00:00 2001 From: a Date: Sun, 28 Apr 2024 22:59:21 +0200 Subject: [PATCH 1/2] new egraph stuff --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 92e5f0e4a..a94aec529 100644 --- a/Project.toml +++ b/Project.toml @@ -15,6 +15,7 @@ DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07" IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Metatheory = "e9d8d322-4543-424a-9be4-0cc815abe26c" MultivariatePolynomials = "102ac46a-7ee4-5c85-9060-abc95bfdeaa3" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" From ef35bf3ff7f2327b752fd35da63e3a6265b0a301 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 2 May 2024 21:25:58 +0200 Subject: [PATCH 2/2] add egraphs --- src/egraphs.jl | 88 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/types.jl | 6 ++-- 2 files changed, 91 insertions(+), 3 deletions(-) create mode 100644 src/egraphs.jl diff --git a/src/egraphs.jl b/src/egraphs.jl new file mode 100644 index 000000000..a169f2245 --- /dev/null +++ b/src/egraphs.jl @@ -0,0 +1,88 @@ +using Metatheory, SymbolicUtils +using SymbolicUtils: Symbolic, BasicSymbolic, unflatten, toterm +using Metatheory: @rule + +using SymbolicUtils: <ₑ + +function EGraphs.preprocess(t::BasicSymbolic) + toterm(unflatten(t)) +end + + +""" +Equational rewrite rules for optimizing expressions +""" +opt_theory = @theory a b x y begin + a + b == b + a + a * b == b * a + a * x + a * y == a * (x + y) + -1 * a == -a + a + (-1 * b) == a - b + x^-1 == 1 / x + 1 / x * a == a / x + # fraction rules + # (a/b) + (c/b) => (a+c)/b + # trig functions + sin(x) / cos(x) == tan(x) + cos(x) / sin(x) == cot(x) + sin(x)^2 + cos(x)^2 --> 1 + sin(2a) == 2sin(a)cos(a) +end + + +""" +Approximation of costs of operators in number +of CPU cycles required for the numerical computation + +See + * https://latkin.org/blog/2014/11/09/a-simple-benchmark-of-various-math-operations/ + * https://streamhpc.com/blog/2012-07-16/how-expensive-is-an-operation-on-a-cpu/ + * https://github.com/triscale-innov/GFlops.jl +""" +const op_costs = Dict( + (+) => 1, + (-) => 1, + abs => 2, + (*) => 3, + exp => 18, + (/) => 24, + (^) => 100, + log1p => 124, + deg2rad => 125, + rad2deg => 125, + acos => 127, + asind => 128, + acsch => 133, + sin => 134, + cos => 134, + atan => 135, + tan => 156, +) +# TODO some operator costs are in FLOP and not in cycles!! + +function costfun(n::VecExpr, op, children_costs::Vector{Float64})::Float64 + v_isexpr(n) || return 0 #1 + get(op_costs, op, 1) + sum(children_costs) +end + + +function optimize(ex; params=SaturationParams(timeout=20)) + # @show ex + g = EGraph{BasicSymbolic}(ex) + saturate!(g, opt_theory, params) + return extract!(g, costfun) +end + + + +# ======================================================================= + +@syms x y z a b c + +expr = ((a * b) + (a * c)) / ((x * y) + (x * z) ) + +g = EGraph{BasicSymbolic}(expr) +saturate!(g, opt_theory, SaturationParams()) +return extract!(g, costfun) + +@benchmark optimize(expr) \ No newline at end of file diff --git a/src/types.jl b/src/types.jl index f9ea33121..f696a1ca8 100644 --- a/src/types.jl +++ b/src/types.jl @@ -535,8 +535,8 @@ end unflatten(t) = t -function TermInterface.maketerm(::Type{<:BasicSymbolic}, head, args, type, metadata) - basicsymbolic(first(args), args[2:end], type, metadata) +function TermInterface.maketerm(::Type{<:BasicSymbolic}, op, args, type=nothing, metadata=nothing) + basicsymbolic(op, args, type, metadata) end @@ -649,7 +649,7 @@ function similarterm(x, op, args, symtype=nothing; metadata=nothing) The present call can be replaced by `maketerm(typeof(x), $(head(x)), [op, args...], symtype, metadata)`""", :similarterm) - TermInterface.maketerm(typeof(x), callhead(x), [op, args...], symtype, metadata) + TermInterface.maketerm(typeof(x), op, args, symtype, metadata) end # Old fallback