Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add naive Base julia AbstractArray implementation #171

Merged
merged 13 commits into from
Jun 23, 2024
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Logging = "1.6"
PackageExtensionCompat = "1"
Random = "1"
Strided = "2.0.4"
StridedViews = "0.2"
StridedViews = "0.3"
Test = "1"
TupleTools = "1.1"
VectorInterface = "0.4.1"
Expand Down
1 change: 1 addition & 0 deletions src/TensorOperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using TupleTools: TupleTools, isperm, invperm
using LinearAlgebra
using LinearAlgebra: mul!, BlasFloat
using Strided
using StridedViews: isstrided
using LRUCache

using Base.Meta: isexpr
Expand Down
217 changes: 183 additions & 34 deletions src/implementation/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,52 @@ end

const StridedNative = Backend{:StridedNative}
const StridedBLAS = Backend{:StridedBLAS}
const BaseView = Backend{:BaseView}
const BaseCopy = Backend{:BaseCopy}

function tensoradd!(C::StridedArray,
A::StridedArray, pA::Index2Tuple, conjA::Bool,
function tensoradd!(C::AbstractArray,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
α::Number, β::Number)
return tensoradd!(C, A, pA, conjA, α, β, StridedNative())
if isstrided(A) && isstrided(C)
return tensoradd!(C, A, pA, conjA, α, β, StridedNative())
else
return tensoradd!(C, A, pA, conjA, α, β, BaseView())
end
end

function tensortrace!(C::StridedArray,
A::StridedArray, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
function tensortrace!(C::AbstractArray,
A::AbstractArray, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
α::Number, β::Number)
return tensortrace!(C, A, p, q, conjA, α, β, StridedNative())
if isstrided(A) && isstrided(C)
return tensortrace!(C, A, p, q, conjA, α, β, StridedNative())
else
return tensortrace!(C, A, p, q, conjA, α, β, BaseView())
end
end

function tensorcontract!(C::StridedArray,
A::StridedArray, pA::Index2Tuple, conjA::Bool,
B::StridedArray, pB::Index2Tuple, conjB::Bool,
function tensorcontract!(C::AbstractArray,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
B::AbstractArray, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple,
α::Number, β::Number)
if eltype(C) <: LinearAlgebra.BlasFloat && !isa(B, Diagonal) && !isa(A, Diagonal)
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, StridedBLAS())
if eltype(C) <: LinearAlgebra.BlasFloat
if isstrided(A) && isstrided(B) && isstrided(C)
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, StridedBLAS())
elseif (isstrided(A) || isa(A, Diagonal)) && (isstrided(B) || isa(B, Diagonal)) &&
isstrided(C)
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β,
StridedNative())
else
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, BaseCopy())
end
else
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, StridedNative())
if (isstrided(A) || isa(A, Diagonal)) && (isstrided(B) || isa(B, Diagonal)) &&
isstrided(C)
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β,
StridedNative())
else
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, BaseView())
end
end
end

Expand Down Expand Up @@ -82,39 +106,116 @@ end
# Implementation based on Base + LinearAlgebra
#-------------------------------------------------------------------------------------------
# Note that this is mostly for convenience + checking, and not for performance
function tensoradd!(C::AbstractArray,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
α::Number, β::Number, ::BaseView)
argcheck_tensoradd(C, A, pA)
dimcheck_tensoradd(C, A, pA)

# can we assume that C is mutable?
# is there more functionality in base that we can use?
if conjA
if iszero(β)
C .= α .* conj.(PermutedDimsArray(A, linearize(pA)))
else
C .= β .* C .+ α .* conj.(PermutedDimsArray(A, linearize(pA)))
end
else
if iszero(β)
C .= α .* PermutedDimsArray(A, linearize(pA))
else
C .= β .* C .+ α .* PermutedDimsArray(A, linearize(pA))
end
end
return C
end
function tensoradd!(C::AbstractArray,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
α::Number, β::Number)
α::Number, β::Number, ::BaseCopy)
argcheck_tensoradd(C, A, pA)
dimcheck_tensoradd(C, A, pA)

