Skip to content

Commit b89e88e

Browse files
committed
correctly limit depth and length
remove code to handle exponential blowup, since there isn't any
1 parent 813525c commit b89e88e

File tree

4 files changed

+89
-88
lines changed

4 files changed

+89
-88
lines changed

base/inference.jl

+50-60
Original file line numberDiff line numberDiff line change
@@ -909,8 +909,8 @@ end
909909
function limit_type_size(@nospecialize(t), @nospecialize(compare), @nospecialize(source), allowed_tuplelen::Int)
910910
source = svec(unwrap_unionall(compare), unwrap_unionall(source))
911911
source[1] === source[2] && (source = svec(source[1]))
912-
type_more_complex(t, compare, source, TUPLE_COMPLEXITY_LIMIT_DEPTH, allowed_tuplelen) || return t
913-
r = _limit_type_size(t, compare, source, allowed_tuplelen)
912+
type_more_complex(t, compare, source, 1, TUPLE_COMPLEXITY_LIMIT_DEPTH, allowed_tuplelen) || return t
913+
r = _limit_type_size(t, compare, source, 1, allowed_tuplelen)
914914
@assert t <: r
915915
#@assert r === _limit_type_size(r, t, source) # this monotonicity constraint is slightly stronger than actually required,
916916
# since we only actually need to demonstrate that repeated application would reaches a fixed point,
@@ -920,7 +920,7 @@ end
920920

921921
sym_isless(a::Symbol, b::Symbol) = ccall(:strcmp, Int32, (Ptr{UInt8}, Ptr{UInt8}), a, b) < 0
922922

