Skip to content

Commit 2425ae7

Browse files
dkarraschandreasnoack
authored andcommitted
add generalized dot product (#32739)
* add generalized dot product * add generalized dot for Adjoint and Transpose * add "generalized" dot for UniformScalings * fix adjoint/transpose in tridiags * improve generic dot, add tests * fix typos, optimize *diag, require_one_based_indexing * add tests * fix typos in triangular and tridiag * fix BigFloat tests in triangular * add sparse tests (and minor fix) * handle block arrays of varying lengths * make generalized dot act recursively * add generalized dot for symmetric/Hermitian matrices * fix triangular case * more complete tests for Symmetric/Hermitian * fix UnitLowerTriangular case * fix complex case in symmetric gendot * interpret dot(x, A, y) as dot(A'x, y), test accordingly * use correct tolerance in triangular tests * add gendot for UpperHessenberg, and tests * fix docstring of 3-arg dot * add generic 3-arg dot for UniformScaling * add generic fallback This should be only relevant to cases like `dot(x, J, y)`, where `x` and `y` are vectors of quaternion vectors, and `J` is a quaternion `UniformScaling`. * add gendot with middle argument Number * attach docstring to generic fallback * simplify scalar/uniform scaling gendot * merge NEWS * use dot(A'x,y) for fallback * use accessor functions in sparse code, generalize to Abstract..., tests * revert fallback definition * remove redundant Number version * write out loops in symmetric/hermitian case * test quaternions in uniformscaling gendot * fix uniformscaling test * add compat note and jldoctest
1 parent f4c23df commit 2425ae7

18 files changed

+412
-1
lines changed

NEWS.md

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ Standard library changes
3939

4040
* `qr` and `qr!` functions support `blocksize` keyword argument ([#33053]).
4141

42+
* `dot` now admits a 3-argument method `dot(x, A, y)` to compute generalized dot products `dot(x, A*y)`, but without computing and storing the intermediate result `A*y` ([#32739]).
4243

4344
#### SparseArrays
4445

stdlib/LinearAlgebra/src/bidiag.jl

+30
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,36 @@ function *(A::SymTridiagonal, B::Diagonal)
647647
A_mul_B_td!(Tridiagonal(zeros(TS, size(A, 1)-1), zeros(TS, size(A, 1)), zeros(TS, size(A, 1)-1)), A, B)
648648
end
649649

650+
function dot(x::AbstractVector, B::Bidiagonal, y::AbstractVector)
651+
require_one_based_indexing(x, y)
652+
nx, ny = length(x), length(y)
653+
(nx == size(B, 1) == ny) || throw(DimensionMismatch())
654+
if iszero(nx)
655+
return dot(zero(eltype(x)), zero(eltype(B)), zero(eltype(y)))
656+
end
657+
ev, dv = B.ev, B.dv
658+
if B.uplo == 'U'
659+
x₀ = x[1]
660+
r = dot(x[1], dv[1], y[1])
661+
@inbounds for j in 2:nx-1
662+
x₋, x₀ = x₀, x[j]
663+
r += dot(adjoint(ev[j-1])*x₋ + adjoint(dv[j])*x₀, y[j])
664+
end
665+
r += dot(adjoint(ev[nx-1])*x₀ + adjoint(dv[nx])*x[nx], y[nx])
666+
return r
667+
else # B.uplo == 'L'
668+
x₀ = x[1]
669+
x₊ = x[2]
670+
r = dot(adjoint(dv[1])*x₀ + adjoint(ev[1])*x₊, y[1])
671+
@inbounds for j in 2:nx-1
672+
x₀, x₊ = x₊, x[j+1]
673+
r += dot(adjoint(dv[j])*x₀ + adjoint(ev[j])*x₊, y[j])
674+
end
675+
r += dot(x₊, dv[nx], y[nx])
676+
return r
677+
end
678+
end
679+
650680
#Linear solvers
651681
ldiv!(A::Union{Bidiagonal, AbstractTriangular}, b::AbstractVector) = naivesub!(A, b)
652682
ldiv!(A::Transpose{<:Any,<:Bidiagonal}, b::AbstractVector) = ldiv!(copy(A), b)

stdlib/LinearAlgebra/src/diagonal.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -637,11 +637,14 @@ end
637637

638638
# disambiguation methods: * of Diagonal and Adj/Trans AbsVec
639639
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal) = Adjoint(map((t,s) -> t'*s, D.diag, parent(x)))
640+
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal) = Transpose(map((t,s) -> transpose(t)*s, D.diag, parent(x)))
640641
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) =
641642
mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y))
642-
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal) = Transpose(map((t,s) -> transpose(t)*s, D.diag, parent(x)))
643643
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) =
644644
mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y))
645+
function dot(x::AbstractVector, D::Diagonal, y::AbstractVector)
646+
mapreduce(t -> dot(t[1], t[2], t[3]), +, zip(x, D.diag, y))
647+
end
645648

