Skip to content

Commit 5a4dd99

Browse files
committed
correctly limit depth and length
remove code to handle exponential blowup, since there isn't any
1 parent 41ea4bf commit 5a4dd99

File tree

4 files changed

+89
-88
lines changed

4 files changed

+89
-88
lines changed

base/inference.jl

Lines changed: 50 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -907,8 +907,8 @@ end
907907
function limit_type_size(@nospecialize(t), @nospecialize(compare), @nospecialize(source), allowed_tuplelen::Int)
908908
source = svec(unwrap_unionall(compare), unwrap_unionall(source))
909909
source[1] === source[2] && (source = svec(source[1]))
910-
type_more_complex(t, compare, source, TUPLE_COMPLEXITY_LIMIT_DEPTH, allowed_tuplelen) || return t
911-
r = _limit_type_size(t, compare, source, allowed_tuplelen)
910+
type_more_complex(t, compare, source, 1, TUPLE_COMPLEXITY_LIMIT_DEPTH, allowed_tuplelen) || return t
911+
r = _limit_type_size(t, compare, source, 1, allowed_tuplelen)
912912
@assert t <: r
913913
#@assert r === _limit_type_size(r, t, source) # this monotonicity constraint is slightly stronger than actually required,
914914
# since we only actually need to demonstrate that repeated application would reaches a fixed point,
@@ -918,7 +918,7 @@ end
918918

919919
sym_isless(a::Symbol, b::Symbol) = ccall(:strcmp, Int32, (Ptr{UInt8}, Ptr{UInt8}), a, b) < 0
920920

