Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add naive Base julia AbstractArray implementation #171

Merged
merged 13 commits into from
Jun 23, 2024
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Logging = "1.6"
PackageExtensionCompat = "1"
Random = "1"
Strided = "2.0.4"
StridedViews = "0.2"
StridedViews = "0.3"
Test = "1"
TupleTools = "1.1"
VectorInterface = "0.4.1"
Expand Down
1 change: 1 addition & 0 deletions src/TensorOperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using TupleTools: TupleTools, isperm, invperm
using LinearAlgebra
using LinearAlgebra: mul!, BlasFloat
using Strided
using StridedViews: isstrided
using LRUCache

using Base.Meta: isexpr
Expand Down
240 changes: 232 additions & 8 deletions src/implementation/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,54 @@ function checkcontractible(A::AbstractArray, iA, conjA::Bool,
return nothing
end

# TODO
# add check for stridedness of abstract arrays and add a pure implementation as fallback

const StridedNative = Backend{:StridedNative}
const StridedBLAS = Backend{:StridedBLAS}
const BaseView = Backend{:BaseView}
const BaseCopy = Backend{:BaseCopy}

function tensoradd!(C::AbstractArray,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
α::Number, β::Number)
return tensoradd!(C, A, pA, conjA, α, β, StridedNative())
if isstrided(A) && isstrided(C)
return tensoradd!(C, A, pA, conjA, α, β, StridedNative())
else
return tensoradd!(C, A, pA, conjA, α, β, BaseView())
end
end

function tensortrace!(C::AbstractArray,
A::AbstractArray, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
α::Number, β::Number)
return tensortrace!(C, A, p, q, conjA, α, β, StridedNative())
if isstrided(A) && isstrided(C)
return tensortrace!(C, A, p, q, conjA, α, β, StridedNative())
else
return tensortrace!(C, A, p, q, conjA, α, β, BaseView())
end
end

function tensorcontract!(C::AbstractArray,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
B::AbstractArray, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple,
α::Number, β::Number)
if eltype(C) <: LinearAlgebra.BlasFloat && !isa(B, Diagonal) && !isa(A, Diagonal)
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, StridedBLAS())
if eltype(C) <: LinearAlgebra.BlasFloat
if isstrided(A) && isstrided(B) && isstrided(C)
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, StridedBLAS())
elseif (isstrided(A) || isa(A, Diagonal)) && (isstrided(B) || isa(B, Diagonal)) &&
isstrided(C)
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β,
StridedNative())
else
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, BaseCopy())
end
else
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, StridedNative())
if (isstrided(A) || isa(A, Diagonal)) && (isstrided(B) || isa(B, Diagonal)) &&
isstrided(C)
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β,
StridedNative())
else
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, BaseView())
end
end
end

