Skip to content

Commit

Permalink
Specialize isbanded for StridedMatrix (#56487)
Browse files Browse the repository at this point in the history
This improves performance, as the loops in `istriu` and `istril` may be
fused to improve cache-locality.
This also changes the quick-return behavior, and only returns after the
check over all the upper or lower bands for a column is complete.

```julia
julia> using LinearAlgebra

julia> A = zeros(2, 10_000);

julia> @Btime isdiag($A);
  32.682 μs (0 allocations: 0 bytes) # nightly v"1.12.0-DEV.1593"
  9.481 μs (0 allocations: 0 bytes) # this PR

julia> A = zeros(10_000, 2);

julia> @Btime isdiag($A);
  10.288 μs (0 allocations: 0 bytes)  # nightly 
  2.579 μs (0 allocations: 0 bytes) # this PR

julia> A = zeros(100, 100);

julia> @Btime isdiag($A);
  6.616 μs (0 allocations: 0 bytes) # nightly
  3.075 μs (0 allocations: 0 bytes) # this PR

julia> A = diagm(0=>1:100);  A[3,4] = 1;

julia> @Btime isdiag($A);
  2.759 μs (0 allocations: 0 bytes) # nightly
  85.371 ns (0 allocations: 0 bytes) # this PR
```

A similar change is added to `istriu`/`istril` as well, so that
```julia
julia> A = zeros(2, 10_000);

julia> @Btime istriu($A); # trivial
  7.358 ns (0 allocations: 0 bytes) # nightly
  13.779 ns (0 allocations: 0 bytes) # this PR

julia> @Btime istril($A);
  33.464 μs (0 allocations: 0 bytes) # nightly
  9.476 μs (0 allocations: 0 bytes) # this PR

julia> A = zeros(10_000, 2);

julia> @Btime istriu($A);
  10.020 μs (0 allocations: 0 bytes) # nightly
  2.620 μs (0 allocations: 0 bytes) # this PR

julia> @Btime istril($A); # trivial
  6.793 ns (0 allocations: 0 bytes) # nightly
  14.473 ns (0 allocations: 0 bytes) # this PR

julia> A = zeros(100, 100);

julia> @Btime istriu($A);
  3.435 μs (0 allocations: 0 bytes) # nightly
  1.637 μs (0 allocations: 0 bytes) # this PR

julia> @Btime istril($A);
  3.353 μs (0 allocations: 0 bytes) # nightly
  1.661 μs (0 allocations: 0 bytes) # this PR
```

---------

Co-authored-by: Daniel Karrasch <[email protected]>
  • Loading branch information
jishnub and dkarrasch authored Nov 11, 2024
1 parent 38e3d14 commit f93138e
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 28 deletions.
95 changes: 70 additions & 25 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,14 @@ end

ishermitian(x::Number) = (x == conj(x))

# helper function equivalent to `iszero(v)`, but potentially without the fast exit feature
# of `all` if this improves performance
_iszero(V) = iszero(V)
# A Base.FastContiguousSubArray view of a StridedArray
FastContiguousSubArrayStrided{T,N,P<:StridedArray,I<:Tuple{AbstractUnitRange, Vararg{Any}}} = Base.SubArray{T,N,P,I,true}
# using mapreduce instead of all permits vectorization
_iszero(V::FastContiguousSubArrayStrided) = mapreduce(iszero, &, V, init=true)

"""
istriu(A::AbstractMatrix, k::Integer = 0) -> Bool
Expand Down Expand Up @@ -1384,20 +1392,9 @@ julia> istriu(c, -1)
true
```
"""
function istriu(A::AbstractMatrix, k::Integer = 0)
require_one_based_indexing(A)
return _istriu(A, k)
end
istriu(A::AbstractMatrix, k::Integer = 0) = _isbanded_impl(A, k, size(A,2)-1)
istriu(x::Number) = true

@inline function _istriu(A::AbstractMatrix, k)
m, n = size(A)
for j in 1:min(n, m + k - 1)
all(iszero, view(A, max(1, j - k + 1):m, j)) || return false
end
return true
end

"""
istril(A::AbstractMatrix, k::Integer = 0) -> Bool
Expand Down Expand Up @@ -1429,20 +1426,9 @@ julia> istril(c, 1)
true
```
"""
function istril(A::AbstractMatrix, k::Integer = 0)
require_one_based_indexing(A)
return _istril(A, k)
end
istril(A::AbstractMatrix, k::Integer = 0) = _isbanded_impl(A, -size(A,1)+1, k)
istril(x::Number) = true

@inline function _istril(A::AbstractMatrix, k)
m, n = size(A)
for j in max(1, k + 2):n
all(iszero, view(A, 1:min(j - k - 1, m), j)) || return false
end
return true
end

"""
isbanded(A::AbstractMatrix, kl::Integer, ku::Integer) -> Bool
Expand Down Expand Up @@ -1474,7 +1460,66 @@ julia> LinearAlgebra.isbanded(b, -1, 0)
true
```
"""
isbanded(A::AbstractMatrix, kl::Integer, ku::Integer) = istriu(A, kl) && istril(A, ku)
isbanded(A::AbstractMatrix, kl::Integer, ku::Integer) = _isbanded(A, kl, ku)
_isbanded(A::AbstractMatrix, kl::Integer, ku::Integer) = istriu(A, kl) && istril(A, ku)
# Performance optimization for StridedMatrix by better utilizing cache locality
# The istriu and istril loops are merged
# the additional indirection allows us to reuse the isbanded loop within istriu/istril
# without encountering cycles
_isbanded(A::StridedMatrix, kl::Integer, ku::Integer) = _isbanded_impl(A, kl, ku)
function _isbanded_impl(A, kl, ku)
Base.require_one_based_indexing(A)

#=
We split the column range into four possible groups, depending on the values of kl and ku.
The first is the bottom left triangle, where bands below kl must be zero,
but there are no bands above ku in that column.
The second is where there are both bands below kl and above ku in the column.
These are the middle columns typically.
The third is the top right, where there are bands above ku but no bands below kl
in the column.
The fourth is mainly relevant for wide matrices, where there is a block to the right
beyond ku, where the elements should all be zero. The reason we separate this from the
third group is that we may loop over all the rows using A[:, col] instead of A[rowrange, col],
which is usually faster.
=#

last_col_nonzeroblocks = size(A,1) + ku # fully zero rectangular block beyond this column
last_col_emptytoprows = ku + 1 # empty top rows before this column
last_col_nonemptybottomrows = size(A,1) + kl - 1 # empty bottom rows after this column

colrange_onlybottomrows = firstindex(A,2):min(last_col_nonemptybottomrows, last_col_emptytoprows)
colrange_topbottomrows = max(last_col_emptytoprows, last(colrange_onlybottomrows))+1:last_col_nonzeroblocks
colrange_onlytoprows_nonzero = last(colrange_topbottomrows)+1:last_col_nonzeroblocks
colrange_zero_block = last_col_nonzeroblocks+1:lastindex(A,2)

for col in intersect(axes(A,2), colrange_onlybottomrows) # only loop over the bottom rows
botrowinds = max(firstindex(A,1), col-kl+1):lastindex(A,1)
bottomrows = @view A[botrowinds, col]
_iszero(bottomrows) || return false
end
for col in intersect(axes(A,2), colrange_topbottomrows)
toprowinds = firstindex(A,1):min(col-ku-1, lastindex(A,1))
toprows = @view A[toprowinds, col]
_iszero(toprows) || return false
botrowinds = max(firstindex(A,1), col-kl+1):lastindex(A,1)
bottomrows = @view A[botrowinds, col]
_iszero(bottomrows) || return false
end
for col in intersect(axes(A,2), colrange_onlytoprows_nonzero)
toprowinds = firstindex(A,1):min(col-ku-1, lastindex(A,1))
toprows = @view A[toprowinds, col]
_iszero(toprows) || return false
end
for col in intersect(axes(A,2), colrange_zero_block)
_iszero(@view A[:, col]) || return false
end
return true
end

"""
isdiag(A) -> Bool
Expand Down
10 changes: 10 additions & 0 deletions stdlib/LinearAlgebra/src/hessenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ Base.@constprop :aggressive function istriu(A::UpperHessenberg, k::Integer=0)
k <= -1 && return true
return _istriu(A, k)
end
# additional indirection to dispatch to optimized method for banded parents (defined in special.jl)
@inline function _istriu(A::UpperHessenberg, k)
P = parent(A)
m = size(A, 1)
for j in firstindex(P,2):min(m + k - 1, lastindex(P,2))
Prows = @view P[max(begin, j - k + 1):min(j+1,end), j]
_iszero(Prows) || return false
end
return true
end

function Matrix{T}(H::UpperHessenberg) where T
m,n = size(H)
Expand Down
1 change: 1 addition & 0 deletions stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -592,3 +592,4 @@ end
# istriu/istril for triangular wrappers of structured matrices
_istril(A::LowerTriangular{<:Any, <:BandedMatrix}, k) = istril(parent(A), k)
_istriu(A::UpperTriangular{<:Any, <:BandedMatrix}, k) = istriu(parent(A), k)
_istriu(A::UpperHessenberg{<:Any, <:BandedMatrix}, k) = istriu(parent(A), k)
8 changes: 6 additions & 2 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,25 +348,29 @@ Base.@constprop :aggressive function istril(A::LowerTriangular, k::Integer=0)
k >= 0 && return true
return _istril(A, k)
end
# additional indirection to dispatch to optimized method for banded parents (defined in special.jl)
@inline function _istril(A::LowerTriangular, k)
P = parent(A)
for j in max(firstindex(P,2), k + 2):lastindex(P,2)
all(iszero, @view(P[j:min(j - k - 1, end), j])) || return false
_iszero(@view P[max(j, begin):min(j - k - 1, end), j]) || return false
end
return true
end

Base.@constprop :aggressive function istriu(A::UpperTriangular, k::Integer=0)
k <= 0 && return true
return _istriu(A, k)
end
# additional indirection to dispatch to optimized method for banded parents (defined in special.jl)
@inline function _istriu(A::UpperTriangular, k)
P = parent(A)
m = size(A, 1)
for j in firstindex(P,2):min(m + k - 1, lastindex(P,2))
all(iszero, @view(P[max(begin, j - k + 1):j, j])) || return false
_iszero(@view P[max(begin, j - k + 1):min(j, end), j]) || return false
end
return true
end

istril(A::Adjoint, k::Integer=0) = istriu(A.parent, -k)
istril(A::Transpose, k::Integer=0) = istriu(A.parent, -k)
istriu(A::Adjoint, k::Integer=0) = istril(A.parent, -k)
Expand Down
58 changes: 57 additions & 1 deletion stdlib/LinearAlgebra/test/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
module TestGeneric

using Test, LinearAlgebra, Random
using Test: GenericArray
using LinearAlgebra: isbanded

const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")

Expand Down Expand Up @@ -511,56 +513,110 @@ end
end

@testset "generic functions for checking whether matrices have banded structure" begin
using LinearAlgebra: isbanded
pentadiag = [1 2 3; 4 5 6; 7 8 9]
tridiag = [1 2 0; 4 5 6; 0 8 9]
tridiagG = GenericArray([1 2 0; 4 5 6; 0 8 9])
Tridiag = Tridiagonal(tridiag)
ubidiag = [1 2 0; 0 5 6; 0 0 9]
ubidiagG = GenericArray([1 2 0; 0 5 6; 0 0 9])
uBidiag = Bidiagonal(ubidiag, :U)
lbidiag = [1 0 0; 4 5 0; 0 8 9]
lbidiagG = GenericArray([1 0 0; 4 5 0; 0 8 9])
lBidiag = Bidiagonal(lbidiag, :L)
adiag = [1 0 0; 0 5 0; 0 0 9]
adiagG = GenericArray([1 0 0; 0 5 0; 0 0 9])
aDiag = Diagonal(adiag)
@testset "istriu" begin
@test !istriu(pentadiag)
@test istriu(pentadiag, -2)
@test !istriu(tridiag)
@test istriu(tridiag) == istriu(tridiagG) == istriu(Tridiag)
@test istriu(tridiag, -1)
@test istriu(tridiag, -1) == istriu(tridiagG, -1) == istriu(Tridiag, -1)
@test istriu(ubidiag)
@test istriu(ubidiag) == istriu(ubidiagG) == istriu(uBidiag)
@test !istriu(ubidiag, 1)
@test istriu(ubidiag, 1) == istriu(ubidiagG, 1) == istriu(uBidiag, 1)
@test !istriu(lbidiag)
@test istriu(lbidiag) == istriu(lbidiagG) == istriu(lBidiag)
@test istriu(lbidiag, -1)
@test istriu(lbidiag, -1) == istriu(lbidiagG, -1) == istriu(lBidiag, -1)
@test istriu(adiag)
@test istriu(adiag) == istriu(adiagG) == istriu(aDiag)
end
@testset "istril" begin
@test !istril(pentadiag)
@test istril(pentadiag, 2)
@test !istril(tridiag)
@test istril(tridiag) == istril(tridiagG) == istril(Tridiag)
@test istril(tridiag, 1)
@test istril(tridiag, 1) == istril(tridiagG, 1) == istril(Tridiag, 1)
@test !istril(ubidiag)
@test istril(ubidiag) == istril(ubidiagG) == istril(ubidiagG)
@test istril(ubidiag, 1)
@test istril(ubidiag, 1) == istril(ubidiagG, 1) == istril(uBidiag, 1)
@test istril(lbidiag)
@test istril(lbidiag) == istril(lbidiagG) == istril(lBidiag)
@test !istril(lbidiag, -1)
@test istril(lbidiag, -1) == istril(lbidiagG, -1) == istril(lBidiag, -1)
@test istril(adiag)
@test istril(adiag) == istril(adiagG) == istril(aDiag)
end
@testset "isbanded" begin
@test isbanded(pentadiag, -2, 2)
@test !isbanded(pentadiag, -1, 2)
@test !isbanded(pentadiag, -2, 1)
@test isbanded(tridiag, -1, 1)
@test isbanded(tridiag, -1, 1) == isbanded(tridiagG, -1, 1) == isbanded(Tridiag, -1, 1)
@test !isbanded(tridiag, 0, 1)
@test isbanded(tridiag, 0, 1) == isbanded(tridiagG, 0, 1) == isbanded(Tridiag, 0, 1)
@test !isbanded(tridiag, -1, 0)
@test isbanded(tridiag, -1, 0) == isbanded(tridiagG, -1, 0) == isbanded(Tridiag, -1, 0)
@test isbanded(ubidiag, 0, 1)
@test isbanded(ubidiag, 0, 1) == isbanded(ubidiagG, 0, 1) == isbanded(uBidiag, 0, 1)
@test !isbanded(ubidiag, 1, 1)
@test isbanded(ubidiag, 1, 1) == isbanded(ubidiagG, 1, 1) == isbanded(uBidiag, 1, 1)
@test !isbanded(ubidiag, 0, 0)
@test isbanded(ubidiag, 0, 0) == isbanded(ubidiagG, 0, 0) == isbanded(uBidiag, 0, 0)
@test isbanded(lbidiag, -1, 0)
@test isbanded(lbidiag, -1, 0) == isbanded(lbidiagG, -1, 0) == isbanded(lBidiag, -1, 0)
@test !isbanded(lbidiag, 0, 0)
@test isbanded(lbidiag, 0, 0) == isbanded(lbidiagG, 0, 0) == isbanded(lBidiag, 0, 0)
@test !isbanded(lbidiag, -1, -1)
@test isbanded(lbidiag, -1, -1) == isbanded(lbidiagG, -1, -1) == isbanded(lBidiag, -1, -1)
@test isbanded(adiag, 0, 0)
@test isbanded(adiag, 0, 0) == isbanded(adiagG, 0, 0) == isbanded(aDiag, 0, 0)
@test !isbanded(adiag, -1, -1)
@test isbanded(adiag, -1, -1) == isbanded(adiagG, -1, -1) == isbanded(aDiag, -1, -1)
@test !isbanded(adiag, 1, 1)
@test isbanded(adiag, 1, 1) == isbanded(adiagG, 1, 1) == isbanded(aDiag, 1, 1)
end
@testset "isdiag" begin
@test !isdiag(tridiag)
@test isdiag(tridiag) == isdiag(tridiagG) == isdiag(Tridiag)
@test !isdiag(ubidiag)
@test isdiag(ubidiag) == isdiag(ubidiagG) == isdiag(uBidiag)
@test !isdiag(lbidiag)
@test isdiag(lbidiag) == isdiag(lbidiagG) == isdiag(lBidiag)
@test isdiag(adiag)
@test isdiag(adiag) ==isdiag(adiagG) == isdiag(aDiag)
end
end

@testset "isbanded/istril/istriu with rectangular matrices" begin
@testset "$(size(A))" for A in [zeros(0,4), zeros(2,5), zeros(5,2), zeros(4,0)]
@testset for m in -(size(A,1)-1):(size(A,2)-1)
A .= 0
A[diagind(A, m)] .= 1
G = GenericArray(A)
@testset for (kl,ku) in Iterators.product(-6:6, -6:6)
@test isbanded(A, kl, ku) == isbanded(G, kl, ku) == isempty(A) || (m in (kl:ku))
end
@testset for k in -6:6
@test istriu(A,k) == istriu(G,k) == isempty(A) || (k <= m)
@test istril(A,k) == istril(G,k) == isempty(A) || (k >= m)
end
end
end
end

Expand Down
26 changes: 26 additions & 0 deletions stdlib/LinearAlgebra/test/hessenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,4 +279,30 @@ end
@test H.H == D
end

@testset "istriu/istril forwards to parent" begin
n = 10
@testset "$(nameof(typeof(M)))" for M in [Tridiagonal(rand(n-1), rand(n), rand(n-1)),
Tridiagonal(zeros(n-1), zeros(n), zeros(n-1)),
Diagonal(randn(n)),
Diagonal(zeros(n)),
]
U = UpperHessenberg(M)
A = Array(U)
for k in -n:n
@test istriu(U, k) == istriu(A, k)
@test istril(U, k) == istril(A, k)
end
end
z = zeros(n,n)
P = Matrix{BigFloat}(undef, n, n)
copytrito!(P, z, 'U')
P[diagind(P,-1)] .= 0
U = UpperHessenberg(P)
A = Array(U)
@testset for k in -n:n
@test istriu(U, k) == istriu(A, k)
@test istril(U, k) == istril(A, k)
end
end

end # module TestHessenberg

0 comments on commit f93138e

Please sign in to comment.