Skip to content

Commit

Permalink
Dispatch on block matrix types
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Jul 26, 2024
1 parent 12d6b20 commit a189bd6
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
18 changes: 10 additions & 8 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,19 +191,21 @@ end
Return the appropriate zero element `A[i, j]` corresponding to a banded matrix `A`.
"""
diagzero(::Diagonal{T}, i, j) where {T} = zero(T)
diagzero(D::Diagonal{<:AbstractMatrix{T}}, i, j) where {T} = diagzero(T, axes(D.diag[i], 1), axes(D.diag[j], 2))
diagzero(D::Diagonal{M, <:AbstractVector{M}}, i, j) where {T,M<:AbstractMatrix{T}} =
diagzero(M, axes(D.diag[i], 1), axes(D.diag[j], 2))
# dispatching on the axes permits specializing on the axis types to return something other than an Array
diagzero(T::Type, ax::Union{AbstractUnitRange, Integer}...) = diagzero(T, ax)
diagzero(T::Type, ::Tuple{}) = zeros(T)
diagzero(M::Type, ax::Union{AbstractUnitRange, Integer}...) = diagzero(M, ax)
diagzero(M::Type, ::Tuple{}) = zeros(eltype(M))
"""
diagzero(T::Type, ax::Tuple{AbstractUnitRange, Vararg{AbstractUnitRange}})
diagzero(::Type{M}, ax::Tuple{AbstractUnitRange, Vararg{AbstractUnitRange}}) where {M<:AbstractMatrix}
Return an appropriate zero-ed array with either the axes `ax`, or the `size` `map(length, ax)`,
which may be used as a structural zero element of a banded matrix. By default, this falls back to
Return an appropriate zero-ed matrix similar to `M`, with either
the axes `ax`, or the `size` `map(length, ax)`.
This will be used as a structural zero element of a banded matrix. By default, `diagzero` falls back to
using the size along each axis to construct the result.
"""
diagzero(T::Type, ax::Tuple{AbstractUnitRange, Vararg{AbstractUnitRange}}) = diagzero(T, map(length, ax))
diagzero(T::Type, sz::Tuple{Integer, Vararg{Integer}}) = zeros(T, sz)
diagzero(M::Type, ax::Tuple{AbstractUnitRange, Vararg{AbstractUnitRange}}) = diagzero(M, map(length, ax))
diagzero(::Type{M}, sz::Tuple{Integer, Vararg{Integer}}) where {M<:AbstractMatrix} = zeros(eltype(M), sz)

@inline function getindex(D::Diagonal, b::BandIndex)
@boundscheck checkbounds(D, b)
Expand Down
3 changes: 0 additions & 3 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -812,9 +812,6 @@ end
@test fill(S,3,2)' * D == fill(S' * S, 2, 3)

@testset "indexing with non-standard-axes" begin
LinearAlgebra.diagzero(T::Type, ax::Tuple{SizedArrays.SOneTo, Vararg{SizedArrays.SOneTo}}) =
zeros(T, ax)

s = SizedArrays.SizedArray{(2,2)}([1 2; 3 4])
D = Diagonal(fill(s,3))
@test @inferred(D[1,2]) isa typeof(s)
Expand Down
3 changes: 3 additions & 0 deletions test/testhelpers/SizedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,7 @@ mul!(dest::AbstractMatrix, S1::SizedMatrix, S2::SizedMatrix, α::Number, β::Num
mul!(dest::AbstractVector, M::AbstractMatrix, v::SizedVector, α::Number, β::Number) =
mul!(dest, M, _data(v), α, β)

LinearAlgebra.diagzero(::Type{S}, ax::Tuple{SizedArrays.SOneTo, Vararg{SizedArrays.SOneTo}}) where {S<:SizedArray} =
zeros(eltype(S), ax)

end

0 comments on commit a189bd6

Please sign in to comment.