Skip to content

Commit

Permalink
Add naive Base julia AbstractArray implementation (#171)
Browse files Browse the repository at this point in the history
* Add naive Base julia AbstractArray implementation

* Use VectorInterface

* Re-enable Strided Diagonal implementation

* finish base implementations

* restrict strided to strided

* restore formatting diagonal

* some fixes

* one more fix

* Fix some stray `pC`s to `pAB`

* Fix some stray `Symbol` to `Bool`

* finally finish and test base implementations

* a fix and more tests

* more tests

---------

Co-authored-by: Jutho Haegeman <[email protected]>
  • Loading branch information
lkdvos and Jutho authored Jun 23, 2024
1 parent 65cad25 commit fabfb08
Show file tree
Hide file tree
Showing 10 changed files with 364 additions and 68 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Logging = "1.6"
PackageExtensionCompat = "1"
Random = "1"
Strided = "2.0.4"
StridedViews = "0.2"
StridedViews = "0.3"
Test = "1"
TupleTools = "1.1"
VectorInterface = "0.4.1"
Expand Down
1 change: 1 addition & 0 deletions src/TensorOperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using TupleTools: TupleTools, isperm, invperm
using LinearAlgebra
using LinearAlgebra: mul!, BlasFloat
using Strided
using StridedViews: isstrided
using LRUCache

using Base.Meta: isexpr
Expand Down
240 changes: 232 additions & 8 deletions src/implementation/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,54 @@ function checkcontractible(A::AbstractArray, iA, conjA::Bool,
return nothing
end

# TODO
# add check for stridedness of abstract arrays and add a pure implementation as fallback

const StridedNative = Backend{:StridedNative}
const StridedBLAS = Backend{:StridedBLAS}
const BaseView = Backend{:BaseView}
const BaseCopy = Backend{:BaseCopy}

function tensoradd!(C::AbstractArray,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
α::Number, β::Number)
return tensoradd!(C, A, pA, conjA, α, β, StridedNative())
if isstrided(A) && isstrided(C)
return tensoradd!(C, A, pA, conjA, α, β, StridedNative())
else
return tensoradd!(C, A, pA, conjA, α, β, BaseView())
end
end

function tensortrace!(C::AbstractArray,
A::AbstractArray, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
α::Number, β::Number)
return tensortrace!(C, A, p, q, conjA, α, β, StridedNative())
if isstrided(A) && isstrided(C)
return tensortrace!(C, A, p, q, conjA, α, β, StridedNative())
else
return tensortrace!(C, A, p, q, conjA, α, β, BaseView())
end
end

function tensorcontract!(C::AbstractArray,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
B::AbstractArray, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple,
α::Number, β::Number)
if eltype(C) <: LinearAlgebra.BlasFloat && !isa(B, Diagonal) && !isa(A, Diagonal)
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, StridedBLAS())
if eltype(C) <: LinearAlgebra.BlasFloat
if isstrided(A) && isstrided(B) && isstrided(C)
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, StridedBLAS())
elseif (isstrided(A) || isa(A, Diagonal)) && (isstrided(B) || isa(B, Diagonal)) &&
isstrided(C)
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β,
StridedNative())
else
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, BaseCopy())
end
else
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, StridedNative())
if (isstrided(A) || isa(A, Diagonal)) && (isstrided(B) || isa(B, Diagonal)) &&
isstrided(C)
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β,
StridedNative())
else
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, BaseView())
end
end
end

