Skip to content

Commit a622302

Browse files
authored
Consistently check matrix sizes in matmul (#1152)
Fixes #1147, and makes the error messages more verbose. Before ```julia julia> zeros(0,4) * zeros(1) ERROR: DimensionMismatch: second dimension of matrix, 4, does not match length of input vector, 1 [...] julia> mul!(zeros(0), zeros(2,2), zeros(2)) ERROR: DimensionMismatch: first dimension of matrix, 2, does not match length of output vector, 0 [...] ``` After ```julia julia> zeros(0,4) * zeros(1) ERROR: DimensionMismatch: incompatible dimensions for matrix multiplication: tried to multiply a matrix of size (0, 4) with a vector of length 1. The second dimension of the matrix: 4, does not match the length of the vector: 1. [...] julia> zeros(0,4) * zeros(1,1) ERROR: DimensionMismatch: incompatible dimensions for matrix multiplication: tried to multiply a matrix of size (0, 4) with a matrix of size (1, 1). The second dimension of the first matrix: 4, does not match the first dimension of the second matrix: 1. [...] julia> mul!(zeros(0), zeros(2,2), zeros(2)) ERROR: DimensionMismatch: incompatible destination size: the destination vector of length 0 is incomatible with the product of a matrix of size (2, 2) and a vector of length 2. The destination must be of length 2. [...] julia> mul!(zeros(0,1), zeros(3,2), zeros(2,3)) ERROR: DimensionMismatch: incompatible destination size: the destination matrix of size (0, 1) is incomatible with the product of a matrix of size (3, 2) and a matrix of size (2, 3). The destination must be of size (3, 3). [...] ```
1 parent 6e5ea12 commit a622302

File tree

6 files changed

+128
-118
lines changed

6 files changed

+128
-118
lines changed

src/bidiag.jl

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ end
497497

498498
# B .= A * B
499499
function lmul!(A::Bidiagonal, B::AbstractVecOrMat)
500-
_muldiag_size_check(size(A), size(B))
500+
matmul_size_check(size(A), size(B))
501501
(; dv, ev) = A
502502
if A.uplo == 'U'
503503
for k in axes(B,2)
@@ -518,7 +518,7 @@ function lmul!(A::Bidiagonal, B::AbstractVecOrMat)
518518
end
519519
# B .= D * B
520520
function lmul!(D::Diagonal, B::Bidiagonal)
521-
_muldiag_size_check(size(D), size(B))
521+
matmul_size_check(size(D), size(B))
522522
(; dv, ev) = B
523523
isL = B.uplo == 'L'
524524
dv[1] = D.diag[1] * dv[1]
@@ -530,7 +530,7 @@ function lmul!(D::Diagonal, B::Bidiagonal)
530530
end
531531
# B .= B * A
532532
function rmul!(B::AbstractMatrix, A::Bidiagonal)
533-
_muldiag_size_check(size(A), size(B))
533+
matmul_size_check(size(A), size(B))
534534
(; dv, ev) = A
535535
if A.uplo == 'U'
536536
for k in reverse(axes(dv,1)[2:end])
@@ -555,7 +555,7 @@ function rmul!(B::AbstractMatrix, A::Bidiagonal)
555555
end
556556
# B .= B * D
557557
function rmul!(B::Bidiagonal, D::Diagonal)
558-
_muldiag_size_check(size(B), size(D))
558+
matmul_size_check(size(B), size(D))
559559
(; dv, ev) = B
560560
isU = B.uplo == 'U'
561561
dv[1] *= D.diag[1]
@@ -566,22 +566,6 @@ function rmul!(B::Bidiagonal, D::Diagonal)
566566
return B
567567
end
568568

569-
@noinline function check_A_mul_B!_sizes((mC, nC)::NTuple{2,Integer}, (mA, nA)::NTuple{2,Integer}, (mB, nB)::NTuple{2,Integer})
570-
# check for matching sizes in one column of B and C
571-
check_A_mul_B!_sizes((mC,), (mA, nA), (mB,))
572-
# ensure that the number of columns in B and C match
573-
if nB != nC
574-
throw(DimensionMismatch(lazy"second dimension of output C, $nC, and second dimension of B, $nB, must match"))
575-
end
576-
end
577-
@noinline function check_A_mul_B!_sizes((mC,)::Tuple{Integer}, (mA, nA)::NTuple{2,Integer}, (mB,)::Tuple{Integer})
578-
if mA != mC
579-
throw(DimensionMismatch(lazy"first dimension of A, $mA, and first dimension of output C, $mC, must match"))
580-
elseif nA != mB
581-
throw(DimensionMismatch(lazy"second dimension of A, $nA, and first dimension of B, $mB, must match"))
582-
end
583-
end
584-
585569
# function to get the internally stored vectors for Bidiagonal and [Sym]Tridiagonal
586570
# to avoid allocations in _mul! below (#24324, #24578)
587571
_diag(A::Tridiagonal, k) = k == -1 ? A.dl : k == 0 ? A.d : A.du
@@ -603,7 +587,7 @@ _mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul) =
603587
_bibimul!(C, A, B, _add)
604588
function _bibimul!(C, A, B, _add)
605589
require_one_based_indexing(C)
606-
check_A_mul_B!_sizes(size(C), size(A), size(B))
590+
matmul_size_check(size(C), size(A), size(B))
607591
n = size(A,1)
608592
iszero(n) && return C
609593
# We use `_rmul_or_fill!` instead of `_modify!` here since using
@@ -851,7 +835,7 @@ _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number)
851835
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
852836
function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
853837
require_one_based_indexing(C)
854-
check_A_mul_B!_sizes(size(C), size(A), size(B))
838+
matmul_size_check(size(C), size(A), size(B))
855839
n = size(A,1)
856840
iszero(n) && return C
857841
_rmul_or_fill!(C, _add.beta) # see the same use above
@@ -894,7 +878,7 @@ end
894878

