Skip to content

Commit 7529a15

Browse files
committed
reflection: move signature union-splitting logic under the control of inference
1 parent ec8e8ee commit 7529a15

File tree

2 files changed

+175
-152
lines changed

2 files changed

+175
-152
lines changed

base/inference.jl

+175-115
Original file line numberDiff line numberDiff line change
@@ -1253,8 +1253,34 @@ end
12531253

12541254
#### recursing into expression ####
12551255

1256+
# take a Tuple where one or more parameters are Unions
1257+
# and return an array such that those Unions are removed
1258+
# and `Union{return...} == ty`
1259+
function switchtupleunion(ty::ANY)
1260+
tparams = (unwrap_unionall(ty)::DataType).parameters
1261+
return _switchtupleunion(Any[tparams...], length(tparams), [], ty)
1262+
end
1263+
1264+
function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, origt::ANY)
1265+
if i == 0
1266+
tpl = rewrap_unionall(Tuple{t...}, origt)
1267+
push!(tunion, tpl)
1268+
else
1269+
ti = t[i]
1270+
if isa(ti, Union)
1271+
for ty in uniontypes(ti::Union)
1272+
t[i] = ty
1273+
_switchtupleunion(t, i - 1, tunion, origt)
1274+
end
1275+
t[i] = ti
1276+
else
1277+
_switchtupleunion(t, i - 1, tunion, origt)
1278+
end
1279+
end
1280+
return tunion
1281+
end
1282+
12561283
function abstract_call_gf_by_type(f::ANY, atype::ANY, sv::InferenceState)
1257-
tm = _topmod(sv)
12581284
# don't consider more than N methods. this trades off between
12591285
# compiler performance and generated code performance.
12601286
# typically, considering many methods means spending lots of time
@@ -1282,136 +1308,165 @@ function abstract_call_gf_by_type(f::ANY, atype::ANY, sv::InferenceState)
12821308
end
12831309
min_valid = UInt[typemin(UInt)]
12841310
max_valid = UInt[typemax(UInt)]
1285-
applicable = _methods_by_ftype(argtype, sv.params.MAX_METHODS, sv.params.world, min_valid, max_valid)
1286-
rettype = Bottom
1287-
if applicable === false
1288-
# this means too many methods matched
1289-
return Any
1311+
splitunions = 1 < countunionsplit(argtypes) <= sv.params.MAX_UNION_SPLITTING
1312+
if splitunions
1313+
splitsigs = switchtupleunion(argtype)
1314+
applicable = Any[]
1315+
for sig_n in splitsigs
1316+
xapplicable = _methods_by_ftype(sig_n, sv.params.MAX_METHODS, sv.params.world, min_valid, max_valid)
1317+
xapplicable === false && return Any
1318+
append!(applicable, xapplicable)
1319+
end
1320+
else
1321+
applicable = _methods_by_ftype(argtype, sv.params.MAX_METHODS, sv.params.world, min_valid, max_valid)
1322+
if applicable === false
1323+
# this means too many methods matched
1324+
return Any
1325+
end
12901326
end
12911327
applicable = applicable::Array{Any,1}
1328+
napplicable = length(applicable)
12921329
fullmatch = false
1293-
for (m::SimpleVector) in applicable
1294-
sig = m[1]
1295-
sigtuple = unwrap_unionall(sig)::DataType
1296-
method = m[3]::Method
1297-
sparams = m[2]::SimpleVector
1298-
recomputesvec = false
1330+
rettype = Bottom
1331+
for i in 1:napplicable
1332+
match = applicable[i]::SimpleVector
1333+
method = match[3]::Method
12991334
if !fullmatch && (argtype <: method.sig)
13001335
fullmatch = true
13011336
end
1337+
sig = match[1]
1338+
sigtuple = unwrap_unionall(sig)::DataType
1339+
splitunions = false
1340+
# TODO: splitunions = 1 < countunionsplit(sigtuple.parameters) * napplicable <= sv.params.MAX_UNION_SPLITTING
1341+
# currently this triggers a bug in inference recursion detection
1342+
if splitunions
1343+
splitsigs = switchtupleunion(sig)
1344+
for sig_n in splitsigs
1345+
rt = abstract_call_method(method, f, sig_n, svec(), sv)
1346+
rettype = tmerge(rettype, rt)
1347+
rettype === Any && break
1348+
end
1349+
rettype === Any && break
1350+
else
1351+
rt = abstract_call_method(method, f, sig, match[2]::SimpleVector, sv)
1352+
rettype = tmerge(rettype, rt)
1353+
rettype === Any && break
1354+
end
1355+
end
1356+
if !(fullmatch || rettype === Any)
1357+
# also need an edge to the method table in case something gets
1358+
# added that did not intersect with any existing method
1359+
add_mt_backedge(ftname.mt, argtype, sv)
1360+
update_valid_age!(min_valid[1], max_valid[1], sv)
1361+
end
1362+
#print("=> ", rettype, "\n")
1363+
return rettype
1364+
end
13021365

