Skip to content

reflection: move signature union-splitting logic under the control of inference #22144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 31, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 175 additions & 115 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1253,8 +1253,34 @@ end

#### recursing into expression ####

# take a Tuple where one or more parameters are Unions
# and return an array such that those Unions are removed
# and `Union{return...} == ty`
function switchtupleunion(ty::ANY)
tparams = (unwrap_unionall(ty)::DataType).parameters
return _switchtupleunion(Any[tparams...], length(tparams), [], ty)
end

function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, origt::ANY)
if i == 0
tpl = rewrap_unionall(Tuple{t...}, origt)
push!(tunion, tpl)
else
ti = t[i]
if isa(ti, Union)
for ty in uniontypes(ti::Union)
t[i] = ty
_switchtupleunion(t, i - 1, tunion, origt)
end
t[i] = ti
else
_switchtupleunion(t, i - 1, tunion, origt)
end
end
return tunion
end

function abstract_call_gf_by_type(f::ANY, atype::ANY, sv::InferenceState)
tm = _topmod(sv)
# don't consider more than N methods. this trades off between
# compiler performance and generated code performance.
# typically, considering many methods means spending lots of time
Expand Down Expand Up @@ -1282,136 +1308,165 @@ function abstract_call_gf_by_type(f::ANY, atype::ANY, sv::InferenceState)
end
min_valid = UInt[typemin(UInt)]
max_valid = UInt[typemax(UInt)]
applicable = _methods_by_ftype(argtype, sv.params.MAX_METHODS, sv.params.world, min_valid, max_valid)
rettype = Bottom
if applicable === false
# this means too many methods matched
return Any
splitunions = 1 < countunionsplit(argtypes) <= sv.params.MAX_UNION_SPLITTING
if splitunions
splitsigs = switchtupleunion(argtype)
applicable = Any[]
for sig_n in splitsigs
xapplicable = _methods_by_ftype(sig_n, sv.params.MAX_METHODS, sv.params.world, min_valid, max_valid)
xapplicable === false && return Any
append!(applicable, xapplicable)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth to keep track of how many different methods the entries in applicable actually refer to and bail out if that is more than sv.params.MAX_METHODS, too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably, for now I'm just moving the existing logic rather than trying to improve on it much

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.

else
applicable = _methods_by_ftype(argtype, sv.params.MAX_METHODS, sv.params.world, min_valid, max_valid)
if applicable === false
# this means too many methods matched
return Any
end
end
applicable = applicable::Array{Any,1}
napplicable = length(applicable)
fullmatch = false
for (m::SimpleVector) in applicable
sig = m[1]
sigtuple = unwrap_unionall(sig)::DataType
method = m[3]::Method
sparams = m[2]::SimpleVector
recomputesvec = false
rettype = Bottom
for i in 1:napplicable
match = applicable[i]::SimpleVector
method = match[3]::Method
if !fullmatch && (argtype <: method.sig)
fullmatch = true
end
sig = match[1]
sigtuple = unwrap_unionall(sig)::DataType
splitunions = false
# TODO: splitunions = 1 < countunionsplit(sigtuple.parameters) * napplicable <= sv.params.MAX_UNION_SPLITTING
# currently this triggers a bug in inference recursion detection
if splitunions
splitsigs = switchtupleunion(sig)
for sig_n in splitsigs
rt = abstract_call_method(method, f, sig_n, svec(), sv)
rettype = tmerge(rettype, rt)
rettype === Any && break
end
rettype === Any && break
else
rt = abstract_call_method(method, f, sig, match[2]::SimpleVector, sv)
rettype = tmerge(rettype, rt)
rettype === Any && break
end
end
if !(fullmatch || rettype === Any)
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
add_mt_backedge(ftname.mt, argtype, sv)
update_valid_age!(min_valid[1], max_valid[1], sv)
end
#print("=> ", rettype, "\n")
return rettype
end