895879
function _mul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
896880
require_one_based_indexing(C)
897-
check_A_mul_B!_sizes(size(C), size(A), size(B))
881+
matmul_size_check(size(C), size(A), size(B))
898882
n = size(A,1)
899883
iszero(n) && return C
900884
_rmul_or_fill!(C, _add.beta) # see the same use above
@@ -924,7 +908,7 @@ function _mul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
924908
end
925909

926910
function _mul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
927-
check_A_mul_B!_sizes(size(C), size(A), size(B))
911+
matmul_size_check(size(C), size(A), size(B))
928912
n = size(A,1)
929913
iszero(n) && return C
930914
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
@@ -957,7 +941,7 @@ end
957941

958942
function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulAddMul)
959943
require_one_based_indexing(C, B)
960-
check_A_mul_B!_sizes(size(C), size(A), size(B))
944+
matmul_size_check(size(C), size(A), size(B))
961945
nA = size(A,1)
962946
nB = size(B,2)
963947
(iszero(nA) || iszero(nB)) && return C
@@ -1027,7 +1011,7 @@ end
10271011

10281012
function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul)
10291013
require_one_based_indexing(C, A)
1030-
check_A_mul_B!_sizes(size(C), size(A), size(B))
1014+
matmul_size_check(size(C), size(A), size(B))
10311015
n = size(A,1)
10321016
m = size(B,2)
10331017
(iszero(_add.alpha) || iszero(m)) && return _rmul_or_fill!(C, _add.beta)
@@ -1063,7 +1047,7 @@ end
10631047

10641048
function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAddMul)
10651049
require_one_based_indexing(C, A)
1066-
check_A_mul_B!_sizes(size(C), size(A), size(B))
1050+
matmul_size_check(size(C), size(A), size(B))
10671051
m, n = size(A)
10681052
(iszero(m) || iszero(n)) && return C
10691053
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
@@ -1093,7 +1077,7 @@ _mul!(C::AbstractMatrix, A::Diagonal, B::TriSym, _add::MulAddMul) =
10931077
_dibimul!(C, A, B, _add)
10941078
function _dibimul!(C, A, B, _add)
10951079
require_one_based_indexing(C)
1096-
check_A_mul_B!_sizes(size(C), size(A), size(B))
1080+
matmul_size_check(size(C), size(A), size(B))
10971081
n = size(A,1)
10981082
iszero(n) && return C
10991083
# ensure that we fill off-band elements in the destination
@@ -1137,7 +1121,7 @@ function _dibimul!(C, A, B, _add)
11371121
end
11381122
function _dibimul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add)
11391123
require_one_based_indexing(C)
1140-
check_A_mul_B!_sizes(size(C), size(A), size(B))
1124+
matmul_size_check(size(C), size(A), size(B))
11411125
n = size(A,1)
11421126
iszero(n) && return C
11431127
# ensure that we fill off-band elements in the destination
@@ -1168,7 +1152,7 @@ function _dibimul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add)
11681152
C
11691153
end
11701154
function _dibimul!(C::Bidiagonal, A::Diagonal, B::Bidiagonal, _add)
1171-
check_A_mul_B!_sizes(size(C), size(A), size(B))
1155+
matmul_size_check(size(C), size(A), size(B))
11721156
n = size(A,1)
11731157
n == 0 && return C
11741158
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)

src/diagonal.jl

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -322,39 +322,18 @@ Base.literal_pow(::typeof(^), D::Diagonal, valp::Val) =
322322
Diagonal(Base.literal_pow.(^, D.diag, valp)) # for speed
323323
Base.literal_pow(::typeof(^), D::Diagonal, ::Val{-1}) = inv(D) # for disambiguation
324324

