@@ -238,10 +238,22 @@ Base.isassigned(A::UpperOrLowerTriangular, i::Int, j::Int) =
238
238
Base. isstored (A:: UpperOrLowerTriangular , i:: Int , j:: Int ) =
239
239
_shouldforwardindex (A, i, j) ? Base. isstored (A. data, i, j) : false
240
240
241
- @propagate_inbounds getindex (A:: Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}} , i:: Int , j:: Int ) where {T} =
242
- _shouldforwardindex (A, i, j) ? A. data[i,j] : ifelse (i == j, oneunit (T), zero (T))
243
- @propagate_inbounds getindex (A:: Union{LowerTriangular, UpperTriangular} , i:: Int , j:: Int ) =
244
- _shouldforwardindex (A, i, j) ? A. data[i,j] : diagzero (A,i,j)
241
+ @propagate_inbounds function getindex (A:: Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}} , i:: Int , j:: Int ) where {T}
242
+ if _shouldforwardindex (A, i, j)
243
+ A. data[i,j]
244
+ else
245
+ @boundscheck checkbounds (A, i, j)
246
+ ifelse (i == j, oneunit (T), zero (T))
247
+ end
248
+ end
249
+ @propagate_inbounds function getindex (A:: Union{LowerTriangular, UpperTriangular} , i:: Int , j:: Int )
250
+ if _shouldforwardindex (A, i, j)
251
+ A. data[i,j]
252
+ else
253
+ @boundscheck checkbounds (A, i, j)
254
+ @inbounds diagzero (A,i,j)
255
+ end
256
+ end
245
257
246
258
_shouldforwardindex (U:: UpperTriangular , b:: BandIndex ) = b. band >= 0
247
259
_shouldforwardindex (U:: LowerTriangular , b:: BandIndex ) = b. band <= 0
@@ -250,10 +262,20 @@ _shouldforwardindex(U::UnitLowerTriangular, b::BandIndex) = b.band < 0
250
262
251
263
# these specialized getindex methods enable constant-propagation of the band
252
264
Base. @constprop :aggressive @propagate_inbounds function getindex (A:: Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}} , b:: BandIndex ) where {T}
253
- _shouldforwardindex (A, b) ? A. data[b] : ifelse (b. band == 0 , oneunit (T), zero (T))
265
+ if _shouldforwardindex (A, b)
266
+ A. data[b]
267
+ else
268
+ @boundscheck checkbounds (A, b)
269
+ ifelse (b. band == 0 , oneunit (T), zero (T))
270
+ end
254
271
end
255
272
Base. @constprop :aggressive @propagate_inbounds function getindex (A:: Union{LowerTriangular, UpperTriangular} , b:: BandIndex )
256
- _shouldforwardindex (A, b) ? A. data[b] : diagzero (A. data, b)
273
+ if _shouldforwardindex (A, b)
274
+ A. data[b]
275
+ else
276
+ @boundscheck checkbounds (A, b)
277
+ @inbounds diagzero (A, b)
278
+ end
257
279
end
258
280
259
281
_zero_triangular_half_str (:: Type{<:UpperOrUnitUpperTriangular} ) = " lower"
@@ -265,14 +287,20 @@ _zero_triangular_half_str(::Type{<:LowerOrUnitLowerTriangular}) = "upper"
265
287
throw (ArgumentError (
266
288
lazy " cannot set index in the $Ts triangular part ($i, $j) of an $Tn matrix to a nonzero value ($x)" ))
267
289
end
268
- @noinline function throw_nononeerror (T, @nospecialize (x), i, j)
290
+ @noinline function throw_nonuniterror (T, @nospecialize (x), i, j)
291
+ check_compatible_type (T, x)
269
292
Tn = nameof (T)
270
293
throw (ArgumentError (
271
294
lazy " cannot set index on the diagonal ($i, $j) of an $Tn matrix to a non-unit value ($x)" ))
272
295
end
296
+ function check_compatible_type (T, @nospecialize (x))
297
+ ET = eltype (T)
298
+ convert (ET, x) # check that the types are compatible with setindex!
299
+ end
273
300
274
301
@propagate_inbounds function setindex! (A:: UpperTriangular , x, i:: Integer , j:: Integer )
275
302
if i > j
303
+ @boundscheck checkbounds (A, i, j)
276
304
iszero (x) || throw_nonzeroerror (typeof (A), x, i, j)
277
305
else
278
306
A. data[i,j] = x
282
310
283
311
@propagate_inbounds function setindex! (A:: UnitUpperTriangular , x, i:: Integer , j:: Integer )
284
312
if i > j
313
+ @boundscheck checkbounds (A, i, j)
285
314
iszero (x) || throw_nonzeroerror (typeof (A), x, i, j)
286
315
elseif i == j
287
- x == oneunit (x) || throw_nononeerror (typeof (A), x, i, j)
316
+ @boundscheck checkbounds (A, i, j)
317
+ x == oneunit (eltype (A)) || throw_nonuniterror (typeof (A), x, i, j)
288
318
else
289
319
A. data[i,j] = x
290
320
end
293
323
294
324
@propagate_inbounds function setindex! (A:: LowerTriangular , x, i:: Integer , j:: Integer )
295
325
if i < j
326
+ @boundscheck checkbounds (A, i, j)
296
327
iszero (x) || throw_nonzeroerror (typeof (A), x, i, j)
297
328
else
298
329
A. data[i,j] = x
302
333
303
334
@propagate_inbounds function setindex! (A:: UnitLowerTriangular , x, i:: Integer , j:: Integer )
304
335
if i < j
336
+ @boundscheck checkbounds (A, i, j)
305
337
iszero (x) || throw_nonzeroerror (typeof (A), x, i, j)
306
338
elseif i == j
307
- x == oneunit (x) || throw_nononeerror (typeof (A), x, i, j)
339
+ @boundscheck checkbounds (A, i, j)
340
+ x == oneunit (eltype (A)) || throw_nonuniterror (typeof (A), x, i, j)
308
341
else
309
342
A. data[i,j] = x
310
343
end
@@ -560,7 +593,7 @@ for (T, UT) in ((:UpperTriangular, :UnitUpperTriangular), (:LowerTriangular, :Un
560
593
@eval @inline function _copy! (A:: $UT , B:: $T )
561
594
for dind in diagind (A, IndexStyle (A))
562
595
if A[dind] != B[dind]
563
- throw_nononeerror (typeof (A), B[dind], Tuple (dind)... )
596
+ throw_nonuniterror (typeof (A), B[dind], Tuple (dind)... )
564
597
end
565
598
end
566
599
_copy! ($ T (parent (A)), B)
@@ -742,7 +775,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular, c::Nu
742
775
checksize1 (A, B)
743
776
_iszero_alpha (_add) && return _rmul_or_fill! (A, _add. beta)
744
777
for j in axes (B. data,2 )
745
- @inbounds _modify! (_add, c, A, (j,j))
778
+ @inbounds _modify! (_add, B[ BandIndex ( 0 ,j)] * c, A, (j,j))
746
779
for i in firstindex (B. data,1 ): (j - 1 )
747
780
@inbounds _modify! (_add, B. data[i,j] * c, A. data, (i,j))
748
781
end
@@ -753,7 +786,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, c::Number, B::UnitUpperTriang
753
786
checksize1 (A, B)
754
787
_iszero_alpha (_add) && return _rmul_or_fill! (A, _add. beta)
755
788
for j in axes (B. data,2 )
756
- @inbounds _modify! (_add, c, A, (j,j))
789
+ @inbounds _modify! (_add, c * B[ BandIndex ( 0 ,j)] , A, (j,j))
757
790
for i in firstindex (B. data,1 ): (j - 1 )
758
791
@inbounds _modify! (_add, c * B. data[i,j], A. data, (i,j))
759
792
end
@@ -784,7 +817,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular, c::Nu
784
817
checksize1 (A, B)
785
818
_iszero_alpha (_add) && return _rmul_or_fill! (A, _add. beta)
786
819
for j in axes (B. data,2 )
787
- @inbounds _modify! (_add, c, A, (j,j))
820
+ @inbounds _modify! (_add, B[ BandIndex ( 0 ,j)] * c, A, (j,j))
788
821
for i in (j + 1 ): lastindex (B. data,1 )
789
822
@inbounds _modify! (_add, B. data[i,j] * c, A. data, (i,j))
790
823
end
@@ -795,7 +828,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, c::Number, B::UnitLowerTriang
795
828
checksize1 (A, B)
796
829
_iszero_alpha (_add) && return _rmul_or_fill! (A, _add. beta)
797
830
for j in axes (B. data,2 )
798
- @inbounds _modify! (_add, c, A, (j,j))
831
+ @inbounds _modify! (_add, c * B[ BandIndex ( 0 ,j)] , A, (j,j))
799
832
for i in (j + 1 ): lastindex (B. data,1 )
800
833
@inbounds _modify! (_add, c * B. data[i,j], A. data, (i,j))
801
834
end
@@ -805,36 +838,52 @@ end
805
838
806
839
function _trirdiv! (A:: UpperTriangular , B:: UpperOrUnitUpperTriangular , c:: Number )
807
840
checksize1 (A, B)
841
+ isunit = B isa UnitUpperTriangular
808
842
for j in axes (B,2 )
809
- for i in firstindex (B,1 ): j
810
- @inbounds A[i, j] = B[i, j] / c
843
+ for i in firstindex (B,1 ): j- isunit
844
+ @inbounds A. data[i, j] = B. data[i, j] / c
845
+ end
846
+ if isunit
847
+ @inbounds A. data[j, j] = B[BandIndex (0 ,j)] / c
811
848
end
812
849
end
813
850
return A
814
851
end
815
852
function _trirdiv! (A:: LowerTriangular , B:: LowerOrUnitLowerTriangular , c:: Number )
816
853
checksize1 (A, B)
854
+ isunit = B isa UnitLowerTriangular
817
855
for j in axes (B,2 )
818
- for i in j: lastindex (B,1 )
819
- @inbounds A[i, j] = B[i, j] / c
856
+ if isunit
857
+ @inbounds A. data[j, j] = B[BandIndex (0 ,j)] / c
858
+ end
859
+ for i in j+ isunit: lastindex (B,1 )
860
+ @inbounds A. data[i, j] = B. data[i, j] / c
820
861
end
821
862
end
822
863
return A
823
864
end
824
865
function _trildiv! (A:: UpperTriangular , c:: Number , B:: UpperOrUnitUpperTriangular )
825
866
checksize1 (A, B)
867
+ isunit = B isa UnitUpperTriangular
826
868
for j in axes (B,2 )
827
- for i in firstindex (B,1 ): j
828
- @inbounds A[i, j] = c \ B[i, j]
869
+ for i in firstindex (B,1 ): j- isunit
870
+ @inbounds A. data[i, j] = c \ B. data[i, j]
871
+ end
872
+ if isunit
873
+ @inbounds A. data[j, j] = c \ B[BandIndex (0 ,j)]
829
874
end
830
875
end
831
876
return A
832
877
end
833
878
function _trildiv! (A:: LowerTriangular , c:: Number , B:: LowerOrUnitLowerTriangular )
834
879
checksize1 (A, B)
880
+ isunit = B isa UnitLowerTriangular
835
881
for j in axes (B,2 )
836
- for i in j: lastindex (B,1 )
837
- @inbounds A[i, j] = c \ B[i, j]
882
+ if isunit
883
+ @inbounds A. data[j, j] = c \ B[BandIndex (0 ,j)]
884
+ end
885
+ for i in j+ isunit: lastindex (B,1 )
886
+ @inbounds A. data[i, j] = c \ B. data[i, j]
838
887
end
839
888
end
840
889
return A
0 commit comments