Expand Down Expand Up @@ -81,6 +102,209 @@ function tensorcontract!(C::AbstractArray,
return C
end

#-------------------------------------------------------------------------------------------
# Implementation based on Base + LinearAlgebra
#-------------------------------------------------------------------------------------------
# Note that this is mostly for convenience + checking, and not for performance
function tensoradd!(C::AbstractArray,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
α::Number, β::Number, ::BaseView)
argcheck_tensoradd(C, A, pA)
dimcheck_tensoradd(C, A, pA)

# can we assume that C is mutable?
# is there more functionality in base that we can use?
if conjA
if iszero(β)
C .= α .* conj.(PermutedDimsArray(A, linearize(pA)))
else
C .= β .* C .+ α .* conj.(PermutedDimsArray(A, linearize(pA)))
end
else
if iszero(β)
C .= α .* PermutedDimsArray(A, linearize(pA))
else
C .= β .* C .+ α .* PermutedDimsArray(A, linearize(pA))
end
end
return C
end
function tensoradd!(C::AbstractArray,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
α::Number, β::Number, ::BaseCopy)
argcheck_tensoradd(C, A, pA)
dimcheck_tensoradd(C, A, pA)

# can we assume that C is mutable?
# is there more functionality in base that we can use?
if conjA
if iszero(β)
C .= α .* conj.(permutedims(A, linearize(pA)))
else
C .= β .* C .+ α .* conj.(permutedims(A, linearize(pA)))
end
else
if iszero(β)
C .= α .* permutedims(A, linearize(pA))
else
C .= β .* C .+ α .* permutedims(A, linearize(pA))
end
end
return C
end

# tensortrace
function tensortrace!(C::AbstractArray,
A::AbstractArray, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
α::Number, β::Number, ::BaseView)
argcheck_tensortrace(C, A, p, q)
dimcheck_tensortrace(C, A, p, q)

szA = size(A)
so = TupleTools.getindices(szA, linearize(p))
st = prod(TupleTools.getindices(szA, q[1]))
= reshape(PermutedDimsArray(A, (linearize(p)..., linearize(q)...)),
(prod(so), st, st))

if conjA
if iszero(β)
C .= α .* conj.(reshape(view(Ã, :, 1, 1), so))
else
C .= β .* C .+ α .* conj.(reshape(view(Ã, :, 1, 1), so))
end
for i in 2:st
C .+= α .* conj.(reshape(view(Ã, :, i, i), so))
end
else
if iszero(β)
C .= α .* reshape(view(Ã, :, 1, 1), so)
else
C .= β .* C .+ α .* reshape(view(Ã, :, 1, 1), so)
end
for i in 2:st
C .+= α .* reshape(view(Ã, :, i, i), so)
end
end
return C
end
function tensortrace!(C::AbstractArray,
A::AbstractArray, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
α::Number, β::Number, ::BaseCopy)
argcheck_tensortrace(C, A, p, q)
dimcheck_tensortrace(C, A, p, q)

szA = size(A)
so = TupleTools.getindices(szA, linearize(p))
st = prod(TupleTools.getindices(szA, q[1]))
= reshape(permutedims(A, (linearize(p)..., linearize(q)...)), (prod(so), st, st))

if conjA
if iszero(β)
C .= α .* conj.(reshape(view(Ã, :, 1, 1), so))
else
C .= β .* C .+ α .* conj.(reshape(view(Ã, :, 1, 1), so))
end
for i in 2:st
C .+= α .* conj.(reshape(view(Ã, :, i, i), so))
end
else
if iszero(β)
C .= α .* reshape(view(Ã, :, 1, 1), so)
else
C .= β .* C .+ α .* reshape(view(Ã, :, 1, 1), so)
end
for i in 2:st
C .+= α .* reshape(view(Ã, :, i, i), so)
end
end
return C
end

function tensorcontract!(C::AbstractArray,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
B::AbstractArray, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple,
α::Number, β::Number, ::BaseView)
argcheck_tensorcontract(C, A, pA, B, pB, pAB)
dimcheck_tensorcontract(C, A, pA, B, pB, pAB)

szA = size(A)
szB = size(B)
soA = TupleTools.getindices(szA, pA[1])
soB = TupleTools.getindices(szB, pB[2])
sc = TupleTools.getindices(szA, pA[2])
soA1 = prod(soA)
soB1 = prod(soB)
sc1 = prod(sc)
pC = invperm(linearize(pAB))
= reshape(PermutedDimsArray(C, pC), (soA1, soB1))

if conjA && conjB
= reshape(PermutedDimsArray(A, linearize(reverse(pA))), (sc1, soA1))
= reshape(PermutedDimsArray(B, linearize(reverse(pB))), (soB1, sc1))
= mul!(C̃, adjoint(Ã), adjoint(B̃), α, β)
elseif conjA
= reshape(PermutedDimsArray(A, linearize(reverse(pA))), (sc1, soA1))
= reshape(PermutedDimsArray(B, linearize(pB)), (sc1, soB1))
= mul!(C̃, adjoint(Ã), B̃, α, β)
elseif conjB
= reshape(PermutedDimsArray(A, linearize(pA)), (soA1, sc1))
= reshape(PermutedDimsArray(B, linearize(reverse(pB))), (soB1, sc1))
= mul!(C̃, Ã, adjoint(B̃), α, β)
else
= reshape(PermutedDimsArray(A, linearize(pA)), (soA1, sc1))
= reshape(PermutedDimsArray(B, linearize(pB)), (sc1, soB1))
= mul!(C̃, Ã, B̃, α, β)
end
return C
end
function tensorcontract!(C::AbstractArray,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
B::AbstractArray, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple,
α::Number, β::Number, ::BaseCopy)
argcheck_tensorcontract(C, A, pA, B, pB, pAB)
dimcheck_tensorcontract(C, A, pA, B, pB, pAB)

szA = size(A)
szB = size(B)
soA = TupleTools.getindices(szA, pA[1])
soB = TupleTools.getindices(szB, pB[2])
sc = TupleTools.getindices(szA, pA[2])
soA1 = prod(soA)
soB1 = prod(soB)
sc1 = prod(sc)

if conjA && conjB
= reshape(permutedims(A, linearize(reverse(pA))), (sc1, soA1))
= reshape(permutedims(B, linearize(reverse(pB))), (soB1, sc1))
ÃB̃ = reshape(adjoint(Ã) * adjoint(B̃), (soA..., soB...))
elseif conjA
= reshape(permutedims(A, linearize(reverse(pA))), (sc1, soA1))
= reshape(permutedims(B, linearize(pB)), (sc1, soB1))
ÃB̃ = reshape(adjoint(Ã) * B̃, (soA..., soB...))
elseif conjB
= reshape(permutedims(A, linearize(pA)), (soA1, sc1))
= reshape(permutedims(B, linearize(reverse(pB))), (soB1, sc1))
ÃB̃ = reshape(Ã * adjoint(B̃), (soA..., soB...))
else
= reshape(permutedims(A, linearize(pA)), (soA1, sc1))
= reshape(permutedims(B, linearize(pB)), (sc1, soB1))
ÃB̃ = reshape(Ã * B̃, (soA..., soB...))
end
if istrivialpermutation(linearize(pAB))
pÃB̃ = ÃB̃
else
pÃB̃ = permutedims(ÃB̃, linearize(pAB))
end
if iszero(β)
C .= α .* pÃB̃
else
C .= β .* C .+ α .* pÃB̃
end
return C
end

# ------------------------------------------------------------------------------------------
# Argument Checking: can be used by backends to check the validity of the arguments
# ------------------------------------------------------------------------------------------
Expand Down
14 changes: 14 additions & 0 deletions src/implementation/diagonal.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
#-------------------------------------------------------------------------------------------
# Specialized implementations for contractions involving diagonal matrices
#-------------------------------------------------------------------------------------------

# backend selection:
for (TC, TA, TB) in ((:AbstractArray, :AbstractArray, :Diagonal),
(:AbstractArray, :Diagonal, :AbstractArray), (:AbstractArray, :Diagonal, :Diagonal),
(:Diagonal, :Diagonal, :Diagonal))
@eval function tensorcontract!(C::$TC,
A::$TA, pA::Index2Tuple, conjA::Bool,
B::$TB, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple, α::Number, β::Number)
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, StridedNative())
end
end