646649
function cholesky!(A::Diagonal, ::Val{false} = Val(false); check::Bool = true)
647650
info = 0

stdlib/LinearAlgebra/src/generic.jl

+45
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,51 @@ function dot(x::AbstractArray, y::AbstractArray)
874874
s
875875
end
876876

877+
"""
878+
dot(x, A, y)
879+
880+
Compute the generalized dot product `dot(x, A*y)` between two vectors `x` and `y`,
881+
without storing the intermediate result of `A*y`. As for the two-argument
882+
[`dot(_,_)`](@ref), this acts recursively. Moreover, for complex vectors, the
883+
first vector is conjugated.
884+
885+
!!! compat "Julia 1.4"
886+
Three-argument `dot` requires at least Julia 1.4.
887+
888+
# Examples
889+
```jldoctest
890+
julia> dot([1; 1], [1 2; 3 4], [2; 3])
891+
26
892+
893+
julia> dot(1:5, reshape(1:25, 5, 5), 2:6)
894+
4850
895+
896+
julia> ⋅(1:5, reshape(1:25, 5, 5), 2:6) == dot(1:5, reshape(1:25, 5, 5), 2:6)
897+
true
898+
```
899+
"""
900+
dot(x, A, y) = dot(x, A*y) # generic fallback for cases that are not covered by specialized methods
901+
902+
function dot(x::AbstractVector, A::AbstractMatrix, y::AbstractVector)
903+
(axes(x)..., axes(y)...) == axes(A) || throw(DimensionMismatch())
904+
T = typeof(dot(first(x), first(A), first(y)))
905+
s = zero(T)
906+
i₁ = first(eachindex(x))
907+
x₁ = first(x)
908+
@inbounds for j in eachindex(y)
909+
yj = y[j]
910+
if !iszero(yj)
911+
temp = zero(adjoint(A[i₁,j]) * x₁)
912+
@simd for i in eachindex(x)
913+
temp += adjoint(A[i,j]) * x[i]
914+
end
915+
s += dot(temp, yj)
916+
end
917+
end
918+
return s
919+
end
920+
dot(x::AbstractVector, adjA::Adjoint, y::AbstractVector) = adjoint(dot(y, adjA.parent, x))
921+
dot(x::AbstractVector, transA::Transpose{<:Real}, y::AbstractVector) = adjoint(dot(y, transA.parent, x))
877922

878923
###########################################################################################
879924

stdlib/LinearAlgebra/src/hessenberg.jl

+31
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,37 @@ function logabsdet(F::UpperHessenberg; shift::Number=false)
284284
return (logdeterminant, P)
285285
end
286286

287+
function dot(x::AbstractVector, H::UpperHessenberg, y::AbstractVector)
288+
require_one_based_indexing(x, y)
289+
m = size(H, 1)
290+
(length(x) == m == length(y)) || throw(DimensionMismatch())
291+
if iszero(m)
292+
return dot(zero(eltype(x)), zero(eltype(H)), zero(eltype(y)))
293+
end
294+
x₁ = x[1]
295+
r = dot(x₁, H[1,1], y[1])
296+
r += dot(x[2], H[2,1], y[1])
297+
@inbounds for j in 2:m-1
298+
yj = y[j]
299+
if !iszero(yj)
300+
temp = adjoint(H[1,j]) * x₁
301+
@simd for i in 2:j+1
302+
temp += adjoint(H[i,j]) * x[i]
303+
end
304+
r += dot(temp, yj)
305+
end
306+
end
307+
ym = y[m]
308+
if !iszero(ym)
309+
temp = adjoint(H[1,m]) * x₁
310+
@simd for i in 2:m
311+
temp += adjoint(H[i,m]) * x[i]
312+
end
313+
r += dot(temp, ym)
314+
end
315+
return r
316+
end
317+
287318
######################################################################################
288319
# Hessenberg factorizations Q(H+μI)Q' of A+μI:
289320

