Skip to content

Commit 9741f5c

Browse files
vtjnashKristofferC
authored andcommitted
fix collect on stateful iterators
Generalization of #41919 Fixes #42168 (cherry picked from commit 68e0813)
1 parent 85b2b41 commit 9741f5c

File tree

5 files changed

+52
-34
lines changed

5 files changed

+52
-34
lines changed

base/array.jl

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -583,23 +583,38 @@ julia> collect(Float64, 1:2:5)
583583
"""
584584
collect(::Type{T}, itr) where {T} = _collect(T, itr, IteratorSize(itr))
585585

586-
_collect(::Type{T}, itr, isz::HasLength) where {T} = copyto!(Vector{T}(undef, Int(length(itr)::Integer)), itr)
587-
_collect(::Type{T}, itr, isz::HasShape) where {T} = copyto!(similar(Array{T}, axes(itr)), itr)
586+
_collect(::Type{T}, itr, isz::Union{HasLength,HasShape}) where {T} =
587+
copyto!(_array_for(T, isz, _similar_shape(itr, isz)), itr)
588588
function _collect(::Type{T}, itr, isz::SizeUnknown) where T
589589
a = Vector{T}()
590590
for x in itr
591-
push!(a,x)
591+
push!(a, x)
592592
end
593593
return a
594594
end
595595

596596
# make a collection similar to `c` and appropriate for collecting `itr`
597-
_similar_for(c::AbstractArray, ::Type{T}, itr, ::SizeUnknown) where {T} = similar(c, T, 0)
598-
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasLength) where {T} =
599-
similar(c, T, Int(length(itr)::Integer))
600-
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasShape) where {T} =
601-
similar(c, T, axes(itr))
602-
_similar_for(c, ::Type{T}, itr, isz) where {T} = similar(c, T)
597+
_similar_for(c, ::Type{T}, itr, isz, shp) where {T} = similar(c, T)
598+
599+
_similar_shape(itr, ::SizeUnknown) = nothing
600+
_similar_shape(itr, ::HasLength) = length(itr)::Integer
601+
_similar_shape(itr, ::HasShape) = axes(itr)
602+
603+
_similar_for(c::AbstractArray, ::Type{T}, itr, ::SizeUnknown, ::Nothing) where {T} =
604+
similar(c, T, 0)
605+
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasLength, len::Integer) where {T} =
606+
similar(c, T, len)
607+
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasShape, axs) where {T} =
608+
similar(c, T, axs)
609+
610+
# make a collection appropriate for collecting `itr::Generator`
611+
_array_for(::Type{T}, ::SizeUnknown, ::Nothing) where {T} = Vector{T}(undef, 0)
612+
_array_for(::Type{T}, ::HasLength, len::Integer) where {T} = Vector{T}(undef, Int(len))
613+
_array_for(::Type{T}, ::HasShape{N}, axs) where {T,N} = similar(Array{T,N}, axs)
614+
615+
# used by syntax lowering for simple typed comprehensions
616+
_array_for(::Type{T}, itr, isz) where {T} = _array_for(T, isz, _similar_shape(itr, isz))
617+
603618

604619
"""
605620
collect(collection)
@@ -638,10 +653,10 @@ collect(A::AbstractArray) = _collect_indices(axes(A), A)
638653
collect_similar(cont, itr) = _collect(cont, itr, IteratorEltype(itr), IteratorSize(itr))
639654

640655
_collect(cont, itr, ::HasEltype, isz::Union{HasLength,HasShape}) =
641-
copyto!(_similar_for(cont, eltype(itr), itr, isz), itr)
656+
copyto!(_similar_for(cont, eltype(itr), itr, isz, _similar_shape(itr, isz)), itr)
642657

643658
function _collect(cont, itr, ::HasEltype, isz::SizeUnknown)
644-
a = _similar_for(cont, eltype(itr), itr, isz)
659+
a = _similar_for(cont, eltype(itr), itr, isz, nothing)
645660
for x in itr
646661
push!(a,x)
647662
end
@@ -699,24 +714,19 @@ else
699714
end
700715
end
701716

702-
_array_for(::Type{T}, itr, isz::HasLength) where {T} = _array_for(T, itr, isz, length(itr))
703-
_array_for(::Type{T}, itr, isz::HasShape{N}) where {T,N} = _array_for(T, itr, isz, axes(itr))
704-
_array_for(::Type{T}, itr, ::HasLength, len) where {T} = Vector{T}(undef, len)
705-
_array_for(::Type{T}, itr, ::HasShape{N}, axs) where {T,N} = similar(Array{T,N}, axs)
706-
707717
function collect(itr::Generator)
708718
isz = IteratorSize(itr.iter)
709719
et = @default_eltype(itr)
710720
if isa(isz, SizeUnknown)
711721
return grow_to!(Vector{et}(), itr)
712722
else
713-
shape = isz isa HasLength ? length(itr) : axes(itr)
723+
shp = _similar_shape(itr, isz)
714724
y = iterate(itr)
715725
if y === nothing
716-
return _array_for(et, itr.iter, isz)
726+
return _array_for(et, isz, shp)
717727
end
718728
v1, st = y
719-
dest = _array_for(typeof(v1), itr.iter, isz, shape)
729+
dest = _array_for(typeof(v1), isz, shp)
720730
# The typeassert gives inference a helping hand on the element type and dimensionality
721731
# (work-around for #28382)
722732
et′ = et <: Type ? Type : et
@@ -726,15 +736,22 @@ function collect(itr::Generator)
726736
end
727737