# actual implementations:
function tensorcontract!(C::AbstractArray,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
B::Diagonal, pB::Index2Tuple, conjB::Bool,
Expand Down
2 changes: 1 addition & 1 deletion src/implementation/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

"""
tensorcopy([IC=IA], A, IA, [conjA=false, [α=1]])
tensorcopy(A, pA::Index2Tuple, conjA, α) # expert mode
tensorcopy(A, pA::Index2Tuple, conjA, α, [backend]) # expert mode
Create a copy of `A`, where the dimensions of `A` are assigned indices from the
iterable `IA` and the indices of the copy are contained in `IC`. Both iterables
Expand Down
8 changes: 4 additions & 4 deletions src/implementation/strided.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@

# default backends
function tensoradd!(C::StridedView,
A::StridedView, pA::Index2Tuple, conjA::Symbol,
A::StridedView, pA::Index2Tuple, conjA::Bool,
α::Number, β::Number)
backend = eltype(C) isa BlasFloat ? StridedBLAS() : StridedNative()
return tensoradd!(C, A, pA, conjA, α, β, backend)
end
function tensortrace!(C::StridedView,
A::StridedView, p::Index2Tuple, q::Index2Tuple, conjA::Symbol,
A::StridedView, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
α::Number, β::Number)
backend = eltype(C) isa BlasFloat ? StridedBLAS() : StridedNative()
return tensortrace!(C, A, p, q, conjA, α, β, backend)
end
function tensorcontract!(C::StridedView,
A::StridedView, pA::Index2Tuple, conjA::Symbol,
B::StridedView, pB::Index2Tuple, conjB::Symbol,
A::StridedView, pA::Index2Tuple, conjA::Bool,
B::StridedView, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple, α::Number, β::Number)
backend = eltype(C) isa BlasFloat ? StridedBLAS() : StridedNative()
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, backend)
Expand Down
38 changes: 17 additions & 21 deletions src/indexnotation/contractiontrees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,13 @@ function insertcontractiontrees!(ex, treebuilder, treesorter, costcheck, preexpr
optimalordersym = gensym("optimalorder")
if costcheck == :warn
costcompareex = :(@notensor begin
$currentcostsym = first(TensorOperations.treecost($tree,
$network,
$costmapsym))
$optimaltreesym, $optimalcostsym = TensorOperations.optimaltree($network,
$costmapsym)
$currentcostsym = first(treecost($tree, $network,
$costmapsym))
$optimaltreesym, $optimalcostsym = optimaltree($network,
$costmapsym)
if $currentcostsym > $optimalcostsym
$optimalordersym = tuple(first(TensorOperations.tree2indexorder($optimaltreesym,
$network))...)
$optimalordersym = tuple(first(tree2indexorder($optimaltreesym,
$network))...)
@warn "Tensor network: " *
$(string(ex)) *
":\n" *
Expand All @@ -132,21 +131,18 @@ function insertcontractiontrees!(ex, treebuilder, treesorter, costcheck, preexpr
end)
elseif costcheck == :cache
key = Expr(:quote, ex)
cacheref = GlobalRef(TensorOperations, :costcache)
costcompareex = :(@notensor begin
$currentcostsym = first(TensorOperations.treecost($tree,
$network,
$costmapsym))
if !($key in
keys(TensorOperations.costcache)) ||
first(TensorOperations.costcache[$key]) <
$(currentcostsym)
$optimaltreesym, $optimalcostsym = TensorOperations.optimaltree($network,
$costmapsym)
$optimalordersym = tuple(first(TensorOperations.tree2indexorder($optimaltreesym,
$network))...)
TensorOperations.costcache[$key] = ($currentcostsym,
$optimalcostsym,
$optimalordersym)
$currentcostsym = first(treecost($tree, $network,
$costmapsym))
if !($key in keys($cacheref)) ||
first($cacheref[$key]) < $(currentcostsym)
$optimaltreesym, $optimalcostsym = optimaltree($network,
$costmapsym)
$optimalordersym = tuple(first(tree2indexorder($optimaltreesym,
$network))...)
$cacheref[$key] = ($currentcostsym, $optimalcostsym,
$optimalordersym)
end
end)
end
Expand Down
Loading

0 comments on commit fabfb08

Please sign in to comment.