From 259e82c08d73716307680b2174e7565884c50788 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sat, 31 Aug 2024 00:21:33 +0530 Subject: [PATCH 1/4] Specialize indexing triangular matrices with BandIndex --- stdlib/LinearAlgebra/src/dense.jl | 2 +- stdlib/LinearAlgebra/src/triangular.jl | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/stdlib/LinearAlgebra/src/dense.jl b/stdlib/LinearAlgebra/src/dense.jl index 62096cbb172f2..aacc5479bfa9d 100644 --- a/stdlib/LinearAlgebra/src/dense.jl +++ b/stdlib/LinearAlgebra/src/dense.jl @@ -110,7 +110,7 @@ norm2(x::Union{Array{T},StridedVector{T}}) where {T<:BlasFloat} = # Conservative assessment of types that have zero(T) defined for themselves haszero(::Type) = false haszero(::Type{T}) where {T<:Number} = isconcretetype(T) -@propagate_inbounds _zero(M::AbstractArray{T}, i, j) where {T} = haszero(T) ? zero(T) : zero(M[i,j]) +@propagate_inbounds _zero(M::AbstractArray{T}, inds...) where {T} = haszero(T) ? zero(T) : zero(M[inds...]) """ triu!(M, k::Integer) diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index 74eab6a392723..19e2ebe464109 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -236,6 +236,20 @@ Base.isstored(A::UpperTriangular, i::Int, j::Int) = @propagate_inbounds getindex(A::UpperTriangular, i::Int, j::Int) = i <= j ? A.data[i,j] : _zero(A.data,j,i) +# these specialized getindex methods enable constant-propagation of the band +Base.@constprop :aggressive @propagate_inbounds function getindex(A::UnitLowerTriangular{T}, b::BandIndex) where {T} + b.band < 0 ? A.data[b] : ifelse(b.band == 0, oneunit(T), zero(T)) +end +Base.@constprop :aggressive @propagate_inbounds function getindex(A::LowerTriangular, b::BandIndex) + b.band <= 0 ? A.data[b] : _zero(A.data, b) +end +Base.@constprop :aggressive @propagate_inbounds function getindex(A::UnitUpperTriangular{T}, b::BandIndex) where {T} + b.band > 0 ? A.data[b] : ifelse(b.band == 0, oneunit(T), zero(T)) +end +Base.@constprop :aggressive @propagate_inbounds function getindex(A::UpperTriangular, b::BandIndex) + b.band >= 0 ? A.data[b] : _zero(A.data, b) +end + _zero_triangular_half_str(::Type{<:UpperOrUnitUpperTriangular}) = "lower" _zero_triangular_half_str(::Type{<:LowerOrUnitLowerTriangular}) = "upper" From 09a5386987cc064da768c0276a30d3f8f7c16e64 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sat, 31 Aug 2024 12:44:59 +0530 Subject: [PATCH 2/4] Add tests --- stdlib/LinearAlgebra/test/triangular.jl | 29 ++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/stdlib/LinearAlgebra/test/triangular.jl b/stdlib/LinearAlgebra/test/triangular.jl index a09d0092e9f39..375e2210a61aa 100644 --- a/stdlib/LinearAlgebra/test/triangular.jl +++ b/stdlib/LinearAlgebra/test/triangular.jl @@ -6,7 +6,7 @@ debug = false using Test, LinearAlgebra, Random using LinearAlgebra: BlasFloat, errorbounds, full!, transpose!, UnitUpperTriangular, UnitLowerTriangular, - mul!, rdiv!, rmul!, lmul! + mul!, rdiv!, rmul!, lmul!, BandIndex const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test") @@ -1270,4 +1270,31 @@ end end end +@testset "indexing with a BandIndex" begin + # these tests should succeed even if the linear index along + # the band isn't a constant, or type-inferred at all + M = rand(Int,2,2) + f(A,j, v::Val{n}) where {n} = Val(A[BandIndex(n,j)]) + function common_tests(M, ind) + j = ind[] + @test @inferred(f(UpperTriangular(M), j, Val(-1))) == Val(0) + @test @inferred(f(UnitUpperTriangular(M), j, Val(-1))) == Val(0) + @test @inferred(f(UnitUpperTriangular(M), j, Val(0))) == Val(1) + @test @inferred(f(LowerTriangular(M), j, Val(1))) == Val(0) + @test @inferred(f(UnitLowerTriangular(M), j, Val(1))) == Val(0) + @test @inferred(f(UnitLowerTriangular(M), j, Val(0))) == Val(1) + end + common_tests(M, Any[1]) + + M = Diagonal(Int[1,2]) + common_tests(M, Any[1]) + # extra tests for banded structure of the parent + for T in (UpperTriangular, UnitUpperTriangular) + @test @inferred(f(T(M), 1, Val(1))) == Val(0) + end + for T in (LowerTriangular, UnitLowerTriangular) + @test @inferred(f(T(M), 1, Val(-1))) == Val(0) + end +end + end # module TestTriangular From ec7dcb629fded56414e033da44257d711240bb79 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sat, 31 Aug 2024 12:51:42 +0530 Subject: [PATCH 3/4] Tests for a Tridiaognal parent --- stdlib/LinearAlgebra/test/triangular.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/stdlib/LinearAlgebra/test/triangular.jl b/stdlib/LinearAlgebra/test/triangular.jl index 375e2210a61aa..b4733c81add31 100644 --- a/stdlib/LinearAlgebra/test/triangular.jl +++ b/stdlib/LinearAlgebra/test/triangular.jl @@ -1286,7 +1286,7 @@ end end common_tests(M, Any[1]) - M = Diagonal(Int[1,2]) + M = Diagonal([1,2]) common_tests(M, Any[1]) # extra tests for banded structure of the parent for T in (UpperTriangular, UnitUpperTriangular) @@ -1295,6 +1295,15 @@ end for T in (LowerTriangular, UnitLowerTriangular) @test @inferred(f(T(M), 1, Val(-1))) == Val(0) end + + M = Tridiagonal([1,2], [1,2,3], [1,2]) + common_tests(M, Any[1]) + for T in (UpperTriangular, UnitUpperTriangular) + @test @inferred(f(T(M), 1, Val(2))) == Val(0) + end + for T in (LowerTriangular, UnitLowerTriangular) + @test @inferred(f(T(M), 1, Val(-2))) == Val(0) + end end end # module TestTriangular From e0732e7bfbd038f759b3b43b4418b40fecabd048 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 23 Sep 2024 14:24:07 +0530 Subject: [PATCH 4/4] Branch elimination in Bidiagonal indexing with BandIndex --- stdlib/LinearAlgebra/src/bidiag.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 12d638f52add6..7919e13ccccc4 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -166,10 +166,11 @@ end end @inline function getindex(A::Bidiagonal{T}, b::BandIndex) where T - @boundscheck checkbounds(A, _cartinds(b)) + @boundscheck checkbounds(A, b) if b.band == 0 return @inbounds A.dv[b.index] - elseif b.band == _offdiagind(A.uplo) + elseif b.band ∈ (-1,1) && b.band == _offdiagind(A.uplo) + # we explicitly compare the possible bands as b.band may be constant-propagated return @inbounds A.ev[b.index] else return bidiagzero(A, Tuple(_cartinds(b))...)