Skip to content

Commit 91151ab

Browse files
vchuravyYingboMa
andauthored
fix overflow in CartesianIndices iteration (#31011)
This allows LLVM to vectorize the 1D CartesianIndices case, as well as fixing an overflow bug for: ```julia CartesianIndices(((typemax(Int64)-2):typemax(Int64),)) ``` Co-authored-by: Yingbo Ma <[email protected]> Co-Authored-By: vchuravy <[email protected]>
1 parent a399780 commit 91151ab

File tree

3 files changed

+115
-28
lines changed

3 files changed

+115
-28
lines changed

base/array.jl

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1604,10 +1604,11 @@ CartesianIndex(2, 1)
16041604
function findnext(A, start)
16051605
l = last(keys(A))
16061606
i = start
1607-
while i <= l
1608-
if A[i]
1609-
return i
1610-
end
1607+
i > l && return nothing
1608+
while true
1609+
A[i] && return i
1610+
i == l && break
1611+
# nextind(A, l) can throw/overflow
16111612
i = nextind(A, i)
16121613
end
16131614
return nothing
@@ -1685,10 +1686,11 @@ CartesianIndex(1, 1)
16851686
function findnext(testf::Function, A, start)
16861687
l = last(keys(A))
16871688
i = start
1688-
while i <= l
1689-
if testf(A[i])
1690-
return i
1691-
end
1689+
i > l && return nothing
1690+
while true
1691+
testf(A[i]) && return i
1692+
i == l && break
1693+
# nextind(A, l) can throw/overflow
16921694
i = nextind(A, i)
16931695
end
16941696
return nothing
@@ -1781,8 +1783,12 @@ CartesianIndex(2, 1)
17811783
"""
17821784
function findprev(A, start)
17831785
i = start
1784-
while i >= first(keys(A))
1786+
f = first(keys(A))
1787+
i < f && return nothing
1788+
while true
17851789
A[i] && return i
1790+
i == f && break
1791+
# prevind(A, f) can throw/underflow
17861792
i = prevind(A, i)
17871793
end
17881794
return nothing
@@ -1868,8 +1874,12 @@ CartesianIndex(2, 1)
18681874
"""
18691875
function findprev(testf::Function, A, start)
18701876
i = start
1871-
while i >= first(keys(A))
1877+
f = first(keys(A))
1878+
i < f && return nothing
1879+
while true
18721880
testf(A[i]) && return i
1881+
i == f && break
1882+
# prevind(A, f) can throw/underflow
18731883
i = prevind(A, i)
18741884
end
18751885
return nothing

base/multidimensional.jl

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ module IteratorsMD
9595
# access to index tuple
9696
Tuple(index::CartesianIndex) = index.I
9797

98+
# equality
99+
Base.:(==)(a::CartesianIndex{N}, b::CartesianIndex{N}) where N = a.I == b.I
100+
98101
# zeros and ones
99102
zero(::CartesianIndex{N}) where {N} = zero(CartesianIndex{N})
100103
zero(::Type{CartesianIndex{N}}) where {N} = CartesianIndex(ntuple(x -> 0, Val(N)))
@@ -142,11 +145,15 @@ module IteratorsMD
142145
# nextind and prevind with CartesianIndex
143146
function Base.nextind(a::AbstractArray{<:Any,N}, i::CartesianIndex{N}) where {N}
144147
iter = CartesianIndices(axes(a))
145-
return CartesianIndex(inc(i.I, first(iter).I, last(iter).I))
148+
# might overflow
149+
I = inc(i.I, first(iter).I, last(iter).I)
150+
return I
146151
end
147152
function Base.prevind(a::AbstractArray{<:Any,N}, i::CartesianIndex{N}) where {N}
148153
iter = CartesianIndices(axes(a))
149-
return CartesianIndex(dec(i.I, last(iter).I, first(iter).I))
154+
# might underflow
155+
I = dec(i.I, last(iter).I, first(iter).I)
156+
return I
150157
end
151158

152159
# Iteration over the elements of CartesianIndex cannot be supported until its length can be inferred,
@@ -334,20 +341,30 @@ module IteratorsMD
334341
iterfirst, iterfirst
335342
end
336343
@inline function iterate(iter::CartesianIndices, state)
337-
nextstate = CartesianIndex(inc(state.I, first(iter).I, last(iter).I))
338-
nextstate.I[end] > last(iter.indices[end]) && return nothing
339-
nextstate, nextstate
344+
valid, I = __inc(state.I, first(iter).I, last(iter).I)
345+
valid || return nothing
346+
return CartesianIndex(I...), CartesianIndex(I...)
340347
end
341348

342349
# increment & carry
343-
@inline inc(::Tuple{}, ::Tuple{}, ::Tuple{}) = ()
344-
@inline inc(state::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int}) = (state[1]+1,)
345350
@inline function inc(state, start, stop)
351+
_, I = __inc(state, start, stop)
352+
return CartesianIndex(I...)
353+
end
354+
355+
# increment post check to avoid integer overflow
356+
@inline __inc(::Tuple{}, ::Tuple{}, ::Tuple{}) = false, ()
357+
@inline function __inc(state::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int})
358+
valid = state[1] < stop[1]
359+
return valid, (state[1]+1,)
360+
end
361+
362+
@inline function __inc(state, start, stop)
346363
if state[1] < stop[1]
347-
return (state[1]+1,tail(state)...)
364+
return true, (state[1]+1, tail(state)...)
348365
end
349-
newtail = inc(tail(state), tail(start), tail(stop))
350-
(start[1], newtail...)
366+
valid, I = __inc(tail(state), tail(start), tail(stop))
367+
return valid, (start[1], I...)
351368
end
352369

