Skip to content

Commit

Permalink
Re-enable Strided Diagonal implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Apr 21, 2024
1 parent 52910cc commit 959bc09
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions src/implementation/diagonal.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
#-------------------------------------------------------------------------------------------
# Specialized implementations for contractions involving diagonal matrices
#-------------------------------------------------------------------------------------------

# backend selection:
for (TC, TA, TB) in ((:AbstractArray, :AbstractArray, :Diagonal),
(:AbstractArray, :Diagonal, :AbstractArray), (:AbstractArray, :Diagonal, :Diagonal),
(:Diagonal, :Diagonal, :Diagonal))
@eval function tensorcontract!(C::$TC, pC::Index2Tuple,
A::$TA, pA::Index2Tuple, conjA::Symbol,
B::$TB, pB::Index2Tuple, conjB::Symbol,
α::Number, β::Number)
return tensorcontract!(C, pC, A, pA, conjA, B, pB, conjB, α, β, StridedNative())
end
end

# actual implementations:
function tensorcontract!(C::AbstractArray, pC::Index2Tuple,
A::AbstractArray, pA::Index2Tuple, conjA::Symbol,
B::Diagonal, pB::Index2Tuple, conjB::Symbol,
Expand All @@ -12,7 +26,6 @@ function tensorcontract!(C::AbstractArray, pC::Index2Tuple,
StridedView(B.diag), pB, conjB, α, β)
return C
end

function tensorcontract!(C::AbstractArray, pC::Index2Tuple,
A::Diagonal, pA::Index2Tuple, conjA::Symbol,
B::AbstractArray, pB::Index2Tuple, conjB::Symbol,
Expand All @@ -33,7 +46,6 @@ function tensorcontract!(C::AbstractArray, pC::Index2Tuple,
StridedView(A.diag), rpA, conjA, α, β)
return C
end

function tensorcontract!(C::AbstractArray, pC::Index2Tuple,
A::Diagonal, pA::Index2Tuple, conjA::Symbol,
B::Diagonal, pB::Index2Tuple, conjB::Symbol,
Expand Down Expand Up @@ -76,7 +88,6 @@ function tensorcontract!(C::AbstractArray, pC::Index2Tuple,

return C
end

function tensorcontract!(C::Diagonal, pC::Index2Tuple,
A::Diagonal, pA::Index2Tuple, conjA::Symbol,
B::Diagonal, pB::Index2Tuple, conjB::Symbol,
Expand All @@ -91,7 +102,6 @@ function tensorcontract!(C::Diagonal, pC::Index2Tuple,
C2 .= C2 .* β .+ A2 .* B2 .* α
return C
end

function _diagtensorcontract!(C::StridedView, pC::Index2Tuple,
A::StridedView, pA::Index2Tuple, conjA::Symbol,
Bdiag::StridedView, pB::Index2Tuple, conjB::Symbol,
Expand Down

0 comments on commit 959bc09

Please sign in to comment.