Skip to content

Commit 10d470d

Browse files
authored
Merge pull request #23912 from JuliaLang/jn/infer-norecur-more
inference: revise recursion detection algorithm
2 parents 546a801 + b89e88e commit 10d470d

File tree

9 files changed

+258
-158
lines changed

9 files changed

+258
-158
lines changed

base/array.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -628,18 +628,28 @@ function _collect_indices(indsA, A)
628628
copy!(B, CartesianRange(indices(B)), A, CartesianRange(indsA))
629629
end
630630

631+
# define this as a macro so that the call to Inference
632+
# gets inlined into the caller before recursion detection
633+
# gets a chance to see it, so that recursive calls to the caller
634+
# don't trigger the inference limiter
631635
if isdefined(Core, :Inference)
632-
_default_eltype(@nospecialize itrt) = Core.Inference.return_type(first, Tuple{itrt})
636+
macro default_eltype(itrt)
637+
return quote
638+
Core.Inference.return_type(first, Tuple{$(esc(itrt))})
639+
end
640+
end
633641
else
634-
_default_eltype(@nospecialize itr) = Any
642+
macro default_eltype(itrt)
643+
return :(Any)
644+
end
635645
end
636646

637647
_array_for(::Type{T}, itr, ::HasLength) where {T} = Array{T,1}(Int(length(itr)::Integer))
638648
_array_for(::Type{T}, itr, ::HasShape) where {T} = similar(Array{T}, indices(itr))
639649

640650
function collect(itr::Generator)
641651
isz = iteratorsize(itr.iter)
642-
et = _default_eltype(typeof(itr))
652+
et = @default_eltype(typeof(itr))
643653
if isa(isz, SizeUnknown)
644654
return grow_to!(Array{et,1}(0), itr)
645655
else
@@ -653,12 +663,12 @@ function collect(itr::Generator)
653663
end
654664

655665
_collect(c, itr, ::EltypeUnknown, isz::SizeUnknown) =
656-
grow_to!(_similar_for(c, _default_eltype(typeof(itr)), itr, isz), itr)
666+
grow_to!(_similar_for(c, @default_eltype(typeof(itr)), itr, isz), itr)
657667

658668
function _collect(c, itr, ::EltypeUnknown, isz::Union{HasLength,HasShape})
659669
st = start(itr)
660670
if done(itr,st)
661-
return _similar_for(c, _default_eltype(typeof(itr)), itr, isz)
671+
return _similar_for(c, @default_eltype(typeof(itr)), itr, isz)
662672
end
663673
v1, st = next(itr, st)
664674
collect_to_with_first!(_similar_for(c, typeof(v1), itr, isz), v1, itr, st)