1303-
# limit argument type tuple growth
1304-
msig = unwrap_unionall(method.sig)
1305-
lsig = length(msig.parameters)
1306-
ls = length(sigtuple.parameters)
1307-
td = type_depth(sig)
1308-
mightlimitlength = ls > lsig + 1
1309-
mightlimitdepth = td > 2
1310-
limitlength = false
1311-
if mightlimitlength || mightlimitdepth
1312-
# TODO: FIXME: this heuristic depends on non-local state making type-inference unpredictable
1313-
cyclei = 0
1314-
infstate = sv
1315-
while infstate !== nothing
1316-
infstate = infstate::InferenceState
1317-
if isdefined(infstate.linfo, :def) && method === infstate.linfo.def
1318-
if mightlimitlength && ls > length(unwrap_unionall(infstate.linfo.specTypes).parameters)
1319-
limitlength = true
1320-
end
1321-
if mightlimitdepth && td > type_depth(infstate.linfo.specTypes)
1322-
# impose limit if we recur and the argument types grow beyond MAX_TYPE_DEPTH
1323-
if td > MAX_TYPE_DEPTH
1324-
sig = limit_type_depth(sig, 0)
1325-
sigtuple = unwrap_unionall(sig)
1326-
recomputesvec = true
1327-
break
1328-
else
1329-
p1, p2 = sigtuple.parameters, unwrap_unionall(infstate.linfo.specTypes).parameters
1330-
if length(p2) == ls
1331-
limitdepth = false
1332-
newsig = Vector{Any}(ls)
1333-
for i = 1:ls
1334-
if p1[i] <: Function && type_depth(p1[i]) > type_depth(p2[i]) &&
1335-
isa(p1[i],DataType)
1336-
# if a Function argument is growing (e.g. nested closures)
1337-
# then widen to the outermost function type. without this
1338-
# inference fails to terminate on do_quadgk.
1339-
newsig[i] = p1[i].name.wrapper
1340-
limitdepth = true
1341-
else
1342-
newsig[i] = limit_type_depth(p1[i], 1)
1343-
end
1344-
end
1345-
if limitdepth
1346-
sigtuple = Tuple{newsig...}
1347-
sig = rewrap_unionall(sigtuple, sig)
1348-
recomputesvec = true
1349-
break
1366+
function abstract_call_method(method::Method, f::ANY, sig::ANY, sparams::SimpleVector, sv::InferenceState)
1367+
sigtuple = unwrap_unionall(sig)::DataType
1368+
recomputesvec = false
1369+
1370+
# limit argument type tuple growth
1371+
msig = unwrap_unionall(method.sig)
1372+
lsig = length(msig.parameters)
1373+
ls = length(sigtuple.parameters)
1374+
td = type_depth(sig)
1375+
mightlimitlength = ls > lsig + 1
1376+
mightlimitdepth = td > 2
1377+
limitlength = false
1378+
if mightlimitlength || mightlimitdepth
1379+
# TODO: FIXME: this heuristic depends on non-local state making type-inference unpredictable
1380+
cyclei = 0
1381+
infstate = sv
1382+
while infstate !== nothing
1383+
infstate = infstate::InferenceState
1384+
if isdefined(infstate.linfo, :def) && method === infstate.linfo.def
1385+
if mightlimitlength && ls > length(unwrap_unionall(infstate.linfo.specTypes).parameters)
1386+
limitlength = true
1387+
end
1388+
if mightlimitdepth && td > type_depth(infstate.linfo.specTypes)
1389+
# impose limit if we recur and the argument types grow beyond MAX_TYPE_DEPTH
1390+
if td > MAX_TYPE_DEPTH
1391+
sig = limit_type_depth(sig, 0)
1392+
sigtuple = unwrap_unionall(sig)
1393+
recomputesvec = true
1394+
break
1395+
else
1396+
p1, p2 = sigtuple.parameters, unwrap_unionall(infstate.linfo.specTypes).parameters
1397+
if length(p2) == ls
1398+
limitdepth = false
1399+
newsig = Vector{Any}(ls)
1400+
for i = 1:ls
1401+
if p1[i] <: Function && type_depth(p1[i]) > type_depth(p2[i]) &&
1402+
isa(p1[i],DataType)
1403+
# if a Function argument is growing (e.g. nested closures)
1404+
# then widen to the outermost function type. without this
1405+
# inference fails to terminate on do_quadgk.
1406+
newsig[i] = p1[i].name.wrapper
1407+
limitdepth = true
1408+
else
1409+
newsig[i] = limit_type_depth(p1[i], 1)
13501410
end
13511411
end
1412+
if limitdepth
1413+
sigtuple = Tuple{newsig...}
1414+
sig = rewrap_unionall(sigtuple, sig)
1415+
recomputesvec = true
1416+
break
1417+
end
13521418
end
13531419
end
13541420
end
1355-
# iterate through the cycle before walking to the parent
1356-
if cyclei < length(infstate.callers_in_cycle)
1357-
cyclei += 1
1358-
infstate = infstate.callers_in_cycle[cyclei]
1359-
else
1360-
cyclei = 0
1361-
infstate = infstate.parent
1362-
end
13631421
end
1364-
end
1365-
1366-
# limit length based on size of definition signature.
1367-
# for example, given function f(T, Any...), limit to 3 arguments
1368-
# instead of the default (MAX_TUPLETYPE_LEN)
1369-
if limitlength
1370-
if !istopfunction(tm, f, :promote_typeof)
1371-
fst = sigtuple.parameters[lsig + 1]
1372-
allsame = true
1373-
# allow specializing on longer arglists if all the trailing
1374-
# arguments are the same, since there is no exponential
1375-
# blowup in this case.
1376-
for i = (lsig + 2):ls
1377-
if sigtuple.parameters[i] != fst
1378-
allsame = false
1379-
break
1380-
end
1381-
end
1382-
if !allsame
1383-
sigtuple = limit_tuple_type_n(sigtuple, lsig + 1)
1384-
sig = rewrap_unionall(sigtuple, sig)
1385-
recomputesvec = true
1422+
# iterate through the cycle before walking to the parent
1423+
if cyclei < length(infstate.callers_in_cycle)
1424+
cyclei += 1
1425+
infstate = infstate.callers_in_cycle[cyclei]
1426+
else
1427+
cyclei = 0
1428+
infstate = infstate.parent
1429+
end
1430+
end
1431+
end
1432+
1433+
# limit length based on size of definition signature.
1434+
# for example, given function f(T, Any...), limit to 3 arguments
1435+
# instead of the default (MAX_TUPLETYPE_LEN)
1436+
if limitlength
1437+
tm = _topmod(sv)
1438+
if !istopfunction(tm, f, :promote_typeof)
1439+
fst = sigtuple.parameters[lsig + 1]
1440+
allsame = true
1441+
# allow specializing on longer arglists if all the trailing
1442+
# arguments are the same, since there is no exponential
1443+
# blowup in this case.
1444+
for i = (lsig + 2):ls
1445+
if sigtuple.parameters[i] != fst
1446+
allsame = false
1447+
break
13861448
end
13871449
end
1388-
end
1389-
1390-
# if sig changed, may need to recompute the sparams environment
1391-
if recomputesvec && !isempty(sparams)
1392-
recomputed = ccall(:jl_env_from_type_intersection, Ref{SimpleVector}, (Any, Any), sig, method.sig)
1393-
sig = recomputed[1]
1394-
if !isa(unwrap_unionall(sig), DataType) # probably Union{}
1395-
rettype = Any
1396-
break
1450+
if !allsame
1451+
sigtuple = limit_tuple_type_n(sigtuple, lsig + 1)
1452+
sig = rewrap_unionall(sigtuple, sig)
1453+
recomputesvec = true
13971454
end
1398-
sparams = recomputed[2]::SimpleVector
1399-
end
1400-
rt, edge = typeinf_edge(method, sig, sparams, sv)
1401-
edge !== nothing && add_backedge!(edge::MethodInstance, sv)
1402-
rettype = tmerge(rettype, rt)
1403-
if rettype === Any
1404-
break
14051455
end
14061456
end
1407-
if !(fullmatch || rettype === Any)
1408-
# also need an edge to the method table in case something gets
1409-
# added that did not intersect with any existing method
1410-
add_mt_backedge(ftname.mt, argtype, sv)
1411-
update_valid_age!(min_valid[1], max_valid[1], sv)
1457+
1458+
# if sig changed, may need to recompute the sparams environment
1459+
if isa(method.sig, UnionAll) && (recomputesvec || isempty(sparams))
1460+
recomputed = ccall(:jl_env_from_type_intersection, Ref{SimpleVector}, (Any, Any), sig, method.sig)
1461+
sig = recomputed[1]
1462+
if !isa(unwrap_unionall(sig), DataType) # probably Union{}
1463+
return Any
1464+
end
1465+
sparams = recomputed[2]::SimpleVector
14121466
end
1413-
#print("=> ", rettype, "\n")
1414-
return rettype
1467+
rt, edge = typeinf_edge(method, sig, sparams, sv)
1468+
edge !== nothing && add_backedge!(edge::MethodInstance, sv)
1469+
return rt
14151470
end
14161471

