-
Notifications
You must be signed in to change notification settings - Fork 59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add naive Base julia AbstractArray implementation #171
Changes from all commits
a3551e2
0e6f895
71b41b4
1266d6e
97309d3
3a23b9b
bb13e7e
3cb9aec
c0f8eb0
7bf566c
043378e
0a5e3f4
34f7a53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,33 +9,54 @@ function checkcontractible(A::AbstractArray, iA, conjA::Bool, | |
return nothing | ||
end | ||
|
||
# TODO | ||
# add check for stridedness of abstract arrays and add a pure implementation as fallback | ||
|
||
const StridedNative = Backend{:StridedNative} | ||
const StridedBLAS = Backend{:StridedBLAS} | ||
const BaseView = Backend{:BaseView} | ||
const BaseCopy = Backend{:BaseCopy} | ||
|
||
function tensoradd!(C::AbstractArray, | ||
A::AbstractArray, pA::Index2Tuple, conjA::Bool, | ||
α::Number, β::Number) | ||
return tensoradd!(C, A, pA, conjA, α, β, StridedNative()) | ||
if isstrided(A) && isstrided(C) | ||
return tensoradd!(C, A, pA, conjA, α, β, StridedNative()) | ||
else | ||
return tensoradd!(C, A, pA, conjA, α, β, BaseView()) | ||
end | ||
end | ||
|
||
function tensortrace!(C::AbstractArray, | ||
A::AbstractArray, p::Index2Tuple, q::Index2Tuple, conjA::Bool, | ||
α::Number, β::Number) | ||
return tensortrace!(C, A, p, q, conjA, α, β, StridedNative()) | ||
if isstrided(A) && isstrided(C) | ||
return tensortrace!(C, A, p, q, conjA, α, β, StridedNative()) | ||
else | ||
return tensortrace!(C, A, p, q, conjA, α, β, BaseView()) | ||
end | ||
end | ||
|
||
function tensorcontract!(C::AbstractArray, | ||
A::AbstractArray, pA::Index2Tuple, conjA::Bool, | ||
B::AbstractArray, pB::Index2Tuple, conjB::Bool, | ||
pAB::Index2Tuple, | ||
α::Number, β::Number) | ||
if eltype(C) <: LinearAlgebra.BlasFloat && !isa(B, Diagonal) && !isa(A, Diagonal) | ||
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, StridedBLAS()) | ||
if eltype(C) <: LinearAlgebra.BlasFloat | ||
if isstrided(A) && isstrided(B) && isstrided(C) | ||
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, StridedBLAS()) | ||
elseif (isstrided(A) || isa(A, Diagonal)) && (isstrided(B) || isa(B, Diagonal)) && | ||
isstrided(C) | ||
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, | ||
StridedNative()) | ||
else | ||
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, BaseCopy()) | ||
end | ||
else | ||
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, StridedNative()) | ||
if (isstrided(A) || isa(A, Diagonal)) && (isstrided(B) || isa(B, Diagonal)) && | ||
isstrided(C) | ||
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, | ||
StridedNative()) | ||
else | ||
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, BaseView()) | ||
end | ||
end | ||
end | ||
|
||
|
@@ -81,6 +102,209 @@ function tensorcontract!(C::AbstractArray, | |
return C | ||
end | ||
|
||
#------------------------------------------------------------------------------------------- | ||
# Implementation based on Base + LinearAlgebra | ||
#------------------------------------------------------------------------------------------- | ||
# Note that this is mostly for convenience + checking, and not for performance | ||
function tensoradd!(C::AbstractArray, | ||
A::AbstractArray, pA::Index2Tuple, conjA::Bool, | ||
α::Number, β::Number, ::BaseView) | ||
argcheck_tensoradd(C, A, pA) | ||
dimcheck_tensoradd(C, A, pA) | ||
|
||
# can we assume that C is mutable? | ||
# is there more functionality in base that we can use? | ||
if conjA | ||
if iszero(β) | ||
C .= α .* conj.(PermutedDimsArray(A, linearize(pA))) | ||
else | ||
C .= β .* C .+ α .* conj.(PermutedDimsArray(A, linearize(pA))) | ||
end | ||
else | ||
if iszero(β) | ||
C .= α .* PermutedDimsArray(A, linearize(pA)) | ||
else | ||
C .= β .* C .+ α .* PermutedDimsArray(A, linearize(pA)) | ||
end | ||
end | ||
return C | ||
end | ||
function tensoradd!(C::AbstractArray, | ||
A::AbstractArray, pA::Index2Tuple, conjA::Bool, | ||
α::Number, β::Number, ::BaseCopy) | ||
argcheck_tensoradd(C, A, pA) | ||
dimcheck_tensoradd(C, A, pA) | ||
|
||
# can we assume that C is mutable? | ||
# is there more functionality in base that we can use? | ||
if conjA | ||
if iszero(β) | ||
C .= α .* conj.(permutedims(A, linearize(pA))) | ||
else | ||
C .= β .* C .+ α .* conj.(permutedims(A, linearize(pA))) | ||
end | ||
else | ||
if iszero(β) | ||
C .= α .* permutedims(A, linearize(pA)) | ||
else | ||
C .= β .* C .+ α .* permutedims(A, linearize(pA)) | ||
end | ||
end | ||
return C | ||
end | ||
|
||
# tensortrace | ||
function tensortrace!(C::AbstractArray, | ||
A::AbstractArray, p::Index2Tuple, q::Index2Tuple, conjA::Bool, | ||
α::Number, β::Number, ::BaseView) | ||
argcheck_tensortrace(C, A, p, q) | ||
dimcheck_tensortrace(C, A, p, q) | ||
|
||
szA = size(A) | ||
so = TupleTools.getindices(szA, linearize(p)) | ||
st = prod(TupleTools.getindices(szA, q[1])) | ||
à = reshape(PermutedDimsArray(A, (linearize(p)..., linearize(q)...)), | ||
(prod(so), st, st)) | ||
|
||
if conjA | ||
if iszero(β) | ||
C .= α .* conj.(reshape(view(Ã, :, 1, 1), so)) | ||
else | ||
C .= β .* C .+ α .* conj.(reshape(view(Ã, :, 1, 1), so)) | ||
end | ||
for i in 2:st | ||
C .+= α .* conj.(reshape(view(Ã, :, i, i), so)) | ||
end | ||
else | ||
if iszero(β) | ||
C .= α .* reshape(view(Ã, :, 1, 1), so) | ||
else | ||
C .= β .* C .+ α .* reshape(view(Ã, :, 1, 1), so) | ||
end | ||
for i in 2:st | ||
C .+= α .* reshape(view(Ã, :, i, i), so) | ||
end | ||
end | ||
return C | ||
end | ||
function tensortrace!(C::AbstractArray, | ||
A::AbstractArray, p::Index2Tuple, q::Index2Tuple, conjA::Bool, | ||
α::Number, β::Number, ::BaseCopy) | ||
argcheck_tensortrace(C, A, p, q) | ||
dimcheck_tensortrace(C, A, p, q) | ||
|
||
szA = size(A) | ||
so = TupleTools.getindices(szA, linearize(p)) | ||
st = prod(TupleTools.getindices(szA, q[1])) | ||
à = reshape(permutedims(A, (linearize(p)..., linearize(q)...)), (prod(so), st, st)) | ||
|
||
if conjA | ||
if iszero(β) | ||
C .= α .* conj.(reshape(view(Ã, :, 1, 1), so)) | ||
else | ||
C .= β .* C .+ α .* conj.(reshape(view(Ã, :, 1, 1), so)) | ||
end | ||
for i in 2:st | ||
C .+= α .* conj.(reshape(view(Ã, :, i, i), so)) | ||
end | ||
Comment on lines
+202
to
+209
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we rewrite this with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know? In a way that does not cause additional allocations? Is there an issue with the current approach? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wrote this in reply to the comment about -- |
||
else | ||
if iszero(β) | ||
C .= α .* reshape(view(Ã, :, 1, 1), so) | ||
else | ||
C .= β .* C .+ α .* reshape(view(Ã, :, 1, 1), so) | ||
end | ||
for i in 2:st | ||
C .+= α .* reshape(view(Ã, :, i, i), so) | ||
end | ||
end | ||
return C | ||
end | ||
|
||
function tensorcontract!(C::AbstractArray, | ||
A::AbstractArray, pA::Index2Tuple, conjA::Bool, | ||
B::AbstractArray, pB::Index2Tuple, conjB::Bool, | ||
pAB::Index2Tuple, | ||
α::Number, β::Number, ::BaseView) | ||
argcheck_tensorcontract(C, A, pA, B, pB, pAB) | ||
dimcheck_tensorcontract(C, A, pA, B, pB, pAB) | ||
|
||
szA = size(A) | ||
szB = size(B) | ||
soA = TupleTools.getindices(szA, pA[1]) | ||
soB = TupleTools.getindices(szB, pB[2]) | ||
sc = TupleTools.getindices(szA, pA[2]) | ||
soA1 = prod(soA) | ||
soB1 = prod(soB) | ||
sc1 = prod(sc) | ||
pC = invperm(linearize(pAB)) | ||
C̃ = reshape(PermutedDimsArray(C, pC), (soA1, soB1)) | ||
|
||
if conjA && conjB | ||
à = reshape(PermutedDimsArray(A, linearize(reverse(pA))), (sc1, soA1)) | ||
B̃ = reshape(PermutedDimsArray(B, linearize(reverse(pB))), (soB1, sc1)) | ||
C̃ = mul!(C̃, adjoint(Ã), adjoint(B̃), α, β) | ||
elseif conjA | ||
à = reshape(PermutedDimsArray(A, linearize(reverse(pA))), (sc1, soA1)) | ||
B̃ = reshape(PermutedDimsArray(B, linearize(pB)), (sc1, soB1)) | ||
C̃ = mul!(C̃, adjoint(Ã), B̃, α, β) | ||
elseif conjB | ||
à = reshape(PermutedDimsArray(A, linearize(pA)), (soA1, sc1)) | ||
B̃ = reshape(PermutedDimsArray(B, linearize(reverse(pB))), (soB1, sc1)) | ||
C̃ = mul!(C̃, Ã, adjoint(B̃), α, β) | ||
else | ||
à = reshape(PermutedDimsArray(A, linearize(pA)), (soA1, sc1)) | ||
B̃ = reshape(PermutedDimsArray(B, linearize(pB)), (sc1, soB1)) | ||
C̃ = mul!(C̃, Ã, B̃, α, β) | ||
end | ||
return C | ||
end | ||
function tensorcontract!(C::AbstractArray, | ||
A::AbstractArray, pA::Index2Tuple, conjA::Bool, | ||
B::AbstractArray, pB::Index2Tuple, conjB::Bool, | ||
pAB::Index2Tuple, | ||
α::Number, β::Number, ::BaseCopy) | ||
argcheck_tensorcontract(C, A, pA, B, pB, pAB) | ||
dimcheck_tensorcontract(C, A, pA, B, pB, pAB) | ||
|
||
szA = size(A) | ||
szB = size(B) | ||
soA = TupleTools.getindices(szA, pA[1]) | ||
soB = TupleTools.getindices(szB, pB[2]) | ||
sc = TupleTools.getindices(szA, pA[2]) | ||
soA1 = prod(soA) | ||
soB1 = prod(soB) | ||
sc1 = prod(sc) | ||
|
||
if conjA && conjB | ||
à = reshape(permutedims(A, linearize(reverse(pA))), (sc1, soA1)) | ||
B̃ = reshape(permutedims(B, linearize(reverse(pB))), (soB1, sc1)) | ||
ÃB̃ = reshape(adjoint(Ã) * adjoint(B̃), (soA..., soB...)) | ||
elseif conjA | ||
à = reshape(permutedims(A, linearize(reverse(pA))), (sc1, soA1)) | ||
B̃ = reshape(permutedims(B, linearize(pB)), (sc1, soB1)) | ||
ÃB̃ = reshape(adjoint(Ã) * B̃, (soA..., soB...)) | ||
elseif conjB | ||
à = reshape(permutedims(A, linearize(pA)), (soA1, sc1)) | ||
B̃ = reshape(permutedims(B, linearize(reverse(pB))), (soB1, sc1)) | ||
ÃB̃ = reshape(Ã * adjoint(B̃), (soA..., soB...)) | ||
else | ||
à = reshape(permutedims(A, linearize(pA)), (soA1, sc1)) | ||
B̃ = reshape(permutedims(B, linearize(pB)), (sc1, soB1)) | ||
ÃB̃ = reshape(Ã * B̃, (soA..., soB...)) | ||
end | ||
if istrivialpermutation(linearize(pAB)) | ||
pÃB̃ = ÃB̃ | ||
else | ||
pÃB̃ = permutedims(ÃB̃, linearize(pAB)) | ||
end | ||
if iszero(β) | ||
C .= α .* pÃB̃ | ||
else | ||
C .= β .* C .+ α .* pÃB̃ | ||
end | ||
return C | ||
end | ||
|
||
# ------------------------------------------------------------------------------------------ | ||
# Argument Checking: can be used by backends to check the validity of the arguments | ||
# ------------------------------------------------------------------------------------------ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is an in-place
permutedims!
, such that we could use the allocation interface to also hijack into allocating these temporary arraysThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That might be possible indeed. The question is whether it is worth it. This will probably only be used for types which are very different from strided arrays, e.g. sparse arrays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's probably a very valid point. It's probably fair to assume that we cannot guarantee optimal performance without extra information about the specific type anyways, and the system does allow to easily implement custom backends if necessary. The base backend should serve mostly as a catch-all implementation that ensures that it works for most types.