Expand Down Expand Up @@ -81,6 +102,209 @@ function tensorcontract!(C::AbstractArray,
return C
end

#-------------------------------------------------------------------------------------------
# Implementation based on Base + LinearAlgebra
#-------------------------------------------------------------------------------------------
# Note that this is mostly for convenience + checking, and not for performance
function tensoradd!(C::AbstractArray,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
α::Number, β::Number, ::BaseView)
argcheck_tensoradd(C, A, pA)
dimcheck_tensoradd(C, A, pA)

# can we assume that C is mutable?
# is there more functionality in base that we can use?
if conjA
if iszero(β)
C .= α .* conj.(PermutedDimsArray(A, linearize(pA)))
else
C .= β .* C .+ α .* conj.(PermutedDimsArray(A, linearize(pA)))
end
else
if iszero(β)
C .= α .* PermutedDimsArray(A, linearize(pA))
else
C .= β .* C .+ α .* PermutedDimsArray(A, linearize(pA))
end
end
return C
end
function tensoradd!(C::AbstractArray,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
α::Number, β::Number, ::BaseCopy)
argcheck_tensoradd(C, A, pA)
dimcheck_tensoradd(C, A, pA)

# can we assume that C is mutable?
# is there more functionality in base that we can use?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is an in-place permutedims!, such that we could use the allocation interface to also hijack into allocating these temporary arrays

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That might be possible indeed. The question is whether it is worth it. This will probably only be used for types which are very different from strided arrays, e.g. sparse arrays.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's probably a very valid point. It's probably fair to assume that we cannot guarantee optimal performance without extra information about the specific type anyways, and the system does allow to easily implement custom backends if necessary. The base backend should serve mostly as a catch-all implementation that ensures that it works for most types.

if conjA
if iszero(β)
C .= α .* conj.(permutedims(A, linearize(pA)))
else
C .= β .* C .+ α .* conj.(permutedims(A, linearize(pA)))
end
else
if iszero(β)
C .= α .* permutedims(A, linearize(pA))
else
C .= β .* C .+ α .* permutedims(A, linearize(pA))
end
end
return C
end

# tensortrace
function tensortrace!(C::AbstractArray,
A::AbstractArray, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
α::Number, β::Number, ::BaseView)
argcheck_tensortrace(C, A, p, q)
dimcheck_tensortrace(C, A, p, q)

szA = size(A)
so = TupleTools.getindices(szA, linearize(p))
st = prod(TupleTools.getindices(szA, q[1]))
à = reshape(PermutedDimsArray(A, (linearize(p)..., linearize(q)...)),
(prod(so), st * st))

if conjA
if iszero(β)
C .= α .* conj.(reshape(view(Ã, :, 1, 1), so))
else
C .= β .* C .+ α .* conj.(reshape(view(Ã, :, 1, 1), so))
end
for i in 2:st
C .+= α .* conj.(reshape(view(Ã, :, i, i), so))
end
else
if iszero(β)
C .= α .* reshape(view(Ã, :, 1, 1), so)
else
C .= β .* C .+ α .* reshape(view(Ã, :, 1, 1), so)
end
for i in 2:st
C .+= α .* reshape(view(Ã, :, i, i), so)
end
end
return C
end
function tensortrace!(C::AbstractArray,
A::AbstractArray, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
α::Number, β::Number, ::BaseCopy)
argcheck_tensortrace(C, A, p, q)
dimcheck_tensortrace(C, A, p, q)

szA = size(A)
so = TupleTools.getindices(szA, linearize(p))
st = prod(TupleTools.getindices(szA, q[1]))
à = reshape(permutedims(A, (linearize(p)..., linearize(q)...)), (prod(so), st * st))

if conjA
if iszero(β)
C .= α .* conj.(reshape(view(Ã, :, 1, 1), so))
else
C .= β .* C .+ α .* conj.(reshape(view(Ã, :, 1, 1), so))
end
for i in 2:st
C .+= α .* conj.(reshape(view(Ã, :, i, i), so))
end
Comment on lines +202 to +209
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we rewrite this with sum?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know? In a way that does not cause additional allocations? Is there an issue with the current approach?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote this in reply to the comment about -- # is there more base functionality we can use. I don't think there is any issue with the current approach.

else
if iszero(β)
C .= α .* reshape(view(Ã, :, 1, 1), so)
else
C .= β .* C .+ α .* reshape(view(Ã, :, 1, 1), so)
end
for i in 2:st
C .+= α .* reshape(view(Ã, :, i, i), so)
end
end
return C
end

function tensorcontract!(C::AbstractArray,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
B::AbstractArray, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple,
α::Number, β::Number, ::BaseView)
argcheck_tensorcontract(C, A, pA, B, pB, pAB)
dimcheck_tensorcontract(C, A, pA, B, pB, pAB)

szA = size(A)
szB = size(B)
soA = TupleTools.getindices(szA, pA[1])
soB = TupleTools.getindices(szB, pB[2])
sc = TupleTools.getindices(szA, pA[2])
soA1 = prod(soA)
soB1 = prod(soB)
sc1 = prod(sc)
pC = invperm(linearize(pAB))
C̃ = reshape(PermutedDimsArray(C, pC), (soA1, soB1))

if conjA && conjB
à = reshape(PermutedDimsArray(A, linearize(reverse(pA))), (sc1, soA1))
B̃ = reshape(PermutedDimsArray(B, linearize(reverse(pB))), (soB1, sc1))
C̃ = mul!(C̃, adjoint(Ã), adjoint(B̃), α, β)
elseif conjA
à = reshape(PermutedDimsArray(A, linearize(reverse(pA))), (sc1, soA1))
B̃ = reshape(PermutedDimsArray(B, linearize(pB)), (sc1, soB1))
C̃ = mul!(C̃, adjoint(Ã), B̃, α, β)
elseif conjB
à = reshape(PermutedDimsArray(A, linearize(pA)), (soA1, sc1))
B̃ = reshape(PermutedDimsArray(B, linearize(reverse(pB))), (soB1, sc1))
C̃ = mul!(C̃, Ã, adjoint(B̃), α, β)
else
à = reshape(PermutedDimsArray(A, linearize(pA)), (soA1, sc1))
B̃ = reshape(PermutedDimsArray(B, linearize(pB)), (sc1, soB1))
C̃ = mul!(C̃, Ã, B̃, α, β)
end
return C
end
function tensorcontract!(C::AbstractArray,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
B::AbstractArray, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple,
α::Number, β::Number, ::BaseCopy)
argcheck_tensorcontract(C, A, pA, B, pB, pAB)
dimcheck_tensorcontract(C, A, pA, B, pB, pAB)

szA = size(A)
szB = size(B)
soA = TupleTools.getindices(szA, pA[1])
soB = TupleTools.getindices(szB, pB[2])
sc = TupleTools.getindices(szA, pA[2])
soA1 = prod(soA)
soB1 = prod(soB)
sc1 = prod(sc)

if conjA && conjB
à = reshape(permutedims(A, linearize(reverse(pA))), (sc1, soA1))
B̃ = reshape(permutedims(B, linearize(reverse(pB))), (soB1, sc1))
ÃB̃ = reshape(adjoint(Ã) * adjoint(B̃), (soA..., soB...))
elseif conjA
à = reshape(permutedims(A, linearize(reverse(pA))), (sc1, soA1))
B̃ = reshape(permutedims(B, linearize(pB)), (sc1, soB1))
ÃB̃ = reshape(adjoint(Ã) * B̃, (soA..., soB...))
elseif conjB
à = reshape(permutedims(A, linearize(pA)), (soA1, sc1))
B̃ = reshape(permutedims(B, linearize(reverse(pB))), (soB1, sc1))
ÃB̃ = reshape(Ã * adjoint(B̃), (soA..., soB...))
else
à = reshape(permutedims(A, linearize(pA)), (soA1, sc1))
B̃ = reshape(permutedims(B, linearize(pB)), (sc1, soB1))
ÃB̃ = reshape(Ã * B̃, (soA..., soB...))
end
if istrivialpermutation(linearize(pAB))
pÃB̃ = ÃB̃
else
pÃB̃ = permutedims(ÃB̃, linearize(pAB))
end
if iszero(β)
C .= α .* pÃB̃
else
C .= β .* C .+ α .* pÃB̃
end
return C
end

# ------------------------------------------------------------------------------------------
# Argument Checking: can be used by backends to check the validity of the arguments
# ------------------------------------------------------------------------------------------
Expand Down
14 changes: 14 additions & 0 deletions src/implementation/diagonal.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
#-------------------------------------------------------------------------------------------
# Specialized implementations for contractions involving diagonal matrices
#-------------------------------------------------------------------------------------------

# backend selection:
for (TC, TA, TB) in ((:AbstractArray, :AbstractArray, :Diagonal),
(:AbstractArray, :Diagonal, :AbstractArray), (:AbstractArray, :Diagonal, :Diagonal),
(:Diagonal, :Diagonal, :Diagonal))
@eval function tensorcontract!(C::$TC,
A::$TA, pA::Index2Tuple, conjA::Bool,
B::$TB, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple, α::Number, β::Number)
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, StridedNative())
end
end

