From a189bd609ab9888df90241fd43b882cce2f7f2ec Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Fri, 26 Jul 2024 10:32:05 +0530 Subject: [PATCH] Dispatch on block matrix types --- stdlib/LinearAlgebra/src/diagonal.jl | 18 ++++++++++-------- stdlib/LinearAlgebra/test/diagonal.jl | 3 --- test/testhelpers/SizedArrays.jl | 3 +++ 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index c33cf01ab9124..1aadf3713d2a1 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -191,19 +191,21 @@ end Return the appropriate zero element `A[i, j]` corresponding to a banded matrix `A`. """ diagzero(::Diagonal{T}, i, j) where {T} = zero(T) -diagzero(D::Diagonal{<:AbstractMatrix{T}}, i, j) where {T} = diagzero(T, axes(D.diag[i], 1), axes(D.diag[j], 2)) +diagzero(D::Diagonal{M, <:AbstractVector{M}}, i, j) where {T,M<:AbstractMatrix{T}} = + diagzero(M, axes(D.diag[i], 1), axes(D.diag[j], 2)) # dispatching on the axes permits specializing on the axis types to return something other than an Array -diagzero(T::Type, ax::Union{AbstractUnitRange, Integer}...) = diagzero(T, ax) -diagzero(T::Type, ::Tuple{}) = zeros(T) +diagzero(M::Type, ax::Union{AbstractUnitRange, Integer}...) = diagzero(M, ax) +diagzero(M::Type, ::Tuple{}) = zeros(eltype(M)) """ - diagzero(T::Type, ax::Tuple{AbstractUnitRange, Vararg{AbstractUnitRange}}) + diagzero(::Type{M}, ax::Tuple{AbstractUnitRange, Vararg{AbstractUnitRange}}) where {M<:AbstractMatrix} -Return an appropriate zero-ed array with either the axes `ax`, or the `size` `map(length, ax)`, -which may be used as a structural zero element of a banded matrix. By default, this falls back to +Return an appropriate zero-ed matrix similar to `M`, with either +the axes `ax`, or the `size` `map(length, ax)`. +This will be used as a structural zero element of a banded matrix. By default, `diagzero` falls back to using the size along each axis to construct the result. """ -diagzero(T::Type, ax::Tuple{AbstractUnitRange, Vararg{AbstractUnitRange}}) = diagzero(T, map(length, ax)) -diagzero(T::Type, sz::Tuple{Integer, Vararg{Integer}}) = zeros(T, sz) +diagzero(M::Type, ax::Tuple{AbstractUnitRange, Vararg{AbstractUnitRange}}) = diagzero(M, map(length, ax)) +diagzero(::Type{M}, sz::Tuple{Integer, Vararg{Integer}}) where {M<:AbstractMatrix} = zeros(eltype(M), sz) @inline function getindex(D::Diagonal, b::BandIndex) @boundscheck checkbounds(D, b) diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index 513b6c4c376d2..534e82c391114 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -812,9 +812,6 @@ end @test fill(S,3,2)' * D == fill(S' * S, 2, 3) @testset "indexing with non-standard-axes" begin - LinearAlgebra.diagzero(T::Type, ax::Tuple{SizedArrays.SOneTo, Vararg{SizedArrays.SOneTo}}) = - zeros(T, ax) - s = SizedArrays.SizedArray{(2,2)}([1 2; 3 4]) D = Diagonal(fill(s,3)) @test @inferred(D[1,2]) isa typeof(s) diff --git a/test/testhelpers/SizedArrays.jl b/test/testhelpers/SizedArrays.jl index bc02fb5cbbd20..f8c586e5ed951 100644 --- a/test/testhelpers/SizedArrays.jl +++ b/test/testhelpers/SizedArrays.jl @@ -89,4 +89,7 @@ mul!(dest::AbstractMatrix, S1::SizedMatrix, S2::SizedMatrix, α::Number, β::Num mul!(dest::AbstractVector, M::AbstractMatrix, v::SizedVector, α::Number, β::Number) = mul!(dest, M, _data(v), α, β) +LinearAlgebra.diagzero(::Type{S}, ax::Tuple{SizedArrays.SOneTo, Vararg{SizedArrays.SOneTo}}) where {S<:SizedArray} = + zeros(eltype(S), ax) + end