Skip to content

Commit

Permalink
Improve inference for common reflection operations
Browse files Browse the repository at this point in the history
Make inference be able to infer the type constructor. This is a little
tricky, since we don't really have a good way to represent this. It's
not quite `Const(TypeVar(:T,lb,ub))`, because a) the bounds may only
be accurate up to type equality and b) TypeVar is not `isbits`, so
it's not actually egal to the value we'll have at runtime. Additionally
Type{T} already has meaning as a partially constructed type (e.g. an
unwrapped UnionAll), so using T::Type{T} runs the risk of confusing
this with an unwrapped type. Instead, introduce a new Const-like type,
`PartialTypeVar`, which carries the type var and also keeps track of
whether the bounds were egal or only typequal (we don't take advantage
of that yet, but we could in the future).

Additionally, improve the inference of `typename` and allow constant folding
of field accesses on TypeName and SimpleVectors (to be able to constant
fold T.parameters[1]).
  • Loading branch information
Keno committed Jan 25, 2017
1 parent cf385fe commit 21a751f
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 14 deletions.
120 changes: 106 additions & 14 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ immutable Const
Const(v::ANY) = new(v)
end

immutable PartialTypeVar
tv::TypeVar
lb_certain::Bool
ub_certain::Bool
PartialTypeVar(tv::TypeVar, lb_certain::ANY, ub_certain::ANY) = new(tv, lb_certain, ub_certain)
end