stdlib/LinearAlgebra/src/symmetric.jl

+25
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,31 @@ end
457457

458458
*(A::HermOrSym, B::HermOrSym) = A * copyto!(similar(parent(B)), B)
459459

460+
function dot(x::AbstractVector, A::RealHermSymComplexHerm, y::AbstractVector)
461+
require_one_based_indexing(x, y)
462+
(length(x) == length(y) == size(A, 1)) || throw(DimensionMismatch())
463+
data = A.data
464+
r = zero(eltype(x)) * zero(eltype(A)) * zero(eltype(y))
465+
if A.uplo == 'U'
466+
@inbounds for j = 1:length(y)
467+
r += dot(x[j], real(data[j,j]), y[j])
468+
@simd for i = 1:j-1
469+
Aij = data[i,j]
470+
r += dot(x[i], Aij, y[j]) + dot(x[j], adjoint(Aij), y[i])
471+
end
472+
end
473+
else # A.uplo == 'L'
474+
@inbounds for j = 1:length(y)
475+
r += dot(x[j], real(data[j,j]), y[j])
476+
@simd for i = j+1:length(y)
477+
Aij = data[i,j]
478+
r += dot(x[i], Aij, y[j]) + dot(x[j], adjoint(Aij), y[i])
479+
end
480+
end
481+
end
482+
return r
483+
end
484+
460485
# Fallbacks to avoid generic_matvecmul!/generic_matmatmul!
461486
## Symmetric{<:Number} and Hermitian{<:Real} are invariant to transpose; peel off the t
462487
*(transA::Transpose{<:Any,<:RealHermSymComplexSym}, B::AbstractVector) = transA.parent * B

stdlib/LinearAlgebra/src/triangular.jl

+84
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,90 @@ end
545545
rmul!(A::Union{UpperTriangular,LowerTriangular}, c::Number) = mul!(A, A, c)
546546
lmul!(c::Number, A::Union{UpperTriangular,LowerTriangular}) = mul!(A, c, A)
547547

548+
function dot(x::AbstractVector, A::UpperTriangular, y::AbstractVector)
549+
require_one_based_indexing(x, y)
550+
m = size(A, 1)
551+
(length(x) == m == length(y)) || throw(DimensionMismatch())
552+
if iszero(m)
553+
return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
554+
end
555+
x₁ = x[1]
556+
r = dot(x₁, A[1,1], y[1])
557+
@inbounds for j in 2:m
558+
yj = y[j]
559+
if !iszero(yj)
560+
temp = adjoint(A[1,j]) * x₁
561+
@simd for i in 2:j
562+
temp += adjoint(A[i,j]) * x[i]
563+
end
564+
r += dot(temp, yj)
565+
end
566+
end
567+
return r
568+
end
569+
function dot(x::AbstractVector, A::UnitUpperTriangular, y::AbstractVector)
570+
require_one_based_indexing(x, y)
571+
m = size(A, 1)
572+
(length(x) == m == length(y)) || throw(DimensionMismatch())
573+
if iszero(m)
574+
return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
575+
end
576+
x₁ = first(x)
577+
r = dot(x₁, y[1])
578+
@inbounds for j in 2:m
579+
yj = y[j]
580+
if !iszero(yj)
581+
temp = adjoint(A[1,j]) * x₁
582+
@simd for i in 2:j-1
583+
temp += adjoint(A[i,j]) * x[i]
584+
end
585+
r += dot(temp, yj)
586+
r += dot(x[j], yj)
587+
end
588+
end
589+
return r
590+
end
591+
function dot(x::AbstractVector, A::LowerTriangular, y::AbstractVector)
592+
require_one_based_indexing(x, y)
593+
m = size(A, 1)
594+
(length(x) == m == length(y)) || throw(DimensionMismatch())
595+
if iszero(m)
596+
return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
597+
end
598+
r = zero(typeof(dot(first(x), first(A), first(y))))
599+
@inbounds for j in 1:m
600+
yj = y[j]
601+
if !iszero(yj)
602+
temp = adjoint(A[j,j]) * x[j]
603+
@simd for i in j+1:m
604+
temp += adjoint(A[i,j]) * x[i]
605+
end
606+
r += dot(temp, yj)
607+
end
608+
end
609+
return r
610+
end
611+
function dot(x::AbstractVector, A::UnitLowerTriangular, y::AbstractVector)
612+
require_one_based_indexing(x, y)
613+
m = size(A, 1)
614+
(length(x) == m == length(y)) || throw(DimensionMismatch())
615+
if iszero(m)
616+
return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
617+
end
618+
r = zero(typeof(dot(first(x), first(y))))
619+
@inbounds for j in 1:m
620+
yj = y[j]
621+
if !iszero(yj)
622+
temp = x[j]
623+
@simd for i in j+1:m
624+
temp += adjoint(A[i,j]) * x[i]
625+
end
626+
r += dot(temp, yj)
627+
end
628+
end
629+
return r
630+
end
631+
548632
fillstored!(A::LowerTriangular, x) = (fillband!(A.data, x, 1-size(A,1), 0); A)
549633
fillstored!(A::UnitLowerTriangular, x) = (fillband!(A.data, x, 1-size(A,1), -1); A)
550634
fillstored!(A::UpperTriangular, x) = (fillband!(A.data, x, 0, size(A,2)-1); A)

