Skip to content

Commit

Permalink
Quick return for empty arrays in bidiagonal matrix multiplications (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub authored and KristofferC committed Sep 12, 2024
1 parent 51c50f8 commit e4a2fc4
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 11 deletions.
28 changes: 17 additions & 11 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -535,12 +535,18 @@ function rmul!(B::Bidiagonal, D::Diagonal)
end

@noinline function check_A_mul_B!_sizes((mC, nC)::NTuple{2,Integer}, (mA, nA)::NTuple{2,Integer}, (mB, nB)::NTuple{2,Integer})
# check for matching sizes in one column of B and C
check_A_mul_B!_sizes((mC,), (mA, nA), (mB,))
# ensure that the number of columns in B and C match
if nB != nC
throw(DimensionMismatch(lazy"second dimension of output C, $nC, and second dimension of B, $nB, must match"))
end
end
@noinline function check_A_mul_B!_sizes((mC,)::Tuple{Integer}, (mA, nA)::NTuple{2,Integer}, (mB,)::Tuple{Integer})
if mA != mC
throw(DimensionMismatch(lazy"first dimension of A, $mA, and first dimension of output C, $mC, must match"))
elseif nA != mB
throw(DimensionMismatch(lazy"second dimension of A, $nA, and first dimension of B, $mB, must match"))
elseif nB != nC
throw(DimensionMismatch(lazy"second dimension of output C, $nC, and second dimension of B, $nB, must match"))
end
end

Expand All @@ -563,8 +569,10 @@ _mul!(C::AbstractMatrix, A::BiTriSym, B::TriSym, _add::MulAddMul) =
_mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul) =
_bibimul!(C, A, B, _add)
function _bibimul!(C, A, B, _add)
require_one_based_indexing(C)
check_A_mul_B!_sizes(size(C), size(A), size(B))
n = size(A,1)
iszero(n) && return C
n <= 3 && return mul!(C, Array(A), Array(B), _add.alpha, _add.beta)
# We use `_rmul_or_fill!` instead of `_modify!` here since using
# `_modify!` in the following loop will not update the
Expand Down Expand Up @@ -727,15 +735,10 @@ end

function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulAddMul)
require_one_based_indexing(C, B)
check_A_mul_B!_sizes(size(C), size(A), size(B))
nA = size(A,1)
nB = size(B,2)
if !(size(C,1) == size(B,1) == nA)
throw(DimensionMismatch(lazy"A has first dimension $nA, B has $(size(B,1)), C has $(size(C,1)) but all must match"))
end
if size(C,2) != nB
throw(DimensionMismatch(lazy"A has second dimension $nA, B has $(size(B,2)), C has $(size(C,2)) but all must match"))
end
iszero(nA) && return C
(iszero(nA) || iszero(nB)) && return C
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
nA <= 3 && return mul!(C, Array(A), Array(B), _add.alpha, _add.beta)
l = _diag(A, -1)
Expand All @@ -758,9 +761,10 @@ end
function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul)
require_one_based_indexing(C, A)
check_A_mul_B!_sizes(size(C), size(A), size(B))
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
n = size(A,1)
m = size(B,2)
(iszero(m) || iszero(n)) && return C
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
if n <= 3 || m <= 1
return mul!(C, Array(A), Array(B), _add.alpha, _add.beta)
end
Expand Down Expand Up @@ -793,11 +797,12 @@ end
function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAddMul)
require_one_based_indexing(C, A)
check_A_mul_B!_sizes(size(C), size(A), size(B))
m, n = size(A)
(iszero(m) || iszero(n)) && return C
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
if size(A, 1) <= 3 || size(B, 2) <= 1
return mul!(C, Array(A), Array(B), _add.alpha, _add.beta)
end
m, n = size(A)
@inbounds if B.uplo == 'U'
for i in 1:m
for j in n:-1:2
Expand All @@ -824,6 +829,7 @@ function _dibimul!(C, A, B, _add)
require_one_based_indexing(C)
check_A_mul_B!_sizes(size(C), size(A), size(B))
n = size(A,1)
iszero(n) && return C
n <= 3 && return mul!(C, Array(A), Array(B), _add.alpha, _add.beta)
_rmul_or_fill!(C, _add.beta) # see the same use above
iszero(_add.alpha) && return C
Expand Down
22 changes: 22 additions & 0 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1023,4 +1023,26 @@ end
@test_throws "cannot set entry" B[1,2] = 4
end

@testset "mul with empty arrays" begin
A = zeros(5,0)
B = Bidiagonal(zeros(0), zeros(0), :U)
BL = Bidiagonal(zeros(5), zeros(4), :U)
@test size(A * B) == size(A)
@test size(BL * A) == size(A)
@test size(B * B) == size(B)
C = similar(A)
@test mul!(C, A, B) == A * B
@test mul!(C, BL, A) == BL * A
@test mul!(similar(B), B, B) == B * B
@test mul!(similar(B, size(B)), B, B) == B * B

v = zeros(size(B,2))
@test size(B * v) == size(v)
@test mul!(similar(v), B, v) == B * v

D = Diagonal(zeros(size(B,2)))
@test size(B * D) == size(D * B) == size(D)
@test mul!(similar(D), B, D) == mul!(similar(D), D, B) == B * D
end

end # module TestBidiagonal
22 changes: 22 additions & 0 deletions stdlib/LinearAlgebra/test/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,28 @@ end
end
end

@testset "mul with empty arrays" begin
A = zeros(5,0)
T = Tridiagonal(zeros(0), zeros(0), zeros(0))
TL = Tridiagonal(zeros(4), zeros(5), zeros(4))
@test size(A * T) == size(A)
@test size(TL * A) == size(A)
@test size(T * T) == size(T)
C = similar(A)
@test mul!(C, A, T) == A * T
@test mul!(C, TL, A) == TL * A
@test mul!(similar(T), T, T) == T * T
@test mul!(similar(T, size(T)), T, T) == T * T

v = zeros(size(T,2))
@test size(T * v) == size(v)
@test mul!(similar(v), T, v) == T * v

D = Diagonal(zeros(size(T,2)))
@test size(T * D) == size(D * T) == size(D)
@test mul!(similar(D), T, D) == mul!(similar(D), D, T) == T * D
end

@testset "show" begin
T = Tridiagonal(1:3, 1:4, 1:3)
@test sprint(show, T) == "Tridiagonal(1:3, 1:4, 1:3)"
Expand Down

0 comments on commit e4a2fc4

Please sign in to comment.