From a622302e59a99606936d33dab9da0b51e085e95f Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 25 Dec 2024 14:06:19 +0530 Subject: [PATCH] Consistently check matrix sizes in matmul (#1152) Fixes https://github.com/JuliaLang/LinearAlgebra.jl/issues/1147, and makes the error messages more verbose. Before ```julia julia> zeros(0,4) * zeros(1) ERROR: DimensionMismatch: second dimension of matrix, 4, does not match length of input vector, 1 [...] julia> mul!(zeros(0), zeros(2,2), zeros(2)) ERROR: DimensionMismatch: first dimension of matrix, 2, does not match length of output vector, 0 [...] ``` After ```julia julia> zeros(0,4) * zeros(1) ERROR: DimensionMismatch: incompatible dimensions for matrix multiplication: tried to multiply a matrix of size (0, 4) with a vector of length 1. The second dimension of the matrix: 4, does not match the length of the vector: 1. [...] julia> zeros(0,4) * zeros(1,1) ERROR: DimensionMismatch: incompatible dimensions for matrix multiplication: tried to multiply a matrix of size (0, 4) with a matrix of size (1, 1). The second dimension of the first matrix: 4, does not match the first dimension of the second matrix: 1. [...] julia> mul!(zeros(0), zeros(2,2), zeros(2)) ERROR: DimensionMismatch: incompatible destination size: the destination vector of length 0 is incomatible with the product of a matrix of size (2, 2) and a vector of length 2. The destination must be of length 2. [...] julia> mul!(zeros(0,1), zeros(3,2), zeros(2,3)) ERROR: DimensionMismatch: incompatible destination size: the destination matrix of size (0, 1) is incomatible with the product of a matrix of size (3, 2) and a matrix of size (2, 3). The destination must be of size (3, 3). [...] ``` --- src/bidiag.jl | 44 ++++++----------- src/diagonal.jl | 43 +++++------------ src/matmul.jl | 114 +++++++++++++++++++++++++-------------------- src/triangular.jl | 12 ++--- test/matmul.jl | 19 ++++++++ test/triangular.jl | 14 ++++++ 6 files changed, 128 insertions(+), 118 deletions(-) diff --git a/src/bidiag.jl b/src/bidiag.jl index 2fb1415f..bb5b8830 100644 --- a/src/bidiag.jl +++ b/src/bidiag.jl @@ -497,7 +497,7 @@ end # B .= A * B function lmul!(A::Bidiagonal, B::AbstractVecOrMat) - _muldiag_size_check(size(A), size(B)) + matmul_size_check(size(A), size(B)) (; dv, ev) = A if A.uplo == 'U' for k in axes(B,2) @@ -518,7 +518,7 @@ function lmul!(A::Bidiagonal, B::AbstractVecOrMat) end # B .= D * B function lmul!(D::Diagonal, B::Bidiagonal) - _muldiag_size_check(size(D), size(B)) + matmul_size_check(size(D), size(B)) (; dv, ev) = B isL = B.uplo == 'L' dv[1] = D.diag[1] * dv[1] @@ -530,7 +530,7 @@ function lmul!(D::Diagonal, B::Bidiagonal) end # B .= B * A function rmul!(B::AbstractMatrix, A::Bidiagonal) - _muldiag_size_check(size(A), size(B)) + matmul_size_check(size(A), size(B)) (; dv, ev) = A if A.uplo == 'U' for k in reverse(axes(dv,1)[2:end]) @@ -555,7 +555,7 @@ function rmul!(B::AbstractMatrix, A::Bidiagonal) end # B .= B * D function rmul!(B::Bidiagonal, D::Diagonal) - _muldiag_size_check(size(B), size(D)) + matmul_size_check(size(B), size(D)) (; dv, ev) = B isU = B.uplo == 'U' dv[1] *= D.diag[1] @@ -566,22 +566,6 @@ function rmul!(B::Bidiagonal, D::Diagonal) return B end -@noinline function check_A_mul_B!_sizes((mC, nC)::NTuple{2,Integer}, (mA, nA)::NTuple{2,Integer}, (mB, nB)::NTuple{2,Integer}) - # check for matching sizes in one column of B and C - check_A_mul_B!_sizes((mC,), (mA, nA), (mB,)) - # ensure that the number of columns in B and C match - if nB != nC - throw(DimensionMismatch(lazy"second dimension of output C, $nC, and second dimension of B, $nB, must match")) - end -end -@noinline function check_A_mul_B!_sizes((mC,)::Tuple{Integer}, (mA, nA)::NTuple{2,Integer}, (mB,)::Tuple{Integer}) - if mA != mC - throw(DimensionMismatch(lazy"first dimension of A, $mA, and first dimension of output C, $mC, must match")) - elseif nA != mB - throw(DimensionMismatch(lazy"second dimension of A, $nA, and first dimension of B, $mB, must match")) - end -end - # function to get the internally stored vectors for Bidiagonal and [Sym]Tridiagonal # to avoid allocations in _mul! below (#24324, #24578) _diag(A::Tridiagonal, k) = k == -1 ? A.dl : k == 0 ? A.d : A.du @@ -603,7 +587,7 @@ _mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul) = _bibimul!(C, A, B, _add) function _bibimul!(C, A, B, _add) require_one_based_indexing(C) - check_A_mul_B!_sizes(size(C), size(A), size(B)) + matmul_size_check(size(C), size(A), size(B)) n = size(A,1) iszero(n) && return C # We use `_rmul_or_fill!` instead of `_modify!` here since using @@ -851,7 +835,7 @@ _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number) @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul) require_one_based_indexing(C) - check_A_mul_B!_sizes(size(C), size(A), size(B)) + matmul_size_check(size(C), size(A), size(B)) n = size(A,1) iszero(n) && return C _rmul_or_fill!(C, _add.beta) # see the same use above @@ -894,7 +878,7 @@ end function _mul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMul) require_one_based_indexing(C) - check_A_mul_B!_sizes(size(C), size(A), size(B)) + matmul_size_check(size(C), size(A), size(B)) n = size(A,1) iszero(n) && return C _rmul_or_fill!(C, _add.beta) # see the same use above @@ -924,7 +908,7 @@ function _mul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMul) end function _mul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, _add::MulAddMul) - check_A_mul_B!_sizes(size(C), size(A), size(B)) + matmul_size_check(size(C), size(A), size(B)) n = size(A,1) iszero(n) && return C iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta) @@ -957,7 +941,7 @@ end function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulAddMul) require_one_based_indexing(C, B) - check_A_mul_B!_sizes(size(C), size(A), size(B)) + matmul_size_check(size(C), size(A), size(B)) nA = size(A,1) nB = size(B,2) (iszero(nA) || iszero(nB)) && return C @@ -1027,7 +1011,7 @@ end function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul) require_one_based_indexing(C, A) - check_A_mul_B!_sizes(size(C), size(A), size(B)) + matmul_size_check(size(C), size(A), size(B)) n = size(A,1) m = size(B,2) (iszero(_add.alpha) || iszero(m)) && return _rmul_or_fill!(C, _add.beta) @@ -1063,7 +1047,7 @@ end function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAddMul) require_one_based_indexing(C, A) - check_A_mul_B!_sizes(size(C), size(A), size(B)) + matmul_size_check(size(C), size(A), size(B)) m, n = size(A) (iszero(m) || iszero(n)) && return C iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta) @@ -1093,7 +1077,7 @@ _mul!(C::AbstractMatrix, A::Diagonal, B::TriSym, _add::MulAddMul) = _dibimul!(C, A, B, _add) function _dibimul!(C, A, B, _add) require_one_based_indexing(C) - check_A_mul_B!_sizes(size(C), size(A), size(B)) + matmul_size_check(size(C), size(A), size(B)) n = size(A,1) iszero(n) && return C # ensure that we fill off-band elements in the destination @@ -1137,7 +1121,7 @@ function _dibimul!(C, A, B, _add) end function _dibimul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add) require_one_based_indexing(C) - check_A_mul_B!_sizes(size(C), size(A), size(B)) + matmul_size_check(size(C), size(A), size(B)) n = size(A,1) iszero(n) && return C # ensure that we fill off-band elements in the destination @@ -1168,7 +1152,7 @@ function _dibimul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add) C end function _dibimul!(C::Bidiagonal, A::Diagonal, B::Bidiagonal, _add) - check_A_mul_B!_sizes(size(C), size(A), size(B)) + matmul_size_check(size(C), size(A), size(B)) n = size(A,1) n == 0 && return C iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta) diff --git a/src/diagonal.jl b/src/diagonal.jl index 280caec1..9f8d54e5 100644 --- a/src/diagonal.jl +++ b/src/diagonal.jl @@ -322,39 +322,18 @@ Base.literal_pow(::typeof(^), D::Diagonal, valp::Val) = Diagonal(Base.literal_pow.(^, D.diag, valp)) # for speed Base.literal_pow(::typeof(^), D::Diagonal, ::Val{-1}) = inv(D) # for disambiguation -function _muldiag_size_check(szA::NTuple{2,Integer}, szB::Tuple{Integer,Vararg{Integer}}) - nA = szA[2] - mB = szB[1] - @noinline throw_dimerr(szB::NTuple{2}, nA, mB) = throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match first dimension of B, $mB")) - @noinline throw_dimerr(szB::NTuple{1}, nA, mB) = throw(DimensionMismatch(lazy"second dimension of D, $nA, does not match length of V, $mB")) - nA == mB || throw_dimerr(szB, nA, mB) - return nothing -end -# the output matrix should have the same size as the non-diagonal input matrix or vector -@noinline throw_dimerr(szC, szA) = throw(DimensionMismatch(lazy"output matrix has size: $szC, but should have size $szA")) -function _size_check_out(szC::NTuple{2}, szA::NTuple{2}, szB::NTuple{2}) - (szC[1] == szA[1] && szC[2] == szB[2]) || throw_dimerr(szC, (szA[1], szB[2])) -end -function _size_check_out(szC::NTuple{1}, szA::NTuple{2}, szB::NTuple{1}) - szC[1] == szA[1] || throw_dimerr(szC, (szA[1],)) -end -function _muldiag_size_check(szC::Tuple{Vararg{Integer}}, szA::Tuple{Vararg{Integer}}, szB::Tuple{Vararg{Integer}}) - _muldiag_size_check(szA, szB) - _size_check_out(szC, szA, szB) -end - function (*)(Da::Diagonal, Db::Diagonal) - _muldiag_size_check(size(Da), size(Db)) + matmul_size_check(size(Da), size(Db)) return Diagonal(Da.diag .* Db.diag) end function (*)(D::Diagonal, V::AbstractVector) - _muldiag_size_check(size(D), size(V)) + matmul_size_check(size(D), size(V)) return D.diag .* V end function rmul!(A::AbstractMatrix, D::Diagonal) - _muldiag_size_check(size(A), size(D)) + matmul_size_check(size(A), size(D)) for I in CartesianIndices(A) row, col = Tuple(I) @inbounds A[row, col] *= D.diag[col] @@ -363,7 +342,7 @@ function rmul!(A::AbstractMatrix, D::Diagonal) end # T .= T * D function rmul!(T::Tridiagonal, D::Diagonal) - _muldiag_size_check(size(T), size(D)) + matmul_size_check(size(T), size(D)) (; dl, d, du) = T d[1] *= D.diag[1] for i in axes(dl,1) @@ -375,7 +354,7 @@ function rmul!(T::Tridiagonal, D::Diagonal) end function lmul!(D::Diagonal, B::AbstractVecOrMat) - _muldiag_size_check(size(D), size(B)) + matmul_size_check(size(D), size(B)) for I in CartesianIndices(B) row = I[1] @inbounds B[I] = D.diag[row] * B[I] @@ -386,7 +365,7 @@ end # in-place multiplication with a diagonal # T .= D * T function lmul!(D::Diagonal, T::Tridiagonal) - _muldiag_size_check(size(D), size(T)) + matmul_size_check(size(D), size(T)) (; dl, d, du) = T d[1] = D.diag[1] * d[1] for i in axes(dl,1) @@ -507,7 +486,7 @@ end # specialize the non-trivial case function _mul_diag!(out, A, B, alpha, beta) require_one_based_indexing(out, A, B) - _muldiag_size_check(size(out), size(A), size(B)) + matmul_size_check(size(out), size(A), size(B)) if iszero(alpha) _rmul_or_fill!(out, beta) else @@ -532,14 +511,14 @@ _mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number _mul_diag!(C, Da, Db, alpha, beta) function (*)(Da::Diagonal, A::AbstractMatrix, Db::Diagonal) - _muldiag_size_check(size(Da), size(A)) - _muldiag_size_check(size(A), size(Db)) + matmul_size_check(size(Da), size(A)) + matmul_size_check(size(A), size(Db)) return broadcast(*, Da.diag, A, permutedims(Db.diag)) end function (*)(Da::Diagonal, Db::Diagonal, Dc::Diagonal) - _muldiag_size_check(size(Da), size(Db)) - _muldiag_size_check(size(Db), size(Dc)) + matmul_size_check(size(Da), size(Db)) + matmul_size_check(size(Db), size(Dc)) return Diagonal(Da.diag .* Db.diag .* Dc.diag) end diff --git a/src/matmul.jl b/src/matmul.jl index b923cd59..76719c53 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -408,6 +408,55 @@ julia> lmul!(F.Q, B) """ lmul!(A, B) +_vec_or_mat_str(s::Tuple{Any}) = "vector" +_vec_or_mat_str(s::Tuple{Any,Any}) = "matrix" +@noinline function matmul_size_check(sizeA::Tuple{Integer,Vararg{Integer}}, sizeB::Tuple{Integer,Vararg{Integer}}) + szA2 = get(sizeA, 2, 1) + if szA2 != sizeB[1] + strA = _vec_or_mat_str(sizeA) + strB = _vec_or_mat_str(sizeB) + B_size_len = length(sizeB) == 1 ? sizeB[1] : sizeB + size_or_len_str_B = B_size_len isa Integer ? "length" : "size" + dim_or_len_str_B = B_size_len isa Integer ? "length" : "first dimension" + pos_str_A = LazyString(length(sizeA) == length(sizeB) ? "first " : "", strA) + pos_str_B = LazyString(length(sizeA) == length(sizeB) ? "second " : "", strB) + throw(DimensionMismatch( + LazyString( + lazy"incompatible dimensions for matrix multiplication: ", + lazy"tried to multiply a $strA of size $sizeA with a $strB of $size_or_len_str_B $B_size_len. ", + lazy"The second dimension of the $pos_str_A: $szA2, does not match the $dim_or_len_str_B of the $pos_str_B: $(sizeB[1])." + ) + ) + ) + end + return nothing +end +@noinline function matmul_size_check(sizeC::Tuple{Integer,Vararg{Integer}}, sizeA::Tuple{Integer,Vararg{Integer}}, sizeB::Tuple{Integer,Vararg{Integer}}) + matmul_size_check(sizeA, sizeB) + szB2 = get(sizeB, 2, 1) + szC2 = get(sizeC, 2, 1) + if sizeC[1] != sizeA[1] || szC2 != szB2 + strA = _vec_or_mat_str(sizeA) + strB = _vec_or_mat_str(sizeB) + strC = _vec_or_mat_str(sizeC) + C_size_len = length(sizeC) == 1 ? sizeC[1] : sizeC + size_or_len_str_C = C_size_len isa Integer ? "length" : "size" + B_size_len = length(sizeB) == 1 ? sizeB[1] : sizeB + size_or_len_str_B = B_size_len isa Integer ? "length" : "size" + destsize = length(sizeB) == length(sizeC) == 1 ? sizeA[1] : (sizeA[1], szB2) + size_or_len_str_dest = destsize isa Integer ? "length" : "size" + throw(DimensionMismatch( + LazyString( + "incompatible destination size: ", + lazy"the destination $strC of $size_or_len_str_C $C_size_len is incomatible with the product of a $strA of size $sizeA and a $strB of $size_or_len_str_B $B_size_len. ", + lazy"The destination must be of $size_or_len_str_dest $destsize." + ) + ) + ) + end + return nothing +end + # We may inline the matmul2x2! and matmul3x3! calls for `α == true` # to simplify the @stable_muladdmul branches function matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) @@ -441,9 +490,7 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α) - if size(C) != (mA, nB) - throw(DimensionMismatch(lazy"C has dimensions $(size(C)), should have ($mA,$nB)")) - end + matmul_size_check(size(C), (mA, nA), (mB, nB)) return _rmul_or_fill!(C, β) end matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C @@ -474,10 +521,8 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix α::Number, β::Number, val::BlasFlag.SymmHemmGeneric) where {T<:BlasFloat} mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) + matmul_size_check(size(C), (mA, nA), (mB, nB)) if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α) - if size(C) != (mA, nB) - throw(DimensionMismatch(lazy"C has dimensions $(size(C)), should have ($mA,$nB)")) - end return _rmul_or_fill!(C, β) end matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C @@ -571,10 +616,7 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar A::StridedVecOrMat{T}, x::StridedVector{T}, α::Number=true, β::Number=false) where {T<:BlasFloat} mA, nA = lapack_size(tA, A) - nA != length(x) && - throw(DimensionMismatch(lazy"second dimension of matrix, $nA, does not match length of input vector, $(length(x))")) - mA != length(y) && - throw(DimensionMismatch(lazy"first dimension of matrix, $mA, does not match length of output vector, $(length(y))")) + matmul_size_check(size(y), (mA, nA), size(x)) mA == 0 && return y nA == 0 && return _rmul_or_fill!(y, β) alpha, beta = promote(α, β, zero(T)) @@ -602,10 +644,7 @@ end Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T}, α::Number = true, β::Number = false) where {T<:BlasReal} mA, nA = lapack_size(tA, A) - nA != length(x) && - throw(DimensionMismatch(lazy"second dimension of matrix, $nA, does not match length of input vector, $(length(x))")) - mA != length(y) && - throw(DimensionMismatch(lazy"first dimension of matrix, $mA, does not match length of output vector, $(length(y))")) + matmul_size_check(size(y), (mA, nA), size(x)) mA == 0 && return y nA == 0 && return _rmul_or_fill!(y, β) alpha, beta = promote(α, β, zero(T)) @@ -626,10 +665,7 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs A::StridedVecOrMat{T}, x::StridedVector{Complex{T}}, α::Number = true, β::Number = false) where {T<:BlasReal} mA, nA = lapack_size(tA, A) - nA != length(x) && - throw(DimensionMismatch(lazy"second dimension of matrix, $nA, does not match length of input vector, $(length(x))")) - mA != length(y) && - throw(DimensionMismatch(lazy"first dimension of matrix, $mA, does not match length of output vector, $(length(y))")) + matmul_size_check(size(y), (mA, nA), size(x)) mA == 0 && return y nA == 0 && return _rmul_or_fill!(y, β) alpha, beta = promote(α, β, zero(T)) @@ -665,7 +701,7 @@ Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::Abst tAt = 'T' end if nC != mA - throw(DimensionMismatch(lazy"output matrix has size: $(nC), but should have size $(mA)")) + throw(DimensionMismatch(lazy"output matrix has size: $(size(C)), but should have size $((mA, mA))")) end # BLAS.syrk! only updates symmetric C @@ -699,7 +735,7 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St tAt = 'C' end if nC != mA - throw(DimensionMismatch(lazy"output matrix has size: $(nC), but should have size $(mA)")) + throw(DimensionMismatch(lazy"output matrix has size: $(size(C)), but should have size $((mA, mA))")) end # Result array does not need to be initialized as long as beta==0 @@ -748,9 +784,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::Ab mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) - if nA != mB - throw(DimensionMismatch(lazy"A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)")) - end + matmul_size_check(size(C), (mA, nA), (mB, nB)) if C === A || B === C throw(ArgumentError("output matrix must not be aliased with input matrix")) @@ -778,9 +812,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T} mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) - if nA != mB - throw(DimensionMismatch(lazy"A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)")) - end + matmul_size_check(size(C), (mA, nA), (mB, nB)) if C === A || B === C throw(ArgumentError("output matrix must not be aliased with input matrix")) @@ -940,14 +972,8 @@ function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::Abst alpha::Number, beta::Number) require_one_based_indexing(C, A, B) @assert tA in ('N', 'T', 'C') - mB = length(B) mA, nA = lapack_size(tA, A) - if mB != nA - throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), vector B has length $mB")) - end - if mA != length(C) - throw(DimensionMismatch(lazy"result C has length $(length(C)), needs length $mA")) - end + matmul_size_check(size(C), (mA, nA), size(B)) if tA == 'T' # fastest case __generic_matvecmul!(transpose, C, A, B, alpha, beta) @@ -979,21 +1005,7 @@ _generic_matmatmul!(C::AbstractVecOrMat, A::AbstractVecOrMat, B::AbstractVecOrMa @noinline function _generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat, B::AbstractVecOrMat, alpha::Number, beta::Number) where {R} - AxM = axes(A, 1) - AxK = axes(A, 2) # we use two `axes` calls in case of `AbstractVector` - BxK = axes(B, 1) - BxN = axes(B, 2) - CxM = axes(C, 1) - CxN = axes(C, 2) - if AxM != CxM - throw(DimensionMismatch(lazy"matrix A has axes ($AxM,$AxK), matrix C has axes ($CxM,$CxN)")) - end - if AxK != BxK - throw(DimensionMismatch(lazy"matrix A has axes ($AxM,$AxK), matrix B has axes ($BxK,$CxN)")) - end - if BxN != CxN - throw(DimensionMismatch(lazy"matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)")) - end + matmul_size_check(size(C), size(A), size(B)) __generic_matmatmul!(C, A, B, alpha, beta, Val(isbitstype(R) && sizeof(R) ≤ 16)) return C end @@ -1055,11 +1067,13 @@ end function __matmul_checks(C, A, B, sz) require_one_based_indexing(C, A, B) + matmul_size_check(size(C), size(A), size(B)) if C === A || B === C throw(ArgumentError("output matrix must not be aliased with input matrix")) end - if !(size(A) == size(B) == size(C) == sz) - throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(B)), C has size $(size(C))")) + if !(size(A) == size(B) == sz) # if A and B are both of size sz, C must also be of size sz for the matmul_size_check to pass + pos, mismatched_sz = size(A) != sz ? ("first", size(A)) : ("second", size(B)) + throw(DimensionMismatch(lazy"expected size: $sz, but got size $mismatched_sz for the $pos matrix")) end return nothing end diff --git a/src/triangular.jl b/src/triangular.jl index abe3471c..5c6b5188 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -1094,7 +1094,7 @@ end for TC in (:AbstractVector, :AbstractMatrix) @eval @inline function _mul!(C::$TC, A::AbstractTriangular, B::AbstractVector, alpha::Number, beta::Number) - check_A_mul_B!_sizes(size(C), size(A), size(B)) + matmul_size_check(size(C), size(A), size(B)) if isone(alpha) && iszero(beta) return _trimul!(C, A, B) else @@ -1107,7 +1107,7 @@ for (TA, TB) in ((:AbstractTriangular, :AbstractMatrix), (:AbstractTriangular, :AbstractTriangular) ) @eval @inline function _mul!(C::AbstractMatrix, A::$TA, B::$TB, alpha::Number, beta::Number) - check_A_mul_B!_sizes(size(C), size(A), size(B)) + matmul_size_check(size(C), size(A), size(B)) if isone(alpha) && iszero(beta) return _trimul!(C, A, B) else @@ -1341,7 +1341,7 @@ end ## Generic triangular multiplication function generic_trimatmul!(C::AbstractVecOrMat, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractVecOrMat) require_one_based_indexing(C, A, B) - check_A_mul_B!_sizes(size(C), size(A), size(B)) + matmul_size_check(size(C), size(A), size(B)) oA = oneunit(eltype(A)) unit = isunitc == 'U' @inbounds if uploc == 'U' @@ -1394,7 +1394,7 @@ end # conjugate cases function generic_trimatmul!(C::AbstractVecOrMat, uploc, isunitc, ::Function, xA::AdjOrTrans, B::AbstractVecOrMat) require_one_based_indexing(C, xA, B) - check_A_mul_B!_sizes(size(C), size(xA), size(B)) + matmul_size_check(size(C), size(xA), size(B)) A = parent(xA) oA = oneunit(eltype(A)) unit = isunitc == 'U' @@ -1424,7 +1424,7 @@ end function generic_mattrimul!(C::AbstractMatrix, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractMatrix) require_one_based_indexing(C, A, B) - check_A_mul_B!_sizes(size(C), size(A), size(B)) + matmul_size_check(size(C), size(A), size(B)) oB = oneunit(eltype(B)) unit = isunitc == 'U' @inbounds if uploc == 'U' @@ -1477,7 +1477,7 @@ end # conjugate cases function generic_mattrimul!(C::AbstractMatrix, uploc, isunitc, ::Function, A::AbstractMatrix, xB::AdjOrTrans) require_one_based_indexing(C, A, xB) - check_A_mul_B!_sizes(size(C), size(A), size(xB)) + matmul_size_check(size(C), size(A), size(xB)) B = parent(xB) oB = oneunit(eltype(B)) unit = isunitc == 'U' diff --git a/test/matmul.jl b/test/matmul.jl index 1294e97c..805edeac 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -1148,4 +1148,23 @@ end @test A * A ≈ M * M end +@testset "issue #1147: error messages in matmul" begin + for T in (Int, Float64, ComplexF64) + for f in (identity, Symmetric) + @test_throws "incompatible dimensions for matrix multiplication" f(zeros(T,0,0)) * zeros(T,1,5) + @test_throws "incompatible dimensions for matrix multiplication" f(zeros(T,0,0)) * zeros(T,1) + @test_throws "incompatible dimensions for matrix multiplication" zeros(T,0) * f(zeros(T,2,2)) + @test_throws "incompatible dimensions for matrix multiplication" mul!(zeros(T,0,0), zeros(T,5), zeros(T,5)) + @test_throws "incompatible dimensions for matrix multiplication" mul!(zeros(T,0,0), f(zeros(T,1,1)), zeros(T,0,0)) + @test_throws "incompatible destination size" mul!(zeros(T,0,2), f(zeros(T,1,1)), zeros(T,1,2)) + @test_throws "incompatible destination size" mul!(zeros(T,1,0), f(zeros(T,1,1)), zeros(T,1,2)) + @test_throws "incompatible destination size" mul!(zeros(T,0,0), f(zeros(T,1,1)), zeros(T,1)) + @test_throws "incompatible destination size" mul!(zeros(T,0), f(zeros(T,1,1)), zeros(T,1)) + end + + @test_throws "expected size: (2, 2)" LinearAlgebra.matmul2x2!(zeros(T,2,2), 'N', 'N', zeros(T,2,3), zeros(T,3,2)) + @test_throws "expected size: (2, 2)" LinearAlgebra.matmul2x2!(zeros(T,2,3), 'N', 'N', zeros(T,2,2), zeros(T,2,3)) + end +end + end # module TestMatmul diff --git a/test/triangular.jl b/test/triangular.jl index ca172229..499e052e 100644 --- a/test/triangular.jl +++ b/test/triangular.jl @@ -1429,4 +1429,18 @@ end @test rdiv!(B2v, U) ≈ rdiv!(B2vc, U) end +@testset "error messages in matmul with mismatched matrix sizes" begin + for T in (Int, Float64) + A = UpperTriangular(ones(T,2,2)) + B = ones(T,3,3) + C = similar(B) + @test_throws "incompatible dimensions for matrix multiplication" mul!(C, A, B) + @test_throws "incompatible dimensions for matrix multiplication" mul!(C, B, A) + B = Array(A) + C = similar(B, (4,4)) + @test_throws "incompatible destination size" mul!(C, A, B) + @test_throws "incompatible destination size" mul!(C, B, A) + end +end + end # module TestTriangular