From 9136bddb6c36050e03529e2db456e6ea2e380557 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 22 Sep 2024 22:10:08 +0530 Subject: [PATCH] lmul!/rmul! for banded matrices (#55823) This adds fast methods for `lmul!` and `rmul!` between banded matrices and numbers. Performance impact: ```julia julia> T = Tridiagonal(rand(999), rand(1000), rand(999)); julia> @btime rmul!($T, 0.2); 4.686 ms (0 allocations: 0 bytes) # nightly v"1.12.0-DEV.1225" 669.355 ns (0 allocations: 0 bytes) # this PR ``` --- stdlib/LinearAlgebra/src/bidiag.jl | 26 +++++++++++++++ stdlib/LinearAlgebra/src/diagonal.jl | 20 ++++++++++++ stdlib/LinearAlgebra/src/tridiag.jl | 47 +++++++++++++++++++++++++++ stdlib/LinearAlgebra/test/bidiag.jl | 13 ++++++++ stdlib/LinearAlgebra/test/diagonal.jl | 11 +++++++ stdlib/LinearAlgebra/test/tridiag.jl | 13 ++++++++ 6 files changed, 130 insertions(+) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 12d638f52add6..0aab9ceeca6b9 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -441,6 +441,32 @@ end -(A::Bidiagonal)=Bidiagonal(-A.dv,-A.ev,A.uplo) *(A::Bidiagonal, B::Number) = Bidiagonal(A.dv*B, A.ev*B, A.uplo) *(B::Number, A::Bidiagonal) = Bidiagonal(B*A.dv, B*A.ev, A.uplo) +function rmul!(B::Bidiagonal, x::Number) + if size(B,1) > 1 + isupper = B.uplo == 'U' + row, col = 1 + isupper, 1 + !isupper + # ensure that zeros are preserved on scaling + y = B[row,col] * x + iszero(y) || throw(ArgumentError(LazyString(lazy"cannot set index ($row, $col) off ", + lazy"the tridiagonal band to a nonzero value ($y)"))) + end + @. B.dv *= x + @. B.ev *= x + return B +end +function lmul!(x::Number, B::Bidiagonal) + if size(B,1) > 1 + isupper = B.uplo == 'U' + row, col = 1 + isupper, 1 + !isupper + # ensure that zeros are preserved on scaling + y = x * B[row,col] + iszero(y) || throw(ArgumentError(LazyString(lazy"cannot set index ($row, $col) off ", + lazy"the tridiagonal band to a nonzero value ($y)"))) + end + @. B.dv = x * B.dv + @. B.ev = x * B.ev + return B +end /(A::Bidiagonal, B::Number) = Bidiagonal(A.dv/B, A.ev/B, A.uplo) \(B::Number, A::Bidiagonal) = Bidiagonal(B\A.dv, B\A.ev, A.uplo) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 23d2422d13654..d762549a2b228 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -274,6 +274,26 @@ end (*)(x::Number, D::Diagonal) = Diagonal(x * D.diag) (*)(D::Diagonal, x::Number) = Diagonal(D.diag * x) +function lmul!(x::Number, D::Diagonal) + if size(D,1) > 1 + # ensure that zeros are preserved on scaling + y = D[2,1] * x + iszero(y) || throw(ArgumentError(LazyString("cannot set index (2, 1) off ", + lazy"the tridiagonal band to a nonzero value ($y)"))) + end + @. D.diag = x * D.diag + return D +end +function rmul!(D::Diagonal, x::Number) + if size(D,1) > 1 + # ensure that zeros are preserved on scaling + y = x * D[2,1] + iszero(y) || throw(ArgumentError(LazyString("cannot set index (2, 1) off ", + lazy"the tridiagonal band to a nonzero value ($y)"))) + end + @. D.diag *= x + return D +end (/)(D::Diagonal, x::Number) = Diagonal(D.diag / x) (\)(x::Number, D::Diagonal) = Diagonal(x \ D.diag) (^)(D::Diagonal, a::Number) = Diagonal(D.diag .^ a) diff --git a/stdlib/LinearAlgebra/src/tridiag.jl b/stdlib/LinearAlgebra/src/tridiag.jl index 84c79f57debc7..e755ce63e9b2a 100644 --- a/stdlib/LinearAlgebra/src/tridiag.jl +++ b/stdlib/LinearAlgebra/src/tridiag.jl @@ -228,6 +228,29 @@ end -(A::SymTridiagonal) = SymTridiagonal(-A.dv, -A.ev) *(A::SymTridiagonal, B::Number) = SymTridiagonal(A.dv*B, A.ev*B) *(B::Number, A::SymTridiagonal) = SymTridiagonal(B*A.dv, B*A.ev) +function rmul!(A::SymTridiagonal, x::Number) + if size(A,1) > 2 + # ensure that zeros are preserved on scaling + y = A[3,1] * x + iszero(y) || throw(ArgumentError(LazyString("cannot set index (3, 1) off ", + lazy"the tridiagonal band to a nonzero value ($y)"))) + end + A.dv .*= x + _evview(A) .*= x + return A +end +function lmul!(x::Number, B::SymTridiagonal) + if size(B,1) > 2 + # ensure that zeros are preserved on scaling + y = x * B[3,1] + iszero(y) || throw(ArgumentError(LazyString("cannot set index (3, 1) off ", + lazy"the tridiagonal band to a nonzero value ($y)"))) + end + @. B.dv = x * B.dv + ev = _evview(B) + @. ev = x * ev + return B +end /(A::SymTridiagonal, B::Number) = SymTridiagonal(A.dv/B, A.ev/B) \(B::Number, A::SymTridiagonal) = SymTridiagonal(B\A.dv, B\A.ev) ==(A::SymTridiagonal{<:Number}, B::SymTridiagonal{<:Number}) = @@ -836,6 +859,30 @@ tr(M::Tridiagonal) = sum(M.d) -(A::Tridiagonal) = Tridiagonal(-A.dl, -A.d, -A.du) *(A::Tridiagonal, B::Number) = Tridiagonal(A.dl*B, A.d*B, A.du*B) *(B::Number, A::Tridiagonal) = Tridiagonal(B*A.dl, B*A.d, B*A.du) +function rmul!(T::Tridiagonal, x::Number) + if size(T,1) > 2 + # ensure that zeros are preserved on scaling + y = T[3,1] * x + iszero(y) || throw(ArgumentError(LazyString("cannot set index (3, 1) off ", + lazy"the tridiagonal band to a nonzero value ($y)"))) + end + T.dl .*= x + T.d .*= x + T.du .*= x + return T +end +function lmul!(x::Number, T::Tridiagonal) + if size(T,1) > 2 + # ensure that zeros are preserved on scaling + y = x * T[3,1] + iszero(y) || throw(ArgumentError(LazyString("cannot set index (3, 1) off ", + lazy"the tridiagonal band to a nonzero value ($y)"))) + end + @. T.dl = x * T.dl + @. T.d = x * T.d + @. T.du = x * T.du + return T +end /(A::Tridiagonal, B::Number) = Tridiagonal(A.dl/B, A.d/B, A.du/B) \(B::Number, A::Tridiagonal) = Tridiagonal(B\A.dl, B\A.d, B\A.du) diff --git a/stdlib/LinearAlgebra/test/bidiag.jl b/stdlib/LinearAlgebra/test/bidiag.jl index edad29d4ec180..d633a99a2390e 100644 --- a/stdlib/LinearAlgebra/test/bidiag.jl +++ b/stdlib/LinearAlgebra/test/bidiag.jl @@ -969,6 +969,19 @@ end end end +@testset "rmul!/lmul! with numbers" begin + for T in (Bidiagonal(rand(4), rand(3), :U), Bidiagonal(rand(4), rand(3), :L)) + @test rmul!(copy(T), 0.2) ≈ rmul!(Array(T), 0.2) + @test lmul!(0.2, copy(T)) ≈ lmul!(0.2, Array(T)) + @test_throws ArgumentError rmul!(T, NaN) + @test_throws ArgumentError lmul!(NaN, T) + end + for T in (Bidiagonal(rand(1), rand(0), :U), Bidiagonal(rand(1), rand(0), :L)) + @test all(isnan, rmul!(copy(T), NaN)) + @test all(isnan, lmul!(NaN, copy(T))) + end +end + @testset "mul with Diagonal" begin for n in 0:4 dv, ev = rand(n), rand(max(n-1,0)) diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index 83d5e4fcdf170..dfb901908ba69 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -1345,6 +1345,17 @@ end end end +@testset "rmul!/lmul! with numbers" begin + D = Diagonal(rand(4)) + @test rmul!(copy(D), 0.2) ≈ rmul!(Array(D), 0.2) + @test lmul!(0.2, copy(D)) ≈ lmul!(0.2, Array(D)) + @test_throws ArgumentError rmul!(D, NaN) + @test_throws ArgumentError lmul!(NaN, D) + D = Diagonal(rand(1)) + @test all(isnan, rmul!(copy(D), NaN)) + @test all(isnan, lmul!(NaN, copy(D))) +end + @testset "+/- with block Symmetric/Hermitian" begin for p in ([1 2; 3 4], [1 2+im; 2-im 4+2im]) m = SizedArrays.SizedArray{(2,2)}(p) diff --git a/stdlib/LinearAlgebra/test/tridiag.jl b/stdlib/LinearAlgebra/test/tridiag.jl index 15ac7f9f2147f..826a6e62355d0 100644 --- a/stdlib/LinearAlgebra/test/tridiag.jl +++ b/stdlib/LinearAlgebra/test/tridiag.jl @@ -935,6 +935,19 @@ end end end +@testset "rmul!/lmul! with numbers" begin + for T in (SymTridiagonal(rand(4), rand(3)), Tridiagonal(rand(3), rand(4), rand(3))) + @test rmul!(copy(T), 0.2) ≈ rmul!(Array(T), 0.2) + @test lmul!(0.2, copy(T)) ≈ lmul!(0.2, Array(T)) + @test_throws ArgumentError rmul!(T, NaN) + @test_throws ArgumentError lmul!(NaN, T) + end + for T in (SymTridiagonal(rand(2), rand(1)), Tridiagonal(rand(1), rand(2), rand(1))) + @test all(isnan, rmul!(copy(T), NaN)) + @test all(isnan, lmul!(NaN, copy(T))) + end +end + @testset "mul with empty arrays" begin A = zeros(5,0) T = Tridiagonal(zeros(0), zeros(0), zeros(0))