Skip to content

Commit

Permalink
LinearAlgebra: diagzero for non-standard axes
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Jul 25, 2024
1 parent 8e5ca0b commit dca91a4
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 9 deletions.
3 changes: 2 additions & 1 deletion stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ public AbstractTriangular,
isbanded,
peakflops,
symmetric,
symmetric_type
symmetric_type,
diagzero

const BlasFloat = Union{Float64,Float32,ComplexF64,ComplexF32}
const BlasReal = Union{Float64,Float32}
Expand Down
14 changes: 7 additions & 7 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,15 @@ Bidiagonal(A::Bidiagonal) = A
Bidiagonal{T}(A::Bidiagonal{T}) where {T} = A
Bidiagonal{T}(A::Bidiagonal) where {T} = Bidiagonal{T}(A.dv, A.ev, A.uplo)

bidiagzero(::Bidiagonal{T}, i, j) where {T} = zero(T)
function bidiagzero(A::Bidiagonal{<:AbstractMatrix}, i, j)
diagzero(::Bidiagonal{T}, i, j) where {T} = zero(T)
function diagzero(A::Bidiagonal{<:AbstractMatrix}, i, j)
Tel = eltype(eltype(A.dv))
if i < j && A.uplo == 'U' #= top right zeros =#
return zeros(Tel, size(A.ev[i], 1), size(A.ev[j-1], 2))
return diagzero(Tel, axes(A.ev[i], 1), axes(A.ev[j-1], 2))
elseif j < i && A.uplo == 'L' #= bottom left zeros =#
return zeros(Tel, size(A.ev[i-1], 1), size(A.ev[j], 2))
return diagzero(Tel, axes(A.ev[i-1], 1), axes(A.ev[j], 2))
else
return zeros(Tel, size(A.dv[i], 1), size(A.dv[j], 2))
return diagzero(Tel, axes(A.dv[i], 1), axes(A.dv[j], 2))
end
end

Expand Down Expand Up @@ -165,7 +165,7 @@ end
elseif A.uplo == 'L' && (i == j + 1)
return @inbounds A.ev[j]
else
return bidiagzero(A, i, j)
return diagzero(A, i, j)
end
end

Expand All @@ -178,7 +178,7 @@ end
elseif A.uplo == 'L' && b.band == -1
return @inbounds A.ev[b.index]
else
return bidiagzero(A, Tuple(_cartinds(b))...)
return diagzero(A, Tuple(_cartinds(b))...)
end
end

Expand Down
19 changes: 18 additions & 1 deletion stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,25 @@ end
end
r
end
"""
diagzero(A::AbstractMatrix, i, j)
Return the appropriate zero element `A[i, j]` corresponding to a banded matrix `A`.
"""
diagzero(::Diagonal{T}, i, j) where {T} = zero(T)
diagzero(D::Diagonal{<:AbstractMatrix{T}}, i, j) where {T} = zeros(T, size(D.diag[i], 1), size(D.diag[j], 2))
diagzero(D::Diagonal{<:AbstractMatrix{T}}, i, j) where {T} = diagzero(T, axes(D.diag[i], 1), axes(D.diag[j], 2))
# dispatching on the axes permits specializing on the axis types to return something other than an Array
diagzero(T::Type, ax::Union{AbstractUnitRange, Integer}...) = diagzero(T, ax)
diagzero(T::Type, ::Tuple{}) = zeros(T)
"""
diagzero(T::Type, ax::Tuple{AbstractUnitRange, Vararg{AbstractUnitRange}})
Return an appropriate zero-ed array with either the axes `ax`, or the `size` `map(length, ax)`,
which may be used as a structural zero element of a banded matrix. By default, this falls back to
using the size along each axis to construct the result.
"""
diagzero(T::Type, ax::Tuple{AbstractUnitRange, Vararg{AbstractUnitRange}}) = diagzero(T, map(length, ax))
diagzero(T::Type, sz::Tuple{Integer, Vararg{Integer}}) = zeros(T, sz)

@inline function getindex(D::Diagonal, b::BandIndex)
@boundscheck checkbounds(D, b)
Expand Down
10 changes: 10 additions & 0 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,16 @@ end
B = Bidiagonal(dv, ev, :U)
@test B == Matrix{eltype(B)}(B)
end

@testset "non-standard axes" begin
LinearAlgebra.diagzero(T::Type, ax::Tuple{SizedArrays.SOneTo, Vararg{SizedArrays.SOneTo}}) =
zeros(T, ax)

s = SizedArrays.SizedArray{(2,2)}([1 2; 3 4])
B = Bidiagonal(fill(s,4), fill(s,3), :U)
@test B[2,1] isa typeof(s)
@test all(iszero, B[2,1])
end
end

@testset "copyto!" begin
Expand Down
12 changes: 12 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,16 @@ end
D = Diagonal(fill(S,3))
@test D * fill(S,2,3)' == fill(S * S', 3, 2)
@test fill(S,3,2)' * D == fill(S' * S, 2, 3)

@testset "indexing with non-standard-axes" begin
LinearAlgebra.diagzero(T::Type, ax::Tuple{SizedArrays.SOneTo, Vararg{SizedArrays.SOneTo}}) =
zeros(T, ax)

s = SizedArrays.SizedArray{(2,2)}([1 2; 3 4])
D = Diagonal(fill(s,3))
@test D[1,2] isa typeof(s)
@test all(iszero, D[1,2])
end
end

@testset "Eigensystem for block diagonal (issue #30681)" begin
Expand Down Expand Up @@ -1335,4 +1345,6 @@ end
end
end



end # module TestDiagonal

0 comments on commit dca91a4

Please sign in to comment.