@@ -613,6 +613,27 @@ function _diag(A::Bidiagonal, k)
613
613
end
614
614
end
615
615
616
+ """
617
+ _MulAddMul_nonzeroalpha(_add::MulAddMul[, ::Val{false}])
618
+
619
+ Return a new `MulAddMul` with the value of `alpha` potentially set to a literal non-zero
620
+ value if permitted by the type (e.g., for `_add.alpha isa Bool`, in which case the `alpha` is
621
+ set to `true` in the returned instance).
622
+ In other cases, the single-argument call is a no-op and returns `_add` without modifications.
623
+
624
+ In addition, if `Val(false)` is provided as the second argument,
625
+ `beta` is set to `false` in the returned `MulAddMul` instance.
626
+ """
627
+ _MulAddMul_nonzeroalpha (_add:: MulAddMul ) = _add
628
+ function _MulAddMul_nonzeroalpha (_add:: MulAddMul{ais1,bis0,A} , :: Val{false} ) where {ais1,bis0,A}
629
+ MulAddMul {ais1,true,A,Bool} (_add. alpha, false )
630
+ end
631
+ function _MulAddMul_nonzeroalpha (_add:: MulAddMul{ais1,bis0,Bool} ) where {ais1,bis0}
632
+ (; beta) = _add
633
+ MulAddMul {true,bis0,Bool,typeof(beta)} (true , beta)
634
+ end
635
+ _MulAddMul_nonzeroalpha (_add:: MulAddMul{ais1,bis0,Bool} , :: Val{false} ) where {ais1,bis0} = MulAddMul ()
636
+
616
637
_mul! (C:: AbstractMatrix , A:: BiTriSym , B:: TriSym , _add:: MulAddMul ) =
617
638
_bibimul! (C, A, B, _add)
618
639
_mul! (C:: AbstractMatrix , A:: BiTriSym , B:: Bidiagonal , _add:: MulAddMul ) =
@@ -626,36 +647,54 @@ function _bibimul!(C, A, B, _add)
626
647
# `_modify!` in the following loop will not update the
627
648
# off-diagonal elements for non-zero beta.
628
649
_rmul_or_fill! (C, _add. beta)
629
- _iszero_alpha (_add) && return C
630
- if n <= 3
650
+ iszero (_add. alpha) && return C
651
+ # beta is unused in _bibimul_nonzeroalpha!, so we set it to false
652
+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add, Val (false ))
653
+ _bibimul_nonzeroalpha! (C, A, B, _add_nonzeroalpha)
654
+ C
655
+ end
656
+ function _bibimul_nonzeroalpha! (C, A, B, _add)
657
+ n = size (A,1 )
658
+ if n == 1
631
659
# naive multiplication
632
- for I in CartesianIndices (C)
633
- C[I] += _add (sum (A[I[1 ], k] * B[k, I[2 ]] for k in axes (A,2 )))
634
- end
660
+ @inbounds C[1 ,1 ] += _add (A[1 ,1 ] * B[1 ,1 ])
635
661
return C
636
662
end
637
663
@inbounds begin
638
664
# first column of C
639
665
C[1 ,1 ] += _add (A[1 ,1 ]* B[1 ,1 ] + A[1 , 2 ]* B[2 ,1 ])
640
666
C[2 ,1 ] += _add (A[2 ,1 ]* B[1 ,1 ] + A[2 ,2 ]* B[2 ,1 ])
641
- C[3 ,1 ] += _add (A[3 ,2 ]* B[2 ,1 ])
667
+ if n >= 3
668
+ C[3 ,1 ] += _add (A[3 ,2 ]* B[2 ,1 ])
669
+ end
642
670
# second column of C
643
671
C[1 ,2 ] += _add (A[1 ,1 ]* B[1 ,2 ] + A[1 ,2 ]* B[2 ,2 ])
644
- C[2 ,2 ] += _add (A[2 ,1 ]* B[1 ,2 ] + A[2 ,2 ]* B[2 ,2 ] + A[2 ,3 ]* B[3 ,2 ])
645
- C[3 ,2 ] += _add (A[3 ,2 ]* B[2 ,2 ] + A[3 ,3 ]* B[3 ,2 ])
646
- C[4 ,2 ] += _add (A[4 ,3 ]* B[3 ,2 ])
672
+ C22 = A[2 ,1 ]* B[1 ,2 ] + A[2 ,2 ]* B[2 ,2 ]
673
+ if n >= 3
674
+ C[2 ,2 ] += _add (C22 + A[2 ,3 ]* B[3 ,2 ])
675
+ C[3 ,2 ] += _add (A[3 ,2 ]* B[2 ,2 ] + A[3 ,3 ]* B[3 ,2 ])
676
+ if n >= 4
677
+ C[4 ,2 ] += _add (A[4 ,3 ]* B[3 ,2 ])
678
+ end
679
+ else
680
+ C[2 ,2 ] += _add (C22)
681
+ end
647
682
end # inbounds
648
683
# middle columns
649
684
__bibimul! (C, A, B, _add)
650
685
@inbounds begin
651
- C[n- 3 ,n- 1 ] += _add (A[n- 3 ,n- 2 ]* B[n- 2 ,n- 1 ])
652
- C[n- 2 ,n- 1 ] += _add (A[n- 2 ,n- 2 ]* B[n- 2 ,n- 1 ] + A[n- 2 ,n- 1 ]* B[n- 1 ,n- 1 ])
653
- C[n- 1 ,n- 1 ] += _add (A[n- 1 ,n- 2 ]* B[n- 2 ,n- 1 ] + A[n- 1 ,n- 1 ]* B[n- 1 ,n- 1 ] + A[n- 1 ,n]* B[n,n- 1 ])
654
- C[n, n- 1 ] += _add (A[n,n- 1 ]* B[n- 1 ,n- 1 ] + A[n,n]* B[n,n- 1 ])
686
+ if n >= 4
687
+ C[n- 3 ,n- 1 ] += _add (A[n- 3 ,n- 2 ]* B[n- 2 ,n- 1 ])
688
+ C[n- 2 ,n- 1 ] += _add (A[n- 2 ,n- 2 ]* B[n- 2 ,n- 1 ] + A[n- 2 ,n- 1 ]* B[n- 1 ,n- 1 ])
689
+ C[n- 1 ,n- 1 ] += _add (A[n- 1 ,n- 2 ]* B[n- 2 ,n- 1 ] + A[n- 1 ,n- 1 ]* B[n- 1 ,n- 1 ] + A[n- 1 ,n]* B[n,n- 1 ])
690
+ C[n, n- 1 ] += _add (A[n,n- 1 ]* B[n- 1 ,n- 1 ] + A[n,n]* B[n,n- 1 ])
691
+ end
655
692
# last column of C
656
- C[n- 2 , n] += _add (A[n- 2 ,n- 1 ]* B[n- 1 ,n])
657
- C[n- 1 , n] += _add (A[n- 1 ,n- 1 ]* B[n- 1 ,n ] + A[n- 1 ,n]* B[n,n ])
658
- C[n, n] += _add (A[n,n- 1 ]* B[n- 1 ,n ] + A[n,n]* B[n,n ])
693
+ if n >= 3
694
+ C[n- 2 , n] += _add (A[n- 2 ,n- 1 ]* B[n- 1 ,n])
695
+ C[n- 1 , n] += _add (A[n- 1 ,n- 1 ]* B[n- 1 ,n ] + A[n- 1 ,n]* B[n,n ])
696
+ C[n, n] += _add (A[n,n- 1 ]* B[n- 1 ,n ] + A[n,n]* B[n,n ])
697
+ end
659
698
end # inbounds
660
699
C
661
700
end
@@ -696,9 +735,9 @@ function __bibimul!(C, A, B::Bidiagonal, _add)
696
735
Al = _diag (A, - 1 )
697
736
Ad = _diag (A, 0 )
698
737
Au = _diag (A, 1 )
699
- Bd = _diag (B, 0 )
738
+ Bd = B . dv
700
739
if B. uplo == ' U'
701
- Bu = _diag (B, 1 )
740
+ Bu = B . ev
702
741
@inbounds begin
703
742
for j in 3 : n- 2
704
743
Aj₋2j₋1 = Au[j- 2 ]
@@ -717,7 +756,7 @@ function __bibimul!(C, A, B::Bidiagonal, _add)
717
756
end
718
757
end
719
758
else # B.uplo == 'L'
720
- Bl = _diag (B, - 1 )
759
+ Bl = B . ev
721
760
@inbounds begin
722
761
for j in 3 : n- 2
723
762
Aj₋1j = Au[j- 1 ]
@@ -743,9 +782,9 @@ function __bibimul!(C, A::Bidiagonal, B, _add)
743
782
Bl = _diag (B, - 1 )
744
783
Bd = _diag (B, 0 )
745
784
Bu = _diag (B, 1 )
746
- Ad = _diag (A, 0 )
785
+ Ad = A . dv
747
786
if A. uplo == ' U'
748
- Au = _diag (A, 1 )
787
+ Au = A . ev
749
788
@inbounds begin
750
789
for j in 3 : n- 2
751
790
Aj₋2j₋1 = Au[j- 2 ]
@@ -765,7 +804,7 @@ function __bibimul!(C, A::Bidiagonal, B, _add)
765
804
end
766
805
end
767
806
else # A.uplo == 'L'
768
- Al = _diag (A, - 1 )
807
+ Al = A . ev
769
808
@inbounds begin
770
809
for j in 3 : n- 2
771
810
Aj₋1j₋1 = Ad[j- 1 ]
@@ -789,11 +828,11 @@ function __bibimul!(C, A::Bidiagonal, B, _add)
789
828
end
790
829
function __bibimul! (C, A:: Bidiagonal , B:: Bidiagonal , _add)
791
830
n = size (A,1 )
792
- Ad = _diag (A, 0 )
793
- Bd = _diag (B, 0 )
831
+ Ad = A . dv
832
+ Bd = B . dv
794
833
if A. uplo == ' U' && B. uplo == ' U'
795
- Au = _diag (A, 1 )
796
- Bu = _diag (B, 1 )
834
+ Au = A . ev
835
+ Bu = B . ev
797
836
@inbounds begin
798
837
for j in 3 : n- 2
799
838
Aj₋2j₋1 = Au[j- 2 ]
@@ -809,8 +848,8 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
809
848
end
810
849
end
811
850
elseif A. uplo == ' U' && B. uplo == ' L'
812
- Au = _diag (A, 1 )
813
- Bl = _diag (B, - 1 )
851
+ Au = A . ev
852
+ Bl = B . ev
814
853
@inbounds begin
815
854
for j in 3 : n- 2
816
855
Aj₋1j = Au[j- 1 ]
@@ -826,8 +865,8 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
826
865
end
827
866
end
828
867
elseif A. uplo == ' L' && B. uplo == ' U'
829
- Al = _diag (A, - 1 )
830
- Bu = _diag (B, 1 )
868
+ Al = A . ev
869
+ Bu = B . ev
831
870
@inbounds begin
832
871
for j in 3 : n- 2
833
872
Aj₋1j₋1 = Ad[j- 1 ]
@@ -843,8 +882,8 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
843
882
end
844
883
end
845
884
else # A.uplo == 'L' && B.uplo == 'L'
846
- Al = _diag (A, - 1 )
847
- Bl = _diag (B, - 1 )
885
+ Al = A . ev
886
+ Bl = B . ev
848
887
@inbounds begin
849
888
for j in 3 : n- 2
850
889
Ajj = Ad[j]
@@ -863,15 +902,20 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
863
902
C
864
903
end
865
904
866
- _mul! (C:: AbstractMatrix , A:: BiTriSym , B:: Diagonal , alpha:: Number , beta:: Number ) =
867
- @stable_muladdmul _mul! (C, A, B, MulAddMul (alpha, beta))
868
905
function _mul! (C:: AbstractMatrix , A:: BiTriSym , B:: Diagonal , _add:: MulAddMul )
869
906
require_one_based_indexing (C)
870
907
matmul_size_check (size (C), size (A), size (B))
871
908
n = size (A,1 )
872
909
iszero (n) && return C
873
910
_rmul_or_fill! (C, _add. beta) # see the same use above
874
- _iszero_alpha (_add) && return C
911
+ iszero (_add. alpha) && return C
912
+ # beta is unused in the _bidimul! call, so we set it to false
913
+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add, Val (false ))
914
+ _bidimul! (C, A, B, _add_nonzeroalpha)
915
+ C
916
+ end
917
+ function _bidimul! (C:: AbstractMatrix , A:: BiTriSym , B:: Diagonal , _add:: MulAddMul )
918
+ n = size (A,1 )
875
919
Al = _diag (A, - 1 )
876
920
Ad = _diag (A, 0 )
877
921
Au = _diag (A, 1 )
@@ -907,14 +951,8 @@ function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
907
951
end # inbounds
908
952
C
909
953
end
910
-
911
- function _mul! (C:: AbstractMatrix , A:: Bidiagonal , B:: Diagonal , _add:: MulAddMul )
912
- require_one_based_indexing (C)
913
- matmul_size_check (size (C), size (A), size (B))
954
+ function _bidimul! (C:: AbstractMatrix , A:: Bidiagonal , B:: Diagonal , _add:: MulAddMul )
914
955
n = size (A,1 )
915
- iszero (n) && return C
916
- _rmul_or_fill! (C, _add. beta) # see the same use above
917
- _iszero_alpha (_add) && return C
918
956
(; dv, ev) = A
919
957
Bd = B. diag
920
958
rowshift = A. uplo == ' U' ? - 1 : 1
@@ -943,7 +981,13 @@ function _mul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
943
981
matmul_size_check (size (C), size (A), size (B))
944
982
n = size (A,1 )
945
983
iszero (n) && return C
946
- _iszero_alpha (_add) && return _rmul_or_fill! (C, _add. beta)
984
+ iszero (_add. alpha) && return _rmul_or_fill! (C, _add. beta)
985
+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add)
986
+ _bidimul! (C, A, B, _add_nonzeroalpha)
987
+ C
988
+ end
989
+ function _bidimul! (C:: Bidiagonal , A:: Bidiagonal , B:: Diagonal , _add:: MulAddMul )
990
+ n = size (A,1 )
947
991
Adv, Aev = A. dv, A. ev
948
992
Cdv, Cev = C. dv, C. ev
949
993
Bd = B. diag
@@ -978,14 +1022,22 @@ function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulA
978
1022
nB = size (B,2 )
979
1023
(iszero (nA) || iszero (nB)) && return C
980
1024
_iszero_alpha (_add) && return _rmul_or_fill! (C, _add. beta)
1025
+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add)
1026
+ _mul_bitrisym_left! (C, A, B, _add_nonzeroalpha)
1027
+ return C
1028
+ end
1029
+ function _mul_bitrisym_left! (C:: AbstractVecOrMat , A:: BiTriSym , B:: AbstractVecOrMat , _add:: MulAddMul )
1030
+ nA = size (A,1 )
1031
+ nB = size (B,2 )
981
1032
if nA == 1
982
1033
A11 = @inbounds A[1 ,1 ]
983
1034
for i in axes (B, 2 )
984
1035
@inbounds _modify! (_add, A11 * B[1 ,i], C, (1 ,i))
985
1036
end
986
- return C
1037
+ else
1038
+ _mul_bitrisym! (C, A, B, _add)
987
1039
end
988
- _mul_bitrisym! (C, A, B, _add)
1040
+ return C
989
1041
end
990
1042
function _mul_bitrisym! (C:: AbstractVecOrMat , A:: Bidiagonal , B:: AbstractVecOrMat , _add:: MulAddMul )
991
1043
nA = size (A,1 )
@@ -1046,6 +1098,13 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul)
1046
1098
n = size (A,1 )
1047
1099
m = size (B,2 )
1048
1100
(_iszero_alpha (_add) || iszero (m)) && return _rmul_or_fill! (C, _add. beta)
1101
+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add)
1102
+ _mul_bitrisym_right! (C, A, B, _add_nonzeroalpha)
1103
+ C
1104
+ end
1105
+ function _mul_bitrisym_right! (C:: AbstractMatrix , A:: AbstractMatrix , B:: TriSym , _add:: MulAddMul )
1106
+ n = size (A,1 )
1107
+ m = size (B,2 )
1049
1108
if m == 1
1050
1109
B11 = B[1 ,1 ]
1051
1110
return mul! (C, A, B11, _add. alpha, _add. beta)
@@ -1082,6 +1141,12 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAdd
1082
1141
m, n = size (A)
1083
1142
(iszero (m) || iszero (n)) && return C
1084
1143
_iszero_alpha (_add) && return _rmul_or_fill! (C, _add. beta)
1144
+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add)
1145
+ _mul_bitrisym_right! (C, A, B, _add_nonzeroalpha)
1146
+ C
1147
+ end
1148
+ function _mul_bitrisym_right! (C:: AbstractMatrix , A:: AbstractMatrix , B:: Bidiagonal , _add:: MulAddMul )
1149
+ m, n = size (A)
1085
1150
@inbounds if B. uplo == ' U'
1086
1151
for j in n: - 1 : 2 , i in 1 : m
1087
1152
_modify! (_add, A[i,j] * B. dv[j] + A[i,j- 1 ] * B. ev[j- 1 ], C, (i, j))
@@ -1114,6 +1179,13 @@ function _dibimul!(C, A, B, _add)
1114
1179
# ensure that we fill off-band elements in the destination
1115
1180
_rmul_or_fill! (C, _add. beta)
1116
1181
_iszero_alpha (_add) && return C
1182
+ # beta is unused in the _dibimul_nonzeroalpha! call, so we set it to false
1183
+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add, Val (false ))
1184
+ _dibimul_nonzeroalpha! (C, A, B, _add_nonzeroalpha)
1185
+ C
1186
+ end
1187
+ function _dibimul_nonzeroalpha! (C, A, B, _add)
1188
+ n = size (A,1 )
1117
1189
if n <= 3
1118
1190
# For simplicity, use a naive multiplication for small matrices
1119
1191
# that loops over all elements.
@@ -1150,14 +1222,8 @@ function _dibimul!(C, A, B, _add)
1150
1222
end # inbounds
1151
1223
C
1152
1224
end
1153
- function _dibimul! (C:: AbstractMatrix , A:: Diagonal , B:: Bidiagonal , _add)
1154
- require_one_based_indexing (C)
1155
- matmul_size_check (size (C), size (A), size (B))
1225
+ function _dibimul_nonzeroalpha! (C:: AbstractMatrix , A:: Diagonal , B:: Bidiagonal , _add)
1156
1226
n = size (A,1 )
1157
- iszero (n) && return C
1158
- # ensure that we fill off-band elements in the destination
1159
- _rmul_or_fill! (C, _add. beta)
1160
- _iszero_alpha (_add) && return C
1161
1227
Ad = A. diag
1162
1228
Bdv, Bev = B. dv, B. ev
1163
1229
rowshift = B. uplo == ' U' ? - 1 : 1
@@ -1187,6 +1253,11 @@ function _dibimul!(C::Bidiagonal, A::Diagonal, B::Bidiagonal, _add)
1187
1253
n = size (A,1 )
1188
1254
n == 0 && return C
1189
1255
_iszero_alpha (_add) && return _rmul_or_fill! (C, _add. beta)
1256
+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add)
1257
+ _dibimul_nonzeroalpha! (C, A, B, _add_nonzeroalpha)
1258
+ C
1259
+ end
1260
+ function _dibimul_nonzeroalpha! (C:: Bidiagonal , A:: Diagonal , B:: Bidiagonal , _add)
1190
1261
Ad = A. diag
1191
1262
Bdv, Bev = B. dv, B. ev
1192
1263
Cdv, Cev = C. dv, C. ev
0 commit comments