function rewrap(t::ANY, u::ANY)
isa(t, Const) && return t
rewrap_unionall(t, u)
Expand Down Expand Up @@ -638,6 +645,7 @@ function limit_type_depth(t::ANY, d::Int, cov::Bool=true, var::Union{Void,TypeVa
return (cov && !stillcov) ? UnionAll(var, R) : R
end

const DataType_name_fieldindex = fieldindex(DataType, :name)
const DataType_parameters_fieldindex = fieldindex(DataType, :parameters)
const DataType_types_fieldindex = fieldindex(DataType, :types)
const DataType_super_fieldindex = fieldindex(DataType, :super)
Expand Down Expand Up @@ -671,7 +679,8 @@ function getfield_tfunc(s00::ANY, name)
if isa(sv, Module) && isa(nv, Symbol)
return abstract_eval_global(sv, nv)
end
if (isa(sv, DataType) || isimmutable(sv)) && isdefined(sv, nv)
if (isa(sv, DataType) || isa(sv, SimpleVector) || isa(sv, TypeName)
|| isimmutable(sv)) && isdefined(sv, nv)
return abstract_eval_constant(getfield(sv, nv))
end
end
Expand Down Expand Up @@ -729,7 +738,8 @@ function getfield_tfunc(s00::ANY, name)
sp = nothing
end
if (sp !== nothing &&
(fld == DataType_parameters_fieldindex ||
(fld == DataType_name_fieldindex ||
fld == DataType_parameters_fieldindex ||
fld == DataType_types_fieldindex ||
fld == DataType_super_fieldindex))
return Const(getfield(sp, fld))
Expand Down Expand Up @@ -821,15 +831,19 @@ function apply_type_tfunc(headtypetype::ANY, args::ANY...)
return Any
end
uncertain = false
uncertain_typevar = false
tparams = Any[]
for i = 1:largs
ai = args[i]
if isType(ai)
aip1 = ai.parameters[1]
uncertain |= has_free_typevars(aip1)
push!(tparams, aip1)
elseif isa(ai, Const) && (isa(ai.val, Type) || valid_tparam(ai.val))
elseif isa(ai, Const) && (isa(ai.val, Type) || isa(ai.val, TypeVar) || valid_tparam(ai.val))
push!(tparams, ai.val)
elseif isa(ai, PartialTypeVar)
uncertain_typevar = true
push!(tparams, ai.tv)
else
# TODO: return `Bottom` for trying to apply a non-UnionAll
#if !istuple && i-1 > length(headtype.parameters)
Expand All @@ -855,7 +869,8 @@ function apply_type_tfunc(headtypetype::ANY, args::ANY...)
appl = headtype
uncertain = true
end
!uncertain && return Const(appl)
!uncertain && !uncertain_typevar && return Const(appl)
!uncertain && return Type{appl}
if type_too_complex(appl,0)
return Type{_} where _<:headtype
end
Expand Down Expand Up @@ -1352,6 +1367,25 @@ function Pair_name()
return _Pair_name
end

_typename(a) = Union{}
_typename(a::Vararg) = Any
_typename(a::TypeVar) = Any
_typename(a::DataType) = Const(a.name)
function _typename(a::Union)
ta = _typename(a.a)
tb = _typename(a.b)
ta == tb ? tb : (ta === Any || tb == Any) ? Any : Union{}
end
_typename(union::UnionAll) = typename(union.body)
function typename_static(t)
# N.B.: typename maps type equivalence classes to a single value
if isa(t, Const) || isType(t)
return _typename(isa(t, Const) ? t.val : t.parameters[1])
else
return Any
end
end

function abstract_call(f::ANY, fargs, argtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState)
if f === _apply
length(fargs) > 1 || return Any
Expand Down Expand Up @@ -1386,19 +1420,62 @@ function abstract_call(f::ANY, fargs, argtypes::Vector{Any}, vtypes::VarTable, s
end
end
return Any
elseif f === UnionAll
if length(fargs) == 3 && isa(argtypes[2], Const)
tv = argtypes[2].val
if isa(tv, TypeVar)
elseif f === TypeVar
lb = Union{}
ub = Any
ub_certain = lb_certain = true
if length(fargs) >= 2 && isa(argtypes[2], Const)
nv = argtypes[2].val
ubidx = 3
if length(fargs) >= 4
ubidx = 4
if isa(argtypes[3], Const)
body = argtypes[3].val
lb = argtypes[3].val
elseif isType(argtypes[3])
body = argtypes[3].parameters[1]
lb = argtypes[3].parameters[1]
lb_certain = false
else
return Any
return TypeVar
end
return abstract_eval_constant(UnionAll(tv, body))
end
if length(fargs) >= ubidx
if isa(argtypes[ubidx], Const)
ub = argtypes[ubidx].val
elseif isType(argtypes[ubidx])
ub = argtypes[ubidx].parameters[1]
ub_certain = false
else
return TypeVar
end
end
tv = TypeVar(nv, lb, ub)
return PartialTypeVar(tv, lb_certain, ub_certain)
end
return TypeVar
elseif f === UnionAll
if length(fargs) == 3
canconst = true
if isa(argtypes[3], Const)
body = argtypes[3].val
elseif isType(argtypes[3])
body = argtypes[3].parameters[1]
canconst = false
else
return Any
end
if isa(argtypes[2], Const)
tv = argtypes[2].val
elseif isa(argtypes[2], PartialTypeVar)
ptv = argtypes[2]
tv = ptv.tv
canconst = false
else
return Any
end
!isa(tv, TypeVar) && return Any
ret = canconst ? abstract_eval_constant(UnionAll(tv, body)) :
Type{UnionAll(tv, body)}
return ret
end
return Any
elseif f === return_type
Expand All @@ -1411,7 +1488,15 @@ function abstract_call(f::ANY, fargs, argtypes::Vector{Any}, vtypes::VarTable, s
tm = _topmod(sv)
if length(argtypes)>2 && argtypes[3] Int
at2 = widenconst(argtypes[2])
if (at2 <: Tuple ||
if at2 <: SimpleVector && istopfunction(tm, f, :getindex)
if isa(argtypes[2], Const) && isa(argtypes[3], Const)
svecval = argtypes[2].val
idx = argtypes[3].val
if isa(idx, Int) && 1 <= idx <= length(svecval)
return Const(getindex(svecval, idx))
end
end
elseif (at2 <: Tuple ||
(isa(at2, DataType) && (at2::DataType).name === Pair_name()))
# allow tuple indexing functions to take advantage of constant
# index arguments.
Expand All @@ -1433,6 +1518,12 @@ function abstract_call(f::ANY, fargs, argtypes::Vector{Any}, vtypes::VarTable, s

if istopfunction(tm, f, :promote_type) || istopfunction(tm, f, :typejoin)
return Type
elseif length(argtypes) == 2 && istopfunction(tm, f, :typename)
t = argtypes[2]
if isa(t, Const) || isType(t)
return typename_static(t)
end
return Any
end

if sv.params.inlining
Expand Down Expand Up @@ -1683,6 +1774,7 @@ function ⊑(a::ANY, b::ANY)
end

widenconst(c::Const) = isa(c.val, Type) ? Type{c.val} : typeof(c.val)
widenconst(c::PartialTypeVar) = TypeVar
widenconst(t::ANY) = t

issubstate(a::VarState, b::VarState) = (a.typ b.typ && a.undef <= b.undef)
Expand Down Expand Up @@ -3300,7 +3392,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
if method.name == :getindex || method.name == :next || method.name == :indexed_next
if length(atypes) > 2 && atypes[3] Int
at2 = widenconst(atypes[2])
if (at2 <: Tuple ||
if (at2 <: Tuple || at2 <: SimpleVector ||
(isa(at2, DataType) && (at2::DataType).name === Pair_name()))
force_infer = true
end
Expand Down
9 changes: 9 additions & 0 deletions test/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -491,3 +491,12 @@ tpara18457{I}(::Type{AbstractMyType18457{I}}) = I
tpara18457{A<:AbstractMyType18457}(::Type{A}) = tpara18457(supertype(A))
@test tpara18457(MyType18457{true}) === true

fUnionAll{T}(::Type{T}) = Type{S} where S <: T
@inferred fUnionAll(Real) == Type{T} where T <: Real
@inferred fUnionAll(Rational{T} where T <: AbstractFloat) == Type{T} where T<:(Rational{S} where S <: AbstractFloat)

fComplicatedUnionAll{T}(::Type{T}) = Type{Tuple{S,rand() >= 0.5 ? Int : Float64}} where S <: T
let pub = Base.parameter_upper_bound, x = fComplicatedUnionAll(Real)
@test pub(pub(x, 1), 1) == Real
@test pub(pub(x, 1), 2) == Int || pub(pub(x, 1), 2) == Float64
end

0 comments on commit 21a751f

Please sign in to comment.