Skip to content

Commit 9159fbb

Browse files
committed
Bounds-checking in triangular indexing branches
1 parent 95703b5 commit 9159fbb

File tree

3 files changed

+120
-24
lines changed

3 files changed

+120
-24
lines changed

src/diagonal.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,8 @@ Return the appropriate zero element `A[i, j]` corresponding to a banded matrix `
192192
"""
193193
diagzero(A::AbstractMatrix, i, j) = zero(eltype(A))
194194
@propagate_inbounds diagzero(A::AbstractMatrix{M}, i, j) where {M<:AbstractMatrix} =
195-
zeroslike(M, axes(A[i,i], 1), axes(A[j,j], 2))
196-
diagzero(A::AbstractMatrix, inds...) = diagzero(A, to_indices(A, inds)...)
195+
zeroslike(M, axes(A[BandIndex(0,i)], 1), axes(A[BandIndex(0,j)], 2))
196+
@propagate_inbounds diagzero(A::AbstractMatrix, inds...) = diagzero(A, to_indices(A, inds)...)
197197
# dispatching on the axes permits specializing on the axis types to return something other than an Array
198198
zeroslike(M::Type, ax::Vararg{Union{AbstractUnitRange, Integer}}) = zeroslike(M, ax)
199199
"""

src/triangular.jl

+71-22
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,22 @@ Base.isassigned(A::UpperOrLowerTriangular, i::Int, j::Int) =
238238
Base.isstored(A::UpperOrLowerTriangular, i::Int, j::Int) =
239239
_shouldforwardindex(A, i, j) ? Base.isstored(A.data, i, j) : false
240240

241-
@propagate_inbounds getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, i::Int, j::Int) where {T} =
242-
_shouldforwardindex(A, i, j) ? A.data[i,j] : ifelse(i == j, oneunit(T), zero(T))
243-
@propagate_inbounds getindex(A::Union{LowerTriangular, UpperTriangular}, i::Int, j::Int) =
244-
_shouldforwardindex(A, i, j) ? A.data[i,j] : diagzero(A,i,j)
241+
@propagate_inbounds function getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, i::Int, j::Int) where {T}
242+
if _shouldforwardindex(A, i, j)
243+
A.data[i,j]
244+
else
245+
@boundscheck checkbounds(A, i, j)
246+
ifelse(i == j, oneunit(T), zero(T))
247+
end
248+
end
249+
@propagate_inbounds function getindex(A::Union{LowerTriangular, UpperTriangular}, i::Int, j::Int)
250+
if _shouldforwardindex(A, i, j)
251+
A.data[i,j]
252+
else
253+
@boundscheck checkbounds(A, i, j)
254+
@inbounds diagzero(A,i,j)
255+
end
256+
end
245257

246258
_shouldforwardindex(U::UpperTriangular, b::BandIndex) = b.band >= 0
247259
_shouldforwardindex(U::LowerTriangular, b::BandIndex) = b.band <= 0
@@ -250,10 +262,20 @@ _shouldforwardindex(U::UnitLowerTriangular, b::BandIndex) = b.band < 0
250262

251263
# these specialized getindex methods enable constant-propagation of the band
252264
Base.@constprop :aggressive @propagate_inbounds function getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, b::BandIndex) where {T}
253-
_shouldforwardindex(A, b) ? A.data[b] : ifelse(b.band == 0, oneunit(T), zero(T))
265+
if _shouldforwardindex(A, b)
266+
A.data[b]
267+
else
268+
@boundscheck checkbounds(A, b)
269+
ifelse(b.band == 0, oneunit(T), zero(T))
270+
end
254271
end
255272
Base.@constprop :aggressive @propagate_inbounds function getindex(A::Union{LowerTriangular, UpperTriangular}, b::BandIndex)
256-
_shouldforwardindex(A, b) ? A.data[b] : diagzero(A.data, b)
273+
if _shouldforwardindex(A, b)
274+
A.data[b]
275+
else
276+
@boundscheck checkbounds(A, b)
277+
@inbounds diagzero(A, b)
278+
end
257279
end
258280

259281
_zero_triangular_half_str(::Type{<:UpperOrUnitUpperTriangular}) = "lower"
@@ -265,14 +287,20 @@ _zero_triangular_half_str(::Type{<:LowerOrUnitLowerTriangular}) = "upper"
265287
throw(ArgumentError(
266288
lazy"cannot set index in the $Ts triangular part ($i, $j) of an $Tn matrix to a nonzero value ($x)"))
267289
end
268-
@noinline function throw_nononeerror(T, @nospecialize(x), i, j)
290+
@noinline function throw_nonuniterror(T, @nospecialize(x), i, j)
291+
check_compatible_type(T, x)
269292
Tn = nameof(T)
270293
throw(ArgumentError(
271294
lazy"cannot set index on the diagonal ($i, $j) of an $Tn matrix to a non-unit value ($x)"))
272295
end
296+
function check_compatible_type(T, @nospecialize(x))
297+
ET = eltype(T)
298+
convert(ET, x) # check that the types are compatible with setindex!
299+
end
273300

274301
@propagate_inbounds function setindex!(A::UpperTriangular, x, i::Integer, j::Integer)
275302
if i > j
303+
@boundscheck checkbounds(A, i, j)
276304
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
277305
else
278306
A.data[i,j] = x
@@ -282,9 +310,11 @@ end
282310

283311
@propagate_inbounds function setindex!(A::UnitUpperTriangular, x, i::Integer, j::Integer)
284312
if i > j
313+
@boundscheck checkbounds(A, i, j)
285314
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
286315
elseif i == j
287-
x == oneunit(x) || throw_nononeerror(typeof(A), x, i, j)
316+
@boundscheck checkbounds(A, i, j)
317+
x == oneunit(eltype(A)) || throw_nonuniterror(typeof(A), x, i, j)
288318
else
289319
A.data[i,j] = x
290320
end
@@ -293,6 +323,7 @@ end
293323

294324
@propagate_inbounds function setindex!(A::LowerTriangular, x, i::Integer, j::Integer)
295325
if i < j
326+
@boundscheck checkbounds(A, i, j)
296327
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
297328
else
298329
A.data[i,j] = x
@@ -302,9 +333,11 @@ end
302333

303334
@propagate_inbounds function setindex!(A::UnitLowerTriangular, x, i::Integer, j::Integer)
304335
if i < j
336+
@boundscheck checkbounds(A, i, j)
305337
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
306338
elseif i == j
307-
x == oneunit(x) || throw_nononeerror(typeof(A), x, i, j)
339+
@boundscheck checkbounds(A, i, j)
340+
x == oneunit(eltype(A)) || throw_nonuniterror(typeof(A), x, i, j)
308341
else
309342
A.data[i,j] = x
310343
end
@@ -560,7 +593,7 @@ for (T, UT) in ((:UpperTriangular, :UnitUpperTriangular), (:LowerTriangular, :Un
560593
@eval @inline function _copy!(A::$UT, B::$T)
561594
for dind in diagind(A, IndexStyle(A))
562595
if A[dind] != B[dind]
563-
throw_nononeerror(typeof(A), B[dind], Tuple(dind)...)
596+
throw_nonuniterror(typeof(A), B[dind], Tuple(dind)...)
564597
end
565598
end
566599
_copy!($T(parent(A)), B)
@@ -742,7 +775,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular, c::Nu
742775
checksize1(A, B)
743776
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
744777
for j in axes(B.data,2)
745-
@inbounds _modify!(_add, c, A, (j,j))
778+
@inbounds _modify!(_add, B[BandIndex(0,j)] * c, A, (j,j))
746779
for i in firstindex(B.data,1):(j - 1)
747780
@inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j))
748781
end
@@ -753,7 +786,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, c::Number, B::UnitUpperTriang
753786
checksize1(A, B)
754787
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
755788
for j in axes(B.data,2)
756-
@inbounds _modify!(_add, c, A, (j,j))
789+
@inbounds _modify!(_add, c * B[BandIndex(0,j)], A, (j,j))
757790
for i in firstindex(B.data,1):(j - 1)
758791
@inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j))
759792
end
@@ -784,7 +817,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular, c::Nu
784817
checksize1(A, B)
785818
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
786819
for j in axes(B.data,2)
787-
@inbounds _modify!(_add, c, A, (j,j))
820+
@inbounds _modify!(_add, B[BandIndex(0,j)] *c, A, (j,j))
788821
for i in (j + 1):lastindex(B.data,1)
789822
@inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j))
790823
end
@@ -795,7 +828,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, c::Number, B::UnitLowerTriang
795828
checksize1(A, B)
796829
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
797830
for j in axes(B.data,2)
798-
@inbounds _modify!(_add, c, A, (j,j))
831+
@inbounds _modify!(_add, c * B[BandIndex(0,j)], A, (j,j))
799832
for i in (j + 1):lastindex(B.data,1)
800833
@inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j))
801834
end
@@ -805,36 +838,52 @@ end
805838

806839
function _trirdiv!(A::UpperTriangular, B::UpperOrUnitUpperTriangular, c::Number)
807840
checksize1(A, B)
841+
isunit = B isa UnitUpperTriangular
808842
for j in axes(B,2)
809-
for i in firstindex(B,1):j
810-
@inbounds A[i, j] = B[i, j] / c
843+
for i in firstindex(B,1):j-isunit
844+
@inbounds A.data[i, j] = B.data[i, j] / c
845+
end
846+
if isunit
847+
@inbounds A.data[j, j] = B[BandIndex(0,j)] / c
811848
end
812849
end
813850
return A
814851
end
815852
function _trirdiv!(A::LowerTriangular, B::LowerOrUnitLowerTriangular, c::Number)
816853
checksize1(A, B)
854+
isunit = B isa UnitLowerTriangular
817855
for j in axes(B,2)
818-
for i in j:lastindex(B,1)
819-
@inbounds A[i, j] = B[i, j] / c
856+
if isunit
857+
@inbounds A.data[j, j] = B[BandIndex(0,j)] / c
858+
end
859+
for i in j+isunit:lastindex(B,1)
860+
@inbounds A.data[i, j] = B.data[i, j] / c
820861
end
821862
end
822863
return A
823864
end
824865
function _trildiv!(A::UpperTriangular, c::Number, B::UpperOrUnitUpperTriangular)
825866
checksize1(A, B)
867+
isunit = B isa UnitUpperTriangular
826868
for j in axes(B,2)
827-
for i in firstindex(B,1):j
828-
@inbounds A[i, j] = c \ B[i, j]
869+
for i in firstindex(B,1):j-isunit
870+
@inbounds A.data[i, j] = c \ B.data[i, j]
871+
end
872+
if isunit
873+
@inbounds A.data[j, j] = c \ B[BandIndex(0,j)]
829874
end
830875
end
831876
return A
832877
end
833878
function _trildiv!(A::LowerTriangular, c::Number, B::LowerOrUnitLowerTriangular)
834879
checksize1(A, B)
880+
isunit = B isa UnitLowerTriangular
835881
for j in axes(B,2)
836-
for i in j:lastindex(B,1)
837-
@inbounds A[i, j] = c \ B[i, j]
882+
if isunit
883+
@inbounds A.data[j, j] = c \ B[BandIndex(0,j)]
884+
end
885+
for i in j+isunit:lastindex(B,1)
886+
@inbounds A.data[i, j] = c \ B.data[i, j]
838887
end
839888
end
840889
return A

test/triangular.jl

+47
Original file line numberDiff line numberDiff line change
@@ -934,4 +934,51 @@ end
934934
end
935935
end
936936

937+
@testset "indexing checks" begin
938+
@testset "getindex" begin
939+
U = UnitUpperTriangular(P)
940+
@test_throws BoundsError U[0,0]
941+
@test_throws BoundsError U[1,0]
942+
@test_throws BoundsError U[BandIndex(0,0)]
943+
@test_throws BoundsError U[BandIndex(-1,0)]
944+
945+
U = UpperTriangular(P)
946+
@test_throws BoundsError U[1,0]
947+
@test_throws BoundsError U[BandIndex(-1,0)]
948+
949+
L = UnitLowerTriangular(P)
950+
@test_throws BoundsError L[0,0]
951+
@test_throws BoundsError L[0,1]
952+
@test_throws BoundsError U[BandIndex(0,0)]
953+
@test_throws BoundsError U[BandIndex(1,0)]
954+
955+
L = LowerTriangular(P)
956+
@test_throws BoundsError L[0,1]
957+
@test_throws BoundsError L[BandIndex(1,0)]
958+
end
959+
@testset "setindex!" begin
960+
P = [1 2; 3 4]
961+
A = SizedArrays.SizedArray{(2,2)}(P)
962+
M = fill(A, 2, 2)
963+
U = UnitUpperTriangular(M)
964+
@test_throws "Cannot `convert` an object of type Int64" U[1,1] = 1
965+
L = UnitLowerTriangular(M)
966+
@test_throws "Cannot `convert` an object of type Int64" L[1,1] = 1
967+
968+
U = UnitUpperTriangular(P)
969+
@test_throws BoundsError U[0,0] = 1
970+
@test_throws BoundsError U[1,0] = 0
971+
972+
U = UpperTriangular(P)
973+
@test_throws BoundsError U[1,0] = 0
974+
975+
L = UnitLowerTriangular(P)
976+
@test_throws BoundsError L[0,0] = 1
977+
@test_throws BoundsError L[0,1] = 0
978+
979+
L = LowerTriangular(P)
980+
@test_throws BoundsError L[0,1] = 0
981+
end
982+
end
983+
937984
end # module TestTriangular

0 commit comments

Comments
 (0)