Skip to content

Commit 97dc858

Browse files
committed
Merge pull request #15505 from jw3126/bitrimatmul
faster A_mul_B! and * involving bidiagonal and tridiagonal matrices
2 parents 5d52f02 + 7182b87 commit 97dc858

File tree

8 files changed

+333
-36
lines changed

8 files changed

+333
-36
lines changed

base/linalg/bidiag.jl

Lines changed: 172 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,142 @@ end
224224
/(A::Bidiagonal, B::Number) = Bidiagonal(A.dv/B, A.ev/B, A.isupper)
225225
==(A::Bidiagonal, B::Bidiagonal) = (A.dv==B.dv) && (A.ev==B.ev) && (A.isupper==B.isupper)
226226

227-
SpecialMatrix = Union{Bidiagonal, SymTridiagonal, Tridiagonal, AbstractTriangular}
228-
*(A::SpecialMatrix, B::SpecialMatrix)=full(A)*full(B)
227+
228+
BiTriSym = Union{Bidiagonal, Tridiagonal, SymTridiagonal}
229+
BiTri = Union{Bidiagonal, Tridiagonal}
230+
A_mul_B!(C::AbstractMatrix, A::SymTridiagonal, B::BiTriSym) = A_mul_B_td!(C, A, B)
231+
A_mul_B!(C::AbstractMatrix, A::BiTri, B::BiTriSym) = A_mul_B_td!(C, A, B)
232+
A_mul_B!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym) = A_mul_B_td!(C, A, B)
233+
A_mul_B!(C::AbstractMatrix, A::AbstractTriangular, B::BiTriSym) = A_mul_B_td!(C, A, B)
234+
A_mul_B!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym) = A_mul_B_td!(C, A, B)
235+
A_mul_B!(C::AbstractVector, A::BiTri, B::AbstractVector) = A_mul_B_td!(C, A, B)
236+
A_mul_B!(C::AbstractMatrix, A::BiTri, B::AbstractVecOrMat) = A_mul_B_td!(C, A, B)
237+
A_mul_B!(C::AbstractVecOrMat, A::BiTri, B::AbstractVecOrMat) = A_mul_B_td!(C, A, B)
238+
239+
240+
function check_A_mul_B!_sizes(C, A, B)
241+
nA, mA = size(A)
242+
nB, mB = size(B)
243+
nC, mC = size(C)
244+
if !(nA == nC)
245+
throw(DimensionMismatch("Sizes size(A)=$(size(A)) and size(C) = $(size(C)) must match at first entry."))
246+
elseif !(mA == nB)
247+
throw(DimensionMismatch("Second entry of size(A)=$(size(A)) and first entry of size(B) = $(size(B)) must match."))
248+
elseif !(mB == mC)
249+
throw(DimensionMismatch("Sizes size(B)=$(size(B)) and size(C) = $(size(C)) must match at first second entry."))
250+
end
251+
end
252+
253+
function A_mul_B_td!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym)
254+
check_A_mul_B!_sizes(C, A, B)
255+
n = size(A,1)
256+
n <= 3 && return A_mul_B!(C, full(A), full(B))
257+
fill!(C, zero(eltype(C)))
258+
Al = diag(A, -1)
259+
Ad = diag(A, 0)
260+
Au = diag(A, 1)
261+
Bl = diag(B, -1)
262+
Bd = diag(B, 0)
263+
Bu = diag(B, 1)
264+
@inbounds begin
265+
# first row of C
266+
C[1,1] = A[1,1]*B[1,1] + A[1, 2]*B[2, 1]
267+
C[1,2] = A[1,1]*B[1,2] + A[1,2]*B[2,2]
268+
C[1,3] = A[1,2]*B[2,3]
269+
# second row of C
270+
C[2,1] = A[2,1]*B[1,1] + A[2,2]*B[2,1]
271+
C[2,2] = A[2,1]*B[1,2] + A[2,2]*B[2,2] + A[2,3]*B[3,2]
272+
C[2,3] = A[2,2]*B[2,3] + A[2,3]*B[3,3]
273+
C[2,4] = A[2,3]*B[3,4]
274+
for j in 3:n-2
275+
Ajj₋1 = Al[j-1]
276+
Ajj = Ad[j]
277+
Ajj₊1 = Au[j]
278+
Bj₋1j₋2 = Bl[j-2]
279+
Bj₋1j₋1 = Bd[j-1]
280+
Bj₋1j = Bu[j-1]
281+
Bjj₋1 = Bl[j-1]
282+
Bjj = Bd[j]
283+
Bjj₊1 = Bu[j]
284+
Bj₊1j = Bl[j]
285+
Bj₊1j₊1 = Bd[j+1]
286+
Bj₊1j₊2 = Bu[j+1]
287+
C[j,j-2] = Ajj₋1*Bj₋1j₋2
288+
C[j, j-1] = Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1
289+
C[j, j ] = Ajj₋1*Bj₋1j + Ajj*Bjj + Ajj₊1*Bj₊1j
290+
C[j, j+1] = Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1
291+
C[j, j+2] = Ajj₊1*Bj₊1j₊2
292+
end
293+
# row before last of C
294+
C[n-1,n-3] = A[n-1,n-2]*B[n-2,n-3]
295+
C[n-1,n-2] = A[n-1,n-1]*B[n-1,n-2] + A[n-1,n-2]*B[n-2,n-2]
296+
C[n-1,n-1] = A[n-1,n-2]*B[n-2,n-1] + A[n-1,n-1]*B[n-1,n-1] + A[n-1,n]*B[n,n-1]
297+
C[n-1,n ] = A[n-1,n-1]*B[n-1,n ] + A[n-1, n]*B[n ,n ]
298+
# last row of C
299+
C[n,n-2] = A[n,n-1]*B[n-1,n-2]
300+
C[n,n-1] = A[n,n-1]*B[n-1,n-1] + A[n,n]*B[n,n-1]
301+
C[n,n ] = A[n,n-1]*B[n-1,n ] + A[n,n]*B[n,n ]
302+
end # inbounds
303+
C
304+
end
305+
306+
function A_mul_B_td!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat)
307+
nA = size(A,1)
308+
nB = size(B,2)
309+
if !(size(C,1) == size(B,1) == nA)
310+
throw(DimensionMismatch("A has first dimension $nA, B has $(size(B,1)), C has $(size(C,1)) but all must match"))
311+
end
312+
if size(C,2) != nB
313+
throw(DimensionMismatch("A has second dimension $nA, B has $(size(B,2)), C has $(size(C,2)) but all must match"))
314+
end
315+
nA <= 3 && return A_mul_B!(C, full(A), full(B))
316+
l = diag(A, -1)
317+
d = diag(A, 0)
318+
u = diag(A, 1)
319+
@inbounds begin
320+
for j = 1:nB
321+
b₀, b₊ = B[1, j], B[2, j]
322+
C[1, j] = d[1]*b₀ + u[1]*b₊
323+
for i = 2:nA - 1
324+
b₋, b₀, b₊ = b₀, b₊, B[i + 1, j]
325+
C[i, j] = l[i - 1]*b₋ + d[i]*b₀ + u[i]*b₊
326+
end
327+
C[nA, j] = l[nA - 1]*b₀ + d[nA]*b₊
328+
end
329+
end
330+
C
331+
end
332+
333+
function A_mul_B_td!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym)
334+
check_A_mul_B!_sizes(C, A, B)
335+
n = size(A,1)
336+
n <= 3 && return A_mul_B!(C, full(A), full(B))
337+
m = size(B,2)
338+
Bl = diag(B, -1)
339+
Bd = diag(B, 0)
340+
Bu = diag(B, 1)
341+
@inbounds begin
342+
# first and last column of C
343+
B11 = Bd[1]
344+
B21 = Bl[1]
345+
Bmm = Bd[m]
346+
Bm₋1m = Bu[m-1]
347+
for i in 1:n
348+
C[i, 1] = A[i,1] * B11 + A[i, 2] * B21
349+
C[i, m] = A[i, m-1] * Bm₋1m + A[i, m] * Bmm
350+
end
351+
# middle columns of C
352+
for j = 2:m-1
353+
Bj₋1j = Bu[j-1]
354+
Bjj = Bd[j]
355+
Bj₊1j = Bl[j]
356+
for i = 1:n
357+
C[i, j] = A[i, j-1] * Bj₋1j + A[i, j]*Bjj + A[i, j+1] * Bj₊1j
358+
end
359+
end
360+
end # inbounds
361+
C
362+
end
229363

