Skip to content

Commit

Permalink
lmul!/rmul! for banded matrices (JuliaLang#55823)
Browse files Browse the repository at this point in the history
This adds fast methods for `lmul!` and `rmul!` between banded matrices
and numbers.
Performance impact:
```julia
julia> T = Tridiagonal(rand(999), rand(1000), rand(999));

julia> @Btime rmul!($T, 0.2);
  4.686 ms (0 allocations: 0 bytes) # nightly v"1.12.0-DEV.1225"
  669.355 ns (0 allocations: 0 bytes) # this PR
```
  • Loading branch information
jishnub authored Sep 22, 2024
1 parent 4964c97 commit 9136bdd
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 0 deletions.
26 changes: 26 additions & 0 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,32 @@ end
-(A::Bidiagonal)=Bidiagonal(-A.dv,-A.ev,A.uplo)
*(A::Bidiagonal, B::Number) = Bidiagonal(A.dv*B, A.ev*B, A.uplo)
*(B::Number, A::Bidiagonal) = Bidiagonal(B*A.dv, B*A.ev, A.uplo)
function rmul!(B::Bidiagonal, x::Number)
if size(B,1) > 1
isupper = B.uplo == 'U'
row, col = 1 + isupper, 1 + !isupper
# ensure that zeros are preserved on scaling
y = B[row,col] * x
iszero(y) || throw(ArgumentError(LazyString(lazy"cannot set index ($row, $col) off ",
lazy"the tridiagonal band to a nonzero value ($y)")))
end
@. B.dv *= x
@. B.ev *= x
return B
end
function lmul!(x::Number, B::Bidiagonal)
if size(B,1) > 1
isupper = B.uplo == 'U'
row, col = 1 + isupper, 1 + !isupper
# ensure that zeros are preserved on scaling
y = x * B[row,col]
iszero(y) || throw(ArgumentError(LazyString(lazy"cannot set index ($row, $col) off ",
lazy"the tridiagonal band to a nonzero value ($y)")))
end
@. B.dv = x * B.dv
@. B.ev = x * B.ev
return B
end
/(A::Bidiagonal, B::Number) = Bidiagonal(A.dv/B, A.ev/B, A.uplo)
\(B::Number, A::Bidiagonal) = Bidiagonal(B\A.dv, B\A.ev, A.uplo)

Expand Down
20 changes: 20 additions & 0 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,26 @@ end

(*)(x::Number, D::Diagonal) = Diagonal(x * D.diag)
(*)(D::Diagonal, x::Number) = Diagonal(D.diag * x)
function lmul!(x::Number, D::Diagonal)
if size(D,1) > 1
# ensure that zeros are preserved on scaling
y = D[2,1] * x
iszero(y) || throw(ArgumentError(LazyString("cannot set index (2, 1) off ",
lazy"the tridiagonal band to a nonzero value ($y)")))
end
@. D.diag = x * D.diag
return D
end
function rmul!(D::Diagonal, x::Number)
if size(D,1) > 1
# ensure that zeros are preserved on scaling
y = x * D[2,1]
iszero(y) || throw(ArgumentError(LazyString("cannot set index (2, 1) off ",
lazy"the tridiagonal band to a nonzero value ($y)")))
end
@. D.diag *= x
return D
end
(/)(D::Diagonal, x::Number) = Diagonal(D.diag / x)
(\)(x::Number, D::Diagonal) = Diagonal(x \ D.diag)
(^)(D::Diagonal, a::Number) = Diagonal(D.diag .^ a)
Expand Down
47 changes: 47 additions & 0 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,29 @@ end
-(A::SymTridiagonal) = SymTridiagonal(-A.dv, -A.ev)
*(A::SymTridiagonal, B::Number) = SymTridiagonal(A.dv*B, A.ev*B)
*(B::Number, A::SymTridiagonal) = SymTridiagonal(B*A.dv, B*A.ev)
function rmul!(A::SymTridiagonal, x::Number)
if size(A,1) > 2
# ensure that zeros are preserved on scaling
y = A[3,1] * x
iszero(y) || throw(ArgumentError(LazyString("cannot set index (3, 1) off ",
lazy"the tridiagonal band to a nonzero value ($y)")))
end
A.dv .*= x
_evview(A) .*= x
return A
end
function lmul!(x::Number, B::SymTridiagonal)
if size(B,1) > 2
# ensure that zeros are preserved on scaling
y = x * B[3,1]
iszero(y) || throw(ArgumentError(LazyString("cannot set index (3, 1) off ",
lazy"the tridiagonal band to a nonzero value ($y)")))
end
@. B.dv = x * B.dv
ev = _evview(B)
@. ev = x * ev
return B
end
/(A::SymTridiagonal, B::Number) = SymTridiagonal(A.dv/B, A.ev/B)
\(B::Number, A::SymTridiagonal) = SymTridiagonal(B\A.dv, B\A.ev)
==(A::SymTridiagonal{<:Number}, B::SymTridiagonal{<:Number}) =
Expand Down Expand Up @@ -836,6 +859,30 @@ tr(M::Tridiagonal) = sum(M.d)
-(A::Tridiagonal) = Tridiagonal(-A.dl, -A.d, -A.du)
*(A::Tridiagonal, B::Number) = Tridiagonal(A.dl*B, A.d*B, A.du*B)
*(B::Number, A::Tridiagonal) = Tridiagonal(B*A.dl, B*A.d, B*A.du)
function rmul!(T::Tridiagonal, x::Number)
if size(T,1) > 2
# ensure that zeros are preserved on scaling
y = T[3,1] * x
iszero(y) || throw(ArgumentError(LazyString("cannot set index (3, 1) off ",
lazy"the tridiagonal band to a nonzero value ($y)")))
end
T.dl .*= x
T.d .*= x
T.du .*= x
return T
end
function lmul!(x::Number, T::Tridiagonal)
if size(T,1) > 2
# ensure that zeros are preserved on scaling
y = x * T[3,1]
iszero(y) || throw(ArgumentError(LazyString("cannot set index (3, 1) off ",
lazy"the tridiagonal band to a nonzero value ($y)")))
end
@. T.dl = x * T.dl
@. T.d = x * T.d
@. T.du = x * T.du
return T
end
/(A::Tridiagonal, B::Number) = Tridiagonal(A.dl/B, A.d/B, A.du/B)
\(B::Number, A::Tridiagonal) = Tridiagonal(B\A.dl, B\A.d, B\A.du)

Expand Down
13 changes: 13 additions & 0 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,19 @@ end
end
end

@testset "rmul!/lmul! with numbers" begin
for T in (Bidiagonal(rand(4), rand(3), :U), Bidiagonal(rand(4), rand(3), :L))
@test rmul!(copy(T), 0.2) rmul!(Array(T), 0.2)
@test lmul!(0.2, copy(T)) lmul!(0.2, Array(T))
@test_throws ArgumentError rmul!(T, NaN)
@test_throws ArgumentError lmul!(NaN, T)
end
for T in (Bidiagonal(rand(1), rand(0), :U), Bidiagonal(rand(1), rand(0), :L))
@test all(isnan, rmul!(copy(T), NaN))
@test all(isnan, lmul!(NaN, copy(T)))
end
end

@testset "mul with Diagonal" begin
for n in 0:4
dv, ev = rand(n), rand(max(n-1,0))
Expand Down
11 changes: 11 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1345,6 +1345,17 @@ end
end
end

@testset "rmul!/lmul! with numbers" begin
D = Diagonal(rand(4))
@test rmul!(copy(D), 0.2) rmul!(Array(D), 0.2)
@test lmul!(0.2, copy(D)) lmul!(0.2, Array(D))
@test_throws ArgumentError rmul!(D, NaN)
@test_throws ArgumentError lmul!(NaN, D)
D = Diagonal(rand(1))
@test all(isnan, rmul!(copy(D), NaN))
@test all(isnan, lmul!(NaN, copy(D)))
end

@testset "+/- with block Symmetric/Hermitian" begin
for p in ([1 2; 3 4], [1 2+im; 2-im 4+2im])
m = SizedArrays.SizedArray{(2,2)}(p)
Expand Down
13 changes: 13 additions & 0 deletions stdlib/LinearAlgebra/test/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,19 @@ end
end
end

@testset "rmul!/lmul! with numbers" begin
for T in (SymTridiagonal(rand(4), rand(3)), Tridiagonal(rand(3), rand(4), rand(3)))
@test rmul!(copy(T), 0.2) rmul!(Array(T), 0.2)
@test lmul!(0.2, copy(T)) lmul!(0.2, Array(T))
@test_throws ArgumentError rmul!(T, NaN)
@test_throws ArgumentError lmul!(NaN, T)
end
for T in (SymTridiagonal(rand(2), rand(1)), Tridiagonal(rand(1), rand(2), rand(1)))
@test all(isnan, rmul!(copy(T), NaN))
@test all(isnan, lmul!(NaN, copy(T)))
end
end

@testset "mul with empty arrays" begin
A = zeros(5,0)
T = Tridiagonal(zeros(0), zeros(0), zeros(0))
Expand Down

0 comments on commit 9136bdd

Please sign in to comment.