From dca91a468fa748d1b24191cb2682bb42d1502807 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Thu, 25 Jul 2024 23:28:58 +0530 Subject: [PATCH] LinearAlgebra: diagzero for non-standard axes --- stdlib/LinearAlgebra/src/LinearAlgebra.jl | 3 ++- stdlib/LinearAlgebra/src/bidiag.jl | 14 +++++++------- stdlib/LinearAlgebra/src/diagonal.jl | 19 ++++++++++++++++++- stdlib/LinearAlgebra/test/bidiag.jl | 10 ++++++++++ stdlib/LinearAlgebra/test/diagonal.jl | 12 ++++++++++++ 5 files changed, 49 insertions(+), 9 deletions(-) diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index bad0431755e98..164b50fc88dbd 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -174,7 +174,8 @@ public AbstractTriangular, isbanded, peakflops, symmetric, - symmetric_type + symmetric_type, + diagzero const BlasFloat = Union{Float64,Float32,ComplexF64,ComplexF32} const BlasReal = Union{Float64,Float32} diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 04d54911d88aa..01719484dd6e9 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -118,15 +118,15 @@ Bidiagonal(A::Bidiagonal) = A Bidiagonal{T}(A::Bidiagonal{T}) where {T} = A Bidiagonal{T}(A::Bidiagonal) where {T} = Bidiagonal{T}(A.dv, A.ev, A.uplo) -bidiagzero(::Bidiagonal{T}, i, j) where {T} = zero(T) -function bidiagzero(A::Bidiagonal{<:AbstractMatrix}, i, j) +diagzero(::Bidiagonal{T}, i, j) where {T} = zero(T) +function diagzero(A::Bidiagonal{<:AbstractMatrix}, i, j) Tel = eltype(eltype(A.dv)) if i < j && A.uplo == 'U' #= top right zeros =# - return zeros(Tel, size(A.ev[i], 1), size(A.ev[j-1], 2)) + return diagzero(Tel, axes(A.ev[i], 1), axes(A.ev[j-1], 2)) elseif j < i && A.uplo == 'L' #= bottom left zeros =# - return zeros(Tel, size(A.ev[i-1], 1), size(A.ev[j], 2)) + return diagzero(Tel, axes(A.ev[i-1], 1), axes(A.ev[j], 2)) else - return zeros(Tel, size(A.dv[i], 1), size(A.dv[j], 2)) + return diagzero(Tel, axes(A.dv[i], 1), axes(A.dv[j], 2)) end end @@ -165,7 +165,7 @@ end elseif A.uplo == 'L' && (i == j + 1) return @inbounds A.ev[j] else - return bidiagzero(A, i, j) + return diagzero(A, i, j) end end @@ -178,7 +178,7 @@ end elseif A.uplo == 'L' && b.band == -1 return @inbounds A.ev[b.index] else - return bidiagzero(A, Tuple(_cartinds(b))...) + return diagzero(A, Tuple(_cartinds(b))...) end end diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index b3826a2aa7f82..c33cf01ab9124 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -185,8 +185,25 @@ end end r end +""" + diagzero(A::AbstractMatrix, i, j) + +Return the appropriate zero element `A[i, j]` corresponding to a banded matrix `A`. +""" diagzero(::Diagonal{T}, i, j) where {T} = zero(T) -diagzero(D::Diagonal{<:AbstractMatrix{T}}, i, j) where {T} = zeros(T, size(D.diag[i], 1), size(D.diag[j], 2)) +diagzero(D::Diagonal{<:AbstractMatrix{T}}, i, j) where {T} = diagzero(T, axes(D.diag[i], 1), axes(D.diag[j], 2)) +# dispatching on the axes permits specializing on the axis types to return something other than an Array +diagzero(T::Type, ax::Union{AbstractUnitRange, Integer}...) = diagzero(T, ax) +diagzero(T::Type, ::Tuple{}) = zeros(T) +""" + diagzero(T::Type, ax::Tuple{AbstractUnitRange, Vararg{AbstractUnitRange}}) + +Return an appropriate zero-ed array with either the axes `ax`, or the `size` `map(length, ax)`, +which may be used as a structural zero element of a banded matrix. By default, this falls back to +using the size along each axis to construct the result. +""" +diagzero(T::Type, ax::Tuple{AbstractUnitRange, Vararg{AbstractUnitRange}}) = diagzero(T, map(length, ax)) +diagzero(T::Type, sz::Tuple{Integer, Vararg{Integer}}) = zeros(T, sz) @inline function getindex(D::Diagonal, b::BandIndex) @boundscheck checkbounds(D, b) diff --git a/stdlib/LinearAlgebra/test/bidiag.jl b/stdlib/LinearAlgebra/test/bidiag.jl index 2ff3e9b423702..3d92e98e9144c 100644 --- a/stdlib/LinearAlgebra/test/bidiag.jl +++ b/stdlib/LinearAlgebra/test/bidiag.jl @@ -835,6 +835,16 @@ end B = Bidiagonal(dv, ev, :U) @test B == Matrix{eltype(B)}(B) end + + @testset "non-standard axes" begin + LinearAlgebra.diagzero(T::Type, ax::Tuple{SizedArrays.SOneTo, Vararg{SizedArrays.SOneTo}}) = + zeros(T, ax) + + s = SizedArrays.SizedArray{(2,2)}([1 2; 3 4]) + B = Bidiagonal(fill(s,4), fill(s,3), :U) + @test B[2,1] isa typeof(s) + @test all(iszero, B[2,1]) + end end @testset "copyto!" begin diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index 1a3b8d4fd0ea7..949c5148da2d6 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -810,6 +810,16 @@ end D = Diagonal(fill(S,3)) @test D * fill(S,2,3)' == fill(S * S', 3, 2) @test fill(S,3,2)' * D == fill(S' * S, 2, 3) + + @testset "indexing with non-standard-axes" begin + LinearAlgebra.diagzero(T::Type, ax::Tuple{SizedArrays.SOneTo, Vararg{SizedArrays.SOneTo}}) = + zeros(T, ax) + + s = SizedArrays.SizedArray{(2,2)}([1 2; 3 4]) + D = Diagonal(fill(s,3)) + @test D[1,2] isa typeof(s) + @test all(iszero, D[1,2]) + end end @testset "Eigensystem for block diagonal (issue #30681)" begin @@ -1335,4 +1345,6 @@ end end end + + end # module TestDiagonal