921-
function type_more_complex(@nospecialize(t), @nospecialize(c), sources::SimpleVector, tupledepth::Int, allowed_tuplelen::Int)
921+
function type_more_complex(@nospecialize(t), @nospecialize(c), sources::SimpleVector, depth::Int, tupledepth::Int, allowed_tuplelen::Int)
922922
# detect cases where the comparison is trivial
923923
if t === c
924924
return false
@@ -928,7 +928,7 @@ function type_more_complex(@nospecialize(t), @nospecialize(c), sources::SimpleVe
928928
return false # fastpath: unparameterized types are always finite
929929
elseif tupledepth > 0 && isa(unwrap_unionall(t), DataType) && isa(c, Type) && c !== Union{} && c <: t
930930
return false # t is already wider than the comparison in the type lattice
931-
elseif tupledepth > 0 && is_derived_type_from_any(unwrap_unionall(t), sources)
931+
elseif tupledepth > 0 && is_derived_type_from_any(unwrap_unionall(t), sources, depth)
932932
return false # t isn't something new
933933
end
934934
# peel off wrappers
@@ -942,19 +942,20 @@ function type_more_complex(@nospecialize(t), @nospecialize(c), sources::SimpleVe
942942
end
943943
# rules for various comparison types
944944
if isa(c, TypeVar)
945+
tupledepth = 1 # allow replacing a TypeVar with a concrete value (since we know the UnionAll must be in covariant position)
945946
if isa(t, TypeVar)
946947
return !(t.lb === Union{} || t.lb === c.lb) || # simplify lb towards Union{}
947-
type_more_complex(t.ub, c.ub, sources, tupledepth, 0)
948+
type_more_complex(t.ub, c.ub, sources, depth + 1, tupledepth, 0)
948949
end
949950
c.lb === Union{} || return true
950-
return type_more_complex(t, c.ub, sources, max(tupledepth, 1), 0) # allow replacing a TypeVar with a concrete value
951+
return type_more_complex(t, c.ub, sources, depth, tupledepth, 0)
951952
elseif isa(c, Union)
952953
if isa(t, Union)
953-
return type_more_complex(t.a, c.a, sources, tupledepth, allowed_tuplelen) ||
954-
type_more_complex(t.b, c.b, sources, tupledepth, allowed_tuplelen)
954+
return type_more_complex(t.a, c.a, sources, depth, tupledepth, allowed_tuplelen) ||
955+
type_more_complex(t.b, c.b, sources, depth, tupledepth, allowed_tuplelen)
955956
end
956-
return type_more_complex(t, c.a, sources, tupledepth, allowed_tuplelen) &&
957-
type_more_complex(t, c.b, sources, tupledepth, allowed_tuplelen)
957+
return type_more_complex(t, c.a, sources, depth, tupledepth, allowed_tuplelen) &&
958+
type_more_complex(t, c.b, sources, depth, tupledepth, allowed_tuplelen)
958959
elseif isa(t, Int) && isa(c, Int)
959960
return t !== 1 # alternatively, could use !(0 <= t < c)
960961
end
@@ -987,34 +988,41 @@ function type_more_complex(@nospecialize(t), @nospecialize(c), sources::SimpleVe
987988
end
988989
end
989990
end
990-
type_more_complex(tPi, cPi, sources, tupledepth, 0) && return true
991+
type_more_complex(tPi, cPi, sources, depth + 1, tupledepth, 0) && return true
991992
end
992993
return false
993994
elseif isvarargtype(c)
994-
return type_more_complex(t, unwrapva(c), sources, tupledepth, 0)
995+
return type_more_complex(t, unwrapva(c), sources, depth, tupledepth, 0)
995996
end
996997
if isType(t) # allow taking typeof any source type anywhere as Type{...}, as long as it isn't nesting Type{Type{...}}
997998
tt = unwrap_unionall(t.parameters[1])
998999
if isa(tt, DataType) && !isType(tt)
999-
is_derived_type_from_any(tt, sources) || return true
1000+
is_derived_type_from_any(tt, sources, depth) || return true
10001001
return false
10011002
end
10021003
end
10031004
end
10041005
return true
10051006
end
10061007

1007-
function is_derived_type(@nospecialize(t), @nospecialize(c)) # try to find `type` somewhere in `comparison` type
1008-
t === c && return true
1008+
# try to find `type` somewhere in `comparison` type
1009+
# at a minimum nesting depth of `mindepth`
1010+
function is_derived_type(@nospecialize(t), @nospecialize(c), mindepth::Int)
1011+
if mindepth > 0
1012+
mindepth -= 1
1013+
end
1014+
if t === c
1015+
return mindepth == 0
1016+
end
10091017
if isa(c, TypeVar)
10101018
# see if it is replacing a TypeVar upper bound with something simpler
1011-
return is_derived_type(t, c.ub)
1019+
return is_derived_type(t, c.ub, mindepth)
10121020
elseif isa(c, Union)
10131021
# see if it is one of the elements of the union
1014-
return is_derived_type(t, c.a) || is_derived_type(t, c.b)
1022+
return is_derived_type(t, c.a, mindepth + 1) || is_derived_type(t, c.b, mindepth + 1)
10151023
elseif isa(c, UnionAll)
10161024
# see if it is derived from the body
1017-
return is_derived_type(t, c.body)
1025+
return is_derived_type(t, c.body, mindepth)
10181026
elseif isa(c, DataType)
10191027
if isa(t, DataType)
10201028
# see if it is one of the supertypes of a parameter
@@ -1027,7 +1035,7 @@ function is_derived_type(@nospecialize(t), @nospecialize(c)) # try to find `type
10271035
# see if it was extracted from a type parameter
10281036
cP = c.parameters
10291037
for p in cP
1030-
is_derived_type(t, p) && return true
1038+
is_derived_type(t, p, mindepth) && return true
10311039
end
10321040
if isleaftype(c) && isbits(c)
10331041
# see if it was extracted from a fieldtype
@@ -1038,21 +1046,22 @@ function is_derived_type(@nospecialize(t), @nospecialize(c)) # try to find `type
10381046
# it cannot have a reference cycle in the type graph
10391047
cF = c.types
10401048
for f in cF
1041-
is_derived_type(t, f) && return true
1049+
is_derived_type(t, f, mindepth) && return true
10421050
end
10431051
end
10441052
end
10451053
return false
10461054
end
10471055

1048-
function is_derived_type_from_any(@nospecialize(t), sources::SimpleVector)
1056+
function is_derived_type_from_any(@nospecialize(t), sources::SimpleVector, mindepth::Int)
10491057
for s in sources
1050-
is_derived_type(t, s) && return true
1058+
is_derived_type(t, s, mindepth) && return true
10511059
end
10521060
return false
10531061
end
10541062

1055-
function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVector, allowed_tuplelen::Int) # type vs. comparison which was derived from source
1063+
# type vs. comparison or which was derived from source
1064+
function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVector, depth::Int, allowed_tuplelen::Int)
10561065
if t === c
10571066
return t # quick egal test
10581067
elseif t === Union{}
@@ -1061,7 +1070,7 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec
10611070
return t # fast path: unparameterized are always simple
10621071
elseif isa(unwrap_unionall(t), DataType) && isa(c, Type) && c !== Union{} && c <: t
10631072
return t # t is already wider than the comparison in the type lattice
1064-
elseif is_derived_type_from_any(unwrap_unionall(t), sources)
1073+
elseif is_derived_type_from_any(unwrap_unionall(t), sources, depth)
10651074
return t # t isn't something new
10661075
end
10671076
if isa(t, TypeVar)
@@ -1072,8 +1081,8 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec
10721081
end
10731082
elseif isa(t, Union)
10741083
if isa(c, Union)
1075-
a = _limit_type_size(t.a, c.a, sources, allowed_tuplelen)
1076-
b = _limit_type_size(t.b, c.b, sources, allowed_tuplelen)
1084+
a = _limit_type_size(t.a, c.a, sources, depth, allowed_tuplelen)
1085+
b = _limit_type_size(t.b, c.b, sources, depth, allowed_tuplelen)
10771086
return Union{a, b}
10781087
end
10791088
elseif isa(t, UnionAll)
@@ -1082,11 +1091,11 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec
10821091
cv = c.var
10831092
if tv.ub === cv.ub
10841093
if tv.lb === cv.lb
1085-
return UnionAll(tv, _limit_type_size(t.body, c.body, sources, allowed_tuplelen))
1094+
return UnionAll(tv, _limit_type_size(t.body, c.body, sources, depth + 1, allowed_tuplelen))
10861095
end
10871096
ub = tv.ub
10881097
else
1089-
ub = _limit_type_size(tv.ub, cv.ub, sources, 0)
1098+
ub = _limit_type_size(tv.ub, cv.ub, sources, depth + 1, 0)
10901099
end
10911100
if tv.lb === cv.lb
10921101
lb = tv.lb
@@ -1095,21 +1104,21 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec
10951104
lb = Bottom
10961105
end
10971106
v2 = TypeVar(tv.name, lb, ub)
1098-
return UnionAll(v2, _limit_type_size(t{v2}, c{v2}, sources, allowed_tuplelen))
1107+
return UnionAll(v2, _limit_type_size(t{v2}, c{v2}, sources, depth + 1, allowed_tuplelen))
10991108
end
1100-
tbody = _limit_type_size(t.body, c, sources, allowed_tuplelen)
1109+
tbody = _limit_type_size(t.body, c, sources, depth + 1, allowed_tuplelen)
11011110
tbody === t.body && return t
11021111
return UnionAll(t.var, tbody)
11031112
elseif isa(c, UnionAll)
11041113
# peel off non-matching wrapper of comparison
1105-
return _limit_type_size(t, c.body, sources, allowed_tuplelen)
1114+
return _limit_type_size(t, c.body, sources, depth, allowed_tuplelen)
11061115
elseif isa(t, DataType)
11071116
if isa(c, DataType)
11081117
tP = t.parameters
11091118
cP = c.parameters
11101119
if t.name === c.name && !isempty(cP)
11111120
if isvarargtype(t)
1112-
VaT = _limit_type_size(tP[1], cP[1], sources, 0)
1121+
VaT = _limit_type_size(tP[1], cP[1], sources, depth + 1, 0)
11131122
N = tP[2]
11141123
if isa(N, TypeVar) || N === cP[2]
11151124
return Vararg{VaT, N}
@@ -1136,19 +1145,19 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec
11361145
else
11371146
cPi = Any
11381147
end
1139-
Q[i] = _limit_type_size(Q[i], cPi, sources, 0)
1148+
Q[i] = _limit_type_size(Q[i], cPi, sources, depth + 1, 0)
11401149
end
11411150
return Tuple{Q...}
11421151
end
11431152
elseif isvarargtype(c)
11441153
# Tuple{Vararg{T}} --> Tuple{T} is OK
1145-
return _limit_type_size(t, cP[1], sources, 0)
1154+
return _limit_type_size(t, cP[1], sources, depth, 0)
11461155
end
11471156
end
11481157
if isType(t) # allow taking typeof as Type{...}, but ensure it doesn't start nesting
11491158
tt = unwrap_unionall(t.parameters[1])
11501159
if isa(tt, DataType) && !isType(tt)
1151-
is_derived_type_from_any(tt, sources) && return t
1160+
is_derived_type_from_any(tt, sources, depth) && return t
11521161
end
11531162
end
11541163
if isvarargtype(t)
@@ -1864,43 +1873,23 @@ function abstract_call_method(method::Method, @nospecialize(f), @nospecialize(si
18641873
end
18651874

18661875
if limited
1867-
newsig = sig
18681876
sigtuple = unwrap_unionall(sig)::DataType
18691877
msig = unwrap_unionall(method.sig)::DataType
1870-
max_spec_len = length(msig.parameters) + 1
1878+
spec_len = length(msig.parameters) + 1
18711879
ls = length(sigtuple.parameters)
18721880
if method === sv.linfo.def
18731881
# direct self-recursion permits much greater use of reducers
18741882
# without using non-local state (just the total edge)
18751883
# here we assume that complexity(specTypes) :>= complexity(sig)
18761884
comparison = sv.linfo.specTypes
18771885
l_comparison = length(unwrap_unionall(comparison).parameters)
1878-
max_spec_len = max(max_spec_len, l_comparison)
1886+
spec_len = max(spec_len, l_comparison)
18791887
else
18801888
comparison = method.sig
18811889
end
1882-
if method.isva && ls > max_spec_len
1883-
# limit length based on size of definition signature.
1884-
# for example, given function f(T, Any...), limit to 3 arguments
1885-
# instead of the default (MAX_TUPLETYPE_LEN)
1886-
fst = sigtuple.parameters[max_spec_len]
1887-
allsame = true
1888-
# allow specializing on longer arglists if all the trailing
1889-
# arguments are the same, since there is no exponential
1890-
# blowup in this case.
1891-
for i = (max_spec_len + 1):ls
1892-
if sigtuple.parameters[i] != fst
1893-
allsame = false
1894-
break
1895-
end
1896-
end
1897-
if !allsame
1898-
sigtuple = limit_tuple_type_n(sigtuple, max_spec_len)
1899-
newsig = rewrap_unionall(sigtuple, newsig)
1900-
end
1901-
end
1902-
# see if the type is still too big, and limit it further if still required
1903-
newsig = limit_type_size(newsig, comparison, sv.linfo.specTypes, max_spec_len)
1890+
# see if the type is too big, and limit it if required
1891+
newsig = limit_type_size(sig, comparison, sv.linfo.specTypes, spec_len)
1892+
19041893
if newsig !== sig
19051894
if !sv.limited
19061895
# continue inference, but limit parameter complexity to ensure (quick) convergence
@@ -1937,6 +1926,7 @@ function abstract_call_method(method::Method, @nospecialize(f), @nospecialize(si
19371926
end
19381927
sparams = recomputed[2]::SimpleVector
19391928
end
1929+
19401930
rt, edge = typeinf_edge(method, sig, sparams, sv)
19411931
edge !== nothing && add_backedge!(edge::MethodInstance, sv)
19421932
return rt

base/sparse/higherorderfns.jl

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -926,37 +926,34 @@ end
926926
# vectors/matrices in mixedargs in their orginal order, and such that the result of
927927
# broadcast(parevalf, passedargstup...) is broadcast(f, mixedargs...)
928928
@inline function capturescalars(f, mixedargs)
929-
let makeargs = _capturescalars(mixedargs...),
930-
parevalf = (passed...) -> f(makeargs(passed...)...),
931-
passedsrcargstup = _capturenonscalars(mixedargs...)
929+
let (passedsrcargstup, makeargs) = _capturescalars(mixedargs...)
930+
parevalf = (passed...) -> f(makeargs(passed...)...)
932931
return (parevalf, passedsrcargstup)
933932
end
934933
end
935934

936-
@inline _capturenonscalars(nonscalararg::SparseVecOrMat, mixedargs...) =
937-
(nonscalararg, _capturenonscalars(mixedargs...)...)
938-
@inline _capturenonscalars(scalararg, mixedargs...) =
939-
_capturenonscalars(mixedargs...)
940-
@inline _capturenonscalars() = ()
935+
nonscalararg(::SparseVecOrMat) = true
936+
nonscalararg(::Any) = false
941937

942-
@inline _capturescalars(nonscalararg::SparseVecOrMat, mixedargs...) =
943-
let f = _capturescalars(mixedargs...)
944-
(head, tail...) -> (head, f(tail...)...) # pass-through
938+
@inline function _capturescalars()
939+
return (), () -> ()
940+
end
941+
@inline function _capturescalars(arg, mixedargs...)
942+
let (rest, f) = _capturescalars(mixedargs...)
943+
if nonscalararg(arg)
944+
return (arg, rest...), (head, tail...) -> (head, f(tail...)...) # pass-through to broadcast
945+
else
946+
return rest, (tail...) -> (arg, f(tail...)...) # add back scalararg after (in makeargs)
947+
end
945948
end
946-
@inline _capturescalars(scalararg, mixedargs...) =
947-
let f = _capturescalars(mixedargs...)
948-
(tail...) -> (scalararg, f(tail...)...) # add scalararg
949+
end
950+
@inline function _capturescalars(arg) # this definition is just an optimization (to bottom out the recursion slightly sooner)
951+
if nonscalararg(arg)
952+
return (arg,), (head,) -> (head,) # pass-through
953+
else
954+
return (), () -> (arg,) # add scalararg
949955
end
950-
# TODO: use the implicit version once inference can handle it
951-
# handle too-many-arguments explicitly
952-
@inline function _capturescalars()
953-
too_many_arguments() = ()
954-
too_many_arguments(tail...) = throw(ArgumentError("too many"))
955956
end
956-
#@inline _capturescalars(nonscalararg::SparseVecOrMat) =
957-
# (head,) -> (head,) # pass-through
958-
#@inline _capturescalars(scalararg) =
959-
# () -> (scalararg,) # add scalararg
960957

961958
# NOTE: The following two method definitions work around #19096.
962959
broadcast(f::Tf, ::Type{T}, A::SparseMatrixCSC) where {Tf,T} = broadcast(y -> f(T, y), A)

test/core.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3339,10 +3339,6 @@ end
33393339
@test EmptyIIOtherField13175(EmptyImmutable13175(), 1.0) == EmptyIIOtherField13175(EmptyImmutable13175(), 1.0)
33403340
@test EmptyIIOtherField13175(EmptyImmutable13175(), 1.0) != EmptyIIOtherField13175(EmptyImmutable13175(), 2.0)
33413341

3342-
# issue #13183
3343-
gg13183(x::X...) where {X} = 1==0 ? gg13183(x, x) : 0
3344-
@test gg13183(5) == 0
3345-
33463342
# issue 8932 (llvm return type legalizer error)
33473343
struct Vec3_8932
33483344
x::Float32

test/inference.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,20 @@
22

33
# tests for Core.Inference correctness and precision
44
import Core.Inference: Const, Conditional,
5+
const isleaftype = Core.Inference._isleaftype
6+
7+
# demonstrate some of the type-size limits
8+
@test Core.Inference.limit_type_size(Ref{Complex{T} where T}, Ref, Ref, 0) == Ref
9+
@test Core.Inference.limit_type_size(Ref{Complex{T} where T}, Ref{Complex{T} where T}, Ref, 0) == Ref{Complex{T} where T}
10+
let comparison = Tuple{X, X} where X<:Tuple
11+
sig = Tuple{X, X} where X<:comparison
12+
ref = Tuple{X, X} where X
13+
@test Core.Inference.limit_type_size(sig, comparison, comparison, 10) == comparison
14+
@test Core.Inference.limit_type_size(sig, ref, comparison, 10) == comparison
15+
@test Core.Inference.limit_type_size(Tuple{sig}, Tuple{ref}, comparison, 10) == Tuple{comparison}
16+
@test Core.Inference.limit_type_size(sig, ref, Tuple{comparison}, 10) == sig
17+
end
18+
519

620
# issue 9770
721
@noinline x9770() = false
@@ -186,7 +200,6 @@ function find_tvar10930(arg)
186200
end
187201
@test find_tvar10930(Vararg{Int}) === 1
188202

189-
const isleaftype = Base._isleaftype
190203

191204
# issue #12474
192205
@generated function f12474(::Any)
@@ -1225,3 +1238,8 @@ end
12251238
let t = Tuple{Type{T23786{D, N} where N where D<:Tuple{Vararg{Array{T, 1} where T, N} where N}}}
12261239
@test Core.Inference.limit_type_depth(t, 4) >: t
12271240
end
1241+
1242+
# issue #13183
1243+
_false13183 = false
1244+
gg13183(x::X...) where {X} = (_false13183 ? gg13183(x, x) : 0)
1245+
@test gg13183(5) == 0

0 commit comments

Comments
 (0)