Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding OMEinsumContractionOrders.jl as a backend of TensorOperations.jl for finding the optimal contraction order #185

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
*.jl.cov
*.jl.*.cov
*.jl.mem
Manifest.toml
Manifest.toml
.vscode
.DS_Store
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

[extensions]
TensorOperationsBumperExt = "Bumper"
TensorOperationsChainRulesCoreExt = "ChainRulesCore"
TensorOperationsOMEinsumContractionOrdersExt = "OMEinsumContractionOrders"
TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"]

[compat]
Expand All @@ -37,6 +39,7 @@ DynamicPolynomials = "0.5"
LRUCache = "1"
LinearAlgebra = "1.6"
Logging = "1.6"
OMEinsumContractionOrders = "0.9.2"
PackageExtensionCompat = "1"
PtrArrays = "1.2"
Random = "1"
Expand All @@ -55,9 +58,10 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

[targets]
test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper"]
test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper", "OMEinsumContractionOrders"]
113 changes: 113 additions & 0 deletions ext/TensorOperationsOMEinsumContractionOrdersExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
module TensorOperationsOMEinsumContractionOrdersExt

using TensorOperations
using TensorOperations: TensorOperations as TO
using TensorOperations: TreeOptimizer
using OMEinsumContractionOrders
using OMEinsumContractionOrders: EinCode, NestedEinsum, SlicedEinsum, isleaf,
optimize_kahypar_auto

function TO.optimaltree(network, optdata::Dict{TDK,TDV}, ::TreeOptimizer{:GreedyMethod},
verbose::Bool) where {TDK,TDV}
@debug "Using optimizer GreedyMethod from OMEinsumContractionOrders"
ome_optimizer = GreedyMethod()
return optimize(network, optdata, ome_optimizer, verbose)
end

function TO.optimaltree(network, optdata::Dict{TDK,TDV}, ::TreeOptimizer{:KaHyParBipartite},
verbose::Bool) where {TDK,TDV}
@debug "Using optimizer KaHyParBipartite from OMEinsumContractionOrders"
return optimize_kahypar(network, optdata, verbose)
end

function TO.optimaltree(network, optdata::Dict{TDK,TDV}, ::TreeOptimizer{:TreeSA},
verbose::Bool) where {TDK,TDV}
@debug "Using optimizer TreeSA from OMEinsumContractionOrders"
ome_optimizer = TreeSA()
return optimize(network, optdata, ome_optimizer, verbose)
end

function TO.optimaltree(network, optdata::Dict{TDK,TDV}, ::TreeOptimizer{:SABipartite},
verbose::Bool) where {TDK,TDV}
@debug "Using optimizer SABipartite from OMEinsumContractionOrders"
ome_optimizer = SABipartite()
return optimize(network, optdata, ome_optimizer, verbose)
end

function TO.optimaltree(network, optdata::Dict{TDK,TDV}, ::TreeOptimizer{:ExactTreewidth},
verbose::Bool) where {TDK,TDV}
@debug "Using optimizer ExactTreewidth from OMEinsumContractionOrders"
ome_optimizer = ExactTreewidth()
return optimize(network, optdata, ome_optimizer, verbose)
end

function optimize(network, optdata::Dict{TDK,TDV}, ome_optimizer::CodeOptimizer,
verbose::Bool) where {TDK,TDV}
@assert TDV <: Number "The values of `optdata` dictionary must be of `<:Number`"

# transform the network as EinCode
code, size_dict = network2eincode(network, optdata)
# optimize the contraction order using OMEinsumContractionOrders, which gives a NestedEinsum
optcode = optimize_code(code, size_dict, ome_optimizer)

# transform the optimized contraction order back to the network
optimaltree = eincode2contractiontree(optcode)

# calculate the complexity of the contraction
cc = OMEinsumContractionOrders.contraction_complexity(optcode, size_dict)
if verbose
println("Optimal contraction tree: ", optimaltree)
println(cc)

Check warning on line 60 in ext/TensorOperationsOMEinsumContractionOrdersExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorOperationsOMEinsumContractionOrdersExt.jl#L59-L60

Added lines #L59 - L60 were not covered by tests
end
return optimaltree, 2.0^(cc.tc)
end

function optimize_kahypar(network, optdata::Dict{TDK,TDV}, verbose::Bool) where {TDK,TDV}
@assert TDV <: Number "The values of `optdata` dictionary must be of `<:Number`"

# transform the network as EinCode
code, size_dict = network2eincode(network, optdata)
# optimize the contraction order using OMEinsumContractionOrders, which gives a NestedEinsum
optcode = optimize_kahypar_auto(code, size_dict)

# transform the optimized contraction order back to the network
optimaltree = eincode2contractiontree(optcode)

# calculate the complexity of the contraction
cc = OMEinsumContractionOrders.contraction_complexity(optcode, size_dict)
if verbose
println("Optimal contraction tree: ", optimaltree)
println(cc)