stdlib/LinearAlgebra/src/tridiag.jl

+40
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,27 @@ end
202202
return C
203203
end
204204

205+
function dot(x::AbstractVector, S::SymTridiagonal, y::AbstractVector)
206+
require_one_based_indexing(x, y)
207+
nx, ny = length(x), length(y)
208+
(nx == size(S, 1) == ny) || throw(DimensionMismatch())
209+
if iszero(nx)
210+
return dot(zero(eltype(x)), zero(eltype(S)), zero(eltype(y)))
211+
end
212+
dv, ev = S.dv, S.ev
213+
x₀ = x[1]
214+
x₊ = x[2]
215+
sub = transpose(ev[1])
216+
r = dot(adjoint(dv[1])*x₀ + adjoint(sub)*x₊, y[1])
217+
@inbounds for j in 2:nx-1
218+
x₋, x₀, x₊ = x₀, x₊, x[j+1]
219+
sup, sub = transpose(sub), transpose(ev[j])
220+
r += dot(adjoint(sup)*x₋ + adjoint(dv[j])*x₀ + adjoint(sub)*x₊, y[j])
221+
end
222+
r += dot(adjoint(transpose(sub))*x₀ + adjoint(dv[nx])*x₊, y[nx])
223+
return r
224+
end
225+
205226
(\)(T::SymTridiagonal, B::StridedVecOrMat) = ldlt(T)\B
206227

207228
# division with optional shift for use in shifted-Hessenberg solvers (hessenberg.jl):
@@ -657,3 +678,22 @@ end
657678

658679
Base._sum(A::Tridiagonal, ::Colon) = sum(A.d) + sum(A.dl) + sum(A.du)
659680
Base._sum(A::SymTridiagonal, ::Colon) = sum(A.dv) + 2sum(A.ev)
681+
682+
function dot(x::AbstractVector, A::Tridiagonal, y::AbstractVector)
683+
require_one_based_indexing(x, y)
684+
nx, ny = length(x), length(y)
685+
(nx == size(A, 1) == ny) || throw(DimensionMismatch())
686+
if iszero(nx)
687+
return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
688+
end
689+
x₀ = x[1]
690+
x₊ = x[2]
691+
dl, d, du = A.dl, A.d, A.du
692+
r = dot(adjoint(d[1])*x₀ + adjoint(dl[1])*x₊, y[1])
693+
@inbounds for j in 2:nx-1
694+
x₋, x₀, x₊ = x₀, x₊, x[j+1]
695+
r += dot(adjoint(du[j-1])*x₋ + adjoint(d[j])*x₀ + adjoint(dl[j])*x₊, y[j])
696+
end
697+
r += dot(adjoint(du[nx-1])*x₀ + adjoint(d[nx])*x₊, y[nx])
698+
return r
699+
end

