Skip to content

Commit

Permalink
Broadcast binary ops involving strided triangular #55798
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Sep 30, 2024
1 parent 2b28354 commit 108bd2e
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 20 deletions.
8 changes: 4 additions & 4 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -536,10 +536,10 @@ for f in (:+, :-)
@eval begin
$f(A::Hermitian, B::Symmetric{<:Real}) = $f(A, Hermitian(parent(B), sym_uplo(B.uplo)))
$f(A::Symmetric{<:Real}, B::Hermitian) = $f(Hermitian(parent(A), sym_uplo(A.uplo)), B)
$f(A::SymTridiagonal, B::Symmetric) = Symmetric($f(A, B.data), sym_uplo(B.uplo))
$f(A::Symmetric, B::SymTridiagonal) = Symmetric($f(A.data, B), sym_uplo(A.uplo))
$f(A::SymTridiagonal{<:Real}, B::Hermitian) = Hermitian($f(A, B.data), sym_uplo(B.uplo))
$f(A::Hermitian, B::SymTridiagonal{<:Real}) = Hermitian($f(A.data, B), sym_uplo(A.uplo))
$f(A::SymTridiagonal, B::Symmetric) = $f(Symmetric(A, sym_uplo(B.uplo)), B)
$f(A::Symmetric, B::SymTridiagonal) = $f(A, Symmetric(B, sym_uplo(A.uplo)))
$f(A::SymTridiagonal{<:Real}, B::Hermitian) = $f(Hermitian(A, sym_uplo(B.uplo)), B)
$f(A::Hermitian, B::SymTridiagonal{<:Real}) = $f(A, Hermitian(B, sym_uplo(A.uplo)))
end
end

Expand Down
81 changes: 65 additions & 16 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -779,24 +779,73 @@ fillstored!(A::UpperTriangular, x) = (fillband!(A.data, x, 0, size(A,2)-1);
fillstored!(A::UnitUpperTriangular, x) = (fillband!(A.data, x, 1, size(A,2)-1); A)

# Binary operations
+(A::UpperTriangular, B::UpperTriangular) = UpperTriangular(A.data + B.data)
+(A::LowerTriangular, B::LowerTriangular) = LowerTriangular(A.data + B.data)
+(A::UpperTriangular, B::UnitUpperTriangular) = UpperTriangular(A.data + triu(B.data, 1) + I)
+(A::LowerTriangular, B::UnitLowerTriangular) = LowerTriangular(A.data + tril(B.data, -1) + I)
+(A::UnitUpperTriangular, B::UpperTriangular) = UpperTriangular(triu(A.data, 1) + B.data + I)
+(A::UnitLowerTriangular, B::LowerTriangular) = LowerTriangular(tril(A.data, -1) + B.data + I)
+(A::UnitUpperTriangular, B::UnitUpperTriangular) = UpperTriangular(triu(A.data, 1) + triu(B.data, 1) + 2I)
+(A::UnitLowerTriangular, B::UnitLowerTriangular) = LowerTriangular(tril(A.data, -1) + tril(B.data, -1) + 2I)
# use broadcasting if the parents are strided, where we loop only over the triangular part
function +(A::UpperTriangular, B::UpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
UpperTriangular(A.data + B.data)
end
function +(A::LowerTriangular, B::LowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
LowerTriangular(A.data + B.data)
end
function +(A::UpperTriangular, B::UnitUpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
UpperTriangular(A.data + triu(B.data, 1) + I)
end
function +(A::LowerTriangular, B::UnitLowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
LowerTriangular(A.data + tril(B.data, -1) + I)
end
function +(A::UnitUpperTriangular, B::UpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
UpperTriangular(triu(A.data, 1) + B.data + I)
end
function +(A::UnitLowerTriangular, B::LowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
LowerTriangular(tril(A.data, -1) + B.data + I)
end
function +(A::UnitUpperTriangular, B::UnitUpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
UpperTriangular(triu(A.data, 1) + triu(B.data, 1) + 2I)
end
function +(A::UnitLowerTriangular, B::UnitLowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
LowerTriangular(tril(A.data, -1) + tril(B.data, -1) + 2I)
end
+(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A)), A) + copyto!(similar(parent(B)), B)

-(A::UpperTriangular, B::UpperTriangular) = UpperTriangular(A.data - B.data)
-(A::LowerTriangular, B::LowerTriangular) = LowerTriangular(A.data - B.data)
-(A::UpperTriangular, B::UnitUpperTriangular) = UpperTriangular(A.data - triu(B.data, 1) - I)
-(A::LowerTriangular, B::UnitLowerTriangular) = LowerTriangular(A.data - tril(B.data, -1) - I)
-(A::UnitUpperTriangular, B::UpperTriangular) = UpperTriangular(triu(A.data, 1) - B.data + I)
-(A::UnitLowerTriangular, B::LowerTriangular) = LowerTriangular(tril(A.data, -1) - B.data + I)
-(A::UnitUpperTriangular, B::UnitUpperTriangular) = UpperTriangular(triu(A.data, 1) - triu(B.data, 1))
-(A::UnitLowerTriangular, B::UnitLowerTriangular) = LowerTriangular(tril(A.data, -1) - tril(B.data, -1))
function -(A::UpperTriangular, B::UpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
UpperTriangular(A.data - B.data)
end
function -(A::LowerTriangular, B::LowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
LowerTriangular(A.data - B.data)
end
function -(A::UpperTriangular, B::UnitUpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
UpperTriangular(A.data - triu(B.data, 1) - I)
end
function -(A::LowerTriangular, B::UnitLowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
LowerTriangular(A.data - tril(B.data, -1) - I)
end
function -(A::UnitUpperTriangular, B::UpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
UpperTriangular(triu(A.data, 1) - B.data + I)
end
function -(A::UnitLowerTriangular, B::LowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
LowerTriangular(tril(A.data, -1) - B.data + I)
end
function -(A::UnitUpperTriangular, B::UnitUpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
UpperTriangular(triu(A.data, 1) - triu(B.data, 1))
end
function -(A::UnitLowerTriangular, B::UnitLowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
LowerTriangular(tril(A.data, -1) - tril(B.data, -1))
end
-(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A)), A) - copyto!(similar(parent(B)), B)

# use broadcasting if the parents are strided, where we loop only over the triangular part
Expand Down
25 changes: 25 additions & 0 deletions stdlib/LinearAlgebra/test/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1014,4 +1014,29 @@ end
end
end

@testset "partly iniitalized matrices" begin
a = Matrix{BigFloat}(undef, 2,2)
a[1] = 1; a[3] = 1; a[4] = 1
h = Hermitian(a)
s = Symmetric(a)
d = Diagonal([1,1])
symT = SymTridiagonal([1 1;1 1])
@test h+d == Array(h) + Array(d)
@test h+symT == Array(h) + Array(symT)
@test s+d == Array(s) + Array(d)
@test s+symT == Array(s) + Array(symT)
@test h-d == Array(h) - Array(d)
@test h-symT == Array(h) - Array(symT)
@test s-d == Array(s) - Array(d)
@test s-symT == Array(s) - Array(symT)
@test d+h == Array(d) + Array(h)
@test symT+h == Array(symT) + Array(h)
@test d+s == Array(d) + Array(s)
@test symT+s == Array(symT) + Array(s)
@test d-h == Array(d) - Array(h)
@test symT-h == Array(symT) - Array(h)
@test d-s == Array(d) - Array(s)
@test symT-s == Array(symT) - Array(s)
end

end # module TestSymmetric

0 comments on commit 108bd2e

Please sign in to comment.