# actual implementations:
function tensorcontract!(C::AbstractArray,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
B::Diagonal, pB::Index2Tuple, conjB::Bool,
Expand Down
2 changes: 1 addition & 1 deletion src/implementation/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

"""
tensorcopy([IC=IA], A, IA, [conjA=false, [α=1]])
tensorcopy(A, pA::Index2Tuple, conjA, α) # expert mode
tensorcopy(A, pA::Index2Tuple, conjA, α, [backend]) # expert mode

Create a copy of `A`, where the dimensions of `A` are assigned indices from the
iterable `IA` and the indices of the copy are contained in `IC`. Both iterables
Expand Down
8 changes: 4 additions & 4 deletions src/implementation/strided.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@

# default backends
function tensoradd!(C::StridedView,
A::StridedView, pA::Index2Tuple, conjA::Symbol,
A::StridedView, pA::Index2Tuple, conjA::Bool,
α::Number, β::Number)
backend = eltype(C) isa BlasFloat ? StridedBLAS() : StridedNative()
return tensoradd!(C, A, pA, conjA, α, β, backend)
end
function tensortrace!(C::StridedView,
A::StridedView, p::Index2Tuple, q::Index2Tuple, conjA::Symbol,
A::StridedView, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
α::Number, β::Number)
backend = eltype(C) isa BlasFloat ? StridedBLAS() : StridedNative()
return tensortrace!(C, A, p, q, conjA, α, β, backend)
end
function tensorcontract!(C::StridedView,
A::StridedView, pA::Index2Tuple, conjA::Symbol,
B::StridedView, pB::Index2Tuple, conjB::Symbol,
A::StridedView, pA::Index2Tuple, conjA::Bool,
B::StridedView, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple, α::Number, β::Number)
backend = eltype(C) isa BlasFloat ? StridedBLAS() : StridedNative()
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, backend)
Expand Down
38 changes: 17 additions & 21 deletions src/indexnotation/contractiontrees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,13 @@ function insertcontractiontrees!(ex, treebuilder, treesorter, costcheck, preexpr
optimalordersym = gensym("optimalorder")
if costcheck == :warn
costcompareex = :(@notensor begin
$currentcostsym = first(TensorOperations.treecost($tree,
$network,
$costmapsym))
$optimaltreesym, $optimalcostsym = TensorOperations.optimaltree($network,
$costmapsym)
$currentcostsym = first(treecost($tree, $network,
$costmapsym))
$optimaltreesym, $optimalcostsym = optimaltree($network,
$costmapsym)
if $currentcostsym > $optimalcostsym
$optimalordersym = tuple(first(TensorOperations.tree2indexorder($optimaltreesym,
$network))...)
$optimalordersym = tuple(first(tree2indexorder($optimaltreesym,
$network))...)
@warn "Tensor network: " *
$(string(ex)) *
":\n" *
Expand All @@ -132,21 +131,18 @@ function insertcontractiontrees!(ex, treebuilder, treesorter, costcheck, preexpr
end)
elseif costcheck == :cache
key = Expr(:quote, ex)
cacheref = GlobalRef(TensorOperations, :costcache)
costcompareex = :(@notensor begin
$currentcostsym = first(TensorOperations.treecost($tree,
$network,
$costmapsym))
if !($key in
keys(TensorOperations.costcache)) ||
first(TensorOperations.costcache[$key]) <
$(currentcostsym)
$optimaltreesym, $optimalcostsym = TensorOperations.optimaltree($network,
$costmapsym)
$optimalordersym = tuple(first(TensorOperations.tree2indexorder($optimaltreesym,
$network))...)
TensorOperations.costcache[$key] = ($currentcostsym,
$optimalcostsym,
$optimalordersym)
$currentcostsym = first(treecost($tree, $network,
$costmapsym))
if !($key in keys($cacheref)) ||
first($cacheref[$key]) < $(currentcostsym)
$optimaltreesym, $optimalcostsym = optimaltree($network,
$costmapsym)
$optimalordersym = tuple(first(tree2indexorder($optimaltreesym,
$network))...)
$cacheref[$key] = ($currentcostsym, $optimalcostsym,
$optimalordersym)
end
end)
end
Expand Down
Loading
Loading