Skip to content

Commit

Permalink
structure-preserving broadcast for SymTridiagonal (#56001)
Browse files Browse the repository at this point in the history
With this PR, certain broadcasting operations preserve the structure of
a `SymTridiagonal`:
```julia
julia> S = SymTridiagonal([1,2,3,4], [1,2,3])
4×4 SymTridiagonal{Int64, Vector{Int64}}:
 1  1  ⋅  ⋅
 1  2  2  ⋅
 ⋅  2  3  3
 ⋅  ⋅  3  4

julia> S .* 2
4×4 SymTridiagonal{Int64, Vector{Int64}}:
 2  2  ⋅  ⋅
 2  4  4  ⋅
 ⋅  4  6  6
 ⋅  ⋅  6  8
```
This was deliberately disabled on master, but I couldn't find any test
that fails if this is enabled.
  • Loading branch information
jishnub authored Oct 30, 2024
1 parent a9342d6 commit 2fe6562
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
5 changes: 3 additions & 2 deletions stdlib/LinearAlgebra/src/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ end
function Base.similar(bc::Broadcasted{StructuredMatrixStyle{T}}, ::Type{ElType}) where {T,ElType}
inds = axes(bc)
fzerobc = fzeropreserving(bc)
if isstructurepreserving(bc) || (fzerobc && !(T <: Union{SymTridiagonal,UnitLowerTriangular,UnitUpperTriangular}))
if isstructurepreserving(bc) || (fzerobc && !(T <: Union{UnitLowerTriangular,UnitUpperTriangular}))
return structured_broadcast_alloc(bc, T, ElType, length(inds[1]))
elseif fzerobc && T <: UnitLowerTriangular
return similar(convert(Broadcasted{StructuredMatrixStyle{LowerTriangular}}, bc), ElType)
Expand Down Expand Up @@ -240,7 +240,8 @@ function copyto!(dest::SymTridiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
end
for i = 1:size(dest, 1)-1
v = @inbounds bc[BandIndex(1, i)]
v == (@inbounds bc[BandIndex(-1, i)]) || throw(ArgumentError(lazy"broadcasted assignment breaks symmetry between locations ($i, $(i+1)) and ($(i+1), $i)"))
v == transpose(@inbounds bc[BandIndex(-1, i)]) ||
throw(ArgumentError(lazy"broadcasted assignment breaks symmetry between locations ($i, $(i+1)) and ($(i+1), $i)"))
dest.ev[i] = v
end
return dest
Expand Down
15 changes: 13 additions & 2 deletions stdlib/LinearAlgebra/test/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
module TestStructuredBroadcast
using Test, LinearAlgebra

const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")
isdefined(Main, :SizedArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "SizedArrays.jl"))
using .Main.SizedArrays

@testset "broadcast[!] over combinations of scalars, structured matrices, and dense vectors/matrices" begin
N = 10
s = rand()
Expand All @@ -12,10 +16,11 @@ using Test, LinearAlgebra
D = Diagonal(rand(N))
B = Bidiagonal(rand(N), rand(N - 1), :U)
T = Tridiagonal(rand(N - 1), rand(N), rand(N - 1))
S = SymTridiagonal(rand(N), rand(N - 1))
U = UpperTriangular(rand(N,N))
L = LowerTriangular(rand(N,N))
M = Matrix(rand(N,N))
structuredarrays = (D, B, T, U, L, M)
structuredarrays = (D, B, T, U, L, M, S)
fstructuredarrays = map(Array, structuredarrays)
for (X, fX) in zip(structuredarrays, fstructuredarrays)
@test (Q = broadcast(sin, X); typeof(Q) == typeof(X) && Q == broadcast(sin, fX))
Expand Down Expand Up @@ -166,10 +171,11 @@ end
D = Diagonal(rand(N))
B = Bidiagonal(rand(N), rand(N - 1), :U)
T = Tridiagonal(rand(N - 1), rand(N), rand(N - 1))
S = SymTridiagonal(rand(N), rand(N - 1))
U = UpperTriangular(rand(N,N))
L = LowerTriangular(rand(N,N))
M = Matrix(rand(N,N))
structuredarrays = (M, D, B, T, U, L)
structuredarrays = (M, D, B, T, S, U, L)
fstructuredarrays = map(Array, structuredarrays)
for (X, fX) in zip(structuredarrays, fstructuredarrays)
@test (Q = map(sin, X); typeof(Q) == typeof(X) && Q == map(sin, fX))
Expand Down Expand Up @@ -363,6 +369,11 @@ end
U = UpperTriangular([(i+j)*A for i in 1:3, j in 1:3])
standardbroadcastingtests(U, UpperTriangular)
end
@testset "SymTridiagonal" begin
m = SizedArrays.SizedArray{(2,2)}([1 2; 3 4])
S = SymTridiagonal(fill(m,4), fill(m,3))
standardbroadcastingtests(S, SymTridiagonal)
end
end

end

0 comments on commit 2fe6562

Please sign in to comment.