923-
function type_more_complex(@nospecialize(t), @nospecialize(c), sources::SimpleVector, tupledepth::Int, allowed_tuplelen::Int)
923+
function type_more_complex(@nospecialize(t), @nospecialize(c), sources::SimpleVector, depth::Int, tupledepth::Int, allowed_tuplelen::Int)
924924
# detect cases where the comparison is trivial
925925
if t === c
926926
return false
@@ -930,7 +930,7 @@ function type_more_complex(@nospecialize(t), @nospecialize(c), sources::SimpleVe
930930
return false # fastpath: unparameterized types are always finite
931931
elseif tupledepth > 0 && isa(unwrap_unionall(t), DataType) && isa(c, Type) && c !== Union{} && c <: t
932932
return false # t is already wider than the comparison in the type lattice
933-
elseif tupledepth > 0 && is_derived_type_from_any(unwrap_unionall(t), sources)
933+
elseif tupledepth > 0 && is_derived_type_from_any(unwrap_unionall(t), sources, depth)
934934
return false # t isn't something new
935935
end
936936
# peel off wrappers
@@ -944,19 +944,20 @@ function type_more_complex(@nospecialize(t), @nospecialize(c), sources::SimpleVe
944944
end
945945
# rules for various comparison types
946946
if isa(c, TypeVar)
947+
tupledepth = 1 # allow replacing a TypeVar with a concrete value (since we know the UnionAll must be in covariant position)
947948
if isa(t, TypeVar)
948949
return !(t.lb === Union{} || t.lb === c.lb) || # simplify lb towards Union{}
949-
type_more_complex(t.ub, c.ub, sources, tupledepth, 0)
950+
type_more_complex(t.ub, c.ub, sources, depth + 1, tupledepth, 0)
950951
end
951952
c.lb === Union{} || return true
952-
return type_more_complex(t, c.ub, sources, max(tupledepth, 1), 0) # allow replacing a TypeVar with a concrete value
953+
return type_more_complex(t, c.ub, sources, depth, tupledepth, 0)
953954
elseif isa(c, Union)
954955
if isa(t, Union)
955-
return type_more_complex(t.a, c.a, sources, tupledepth, allowed_tuplelen) ||
956-
type_more_complex(t.b, c.b, sources, tupledepth, allowed_tuplelen)
956+
return type_more_complex(t.a, c.a, sources, depth, tupledepth, allowed_tuplelen) ||
957+
type_more_complex(t.b, c.b, sources, depth, tupledepth, allowed_tuplelen)
957958
end
958-
return type_more_complex(t, c.a, sources, tupledepth, allowed_tuplelen) &&
959-
type_more_complex(t, c.b, sources, tupledepth, allowed_tuplelen)
959+
return type_more_complex(t, c.a, sources, depth, tupledepth, allowed_tuplelen) &&
960+
type_more_complex(t, c.b, sources, depth, tupledepth, allowed_tuplelen)
960961
elseif isa(t, Int) && isa(c, Int)
961962
return t !== 1 # alternatively, could use !(0 <= t < c)
962963
end
@@ -989,34 +990,41 @@ function type_more_complex(@nospecialize(t), @nospecialize(c), sources::SimpleVe
989990
end
990991
end
991992
end
992-
type_more_complex(tPi, cPi, sources, tupledepth, 0) && return true
993+
type_more_complex(tPi, cPi, sources, depth + 1, tupledepth, 0) && return true
993994
end
994995
return false
995996
elseif isvarargtype(c)
996-
return type_more_complex(t, unwrapva(c), sources, tupledepth, 0)
997+
return type_more_complex(t, unwrapva(c), sources, depth, tupledepth, 0)
997998
end
998999
if isType(t) # allow taking typeof any source type anywhere as Type{...}, as long as it isn't nesting Type{Type{...}}
9991000
tt = unwrap_unionall(t.parameters[1])
10001001
if isa(tt, DataType) && !isType(tt)
1001-
is_derived_type_from_any(tt, sources) || return true
1002+
is_derived_type_from_any(tt, sources, depth) || return true
10021003
return false
10031004
end
10041005
end
10051006
end
10061007
return true
10071008
end
10081009

1009-
function is_derived_type(@nospecialize(t), @nospecialize(c)) # try to find `type` somewhere in `comparison` type
1010-
t === c && return true
1010+
# try to find `type` somewhere in `comparison` type
1011+
# at a minimum nesting depth of `mindepth`
1012+
function is_derived_type(@nospecialize(t), @nospecialize(c), mindepth::Int)
1013+
if mindepth > 0
1014+
mindepth -= 1
1015+
end
1016+
if t === c
1017+
return mindepth == 0
1018+
end
10111019
if isa(c, TypeVar)
10121020
# see if it is replacing a TypeVar upper bound with something simpler
1013-
return is_derived_type(t, c.ub)
1021+
return is_derived_type(t, c.ub, mindepth)
10141022
elseif isa(c, Union)
10151023
# see if it is one of the elements of the union
1016-
return is_derived_type(t, c.a) || is_derived_type(t, c.b)
1024+
return is_derived_type(t, c.a, mindepth + 1) || is_derived_type(t, c.b, mindepth + 1)
10171025
elseif isa(c, UnionAll)
10181026
# see if it is derived from the body
1019-
return is_derived_type(t, c.body)
1027+
return is_derived_type(t, c.body, mindepth)
10201028
elseif isa(c, DataType)
10211029
if isa(t, DataType)
10221030
# see if it is one of the supertypes of a parameter
@@ -1029,7 +1037,7 @@ function is_derived_type(@nospecialize(t), @nospecialize(c)) # try to find `type
10291037
# see if it was extracted from a type parameter
10301038
cP = c.parameters
10311039
for p in cP
1032-
is_derived_type(t, p) && return true
1040+
is_derived_type(t, p, mindepth) && return true
10331041
end
10341042
if isleaftype(c) && isbits(c)
10351043
# see if it was extracted from a fieldtype
@@ -1040,21 +1048,22 @@ function is_derived_type(@nospecialize(t), @nospecialize(c)) # try to find `type
10401048
# it cannot have a reference cycle in the type graph
10411049
cF = c.types
10421050
for f in cF
1043-
is_derived_type(t, f) && return true
1051+
is_derived_type(t, f, mindepth) && return true
10441052
end
10451053
end
10461054
end
10471055
return false
10481056
end
10491057

1050-
function is_derived_type_from_any(@nospecialize(t), sources::SimpleVector)
1058+
function is_derived_type_from_any(@nospecialize(t), sources::SimpleVector, mindepth::Int)
10511059
for s in sources
1052-
is_derived_type(t, s) && return true
1060+
is_derived_type(t, s, mindepth) && return true
10531061
end
10541062
return false
10551063
end
10561064

1057-
function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVector, allowed_tuplelen::Int) # type vs. comparison which was derived from source
1065+
# type vs. comparison or which was derived from source
1066+
function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVector, depth::Int, allowed_tuplelen::Int)
10581067
if t === c
10591068
return t # quick egal test
10601069
elseif t === Union{}
@@ -1063,7 +1072,7 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec
10631072
return t # fast path: unparameterized are always simple
10641073
elseif isa(unwrap_unionall(t), DataType) && isa(c, Type) && c !== Union{} && c <: t
10651074
return t # t is already wider than the comparison in the type lattice
1066-
elseif is_derived_type_from_any(unwrap_unionall(t), sources)
1075+
elseif is_derived_type_from_any(unwrap_unionall(t), sources, depth)
10671076
return t # t isn't something new
10681077
end
10691078
if isa(t, TypeVar)
@@ -1074,8 +1083,8 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec
10741083
end
10751084
elseif isa(t, Union)
10761085
if isa(c, Union)
1077-
a = _limit_type_size(t.a, c.a, sources, allowed_tuplelen)
1078-
b = _limit_type_size(t.b, c.b, sources, allowed_tuplelen)
1086+
a = _limit_type_size(t.a, c.a, sources, depth, allowed_tuplelen)
1087+
b = _limit_type_size(t.b, c.b, sources, depth, allowed_tuplelen)
10791088
return Union{a, b}
10801089
end
10811090
elseif isa(t, UnionAll)
@@ -1084,11 +1093,11 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec
10841093
cv = c.var
10851094
if tv.ub === cv.ub
10861095
if tv.lb === cv.lb
1087-
return UnionAll(tv, _limit_type_size(t.body, c.body, sources, allowed_tuplelen))
1096+
return UnionAll(tv, _limit_type_size(t.body, c.body, sources, depth + 1, allowed_tuplelen))
10881097
end
10891098
ub = tv.ub
10901099
else
1091-
ub = _limit_type_size(tv.ub, cv.ub, sources, 0)
1100+
ub = _limit_type_size(tv.ub, cv.ub, sources, depth + 1, 0)
10921101
end
10931102
if tv.lb === cv.lb
10941103
lb = tv.lb
@@ -1097,21 +1106,21 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec
10971106
lb = Bottom
10981107
end
10991108
v2 = TypeVar(tv.name, lb, ub)
1100-
return UnionAll(v2, _limit_type_size(t{v2}, c{v2}, sources, allowed_tuplelen))
1109+
return UnionAll(v2, _limit_type_size(t{v2}, c{v2}, sources, depth + 1, allowed_tuplelen))
11011110
end
1102-
tbody = _limit_type_size(t.body, c, sources, allowed_tuplelen)
1111+
tbody = _limit_type_size(t.body, c, sources, depth + 1, allowed_tuplelen)
11031112
tbody === t.body && return t
11041113
return UnionAll(t.var, tbody)
11051114
elseif isa(c, UnionAll)
11061115
# peel off non-matching wrapper of comparison
1107-
return _limit_type_size(t, c.body, sources, allowed_tuplelen)
1116+
return _limit_type_size(t, c.body, sources, depth, allowed_tuplelen)
11081117
elseif isa(t, DataType)
11091118
if isa(c, DataType)
11101119
tP = t.parameters
11111120
cP = c.parameters
11121121
if t.name === c.name && !isempty(cP)
11131122
if isvarargtype(t)
1114-
VaT = _limit_type_size(tP[1], cP[1], sources, 0)
1123+
VaT = _limit_type_size(tP[1], cP[1], sources, depth + 1, 0)
11151124
N = tP[2]
11161125
if isa(N, TypeVar) || N === cP[2]
11171126
return Vararg{VaT, N}
@@ -1138,19 +1147,19 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec
11381147
else
11391148
cPi = Any
11401149
end
1141-
Q[i] = _limit_type_size(Q[i], cPi, sources, 0)
1150+
Q[i] = _limit_type_size(Q[i], cPi, sources, depth + 1, 0)
11421151
end
11431152
return Tuple{Q...}
11441153
end
11451154
elseif isvarargtype(c)
11461155
# Tuple{Vararg{T}} --> Tuple{T} is OK
1147-
return _limit_type_size(t, cP[1], sources, 0)
1156+
return _limit_type_size(t, cP[1], sources, depth, 0)
11481157
end
11491158
end
11501159
if isType(t) # allow taking typeof as Type{...}, but ensure it doesn't start nesting
11511160
tt = unwrap_unionall(t.parameters[1])
11521161
if isa(tt, DataType) && !isType(tt)
1153-
is_derived_type_from_any(tt, sources) && return t
1162+
is_derived_type_from_any(tt, sources, depth) && return t
11541163
end
11551164
end
11561165
if isvarargtype(t)
@@ -1866,43 +1875,23 @@ function abstract_call_method(method::Method, @nospecialize(f), @nospecialize(si
18661875
end
18671876

18681877
if limited
1869-
newsig = sig
18701878
sigtuple = unwrap_unionall(sig)::DataType
18711879
msig = unwrap_unionall(method.sig)::DataType
1872-
max_spec_len = length(msig.parameters) + 1
1880+
spec_len = length(msig.parameters) + 1
18731881
ls = length(sigtuple.parameters)
18741882
if method === sv.linfo.def
18751883
# direct self-recursion permits much greater use of reducers
18761884
# without using non-local state (just the total edge)
18771885
# here we assume that complexity(specTypes) :>= complexity(sig)
18781886
comparison = sv.linfo.specTypes
18791887
l_comparison = length(unwrap_unionall(comparison).parameters)
1880-
max_spec_len = max(max_spec_len, l_comparison)
1888+
spec_len = max(spec_len, l_comparison)
18811889
else
18821890
comparison = method.sig
18831891
end
1884-
if method.isva && ls > max_spec_len
1885-
# limit length based on size of definition signature.
1886-
# for example, given function f(T, Any...), limit to 3 arguments
1887-
# instead of the default (MAX_TUPLETYPE_LEN)
1888-
fst = sigtuple.parameters[max_spec_len]
1889-
allsame = true
1890-
# allow specializing on longer arglists if all the trailing
1891-
# arguments are the same, since there is no exponential
1892-
# blowup in this case.
1893-
for i = (max_spec_len + 1):ls
1894-
if sigtuple.parameters[i] != fst
1895-
allsame = false
1896-
break
1897-
end
1898-
end
1899-
if !allsame
1900-
sigtuple = limit_tuple_type_n(sigtuple, max_spec_len)
1901-
newsig = rewrap_unionall(sigtuple, newsig)
1902-
end
1903-
end
1904-
# see if the type is still too big, and limit it further if still required
1905-
newsig = limit_type_size(newsig, comparison, sv.linfo.specTypes, max_spec_len)
1892+
# see if the type is too big, and limit it if required
1893+
newsig = limit_type_size(sig, comparison, sv.linfo.specTypes, spec_len)
1894+
19061895
if newsig !== sig
19071896
if !sv.limited
19081897
# continue inference, but limit parameter complexity to ensure (quick) convergence
@@ -1939,6 +1928,7 @@ function abstract_call_method(method::Method, @nospecialize(f), @nospecialize(si
19391928
end
19401929
sparams = recomputed[2]::SimpleVector
19411930
end
1931+
19421932
rt, edge = typeinf_edge(method, sig, sparams, sv)
19431933
edge !== nothing && add_backedge!(edge::MethodInstance, sv)
19441934
return rt

base/sparse/higherorderfns.jl

+20-23
Original file line numberDiff line numberDiff line change
@@ -926,37 +926,34 @@ end
926926
# vectors/matrices in mixedargs in their orginal order, and such that the result of
927927
# broadcast(parevalf, passedargstup...) is broadcast(f, mixedargs...)
928928
@inline function capturescalars(f, mixedargs)
929-
let makeargs = _capturescalars(mixedargs...),
930-
parevalf = (passed...) -> f(makeargs(passed...)...),
931-
passedsrcargstup = _capturenonscalars(mixedargs...)
929+
let (passedsrcargstup, makeargs) = _capturescalars(mixedargs...)
930+
parevalf = (passed...) -> f(makeargs(passed...)...)
932931
return (parevalf, passedsrcargstup)
933932
end
934933
end
935934

936-
@inline _capturenonscalars(nonscalararg::SparseVecOrMat, mixedargs...) =
937-
(nonscalararg, _capturenonscalars(mixedargs...)...)
938-
@inline _capturenonscalars(scalararg, mixedargs...) =
939-
_capturenonscalars(mixedargs...)
940-
@inline _capturenonscalars() = ()
935+
nonscalararg(::SparseVecOrMat) = true
936+
nonscalararg(::Any) = false
941937

942-
@inline _capturescalars(nonscalararg::SparseVecOrMat, mixedargs...) =
943-
let f = _capturescalars(mixedargs...)
944-
(head, tail...) -> (head, f(tail...)...) # pass-through
938+
@inline function _capturescalars()
939+
return (), () -> ()
940+
end
941+
@inline function _capturescalars(arg, mixedargs...)
942+
let (rest, f) = _capturescalars(mixedargs...)
943+
if nonscalararg(arg)
944+
return (arg, rest...), (head, tail...) -> (head, f(tail...)...) # pass-through to broadcast
945+
else
946+
return rest, (tail...) -> (arg, f(tail...)...) # add back scalararg after (in makeargs)
947+
end
945948
end
946-
@inline _capturescalars(scalararg, mixedargs...) =
947-
let f = _capturescalars(mixedargs...)
948-
(tail...) -> (scalararg, f(tail...)...) # add scalararg
949+
end
950+
@inline function _capturescalars(arg) # this definition is just an optimization (to bottom out the recursion slightly sooner)
951+
if nonscalararg(arg)
952+
return (arg,), (head,) -> (head,) # pass-through
953+
else
954+
return (), () -> (arg,) # add scalararg
949955
end
950-
# TODO: use the implicit version once inference can handle it
951-
# handle too-many-arguments explicitly
952-
@inline function _capturescalars()
953-
too_many_arguments() = ()
954-
too_many_arguments(tail...) = throw(ArgumentError("too many"))
955956
end
956-
#@inline _capturescalars(nonscalararg::SparseVecOrMat) =
957-
# (head,) -> (head,) # pass-through
958-
#@inline _capturescalars(scalararg) =
959-
# () -> (scalararg,) # add scalararg
960957

961958
# NOTE: The following two method definitions work around #19096.
962959
broadcast(f::Tf, ::Type{T}, A::SparseMatrixCSC) where {Tf,T} = broadcast(y -> f(T, y), A)

test/core.jl

-4
Original file line numberDiff line numberDiff line change
@@ -3339,10 +3339,6 @@ end
33393339
@test EmptyIIOtherField13175(EmptyImmutable13175(), 1.0) == EmptyIIOtherField13175(EmptyImmutable13175(), 1.0)
33403340
@test EmptyIIOtherField13175(EmptyImmutable13175(), 1.0) != EmptyIIOtherField13175(EmptyImmutable13175(), 2.0)
33413341

3342-
# issue #13183
3343-
gg13183(x::X...) where {X} = 1==0 ? gg13183(x, x) : 0
3344-
@test gg13183(5) == 0
3345-
33463342
# issue 8932 (llvm return type legalizer error)
33473343
struct Vec3_8932
33483344
x::Float32

test/inference.jl

+19-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,20 @@
22

33
# tests for Core.Inference correctness and precision
44
import Core.Inference: Const, Conditional,
5+
const isleaftype = Core.Inference._isleaftype
6+
7+
# demonstrate some of the type-size limits
8+
@test Core.Inference.limit_type_size(Ref{Complex{T} where T}, Ref, Ref, 0) == Ref
9+
@test Core.Inference.limit_type_size(Ref{Complex{T} where T}, Ref{Complex{T} where T}, Ref, 0) == Ref{Complex{T} where T}
10+
let comparison = Tuple{X, X} where X<:Tuple
11+
sig = Tuple{X, X} where X<:comparison
12+
ref = Tuple{X, X} where X
13+
@test Core.Inference.limit_type_size(sig, comparison, comparison, 10) == comparison
14+
@test Core.Inference.limit_type_size(sig, ref, comparison, 10) == comparison
15+
@test Core.Inference.limit_type_size(Tuple{sig}, Tuple{ref}, comparison, 10) == Tuple{comparison}
16+
@test Core.Inference.limit_type_size(sig, ref, Tuple{comparison}, 10) == sig
17+
end
18+
519

620
# issue 9770
721
@noinline x9770() = false
@@ -186,7 +200,6 @@ function find_tvar10930(arg)
186200
end
187201
@test find_tvar10930(Vararg{Int}) === 1
188202

189-
const isleaftype = Base._isleaftype
190203

191204
# issue #12474
192205
@generated function f12474(::Any)
@@ -1225,3 +1238,8 @@ end
12251238
let t = Tuple{Type{T23786{D, N} where N where D<:Tuple{Vararg{Array{T, 1} where T, N} where N}}}
12261239
@test Core.Inference.limit_type_depth(t, 4) >: t
12271240
end
1241+
1242+
# issue #13183
1243+
_false13183 = false
1244+
gg13183(x::X...) where {X} = (_false13183 ? gg13183(x, x) : 0)
1245+
@test gg13183(5) == 0

0 commit comments

Comments
 (0)