From f93138ed0791799bf4bd33649cb3269054474a24 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 11 Nov 2024 20:48:37 +0530 Subject: [PATCH] Specialize `isbanded` for `StridedMatrix` (#56487) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This improves performance, as the loops in `istriu` and `istril` may be fused to improve cache-locality. This also changes the quick-return behavior, and only returns after the check over all the upper or lower bands for a column is complete. ```julia julia> using LinearAlgebra julia> A = zeros(2, 10_000); julia> @btime isdiag($A); 32.682 μs (0 allocations: 0 bytes) # nightly v"1.12.0-DEV.1593" 9.481 μs (0 allocations: 0 bytes) # this PR julia> A = zeros(10_000, 2); julia> @btime isdiag($A); 10.288 μs (0 allocations: 0 bytes) # nightly 2.579 μs (0 allocations: 0 bytes) # this PR julia> A = zeros(100, 100); julia> @btime isdiag($A); 6.616 μs (0 allocations: 0 bytes) # nightly 3.075 μs (0 allocations: 0 bytes) # this PR julia> A = diagm(0=>1:100); A[3,4] = 1; julia> @btime isdiag($A); 2.759 μs (0 allocations: 0 bytes) # nightly 85.371 ns (0 allocations: 0 bytes) # this PR ``` A similar change is added to `istriu`/`istril` as well, so that ```julia julia> A = zeros(2, 10_000); julia> @btime istriu($A); # trivial 7.358 ns (0 allocations: 0 bytes) # nightly 13.779 ns (0 allocations: 0 bytes) # this PR julia> @btime istril($A); 33.464 μs (0 allocations: 0 bytes) # nightly 9.476 μs (0 allocations: 0 bytes) # this PR julia> A = zeros(10_000, 2); julia> @btime istriu($A); 10.020 μs (0 allocations: 0 bytes) # nightly 2.620 μs (0 allocations: 0 bytes) # this PR julia> @btime istril($A); # trivial 6.793 ns (0 allocations: 0 bytes) # nightly 14.473 ns (0 allocations: 0 bytes) # this PR julia> A = zeros(100, 100); julia> @btime istriu($A); 3.435 μs (0 allocations: 0 bytes) # nightly 1.637 μs (0 allocations: 0 bytes) # this PR julia> @btime istril($A); 3.353 μs (0 allocations: 0 bytes) # nightly 1.661 μs (0 allocations: 0 bytes) # this PR ``` --------- Co-authored-by: Daniel Karrasch --- stdlib/LinearAlgebra/src/generic.jl | 95 ++++++++++++++++++------- stdlib/LinearAlgebra/src/hessenberg.jl | 10 +++ stdlib/LinearAlgebra/src/special.jl | 1 + stdlib/LinearAlgebra/src/triangular.jl | 8 ++- stdlib/LinearAlgebra/test/generic.jl | 58 ++++++++++++++- stdlib/LinearAlgebra/test/hessenberg.jl | 26 +++++++ 6 files changed, 170 insertions(+), 28 deletions(-) diff --git a/stdlib/LinearAlgebra/src/generic.jl b/stdlib/LinearAlgebra/src/generic.jl index 21719c0c50127..666ad631f919a 100644 --- a/stdlib/LinearAlgebra/src/generic.jl +++ b/stdlib/LinearAlgebra/src/generic.jl @@ -1353,6 +1353,14 @@ end ishermitian(x::Number) = (x == conj(x)) +# helper function equivalent to `iszero(v)`, but potentially without the fast exit feature +# of `all` if this improves performance +_iszero(V) = iszero(V) +# A Base.FastContiguousSubArray view of a StridedArray +FastContiguousSubArrayStrided{T,N,P<:StridedArray,I<:Tuple{AbstractUnitRange, Vararg{Any}}} = Base.SubArray{T,N,P,I,true} +# using mapreduce instead of all permits vectorization +_iszero(V::FastContiguousSubArrayStrided) = mapreduce(iszero, &, V, init=true) + """ istriu(A::AbstractMatrix, k::Integer = 0) -> Bool @@ -1384,20 +1392,9 @@ julia> istriu(c, -1) true ``` """ -function istriu(A::AbstractMatrix, k::Integer = 0) - require_one_based_indexing(A) - return _istriu(A, k) -end +istriu(A::AbstractMatrix, k::Integer = 0) = _isbanded_impl(A, k, size(A,2)-1) istriu(x::Number) = true -@inline function _istriu(A::AbstractMatrix, k) - m, n = size(A) - for j in 1:min(n, m + k - 1) - all(iszero, view(A, max(1, j - k + 1):m, j)) || return false - end - return true -end - """ istril(A::AbstractMatrix, k::Integer = 0) -> Bool @@ -1429,20 +1426,9 @@ julia> istril(c, 1) true ``` """ -function istril(A::AbstractMatrix, k::Integer = 0) - require_one_based_indexing(A) - return _istril(A, k) -end +istril(A::AbstractMatrix, k::Integer = 0) = _isbanded_impl(A, -size(A,1)+1, k) istril(x::Number) = true -@inline function _istril(A::AbstractMatrix, k) - m, n = size(A) - for j in max(1, k + 2):n - all(iszero, view(A, 1:min(j - k - 1, m), j)) || return false - end - return true -end - """ isbanded(A::AbstractMatrix, kl::Integer, ku::Integer) -> Bool @@ -1474,7 +1460,66 @@ julia> LinearAlgebra.isbanded(b, -1, 0) true ``` """ -isbanded(A::AbstractMatrix, kl::Integer, ku::Integer) = istriu(A, kl) && istril(A, ku) +isbanded(A::AbstractMatrix, kl::Integer, ku::Integer) = _isbanded(A, kl, ku) +_isbanded(A::AbstractMatrix, kl::Integer, ku::Integer) = istriu(A, kl) && istril(A, ku) +# Performance optimization for StridedMatrix by better utilizing cache locality +# The istriu and istril loops are merged +# the additional indirection allows us to reuse the isbanded loop within istriu/istril +# without encountering cycles +_isbanded(A::StridedMatrix, kl::Integer, ku::Integer) = _isbanded_impl(A, kl, ku) +function _isbanded_impl(A, kl, ku) + Base.require_one_based_indexing(A) + + #= + We split the column range into four possible groups, depending on the values of kl and ku. + + The first is the bottom left triangle, where bands below kl must be zero, + but there are no bands above ku in that column. + + The second is where there are both bands below kl and above ku in the column. + These are the middle columns typically. + + The third is the top right, where there are bands above ku but no bands below kl + in the column. + + The fourth is mainly relevant for wide matrices, where there is a block to the right + beyond ku, where the elements should all be zero. The reason we separate this from the + third group is that we may loop over all the rows using A[:, col] instead of A[rowrange, col], + which is usually faster. + =# + + last_col_nonzeroblocks = size(A,1) + ku # fully zero rectangular block beyond this column + last_col_emptytoprows = ku + 1 # empty top rows before this column + last_col_nonemptybottomrows = size(A,1) + kl - 1 # empty bottom rows after this column + + colrange_onlybottomrows = firstindex(A,2):min(last_col_nonemptybottomrows, last_col_emptytoprows) + colrange_topbottomrows = max(last_col_emptytoprows, last(colrange_onlybottomrows))+1:last_col_nonzeroblocks + colrange_onlytoprows_nonzero = last(colrange_topbottomrows)+1:last_col_nonzeroblocks + colrange_zero_block = last_col_nonzeroblocks+1:lastindex(A,2) + + for col in intersect(axes(A,2), colrange_onlybottomrows) # only loop over the bottom rows + botrowinds = max(firstindex(A,1), col-kl+1):lastindex(A,1) + bottomrows = @view A[botrowinds, col] + _iszero(bottomrows) || return false + end + for col in intersect(axes(A,2), colrange_topbottomrows) + toprowinds = firstindex(A,1):min(col-ku-1, lastindex(A,1)) + toprows = @view A[toprowinds, col] + _iszero(toprows) || return false + botrowinds = max(firstindex(A,1), col-kl+1):lastindex(A,1) + bottomrows = @view A[botrowinds, col] + _iszero(bottomrows) || return false + end + for col in intersect(axes(A,2), colrange_onlytoprows_nonzero) + toprowinds = firstindex(A,1):min(col-ku-1, lastindex(A,1)) + toprows = @view A[toprowinds, col] + _iszero(toprows) || return false + end + for col in intersect(axes(A,2), colrange_zero_block) + _iszero(@view A[:, col]) || return false + end + return true +end """ isdiag(A) -> Bool diff --git a/stdlib/LinearAlgebra/src/hessenberg.jl b/stdlib/LinearAlgebra/src/hessenberg.jl index bfe2fdd41aace..ed654c33aba55 100644 --- a/stdlib/LinearAlgebra/src/hessenberg.jl +++ b/stdlib/LinearAlgebra/src/hessenberg.jl @@ -77,6 +77,16 @@ Base.@constprop :aggressive function istriu(A::UpperHessenberg, k::Integer=0) k <= -1 && return true return _istriu(A, k) end +# additional indirection to dispatch to optimized method for banded parents (defined in special.jl) +@inline function _istriu(A::UpperHessenberg, k) + P = parent(A) + m = size(A, 1) + for j in firstindex(P,2):min(m + k - 1, lastindex(P,2)) + Prows = @view P[max(begin, j - k + 1):min(j+1,end), j] + _iszero(Prows) || return false + end + return true +end function Matrix{T}(H::UpperHessenberg) where T m,n = size(H) diff --git a/stdlib/LinearAlgebra/src/special.jl b/stdlib/LinearAlgebra/src/special.jl index 6d25540ee3f07..c61586a810140 100644 --- a/stdlib/LinearAlgebra/src/special.jl +++ b/stdlib/LinearAlgebra/src/special.jl @@ -592,3 +592,4 @@ end # istriu/istril for triangular wrappers of structured matrices _istril(A::LowerTriangular{<:Any, <:BandedMatrix}, k) = istril(parent(A), k) _istriu(A::UpperTriangular{<:Any, <:BandedMatrix}, k) = istriu(parent(A), k) +_istriu(A::UpperHessenberg{<:Any, <:BandedMatrix}, k) = istriu(parent(A), k) diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index 49ff5d7f9c3ec..76d97133de796 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -348,25 +348,29 @@ Base.@constprop :aggressive function istril(A::LowerTriangular, k::Integer=0) k >= 0 && return true return _istril(A, k) end +# additional indirection to dispatch to optimized method for banded parents (defined in special.jl) @inline function _istril(A::LowerTriangular, k) P = parent(A) for j in max(firstindex(P,2), k + 2):lastindex(P,2) - all(iszero, @view(P[j:min(j - k - 1, end), j])) || return false + _iszero(@view P[max(j, begin):min(j - k - 1, end), j]) || return false end return true end + Base.@constprop :aggressive function istriu(A::UpperTriangular, k::Integer=0) k <= 0 && return true return _istriu(A, k) end +# additional indirection to dispatch to optimized method for banded parents (defined in special.jl) @inline function _istriu(A::UpperTriangular, k) P = parent(A) m = size(A, 1) for j in firstindex(P,2):min(m + k - 1, lastindex(P,2)) - all(iszero, @view(P[max(begin, j - k + 1):j, j])) || return false + _iszero(@view P[max(begin, j - k + 1):min(j, end), j]) || return false end return true end + istril(A::Adjoint, k::Integer=0) = istriu(A.parent, -k) istril(A::Transpose, k::Integer=0) = istriu(A.parent, -k) istriu(A::Adjoint, k::Integer=0) = istril(A.parent, -k) diff --git a/stdlib/LinearAlgebra/test/generic.jl b/stdlib/LinearAlgebra/test/generic.jl index 725f9b3497db8..6d11ec824e538 100644 --- a/stdlib/LinearAlgebra/test/generic.jl +++ b/stdlib/LinearAlgebra/test/generic.jl @@ -3,6 +3,8 @@ module TestGeneric using Test, LinearAlgebra, Random +using Test: GenericArray +using LinearAlgebra: isbanded const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test") @@ -511,56 +513,110 @@ end end @testset "generic functions for checking whether matrices have banded structure" begin - using LinearAlgebra: isbanded pentadiag = [1 2 3; 4 5 6; 7 8 9] tridiag = [1 2 0; 4 5 6; 0 8 9] + tridiagG = GenericArray([1 2 0; 4 5 6; 0 8 9]) + Tridiag = Tridiagonal(tridiag) ubidiag = [1 2 0; 0 5 6; 0 0 9] + ubidiagG = GenericArray([1 2 0; 0 5 6; 0 0 9]) + uBidiag = Bidiagonal(ubidiag, :U) lbidiag = [1 0 0; 4 5 0; 0 8 9] + lbidiagG = GenericArray([1 0 0; 4 5 0; 0 8 9]) + lBidiag = Bidiagonal(lbidiag, :L) adiag = [1 0 0; 0 5 0; 0 0 9] + adiagG = GenericArray([1 0 0; 0 5 0; 0 0 9]) + aDiag = Diagonal(adiag) @testset "istriu" begin @test !istriu(pentadiag) @test istriu(pentadiag, -2) @test !istriu(tridiag) + @test istriu(tridiag) == istriu(tridiagG) == istriu(Tridiag) @test istriu(tridiag, -1) + @test istriu(tridiag, -1) == istriu(tridiagG, -1) == istriu(Tridiag, -1) @test istriu(ubidiag) + @test istriu(ubidiag) == istriu(ubidiagG) == istriu(uBidiag) @test !istriu(ubidiag, 1) + @test istriu(ubidiag, 1) == istriu(ubidiagG, 1) == istriu(uBidiag, 1) @test !istriu(lbidiag) + @test istriu(lbidiag) == istriu(lbidiagG) == istriu(lBidiag) @test istriu(lbidiag, -1) + @test istriu(lbidiag, -1) == istriu(lbidiagG, -1) == istriu(lBidiag, -1) @test istriu(adiag) + @test istriu(adiag) == istriu(adiagG) == istriu(aDiag) end @testset "istril" begin @test !istril(pentadiag) @test istril(pentadiag, 2) @test !istril(tridiag) + @test istril(tridiag) == istril(tridiagG) == istril(Tridiag) @test istril(tridiag, 1) + @test istril(tridiag, 1) == istril(tridiagG, 1) == istril(Tridiag, 1) @test !istril(ubidiag) + @test istril(ubidiag) == istril(ubidiagG) == istril(ubidiagG) @test istril(ubidiag, 1) + @test istril(ubidiag, 1) == istril(ubidiagG, 1) == istril(uBidiag, 1) @test istril(lbidiag) + @test istril(lbidiag) == istril(lbidiagG) == istril(lBidiag) @test !istril(lbidiag, -1) + @test istril(lbidiag, -1) == istril(lbidiagG, -1) == istril(lBidiag, -1) @test istril(adiag) + @test istril(adiag) == istril(adiagG) == istril(aDiag) end @testset "isbanded" begin @test isbanded(pentadiag, -2, 2) @test !isbanded(pentadiag, -1, 2) @test !isbanded(pentadiag, -2, 1) @test isbanded(tridiag, -1, 1) + @test isbanded(tridiag, -1, 1) == isbanded(tridiagG, -1, 1) == isbanded(Tridiag, -1, 1) @test !isbanded(tridiag, 0, 1) + @test isbanded(tridiag, 0, 1) == isbanded(tridiagG, 0, 1) == isbanded(Tridiag, 0, 1) @test !isbanded(tridiag, -1, 0) + @test isbanded(tridiag, -1, 0) == isbanded(tridiagG, -1, 0) == isbanded(Tridiag, -1, 0) @test isbanded(ubidiag, 0, 1) + @test isbanded(ubidiag, 0, 1) == isbanded(ubidiagG, 0, 1) == isbanded(uBidiag, 0, 1) @test !isbanded(ubidiag, 1, 1) + @test isbanded(ubidiag, 1, 1) == isbanded(ubidiagG, 1, 1) == isbanded(uBidiag, 1, 1) @test !isbanded(ubidiag, 0, 0) + @test isbanded(ubidiag, 0, 0) == isbanded(ubidiagG, 0, 0) == isbanded(uBidiag, 0, 0) @test isbanded(lbidiag, -1, 0) + @test isbanded(lbidiag, -1, 0) == isbanded(lbidiagG, -1, 0) == isbanded(lBidiag, -1, 0) @test !isbanded(lbidiag, 0, 0) + @test isbanded(lbidiag, 0, 0) == isbanded(lbidiagG, 0, 0) == isbanded(lBidiag, 0, 0) @test !isbanded(lbidiag, -1, -1) + @test isbanded(lbidiag, -1, -1) == isbanded(lbidiagG, -1, -1) == isbanded(lBidiag, -1, -1) @test isbanded(adiag, 0, 0) + @test isbanded(adiag, 0, 0) == isbanded(adiagG, 0, 0) == isbanded(aDiag, 0, 0) @test !isbanded(adiag, -1, -1) + @test isbanded(adiag, -1, -1) == isbanded(adiagG, -1, -1) == isbanded(aDiag, -1, -1) @test !isbanded(adiag, 1, 1) + @test isbanded(adiag, 1, 1) == isbanded(adiagG, 1, 1) == isbanded(aDiag, 1, 1) end @testset "isdiag" begin @test !isdiag(tridiag) + @test isdiag(tridiag) == isdiag(tridiagG) == isdiag(Tridiag) @test !isdiag(ubidiag) + @test isdiag(ubidiag) == isdiag(ubidiagG) == isdiag(uBidiag) @test !isdiag(lbidiag) + @test isdiag(lbidiag) == isdiag(lbidiagG) == isdiag(lBidiag) @test isdiag(adiag) + @test isdiag(adiag) ==isdiag(adiagG) == isdiag(aDiag) + end +end + +@testset "isbanded/istril/istriu with rectangular matrices" begin + @testset "$(size(A))" for A in [zeros(0,4), zeros(2,5), zeros(5,2), zeros(4,0)] + @testset for m in -(size(A,1)-1):(size(A,2)-1) + A .= 0 + A[diagind(A, m)] .= 1 + G = GenericArray(A) + @testset for (kl,ku) in Iterators.product(-6:6, -6:6) + @test isbanded(A, kl, ku) == isbanded(G, kl, ku) == isempty(A) || (m in (kl:ku)) + end + @testset for k in -6:6 + @test istriu(A,k) == istriu(G,k) == isempty(A) || (k <= m) + @test istril(A,k) == istril(G,k) == isempty(A) || (k >= m) + end + end end end diff --git a/stdlib/LinearAlgebra/test/hessenberg.jl b/stdlib/LinearAlgebra/test/hessenberg.jl index 54dbb70aa2065..de58fea9fb27e 100644 --- a/stdlib/LinearAlgebra/test/hessenberg.jl +++ b/stdlib/LinearAlgebra/test/hessenberg.jl @@ -279,4 +279,30 @@ end @test H.H == D end +@testset "istriu/istril forwards to parent" begin + n = 10 + @testset "$(nameof(typeof(M)))" for M in [Tridiagonal(rand(n-1), rand(n), rand(n-1)), + Tridiagonal(zeros(n-1), zeros(n), zeros(n-1)), + Diagonal(randn(n)), + Diagonal(zeros(n)), + ] + U = UpperHessenberg(M) + A = Array(U) + for k in -n:n + @test istriu(U, k) == istriu(A, k) + @test istril(U, k) == istril(A, k) + end + end + z = zeros(n,n) + P = Matrix{BigFloat}(undef, n, n) + copytrito!(P, z, 'U') + P[diagind(P,-1)] .= 0 + U = UpperHessenberg(P) + A = Array(U) + @testset for k in -n:n + @test istriu(U, k) == istriu(A, k) + @test istril(U, k) == istril(A, k) + end +end + end # module TestHessenberg