728738
_collect(c, itr, ::EltypeUnknown, isz::SizeUnknown) =
729-
grow_to!(_similar_for(c, @default_eltype(itr), itr, isz), itr)
739+
grow_to!(_similar_for(c, @default_eltype(itr), itr, isz, nothing), itr)
730740

731741
function _collect(c, itr, ::EltypeUnknown, isz::Union{HasLength,HasShape})
742+
et = @default_eltype(itr)
743+
shp = _similar_shape(itr, isz)
732744
y = iterate(itr)
733745
if y === nothing
734-
return _similar_for(c, @default_eltype(itr), itr, isz)
746+
return _similar_for(c, et, itr, isz, shp)
735747
end
736748
v1, st = y
737-
collect_to_with_first!(_similar_for(c, typeof(v1), itr, isz), v1, itr, st)
749+
dest = _similar_for(c, typeof(v1), itr, isz, shp)
750+
# The typeassert gives inference a helping hand on the element type and dimensionality
751+
# (work-around for #28382)
752+
et′ = et <: Type ? Type : et
753+
RT = dest isa AbstractArray ? AbstractArray{<:et′, ndims(dest)} : Any
754+
collect_to_with_first!(dest, v1, itr, st)::RT
738755
end
739756

740757
function collect_to_with_first!(dest::AbstractArray, v1, itr, st)

base/dict.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,6 @@ length(t::ImmutableDict) = count(Returns(true), t)
826826
isempty(t::ImmutableDict) = !isdefined(t, :parent)
827827
empty(::ImmutableDict, ::Type{K}, ::Type{V}) where {K, V} = ImmutableDict{K,V}()
828828

829-
_similar_for(c::Dict, ::Type{Pair{K,V}}, itr, isz) where {K, V} = empty(c, K, V)
830-
_similar_for(c::AbstractDict, ::Type{T}, itr, isz) where {T} =
829+
_similar_for(c::AbstractDict, ::Type{Pair{K,V}}, itr, isz, len) where {K, V} = empty(c, K, V)
830+
_similar_for(c::AbstractDict, ::Type{T}, itr, isz, len) where {T} =
831831
throw(ArgumentError("for AbstractDicts, similar requires an element type of Pair;\n if calling map, consider a comprehension instead"))

base/set.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ empty(s::AbstractSet{T}, ::Type{U}=T) where {T,U} = Set{U}()
3636
# by default, a Set is returned
3737
emptymutable(s::AbstractSet{T}, ::Type{U}=T) where {T,U} = Set{U}()
3838

39-
_similar_for(c::AbstractSet, ::Type{T}, itr, isz) where {T} = empty(c, T)
39+
_similar_for(c::AbstractSet, ::Type{T}, itr, isz, len) where {T} = empty(c, T)
4040

4141
function show(io::IO, s::Set)
4242
if isempty(s)

src/julia-syntax.scm

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2734,7 +2734,7 @@
27342734
(check-no-return expr)
27352735
(if (has-break-or-continue? expr)
27362736
(error "break or continue outside loop"))
2737-
(let ((result (gensy))
2737+
(let ((result (make-ssavalue))
27382738
(idx (gensy))
27392739
(oneresult (make-ssavalue))
27402740
(prod (make-ssavalue))
@@ -2758,16 +2758,14 @@
27582758
(let ((overall-itr (if (length= itrs 1) (car iv) prod)))
27592759
`(scope-block
27602760
(block
2761-
(local ,result) (local ,idx)
2761+
(local ,idx)
27622762
,.(map (lambda (v r) `(= ,v ,(caddr r))) iv itrs)
27632763
,.(if (length= itrs 1)
27642764
'()
27652765
`((= ,prod (call (top product) ,@iv))))
27662766
(= ,isz (call (top IteratorSize) ,overall-itr))
27672767
(= ,szunk (call (core isa) ,isz (top SizeUnknown)))
2768-
(if ,szunk
2769-
(= ,result (call (curly (core Array) ,ty 1) (core undef) 0))
2770-
(= ,result (call (top _array_for) ,ty ,overall-itr ,isz)))
2768+
(= ,result (call (top _array_for) ,ty ,overall-itr ,isz))
27712769
(= ,idx (call (top first) (call (top LinearIndices) ,result)))
27722770
,(construct-loops (reverse itrs) (reverse iv))
27732771
,result)))))

test/iterators.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,11 +293,14 @@ let (a, b) = (1:3, [4 6;
293293
end
294294

295295
# collect stateful iterator
296-
let
297-
itr = (i+1 for i in Base.Stateful([1,2,3]))
296+
let itr
297+
itr = Iterators.Stateful(Iterators.map(identity, 1:5))
298+
@test collect(itr) == 1:5
299+
@test collect(itr) == Int[] # Stateful do not preserve shape
300+
itr = (i+1 for i in Base.Stateful([1, 2, 3]))
298301
@test collect(itr) == [2, 3, 4]
299-
A = zeros(Int, 0, 0)
300-
itr = (i-1 for i in Base.Stateful(A))
302+
@test collect(itr) == Int[] # Stateful do not preserve shape
303+
itr = (i-1 for i in Base.Stateful(zeros(Int, 0, 0)))
301304
@test collect(itr) == Int[] # Stateful do not preserve shape
302305
end
303306

0 commit comments

Comments
 (0)