# limit argument type tuple growth
msig = unwrap_unionall(method.sig)
lsig = length(msig.parameters)
ls = length(sigtuple.parameters)
td = type_depth(sig)
mightlimitlength = ls > lsig + 1
mightlimitdepth = td > 2
limitlength = false
if mightlimitlength || mightlimitdepth
# TODO: FIXME: this heuristic depends on non-local state making type-inference unpredictable
cyclei = 0
infstate = sv
while infstate !== nothing
infstate = infstate::InferenceState
if isdefined(infstate.linfo, :def) && method === infstate.linfo.def
if mightlimitlength && ls > length(unwrap_unionall(infstate.linfo.specTypes).parameters)
limitlength = true
end
if mightlimitdepth && td > type_depth(infstate.linfo.specTypes)
# impose limit if we recur and the argument types grow beyond MAX_TYPE_DEPTH
if td > MAX_TYPE_DEPTH
sig = limit_type_depth(sig, 0)
sigtuple = unwrap_unionall(sig)
recomputesvec = true
break
else
p1, p2 = sigtuple.parameters, unwrap_unionall(infstate.linfo.specTypes).parameters
if length(p2) == ls
limitdepth = false
newsig = Vector{Any}(ls)
for i = 1:ls
if p1[i] <: Function && type_depth(p1[i]) > type_depth(p2[i]) &&
isa(p1[i],DataType)
# if a Function argument is growing (e.g. nested closures)
# then widen to the outermost function type. without this
# inference fails to terminate on do_quadgk.
newsig[i] = p1[i].name.wrapper
limitdepth = true
else
newsig[i] = limit_type_depth(p1[i], 1)
end
end
if limitdepth
sigtuple = Tuple{newsig...}
sig = rewrap_unionall(sigtuple, sig)
recomputesvec = true
break
function abstract_call_method(method::Method, f::ANY, sig::ANY, sparams::SimpleVector, sv::InferenceState)
sigtuple = unwrap_unionall(sig)::DataType
recomputesvec = false

# limit argument type tuple growth
msig = unwrap_unionall(method.sig)
lsig = length(msig.parameters)
ls = length(sigtuple.parameters)
td = type_depth(sig)
mightlimitlength = ls > lsig + 1
mightlimitdepth = td > 2
limitlength = false
if mightlimitlength || mightlimitdepth
# TODO: FIXME: this heuristic depends on non-local state making type-inference unpredictable
cyclei = 0
infstate = sv
while infstate !== nothing
infstate = infstate::InferenceState
if isdefined(infstate.linfo, :def) && method === infstate.linfo.def
if mightlimitlength && ls > length(unwrap_unionall(infstate.linfo.specTypes).parameters)
limitlength = true
end
if mightlimitdepth && td > type_depth(infstate.linfo.specTypes)
# impose limit if we recur and the argument types grow beyond MAX_TYPE_DEPTH
if td > MAX_TYPE_DEPTH
sig = limit_type_depth(sig, 0)
sigtuple = unwrap_unionall(sig)
recomputesvec = true
break
else
p1, p2 = sigtuple.parameters, unwrap_unionall(infstate.linfo.specTypes).parameters
if length(p2) == ls
limitdepth = false
newsig = Vector{Any}(ls)
for i = 1:ls
if p1[i] <: Function && type_depth(p1[i]) > type_depth(p2[i]) &&
isa(p1[i],DataType)
# if a Function argument is growing (e.g. nested closures)
# then widen to the outermost function type. without this
# inference fails to terminate on do_quadgk.
newsig[i] = p1[i].name.wrapper
limitdepth = true
else
newsig[i] = limit_type_depth(p1[i], 1)
end
end
if limitdepth
sigtuple = Tuple{newsig...}
sig = rewrap_unionall(sigtuple, sig)
recomputesvec = true
break
end
end
end
end
# iterate through the cycle before walking to the parent
if cyclei < length(infstate.callers_in_cycle)
cyclei += 1
infstate = infstate.callers_in_cycle[cyclei]
else
cyclei = 0
infstate = infstate.parent
end
end
end

# limit length based on size of definition signature.
# for example, given function f(T, Any...), limit to 3 arguments
# instead of the default (MAX_TUPLETYPE_LEN)
if limitlength
if !istopfunction(tm, f, :promote_typeof)
fst = sigtuple.parameters[lsig + 1]
allsame = true
# allow specializing on longer arglists if all the trailing
# arguments are the same, since there is no exponential
# blowup in this case.
for i = (lsig + 2):ls
if sigtuple.parameters[i] != fst
allsame = false
break
end
end
if !allsame
sigtuple = limit_tuple_type_n(sigtuple, lsig + 1)
sig = rewrap_unionall(sigtuple, sig)
recomputesvec = true
# iterate through the cycle before walking to the parent
if cyclei < length(infstate.callers_in_cycle)
cyclei += 1
infstate = infstate.callers_in_cycle[cyclei]
else
cyclei = 0
infstate = infstate.parent
end
end
end

# limit length based on size of definition signature.
# for example, given function f(T, Any...), limit to 3 arguments
# instead of the default (MAX_TUPLETYPE_LEN)
if limitlength
tm = _topmod(sv)
if !istopfunction(tm, f, :promote_typeof)
fst = sigtuple.parameters[lsig + 1]
allsame = true
# allow specializing on longer arglists if all the trailing
# arguments are the same, since there is no exponential
# blowup in this case.
for i = (lsig + 2):ls
if sigtuple.parameters[i] != fst
allsame = false
break
end
end
end

