Skip to content

Commit d81ac15

Browse files
committed
Improve inference for common reflection operations
Make inference be able to infer the type constructor. This is a little tricky, since we don't really have a good way to represent this. It's not quite `Const(TypeVar(:T,lb,ub))`, because a) the bounds may only be accurate up to type equality and b) TypeVar is not `isbits`, so it's not actually egal to the value we'll have at runtime. Additionally Type{T} already has meaning as a partially constructed type (e.g. an unwrapped UnionAll), so using T::Type{T} runs the risk of confusing this with an unwrapped type. Instead, introduce a new Const-like type, `PartialTypeVar`, which carries the type var and also keeps track of whether the bounds were egal or only typequal (we don't take advantage of that yet, but we could in the future). Additionally, improve the inference of `typename` and allow constant folding of field accesses on TypeName and SimpleVectors (to be able to constant fold T.parameters[1]).
1 parent 9554bec commit d81ac15

File tree

2 files changed

+119
-14
lines changed

2 files changed

+119
-14
lines changed

base/inference.jl

+108-14
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,13 @@ type Conditional
8282
end
8383
end
8484

85+
immutable PartialTypeVar
86+
tv::TypeVar
87+
lb_certain::Bool
88+
ub_certain::Bool
89+
PartialTypeVar(tv::TypeVar, lb_certain::ANY, ub_certain::ANY) = new(tv, lb_certain, ub_certain)
90+
end
91+
8592
function rewrap(t::ANY, u::ANY)
8693
isa(t, Const) && return t
8794
isa(t, Conditional) && return t
@@ -678,6 +685,7 @@ function limit_type_depth(t::ANY, d::Int, cov::Bool=true, var::Union{Void,TypeVa
678685
return (cov && !stillcov) ? UnionAll(var, R) : R
679686
end
680687

688+
const DataType_name_fieldindex = fieldindex(DataType, :name)
681689
const DataType_parameters_fieldindex = fieldindex(DataType, :parameters)
682690
const DataType_types_fieldindex = fieldindex(DataType, :types)
683691
const DataType_super_fieldindex = fieldindex(DataType, :super)
@@ -713,7 +721,8 @@ function getfield_tfunc(s00::ANY, name)
713721
if isa(sv, Module) && isa(nv, Symbol)
714722
return abstract_eval_global(sv, nv)
715723
end
716-
if (isa(sv, DataType) || isimmutable(sv)) && isdefined(sv, nv)
724+
if (isa(sv, DataType) || isa(sv, SimpleVector) || isa(sv, TypeName)
725+
|| isimmutable(sv)) && isdefined(sv, nv)
717726
return abstract_eval_constant(getfield(sv, nv))
718727
end
719728
end
@@ -774,7 +783,8 @@ function getfield_tfunc(s00::ANY, name)
774783
sp = nothing
775784
end
776785
if (sp !== nothing &&
777-
(fld == DataType_parameters_fieldindex ||
786+
(fld == DataType_name_fieldindex ||
787+
fld == DataType_parameters_fieldindex ||
778788
fld == DataType_types_fieldindex ||
779789
fld == DataType_super_fieldindex))
780790
return Const(getfield(sp, fld))
@@ -905,15 +915,19 @@ function apply_type_tfunc(headtypetype::ANY, args::ANY...)
905915
return Any
906916
end
907917
uncertain = false
918+
uncertain_typevar = false
908919
tparams = Any[]
909920
outervars = Any[]
910921
for i = 1:largs
911922
ai = args[i]
912923
if isType(ai)
913924
aip1 = ai.parameters[1]
914925
push!(tparams, aip1)
915-
elseif isa(ai, Const) && (isa(ai.val, Type) || valid_tparam(ai.val))
926+
elseif isa(ai, Const) && (isa(ai.val, Type) || isa(ai.val, TypeVar) || valid_tparam(ai.val))
916927
push!(tparams, ai.val)
928+
elseif isa(ai, PartialTypeVar)
929+
uncertain_typevar = true
930+
push!(tparams, ai.tv)
917931
else
918932
# TODO: return `Bottom` for trying to apply a non-UnionAll
919933
uncertain = true
@@ -956,7 +970,8 @@ function apply_type_tfunc(headtypetype::ANY, args::ANY...)
956970
# doesn't match, which could happen if a type estimate is too coarse
957971
return Type{_} where _<:headtype
958972
end
959-
!uncertain && return Const(appl)
973+
!uncertain && !uncertain_typevar && return Const(appl)
974+
!uncertain && return Type{appl}
960975
if isvarargtype(headtype)
961976
return Type
962977
end
@@ -1476,6 +1491,25 @@ function Pair_name()
14761491
return _Pair_name
14771492
end
14781493

1494+
_typename(a) = Union{}
1495+
_typename(a::Vararg) = Any
1496+
_typename(a::TypeVar) = Any
1497+
_typename(a::DataType) = Const(a.name)
1498+
function _typename(a::Union)
1499+
ta = _typename(a.a)
1500+
tb = _typename(a.b)
1501+
ta == tb ? tb : (ta === Any || tb == Any) ? Any : Union{}
1502+
end
1503+
_typename(union::UnionAll) = typename(union.body)
1504+
function typename_static(t)
1505+
# N.B.: typename maps type equivalence classes to a single value
1506+
if isa(t, Const) || isType(t)
1507+
return _typename(isa(t, Const) ? t.val : t.parameters[1])
1508+
else
1509+
return Any
1510+
end
1511+
end
1512+
14791513
function abstract_call(f::ANY, fargs::Union{Tuple{},Vector{Any}}, argtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState)
14801514
if f === _apply
14811515
length(fargs) > 1 || return Any
@@ -1557,19 +1591,63 @@ function abstract_call(f::ANY, fargs::Union{Tuple{},Vector{Any}}, argtypes::Vect
15571591
end
15581592
end
15591593
return Any
1560-
elseif f === UnionAll
1561-
if length(fargs) == 3 && isa(argtypes[2], Const)
1562-
tv = argtypes[2].val
1563-
if isa(tv, TypeVar)
1594+
elseif f === TypeVar
1595+
lb = Union{}
1596+
ub = Any
1597+
ub_certain = lb_certain = true
1598+
if length(fargs) >= 2 && isa(argtypes[2], Const)
1599+
nv = argtypes[2].val
1600+
ubidx = 3
1601+
if length(fargs) >= 4
1602+
ubidx = 4
15641603
if isa(argtypes[3], Const)
1565-
body = argtypes[3].val
1604+
lb = argtypes[3].val
15661605
elseif isType(argtypes[3])
1567-
body = argtypes[3].parameters[1]
1606+
lb = argtypes[3].parameters[1]
1607+
lb_certain = false
15681608
else
1569-
return Any
1609+
return TypeVar
1610+
end
1611+
end
1612+
if length(fargs) >= ubidx
1613+
if isa(argtypes[ubidx], Const)
1614+
ub = argtypes[ubidx].val
1615+
elseif isType(argtypes[ubidx])
1616+
ub = argtypes[ubidx].parameters[1]
1617+
ub_certain = false
1618+
else
1619+
return TypeVar
15701620
end
1571-
return abstract_eval_constant(UnionAll(tv, body))
15721621
end
1622+
tv = TypeVar(nv, lb, ub)
1623+
return PartialTypeVar(tv, lb_certain, ub_certain)
1624+
end
1625+
return TypeVar
1626+
elseif f === UnionAll
1627+
if length(fargs) == 3
1628+
canconst = true
1629+
if isa(argtypes[3], Const)
1630+
body = argtypes[3].val
1631+
elseif isType(argtypes[3])
1632+
body = argtypes[3].parameters[1]
1633+
canconst = false
1634+
else
1635+
return Any
1636+
end
1637+
if isa(argtypes[2], Const)
1638+
tv = argtypes[2].val
1639+
elseif isa(argtypes[2], PartialTypeVar)
1640+
ptv = argtypes[2]
1641+
tv = ptv.tv
1642+
canconst = false
1643+
else
1644+
return Any
1645+
end
1646+
!isa(tv, TypeVar) && return Any
1647+
(!isa(body, Type) || !isa(body, TypeVar)) && return Any
1648+
theunion = UnionAll(tv, body)
1649+
ret = canconst ? abstract_eval_constant(theunion) : Type{theunion}
1650+
return ret
15731651
end
15741652
return Any
15751653
elseif f === return_type
@@ -1595,7 +1673,16 @@ function abstract_call(f::ANY, fargs::Union{Tuple{},Vector{Any}}, argtypes::Vect
15951673

15961674
if length(argtypes)>2 && argtypes[3] Int
15971675
at2 = widenconst(argtypes[2])
1598-
if (at2 <: Tuple ||
1676+
if at2 <: SimpleVector && istopfunction(tm, f, :getindex)
1677+
if isa(argtypes[2], Const) && isa(argtypes[3], Const)
1678+
svecval = argtypes[2].val
1679+
idx = argtypes[3].val
1680+
if isa(idx, Int) && 1 <= idx <= length(svecval) &
1681+
isassigned(svecval, idx)
1682+
return Const(getindex(svecval, idx))
1683+
end
1684+
end
1685+
elseif (at2 <: Tuple ||
15991686
(isa(at2, DataType) && (at2::DataType).name === Pair_name()))
16001687
# allow tuple indexing functions to take advantage of constant
16011688
# index arguments.
@@ -1617,6 +1704,12 @@ function abstract_call(f::ANY, fargs::Union{Tuple{},Vector{Any}}, argtypes::Vect
16171704

16181705
if istopfunction(tm, f, :promote_type) || istopfunction(tm, f, :typejoin)
16191706
return Type
1707+
elseif length(argtypes) == 2 && istopfunction(tm, f, :typename)
1708+
t = argtypes[2]
1709+
if isa(t, Const) || isType(t)
1710+
return typename_static(t)
1711+
end
1712+
return Any
16201713
end
16211714

16221715
if sv.params.inlining
@@ -1905,6 +1998,7 @@ function widenconst(c::Const)
19051998
return typeof(c.val)
19061999
end
19072000
end
2001+
widenconst(c::PartialTypeVar) = TypeVar
19082002
widenconst(t::ANY) = t
19092003

19102004
issubstate(a::VarState, b::VarState) = (a.typ b.typ && a.undef <= b.undef)
@@ -3552,7 +3646,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
35523646
if method.name == :getindex || method.name == :next || method.name == :indexed_next
35533647
if length(atypes) > 2 && atypes[3] Int
35543648
at2 = widenconst(atypes[2])
3555-
if (at2 <: Tuple ||
3649+
if (at2 <: Tuple || at2 <: SimpleVector ||
35563650
(isa(at2, DataType) && (at2::DataType).name === Pair_name()))
35573651
force_infer = true
35583652
end

test/inference.jl

+11
Original file line numberDiff line numberDiff line change
@@ -588,3 +588,14 @@ f11015(a::AT11015) = g11015(Base.fieldtype(typeof(a), :f), true)
588588
g11015(::Type{Bool}, ::Bool) = 2.0
589589
@test Int <: Base.return_types(f11015, (AT11015,))[1]
590590
@test f11015(AT11015(true)) === 1
591+
592+
# Inference for some type-level computation
593+
fUnionAll{T}(::Type{T}) = Type{S} where S <: T
594+
@inferred fUnionAll(Real) == Type{T} where T <: Real
595+
@inferred fUnionAll(Rational{T} where T <: AbstractFloat) == Type{T} where T<:(Rational{S} where S <: AbstractFloat)
596+
597+
fComplicatedUnionAll{T}(::Type{T}) = Type{Tuple{S,rand() >= 0.5 ? Int : Float64}} where S <: T
598+
let pub = Base.parameter_upper_bound, x = fComplicatedUnionAll(Real)
599+
@test pub(pub(x, 1), 1) == Real
600+
@test pub(pub(x, 1), 2) == Int || pub(pub(x, 1), 2) == Float64
601+
end

0 commit comments

Comments
 (0)