stdlib/LinearAlgebra/src/uniformscaling.jl

+4
Original file line numberDiff line numberDiff line change
@@ -400,3 +400,7 @@ Array(s::UniformScaling, dims::Dims{2}) = Matrix(s, dims)
400400
## Diagonal construction from UniformScaling
401401
Diagonal{T}(s::UniformScaling, m::Integer) where {T} = Diagonal{T}(fill(T(s.λ), m))
402402
Diagonal(s::UniformScaling, m::Integer) = Diagonal{eltype(s)}(s, m)
403+
404+
dot(x::AbstractVector, J::UniformScaling, y::AbstractVector) = dot(x, J.λ, y)
405+
dot(x::AbstractVector, a::Number, y::AbstractVector) = sum(t -> dot(t[1], a, t[2]), zip(x, y))
406+
dot(x::AbstractVector, a::Union{Real,Complex}, y::AbstractVector) = a*dot(x, y)

stdlib/LinearAlgebra/test/bidiag.jl

+13
Original file line numberDiff line numberDiff line change
@@ -455,4 +455,17 @@ end
455455
@test A * Tridiagonal(ones(1, 1)) == A
456456
end
457457

458+
@testset "generalized dot" begin
459+
for elty in (Float64, ComplexF64)
460+
dv = randn(elty, 5)
461+
ev = randn(elty, 4)
462+
x = randn(elty, 5)
463+
y = randn(elty, 5)
464+
for uplo in (:U, :L)
465+
B = Bidiagonal(dv, ev, uplo)
466+
@test dot(x, B, y) dot(B'x, y) dot(x, Matrix(B), y)
467+
end
468+
end
469+
end
470+
458471
end # module TestBidiagonal

stdlib/LinearAlgebra/test/generic.jl

+28
Original file line numberDiff line numberDiff line change
@@ -409,4 +409,32 @@ end
409409
@test all(!isnan, lmul!(false, Any[NaN]))
410410
end
411411

412+
@testset "generalized dot #32739" begin
413+
for elty in (Int, Float32, Float64, BigFloat, Complex{Float32}, Complex{Float64}, Complex{BigFloat})
414+
n = 10
415+
if elty <: Int
416+
A = rand(-n:n, n, n)
417+
x = rand(-n:n, n)
418+
y = rand(-n:n, n)
419+
elseif elty <: Real
420+
A = convert(Matrix{elty}, randn(n,n))
421+
x = rand(elty, n)
422+
y = rand(elty, n)
423+
else
424+
A = convert(Matrix{elty}, complex.(randn(n,n), randn(n,n)))
425+
x = rand(elty, n)
426+
y = rand(elty, n)
427+
end
428+
@test dot(x, A, y) dot(A'x, y) *(x', A, y) (x'A)*y
429+
@test dot(x, A', y) dot(A*x, y) *(x', A', y) (x'A')*y
430+
elty <: Real && @test dot(x, transpose(A), y) dot(x, transpose(A)*y) *(x', transpose(A), y) (x'*transpose(A))*y
431+
B = reshape([A], 1, 1)
432+
x = [x]
433+
y = [y]
434+
@test dot(x, B, y) dot(B'x, y)
435+
@test dot(x, B', y) dot(B*x, y)
436+
elty <: Real && @test dot(x, transpose(B), y) dot(x, transpose(B)*y)
437+
end
438+
end
439+
412440
end # module TestGeneric

stdlib/LinearAlgebra/test/hessenberg.jl

+5
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ let n = 10
8888
@test det(H + shift*I) det(A + shift*I)
8989
@test logabsdet(H + shift*I) logabsdet(A + shift*I)
9090
end
91+
92+
HM = Matrix(h)
93+
@test dot(b, h, b) dot(h'b, b) dot(b, HM, b) dot(HM'b, b)
94+
c = b .+ 1
95+
@test dot(b, h, c) dot(h'b, c) dot(b, HM, c) dot(HM'b, c)
9196
end
9297
end
9398

0 commit comments

Comments
 (0)