# can we assume that C is mutable?
# is there more functionality in base that we can use?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is an in-place permutedims!, such that we could use the allocation interface to also hijack into allocating these temporary arrays

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That might be possible indeed. The question is whether it is worth it. This will probably only be used for types which are very different from strided arrays, e.g. sparse arrays.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's probably a very valid point. It's probably fair to assume that we cannot guarantee optimal performance without extra information about the specific type anyways, and the system does allow to easily implement custom backends if necessary. The base backend should serve mostly as a catch-all implementation that ensures that it works for most types.

if conjA
C .= β .* C .+ α .* conj.(PermutedDimsArray(A, linearize(pA)))
if iszero(β)
C .= α .* conj.(permutedims(A, linearize(pA)))
else
C .= β .* C .+ α .* conj.(permutedims(A, linearize(pA)))
end
else
C .= β .* C .+ α .* PermutedDimsArray(A, linearize(pA))
if iszero(β)
C .= α .* permutedims(A, linearize(pA))
else
C .= β .* C .+ α .* permutedims(A, linearize(pA))
end
end
return C
end

# For now I am giving up on writing a generic tensortrace! that works for all AbstractArray types...
# tensortrace
function tensortrace!(C::AbstractArray,
A::AbstractArray, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
α::Number, β::Number, ::BaseView)
argcheck_tensortrace(C, A, p, q)
dimcheck_tensortrace(C, A, p, q)

szA = size(A)
so = TupleTools.getindices(szA, linearize(p))
st = prod(TupleTools.getindices(szA, q[1]))
à = reshape(PermutedDimsArray(A, (linearize(p)..., linearize(q)...)),
(prod(so), st * st))

if conjA
if iszero(β)
C .= α .* conj.(reshape(view(Ã, :, 1, 1), so))
else
C .= β .* C .+ α .* conj.(reshape(view(Ã, :, 1, 1), so))
end
for i in 2:st
C .+= α .* conj.(reshape(view(Ã, :, i, i), so))
end
else
if iszero(β)
C .= α .* reshape(view(Ã, :, 1, 1), so)
else
C .= β .* C .+ α .* reshape(view(Ã, :, 1, 1), so)
end
for i in 2:st
C .+= α .* reshape(view(Ã, :, i, i), so)
end
end
return C
end
function tensortrace!(C::AbstractArray,
A::AbstractArray, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
α::Number, β::Number)
α::Number, β::Number, ::BaseCopy)
argcheck_tensortrace(C, A, p, q)
dimcheck_tensortrace(C, A, p, q)

szA = size(A)
so = TupleTools.getindices(szA, linearize(p))
st = prod(TupleTools.getindices(szA, q[1]))
A = reshape(PermutedDimsArray(A, (linearize(p)..., linearize(q)...)), (so..., st * st))
à = reshape(permutedims(A, (linearize(p)..., linearize(q)...)), (prod(so), st * st))

if conjA
C .= β .* C .+ α .* conj.(view(A, :, diagind(st, st)))
if iszero(β)
C .= α .* conj.(reshape(view(Ã, :, 1, 1), so))
else
C .= β .* C .+ α .* conj.(reshape(view(Ã, :, 1, 1), so))
end
for i in 2:st
C .+= α .* conj.(reshape(view(Ã, :, i, i), so))
end
Comment on lines +202 to +209
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we rewrite this with sum?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know? In a way that does not cause additional allocations? Is there an issue with the current approach?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote this in reply to the comment about -- # is there more base functionality we can use. I don't think there is any issue with the current approach.

else
C .= β .* C .+ α .* view(A, :, diagind(st, st))
if iszero(β)
C .= α .* reshape(view(Ã, :, 1, 1), so)
else
C .= β .* C .+ α .* reshape(view(Ã, :, 1, 1), so)
end
for i in 2:st
C .+= α .* reshape(view(Ã, :, i, i), so)
end
end
return C
end
Expand All @@ -123,7 +224,7 @@ function tensorcontract!(C::AbstractArray,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
B::AbstractArray, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple,
α::Number, β::Number)
α::Number, β::Number, ::BaseView)
argcheck_tensorcontract(C, A, pA, B, pB, pAB)
dimcheck_tensorcontract(C, A, pA, B, pB, pAB)