base/dict.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,9 @@ associative_with_eltype(DT_apply, kv, ::TP{K,V}) where {K,V} = DT_apply(K, V)(kv
158158
associative_with_eltype(DT_apply, kv::Generator, ::TP{K,V}) where {K,V} = DT_apply(K, V)(kv)
159159
associative_with_eltype(DT_apply, ::Type{Pair{K,V}}) where {K,V} = DT_apply(K, V)()
160160
associative_with_eltype(DT_apply, ::Type) = DT_apply(Any, Any)()
161-
associative_with_eltype(DT_apply::F, kv, t) where {F} = grow_to!(associative_with_eltype(DT_apply, _default_eltype(typeof(kv))), kv)
161+
associative_with_eltype(DT_apply::F, kv, t) where {F} = grow_to!(associative_with_eltype(DT_apply, @default_eltype(typeof(kv))), kv)
162162
function associative_with_eltype(DT_apply::F, kv::Generator, t) where F
163-
T = _default_eltype(typeof(kv))
163+
T = @default_eltype(typeof(kv))
164164
if T <: Union{Pair, Tuple{Any, Any}} && _isleaftype(T)
165165
return associative_with_eltype(DT_apply, kv, T)
166166
end

base/inference.jl

Lines changed: 174 additions & 108 deletions
Large diffs are not rendered by default.

base/reduce.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ Returns the sum of all elements of `A`, using the Kahan-Babuska-Neumaier compens
376376
summation algorithm for additional accuracy.
377377
"""
378378
function sum_kbn(A)
379-
T = _default_eltype(typeof(A))
379+
T = @default_eltype(typeof(A))
380380
c = r_promote(+, zero(T)::T)
381381
i = start(A)
382382
if done(A, i)

base/set.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ for sets of arbitrary objects.
1717
"""
1818
Set(itr) = Set{eltype(itr)}(itr)
1919
function Set(g::Generator)
20-
T = _default_eltype(typeof(g))
20+
T = @default_eltype(typeof(g))
2121
(_isleaftype(T) || T === Union{}) || return grow_to!(Set{T}(), g)
2222
return Set{T}(g)
2323
end
@@ -258,7 +258,7 @@ julia> unique(Real[1, 1.0, 2])
258258
```
259259
"""
260260
function unique(itr)
261-
T = _default_eltype(typeof(itr))
261+
T = @default_eltype(typeof(itr))
262262
out = Vector{T}()
263263
seen = Set{T}()
264264
i = start(itr)

base/sparse/higherorderfns.jl

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -926,37 +926,34 @@ end
926926
# vectors/matrices in mixedargs in their orginal order, and such that the result of
927927
# broadcast(parevalf, passedargstup...) is broadcast(f, mixedargs...)
928928
@inline function capturescalars(f, mixedargs)
929-
let makeargs = _capturescalars(mixedargs...),
930-
parevalf = (passed...) -> f(makeargs(passed...)...),
931-
passedsrcargstup = _capturenonscalars(mixedargs...)
929+
let (passedsrcargstup, makeargs) = _capturescalars(mixedargs...)
930+
parevalf = (passed...) -> f(makeargs(passed...)...)
932931
return (parevalf, passedsrcargstup)
933932
end
934933
end
935934

936-
@inline _capturenonscalars(nonscalararg::SparseVecOrMat, mixedargs...) =
937-
(nonscalararg, _capturenonscalars(mixedargs...)...)
938-
@inline _capturenonscalars(scalararg, mixedargs...) =
939-
_capturenonscalars(mixedargs...)
940-
@inline _capturenonscalars() = ()
935+
nonscalararg(::SparseVecOrMat) = true
936+
nonscalararg(::Any) = false
941937

942-
@inline _capturescalars(nonscalararg::SparseVecOrMat, mixedargs...) =
943-
let f = _capturescalars(mixedargs...)
944-
(head, tail...) -> (head, f(tail...)...) # pass-through
938+
@inline function _capturescalars()
939+
return (), () -> ()
940+
end
941+
@inline function _capturescalars(arg, mixedargs...)
942+
let (rest, f) = _capturescalars(mixedargs...)
943+
if nonscalararg(arg)
944+
return (arg, rest...), (head, tail...) -> (head, f(tail...)...) # pass-through to broadcast
945+
else
946+
return rest, (tail...) -> (arg, f(tail...)...) # add back scalararg after (in makeargs)
947+
end
945948
end
946-
@inline _capturescalars(scalararg, mixedargs...) =
947-
let f = _capturescalars(mixedargs...)
948-
(tail...) -> (scalararg, f(tail...)...) # add scalararg
949+
end
950+
@inline function _capturescalars(arg) # this definition is just an optimization (to bottom out the recursion slightly sooner)
951+
if nonscalararg(arg)
952+
return (arg,), (head,) -> (head,) # pass-through
953+
else
954+
return (), () -> (arg,) # add scalararg
949955
end
950-
# TODO: use the implicit version once inference can handle it
951-
# handle too-many-arguments explicitly
952-
@inline function _capturescalars()
953-
too_many_arguments() = ()
954-
too_many_arguments(tail...) = throw(ArgumentError("too many"))
955956
end
956-
#@inline _capturescalars(nonscalararg::SparseVecOrMat) =
957-
# (head,) -> (head,) # pass-through
958-
#@inline _capturescalars(scalararg) =
959-
# () -> (scalararg,) # add scalararg
960957

961958
# NOTE: The following two method definitions work around #19096.
962959
broadcast(f::Tf, ::Type{T}, A::SparseMatrixCSC) where {Tf,T} = broadcast(y -> f(T, y), A)

src/rtutils.c

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -538,14 +538,12 @@ static size_t jl_static_show_x_(JL_STREAM *out, jl_value_t *v, jl_datatype_t *vt
538538
else if (vt == jl_method_instance_type) {
539539
jl_method_instance_t *li = (jl_method_instance_t*)v;
540540
if (jl_is_method(li->def.method)) {
541-
jl_method_t *m = li->def.method;
542-
n += jl_static_show_x(out, (jl_value_t*)m->module, depth);
543541
if (li->specTypes) {
544-
n += jl_printf(out, ".");
545-
n += jl_show_svec(out, ((jl_datatype_t*)jl_unwrap_unionall(li->specTypes))->parameters,
546-
jl_symbol_name(m->name), "(", ")");
542+
n += jl_static_show_func_sig(out, li->specTypes);
547543
}
548544
else {
545+
jl_method_t *m = li->def.method;
546+
n += jl_static_show_x(out, (jl_value_t*)m->module, depth);
549547
n += jl_printf(out, ".%s(?)", jl_symbol_name(m->name));
550548
}
551549
}
@@ -949,15 +947,15 @@ JL_DLLEXPORT size_t jl_static_show_func_sig(JL_STREAM *s, jl_value_t *type)
949947
if (ftype == NULL)
950948
return jl_static_show(s, type);
951949
size_t n = 0;
952-
if (jl_nparams(ftype)==0 || ftype == ((jl_datatype_t*)ftype)->name->wrapper) {
950+
if (jl_nparams(ftype) == 0 || ftype == ((jl_datatype_t*)ftype)->name->wrapper) {
953951
n += jl_printf(s, "%s", jl_symbol_name(((jl_datatype_t*)ftype)->name->mt->name));
954952
}
955953
else {
956954
n += jl_printf(s, "(::");
957955
n += jl_static_show(s, ftype);
958956
n += jl_printf(s, ")");
959957
}
960-
// TODO: better way to show method parameters
958+
jl_unionall_t *tvars = (jl_unionall_t*)type;
961959
type = jl_unwrap_unionall(type);
962960
if (!jl_is_datatype(type)) {
963961
n += jl_printf(s, " ");
@@ -984,6 +982,19 @@ JL_DLLEXPORT size_t jl_static_show_func_sig(JL_STREAM *s, jl_value_t *type)
984982
}
985983
}
986984
n += jl_printf(s, ")");
985+
if (jl_is_unionall(tvars)) {
986+
int first = 1;
987+
n += jl_printf(s, " where {");
988+
while (jl_is_unionall(tvars)) {
989+
if (first)
990+
first = 0;
991+
else
992+
n += jl_printf(s, ", ");
993+
n += jl_static_show(s, (jl_value_t*)tvars->var);
994+
tvars = (jl_unionall_t*)tvars->body;
995+
}
996+
n += jl_printf(s, "}");
997+
}
987998
return n;
988999
}
9891000

test/core.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3353,10 +3353,6 @@ end
33533353
@test EmptyIIOtherField13175(EmptyImmutable13175(), 1.0) == EmptyIIOtherField13175(EmptyImmutable13175(), 1.0)
33543354
@test EmptyIIOtherField13175(EmptyImmutable13175(), 1.0) != EmptyIIOtherField13175(EmptyImmutable13175(), 2.0)
33553355

3356-
# issue #13183
3357-
gg13183(x::X...) where {X} = 1==0 ? gg13183(x, x) : 0
3358-
@test gg13183(5) == 0
3359-
33603356
# issue 8932 (llvm return type legalizer error)
33613357
struct Vec3_8932
33623358
x::Float32
@@ -5331,7 +5327,8 @@ module UnionOptimizations
53315327
using Test
53325328

53335329
const boxedunions = [Union{}, Union{String, Void}]
5334-
const unboxedunions = [Union{Int8, Void}, Union{Int8, Float16, Void},
5330+
const unboxedunions = [Union{Int8, Void},
5331+
Union{Int8, Float16, Void},
53355332
Union{Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Int128, UInt128},
53365333
Union{Char, Date, Int}]
53375334

@@ -5457,6 +5454,7 @@ t4 = vcat(A23567, t2, t3)
54575454
@test t4[11:15] == A23567
54585455

54595456
for U in unboxedunions
5457+
Base.unionlen(U) > 5 && continue # larger values cause subtyping to crash
54605458
local U
54615459
for N in (1, 2, 3, 4)
54625460
A = Array{U}(ntuple(x->0, N)...)

test/inference.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,20 @@
22

33
# tests for Core.Inference correctness and precision
44
import Core.Inference: Const, Conditional,
5+
const isleaftype = Core.Inference._isleaftype
6+
7+
# demonstrate some of the type-size limits
8+
@test Core.Inference.limit_type_size(Ref{Complex{T} where T}, Ref, Ref, 0) == Ref
9+
@test Core.Inference.limit_type_size(Ref{Complex{T} where T}, Ref{Complex{T} where T}, Ref, 0) == Ref{Complex{T} where T}
10+
let comparison = Tuple{X, X} where X<:Tuple
11+
sig = Tuple{X, X} where X<:comparison
12+
ref = Tuple{X, X} where X
13+
@test Core.Inference.limit_type_size(sig, comparison, comparison, 10) == comparison
14+
@test Core.Inference.limit_type_size(sig, ref, comparison, 10) == comparison
15+
@test Core.Inference.limit_type_size(Tuple{sig}, Tuple{ref}, comparison, 10) == Tuple{comparison}
16+
@test Core.Inference.limit_type_size(sig, ref, Tuple{comparison}, 10) == sig
17+
end
18+
519

620
# issue 9770
721
@noinline x9770() = false
@@ -186,7 +200,6 @@ function find_tvar10930(arg)
186200
end
187201
@test find_tvar10930(Vararg{Int}) === 1
188202

189-
const isleaftype = Base._isleaftype
190203

191204
# issue #12474
192205
@generated function f12474(::Any)
@@ -980,13 +993,13 @@ copy_dims_out(out) = ()
980993
copy_dims_out(out, dim::Int, tail...) = copy_dims_out((out..., dim), tail...)
981994
copy_dims_out(out, dim::Colon, tail...) = copy_dims_out((out..., dim), tail...)
982995
@test Base.return_types(copy_dims_out, (Tuple{}, Vararg{Union{Int,Colon}})) == Any[Tuple{}, Tuple{}, Tuple{}]
983-
@test all(m -> 2 < count_specializations(m) < 15, methods(copy_dims_out))
996+
@test all(m -> 10 < count_specializations(m) < 25, methods(copy_dims_out))
984997

985998
copy_dims_pair(out) = ()
986-
copy_dims_pair(out, dim::Int, tail...) = copy_dims_out(out => dim, tail...)
987-
copy_dims_pair(out, dim::Colon, tail...) = copy_dims_out(out => dim, tail...)
999+
copy_dims_pair(out, dim::Int, tail...) = copy_dims_pair(out => dim, tail...)
1000+
copy_dims_pair(out, dim::Colon, tail...) = copy_dims_pair(out => dim, tail...)
9881001
@test Base.return_types(copy_dims_pair, (Tuple{}, Vararg{Union{Int,Colon}})) == Any[Tuple{}, Tuple{}, Tuple{}]
989-
@test all(m -> 5 < count_specializations(m) < 25, methods(copy_dims_out))
1002+
@test all(m -> 5 < count_specializations(m) < 25, methods(copy_dims_pair))
9901003

9911004
# splatting an ::Any should still allow inference to use types of parameters preceding it
9921005
f22364(::Int, ::Any...) = 0
@@ -1225,3 +1238,8 @@ end
12251238
let t = Tuple{Type{T23786{D, N} where N where D<:Tuple{Vararg{Array{T, 1} where T, N} where N}}}
12261239
@test Core.Inference.limit_type_depth(t, 4) >: t
12271240
end
1241+
1242+
# issue #13183
1243+
_false13183 = false
1244+
gg13183(x::X...) where {X} = (_false13183 ? gg13183(x, x) : 0)
1245+
@test gg13183(5) == 0

0 commit comments

Comments
 (0)