Skip to content

Commit 62e8227

Browse files
authored
Merge pull request #22019 from JuliaLang/jn/inferrable-functions
improve inferability of base
2 parents a145a59 + 4287761 commit 62e8227

14 files changed

+260
-296
lines changed

base/abstractarray.jl

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,11 @@ julia> strides(A)
202202
(1, 3, 12)
203203
```
204204
"""
205-
strides(A::AbstractArray) = _strides((1,), A)
206-
_strides(out::Tuple{Int}, A::AbstractArray{<:Any,0}) = ()
207-
_strides(out::NTuple{N,Int}, A::AbstractArray{<:Any,N}) where {N} = out
208-
function _strides(out::NTuple{M,Int}, A::AbstractArray) where M
209-
@_inline_meta
210-
_strides((out..., out[M]*size(A, M)), A)
211-
end
205+
strides(A::AbstractArray) = size_to_strides(1, size(A)...)
206+
@inline size_to_strides(s, d, sz...) = (s, size_to_strides(s * d, sz...)...)
207+
size_to_strides(s, d) = (s,)
208+
size_to_strides(s) = ()
209+
212210

213211
function isassigned(a::AbstractArray, i::Int...)
214212
try
@@ -1160,30 +1158,32 @@ cat_similar(A::AbstractArray, T, shape) = similar(A, T, shape)
11601158

11611159
cat_shape(dims, shape::Tuple) = shape
11621160
@inline cat_shape(dims, shape::Tuple, nshape::Tuple, shapes::Tuple...) =
1163-
cat_shape(dims, _cshp(dims, (), shape, nshape), shapes...)
1161+
cat_shape(dims, _cshp(1, dims, shape, nshape), shapes...)
11641162

1165-
_cshp(::Tuple{}, out, ::Tuple{}, ::Tuple{}) = out
1166-
_cshp(::Tuple{}, out, ::Tuple{}, nshape) = (out..., nshape...)
1167-
_cshp(dims, out, ::Tuple{}, ::Tuple{}) = (out..., map(b -> 1, dims)...)
1168-
@inline _cshp(dims, out, shape, ::Tuple{}) =
1169-
_cshp(tail(dims), (out..., shape[1] + dims[1]), tail(shape), ())
1170-
@inline _cshp(dims, out, ::Tuple{}, nshape) =
1171-
_cshp(tail(dims), (out..., nshape[1]), (), tail(nshape))
1172-
@inline function _cshp(::Tuple{}, out, shape, ::Tuple{})
1173-
_cs(length(out) + 1, false, shape[1], 1)
1174-
_cshp((), (out..., 1), tail(shape), ())
1163+
_cshp(ndim::Int, ::Tuple{}, ::Tuple{}, ::Tuple{}) = ()
1164+
_cshp(ndim::Int, ::Tuple{}, ::Tuple{}, nshape) = nshape
1165+
_cshp(ndim::Int, dims, ::Tuple{}, ::Tuple{}) = ntuple(b -> 1, Val{length(dims)})
1166+
@inline _cshp(ndim::Int, dims, shape, ::Tuple{}) =
1167+
(shape[1] + dims[1], _cshp(ndim + 1, tail(dims), tail(shape), ())...)
1168+
@inline _cshp(ndim::Int, dims, ::Tuple{}, nshape) =
1169+
(nshape[1], _cshp(ndim + 1, tail(dims), (), tail(nshape))...)
1170+
@inline function _cshp(ndim::Int, ::Tuple{}, shape, ::Tuple{})
1171+
_cs(ndim, shape[1], 1)
1172+
(1, _cshp(ndim + 1, (), tail(shape), ())...)
11751173
end
1176-
@inline function _cshp(::Tuple{}, out, shape, nshape)
1177-
next = _cs(length(out) + 1, false, shape[1], nshape[1])
1178-
_cshp((), (out..., next), tail(shape), tail(nshape))
1174+
@inline function _cshp(ndim::Int, ::Tuple{}, shape, nshape)
1175+
next = _cs(ndim, shape[1], nshape[1])
1176+
(next, _cshp(ndim + 1, (), tail(shape), tail(nshape))...)
11791177
end
1180-
@inline function _cshp(dims, out, shape, nshape)
1181-
next = _cs(length(out) + 1, dims[1], shape[1], nshape[1])
1182-
_cshp(tail(dims), (out..., next), tail(shape), tail(nshape))
1178+
@inline function _cshp(ndim::Int, dims, shape, nshape)
1179+
a = shape[1]
1180+
b = nshape[1]
1181+
next = dims[1] ? a + b : _cs(ndim, a, b)
1182+
(next, _cshp(ndim + 1, tail(dims), tail(shape), tail(nshape))...)
11831183
end
11841184

1185-
_cs(d, concat, a, b) = concat ? (a + b) : (a == b ? a : throw(DimensionMismatch(string(
1186-
"mismatch in dimension ", d, " (expected ", a, " got ", b, ")"))))
1185+
_cs(d, a, b) = (a == b ? a : throw(DimensionMismatch(
1186+
"mismatch in dimension $d (expected $a got $b)")))
11871187

11881188
dims2cat{n}(::Type{Val{n}}) = ntuple(i -> (i == n), Val{n})
11891189
dims2cat(dims) = ntuple(i -> (i in dims), maximum(dims))
@@ -1668,15 +1668,15 @@ end
16681668
function _sub2ind!(Iout, inds, Iinds, I)
16691669
@_noinline_meta
16701670
for i in Iinds
1671-
# Iout[i] = sub2ind(inds, map(Ij->Ij[i], I)...)
1671+
# Iout[i] = sub2ind(inds, map(Ij -> Ij[i], I)...)
16721672
Iout[i] = sub2ind_vec(inds, i, I)
16731673
end
16741674
Iout
16751675
end
16761676

1677-
sub2ind_vec(inds, i, I) = (@_inline_meta; _sub2ind_vec(inds, (), i, I...))
1678-
_sub2ind_vec(inds, out, i, I1, I...) = (@_inline_meta; _sub2ind_vec(inds, (out..., I1[i]), i, I...))
1679-
_sub2ind_vec(inds, out, i) = (@_inline_meta; sub2ind(inds, out...))
1677+
sub2ind_vec(inds, i, I) = (@_inline_meta; sub2ind(inds, _sub2ind_vec(i, I...)...))
1678+
_sub2ind_vec(i, I1, I...) = (@_inline_meta; (I1[i], _sub2ind_vec(i, I...)...))
1679+
_sub2ind_vec(i) = ()
16801680

16811681
function ind2sub(inds::Union{DimsInteger{N},Indices{N}}, ind::AbstractVector{<:Integer}) where N
16821682
M = length(ind)

base/array.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,7 @@ end
7373
size(a::Array, d) = arraysize(a, d)
7474
size(a::Vector) = (arraysize(a,1),)
7575
size(a::Matrix) = (arraysize(a,1), arraysize(a,2))
76-
size(a::Array) = (@_inline_meta; _size((), a))
77-
_size(out::NTuple{N}, A::Array{_,N}) where {_,N} = out
78-
function _size(out::NTuple{M}, A::Array{_,N}) where _ where M where N
79-
@_inline_meta
80-
_size((out..., size(A,M+1)), A)
81-
end
76+
size(a::Array{<:Any,N}) where {N} = (@_inline_meta; ntuple(M -> size(a, M), Val{N}))
8277

8378
asize_from(a::Array, n) = n > ndims(a) ? () : (arraysize(a,n), asize_from(a, n+1)...)
8479

base/broadcast.jl

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,23 +45,22 @@ promote_containertype(::Type{T}, ::Type{T}) where {T} = T
4545
## Calculate the broadcast indices of the arguments, or error if incompatible
4646
# array inputs
4747
broadcast_indices() = ()
48-
broadcast_indices(A) = broadcast_indices(containertype(A), A)
49-
broadcast_indices(::ScalarType, A) = ()
50-
broadcast_indices(::Type{Tuple}, A) = (OneTo(length(A)),)
51-
broadcast_indices(::Type{Array}, A::Ref) = ()
52-
broadcast_indices(::Type{Array}, A) = indices(A)
53-
@inline broadcast_indices(A, B...) = broadcast_shape((), broadcast_indices(A), map(broadcast_indices, B)...)
48+
broadcast_indices(A) = _broadcast_indices(containertype(A), A)
49+
@inline broadcast_indices(A, B...) = broadcast_shape(broadcast_indices(A), broadcast_indices(B...))
50+
_broadcast_indices(::Type, A) = ()
51+
_broadcast_indices(::Type{Tuple}, A) = (OneTo(length(A)),)
52+
_broadcast_indices(::Type{Array}, A::Ref) = ()
53+
_broadcast_indices(::Type{Array}, A) = indices(A)
5454

5555
# shape (i.e., tuple-of-indices) inputs
5656
broadcast_shape(shape::Tuple) = shape
57-
@inline broadcast_shape(shape::Tuple, shape1::Tuple, shapes::Tuple...) = broadcast_shape(_bcs((), shape, shape1), shapes...)
57+
@inline broadcast_shape(shape::Tuple, shape1::Tuple, shapes::Tuple...) = broadcast_shape(_bcs(shape, shape1), shapes...)
5858
# _bcs consolidates two shapes into a single output shape
59-
_bcs(out, ::Tuple{}, ::Tuple{}) = out
60-
@inline _bcs(out, ::Tuple{}, newshape) = _bcs((out..., newshape[1]), (), tail(newshape))
61-
@inline _bcs(out, shape, ::Tuple{}) = _bcs((out..., shape[1]), tail(shape), ())
62-
@inline function _bcs(out, shape, newshape)
63-
newout = _bcs1(shape[1], newshape[1])
64-
_bcs((out..., newout), tail(shape), tail(newshape))
59+
_bcs(::Tuple{}, ::Tuple{}) = ()
60+
@inline _bcs(::Tuple{}, newshape::Tuple) = (newshape[1], _bcs((), tail(newshape))...)
61+
@inline _bcs(shape::Tuple, ::Tuple{}) = (shape[1], _bcs(tail(shape), ())...)
62+
@inline function _bcs(shape::Tuple, newshape::Tuple)
63+
return (_bcs1(shape[1], newshape[1]), _bcs(tail(shape), tail(newshape))...)
6564
end
6665
# _bcs1 handles the logic for a single dimension
6766
_bcs1(a::Integer, b::Integer) = a == 1 ? b : (b == 1 ? a : (a == b ? a : throw(DimensionMismatch("arrays could not be broadcast to a common size"))))

base/inference.jl

Lines changed: 71 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ struct InferenceParams
1515
inlining::Bool
1616

1717
# parameters limiting potentially-infinite types (configurable)
18+
MAX_METHODS::Int
1819
MAX_TUPLETYPE_LEN::Int
1920
MAX_TUPLE_DEPTH::Int
2021
MAX_TUPLE_SPLAT::Int
@@ -24,12 +25,13 @@ struct InferenceParams
2425
# reasonable defaults
2526
function InferenceParams(world::UInt;
2627
inlining::Bool = inlining_enabled(),
28+
max_methods::Int = 4,
2729
tupletype_len::Int = 15,
2830
tuple_depth::Int = 4,
2931
tuple_splat::Int = 16,
3032
union_splitting::Int = 4,
3133
apply_union_enum::Int = 8)
32-
return new(world, inlining, tupletype_len,
34+
return new(world, inlining, max_methods, tupletype_len,
3335
tuple_depth, tuple_splat, union_splitting, apply_union_enum)
3436
end
3537
end
@@ -1280,7 +1282,7 @@ function abstract_call_gf_by_type(f::ANY, atype::ANY, sv::InferenceState)
12801282
end
12811283
min_valid = UInt[typemin(UInt)]
12821284
max_valid = UInt[typemax(UInt)]
1283-
applicable = _methods_by_ftype(argtype, 4, sv.params.world, min_valid, max_valid)
1285+
applicable = _methods_by_ftype(argtype, sv.params.MAX_METHODS, sv.params.world, min_valid, max_valid)
12841286
rettype = Bottom
12851287
if applicable === false
12861288
# this means too many methods matched
@@ -1431,7 +1433,7 @@ function precise_container_type(arg::ANY, typ::ANY, vtypes::VarTable, sv::Infere
14311433
if isa(typ, Const)
14321434
val = typ.val
14331435
if isa(val, SimpleVector) || isa(val, Tuple)
1434-
return Any[ abstract_eval_constant(x) for x in val ]
1436+
return Any[ Const(val[i]) for i in 1:length(val) ] # avoid making a tuple Generator here!
14351437
end
14361438
end
14371439

@@ -1499,44 +1501,64 @@ function abstract_iteration(itertype::ANY, vtypes::VarTable, sv::InferenceState)
14991501
return Vararg{valtype}
15001502
end
15011503

1504+
function tuple_tail_elem(init::ANY, ct)
1505+
return Vararg{widenconst(foldl((a, b) -> tmerge(a, unwrapva(b)), init, ct))}
1506+
end
1507+
15021508
# do apply(af, fargs...), where af is a function value
1503-
function abstract_apply(af::ANY, fargs::Vector{Any}, aargtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState)
1509+
function abstract_apply(aft::ANY, fargs::Vector{Any}, aargtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState)
1510+
if !isa(aft, Const) && !isconstType(aft)
1511+
if !(isleaftype(aft) || aft <: Type) || (aft <: Builtin) || (aft <: IntrinsicFunction)
1512+
return Any
1513+
end
1514+
# non-constant function, but type is known
1515+
end
15041516
res = Union{}
15051517
nargs = length(fargs)
15061518
assert(nargs == length(aargtypes))
1507-
splitunions = countunionsplit(aargtypes) <= sv.params.MAX_APPLY_UNION_ENUM
1508-
ctypes = Any[Any[]]
1519+
splitunions = 1 < countunionsplit(aargtypes) <= sv.params.MAX_APPLY_UNION_ENUM
1520+
ctypes = Any[Any[aft]]
15091521
for i = 1:nargs
15101522
if aargtypes[i] === Any
15111523
# bail out completely and infer as f(::Any...)
1512-
# instead could keep what we got so far and just append a Vararg{Any} (by just
1513-
# using the normal logic from below), but that makes the time of the subarray
1514-
# test explode
1515-
ctypes = Any[Any[Vararg{Any}]]
1524+
# instead could infer the precise types for the types up to this point and just append a Vararg{Any}
1525+
# (by just using the normal logic from below), but that makes the time of the subarray test explode
1526+
push!(ctypes[1], Vararg{Any})
15161527
break
15171528
end
1518-
ctypes´ = []
1519-
for ti in (splitunions ? uniontypes(aargtypes[i]) : Any[aargtypes[i]])
1520-
cti = precise_container_type(fargs[i], ti, vtypes, sv)
1521-
for ct in ctypes
1522-
if !isempty(ct) && isvarargtype(ct[end])
1523-
tail = foldl((a,b)->tmerge(a,unwrapva(b)), unwrapva(ct[end]), cti)
1524-
push!(ctypes´, push!(ct[1:end-1], Vararg{widenconst(tail)}))
1525-
else
1526-
push!(ctypes´, append_any(ct, cti))
1529+
end
1530+
if length(ctypes[1]) == 1
1531+
for i = 1:nargs
1532+
ctypes´ = []
1533+
for ti in (splitunions ? uniontypes(aargtypes[i]) : Any[aargtypes[i]])
1534+
cti = precise_container_type(fargs[i], ti, vtypes, sv)
1535+
for ct in ctypes
1536+
if !isempty(ct) && isvarargtype(ct[end])
1537+
tail = tuple_tail_elem(unwrapva(ct[end]), cti)
1538+
push!(ctypes´, push!(ct[1:(end - 1)], tail))
1539+
else
1540+
push!(ctypes´, append_any(ct, cti))
1541+
end
15271542
end
15281543
end
1544+
ctypes = ctypes´
15291545
end
1530-
ctypes = ctypes´
15311546
end
15321547
for ct in ctypes
15331548
if length(ct) > sv.params.MAX_TUPLETYPE_LEN
1534-
tail = foldl((a,b)->tmerge(a,unwrapva(b)), Bottom, ct[sv.params.MAX_TUPLETYPE_LEN:end])
1549+
tail = tuple_tail_elem(Bottom, ct[sv.params.MAX_TUPLETYPE_LEN:end])
15351550
resize!(ct, sv.params.MAX_TUPLETYPE_LEN)
1536-
ct[end] = Vararg{widenconst(tail)}
1551+
ct[end] = tail
1552+
end
1553+
if isa(aft, Const)
1554+
rt = abstract_call(aft.val, (), ct, vtypes, sv)
1555+
elseif isconstType(aft)
1556+
rt = abstract_call(aft.parameters[1], (), ct, vtypes, sv)
1557+
else
1558+
astype = argtypes_to_type(ct)
1559+
rt = abstract_call_gf_by_type(nothing, astype, sv)
15371560
end
1538-
at = append_any(Any[Const(af)], ct)
1539-
res = tmerge(res, abstract_call(af, (), at, vtypes, sv))
1561+
res = tmerge(res, rt)
15401562
if res === Any
15411563
break
15421564
end
@@ -1651,20 +1673,7 @@ typename_static(t::ANY) = isType(t) ? _typename(t.parameters[1]) : Any
16511673
function abstract_call(f::ANY, fargs::Union{Tuple{},Vector{Any}}, argtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState)
16521674
if f === _apply
16531675
length(fargs) > 1 || return Any
1654-
aft = argtypes[2]
1655-
if isa(aft, Const)
1656-
af = aft.val
1657-
else
1658-
if isType(aft) && isleaftype(aft.parameters[1])
1659-
af = aft.parameters[1]
1660-
elseif isleaftype(aft) && isdefined(aft, :instance)
1661-
af = aft.instance
1662-
else
1663-
# TODO jb/functions: take advantage of case where non-constant `af`'s type is known
1664-
return Any
1665-
end
1666-
end
1667-
return abstract_apply(af, fargs[3:end], argtypes[3:end], vtypes, sv)
1676+
return abstract_apply(argtypes[2], fargs[3:end], argtypes[3:end], vtypes, sv)
16681677
end
16691678

16701679
la = length(argtypes)
@@ -2508,12 +2517,14 @@ function typeinf_edge(method::Method, atypes::ANY, sparams::SimpleVector, caller
25082517
frame = resolve_call_cycle!(code, caller)
25092518
if frame === nothing
25102519
code.inInference = true
2511-
frame = InferenceState(code, true, true, caller.params) # always optimize and cache edge targets
2520+
frame = InferenceState(code, #=optimize=#true, #=cached=#true, caller.params) # always optimize and cache edge targets
25122521
if frame === nothing
25132522
code.inInference = false
25142523
return Any, nothing
25152524
end
2516-
frame.parent = caller
2525+
if caller.cached # don't involve uncached functions in cycle resolution
2526+
frame.parent = caller
2527+
end
25172528
typeinf(frame)
25182529
return frame.bestguess, frame.inferred ? frame.linfo : nothing
25192530
end
@@ -2849,6 +2860,7 @@ end
28492860
#### finalize and record the result of running type inference ####
28502861

28512862
function isinlineable(m::Method, src::CodeInfo)
2863+
# compute the cost (size) of inlining this code
28522864
inlineable = false
28532865
cost = 1000
28542866
if m.module === _topmod(m.module)
@@ -2941,7 +2953,25 @@ function optimize(me::InferenceState)
29412953
end
29422954

29432955
# determine and cache inlineability
2944-
if !me.src.inlineable && !force_noinline && isdefined(me.linfo, :def)
2956+
if !force_noinline
2957+
# don't keep ASTs for functions specialized on a Union argument
2958+
# TODO: this helps avoid a type-system bug mis-computing sparams during intersection
2959+
sig = unwrap_unionall(me.linfo.specTypes)
2960+
if isa(sig, DataType) && sig.name === Tuple.name
2961+
for P in sig.parameters
2962+
P = unwrap_unionall(P)
2963+
if isa(P, Union)
2964+
force_noinline = true
2965+
break
2966+
end
2967+
end
2968+
else
2969+
force_noinline = true
2970+
end
2971+
end
2972+
if force_noinline
2973+
me.src.inlineable = false
2974+
elseif !me.src.inlineable && isdefined(me.linfo, :def)
29452975
me.src.inlineable = isinlineable(me.linfo.def, me.src)
29462976
end
29472977
me.src.inferred = true

base/int.jl

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -362,22 +362,24 @@ end
362362
# @doc isn't available when running in Core at this point.
363363
# Tuple syntax for documention two function signatures at the same time
364364
# doesn't work either at this point.
365-
isdefined(Main, :Base) && for fname in (:mod, :rem)
366-
@eval @doc """
367-
rem(x::Integer, T::Type{<:Integer}) -> T
368-
mod(x::Integer, T::Type{<:Integer}) -> T
369-
%(x::Integer, T::Type{<:Integer}) -> T
370-
371-
Find `y::T` such that `x` ≡ `y` (mod n), where n is the number of integers representable
372-
in `T`, and `y` is an integer in `[typemin(T),typemax(T)]`.
373-
If `T` can represent any integer (e.g. `T == BigInt`), then this operation corresponds to
374-
a conversion to `T`.
375-
376-
```jldoctest
377-
julia> 129 % Int8
378-
-127
379-
```
380-
""" -> $fname(x::Integer, T::Type{<:Integer})
365+
if module_name(current_module()) === :Base
366+
for fname in (:mod, :rem)
367+
@eval @doc ("""
368+
rem(x::Integer, T::Type{<:Integer}) -> T
369+
mod(x::Integer, T::Type{<:Integer}) -> T
370+
%(x::Integer, T::Type{<:Integer}) -> T
371+
372+
Find `y::T` such that `x` ≡ `y` (mod n), where n is the number of integers representable
373+
in `T`, and `y` is an integer in `[typemin(T),typemax(T)]`.
374+
If `T` can represent any integer (e.g. `T == BigInt`), then this operation corresponds to
375+
a conversion to `T`.
376+
377+
```jldoctest
378+
julia> 129 % Int8
379+
-127
380+
```
381+
""" -> $fname(x::Integer, T::Type{<:Integer}))
382+
end
381383
end
382384

383385
rem(x::T, ::Type{T}) where {T<:Integer} = x

base/multidimensional.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,10 @@ module IteratorsMD
138138
eachindex(::IndexCartesian, A::AbstractArray) = CartesianRange(indices(A))
139139

140140
@inline eachindex(::IndexCartesian, A::AbstractArray, B::AbstractArray...) =
141-
CartesianRange(maxsize((), A, B...))
142-
maxsize(sz) = sz
143-
@inline maxsize(sz, A, B...) = maxsize(maxt(sz, size(A)), B...)
141+
CartesianRange(maxsize(A, B...))
142+
maxsize() = ()
143+
@inline maxsize(A) = size(A)
144+
@inline maxsize(A, B...) = maxt(size(A), maxsize(B...))
144145
@inline maxt(a::Tuple{}, b::Tuple{}) = ()
145146
@inline maxt(a::Tuple{}, b::Tuple) = b
146147
@inline maxt(a::Tuple, b::Tuple{}) = a

0 commit comments

Comments
 (0)