From f9c2db1289e87a058d367cbd21d09cc14b84b0d9 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sat, 10 Aug 2024 17:36:47 +0530 Subject: [PATCH 01/10] Avoid materializing arrays in bidiag matmul --- stdlib/LinearAlgebra/src/LinearAlgebra.jl | 4 ++- stdlib/LinearAlgebra/src/bidiag.jl | 37 +++++++++++++------ stdlib/LinearAlgebra/test/bidiag.jl | 33 +++++++++++++++++ stdlib/LinearAlgebra/test/tridiag.jl | 44 +++++++++++++++++++++++ 4 files changed, 107 insertions(+), 11 deletions(-) diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index be59516f086ab..fa3ce3057c9ad 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -673,7 +673,9 @@ matprod_dest(A::Diagonal, B::Diagonal, TS) = _matprod_dest_diag(B, TS) _matprod_dest_diag(A, TS) = similar(A, TS) function _matprod_dest_diag(A::SymTridiagonal, TS) n = size(A, 1) - Tridiagonal(similar(A, TS, n-1), similar(A, TS, n), similar(A, TS, n-1)) + ev = similar(A, TS, max(0, n-1)) + dv = similar(A, TS, n) + Tridiagonal(ev, dv, similar(ev)) end # Special handling for adj/trans vec diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 5aa4314c9ae51..6ee63d3467907 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -565,7 +565,13 @@ _mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul) = function _bibimul!(C, A, B, _add) check_A_mul_B!_sizes(size(C), size(A), size(B)) n = size(A,1) - n <= 3 && return mul!(C, Array(A), Array(B), _add.alpha, _add.beta) + if n <= 3 + # naive multiplication + for I in CartesianIndices(C) + _modify!(_add, sum(A[I[1], k] * B[k, I[2]] for k in axes(A,2)), C, I) + end + return C + end # We use `_rmul_or_fill!` instead of `_modify!` here since using # `_modify!` in the following loop will not update the # off-diagonal elements for non-zero beta. @@ -737,7 +743,14 @@ function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulA end iszero(nA) && return C iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta) - nA <= 3 && return mul!(C, Array(A), Array(B), _add.alpha, _add.beta) + if nA <= 3 + # naive multiplication + for I in CartesianIndices(C) + col = Base.tail(Tuple(I)) + _modify!(_add, sum(A[I[1], k] * B[k, col...] for k in axes(A,2)), C, I) + end + return C + end l = _diag(A, -1) d = _diag(A, 0) u = _diag(A, 1) @@ -758,11 +771,12 @@ end function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul) require_one_based_indexing(C, A) check_A_mul_B!_sizes(size(C), size(A), size(B)) - iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta) n = size(A,1) m = size(B,2) - if n <= 3 || m <= 1 - return mul!(C, Array(A), Array(B), _add.alpha, _add.beta) + (iszero(_add.alpha) || iszero(m)) && return _rmul_or_fill!(C, _add.beta) + if m == 1 + B11 = B[1,1] + return mul!(C, A, B11, _add.alpha, _add.beta) end Bl = _diag(B, -1) Bd = _diag(B, 0) @@ -793,11 +807,9 @@ end function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAddMul) require_one_based_indexing(C, A) check_A_mul_B!_sizes(size(C), size(A), size(B)) - iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta) - if size(A, 1) <= 3 || size(B, 2) <= 1 - return mul!(C, Array(A), Array(B), _add.alpha, _add.beta) - end m, n = size(A) + (iszero(m) || iszero(n)) && return C + iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta) @inbounds if B.uplo == 'U' for i in 1:m for j in n:-1:2 @@ -824,7 +836,12 @@ function _dibimul!(C, A, B, _add) require_one_based_indexing(C) check_A_mul_B!_sizes(size(C), size(A), size(B)) n = size(A,1) - n <= 3 && return mul!(C, Array(A), Array(B), _add.alpha, _add.beta) + if n <= 3 + for I in CartesianIndices(C) + _modify!(_add, A.diag[I[1]] * B[I[1], I[2]], C, I) + end + return C + end _rmul_or_fill!(C, _add.beta) # see the same use above iszero(_add.alpha) && return C Ad = A.diag diff --git a/stdlib/LinearAlgebra/test/bidiag.jl b/stdlib/LinearAlgebra/test/bidiag.jl index e19d890237a26..147683a23c792 100644 --- a/stdlib/LinearAlgebra/test/bidiag.jl +++ b/stdlib/LinearAlgebra/test/bidiag.jl @@ -1020,4 +1020,37 @@ end @test_throws "cannot set entry" B[1,2] = 4 end +@testset "mul for small matrices" begin + @testset for n in 0:4 + D = Diagonal(rand(n)) + v = rand(n) + @testset for uplo in (:L, :U) + B = Bidiagonal(rand(n), rand(max(n-1,0)), uplo) + M = Matrix(B) + + @test B * v ≈ M * v + @test mul!(similar(v), B, v) ≈ M * v + + @test B * B ≈ M * M + @test mul!(similar(B, size(B)), B, B) ≈ M * M + + for m in 1:6 + AL = rand(m,n) + AR = rand(n,m) + @test AL * B ≈ AL * M + @test B * AR ≈ M * AR + @test mul!(similar(AL), AL, B) ≈ AL * M + @test mul!(similar(AR), B, AR) ≈ M * AR + end + + @test B * D ≈ M * D + @test D * B ≈ D * M + @test mul!(similar(B), B, D) ≈ M * D + @test mul!(similar(B), B, D) ≈ M * D + @test mul!(similar(B, size(B)), D, B) ≈ D * M + @test mul!(similar(B, size(B)), B, D) ≈ M * D + end + end +end + end # module TestBidiagonal diff --git a/stdlib/LinearAlgebra/test/tridiag.jl b/stdlib/LinearAlgebra/test/tridiag.jl index e0a8e32d77852..e87f608437c4b 100644 --- a/stdlib/LinearAlgebra/test/tridiag.jl +++ b/stdlib/LinearAlgebra/test/tridiag.jl @@ -930,4 +930,48 @@ end @test sprint(show, S) == "SymTridiagonal($(repr(diag(S))), $(repr(diag(S,1))))" end +@testset "mul for small matrices" begin + @testset for n in 0:4 + for T in ( + Tridiagonal(rand(max(n-1,0)), rand(n), rand(max(n-1,0))), + SymTridiagonal(rand(n), rand(max(n-1,0))), + ) + M = Matrix(T) + @test T * T ≈ M * M + @test mul!(similar(T, size(T)), T, T) ≈ M * M + + for m in 0:6 + AR = rand(n,m) + AL = rand(m,n) + @test AL * T ≈ AL * M + @test T * AR ≈ M * AR + @test mul!(similar(AL), AL, T) ≈ AL * M + @test mul!(similar(AR), T, AR) ≈ M * AR + end + + v = rand(n) + @test T * v ≈ M * v + @test mul!(similar(v), T, v) ≈ M * v + + D = Diagonal(rand(n)) + @test T * D ≈ M * D + @test D * T ≈ D * M + @test mul!(Tridiagonal(similar(T)), D, T) ≈ D * M + @test mul!(Tridiagonal(similar(T)), T, D) ≈ M * D + @test mul!(similar(T, size(T)), D, T) ≈ D * M + @test mul!(similar(T, size(T)), T, D) ≈ M * D + + B = Bidiagonal(rand(n), rand(max(0, n-1)), :U) + @test T * B ≈ M * B + @test B * T ≈ B * M + if n <= 2 + @test mul!(Tridiagonal(similar(T)), B, T) ≈ B * M + @test mul!(Tridiagonal(similar(T)), T, B) ≈ M * B + @test mul!(similar(T, size(T)), B, T) ≈ B * M + @test mul!(similar(T, size(T)), T, B) ≈ M * B + end + end + end +end + end # module TestTridiagonal From 2c56cf877db543028a54b1aeeda62d136a66aba4 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sat, 10 Aug 2024 21:20:45 +0530 Subject: [PATCH 02/10] Make iteration cache-friendly in Bidiagonal matmul --- stdlib/LinearAlgebra/src/bidiag.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 6ee63d3467907..5eaf5e5a41001 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -811,17 +811,17 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAdd (iszero(m) || iszero(n)) && return C iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta) @inbounds if B.uplo == 'U' + for j in n:-1:2, i in 1:m + _modify!(_add, A[i,j] * B.dv[j] + A[i,j-1] * B.ev[j-1], C, (i, j)) + end for i in 1:m - for j in n:-1:2 - _modify!(_add, A[i,j] * B.dv[j] + A[i,j-1] * B.ev[j-1], C, (i, j)) - end _modify!(_add, A[i,1] * B.dv[1], C, (i, 1)) end else # uplo == 'L' + for j in 1:n-1, i in 1:m + _modify!(_add, A[i,j] * B.dv[j] + A[i,j+1] * B.ev[j], C, (i, j)) + end for i in 1:m - for j in 1:n-1 - _modify!(_add, A[i,j] * B.dv[j] + A[i,j+1] * B.ev[j], C, (i, j)) - end _modify!(_add, A[i,n] * B.dv[n], C, (i, n)) end end From 974425c104aa9f7f3baf7ab93379248ea47a480f Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sat, 10 Aug 2024 21:43:30 +0530 Subject: [PATCH 03/10] Cherry-pick changes from jishnub/bidigamul_empty --- stdlib/LinearAlgebra/src/bidiag.jl | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 5eaf5e5a41001..18b8ff926e667 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -535,12 +535,18 @@ function rmul!(B::Bidiagonal, D::Diagonal) end @noinline function check_A_mul_B!_sizes((mC, nC)::NTuple{2,Integer}, (mA, nA)::NTuple{2,Integer}, (mB, nB)::NTuple{2,Integer}) + # check for matching sizes in one column of B and C + check_A_mul_B!_sizes((mC,), (mA, nA), (mB,)) + # ensure that the number of columns in B and C match + if nB != nC + throw(DimensionMismatch(lazy"second dimension of output C, $nC, and second dimension of B, $nB, must match")) + end +end +@noinline function check_A_mul_B!_sizes((mC,)::Tuple{Integer}, (mA, nA)::NTuple{2,Integer}, (mB,)::Tuple{Integer}) if mA != mC throw(DimensionMismatch(lazy"first dimension of A, $mA, and first dimension of output C, $mC, must match")) elseif nA != mB throw(DimensionMismatch(lazy"second dimension of A, $nA, and first dimension of B, $mB, must match")) - elseif nB != nC - throw(DimensionMismatch(lazy"second dimension of output C, $nC, and second dimension of B, $nB, must match")) end end @@ -563,6 +569,7 @@ _mul!(C::AbstractMatrix, A::BiTriSym, B::TriSym, _add::MulAddMul) = _mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul) = _bibimul!(C, A, B, _add) function _bibimul!(C, A, B, _add) + require_one_based_indexing(C) check_A_mul_B!_sizes(size(C), size(A), size(B)) n = size(A,1) if n <= 3 @@ -733,14 +740,9 @@ end function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulAddMul) require_one_based_indexing(C, B) + check_A_mul_B!_sizes(size(C), size(A), size(B)) nA = size(A,1) nB = size(B,2) - if !(size(C,1) == size(B,1) == nA) - throw(DimensionMismatch(lazy"A has first dimension $nA, B has $(size(B,1)), C has $(size(C,1)) but all must match")) - end - if size(C,2) != nB - throw(DimensionMismatch(lazy"A has second dimension $nA, B has $(size(B,2)), C has $(size(C,2)) but all must match")) - end iszero(nA) && return C iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta) if nA <= 3 From fc3387f52384f82c096fc080d11eb821f96c6245 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sat, 10 Aug 2024 23:48:07 +0530 Subject: [PATCH 04/10] Specialize Bidiagonal * AbstractVecOrMat --- stdlib/LinearAlgebra/src/bidiag.jl | 38 +++++++++++++++++++++++++++++ stdlib/LinearAlgebra/test/bidiag.jl | 2 +- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 18b8ff926e667..dd1e63fb16491 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -753,6 +753,44 @@ function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulA end return C end + _mul_bitrisym!(C, A, B, _add) +end +function _mul_bitrisym!(C::AbstractVecOrMat, A::Bidiagonal, B::AbstractVecOrMat, _add::MulAddMul) + nA = size(A,1) + nB = size(B,2) + d = B.dv + if A.uplo == 'U' + u = B.ev + @inbounds begin + for j = 1:nB + b₀, b₊ = B[1, j], B[2, j] + _modify!(_add, d[1]*b₀ + u[1]*b₊, C, (1, j)) + for i = 2:nA - 1 + b₀, b₊ = b₊, B[i + 1, j] + _modify!(_add, d[i]*b₀ + u[i]*b₊, C, (i, j)) + end + _modify!(_add, d[nA]*b₊, C, (nA, j)) + end + end + else + l = B.ev + @inbounds begin + for j = 1:nB + b₀, b₊ = B[1, j], B[2, j] + _modify!(_add, d[1]*b₀, C, (1, j)) + for i = 2:nA - 1 + b₋, b₀, b₊ = b₀, b₊, B[i + 1, j] + _modify!(_add, l[i - 1]*b₋ + d[i]*b₀, C, (i, j)) + end + _modify!(_add, l[nA - 1]*b₀ + d[nA]*b₊, C, (nA, j)) + end + end + end + C +end +function _mul_bitrisym!(C::AbstractVecOrMat, A::TriSym, B::AbstractVecOrMat, _add::MulAddMul) + nA = size(A,1) + nB = size(B,2) l = _diag(A, -1) d = _diag(A, 0) u = _diag(A, 1) diff --git a/stdlib/LinearAlgebra/test/bidiag.jl b/stdlib/LinearAlgebra/test/bidiag.jl index 147683a23c792..6b6cddaf4efef 100644 --- a/stdlib/LinearAlgebra/test/bidiag.jl +++ b/stdlib/LinearAlgebra/test/bidiag.jl @@ -1034,7 +1034,7 @@ end @test B * B ≈ M * M @test mul!(similar(B, size(B)), B, B) ≈ M * M - for m in 1:6 + for m in 0:6 AL = rand(m,n) AR = rand(n,m) @test AL * B ≈ AL * M From 11de30c8425bb0b4136c5f2176dbab5cec0b7e49 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 11 Aug 2024 13:52:18 +0530 Subject: [PATCH 05/10] Fix field access --- stdlib/LinearAlgebra/src/bidiag.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index dd1e63fb16491..c6e403a726e78 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -758,9 +758,9 @@ end function _mul_bitrisym!(C::AbstractVecOrMat, A::Bidiagonal, B::AbstractVecOrMat, _add::MulAddMul) nA = size(A,1) nB = size(B,2) - d = B.dv + d = A.dv if A.uplo == 'U' - u = B.ev + u = A.ev @inbounds begin for j = 1:nB b₀, b₊ = B[1, j], B[2, j] @@ -773,7 +773,7 @@ function _mul_bitrisym!(C::AbstractVecOrMat, A::Bidiagonal, B::AbstractVecOrMat, end end else - l = B.ev + l = A.ev @inbounds begin for j = 1:nB b₀, b₊ = B[1, j], B[2, j] From 2b443dcffa0d42080a9ce342128bf5eef83078cf Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 11 Aug 2024 16:51:24 +0530 Subject: [PATCH 06/10] Split bibimul into cases --- stdlib/LinearAlgebra/src/bidiag.jl | 227 +++++++++++++++++++++++++-- stdlib/LinearAlgebra/test/bidiag.jl | 16 ++ stdlib/LinearAlgebra/test/tridiag.jl | 23 ++- 3 files changed, 243 insertions(+), 23 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index c6e403a726e78..e90ae50232924 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -584,12 +584,6 @@ function _bibimul!(C, A, B, _add) # off-diagonal elements for non-zero beta. _rmul_or_fill!(C, _add.beta) iszero(_add.alpha) && return C - Al = _diag(A, -1) - Ad = _diag(A, 0) - Au = _diag(A, 1) - Bl = _diag(B, -1) - Bd = _diag(B, 0) - Bu = _diag(B, 1) @inbounds begin # first row of C C[1,1] += _add(A[1,1]*B[1,1] + A[1, 2]*B[2, 1]) @@ -600,6 +594,31 @@ function _bibimul!(C, A, B, _add) C[2,2] += _add(A[2,1]*B[1,2] + A[2,2]*B[2,2] + A[2,3]*B[3,2]) C[2,3] += _add(A[2,2]*B[2,3] + A[2,3]*B[3,3]) C[2,4] += _add(A[2,3]*B[3,4]) + end + # middle rows + __bibimul!(C, A, B, _add) + @inbounds begin + # row before last of C + C[n-1,n-3] += _add(A[n-1,n-2]*B[n-2,n-3]) + C[n-1,n-2] += _add(A[n-1,n-1]*B[n-1,n-2] + A[n-1,n-2]*B[n-2,n-2]) + C[n-1,n-1] += _add(A[n-1,n-2]*B[n-2,n-1] + A[n-1,n-1]*B[n-1,n-1] + A[n-1,n]*B[n,n-1]) + C[n-1,n ] += _add(A[n-1,n-1]*B[n-1,n ] + A[n-1, n]*B[n ,n ]) + # last row of C + C[n,n-2] += _add(A[n,n-1]*B[n-1,n-2]) + C[n,n-1] += _add(A[n,n-1]*B[n-1,n-1] + A[n,n]*B[n,n-1]) + C[n,n ] += _add(A[n,n-1]*B[n-1,n ] + A[n,n]*B[n,n ]) + end # inbounds + C +end +function __bibimul!(C, A, B, _add) + n = size(A,1) + Al = _diag(A, -1) + Ad = _diag(A, 0) + Au = _diag(A, 1) + Bl = _diag(B, -1) + Bd = _diag(B, 0) + Bu = _diag(B, 1) + @inbounds begin for j in 3:n-2 Ajj₋1 = Al[j-1] Ajj = Ad[j] @@ -619,16 +638,192 @@ function _bibimul!(C, A, B, _add) C[j, j+1] += _add(Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1) C[j, j+2] += _add(Ajj₊1*Bj₊1j₊2) end - # row before last of C - C[n-1,n-3] += _add(A[n-1,n-2]*B[n-2,n-3]) - C[n-1,n-2] += _add(A[n-1,n-1]*B[n-1,n-2] + A[n-1,n-2]*B[n-2,n-2]) - C[n-1,n-1] += _add(A[n-1,n-2]*B[n-2,n-1] + A[n-1,n-1]*B[n-1,n-1] + A[n-1,n]*B[n,n-1]) - C[n-1,n ] += _add(A[n-1,n-1]*B[n-1,n ] + A[n-1, n]*B[n ,n ]) - # last row of C - C[n,n-2] += _add(A[n,n-1]*B[n-1,n-2]) - C[n,n-1] += _add(A[n,n-1]*B[n-1,n-1] + A[n,n]*B[n,n-1]) - C[n,n ] += _add(A[n,n-1]*B[n-1,n ] + A[n,n]*B[n,n ]) - end # inbounds + end + C +end +function __bibimul!(C, A, B::Bidiagonal, _add) + n = size(A,1) + Al = _diag(A, -1) + Ad = _diag(A, 0) + Au = _diag(A, 1) + if B.uplo == 'U' + Bd = _diag(B, 0) + Bu = _diag(B, 1) + @inbounds begin + for j in 3:n-2 + Ajj₋1 = Al[j-1] + Ajj = Ad[j] + Ajj₊1 = Au[j] + Bj₋1j₋1 = Bd[j-1] + Bj₋1j = Bu[j-1] + Bjj = Bd[j] + Bjj₊1 = Bu[j] + Bj₊1j₊1 = Bd[j+1] + Bj₊1j₊2 = Bu[j+1] + C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1) + C[j, j ] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj + Ajj₊1*Bj₊1j) + C[j, j+1] += _add(Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1) + C[j, j+2] += _add(Ajj₊1*Bj₊1j₊2) + end + end + else # B.uplo == 'L' + Bl = _diag(B, -1) + Bd = _diag(B, 0) + @inbounds begin + for j in 3:n-2 + Ajj₋1 = Al[j-1] + Ajj = Ad[j] + Ajj₊1 = Au[j] + Bj₋1j₋2 = Bl[j-2] + Bj₋1j₋1 = Bd[j-1] + Bjj₋1 = Bl[j-1] + Bjj = Bd[j] + Bj₊1j = Bl[j] + Bj₊1j₊1 = Bd[j+1] + C[j,j-2] += _add( Ajj₋1*Bj₋1j₋2) + C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1) + C[j, j ] += _add(Ajj*Bjj + Ajj₊1*Bj₊1j) + C[j, j+1] += _add(Ajj₊1*Bj₊1j₊1) + end + end + end + C +end +function __bibimul!(C, A::Bidiagonal, B, _add) + n = size(A,1) + Bl = _diag(B, -1) + Bd = _diag(B, 0) + Bu = _diag(B, 1) + if A.uplo == 'U' + Ad = _diag(A, 0) + Au = _diag(A, 1) + @inbounds begin + for j in 3:n-2 + Ajj = Ad[j] + Ajj₊1 = Au[j] + Bj₋1j₋2 = Bl[j-2] + Bj₋1j₋1 = Bd[j-1] + Bj₋1j = Bu[j-1] + Bjj₋1 = Bl[j-1] + Bjj = Bd[j] + Bjj₊1 = Bu[j] + Bj₊1j = Bl[j] + Bj₊1j₊1 = Bd[j+1] + Bj₊1j₊2 = Bu[j+1] + C[j, j-1] += _add(Ajj*Bjj₋1) + C[j, j ] += _add(Ajj*Bjj + Ajj₊1*Bj₊1j) + C[j, j+1] += _add(Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1) + C[j, j+2] += _add(Ajj₊1*Bj₊1j₊2) + end + end + else # A.uplo == 'L' + Al = _diag(A, -1) + Ad = _diag(A, 0) + @inbounds begin + for j in 3:n-2 + Ajj₋1 = Al[j-1] + Ajj = Ad[j] + Bj₋1j₋2 = Bl[j-2] + Bj₋1j₋1 = Bd[j-1] + Bj₋1j = Bu[j-1] + Bjj₋1 = Bl[j-1] + Bjj = Bd[j] + Bjj₊1 = Bu[j] + Bj₊1j = Bl[j] + Bj₊1j₊1 = Bd[j+1] + Bj₊1j₊2 = Bu[j+1] + C[j,j-2] += _add( Ajj₋1*Bj₋1j₋2) + C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1) + C[j, j ] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj) + C[j, j+1] += _add(Ajj *Bjj₊1) + end + end + end + C +end +function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add) + n = size(A,1) + if A.uplo == 'U' && B.uplo == 'U' + Ad = _diag(A, 0) + Au = _diag(A, 1) + Bd = _diag(B, 0) + Bu = _diag(B, 1) + @inbounds begin + for j in 3:n-2 + Ajj = Ad[j] + Ajj₊1 = Au[j] + Bj₋1j₋1 = Bd[j-1] + Bj₋1j = Bu[j-1] + Bjj = Bd[j] + Bjj₊1 = Bu[j] + Bj₊1j₊1 = Bd[j+1] + Bj₊1j₊2 = Bu[j+1] + C[j, j ] += _add(Ajj*Bjj) + C[j, j+1] += _add(Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1) + C[j, j+2] += _add(Ajj₊1*Bj₊1j₊2) + end + end + elseif A.uplo == 'U' && B.uplo == 'L' + Ad = _diag(A, 0) + Au = _diag(A, 1) + Bl = _diag(B, -1) + Bd = _diag(B, 0) + @inbounds begin + for j in 3:n-2 + Ajj = Ad[j] + Ajj₊1 = Au[j] + Bj₋1j₋2 = Bl[j-2] + Bj₋1j₋1 = Bd[j-1] + Bjj₋1 = Bl[j-1] + Bjj = Bd[j] + Bj₊1j = Bl[j] + Bj₊1j₊1 = Bd[j+1] + C[j, j-1] += _add(Ajj*Bjj₋1) + C[j, j ] += _add(Ajj*Bjj + Ajj₊1*Bj₊1j) + C[j, j+1] += _add(Ajj₊1*Bj₊1j₊1) + end + end + elseif A.uplo == 'L' && B.uplo == 'U' + Al = _diag(A, -1) + Ad = _diag(A, 0) + Bd = _diag(B, 0) + Bu = _diag(B, 1) + @inbounds begin + for j in 3:n-2 + Ajj₋1 = Al[j-1] + Ajj = Ad[j] + Bj₋1j₋1 = Bd[j-1] + Bj₋1j = Bu[j-1] + Bjj = Bd[j] + Bjj₊1 = Bu[j] + Bj₊1j₊1 = Bd[j+1] + Bj₊1j₊2 = Bu[j+1] + C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1) + C[j, j ] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj) + C[j, j+1] += _add(Ajj *Bjj₊1) + end + end + else # A.uplo == 'L' && B.uplo == 'L' + Al = _diag(A, -1) + Ad = _diag(A, 0) + Bl = _diag(B, -1) + Bd = _diag(B, 0) + @inbounds begin + for j in 3:n-2 + Ajj₋1 = Al[j-1] + Ajj = Ad[j] + Bj₋1j₋2 = Bl[j-2] + Bj₋1j₋1 = Bd[j-1] + Bjj₋1 = Bl[j-1] + Bjj = Bd[j] + Bj₊1j = Bl[j] + Bj₊1j₊1 = Bd[j+1] + C[j,j-2] += _add( Ajj₋1*Bj₋1j₋2) + C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1) + C[j, j ] += _add(Ajj*Bjj) + end + end + end C end diff --git a/stdlib/LinearAlgebra/test/bidiag.jl b/stdlib/LinearAlgebra/test/bidiag.jl index 6b6cddaf4efef..9afab417f37c1 100644 --- a/stdlib/LinearAlgebra/test/bidiag.jl +++ b/stdlib/LinearAlgebra/test/bidiag.jl @@ -1030,9 +1030,11 @@ end @test B * v ≈ M * v @test mul!(similar(v), B, v) ≈ M * v + @test mul!(ones(size(v)), B, v, 2, 3) ≈ M * v * 2 .+ 3 @test B * B ≈ M * M @test mul!(similar(B, size(B)), B, B) ≈ M * M + @test mul!(ones(size(B)), B, B, 2, 4) ≈ M * M * 2 .+ 4 for m in 0:6 AL = rand(m,n) @@ -1041,6 +1043,8 @@ end @test B * AR ≈ M * AR @test mul!(similar(AL), AL, B) ≈ AL * M @test mul!(similar(AR), B, AR) ≈ M * AR + @test mul!(ones(size(AL)), AL, B, 2, 4) ≈ AL * M * 2 .+ 4 + @test mul!(ones(size(AR)), B, AR, 2, 4) ≈ M * AR * 2 .+ 4 end @test B * D ≈ M * D @@ -1049,7 +1053,19 @@ end @test mul!(similar(B), B, D) ≈ M * D @test mul!(similar(B, size(B)), D, B) ≈ D * M @test mul!(similar(B, size(B)), B, D) ≈ M * D + @test mul!(ones(size(B)), D, B, 2, 4) ≈ D * M * 2 .+ 4 + @test mul!(ones(size(B)), B, D, 2, 4) ≈ M * D * 2 .+ 4 end + BL = Bidiagonal(rand(n), rand(max(0, n-1)), :L) + ML = Matrix(BL) + BU = Bidiagonal(rand(n), rand(max(0, n-1)), :U) + MU = Matrix(BU) + T = Tridiagonal(zeros(max(0, n-1)), zeros(n), zeros(max(0, n-1))) + @test mul!(T, BL, BU) ≈ ML * MU + @test mul!(T, BU, BL) ≈ MU * ML + T = Tridiagonal(ones(max(0, n-1)), ones(n), ones(max(0, n-1))) + @test mul!(copy(T), BL, BU, 2, 3) ≈ ML * MU * 2 + T * 3 + @test mul!(copy(T), BU, BL, 2, 3) ≈ MU * ML * 2 + T * 3 end end diff --git a/stdlib/LinearAlgebra/test/tridiag.jl b/stdlib/LinearAlgebra/test/tridiag.jl index e87f608437c4b..82e34c986ec85 100644 --- a/stdlib/LinearAlgebra/test/tridiag.jl +++ b/stdlib/LinearAlgebra/test/tridiag.jl @@ -939,6 +939,7 @@ end M = Matrix(T) @test T * T ≈ M * M @test mul!(similar(T, size(T)), T, T) ≈ M * M + @test mul!(ones(size(T)), T, T, 2, 4) ≈ M * M * 2 .+ 4 for m in 0:6 AR = rand(n,m) @@ -947,6 +948,8 @@ end @test T * AR ≈ M * AR @test mul!(similar(AL), AL, T) ≈ AL * M @test mul!(similar(AR), T, AR) ≈ M * AR + @test mul!(ones(size(AL)), AL, T, 2, 4) ≈ AL * M * 2 .+ 4 + @test mul!(ones(size(AR)), T, AR, 2, 4) ≈ M * AR * 2 .+ 4 end v = rand(n) @@ -960,15 +963,21 @@ end @test mul!(Tridiagonal(similar(T)), T, D) ≈ M * D @test mul!(similar(T, size(T)), D, T) ≈ D * M @test mul!(similar(T, size(T)), T, D) ≈ M * D - - B = Bidiagonal(rand(n), rand(max(0, n-1)), :U) - @test T * B ≈ M * B - @test B * T ≈ B * M - if n <= 2 - @test mul!(Tridiagonal(similar(T)), B, T) ≈ B * M - @test mul!(Tridiagonal(similar(T)), T, B) ≈ M * B + @test mul!(ones(size(T)), D, T, 2, 4) ≈ D * M * 2 .+ 4 + @test mul!(ones(size(T)), T, D, 2, 4) ≈ M * D * 2 .+ 4 + + for uplo in (:U, :L) + B = Bidiagonal(rand(n), rand(max(0, n-1)), uplo) + @test T * B ≈ M * B + @test B * T ≈ B * M + if n <= 2 + @test mul!(Tridiagonal(similar(T)), B, T) ≈ B * M + @test mul!(Tridiagonal(similar(T)), T, B) ≈ M * B + end @test mul!(similar(T, size(T)), B, T) ≈ B * M @test mul!(similar(T, size(T)), T, B) ≈ M * B + @test mul!(ones(size(T)), B, T, 2, 4) ≈ B * M * 2 .+ 4 + @test mul!(ones(size(T)), T, B, 2, 4) ≈ M * B * 2 .+ 4 end end end From d03767bc2bcdaec0521c4b5de56b29f88f2106a4 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 11 Aug 2024 22:38:16 +0530 Subject: [PATCH 07/10] Fix undef variables in __bibimul --- stdlib/LinearAlgebra/src/bidiag.jl | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index e90ae50232924..7262ae98b7e7b 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -661,7 +661,7 @@ function __bibimul!(C, A, B::Bidiagonal, _add) Bj₊1j₊1 = Bd[j+1] Bj₊1j₊2 = Bu[j+1] C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1) - C[j, j ] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj + Ajj₊1*Bj₊1j) + C[j, j ] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj) C[j, j+1] += _add(Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1) C[j, j+2] += _add(Ajj₊1*Bj₊1j₊2) end @@ -701,9 +701,6 @@ function __bibimul!(C, A::Bidiagonal, B, _add) for j in 3:n-2 Ajj = Ad[j] Ajj₊1 = Au[j] - Bj₋1j₋2 = Bl[j-2] - Bj₋1j₋1 = Bd[j-1] - Bj₋1j = Bu[j-1] Bjj₋1 = Bl[j-1] Bjj = Bd[j] Bjj₊1 = Bu[j] @@ -729,9 +726,6 @@ function __bibimul!(C, A::Bidiagonal, B, _add) Bjj₋1 = Bl[j-1] Bjj = Bd[j] Bjj₊1 = Bu[j] - Bj₊1j = Bl[j] - Bj₊1j₊1 = Bd[j+1] - Bj₊1j₊2 = Bu[j+1] C[j,j-2] += _add( Ajj₋1*Bj₋1j₋2) C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1) C[j, j ] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj) @@ -752,8 +746,6 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add) for j in 3:n-2 Ajj = Ad[j] Ajj₊1 = Au[j] - Bj₋1j₋1 = Bd[j-1] - Bj₋1j = Bu[j-1] Bjj = Bd[j] Bjj₊1 = Bu[j] Bj₊1j₊1 = Bd[j+1] @@ -772,8 +764,6 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add) for j in 3:n-2 Ajj = Ad[j] Ajj₊1 = Au[j] - Bj₋1j₋2 = Bl[j-2] - Bj₋1j₋1 = Bd[j-1] Bjj₋1 = Bl[j-1] Bjj = Bd[j] Bj₊1j = Bl[j] @@ -796,8 +786,6 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add) Bj₋1j = Bu[j-1] Bjj = Bd[j] Bjj₊1 = Bu[j] - Bj₊1j₊1 = Bd[j+1] - Bj₊1j₊2 = Bu[j+1] C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1) C[j, j ] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj) C[j, j+1] += _add(Ajj *Bjj₊1) @@ -816,8 +804,6 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add) Bj₋1j₋1 = Bd[j-1] Bjj₋1 = Bl[j-1] Bjj = Bd[j] - Bj₊1j = Bl[j] - Bj₊1j₊1 = Bd[j+1] C[j,j-2] += _add( Ajj₋1*Bj₋1j₋2) C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1) C[j, j ] += _add(Ajj*Bjj) From b70b6939bf557a61b3d6716f96cd81b3fcf462e1 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 11 Aug 2024 23:56:08 +0530 Subject: [PATCH 08/10] _diag for SymTridiagonal --- stdlib/LinearAlgebra/src/bidiag.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 7262ae98b7e7b..b2907f3fa792f 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -553,7 +553,8 @@ end # function to get the internally stored vectors for Bidiagonal and [Sym]Tridiagonal # to avoid allocations in _mul! below (#24324, #24578) _diag(A::Tridiagonal, k) = k == -1 ? A.dl : k == 0 ? A.d : A.du -_diag(A::SymTridiagonal, k) = k == 0 ? A.dv : A.ev +_diag(A::SymTridiagonal{<:Number}, k) = k == 0 ? A.dv : A.ev +_diag(A::SymTridiagonal, k) = k == 0 ? view(A, diagind(A, IndexStyle(A))) : view(A, diagind(A, 1, IndexStyle(A))) function _diag(A::Bidiagonal, k) if k == 0 return A.dv From 2d15c7071fca6991c1b02dab88af92f9bc481433 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 12 Aug 2024 14:44:35 +0530 Subject: [PATCH 09/10] Column-major iteration in bibimul --- stdlib/LinearAlgebra/src/bidiag.jl | 212 ++++++++++++++--------------- 1 file changed, 106 insertions(+), 106 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index b2907f3fa792f..ba22f9ea4120e 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -586,28 +586,27 @@ function _bibimul!(C, A, B, _add) _rmul_or_fill!(C, _add.beta) iszero(_add.alpha) && return C @inbounds begin - # first row of C - C[1,1] += _add(A[1,1]*B[1,1] + A[1, 2]*B[2, 1]) - C[1,2] += _add(A[1,1]*B[1,2] + A[1,2]*B[2,2]) - C[1,3] += _add(A[1,2]*B[2,3]) - # second row of C + # first column of C + C[1,1] += _add(A[1,1]*B[1,1] + A[1, 2]*B[2,1]) C[2,1] += _add(A[2,1]*B[1,1] + A[2,2]*B[2,1]) + C[3,1] += _add(A[3,2]*B[2,1]) + # second column of C + C[1,2] += _add(A[1,1]*B[1,2] + A[1,2]*B[2,2]) C[2,2] += _add(A[2,1]*B[1,2] + A[2,2]*B[2,2] + A[2,3]*B[3,2]) - C[2,3] += _add(A[2,2]*B[2,3] + A[2,3]*B[3,3]) - C[2,4] += _add(A[2,3]*B[3,4]) - end - # middle rows + C[3,2] += _add(A[3,2]*B[2,2] + A[3,3]*B[3,2]) + C[4,2] += _add(A[4,3]*B[3,2]) + end # inbounds + # middle columns __bibimul!(C, A, B, _add) @inbounds begin - # row before last of C - C[n-1,n-3] += _add(A[n-1,n-2]*B[n-2,n-3]) - C[n-1,n-2] += _add(A[n-1,n-1]*B[n-1,n-2] + A[n-1,n-2]*B[n-2,n-2]) + C[n-3,n-1] += _add(A[n-3,n-2]*B[n-2,n-1]) + C[n-2,n-1] += _add(A[n-2,n-2]*B[n-2,n-1] + A[n-2,n-1]*B[n-1,n-1]) C[n-1,n-1] += _add(A[n-1,n-2]*B[n-2,n-1] + A[n-1,n-1]*B[n-1,n-1] + A[n-1,n]*B[n,n-1]) - C[n-1,n ] += _add(A[n-1,n-1]*B[n-1,n ] + A[n-1, n]*B[n ,n ]) - # last row of C - C[n,n-2] += _add(A[n,n-1]*B[n-1,n-2]) - C[n,n-1] += _add(A[n,n-1]*B[n-1,n-1] + A[n,n]*B[n,n-1]) - C[n,n ] += _add(A[n,n-1]*B[n-1,n ] + A[n,n]*B[n,n ]) + C[n, n-1] += _add(A[n,n-1]*B[n-1,n-1] + A[n,n]*B[n,n-1]) + # last column of C + C[n-2, n] += _add(A[n-2,n-1]*B[n-1,n]) + C[n-1, n] += _add(A[n-1,n-1]*B[n-1,n ] + A[n-1,n]*B[n,n ]) + C[n, n] += _add(A[n,n-1]*B[n-1,n ] + A[n,n]*B[n,n ]) end # inbounds C end @@ -621,23 +620,24 @@ function __bibimul!(C, A, B, _add) Bu = _diag(B, 1) @inbounds begin for j in 3:n-2 - Ajj₋1 = Al[j-1] - Ajj = Ad[j] + Aj₋2j₋1 = Au[j-2] + Aj₋1j = Au[j-1] Ajj₊1 = Au[j] - Bj₋1j₋2 = Bl[j-2] - Bj₋1j₋1 = Bd[j-1] + Aj₋1j₋1 = Ad[j-1] + Ajj = Ad[j] + Aj₊1j₊1 = Ad[j+1] + Ajj₋1 = Al[j-1] + Aj₊1j = Al[j] + Aj₊2j₊1 = Al[j+1] Bj₋1j = Bu[j-1] - Bjj₋1 = Bl[j-1] Bjj = Bd[j] - Bjj₊1 = Bu[j] Bj₊1j = Bl[j] - Bj₊1j₊1 = Bd[j+1] - Bj₊1j₊2 = Bu[j+1] - C[j,j-2] += _add( Ajj₋1*Bj₋1j₋2) - C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1) - C[j, j ] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj + Ajj₊1*Bj₊1j) - C[j, j+1] += _add(Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1) - C[j, j+2] += _add(Ajj₊1*Bj₊1j₊2) + + C[j-2, j] += _add(Aj₋2j₋1*Bj₋1j) + C[j-1, j] += _add(Aj₋1j₋1*Bj₋1j + Aj₋1j*Bjj) + C[j, j] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj + Ajj₊1*Bj₊1j) + C[j+1, j] += _add(Aj₊1j*Bjj + Aj₊1j₊1*Bj₊1j) + C[j+2, j] += _add(Aj₊2j₊1*Bj₊1j) end end C @@ -647,44 +647,43 @@ function __bibimul!(C, A, B::Bidiagonal, _add) Al = _diag(A, -1) Ad = _diag(A, 0) Au = _diag(A, 1) + Bd = _diag(B, 0) if B.uplo == 'U' - Bd = _diag(B, 0) Bu = _diag(B, 1) @inbounds begin for j in 3:n-2 - Ajj₋1 = Al[j-1] + Aj₋2j₋1 = Au[j-2] + Aj₋1j = Au[j-1] + Aj₋1j₋1 = Ad[j-1] Ajj = Ad[j] - Ajj₊1 = Au[j] - Bj₋1j₋1 = Bd[j-1] + Ajj₋1 = Al[j-1] + Aj₊1j = Al[j] Bj₋1j = Bu[j-1] Bjj = Bd[j] - Bjj₊1 = Bu[j] - Bj₊1j₊1 = Bd[j+1] - Bj₊1j₊2 = Bu[j+1] - C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1) - C[j, j ] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj) - C[j, j+1] += _add(Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1) - C[j, j+2] += _add(Ajj₊1*Bj₊1j₊2) + + C[j-2, j] += _add(Aj₋2j₋1*Bj₋1j) + C[j-1, j] += _add(Aj₋1j₋1*Bj₋1j + Aj₋1j*Bjj) + C[j, j] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj) + C[j+1, j] += _add(Aj₊1j*Bjj) end end else # B.uplo == 'L' Bl = _diag(B, -1) - Bd = _diag(B, 0) @inbounds begin for j in 3:n-2 - Ajj₋1 = Al[j-1] - Ajj = Ad[j] + Aj₋1j = Au[j-1] Ajj₊1 = Au[j] - Bj₋1j₋2 = Bl[j-2] - Bj₋1j₋1 = Bd[j-1] - Bjj₋1 = Bl[j-1] + Ajj = Ad[j] + Aj₊1j₊1 = Ad[j+1] + Aj₊1j = Al[j] + Aj₊2j₊1 = Al[j+1] Bjj = Bd[j] Bj₊1j = Bl[j] - Bj₊1j₊1 = Bd[j+1] - C[j,j-2] += _add( Ajj₋1*Bj₋1j₋2) - C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1) - C[j, j ] += _add(Ajj*Bjj + Ajj₊1*Bj₊1j) - C[j, j+1] += _add(Ajj₊1*Bj₊1j₊1) + + C[j-1, j] += _add(Aj₋1j*Bjj) + C[j, j] += _add(Ajj*Bjj + Ajj₊1*Bj₊1j) + C[j+1, j] += _add(Aj₊1j*Bjj + Aj₊1j₊1*Bj₊1j) + C[j+2, j] += _add(Aj₊2j₊1*Bj₊1j) end end end @@ -695,42 +694,45 @@ function __bibimul!(C, A::Bidiagonal, B, _add) Bl = _diag(B, -1) Bd = _diag(B, 0) Bu = _diag(B, 1) + Ad = _diag(A, 0) if A.uplo == 'U' - Ad = _diag(A, 0) Au = _diag(A, 1) @inbounds begin for j in 3:n-2 - Ajj = Ad[j] + Aj₋2j₋1 = Au[j-2] + Aj₋1j = Au[j-1] Ajj₊1 = Au[j] - Bjj₋1 = Bl[j-1] + Aj₋1j₋1 = Ad[j-1] + Ajj = Ad[j] + Aj₊1j₊1 = Ad[j+1] + Bj₋1j = Bu[j-1] Bjj = Bd[j] - Bjj₊1 = Bu[j] Bj₊1j = Bl[j] - Bj₊1j₊1 = Bd[j+1] - Bj₊1j₊2 = Bu[j+1] - C[j, j-1] += _add(Ajj*Bjj₋1) - C[j, j ] += _add(Ajj*Bjj + Ajj₊1*Bj₊1j) - C[j, j+1] += _add(Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1) - C[j, j+2] += _add(Ajj₊1*Bj₊1j₊2) + + C[j-2, j] += _add(Aj₋2j₋1*Bj₋1j) + C[j-1, j] += _add(Aj₋1j₋1*Bj₋1j + Aj₋1j*Bjj) + C[j, j] += _add(Ajj*Bjj + Ajj₊1*Bj₊1j) + C[j+1, j] += _add(Aj₊1j₊1*Bj₊1j) end end else # A.uplo == 'L' Al = _diag(A, -1) - Ad = _diag(A, 0) @inbounds begin for j in 3:n-2 - Ajj₋1 = Al[j-1] + Aj₋1j₋1 = Ad[j-1] Ajj = Ad[j] - Bj₋1j₋2 = Bl[j-2] - Bj₋1j₋1 = Bd[j-1] + Aj₊1j₊1 = Ad[j+1] + Ajj₋1 = Al[j-1] + Aj₊1j = Al[j] + Aj₊2j₊1 = Al[j+1] Bj₋1j = Bu[j-1] - Bjj₋1 = Bl[j-1] Bjj = Bd[j] - Bjj₊1 = Bu[j] - C[j,j-2] += _add( Ajj₋1*Bj₋1j₋2) - C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1) - C[j, j ] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj) - C[j, j+1] += _add(Ajj *Bjj₊1) + Bj₊1j = Bl[j] + + C[j-1, j] += _add(Aj₋1j₋1*Bj₋1j) + C[j, j] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj) + C[j+1, j] += _add(Aj₊1j*Bjj + Aj₊1j₊1*Bj₊1j) + C[j+2, j] += _add(Aj₊2j₊1*Bj₊1j) end end end @@ -738,76 +740,74 @@ function __bibimul!(C, A::Bidiagonal, B, _add) end function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add) n = size(A,1) + Ad = _diag(A, 0) + Bd = _diag(B, 0) if A.uplo == 'U' && B.uplo == 'U' - Ad = _diag(A, 0) Au = _diag(A, 1) - Bd = _diag(B, 0) Bu = _diag(B, 1) @inbounds begin for j in 3:n-2 + Aj₋2j₋1 = Au[j-2] + Aj₋1j = Au[j-1] + Aj₋1j₋1 = Ad[j-1] Ajj = Ad[j] - Ajj₊1 = Au[j] + Bj₋1j = Bu[j-1] Bjj = Bd[j] - Bjj₊1 = Bu[j] - Bj₊1j₊1 = Bd[j+1] - Bj₊1j₊2 = Bu[j+1] - C[j, j ] += _add(Ajj*Bjj) - C[j, j+1] += _add(Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1) - C[j, j+2] += _add(Ajj₊1*Bj₊1j₊2) + + C[j-2, j] += _add(Aj₋2j₋1*Bj₋1j) + C[j-1, j] += _add(Aj₋1j₋1*Bj₋1j + Aj₋1j*Bjj) + C[j, j] += _add(Ajj*Bjj) end end elseif A.uplo == 'U' && B.uplo == 'L' - Ad = _diag(A, 0) Au = _diag(A, 1) Bl = _diag(B, -1) - Bd = _diag(B, 0) @inbounds begin for j in 3:n-2 - Ajj = Ad[j] + Aj₋1j = Au[j-1] Ajj₊1 = Au[j] - Bjj₋1 = Bl[j-1] + Ajj = Ad[j] + Aj₊1j₊1 = Ad[j+1] Bjj = Bd[j] Bj₊1j = Bl[j] - Bj₊1j₊1 = Bd[j+1] - C[j, j-1] += _add(Ajj*Bjj₋1) - C[j, j ] += _add(Ajj*Bjj + Ajj₊1*Bj₊1j) - C[j, j+1] += _add(Ajj₊1*Bj₊1j₊1) + + C[j-1, j] += _add(Aj₋1j*Bjj) + C[j, j] += _add(Ajj*Bjj + Ajj₊1*Bj₊1j) + C[j+1, j] += _add(Aj₊1j₊1*Bj₊1j) end end elseif A.uplo == 'L' && B.uplo == 'U' Al = _diag(A, -1) - Ad = _diag(A, 0) - Bd = _diag(B, 0) Bu = _diag(B, 1) @inbounds begin for j in 3:n-2 - Ajj₋1 = Al[j-1] + Aj₋1j₋1 = Ad[j-1] Ajj = Ad[j] - Bj₋1j₋1 = Bd[j-1] + Ajj₋1 = Al[j-1] + Aj₊1j = Al[j] Bj₋1j = Bu[j-1] Bjj = Bd[j] - Bjj₊1 = Bu[j] - C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1) - C[j, j ] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj) - C[j, j+1] += _add(Ajj *Bjj₊1) + + C[j-1, j] += _add(Aj₋1j₋1*Bj₋1j) + C[j, j] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj) + C[j+1, j] += _add(Aj₊1j*Bjj) end end else # A.uplo == 'L' && B.uplo == 'L' Al = _diag(A, -1) - Ad = _diag(A, 0) Bl = _diag(B, -1) - Bd = _diag(B, 0) @inbounds begin for j in 3:n-2 - Ajj₋1 = Al[j-1] Ajj = Ad[j] - Bj₋1j₋2 = Bl[j-2] - Bj₋1j₋1 = Bd[j-1] - Bjj₋1 = Bl[j-1] + Aj₊1j₊1 = Ad[j+1] + Aj₊1j = Al[j] + Aj₊2j₊1 = Al[j+1] Bjj = Bd[j] - C[j,j-2] += _add( Ajj₋1*Bj₋1j₋2) - C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1) - C[j, j ] += _add(Ajj*Bjj) + Bj₊1j = Bl[j] + + C[j, j] += _add(Ajj*Bjj) + C[j+1, j] += _add(Aj₊1j*Bjj + Aj₊1j₊1*Bj₊1j) + C[j+2, j] += _add(Aj₊2j₊1*Bj₊1j) end end end From 8f30cdcbd386d44db3b5b8e3ef0ad2689b90c350 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 12 Aug 2024 15:00:53 +0530 Subject: [PATCH 10/10] Tests for a matrix-valued eltype --- stdlib/LinearAlgebra/test/bidiag.jl | 20 +++++++++++++++++++- stdlib/LinearAlgebra/test/tridiag.jl | 20 +++++++++++++++++++- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/stdlib/LinearAlgebra/test/bidiag.jl b/stdlib/LinearAlgebra/test/bidiag.jl index 9afab417f37c1..5b750bb9c63c7 100644 --- a/stdlib/LinearAlgebra/test/bidiag.jl +++ b/stdlib/LinearAlgebra/test/bidiag.jl @@ -1021,7 +1021,7 @@ end end @testset "mul for small matrices" begin - @testset for n in 0:4 + @testset for n in 0:6 D = Diagonal(rand(n)) v = rand(n) @testset for uplo in (:L, :U) @@ -1067,6 +1067,24 @@ end @test mul!(copy(T), BL, BU, 2, 3) ≈ ML * MU * 2 + T * 3 @test mul!(copy(T), BU, BL, 2, 3) ≈ MU * ML * 2 + T * 3 end + + n = 4 + arr = SizedArrays.SizedArray{(2,2)}(reshape([1:4;],2,2)) + for B in ( + Bidiagonal(fill(arr,n), fill(arr,n-1), :L), + Bidiagonal(fill(arr,n), fill(arr,n-1), :U), + ) + @test B * B ≈ Matrix(B) * Matrix(B) + BL = Bidiagonal(fill(arr,n), fill(arr,n-1), :L) + BU = Bidiagonal(fill(arr,n), fill(arr,n-1), :U) + @test BL * B ≈ Matrix(BL) * Matrix(B) + @test BU * B ≈ Matrix(BU) * Matrix(B) + @test B * BL ≈ Matrix(B) * Matrix(BL) + @test B * BU ≈ Matrix(B) * Matrix(BU) + D = Diagonal(fill(arr,n)) + @test D * B ≈ Matrix(D) * Matrix(B) + @test B * D ≈ Matrix(B) * Matrix(D) + end end end # module TestBidiagonal diff --git a/stdlib/LinearAlgebra/test/tridiag.jl b/stdlib/LinearAlgebra/test/tridiag.jl index 82e34c986ec85..0398726fe6503 100644 --- a/stdlib/LinearAlgebra/test/tridiag.jl +++ b/stdlib/LinearAlgebra/test/tridiag.jl @@ -931,7 +931,7 @@ end end @testset "mul for small matrices" begin - @testset for n in 0:4 + @testset for n in 0:6 for T in ( Tridiagonal(rand(max(n-1,0)), rand(n), rand(max(n-1,0))), SymTridiagonal(rand(n), rand(max(n-1,0))), @@ -981,6 +981,24 @@ end end end end + + n = 4 + arr = SizedArrays.SizedArray{(2,2)}(reshape([1:4;],2,2)) + for T in ( + SymTridiagonal(fill(arr,n), fill(arr,n-1)), + Tridiagonal(fill(arr,n-1), fill(arr,n), fill(arr,n-1)), + ) + @test T * T ≈ Matrix(T) * Matrix(T) + BL = Bidiagonal(fill(arr,n), fill(arr,n-1), :L) + BU = Bidiagonal(fill(arr,n), fill(arr,n-1), :U) + @test BL * T ≈ Matrix(BL) * Matrix(T) + @test BU * T ≈ Matrix(BU) * Matrix(T) + @test T * BL ≈ Matrix(T) * Matrix(BL) + @test T * BU ≈ Matrix(T) * Matrix(BU) + D = Diagonal(fill(arr,n)) + @test D * T ≈ Matrix(D) * Matrix(T) + @test T * D ≈ Matrix(T) * Matrix(D) + end end end # module TestTridiagonal