Check warning on line 80 in ext/TensorOperationsOMEinsumContractionOrdersExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorOperationsOMEinsumContractionOrdersExt.jl#L79-L80

Added lines #L79 - L80 were not covered by tests
end
return optimaltree, 2.0^(cc.tc)
end

function network2eincode(network, optdata)
indices = unique(vcat(network...))
new_indices = Dict([i => j for (j, i) in enumerate(indices)])
new_network = [Int[new_indices[i] for i in t] for t in network]
open_edges = Int[]
# if a indices appear only once, it is an open index
for i in indices
if sum([i in t for t in network]) == 1
push!(open_edges, new_indices[i])
end
end
size_dict = Dict([new_indices[i] => optdata[i] for i in keys(optdata)])
return EinCode(new_network, open_edges), size_dict
end

function eincode2contractiontree(eincode::NestedEinsum)
if isleaf(eincode)
return eincode.tensorindex
else
return [eincode2contractiontree(arg) for arg in eincode.args]
end
end

# TreeSA returns a SlicedEinsum, with nslice = 0, so directly using the eins
function eincode2contractiontree(eincode::SlicedEinsum)
return eincode2contractiontree(eincode.eins)
end

end
5 changes: 5 additions & 0 deletions src/TensorOperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ export IndexTuple, Index2Tuple, linearize
# export debug functionality
export checkcontractible, tensorcost

# export optimizer
export TreeOptimizer, ExhaustiveSearchOptimizer, GreedyMethodOptimizer,
KaHyParBipartiteOptimizer, TreeSAOptimizer, SABipartiteOptimizer,
ExactTreewidthOptimizer

# Interface and index types
#---------------------------
include("indices.jl")
Expand Down
39 changes: 35 additions & 4 deletions src/implementation/ncon.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
ncon(tensorlist, indexlist, [conjlist, sym]; order = ..., output = ..., backend = ..., allocator = ...)
ncon(tensorlist, indexlist, [conjlist, sym]; order = ..., output = ..., optimizer = ..., backend = ..., allocator = ...)

Contract the tensors in `tensorlist` (of type `Vector` or `Tuple`) according to the network
as specified by `indexlist`. Here, `indexlist` is a list (i.e. a `Vector` or `Tuple`) with
Expand All @@ -20,11 +20,16 @@ over are labelled by increasing integers, i.e. first the contraction correspondi
(negative, so increasing in absolute value) index labels. The keyword arguments `order` and
`output` allow to change these defaults.

Another way to get the contraction order is to use the TreeOptimizer, by passing the `optimizer`
instead of the `order` keyword argument. The `optimizer` can be `:ExhaustiveSearch`.
With the extension `OMEinsumContractionOrders`, the `optimizer` can be one of the following:
`:GreedyMethod`, `:TreeSA`, `:KaHyParBipartite`, `:SABipartite`, `:ExactTreewidth`.

