Skip to content

Commit 5d3d02a

Browse files
authored
Branch on Bool alpha in bidiag matmul (#1257)
Similar to #1256, we may reduce branches in `@stable_muladdmul` for a non-zero `Bool` alpha in `Bidiagonal` matmul. The idea is that if `alpha::Bool` is non-zero, it must be `true`. We may therefore hardcode this value and reduce branches in `@stable_muladdmul`. In addition, if `beta` is unused in a method, we may hardcode `beta = false` as well, which further helps with compilation. ```julia julia> using LinearAlgebra julia> B = Bidiagonal(1:4, 1:3, :U); D = Diagonal(1:4); v = (1:4); ``` With this, ```julia julia> @time B * B; 0.406036 seconds (2.34 M allocations: 105.696 MiB, 4.31% gc time, 99.99% compilation time) # nightly 0.141480 seconds (588.18 k allocations: 26.049 MiB, 99.95% compilation time) # This PR ``` The rest are mainly reductions in allocation: ```julia julia> @time B * D; 0.141903 seconds (487.43 k allocations: 24.557 MiB, 99.96% compilation time) # nightly 0.147749 seconds (382.60 k allocations: 19.385 MiB, 99.96% compilation time) # this PR ``` ```julia julia> @time D * B; 0.136308 seconds (491.83 k allocations: 24.782 MiB, 99.95% compilation time) # nightly 0.136909 seconds (386.35 k allocations: 19.591 MiB, 99.94% compilation time) # this PR ``` ```julia julia> @time B * v; 0.087207 seconds (428.64 k allocations: 21.620 MiB, 99.93% compilation time) # master 0.089002 seconds (342.46 k allocations: 17.306 MiB, 99.92% compilation time) # this PR ``` This also improves performance for small `Bidiagonal` multiplication, as there are fewer operations to carry out. ```julia julia> n = 1; T = Bidiagonal(ones(n), ones(max(n-1,0)), :U); C = Matrix(T); julia> @Btime mul!($C, $T, $T); 33.360 ns (0 allocations: 0 bytes) # nightly 23.773 ns (0 allocations: 0 bytes) # this PR julia> n = 2; T = Bidiagonal(ones(n), ones(max(n-1,0)), :U); C = Matrix(T)); julia> @Btime mul!($C, $T, $T); 78.685 ns (0 allocations: 0 bytes) # nightly 31.388 ns (0 allocations: 0 bytes) # this PR julia> n = 3; T = Bidiagonal(ones(n), ones(max(n-1,0)), :U); C = Matrix(T); julia> @Btime mul!($C, $T, $T); 161.577 ns (0 allocations: 0 bytes) # nightly 41.256 ns (0 allocations: 0 bytes) # this PR ```
1 parent e4e8c19 commit 5d3d02a

File tree

1 file changed

+123
-52
lines changed

1 file changed

+123
-52
lines changed

src/bidiag.jl

Lines changed: 123 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,27 @@ function _diag(A::Bidiagonal, k)
613613
end
614614
end
615615

616+
"""
617+
_MulAddMul_nonzeroalpha(_add::MulAddMul[, ::Val{false}])
618+
619+
Return a new `MulAddMul` with the value of `alpha` potentially set to a literal non-zero
620+
value if permitted by the type (e.g., for `_add.alpha isa Bool`, in which case the `alpha` is
621+
set to `true` in the returned instance).
622+
In other cases, the single-argument call is a no-op and returns `_add` without modifications.
623+
624+
In addition, if `Val(false)` is provided as the second argument,
625+
`beta` is set to `false` in the returned `MulAddMul` instance.
626+
"""
627+
_MulAddMul_nonzeroalpha(_add::MulAddMul) = _add
628+
function _MulAddMul_nonzeroalpha(_add::MulAddMul{ais1,bis0,A}, ::Val{false}) where {ais1,bis0,A}
629+
MulAddMul{ais1,true,A,Bool}(_add.alpha, false)
630+
end
631+
function _MulAddMul_nonzeroalpha(_add::MulAddMul{ais1,bis0,Bool}) where {ais1,bis0}
632+
(; beta) = _add
633+
MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
634+
end
635+
_MulAddMul_nonzeroalpha(_add::MulAddMul{ais1,bis0,Bool}, ::Val{false}) where {ais1,bis0} = MulAddMul()
636+
616637
_mul!(C::AbstractMatrix, A::BiTriSym, B::TriSym, _add::MulAddMul) =
617638
_bibimul!(C, A, B, _add)
618639
_mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul) =
@@ -626,36 +647,54 @@ function _bibimul!(C, A, B, _add)
626647
# `_modify!` in the following loop will not update the
627648
# off-diagonal elements for non-zero beta.
628649
_rmul_or_fill!(C, _add.beta)
629-
_iszero_alpha(_add) && return C
630-
if n <= 3
650+
iszero(_add.alpha) && return C
651+
# beta is unused in _bibimul_nonzeroalpha!, so we set it to false
652+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add, Val(false))
653+
_bibimul_nonzeroalpha!(C, A, B, _add_nonzeroalpha)
654+
C
655+
end
656+
function _bibimul_nonzeroalpha!(C, A, B, _add)
657+
n = size(A,1)
658+
if n == 1
631659
# naive multiplication
632-
for I in CartesianIndices(C)
633-
C[I] += _add(sum(A[I[1], k] * B[k, I[2]] for k in axes(A,2)))
634-
end
660+
@inbounds C[1,1] += _add(A[1,1] * B[1,1])
635661
return C
636662
end
637663
@inbounds begin
638664
# first column of C
639665
C[1,1] += _add(A[1,1]*B[1,1] + A[1, 2]*B[2,1])
640666
C[2,1] += _add(A[2,1]*B[1,1] + A[2,2]*B[2,1])
641-
C[3,1] += _add(A[3,2]*B[2,1])
667+
if n >= 3
668+
C[3,1] += _add(A[3,2]*B[2,1])
669+
end
642670
# second column of C
643671
C[1,2] += _add(A[1,1]*B[1,2] + A[1,2]*B[2,2])
644-
C[2,2] += _add(A[2,1]*B[1,2] + A[2,2]*B[2,2] + A[2,3]*B[3,2])
645-
C[3,2] += _add(A[3,2]*B[2,2] + A[3,3]*B[3,2])
646-
C[4,2] += _add(A[4,3]*B[3,2])
672+
C22 = A[2,1]*B[1,2] + A[2,2]*B[2,2]
673+
if n >= 3
674+
C[2,2] += _add(C22 + A[2,3]*B[3,2])
675+
C[3,2] += _add(A[3,2]*B[2,2] + A[3,3]*B[3,2])
676+
if n >= 4
677+
C[4,2] += _add(A[4,3]*B[3,2])
678+
end
679+
else
680+
C[2,2] += _add(C22)
681+
end
647682
end # inbounds
648683
# middle columns
649684
__bibimul!(C, A, B, _add)
650685
@inbounds begin
651-
C[n-3,n-1] += _add(A[n-3,n-2]*B[n-2,n-1])
652-
C[n-2,n-1] += _add(A[n-2,n-2]*B[n-2,n-1] + A[n-2,n-1]*B[n-1,n-1])
653-
C[n-1,n-1] += _add(A[n-1,n-2]*B[n-2,n-1] + A[n-1,n-1]*B[n-1,n-1] + A[n-1,n]*B[n,n-1])
654-
C[n, n-1] += _add(A[n,n-1]*B[n-1,n-1] + A[n,n]*B[n,n-1])
686+
if n >= 4
687+
C[n-3,n-1] += _add(A[n-3,n-2]*B[n-2,n-1])
688+
C[n-2,n-1] += _add(A[n-2,n-2]*B[n-2,n-1] + A[n-2,n-1]*B[n-1,n-1])
689+
C[n-1,n-1] += _add(A[n-1,n-2]*B[n-2,n-1] + A[n-1,n-1]*B[n-1,n-1] + A[n-1,n]*B[n,n-1])
690+
C[n, n-1] += _add(A[n,n-1]*B[n-1,n-1] + A[n,n]*B[n,n-1])
691+
end
655692
# last column of C
656-
C[n-2, n] += _add(A[n-2,n-1]*B[n-1,n])
657-
C[n-1, n] += _add(A[n-1,n-1]*B[n-1,n ] + A[n-1,n]*B[n,n ])
658-
C[n, n] += _add(A[n,n-1]*B[n-1,n ] + A[n,n]*B[n,n ])
693+
if n >= 3
694+
C[n-2, n] += _add(A[n-2,n-1]*B[n-1,n])
695+
C[n-1, n] += _add(A[n-1,n-1]*B[n-1,n ] + A[n-1,n]*B[n,n ])
696+
C[n, n] += _add(A[n,n-1]*B[n-1,n ] + A[n,n]*B[n,n ])
697+
end
659698
end # inbounds
660699
C
661700
end
@@ -696,9 +735,9 @@ function __bibimul!(C, A, B::Bidiagonal, _add)
696735
Al = _diag(A, -1)
697736
Ad = _diag(A, 0)
698737
Au = _diag(A, 1)
699-
Bd = _diag(B, 0)
738+
Bd = B.dv
700739
if B.uplo == 'U'
701-
Bu = _diag(B, 1)
740+
Bu = B.ev
702741
@inbounds begin
703742
for j in 3:n-2
704743
Aj₋2j₋1 = Au[j-2]
@@ -717,7 +756,7 @@ function __bibimul!(C, A, B::Bidiagonal, _add)
717756
end
718757
end
719758
else # B.uplo == 'L'
720-
Bl = _diag(B, -1)
759+
Bl = B.ev
721760
@inbounds begin
722761
for j in 3:n-2
723762
Aj₋1j = Au[j-1]
@@ -743,9 +782,9 @@ function __bibimul!(C, A::Bidiagonal, B, _add)
743782
Bl = _diag(B, -1)
744783
Bd = _diag(B, 0)
745784
Bu = _diag(B, 1)
746-
Ad = _diag(A, 0)
785+
Ad = A.dv
747786
if A.uplo == 'U'
748-
Au = _diag(A, 1)
787+
Au = A.ev
749788
@inbounds begin
750789
for j in 3:n-2
751790
Aj₋2j₋1 = Au[j-2]
@@ -765,7 +804,7 @@ function __bibimul!(C, A::Bidiagonal, B, _add)
765804
end
766805
end
767806
else # A.uplo == 'L'
768-
Al = _diag(A, -1)
807+
Al = A.ev
769808
@inbounds begin
770809
for j in 3:n-2
771810
Aj₋1j₋1 = Ad[j-1]
@@ -789,11 +828,11 @@ function __bibimul!(C, A::Bidiagonal, B, _add)
789828
end
790829
function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
791830
n = size(A,1)
792-
Ad = _diag(A, 0)
793-
Bd = _diag(B, 0)
831+
Ad = A.dv
832+
Bd = B.dv
794833
if A.uplo == 'U' && B.uplo == 'U'
795-
Au = _diag(A, 1)
796-
Bu = _diag(B, 1)
834+
Au = A.ev
835+
Bu = B.ev
797836
@inbounds begin
798837
for j in 3:n-2
799838
Aj₋2j₋1 = Au[j-2]
@@ -809,8 +848,8 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
809848
end
810849
end
811850
elseif A.uplo == 'U' && B.uplo == 'L'
812-
Au = _diag(A, 1)
813-
Bl = _diag(B, -1)
851+
Au = A.ev
852+
Bl = B.ev
814853
@inbounds begin
815854
for j in 3:n-2
816855
Aj₋1j = Au[j-1]
@@ -826,8 +865,8 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
826865
end
827866
end
828867
elseif A.uplo == 'L' && B.uplo == 'U'
829-
Al = _diag(A, -1)
830-
Bu = _diag(B, 1)
868+
Al = A.ev
869+
Bu = B.ev
831870
@inbounds begin
832871
for j in 3:n-2
833872
Aj₋1j₋1 = Ad[j-1]
@@ -843,8 +882,8 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
843882
end
844883
end
845884
else # A.uplo == 'L' && B.uplo == 'L'
846-
Al = _diag(A, -1)
847-
Bl = _diag(B, -1)
885+
Al = A.ev
886+
Bl = B.ev
848887
@inbounds begin
849888
for j in 3:n-2
850889
Ajj = Ad[j]
@@ -863,15 +902,20 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
863902
C
864903
end
865904