# if sig changed, may need to recompute the sparams environment
if recomputesvec && !isempty(sparams)
recomputed = ccall(:jl_env_from_type_intersection, Ref{SimpleVector}, (Any, Any), sig, method.sig)
sig = recomputed[1]
if !isa(unwrap_unionall(sig), DataType) # probably Union{}
rettype = Any
break
if !allsame
sigtuple = limit_tuple_type_n(sigtuple, lsig + 1)
sig = rewrap_unionall(sigtuple, sig)
recomputesvec = true
end
sparams = recomputed[2]::SimpleVector
end
rt, edge = typeinf_edge(method, sig, sparams, sv)
edge !== nothing && add_backedge!(edge::MethodInstance, sv)
rettype = tmerge(rettype, rt)
if rettype === Any
break
end
end
if !(fullmatch || rettype === Any)
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
add_mt_backedge(ftname.mt, argtype, sv)
update_valid_age!(min_valid[1], max_valid[1], sv)

# if sig changed, may need to recompute the sparams environment
if isa(method.sig, UnionAll) && (recomputesvec || isempty(sparams))
recomputed = ccall(:jl_env_from_type_intersection, Ref{SimpleVector}, (Any, Any), sig, method.sig)
sig = recomputed[1]
if !isa(unwrap_unionall(sig), DataType) # probably Union{}
return Any
end
sparams = recomputed[2]::SimpleVector
end
#print("=> ", rettype, "\n")
return rettype
rt, edge = typeinf_edge(method, sig, sparams, sv)
edge !== nothing && add_backedge!(edge::MethodInstance, sv)
return rt
end

# determine whether `ex` abstractly evals to constant `c`
Expand Down Expand Up @@ -1562,6 +1617,9 @@ function abstract_apply(aft::ANY, fargs::Vector{Any}, aargtypes::Vector{Any}, vt
return res
end

# TODO: this function is a very buggy and poor model of the return_type function
# since abstract_call_gf_by_type is a very inaccurate model of _method and of typeinf_type,
# while this assumes that it is a precisely accurate and exact model of both
function return_type_tfunc(argtypes::ANY, vtypes::VarTable, sv::InferenceState)
if length(argtypes) == 3
tt = argtypes[3]
Expand Down Expand Up @@ -2112,8 +2170,10 @@ function issubconditional(a::Conditional, b::Conditional)
end

function ⊑(a::ANY, b::ANY)
a === NF && return true
b === NF && return false
(a === NF || b === Any) && return true
(a === Any || b === NF) && return false
a === Union{} && return true
b === Union{} && return false
if isa(a, Conditional)
if isa(b, Conditional)
return issubconditional(a, b)
Expand Down Expand Up @@ -3483,7 +3543,7 @@ function is_self_quoting(x::ANY)
return isa(x,Number) || isa(x,AbstractString) || isa(x,Tuple) || isa(x,Type)
end

function countunionsplit(atypes::Vector{Any})
function countunionsplit(atypes)
nu = 1
for ti in atypes
if isa(ti, Union)
Expand Down
37 changes: 0 additions & 37 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -507,46 +507,9 @@ function _methods_by_ftype(t::ANY, lim::Int, world::UInt)
return _methods_by_ftype(t, lim, world, UInt[typemin(UInt)], UInt[typemax(UInt)])
end
function _methods_by_ftype(t::ANY, lim::Int, world::UInt, min::Array{UInt,1}, max::Array{UInt,1})
tp = unwrap_unionall(t).parameters::SimpleVector
nu = 1
for ti in tp
if isa(ti, Union)
nu *= unionlen(ti::Union)
end
end
if 1 < nu <= 64
return _methods_by_ftype(Any[tp...], t, length(tp), lim, [], world, min, max)
end
# XXX: the following can return incorrect answers that the above branch would have corrected
return ccall(:jl_matching_methods, Any, (Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}), t, lim, 0, world, min, max)
end

function _methods_by_ftype(t::Array, origt::ANY, i, lim::Integer, matching::Array{Any,1},
world::UInt, min::Array{UInt,1}, max::Array{UInt,1})
if i == 0
world = typemax(UInt)
new = ccall(:jl_matching_methods, Any, (Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}),
rewrap_unionall(Tuple{t...}, origt), lim, 0, world, min, max)
new === false && return false
append!(matching, new::Array{Any,1})
else
ti = t[i]
if isa(ti, Union)
for ty in uniontypes(ti::Union)
t[i] = ty
if _methods_by_ftype(t, origt, i - 1, lim, matching, world, min, max) === false
t[i] = ti
return false
end
end
t[i] = ti
else
return _methods_by_ftype(t, origt, i - 1, lim, matching, world, min, max)
end
end
return matching
end

# high-level, more convenient method lookup functions

# type for reflecting and pretty-printing a subset of methods
Expand Down
Loading