Expand All @@ -135,25 +236,73 @@ function tensorcontract!(C::AbstractArray,
soA1 = prod(soA)
soB1 = prod(soB)
sc1 = prod(sc)
pC = invperm(linearize(pAB))
C̃ = reshape(PermutedDimsArray(C, pC), (soA1, soB1))

if conjA && conjB
A′ = reshape(permutedims(A, linearize(reverse(pA))), (sc1, soA1))
B′ = reshape(permutedims(B, linearize(reverse(pB))), (soB1, sc1))
C′ = adjoint(A′) * adjoint(B′)
= reshape(PermutedDimsArray(A, linearize(reverse(pA))), (sc1, soA1))
= reshape(PermutedDimsArray(B, linearize(reverse(pB))), (soB1, sc1))
= mul!(C̃, adjoint(Ã), adjoint(B̃), α, β)
elseif conjA
A′ = reshape(permutedims(A, linearize(reverse(pA))), (sc1, soA1))
B′ = reshape(permutedims(B, linearize(pB)), (sc1, soB1))
C′ = adjoint(A′) * B′
= reshape(PermutedDimsArray(A, linearize(reverse(pA))), (sc1, soA1))
= reshape(PermutedDimsArray(B, linearize(pB)), (sc1, soB1))
= mul!(C̃, adjoint(Ã), B̃, α, β)
elseif conjB
A′ = reshape(permutedims(A, linearize(pA)), (soA1, sc1))
B′ = reshape(permutedims(B, linearize(reverse(pB))), (soB1, sc1))
C′ = A′ * adjoint(B′)
= reshape(PermutedDimsArray(A, linearize(pA)), (soA1, sc1))
= reshape(PermutedDimsArray(B, linearize(reverse(pB))), (soB1, sc1))
= mul!(C̃, Ã, adjoint(B̃), α, β)
else
A′ = reshape(permutedims(A, linearize(pA)), (soA1, sc1))
B′ = reshape(permutedims(B, linearize(pB)), (sc1, soB1))
C′ = A′ * B′
= reshape(PermutedDimsArray(A, linearize(pA)), (soA1, sc1))
= reshape(PermutedDimsArray(B, linearize(pB)), (sc1, soB1))
= mul!(C̃, Ã, B̃, α, β)
end
return tensoradd!(C, reshape(C′, (soA..., soB...)), pAB, false, α, β)
return C
end
function tensorcontract!(C::AbstractArray,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
B::AbstractArray, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple,
α::Number, β::Number, ::BaseCopy)
argcheck_tensorcontract(C, A, pA, B, pB, pAB)
dimcheck_tensorcontract(C, A, pA, B, pB, pAB)

szA = size(A)
szB = size(B)
soA = TupleTools.getindices(szA, pA[1])
soB = TupleTools.getindices(szB, pB[2])
sc = TupleTools.getindices(szA, pA[2])
soA1 = prod(soA)
soB1 = prod(soB)
sc1 = prod(sc)

if conjA && conjB
à = reshape(permutedims(A, linearize(reverse(pA))), (sc1, soA1))
B̃ = reshape(permutedims(B, linearize(reverse(pB))), (soB1, sc1))
ÃB̃ = reshape(adjoint(Ã) * adjoint(B̃), (soA..., soB...))
elseif conjA
à = reshape(permutedims(A, linearize(reverse(pA))), (sc1, soA1))
B̃ = reshape(permutedims(B, linearize(pB)), (sc1, soB1))
ÃB̃ = reshape(adjoint(Ã) * B̃, (soA..., soB...))
elseif conjB
à = reshape(permutedims(A, linearize(pA)), (soA1, sc1))
B̃ = reshape(permutedims(B, linearize(reverse(pB))), (soB1, sc1))
ÃB̃ = reshape(Ã * adjoint(B̃), (soA..., soB...))
else
à = reshape(permutedims(A, linearize(pA)), (soA1, sc1))
B̃ = reshape(permutedims(B, linearize(pB)), (sc1, soB1))
ÃB̃ = reshape(Ã * B̃, (soA..., soB...))
end
if istrivialpermutation(linearize(pAB))
pÃB̃ = ÃB̃
else
pÃB̃ = permutedims(ÃB̃, linearize(pAB))
end
if iszero(β)
C .= α .* pÃB̃
else
C .= β .* C .+ α .* pÃB̃
end
return C
end

# ------------------------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion src/implementation/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

"""
tensorcopy([IC=IA], A, IA, [conjA=false, [α=1]])
tensorcopy(A, pA::Index2Tuple, conjA, α) # expert mode
tensorcopy(A, pA::Index2Tuple, conjA, α, [backend]) # expert mode

Create a copy of `A`, where the dimensions of `A` are assigned indices from the
iterable `IA` and the indices of the copy are contained in `IC`. Both iterables
Expand Down
38 changes: 17 additions & 21 deletions src/indexnotation/contractiontrees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,13 @@ function insertcontractiontrees!(ex, treebuilder, treesorter, costcheck, preexpr
optimalordersym = gensym("optimalorder")
if costcheck == :warn
costcompareex = :(@notensor begin
$currentcostsym = first(TensorOperations.treecost($tree,
$network,
$costmapsym))
$optimaltreesym, $optimalcostsym = TensorOperations.optimaltree($network,
$costmapsym)
$currentcostsym = first(treecost($tree, $network,
$costmapsym))
$optimaltreesym, $optimalcostsym = optimaltree($network,
$costmapsym)
if $currentcostsym > $optimalcostsym
$optimalordersym = tuple(first(TensorOperations.tree2indexorder($optimaltreesym,
$network))...)
$optimalordersym = tuple(first(tree2indexorder($optimaltreesym,
$network))...)
@warn "Tensor network: " *
$(string(ex)) *
":\n" *
Expand All @@ -132,21 +131,18 @@ function insertcontractiontrees!(ex, treebuilder, treesorter, costcheck, preexpr
end)
elseif costcheck == :cache
key = Expr(:quote, ex)
cacheref = GlobalRef(TensorOperations, :costcache)
costcompareex = :(@notensor begin
$currentcostsym = first(TensorOperations.treecost($tree,
$network,
$costmapsym))
if !($key in
keys(TensorOperations.costcache)) ||
first(TensorOperations.costcache[$key]) <
$(currentcostsym)
$optimaltreesym, $optimalcostsym = TensorOperations.optimaltree($network,
$costmapsym)
$optimalordersym = tuple(first(TensorOperations.tree2indexorder($optimaltreesym,
$network))...)
TensorOperations.costcache[$key] = ($currentcostsym,
$optimalcostsym,
$optimalordersym)
$currentcostsym = first(treecost($tree, $network,
$costmapsym))
if !($key in keys($cacheref)) ||
first($cacheref[$key]) < $(currentcostsym)
$optimaltreesym, $optimalcostsym = optimaltree($network,
$costmapsym)
$optimalordersym = tuple(first(tree2indexorder($optimaltreesym,
$network))...)
$cacheref[$key] = ($currentcostsym, $optimalcostsym,
$optimalordersym)
end
end)
end
Expand Down
3 changes: 2 additions & 1 deletion src/indexnotation/postprocessors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ const tensoroperationsfunctions = (:tensoralloc, :tensorfree!,
:tensoradd!, :tensortrace!, :tensorcontract!,
:tensorscalar, :tensorcost, :IndexError, :scalartype,
:checkcontractible, :promote_contract, :promote_add,
:tensoralloc_add, :tensoralloc_contract)
:tensoralloc_add, :tensoralloc_contract,
:treecost, :optimaltree, :tree2indexorder)
"""
addtensoroperations(ex)

Expand Down
10 changes: 10 additions & 0 deletions test/macro_kwargs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,17 @@ using Logging
tensorscalar(rhoL[a', a] * A1[a, s, b] * A2[b, s', c] * rhoR[c, c'] *
H[t, t', s, s'] * conj(A1[a', t, b']) * conj(A2[b', t', c']))
end
E3 = @tensor backend = BaseView begin
tensorscalar(rhoL[a', a] * A1[a, s, b] * A2[b, s', c] * rhoR[c, c'] *
H[t, t', s, s'] * conj(A1[a', t, b']) * conj(A2[b', t', c']))
end
E4 = @tensor backend = BaseCopy begin
tensorscalar(rhoL[a', a] * A1[a, s, b] * A2[b, s', c] * rhoR[c, c'] *
H[t, t', s, s'] * conj(A1[a', t, b']) * conj(A2[b', t', c']))
end
@test E1 ≈ E2
@test E1 ≈ E3
@test E1 ≈ E4
end

@testset "contractcheck" begin
Expand Down
Loading
Loading