Skip to content

Commit 85b2b41

Browse files
nalimilanKristofferC
authored andcommitted
Make return type of map inferrable with heterogeneous arrays (#42046)
Inference is not able to detect the element type automatically, but we can do it manually since we know promote_typejoin is used for widening. This is similar to the approach used for `broadcast` at #30485. (cherry picked from commit 49e3aec)
1 parent eca6c31 commit 85b2b41

File tree

6 files changed

+78
-54
lines changed

6 files changed

+78
-54
lines changed

base/array.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -679,18 +679,19 @@ if isdefined(Core, :Compiler)
679679
I = esc(itr)
680680
return quote
681681
if $I isa Generator && ($I).f isa Type
682-
($I).f
682+
T = ($I).f
683683
else
684-
Core.Compiler.return_type(_iterator_upper_bound, Tuple{typeof($I)})
684+
T = Core.Compiler.return_type(_iterator_upper_bound, Tuple{typeof($I)})
685685
end
686+
promote_typejoin_union(T)
686687
end
687688
end
688689
else
689690
macro default_eltype(itr)
690691
I = esc(itr)
691692
return quote
692693
if $I isa Generator && ($I).f isa Type
693-
($I).f
694+
promote_typejoin_union($I.f)
694695
else
695696
Any
696697
end
@@ -715,8 +716,12 @@ function collect(itr::Generator)
715716
return _array_for(et, itr.iter, isz)
716717
end
717718
v1, st = y
718-
arr = _array_for(typeof(v1), itr.iter, isz, shape)
719-
return collect_to_with_first!(arr, v1, itr, st)
719+
dest = _array_for(typeof(v1), itr.iter, isz, shape)
720+
# The typeassert gives inference a helping hand on the element type and dimensionality
721+
# (work-around for #28382)
722+
et′ = et <: Type ? Type : et
723+
RT = dest isa AbstractArray ? AbstractArray{<:et′, ndims(dest)} : Any
724+
collect_to_with_first!(dest, v1, itr, st)::RT
720725
end
721726
end
722727

base/broadcast.jl

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Module containing the broadcasting implementation.
88
module Broadcast
99

1010
using .Base.Cartesian
11-
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, @pure,
11+
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, promote_typejoin_union, @pure,
1212
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias
1313
import .Base: copy, copyto!, axes
1414
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, BroadcastFunction
@@ -713,50 +713,6 @@ eltypes(t::Tuple{Any}) = Tuple{_broadcast_getindex_eltype(t[1])}
713713
eltypes(t::Tuple{Any,Any}) = Tuple{_broadcast_getindex_eltype(t[1]), _broadcast_getindex_eltype(t[2])}
714714
eltypes(t::Tuple) = Tuple{_broadcast_getindex_eltype(t[1]), eltypes(tail(t)).types...}
715715

716-
function promote_typejoin_union(::Type{T}) where T
717-
if T === Union{}
718-
return Union{}
719-
elseif T isa UnionAll
720-
return Any # TODO: compute more precise bounds
721-
elseif T isa Union
722-
return promote_typejoin(promote_typejoin_union(T.a), promote_typejoin_union(T.b))
723-
elseif T <: Tuple
724-
return typejoin_union_tuple(T)
725-
else
726-
return T
727-
end
728-
end
729-
730-
@pure function typejoin_union_tuple(T::Type)
731-
u = Base.unwrap_unionall(T)
732-
u isa Union && return typejoin(
733-
typejoin_union_tuple(Base.rewrap_unionall(u.a, T)),
734-
typejoin_union_tuple(Base.rewrap_unionall(u.b, T)))
735-
p = (u::DataType).parameters
736-
lr = length(p)::Int
737-
if lr == 0
738-
return Tuple{}
739-
end
740-
c = Vector{Any}(undef, lr)
741-
for i = 1:lr
742-
pi = p[i]
743-
U = Core.Compiler.unwrapva(pi)
744-
if U === Union{}
745-
ci = Union{}
746-
elseif U isa Union
747-
ci = typejoin(U.a, U.b)
748-
else
749-
ci = U
750-
end
751-
if i == lr && Core.Compiler.isvarargtype(pi)
752-
c[i] = isdefined(pi, :N) ? Vararg{ci, pi.N} : Vararg{ci}
753-
else
754-
c[i] = ci
755-
end
756-
end
757-
return Base.rewrap_unionall(Tuple{c...}, T)
758-
end
759-
760716
# Inferred eltype of result of broadcast(f, args...)
761717
combine_eltypes(f, args::Tuple) =
762718
promote_typejoin_union(Base._return_type(f, eltypes(args)))

base/promotion.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,50 @@ function promote_typejoin(@nospecialize(a), @nospecialize(b))
161161
end
162162
_promote_typesubtract(@nospecialize(a)) = typesplit(a, Union{Nothing, Missing})
163163

164+
function promote_typejoin_union(::Type{T}) where T
165+
if T === Union{}
166+
return Union{}
167+
elseif T isa UnionAll
168+
return Any # TODO: compute more precise bounds
169+
elseif T isa Union
170+
return promote_typejoin(promote_typejoin_union(T.a), promote_typejoin_union(T.b))
171+
elseif T <: Tuple
172+
return typejoin_union_tuple(T)
173+
else
174+
return T
175+
end
176+
end
177+
178+
function typejoin_union_tuple(T::Type)
179+
@_pure_meta
180+
u = Base.unwrap_unionall(T)
181+
u isa Union && return typejoin(
182+
typejoin_union_tuple(Base.rewrap_unionall(u.a, T)),
183+
typejoin_union_tuple(Base.rewrap_unionall(u.b, T)))
184+
p = (u::DataType).parameters
185+
lr = length(p)::Int
186+
if lr == 0
187+
return Tuple{}
188+
end
189+
c = Vector{Any}(undef, lr)
190+
for i = 1:lr
191+
pi = p[i]
192+
U = Core.Compiler.unwrapva(pi)
193+
if U === Union{}
194+
ci = Union{}
195+
elseif U isa Union
196+
ci = typejoin(U.a, U.b)
197+
else
198+
ci = U
199+
end
200+
if i == lr && Core.Compiler.isvarargtype(pi)
201+
c[i] = isdefined(pi, :N) ? Vararg{ci, pi.N} : Vararg{ci}
202+
else
203+
c[i] = ci
204+
end
205+
end
206+
return Base.rewrap_unionall(Tuple{c...}, T)
207+
end
164208

165209
# Returns length, isfixed
166210
function full_va_len(p)

test/broadcast.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -991,10 +991,6 @@ end
991991
@test Core.Compiler.return_type(broadcast, Tuple{typeof(+), Vector{Int},
992992
Vector{Union{Float64, Missing}}}) ==
993993
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}
994-
@test isequal([1, 2] + [3.0, missing], [4.0, missing])
995-
@test Core.Compiler.return_type(+, Tuple{Vector{Int},
996-
Vector{Union{Float64, Missing}}}) ==
997-
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}
998994
@test Core.Compiler.return_type(+, Tuple{Vector{Int},
999995
Vector{Union{Float64, Missing}}}) ==
1000996
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}