230364
#Generic multiplication
231365
for func in (:*, :Ac_mul_B, :A_mul_Bc, :/, :A_rdiv_Bc)
@@ -329,3 +463,39 @@ function eigvecs{T}(M::Bidiagonal{T})
329463
Q #Actually Triangular
330464
end
331465
eigfact(M::Bidiagonal) = Eigen(eigvals(M), eigvecs(M))
466+
467+
# fill! methods
468+
_valuefields{T <: Diagonal}(S::Type{T}) = [:diag]
469+
_valuefields{T <: Bidiagonal}(S::Type{T}) = [:dv, :ev]
470+
_valuefields{T <: Tridiagonal}(S::Type{T}) = [:dl, :d, :du]
471+
_valuefields{T <: SymTridiagonal}(S::Type{T}) = [:dv, :ev]
472+
_valuefields{T <: AbstractTriangular}(S::Type{T}) = [:data]
473+
474+
SpecialArrays = Union{Diagonal,
475+
Bidiagonal,
476+
Tridiagonal,
477+
SymTridiagonal,
478+
AbstractTriangular}
479+
480+
@generated function fillslots!(A::SpecialArrays, x)
481+
ex = :(xT = convert(eltype(A), x))
482+
for field in _valuefields(A)
483+
ex = :($ex; fill!(A.$field, xT))
484+
end
485+
:($ex;return A)
486+
end
487+
488+
# for historical reasons:
489+
fill!(a::AbstractTriangular, x) = fillslots!(a, x);
490+
fill!(D::Diagonal, x) = fillslots!(D, x);
491+
492+
_small_enough(A::Bidiagonal) = size(A, 1) <= 1
493+
_small_enough(A::Tridiagonal) = size(A, 1) <= 2
494+
_small_enough(A::SymTridiagonal) = size(A, 1) <= 2
495+
496+
function fill!(A::Union{Bidiagonal, Tridiagonal, SymTridiagonal} ,x)
497+
xT = convert(eltype(A), x)
498+
(xT == zero(eltype(A)) || _small_enough(A)) && return fillslots!(A, xT)
499+
throw(ArgumentError("Array A of type $(typeof(A)) and size $(size(A)) can
500+
not be filled with x=$x, since some of its entries are constrained."))
501+
end

base/linalg/diagonal.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ function size(D::Diagonal,d::Integer)
2929
return d<=2 ? length(D.diag) : 1
3030
end
3131

32-
fill!(D::Diagonal, x) = (fill!(D.diag, x); D)
33-
3432
full(D::Diagonal) = diagm(D.diag)
3533

3634
@inline function getindex(D::Diagonal, i::Int, j::Int)

base/linalg/triangular.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@ imag(A::UnitUpperTriangular) = UpperTriangular(triu!(imag(A.data),1))
5151
full(A::AbstractTriangular) = convert(Matrix, A)
5252
parent(A::AbstractTriangular) = A.data
5353

54-
fill!(A::AbstractTriangular, x) = (fill!(A.data, x); A)
55-
5654
# then handle all methods that requires specific handling of upper/lower and unit diagonal
5755

5856
function convert{Tret,T,S}(::Type{Matrix{Tret}}, A::LowerTriangular{T,S})
@@ -380,6 +378,8 @@ scale!(c::Number, A::Union{UpperTriangular,LowerTriangular}) = scale!(A,c)
380378
######################
381379

382380
A_mul_B!(A::Tridiagonal, B::AbstractTriangular) = A*full!(B)
381+
A_mul_B!(C::AbstractMatrix, A::AbstractTriangular, B::Tridiagonal) = A_mul_B!(C, full(A), B)
382+
A_mul_B!(C::AbstractMatrix, A::Tridiagonal, B::AbstractTriangular) = A_mul_B!(C, A, full(B))
383383
A_mul_B!(C::AbstractVector, A::AbstractTriangular, B::AbstractVector) = A_mul_B!(A, copy!(C, B))
384384
A_mul_B!(C::AbstractMatrix, A::AbstractTriangular, B::AbstractVecOrMat) = A_mul_B!(A, copy!(C, B))
385385
A_mul_B!(C::AbstractVecOrMat, A::AbstractTriangular, B::AbstractVecOrMat) = A_mul_B!(A, copy!(C, B))

base/linalg/tridiag.jl

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -497,33 +497,3 @@ function convert{T}(::Type{SymTridiagonal{T}}, M::Tridiagonal)
497497
throw(ArgumentError("Tridiagonal is not symmetric, cannot convert to SymTridiagonal"))
498498
end
499499
end
500-
501-
A_mul_B!(C::AbstractVector, A::Tridiagonal, B::AbstractVector) = A_mul_B_td!(C, A, B)
502-
A_mul_B!(C::AbstractMatrix, A::Tridiagonal, B::AbstractVecOrMat) = A_mul_B_td!(C, A, B)
503-
A_mul_B!(C::AbstractVecOrMat, A::Tridiagonal, B::AbstractVecOrMat) = A_mul_B_td!(C, A, B)
504-
505-
function A_mul_B_td!(C::AbstractVecOrMat, A::Tridiagonal, B::AbstractVecOrMat)
506-
nA = size(A,1)
507-
nB = size(B,2)
508-
if !(size(C,1) == size(B,1) == nA)
509-
throw(DimensionMismatch("A has first dimension $nA, B has $(size(B,1)), C has $(size(C,1)) but all must match"))
510-
end
511-
if size(C,2) != nB
512-
throw(DimensionMismatch("A has second dimension $nA, B has $(size(B,2)), C has $(size(C,2)) but all must match"))
513-
end
514-
l = A.dl
515-
d = A.d
516-
u = A.du
517-
@inbounds begin
518-
for j = 1:nB
519-
b₀, b₊ = B[1, j], B[2, j]
520-
C[1, j] = d[1]*b₀ + u[1]*b₊
521-
for i = 2:nA - 1
522-
b₋, b₀, b₊ = b₀, b₊, B[i + 1, j]
523-
C[i, j] = l[i - 1]*b₋ + d[i]*b₀ + u[i]*b₊
524-
end
525-
C[nA, j] = l[nA - 1]*b₀ + d[nA]*b₊
526-
end
527-
end
528-
C
529-
end

base/sparse/sparsevector.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,3 +1680,43 @@ droptol!(x::SparseVector, tol, trim::Bool = true) = fkeep!(x, (i, x) -> abs(x) >
16801680

16811681
dropzeros!(x::SparseVector, trim::Bool = true) = fkeep!(x, (i, x) -> x != 0, trim)
16821682
dropzeros(x::SparseVector, trim::Bool = true) = dropzeros!(copy(x), trim)
1683+
1684+
function _fillnonzero!{Tv,Ti}(arr::SparseMatrixCSC{Tv, Ti}, val)
1685+
m, n = size(arr)
1686+
resize!(arr.colptr, n+1)
1687+
resize!(arr.rowval, m*n)
1688+
resize!(arr.nzval, m*n)
1689+
copy!(arr.colptr, 1:m:n*m+1)
1690+
fill!(arr.nzval, val)
1691+
index = 1
1692+
@inbounds for _ in 1:n
1693+
for i in 1:m
1694+
arr.rowval[index] = Ti(i)
1695+
index += 1
1696+
end
1697+
end
1698+
arr
1699+
end
1700+
1701+
function _fillnonzero!{Tv,Ti}(arr::SparseVector{Tv,Ti}, val)
1702+
n = arr.n
1703+
resize!(arr.nzind, n)
1704+
resize!(arr.nzval, n)
1705+
@inbounds for i in 1:n
1706+
arr.nzind[i] = Ti(i)
1707+
end
1708+
fill!(arr.nzval, val)
1709+
arr
1710+
end
1711+
1712+
import Base.fill!
1713+
function fill!(A::Union{SparseVector, SparseMatrixCSC}, x)
1714+
T = eltype(A)
1715+
xT = convert(T, x)
1716+
if xT == zero(T)
1717+
fill!(A.nzval, xT)
1718+
else
1719+
_fillnonzero!(A, xT)
1720+
end
1721+
return A
1722+
end

test/linalg/bidiag.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,49 @@ C = Tridiagonal(rand(Float64,9),rand(Float64,10),rand(Float64,9))
236236
@test promote_rule(Matrix{Float64}, Bidiagonal{Float64}) == Matrix{Float64}
237237
@test promote(B,A) == (B,convert(Matrix{Float64},full(A)))
238238
@test promote(C,A) == (C,Tridiagonal(zeros(Float64,9),convert(Vector{Float64},A.dv),convert(Vector{Float64},A.ev)))
239+
240+
import Base.LinAlg: fillslots!, UnitLowerTriangular
241+
let #fill!
242+
let # fillslots!
243+
A = Tridiagonal(randn(2), randn(3), randn(2))
244+
@test fillslots!(A, 3) == Tridiagonal([3, 3.], [3, 3, 3.], [3, 3.])
245+
B = Bidiagonal(randn(3), randn(2), true)
246+
@test fillslots!(B, 2) == Bidiagonal([2.,2,2], [2,2.], true)
247+
S = SymTridiagonal(randn(3), randn(2))
248+
@test fillslots!(S, 1) == SymTridiagonal([1,1,1.], [1,1.])
249+
Ult = UnitLowerTriangular(randn(3,3))
250+
@test fillslots!(Ult, 3) == UnitLowerTriangular([1 0 0; 3 1 0; 3 3 1])
251+
end
252+
let # fill!(exotic, 0)
253+
exotic_arrays = Any[Tridiagonal(randn(3), randn(4), randn(3)),
254+
Bidiagonal(randn(3), randn(2), rand(Bool)),
255+
SymTridiagonal(randn(3), randn(2)),
256+
sparse(randn(3,4)),
257+
Diagonal(randn(5)),
258+
sparse(rand(3)),
259+
LowerTriangular(randn(3,3)),
260+
UpperTriangular(randn(3,3))
261+
]
262+
for A in exotic_arrays
263+
fill!(A, 0)
264+
for a in A
265+
@test a == 0
266+
end
267+
end
268+
end
269+
let # fill!(small, x)
270+
val = randn()
271+
b = Bidiagonal(randn(1,1), true)
272+
st = SymTridiagonal(randn(1,1))
273+
for x in (b, st)
274+
@test full(fill!(x, val)) == fill!(full(x), val)
275+
end
276+
b = Bidiagonal(randn(2,2), true)
277+
st = SymTridiagonal(randn(3), randn(2))
278+
t = Tridiagonal(randn(3,3))
279+
for x in (b, t, st)
280+
@test_throws ArgumentError fill!(x, val)
281+
@test full(fill!(x, 0)) == fill!(full(x), 0)
282+
end
283+
end
284+
end

test/linalg/matmul.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,3 +332,59 @@ a = [RootInt(2),RootInt(10)]
332332
@test a*a' == [4 20; 20 100]
333333
A = [RootInt(3) RootInt(5)]
334334
@test A*a == [56]
335+
336+
function test_mul(C, A, B)
337+
A_mul_B!(C, A, B)
338+
@test full(A) * full(B) C
339+
@test A*B C
340+
end
341+
342+
let
343+
eltypes = [Float32, Float64, Int64]
344+
for k in [3, 4, 10]
345+
T = rand(eltypes)
346+
bi1 = Bidiagonal(rand(T, k), rand(T, k-1), rand(Bool))
347+
bi2 = Bidiagonal(rand(T, k), rand(T, k-1), rand(Bool))
348+
tri1 = Tridiagonal(rand(T,k-1), rand(T, k), rand(T, k-1))
349+
tri2 = Tridiagonal(rand(T,k-1), rand(T, k), rand(T, k-1))
350+
stri1 = SymTridiagonal(rand(T, k), rand(T, k-1))
351+
stri2 = SymTridiagonal(rand(T, k), rand(T, k-1))
352+
C = rand(T, k, k)
353+
specialmatrices = (bi1, bi2, tri1, tri2, stri1, stri2)
354+
for A in specialmatrices
355+
B = specialmatrices[rand(1:length(specialmatrices))]
356+
test_mul(C, A, B)
357+
end
358+
for S in specialmatrices
359+
l = rand(1:6)
360+
B = randn(k, l)
361+
C = randn(k, l)
362+
test_mul(C, S, B)
363+
A = randn(l, k)
364+
C = randn(l, k)
365+
test_mul(C, A, S)
366+
end
367+
end
368+
for T in eltypes
369+
A = Bidiagonal(rand(T, 2), rand(T, 1), rand(Bool))
370+
B = Bidiagonal(rand(T, 2), rand(T, 1), rand(Bool))
371+
C = randn(2,2)
372+
test_mul(C, A, B)
373+
B = randn(2, 9)
374+
C = randn(2, 9)
375+
test_mul(C, A, B)
376+
end
377+
let
378+
tri44 = Tridiagonal(randn(3), randn(4), randn(3))
379+
tri33 = Tridiagonal(randn(2), randn(3), randn(2))
380+
full43 = randn(4, 3)
381+
full24 = randn(2, 4)
382+
full33 = randn(3, 3)
383+
full44 = randn(4, 4)
384+
@test_throws DimensionMismatch A_mul_B!(full43, tri44, tri33)
385+
@test_throws DimensionMismatch A_mul_B!(full44, tri44, tri33)
386+
@test_throws DimensionMismatch A_mul_B!(full44, tri44, full43)
387+
@test_throws DimensionMismatch A_mul_B!(full43, tri33, full43)
388+
@test_throws DimensionMismatch A_mul_B!(full43, full43, tri44)
389+
end
390+
end

0 commit comments

Comments
 (0)