866-
_mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number) =
867-
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
868905
function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
869906
require_one_based_indexing(C)
870907
matmul_size_check(size(C), size(A), size(B))
871908
n = size(A,1)
872909
iszero(n) && return C
873910
_rmul_or_fill!(C, _add.beta) # see the same use above
874-
_iszero_alpha(_add) && return C
911+
iszero(_add.alpha) && return C
912+
# beta is unused in the _bidimul! call, so we set it to false
913+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add, Val(false))
914+
_bidimul!(C, A, B, _add_nonzeroalpha)
915+
C
916+
end
917+
function _bidimul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
918+
n = size(A,1)
875919
Al = _diag(A, -1)
876920
Ad = _diag(A, 0)
877921
Au = _diag(A, 1)
@@ -907,14 +951,8 @@ function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
907951
end # inbounds
908952
C
909953
end
910-
911-
function _mul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
912-
require_one_based_indexing(C)
913-
matmul_size_check(size(C), size(A), size(B))
954+
function _bidimul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
914955
n = size(A,1)
915-
iszero(n) && return C
916-
_rmul_or_fill!(C, _add.beta) # see the same use above
917-
_iszero_alpha(_add) && return C
918956
(; dv, ev) = A
919957
Bd = B.diag
920958
rowshift = A.uplo == 'U' ? -1 : 1
@@ -943,7 +981,13 @@ function _mul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
943981
matmul_size_check(size(C), size(A), size(B))
944982
n = size(A,1)
945983
iszero(n) && return C
946-
_iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta)
984+
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
985+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
986+
_bidimul!(C, A, B, _add_nonzeroalpha)
987+
C
988+
end
989+
function _bidimul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
990+
n = size(A,1)
947991
Adv, Aev = A.dv, A.ev
948992
Cdv, Cev = C.dv, C.ev
949993
Bd = B.diag
@@ -978,14 +1022,22 @@ function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulA
9781022
nB = size(B,2)
9791023
(iszero(nA) || iszero(nB)) && return C
9801024
_iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta)
1025+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
1026+
_mul_bitrisym_left!(C, A, B, _add_nonzeroalpha)
1027+
return C
1028+
end
1029+
function _mul_bitrisym_left!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulAddMul)
1030+
nA = size(A,1)
1031+
nB = size(B,2)
9811032
if nA == 1
9821033
A11 = @inbounds A[1,1]
9831034
for i in axes(B, 2)
9841035
@inbounds _modify!(_add, A11 * B[1,i], C, (1,i))
9851036
end
986-
return C
1037+
else
1038+
_mul_bitrisym!(C, A, B, _add)
9871039
end
988-
_mul_bitrisym!(C, A, B, _add)
1040+
return C
9891041
end
9901042
function _mul_bitrisym!(C::AbstractVecOrMat, A::Bidiagonal, B::AbstractVecOrMat, _add::MulAddMul)
9911043
nA = size(A,1)
@@ -1046,6 +1098,13 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul)
10461098
n = size(A,1)
10471099
m = size(B,2)
10481100
(_iszero_alpha(_add) || iszero(m)) && return _rmul_or_fill!(C, _add.beta)
1101+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
1102+
_mul_bitrisym_right!(C, A, B, _add_nonzeroalpha)
1103+
C
1104+
end
1105+
function _mul_bitrisym_right!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul)
1106+
n = size(A,1)
1107+
m = size(B,2)
10491108
if m == 1
10501109
B11 = B[1,1]
10511110
return mul!(C, A, B11, _add.alpha, _add.beta)
@@ -1082,6 +1141,12 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAdd
10821141
m, n = size(A)
10831142
(iszero(m) || iszero(n)) && return C
10841143
_iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta)
1144+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
1145+
_mul_bitrisym_right!(C, A, B, _add_nonzeroalpha)
1146+
C
1147+
end
1148+
function _mul_bitrisym_right!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAddMul)
1149+
m, n = size(A)
10851150
@inbounds if B.uplo == 'U'
10861151
for j in n:-1:2, i in 1:m
10871152
_modify!(_add, A[i,j] * B.dv[j] + A[i,j-1] * B.ev[j-1], C, (i, j))
@@ -1114,6 +1179,13 @@ function _dibimul!(C, A, B, _add)
11141179
# ensure that we fill off-band elements in the destination
11151180
_rmul_or_fill!(C, _add.beta)
11161181
_iszero_alpha(_add) && return C
1182+
# beta is unused in the _dibimul_nonzeroalpha! call, so we set it to false
1183+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add, Val(false))
1184+
_dibimul_nonzeroalpha!(C, A, B, _add_nonzeroalpha)
1185+
C
1186+
end
1187+
function _dibimul_nonzeroalpha!(C, A, B, _add)
1188+
n = size(A,1)
11171189
if n <= 3
11181190
# For simplicity, use a naive multiplication for small matrices
11191191
# that loops over all elements.
@@ -1150,14 +1222,8 @@ function _dibimul!(C, A, B, _add)
11501222
end # inbounds
11511223
C
11521224
end
1153-
function _dibimul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add)
1154-
require_one_based_indexing(C)
1155-
matmul_size_check(size(C), size(A), size(B))
1225+
function _dibimul_nonzeroalpha!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add)
11561226
n = size(A,1)
1157-
iszero(n) && return C
1158-
# ensure that we fill off-band elements in the destination
1159-
_rmul_or_fill!(C, _add.beta)
1160-
_iszero_alpha(_add) && return C
11611227
Ad = A.diag
11621228
Bdv, Bev = B.dv, B.ev
11631229
rowshift = B.uplo == 'U' ? -1 : 1
@@ -1187,6 +1253,11 @@ function _dibimul!(C::Bidiagonal, A::Diagonal, B::Bidiagonal, _add)
11871253
n = size(A,1)
11881254
n == 0 && return C
11891255
_iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta)
1256+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
1257+
_dibimul_nonzeroalpha!(C, A, B, _add_nonzeroalpha)
1258+
C
1259+
end
1260+
function _dibimul_nonzeroalpha!(C::Bidiagonal, A::Diagonal, B::Bidiagonal, _add)
11901261
Ad = A.diag
11911262
Bdv, Bev = B.dv, B.ev
11921263
Cdv, Cev = C.dv, C.ev

0 commit comments

Comments
 (0)