353370
# 0-d cartesian ranges are special-cased to iterate once and only once
@@ -414,21 +431,32 @@ module IteratorsMD
414431
iterfirst, iterfirst
415432
end
416433
@inline function iterate(r::Reverse{<:CartesianIndices}, state)
417-
nextstate = CartesianIndex(dec(state.I, last(r.itr).I, first(r.itr).I))
418-
nextstate.I[end] < first(r.itr.indices[end]) && return nothing
419-
nextstate, nextstate
434+
valid, I = __dec(state.I, last(r.itr).I, first(r.itr).I)
435+
valid || return nothing
436+
return CartesianIndex(I...), CartesianIndex(I...)
420437
end
421438

422439
# decrement & carry
423-
@inline dec(::Tuple{}, ::Tuple{}, ::Tuple{}) = ()
424-
@inline dec(state::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int}) = (state[1]-1,)
425440
@inline function dec(state, start, stop)
441+
_, I = __dec(state, start, stop)
442+
return CartesianIndex(I...)
443+
end
444+
445+
# decrement post check to avoid integer overflow
446+
@inline __dec(::Tuple{}, ::Tuple{}, ::Tuple{}) = false, ()
447+
@inline function __dec(state::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int})
448+
valid = state[1] > stop[1]
449+
return valid, (state[1]-1,)
450+
end
451+
452+
@inline function __dec(state, start, stop)
426453
if state[1] > stop[1]
427-
return (state[1]-1,tail(state)...)
454+
return true, (state[1]-1, tail(state)...)
428455
end
429-
newtail = dec(tail(state), tail(start), tail(stop))
430-
(start[1], newtail...)
456+
valid, I = __dec(tail(state), tail(start), tail(stop))
457+
return valid, (start[1], I...)
431458
end
459+
432460
# 0-d cartesian ranges are special-cased to iterate once and only once
433461
iterate(iter::Reverse{<:CartesianIndices{0}}, state=false) = state ? nothing : (CartesianIndex(), true)
434462

test/cartesian.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,52 @@ ex = Base.Cartesian.exprresolve(:(if 5 > 4; :x; else :y; end))
1515
# can't convert higher-dimensional indices to Int
1616
@test_throws MethodError convert(Int, CartesianIndex(42, 1))
1717
end
18+
19+
@testset "CartesianIndices overflow" begin
20+
I = CartesianIndices((1:typemax(Int),))
21+
i = last(I)
22+
@test iterate(I, i) === nothing
23+
24+
I = CartesianIndices((1:(typemax(Int)-1),))
25+
i = CartesianIndex(typemax(Int))
26+
@test iterate(I, i) === nothing
27+
28+
I = CartesianIndices((1:typemax(Int), 1:typemax(Int)))
29+
i = last(I)
30+
@test iterate(I, i) === nothing
31+
32+
i = CartesianIndex(typemax(Int), 1)
33+
@test iterate(I, i) === (CartesianIndex(1, 2), CartesianIndex(1,2))
34+
35+
# reverse cartesian indices
36+
I = CartesianIndices((typemin(Int):(typemin(Int)+3),))
37+
i = last(I)
38+
@test iterate(I, i) === nothing
39+
end
40+
41+
@testset "CartesianIndices iteration" begin
42+
I = CartesianIndices((2:4, 0:1, 1:1, 3:5))
43+
indices = Vector{eltype(I)}()
44+
for i in I
45+
push!(indices, i)
46+
end
47+
@test length(I) == length(indices)
48+
@test vec(I) == indices
49+
50+
empty!(indices)
51+
I = Iterators.reverse(I)
52+
for i in I
53+
push!(indices, i)
54+
end
55+
@test length(I) == length(indices)
56+
@test vec(collect(I)) == indices
57+
58+
# test invalid state
59+
I = CartesianIndices((2:4, 3:5))
60+
@test iterate(I, CartesianIndex(typemax(Int), 3))[1] == CartesianIndex(2,4)
61+
@test iterate(I, CartesianIndex(typemax(Int), 4))[1] == CartesianIndex(2,5)
62+
@test iterate(I, CartesianIndex(typemax(Int), 5)) === nothing
63+
64+
@test iterate(I, CartesianIndex(3, typemax(Int)))[1] == CartesianIndex(4,typemax(Int))
65+
@test iterate(I, CartesianIndex(4, typemax(Int))) === nothing
66+
end

0 commit comments

Comments
 (0)