Skip to content

Commit b3b0d8e

Browse files
committed
Allow non-static indices
1 parent 103e9d4 commit b3b0d8e

File tree

4 files changed

+64
-4
lines changed

4 files changed

+64
-4
lines changed

src/indexing.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ end
7575
@inline index_size(::Size, ::Int) = Size()
7676
@inline index_size(::Size, a::StaticArray) = Size(a)
7777
@inline index_size(s::Size, ::Colon) = s
78-
@inline index_size(s::Size, a::SOneTo{n}) where n = Size(n,)
78+
@inline index_size(::Size, a::AbstractRange{<:Integer}) = Size(length(a),)
7979

8080
@inline index_sizes(::S, inds...) where {S<:Size} = map(index_size, unpack_size(S), inds)
8181

@@ -92,9 +92,9 @@ linear_index_size(ind_sizes::Type{<:Size}...) = _linear_index_size((), ind_sizes
9292
@inline _linear_index_size(t::Tuple, ::Type{Size{S}}, ind_sizes...) where {S} = _linear_index_size((t..., prod(S)), ind_sizes...)
9393

9494
_ind(i::Int, ::Int, ::Type{Int}) = :(inds[$i])
95-
_ind(i::Int, j::Int, ::Type{<:StaticArray}) = :(inds[$i][$j])
9695
_ind(i::Int, j::Int, ::Type{Colon}) = j
9796
_ind(i::Int, j::Int, ::Type{<:SOneTo}) = j
97+
_ind(i::Int, j::Int, ::Type{<:AbstractArray}) = :(inds[$i][$j])
9898

9999
################################
100100
## Non-scalar linear indexing ##
@@ -215,7 +215,7 @@ end
215215

216216
# getindex
217217

218-
@propagate_inbounds function getindex(a::StaticArray, inds::Union{Int, StaticArray{<:Tuple, Int}, SOneTo, Colon}...)
218+
@propagate_inbounds function getindex(a::StaticArray, inds::Union{Int, StaticArray{<:Tuple, Int}, AbstractRange, Colon}...)
219219
_getindex(a, index_sizes(Size(a), inds...), inds)
220220
end
221221

test/abstractarray.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ using StaticArrays, Test, LinearAlgebra
8585
@test similar(v, SOneTo(3), SOneTo(4)) isa MMatrix{3,4,Int}
8686
@test similar(v, 3, SOneTo(4)) isa Matrix
8787

88-
@test m[:, 1:2] isa Matrix
88+
@test m[:, 1:2] isa SMatrix{2, 2, Int}
8989
@test m[:, [true, false, false]] isa Matrix
9090
@test m[:, SOneTo(2)] isa SMatrix{2, 2, Int}
9191
@test m[:, :] isa SMatrix{2, 3, Int}

test/indexing.jl

+30
Original file line numberDiff line numberDiff line change
@@ -223,4 +223,34 @@ using StaticArrays, Test
223223
@test eltype(Bvv) == Int
224224
@test Bvv[:] == [B[1,2,3,4], B[1,1,3,4]]
225225
end
226+
227+
@testset "Indexing with constants" begin
228+
function SVector_UnitRange()
229+
x = SA[1, 2, 3]
230+
x[2:end]
231+
end
232+
@test SVector_UnitRange() === SA[2, 3]
233+
@test_const_fold SVector_UnitRange()
234+
235+
function SVector_StepRange()
236+
x = SA[1, 2, 3, 4]
237+
x[1:2:end]
238+
end
239+
@test SVector_StepRange() === SA[1, 3]
240+
@test_const_fold SVector_StepRange()
241+
242+
function SMatrix_UnitRange_UnitRange()
243+
x = SA[1 2 3; 4 5 6]
244+
x[1:2, 2:end]
245+
end
246+
@test SMatrix_UnitRange_UnitRange() === SA[2 3; 5 6]
247+
@test_const_fold SMatrix_UnitRange_UnitRange()
248+
249+
function SMatrix_StepRange_StepRange()
250+
x = SA[1 2 3; 4 5 6]
251+
x[1:1:2, 1:2:end]
252+
end
253+
@test SMatrix_StepRange_StepRange() === SA[1 3; 4 6]
254+
@test_const_fold SMatrix_StepRange_StepRange()
255+
end
226256
end

test/testutil.jl

+30
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,36 @@ should_not_be_inlined(x) = _should_not_be_inlined(x)
9696
end
9797

9898

99+
"""
100+
@test_const_fold f(args...)
101+
102+
Test that constant folding works with a function call `f(args...)`.
103+
Do nothing in `julia` < 1.3.
104+
"""
105+
macro test_const_fold(ex)
106+
quote
107+
ir, = $(esc(:($InteractiveUtils.@code_typed optimize = true $ex)))
108+
if :rettype in fieldnames(typeof(ir)) # skip tests in julia < 1.3
109+
@test ir.rettype isa Core.Compiler.Const
110+
if ir.rettype isa Core.Compiler.Const
111+
@test $(esc(ex)) == ir.rettype.val
112+
end
113+
end
114+
end
115+
end
116+
117+
@testset "@test_const_fold" begin
118+
should_const_fold() = (1, 2, 3)
119+
@test_const_fold should_const_fold()
120+
121+
x = Ref(1)
122+
should_not_const_fold() = x[]
123+
ts = @testset ErrorCounterTestSet "" begin
124+
@test_const_fold should_not_const_fold()
125+
end
126+
@test ts.errorcount == 0 && ts.failcount == 1 && ts.passcount == 0
127+
end
128+
99129
"""
100130
@inferred_maybe_allow allow ex
101131

0 commit comments

Comments
 (0)