Skip to content

Commit e3bb870

Browse files
authored
re-allow vector*(1-row matrix) and transpose thereof (#20423)
* re-allow vector*(1-row matrix) and transpose (closes #20389)
1 parent 7aab89d commit e3bb870

File tree

5 files changed

+24
-39
lines changed

5 files changed

+24
-39
lines changed

base/linalg/diagonal.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -256,16 +256,9 @@ end
256256

257257
# Methods to resolve ambiguities with `Diagonal`
258258
@inline *(rowvec::RowVector, D::Diagonal) = transpose(D * transpose(rowvec))
259-
*(::Diagonal, ::RowVector) = throw(DimensionMismatch("Cannot right-multiply matrix by transposed vector"))
260-
261259
@inline A_mul_Bt(D::Diagonal, rowvec::RowVector) = D*transpose(rowvec)
262-
263-
At_mul_B(rowvec::RowVector, ::Diagonal) = throw(DimensionMismatch("Cannot left-multiply matrix by vector"))
264-
265260
@inline A_mul_Bc(D::Diagonal, rowvec::RowVector) = D*ctranspose(rowvec)
266261

267-
Ac_mul_B(rowvec::RowVector, ::Diagonal) = throw(DimensionMismatch("Cannot left-multiply matrix by vector"))
268-
269262
conj(D::Diagonal) = Diagonal(conj(D.diag))
270263
transpose(D::Diagonal) = D
271264
ctranspose(D::Diagonal) = conj(D)

base/linalg/matmul.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,13 @@ function (*){T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
8181
A_mul_B!(similar(x,TS,size(A,1)),A,x)
8282
end
8383

84+
# these will throw a DimensionMismatch unless B has 1 row (or 1 col for transposed case):
85+
A_mul_Bt(a::AbstractVector, B::AbstractMatrix) = A_mul_Bt(reshape(a,length(a),1),B)
86+
A_mul_Bt(A::AbstractMatrix, b::AbstractVector) = A_mul_Bt(A,reshape(b,length(b),1))
87+
A_mul_Bc(a::AbstractVector, B::AbstractMatrix) = A_mul_Bc(reshape(a,length(a),1),B)
88+
A_mul_Bc(A::AbstractMatrix, b::AbstractVector) = A_mul_Bc(A,reshape(b,length(b),1))
89+
(*)(a::AbstractVector, B::AbstractMatrix) = reshape(a,length(a),1)*B
90+
8491
A_mul_B!{T<:BlasFloat}(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T}) = gemv!(y, 'N', A, x)
8592
for elty in (Float32,Float64)
8693
@eval begin

base/linalg/rowvector.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -154,26 +154,19 @@ end
154154
sum(@inbounds(return rowvec[i]*vec[i]) for i = 1:length(vec))
155155
end
156156
@inline *(rowvec::RowVector, mat::AbstractMatrix) = transpose(mat.' * transpose(rowvec))
157-
*(vec::AbstractVector, mat::AbstractMatrix) = throw(DimensionMismatch(
158-
"Cannot left-multiply a matrix by a vector")) # Should become a deprecation
159157
*(::RowVector, ::RowVector) = throw(DimensionMismatch("Cannot multiply two transposed vectors"))
160158
@inline *(vec::AbstractVector, rowvec::RowVector) = vec .* rowvec
161159
*(vec::AbstractVector, rowvec::AbstractVector) = throw(DimensionMismatch("Cannot multiply two vectors"))
162-
*(mat::AbstractMatrix, rowvec::RowVector) = throw(DimensionMismatch("Cannot right-multiply matrix by transposed vector"))
163160

164161
# Transposed forms
165162
A_mul_Bt(::RowVector, ::AbstractVector) = throw(DimensionMismatch("Cannot multiply two transposed vectors"))
166163
@inline A_mul_Bt(rowvec::RowVector, mat::AbstractMatrix) = transpose(mat * transpose(rowvec))
167-
A_mul_Bt(vec::AbstractVector, mat::AbstractMatrix) = throw(DimensionMismatch(
168-
"Cannot left-multiply a matrix by a vector"))
169164
@inline A_mul_Bt(rowvec1::RowVector, rowvec2::RowVector) = rowvec1*transpose(rowvec2)
170165
A_mul_Bt(vec::AbstractVector, rowvec::RowVector) = throw(DimensionMismatch("Cannot multiply two vectors"))
171166
@inline A_mul_Bt(vec1::AbstractVector, vec2::AbstractVector) = vec1 * transpose(vec2)
172167
@inline A_mul_Bt(mat::AbstractMatrix, rowvec::RowVector) = mat * transpose(rowvec)
173168

174169
@inline At_mul_Bt(rowvec::RowVector, vec::AbstractVector) = transpose(rowvec) * transpose(vec)
175-
At_mul_Bt(rowvec::RowVector, mat::AbstractMatrix) = throw(DimensionMismatch(
176-
"Cannot left-multiply matrix by vector"))
177170
@inline At_mul_Bt(vec::AbstractVector, mat::AbstractMatrix) = transpose(mat * vec)
178171
At_mul_Bt(rowvec1::RowVector, rowvec2::RowVector) = throw(DimensionMismatch("Cannot multiply two vectors"))
179172
@inline At_mul_Bt(vec::AbstractVector, rowvec::RowVector) = transpose(vec)*transpose(rowvec)
@@ -182,42 +175,34 @@ At_mul_Bt(vec::AbstractVector, rowvec::AbstractVector) = throw(DimensionMismatch
182175
@inline At_mul_Bt(mat::AbstractMatrix, rowvec::RowVector) = mat.' * transpose(rowvec)
183176

184177
At_mul_B(::RowVector, ::AbstractVector) = throw(DimensionMismatch("Cannot multiply two vectors"))
185-
At_mul_B(rowvec::RowVector, mat::AbstractMatrix) = throw(DimensionMismatch(
186-
"Cannot left-multiply matrix by vector"))
187178
@inline At_mul_B(vec::AbstractVector, mat::AbstractMatrix) = transpose(At_mul_B(mat,vec))
188179
@inline At_mul_B(rowvec1::RowVector, rowvec2::RowVector) = transpose(rowvec1) * rowvec2
189180
At_mul_B(vec::AbstractVector, rowvec::RowVector) = throw(DimensionMismatch(
190181
"Cannot multiply two transposed vectors"))
191182
@inline At_mul_B{T<:Real}(vec1::AbstractVector{T}, vec2::AbstractVector{T}) =
192183
reduce(+, map(At_mul_B, vec1, vec2)) # Seems to be overloaded...
193184
@inline At_mul_B(vec1::AbstractVector, vec2::AbstractVector) = transpose(vec1) * vec2
194-
At_mul_B(mat::AbstractMatrix, rowvec::RowVector) = throw(DimensionMismatch(
195-
"Cannot right-multiply matrix by transposed vector"))
196185

197186
# Conjugated forms
198187
A_mul_Bc(::RowVector, ::AbstractVector) = throw(DimensionMismatch("Cannot multiply two transposed vectors"))
199188
@inline A_mul_Bc(rowvec::RowVector, mat::AbstractMatrix) = ctranspose(mat * ctranspose(rowvec))
200-
A_mul_Bc(vec::AbstractVector, mat::AbstractMatrix) = throw(DimensionMismatch("Cannot left-multiply a matrix by a vector"))
201189
@inline A_mul_Bc(rowvec1::RowVector, rowvec2::RowVector) = rowvec1 * ctranspose(rowvec2)
202190
A_mul_Bc(vec::AbstractVector, rowvec::RowVector) = throw(DimensionMismatch("Cannot multiply two vectors"))
203191
@inline A_mul_Bc(vec1::AbstractVector, vec2::AbstractVector) = vec1 * ctranspose(vec2)
204192
@inline A_mul_Bc(mat::AbstractMatrix, rowvec::RowVector) = mat * ctranspose(rowvec)
205193

206194
@inline Ac_mul_Bc(rowvec::RowVector, vec::AbstractVector) = ctranspose(rowvec) * ctranspose(vec)
207-
Ac_mul_Bc(rowvec::RowVector, mat::AbstractMatrix) = throw(DimensionMismatch("Cannot left-multiply matrix by vector"))
208195
@inline Ac_mul_Bc(vec::AbstractVector, mat::AbstractMatrix) = ctranspose(mat * vec)
209196
Ac_mul_Bc(rowvec1::RowVector, rowvec2::RowVector) = throw(DimensionMismatch("Cannot multiply two vectors"))
210197
@inline Ac_mul_Bc(vec::AbstractVector, rowvec::RowVector) = ctranspose(vec)*ctranspose(rowvec)
211198
Ac_mul_Bc(vec::AbstractVector, rowvec::AbstractVector) = throw(DimensionMismatch("Cannot multiply two transposed vectors"))
212199
@inline Ac_mul_Bc(mat::AbstractMatrix, rowvec::RowVector) = mat' * ctranspose(rowvec)
213200

214201
Ac_mul_B(::RowVector, ::AbstractVector) = throw(DimensionMismatch("Cannot multiply two vectors"))
215-
Ac_mul_B(rowvec::RowVector, mat::AbstractMatrix) = throw(DimensionMismatch("Cannot left-multiply matrix by vector"))
216202
@inline Ac_mul_B(vec::AbstractVector, mat::AbstractMatrix) = ctranspose(Ac_mul_B(mat,vec))
217203
@inline Ac_mul_B(rowvec1::RowVector, rowvec2::RowVector) = ctranspose(rowvec1) * rowvec2
218204
Ac_mul_B(vec::AbstractVector, rowvec::RowVector) = throw(DimensionMismatch("Cannot multiply two transposed vectors"))
219205
@inline Ac_mul_B(vec1::AbstractVector, vec2::AbstractVector) = ctranspose(vec1)*vec2
220-
Ac_mul_B(mat::AbstractMatrix, rowvec::RowVector) = throw(DimensionMismatch("Cannot right-multiply matrix by transposed vector"))
221206

222207
# Left Division #
223208

base/linalg/triangular.jl

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1585,7 +1585,7 @@ for (f, g) in ((:\, :A_ldiv_B!), (:Ac_ldiv_B, :Ac_ldiv_B!), (:At_ldiv_B, :At_ldi
15851585
end
15861586
### Multiplication with triangle to the rigth and hence lhs cannot be transposed.
15871587
for (f, g) in ((:*, :A_mul_B!), (:A_mul_Bc, :A_mul_Bc!), (:A_mul_Bt, :A_mul_Bt!))
1588-
@eval begin
1588+
mat != :AbstractVector && @eval begin
15891589
function ($f)(A::$mat, B::AbstractTriangular)
15901590
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
15911591
AA = similar(A, TAB, size(A))
@@ -1637,25 +1637,12 @@ At_mul_Bt(A::AbstractMatrix, B::AbstractTriangular) = A_mul_Bt(A.', B)
16371637

16381638
# Specializations for RowVector
16391639
@inline *(rowvec::RowVector, A::AbstractTriangular) = transpose(A * transpose(rowvec))
1640-
*(::AbstractTriangular, ::RowVector) = throw(DimensionMismatch("Cannot right-multiply matrix by transposed vector"))
1641-
16421640
@inline A_mul_Bt(rowvec::RowVector, A::AbstractTriangular) = transpose(A * transpose(rowvec))
16431641
@inline A_mul_Bt(A::AbstractTriangular, rowvec::RowVector) = A * transpose(rowvec)
1644-
16451642
@inline At_mul_Bt(A::AbstractTriangular, rowvec::RowVector) = A.' * transpose(rowvec)
1646-
At_mul_Bt(::RowVector, ::AbstractTriangular) = throw(DimensionMismatch("Cannot left-multiply matrix by vector"))
1647-
1648-
At_mul_B(::AbstractTriangular, ::RowVector) = throw(DimensionMismatch("Cannot right-multiply matrix by transposed vector"))
1649-
At_mul_B(::RowVector, ::AbstractTriangular) = throw(DimensionMismatch("Cannot left-multiply matrix by vector"))
1650-
16511643
@inline A_mul_Bc(rowvec::RowVector, A::AbstractTriangular) = ctranspose(A * ctranspose(rowvec))
16521644
@inline A_mul_Bc(A::AbstractTriangular, rowvec::RowVector) = A * ctranspose(rowvec)
1653-
16541645
@inline Ac_mul_Bc(A::AbstractTriangular, rowvec::RowVector) = A' * ctranspose(rowvec)
1655-
Ac_mul_Bc(::RowVector, ::AbstractTriangular) = throw(DimensionMismatch("Cannot left-multiply matrix by vector"))
1656-
1657-
Ac_mul_B(::AbstractTriangular, ::RowVector) = throw(DimensionMismatch("Cannot right-multiply matrix by transposed vector"))
1658-
Ac_mul_B(::RowVector, ::AbstractTriangular) = throw(DimensionMismatch("Cannot left-multiply matrix by vector"))
16591646

16601647
@inline /(rowvec::RowVector, A::Union{UpperTriangular,LowerTriangular}) = transpose(transpose(A) \ transpose(rowvec))
16611648
@inline /(rowvec::RowVector, A::Union{UnitUpperTriangular,UnitLowerTriangular}) = transpose(transpose(A) \ transpose(rowvec))

test/linalg/rowvector.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,15 @@ end
104104

105105
@test (rv*v) === 14
106106
@test (rv*mat)::RowVector == [1 4 9]
107-
@test_throws DimensionMismatch [1]*reshape([1],(1,1)) # no longer permitted
107+
@test [1]*reshape([1],(1,1)) == reshape([1],(1,1))
108108
@test_throws DimensionMismatch rv*rv
109109
@test (v*rv)::Matrix == [1 2 3; 2 4 6; 3 6 9]
110110
@test_throws DimensionMismatch v*v # Was previously a missing method error, now an error message
111111
@test_throws DimensionMismatch mat*rv
112112

113113
@test_throws DimensionMismatch rv*v.'
114114
@test (rv*mat.')::RowVector == [1 4 9]
115-
@test_throws DimensionMismatch [1]*reshape([1],(1,1)).' # no longer permitted
115+
@test [1]*reshape([1],(1,1)).' == reshape([1],(1,1))
116116
@test rv*rv.' === 14
117117
@test_throws DimensionMismatch v*rv.'
118118
@test (v*v.')::Matrix == [1 2 3; 2 4 6; 3 6 9]
@@ -142,7 +142,7 @@ end
142142

143143
@test_throws DimensionMismatch cz*z'
144144
@test (cz*mat')::RowVector == [-2im 4 9]
145-
@test_throws DimensionMismatch [1]*reshape([1],(1,1))' # no longer permitted
145+
@test [1]*reshape([1],(1,1))' == reshape([1],(1,1))
146146
@test cz*cz' === 15 + 0im
147147
@test_throws DimensionMismatch z*cz'
148148
@test (z*z')::Matrix == [2 2+2im 3+3im; 2-2im 4 6; 3-3im 6 9]
@@ -238,5 +238,18 @@ end
238238
@test_throws DimensionMismatch ut\rv
239239
end
240240

241+
# issue #20389
242+
@testset "1 row/col vec*mat" begin
243+
let x=[1,2,3], A=ones(1,4), y=x', B=A', C=x.*A
244+
@test x*A == y'*A == x*B' == y'*B' == C
245+
@test A'*x' == A'*y == B*x' == B*y == C'
246+
end
247+
end
248+
@testset "complex 1 row/col vec*mat" begin
249+
let x=[1,2,3]*im, A=ones(1,4)*im, y=x', B=A', C=x.*A
250+
@test x*A == y'*A == x*B' == y'*B' == C
251+
@test A'*x' == A'*y == B*x' == B*y == C'
252+
end
253+
end
241254

242255
end # @testset "RowVector"

0 commit comments

Comments
 (0)