test/generic_map_tests.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,28 @@ function generic_map_tests(mapf, inplace_mapf=nothing)
5353
@test A == map(x->x*x*x, Float64[1:10...])
5454
@test A === B
5555
end
56+
57+
# Issue #28382: inferrability of map with Union eltype
58+
@test isequal(map(+, [1, 2], [3.0, missing]), [4.0, missing])
59+
@test Core.Compiler.return_type(map, Tuple{typeof(+), Vector{Int},
60+
Vector{Union{Float64, Missing}}}) ==
61+
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}
62+
@test isequal(map(tuple, [1, 2], [3.0, missing]), [(1, 3.0), (2, missing)])
63+
@test Core.Compiler.return_type(map, Tuple{typeof(tuple), Vector{Int},
64+
Vector{Union{Float64, Missing}}}) ==
65+
Vector{<:Tuple{Int, Any}}
66+
# Check that corner cases do not throw an error
67+
@test isequal(map(x -> x === 1 ? nothing : x, [1, 2, missing]),
68+
[nothing, 2, missing])
69+
@test isequal(map(x -> x === 1 ? nothing : x, Any[1, 2, 3.0, missing]),
70+
[nothing, 2, 3, missing])
71+
@test map((x,y)->(x==1 ? 1.0 : x, y), [1, 2, 3], ["a", "b", "c"]) ==
72+
[(1.0, "a"), (2, "b"), (3, "c")]
73+
@test map(typeof, [iszero, isdigit]) == [typeof(iszero), typeof(isdigit)]
74+
@test map(typeof, [iszero, iszero]) == [typeof(iszero), typeof(iszero)]
75+
@test isequal(map(identity, Vector{<:Union{Int, Missing}}[[1, 2],[missing, 1]]),
76+
[[1, 2],[missing, 1]])
77+
@test map(x -> x < 0 ? false : x, Int[]) isa Vector{Integer}
5678
end
5779

5880
function testmap_equivalence(mapf, f, c...)

test/sets.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ using Dates
2222
@test isa(Set(sin(x) for x = 1:3), Set{Float64})
2323
@test isa(Set(f17741(x) for x = 1:3), Set{Int})
2424
@test isa(Set(f17741(x) for x = -1:1), Set{Integer})
25+
@test isa(Set(f17741(x) for x = 1:0), Set{Integer})
2526
end
2627
let s1 = Set(["foo", "bar"]), s2 = Set(s1)
2728
@test s1 == s2

0 commit comments

Comments
 (0)