14171472
# determine whether `ex` abstractly evals to constant `c`
@@ -1562,6 +1617,9 @@ function abstract_apply(aft::ANY, fargs::Vector{Any}, aargtypes::Vector{Any}, vt
15621617
return res
15631618
end
15641619

1620+
# TODO: this function is a very buggy and poor model of the return_type function
1621+
# since abstract_call_gf_by_type is a very inaccurate model of _method and of typeinf_type,
1622+
# while this assumes that it is a precisely accurate and exact model of both
15651623
function return_type_tfunc(argtypes::ANY, vtypes::VarTable, sv::InferenceState)
15661624
if length(argtypes) == 3
15671625
tt = argtypes[3]
@@ -2112,8 +2170,10 @@ function issubconditional(a::Conditional, b::Conditional)
21122170
end
21132171

21142172
function (a::ANY, b::ANY)
2115-
a === NF && return true
2116-
b === NF && return false
2173+
(a === NF || b === Any) && return true
2174+
(a === Any || b === NF) && return false
2175+
a === Union{} && return true
2176+
b === Union{} && return false
21172177
if isa(a, Conditional)
21182178
if isa(b, Conditional)
21192179
return issubconditional(a, b)
@@ -3483,7 +3543,7 @@ function is_self_quoting(x::ANY)
34833543
return isa(x,Number) || isa(x,AbstractString) || isa(x,Tuple) || isa(x,Type)
34843544
end
34853545

3486-
function countunionsplit(atypes::Vector{Any})
3546+
function countunionsplit(atypes)
34873547
nu = 1
34883548
for ti in atypes
34893549
if isa(ti, Union)

base/reflection.jl

-37
Original file line numberDiff line numberDiff line change
@@ -507,46 +507,9 @@ function _methods_by_ftype(t::ANY, lim::Int, world::UInt)
507507
return _methods_by_ftype(t, lim, world, UInt[typemin(UInt)], UInt[typemax(UInt)])
508508
end
509509
function _methods_by_ftype(t::ANY, lim::Int, world::UInt, min::Array{UInt,1}, max::Array{UInt,1})
510-
tp = unwrap_unionall(t).parameters::SimpleVector
511-
nu = 1
512-
for ti in tp
513-
if isa(ti, Union)
514-
nu *= unionlen(ti::Union)
515-
end
516-
end
517-
if 1 < nu <= 64
518-
return _methods_by_ftype(Any[tp...], t, length(tp), lim, [], world, min, max)
519-
end
520-
# XXX: the following can return incorrect answers that the above branch would have corrected
521510
return ccall(:jl_matching_methods, Any, (Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}), t, lim, 0, world, min, max)
522511
end
523512

524-
function _methods_by_ftype(t::Array, origt::ANY, i, lim::Integer, matching::Array{Any,1},
525-
world::UInt, min::Array{UInt,1}, max::Array{UInt,1})
526-
if i == 0
527-
world = typemax(UInt)
528-
new = ccall(:jl_matching_methods, Any, (Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}),
529-
rewrap_unionall(Tuple{t...}, origt), lim, 0, world, min, max)
530-
new === false && return false
531-
append!(matching, new::Array{Any,1})
532-
else
533-
ti = t[i]
534-
if isa(ti, Union)
535-
for ty in uniontypes(ti::Union)
536-
t[i] = ty
537-
if _methods_by_ftype(t, origt, i - 1, lim, matching, world, min, max) === false
538-
t[i] = ti
539-
return false
540-
end
541-
end
542-
t[i] = ti
543-
else
544-
return _methods_by_ftype(t, origt, i - 1, lim, matching, world, min, max)
545-
end
546-
end
547-
return matching
548-
end
549-
550513
# high-level, more convenient method lookup functions
551514

552515
# type for reflecting and pretty-printing a subset of methods

0 commit comments

Comments
 (0)