Skip to content

Commit

Permalink
Specialize indexing triangular matrices with BandIndex
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Aug 30, 2024
1 parent b6d2155 commit 9dab665
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ norm2(x::Union{Array{T},StridedVector{T}}) where {T<:BlasFloat} =
# Conservative assessment of types that have zero(T) defined for themselves
haszero(::Type) = false
haszero(::Type{T}) where {T<:Number} = isconcretetype(T)
@propagate_inbounds _zero(M::AbstractArray{T}, i, j) where {T} = haszero(T) ? zero(T) : zero(M[i,j])
@propagate_inbounds _zero(M::AbstractArray{T}, inds...) where {T} = haszero(T) ? zero(T) : zero(M[inds...])

"""
triu!(M, k::Integer)
Expand Down
14 changes: 14 additions & 0 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,20 @@ Base.isstored(A::UpperTriangular, i::Int, j::Int) =
@propagate_inbounds getindex(A::UpperTriangular, i::Int, j::Int) =
i <= j ? A.data[i,j] : _zero(A.data,j,i)

# these specialized getindex methods enable constant-propagation of the band
Base.@constprop :aggressive @propagate_inbounds function getindex(A::UnitLowerTriangular{T}, b::BandIndex) where {T}
b.band < 0 ? A.data[b] : ifelse(b.band == 0, oneunit(T), zero(T))
end
Base.@constprop :aggressive @propagate_inbounds function getindex(A::LowerTriangular, b::BandIndex)
b.band <= 0 ? A.data[b] : _zero(A.data, b)
end
Base.@constprop :aggressive @propagate_inbounds function getindex(A::UnitUpperTriangular{T}, b::BandIndex) where {T}
b.band > 0 ? A.data[b] : ifelse(b.band == 0, oneunit(T), zero(T))
end
Base.@constprop :aggressive @propagate_inbounds function getindex(A::UpperTriangular, b::BandIndex)
b.band >= 0 ? A.data[b] : _zero(A.data, b)
end

_zero_triangular_half_str(::Type{<:UpperOrUnitUpperTriangular}) = "lower"
_zero_triangular_half_str(::Type{<:LowerOrUnitLowerTriangular}) = "upper"

Expand Down

0 comments on commit 9dab665

Please sign in to comment.