From c4c78c8daabd2e25420af51bceaaf26c2f5d1dcb Mon Sep 17 00:00:00 2001
From: a
Date: Sun, 28 Apr 2024 22:59:21 +0200
Subject: [PATCH 1/2] new egraph stuff
---
Project.toml | 1 +
1 file changed, 1 insertion(+)
diff --git a/Project.toml b/Project.toml
index 92e5f0e4a..a94aec529 100644
--- a/Project.toml
+++ b/Project.toml
@@ -15,6 +15,7 @@ DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07"
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+Metatheory = "e9d8d322-4543-424a-9be4-0cc815abe26c"
MultivariatePolynomials = "102ac46a-7ee4-5c85-9060-abc95bfdeaa3"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
From ef35bf3ff7f2327b752fd35da63e3a6265b0a301 Mon Sep 17 00:00:00 2001
From: a
Date: Thu, 2 May 2024 21:25:58 +0200
Subject: [PATCH 2/2] add egraphs
---
src/egraphs.jl | 88 ++++++++++++++++++++++++++++++++++++++++++++++++++
src/types.jl | 6 ++--
2 files changed, 91 insertions(+), 3 deletions(-)
create mode 100644 src/egraphs.jl
diff --git a/src/egraphs.jl b/src/egraphs.jl
new file mode 100644
index 000000000..a169f2245
--- /dev/null
+++ b/src/egraphs.jl
@@ -0,0 +1,88 @@
+using Metatheory, SymbolicUtils
+using SymbolicUtils: Symbolic, BasicSymbolic, unflatten, toterm
+using Metatheory: @rule
+
+using SymbolicUtils: <ₑ
+
+function EGraphs.preprocess(t::BasicSymbolic)
+ toterm(unflatten(t))
+end
+
+
+"""
+Equational rewrite rules for optimizing expressions
+"""
+opt_theory = @theory a b x y begin
+ a + b == b + a
+ a * b == b * a
+ a * x + a * y == a * (x + y)
+ -1 * a == -a
+ a + (-1 * b) == a - b
+ x^-1 == 1 / x
+ 1 / x * a == a / x
+ # fraction rules
+ # (a/b) + (c/b) => (a+c)/b
+ # trig functions
+ sin(x) / cos(x) == tan(x)
+ cos(x) / sin(x) == cot(x)
+ sin(x)^2 + cos(x)^2 --> 1
+ sin(2a) == 2sin(a)cos(a)
+end
+
+
+"""
+Approximation of costs of operators in number
+of CPU cycles required for the numerical computation
+
+See
+ * https://latkin.org/blog/2014/11/09/a-simple-benchmark-of-various-math-operations/
+ * https://streamhpc.com/blog/2012-07-16/how-expensive-is-an-operation-on-a-cpu/
+ * https://github.com/triscale-innov/GFlops.jl
+"""
+const op_costs = Dict(
+ (+) => 1,
+ (-) => 1,
+ abs => 2,
+ (*) => 3,
+ exp => 18,
+ (/) => 24,
+ (^) => 100,
+ log1p => 124,
+ deg2rad => 125,
+ rad2deg => 125,
+ acos => 127,
+ asind => 128,
+ acsch => 133,
+ sin => 134,
+ cos => 134,
+ atan => 135,
+ tan => 156,
+)
+# TODO some operator costs are in FLOP and not in cycles!!
+
+function costfun(n::VecExpr, op, children_costs::Vector{Float64})::Float64
+ v_isexpr(n) || return 0 #1
+ get(op_costs, op, 1) + sum(children_costs)
+end
+
+
+function optimize(ex; params=SaturationParams(timeout=20))
+ # @show ex
+ g = EGraph{BasicSymbolic}(ex)
+ saturate!(g, opt_theory, params)
+ return extract!(g, costfun)
+end
+
+
+
+# =======================================================================
+
+@syms x y z a b c
+
+expr = ((a * b) + (a * c)) / ((x * y) + (x * z) )
+
+g = EGraph{BasicSymbolic}(expr)
+saturate!(g, opt_theory, SaturationParams())
+return extract!(g, costfun)
+
+@benchmark optimize(expr)
\ No newline at end of file
diff --git a/src/types.jl b/src/types.jl
index f9ea33121..f696a1ca8 100644
--- a/src/types.jl
+++ b/src/types.jl
@@ -535,8 +535,8 @@ end
unflatten(t) = t
-function TermInterface.maketerm(::Type{<:BasicSymbolic}, head, args, type, metadata)
- basicsymbolic(first(args), args[2:end], type, metadata)
+function TermInterface.maketerm(::Type{<:BasicSymbolic}, op, args, type=nothing, metadata=nothing)
+ basicsymbolic(op, args, type, metadata)
end
@@ -649,7 +649,7 @@ function similarterm(x, op, args, symtype=nothing; metadata=nothing)
The present call can be replaced by
`maketerm(typeof(x), $(head(x)), [op, args...], symtype, metadata)`""", :similarterm)
- TermInterface.maketerm(typeof(x), callhead(x), [op, args...], symtype, metadata)
+ TermInterface.maketerm(typeof(x), op, args, symtype, metadata)
end
# Old fallback