325-
function _muldiag_size_check(szA::NTuple{2,Integer}, szB::Tuple{Integer,Vararg{Integer}})
326-
nA = szA[2]
327-
mB = szB[1]
328-
@noinline throw_dimerr(szB::NTuple{2}, nA, mB) = throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match first dimension of B, $mB"))
329-
@noinline throw_dimerr(szB::NTuple{1}, nA, mB) = throw(DimensionMismatch(lazy"second dimension of D, $nA, does not match length of V, $mB"))
330-
nA == mB || throw_dimerr(szB, nA, mB)
331-
return nothing
332-
end
333-
# the output matrix should have the same size as the non-diagonal input matrix or vector
334-
@noinline throw_dimerr(szC, szA) = throw(DimensionMismatch(lazy"output matrix has size: $szC, but should have size $szA"))
335-
function _size_check_out(szC::NTuple{2}, szA::NTuple{2}, szB::NTuple{2})
336-
(szC[1] == szA[1] && szC[2] == szB[2]) || throw_dimerr(szC, (szA[1], szB[2]))
337-
end
338-
function _size_check_out(szC::NTuple{1}, szA::NTuple{2}, szB::NTuple{1})
339-
szC[1] == szA[1] || throw_dimerr(szC, (szA[1],))
340-
end
341-
function _muldiag_size_check(szC::Tuple{Vararg{Integer}}, szA::Tuple{Vararg{Integer}}, szB::Tuple{Vararg{Integer}})
342-
_muldiag_size_check(szA, szB)
343-
_size_check_out(szC, szA, szB)
344-
end
345-
346325
function (*)(Da::Diagonal, Db::Diagonal)
347-
_muldiag_size_check(size(Da), size(Db))
326+
matmul_size_check(size(Da), size(Db))
348327
return Diagonal(Da.diag .* Db.diag)
349328
end
350329

351330
function (*)(D::Diagonal, V::AbstractVector)
352-
_muldiag_size_check(size(D), size(V))
331+
matmul_size_check(size(D), size(V))
353332
return D.diag .* V
354333
end
355334

356335
function rmul!(A::AbstractMatrix, D::Diagonal)
357-
_muldiag_size_check(size(A), size(D))
336+
matmul_size_check(size(A), size(D))
358337
for I in CartesianIndices(A)
359338
row, col = Tuple(I)
360339
@inbounds A[row, col] *= D.diag[col]
@@ -363,7 +342,7 @@ function rmul!(A::AbstractMatrix, D::Diagonal)
363342
end
364343
# T .= T * D
365344
function rmul!(T::Tridiagonal, D::Diagonal)
366-
_muldiag_size_check(size(T), size(D))
345+
matmul_size_check(size(T), size(D))
367346
(; dl, d, du) = T
368347
d[1] *= D.diag[1]
369348
for i in axes(dl,1)
@@ -375,7 +354,7 @@ function rmul!(T::Tridiagonal, D::Diagonal)
375354
end
376355

377356
function lmul!(D::Diagonal, B::AbstractVecOrMat)
378-
_muldiag_size_check(size(D), size(B))
357+
matmul_size_check(size(D), size(B))
379358
for I in CartesianIndices(B)
380359
row = I[1]
381360
@inbounds B[I] = D.diag[row] * B[I]
@@ -386,7 +365,7 @@ end
386365
# in-place multiplication with a diagonal
387366
# T .= D * T
388367
function lmul!(D::Diagonal, T::Tridiagonal)
389-
_muldiag_size_check(size(D), size(T))
368+
matmul_size_check(size(D), size(T))
390369
(; dl, d, du) = T
391370
d[1] = D.diag[1] * d[1]
392371
for i in axes(dl,1)
@@ -507,7 +486,7 @@ end
507486
# specialize the non-trivial case
508487
function _mul_diag!(out, A, B, alpha, beta)
509488
require_one_based_indexing(out, A, B)
510-
_muldiag_size_check(size(out), size(A), size(B))
489+
matmul_size_check(size(out), size(A), size(B))
511490
if iszero(alpha)
512491
_rmul_or_fill!(out, beta)
513492
else
@@ -532,14 +511,14 @@ _mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number
532511
_mul_diag!(C, Da, Db, alpha, beta)
533512

534513
function (*)(Da::Diagonal, A::AbstractMatrix, Db::Diagonal)
535-
_muldiag_size_check(size(Da), size(A))
536-
_muldiag_size_check(size(A), size(Db))
514+
matmul_size_check(size(Da), size(A))
515+
matmul_size_check(size(A), size(Db))
537516
return broadcast(*, Da.diag, A, permutedims(Db.diag))
538517
end
539518

540519
function (*)(Da::Diagonal, Db::Diagonal, Dc::Diagonal)
541-
_muldiag_size_check(size(Da), size(Db))
542-
_muldiag_size_check(size(Db), size(Dc))
520+
matmul_size_check(size(Da), size(Db))
521+
matmul_size_check(size(Db), size(Dc))
543522
return Diagonal(Da.diag .* Db.diag .* Dc.diag)
544523
end
545524

0 commit comments

Comments
 (0)