See also the macro version [`@ncon`](@ref).
"""
function ncon(tensors, network,
conjlist=fill(false, length(tensors));
order=nothing, output=nothing, kwargs...)
order=nothing, output=nothing, optimizer=nothing, kwargs...)
length(tensors) == length(network) == length(conjlist) ||
throw(ArgumentError("number of tensors and of index lists should be the same"))
isnconstyle(network) || throw(ArgumentError("invalid NCON network: $network"))
Expand All @@ -39,11 +44,37 @@ function ncon(tensors, network,
end

(tensors, network) = resolve_traces(tensors, network)
tree = order === nothing ? ncontree(network) : indexordertree(network, order)

if isnothing(order)
if isnothing(optimizer)
# not specifing order and optimizer, tree via ncontree
tree = ncontree(network)
else
# order via optimizer
optdata = Dict{Any,Number}()
for (i, ids) in enumerate(network)
for (j, id) in enumerate(ids)
optdata[id] = tensorstructure(tensors[i], j, conjlist[i])
end
end
tree = optimaltree(network, optdata, optimizer, false)[1]
end
else
if !isnothing(optimizer)
throw(ArgumentError("cannot specify both `order` and `optimizer`"))
else
# with given order, tree via indexordertree
tree = indexordertree(network, order)
end
end

return ncon(tensors, network, conjlist, tree, output′; kwargs...)
end

function ncon(tensors, network, conjlist, tree, output; kwargs...)
A, IA, conjA = contracttree(tensors, network, conjlist, tree[1]; kwargs...)
B, IB, conjB = contracttree(tensors, network, conjlist, tree[2]; kwargs...)
IC = tuple(output...)
IC = tuple(output...)
C = tensorcontract(IC, A, IA, conjA, B, IB, conjB; kwargs...)
allocator = haskey(kwargs, :allocator) ? kwargs[:allocator] : DefaultAllocator()
tree[1] isa Int || tensorfree!(A, allocator)
Expand Down
22 changes: 21 additions & 1 deletion src/indexnotation/optimaltree.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,24 @@
function optimaltree(network, optdata::Dict; verbose::Bool=false)
struct TreeOptimizer{T} end # T is a Symbol for the algorithm
ExhaustiveSearchOptimizer() = TreeOptimizer{:ExhaustiveSearch}()
GreedyMethodOptimizer() = TreeOptimizer{:GreedyMethod}()
KaHyParBipartiteOptimizer() = TreeOptimizer{:KaHyParBipartite}()
TreeSAOptimizer() = TreeOptimizer{:TreeSA}()
SABipartiteOptimizer() = TreeOptimizer{:SABipartite}()
ExactTreewidthOptimizer() = TreeOptimizer{:ExactTreewidth}()

function optimaltree(network, optdata::Dict;
optimizer::TreeOptimizer{T}=TreeOptimizer{:ExhaustiveSearch}(),
verbose::Bool=false) where {T}
return optimaltree(network, optdata, optimizer, verbose)
end

function optimaltree(network, optdata::Dict, ::TreeOptimizer{T}, verbose::Bool) where {T}
throw(ArgumentError("Unknown optimizer: $T. Hint: may need to load extensions, e.g. `using OMEinsumContractionOrders`"))
end

function optimaltree(network, optdata::Dict, ::TreeOptimizer{:ExhaustiveSearch},
verbose::Bool)
@debug "Using optimizer ExhaustiveSearch"
numtensors = length(network)
allindices = unique(vcat(network...))
numindices = length(allindices)
Expand Down
25 changes: 19 additions & 6 deletions src/indexnotation/tensormacros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@
end
end
# now handle the remaining keyword arguments
optimizer = TreeOptimizer{:ExhaustiveSearch}() # the default optimizer implemented in TensorOperations.jl
optval = nothing
for (name, val) in kwargs
if name == :order
isexpr(val, :tuple) ||
Expand All @@ -86,18 +88,29 @@
throw(ArgumentError("Invalid use of `costcheck`, should be `costcheck=warn` or `costcheck=cache`"))
parser.contractioncostcheck = val
elseif name == :opt
if val isa Bool && val
optdict = optdata(tensorexpr)
elseif val isa Expr
optdict = optdata(val, tensorexpr)
optval = val
elseif name == :opt_algorithm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here you will have to be a little careful, in principle there is no order to the keyword arguments.
If I am not mistaken, now if the user first supplies opt=(a = 2, b = 2, ...), and only afterwards opt_algorithm=..., the algorithm will be ignored.

My best guess is that you probably want to attempt to extract an optimizer and optdict, and only after all kwargs have been parsed, you can construct the contractiontreebuilder

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for pointing that out, I did not notice that perviously.
In the revised version, the contractiontreebuilder will be constructed after all other kwargs have been parsed.

if val isa Symbol
optimizer = TreeOptimizer{val}()
else
throw(ArgumentError("Invalid use of `opt`, should be `opt=true` or `opt=OptExpr`"))
throw(ArgumentError("Invalid use of `opt_algorithm`, should be `opt_algorithm=ExhaustiveSearch` or `opt_algorithm=NameOfAlgorithm`"))
end
parser.contractiontreebuilder = network -> optimaltree(network, optdict)[1]
elseif !(name == :backend || name == :allocator) # these two have been handled
throw(ArgumentError("Unknown keyword argument `name`."))
end
end
# construct the contraction tree builder after all keyword arguments have been processed
if !isnothing(optval)
if optval isa Bool && optval
optdict = optdata(tensorexpr)
elseif optval isa Expr
optdict = optdata(optval, tensorexpr)
else
throw(ArgumentError("Invalid use of `opt`, should be `opt=true` or `opt=OptExpr`"))

Check warning on line 109 in src/indexnotation/tensormacros.jl

View check run for this annotation

Codecov / codecov/patch

src/indexnotation/tensormacros.jl#L109

Added line #L109 was not covered by tests
end
parser.contractiontreebuilder = network -> optimaltree(network, optdict;
optimizer=optimizer)[1]
end
return parser
end

Expand Down
16 changes: 16 additions & 0 deletions test/macro_kwargs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,19 @@ end
end
@test D1 ≈ D2 ≈ D3 ≈ D4 ≈ D5
end

@testset "opt_algorithm" begin
A = randn(5, 5, 5, 5)
B = randn(5, 5, 5)
C = randn(5, 5, 5)

@tensor opt = true begin
D1[a, b, c, d] := A[a, e, c, f] * B[g, d, e] * C[g, f, b]
end

@tensor opt = true opt_algorithm = ExhaustiveSearch begin
D2[a, b, c, d] := A[a, e, c, f] * B[g, d, e] * C[g, f, b]
end

@test D1 ≈ D2
end
Loading
Loading