From 974425c104aa9f7f3baf7ab93379248ea47a480f Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sat, 10 Aug 2024 21:43:30 +0530 Subject: [PATCH] Cherry-pick changes from jishnub/bidigamul_empty --- stdlib/LinearAlgebra/src/bidiag.jl | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 5eaf5e5a41001..18b8ff926e667 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -535,12 +535,18 @@ function rmul!(B::Bidiagonal, D::Diagonal) end @noinline function check_A_mul_B!_sizes((mC, nC)::NTuple{2,Integer}, (mA, nA)::NTuple{2,Integer}, (mB, nB)::NTuple{2,Integer}) + # check for matching sizes in one column of B and C + check_A_mul_B!_sizes((mC,), (mA, nA), (mB,)) + # ensure that the number of columns in B and C match + if nB != nC + throw(DimensionMismatch(lazy"second dimension of output C, $nC, and second dimension of B, $nB, must match")) + end +end +@noinline function check_A_mul_B!_sizes((mC,)::Tuple{Integer}, (mA, nA)::NTuple{2,Integer}, (mB,)::Tuple{Integer}) if mA != mC throw(DimensionMismatch(lazy"first dimension of A, $mA, and first dimension of output C, $mC, must match")) elseif nA != mB throw(DimensionMismatch(lazy"second dimension of A, $nA, and first dimension of B, $mB, must match")) - elseif nB != nC - throw(DimensionMismatch(lazy"second dimension of output C, $nC, and second dimension of B, $nB, must match")) end end @@ -563,6 +569,7 @@ _mul!(C::AbstractMatrix, A::BiTriSym, B::TriSym, _add::MulAddMul) = _mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul) = _bibimul!(C, A, B, _add) function _bibimul!(C, A, B, _add) + require_one_based_indexing(C) check_A_mul_B!_sizes(size(C), size(A), size(B)) n = size(A,1) if n <= 3 @@ -733,14 +740,9 @@ end function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulAddMul) require_one_based_indexing(C, B) + check_A_mul_B!_sizes(size(C), size(A), size(B)) nA = size(A,1) nB = size(B,2) - if !(size(C,1) == size(B,1) == nA) - throw(DimensionMismatch(lazy"A has first dimension $nA, B has $(size(B,1)), C has $(size(C,1)) but all must match")) - end - if size(C,2) != nB - throw(DimensionMismatch(lazy"A has second dimension $nA, B has $(size(B,2)), C has $(size(C,2)) but all must match")) - end iszero(nA) && return C iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta) if nA <= 3