Skip to content

Commit 4c4e504

Browse files
committed
Add optional hooks to inference parameters.
This allows controling inference decisions, for now only method calling.
1 parent 6f45396 commit 4c4e504

File tree

2 files changed

+111
-2
lines changed

2 files changed

+111
-2
lines changed

base/inference.jl

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@ import Core: _apply, svec, apply_type, Builtin, IntrinsicFunction, MethodInstanc
66
const MAX_TYPEUNION_LEN = 3
77
const MAX_TYPE_DEPTH = 7
88

9+
immutable InferenceHooks
10+
call
11+
12+
InferenceHooks(call) = new(call)
13+
InferenceHooks() = new(nothing)
14+
end
15+
916
immutable InferenceParams
1017
# optimization
1118
optimize::Bool
@@ -18,13 +25,16 @@ immutable InferenceParams
1825
MAX_TUPLE_SPLAT::Int
1926
MAX_UNION_SPLITTING::Int
2027

28+
hooks::InferenceHooks
29+
2130
# default values will be used for regular compilation, as the compiler calls typeinf_ext
2231
# without specifying, or allowing to override, the inference parameters
2332
InferenceParams(;optimize::Bool=true, inlining::Bool=inlining_enabled(), cached::Bool=true,
2433
tupletype_len::Int=15, tuple_depth::Int=16,
25-
tuple_splat::Int=4, union_splitting::Int=4) =
34+
tuple_splat::Int=4, union_splitting::Int=4,
35+
hooks::InferenceHooks=InferenceHooks()) =
2636
new(optimize, inlining, cached, tupletype_len,
27-
tuple_depth, tuple_splat, union_splitting)
37+
tuple_depth, tuple_splat, union_splitting, hooks)
2838

2939
# copy constructor for selectively overriding certain params
3040
InferenceParams(params::InferenceParams; kwargs...) =
@@ -33,6 +43,7 @@ immutable InferenceParams
3343
tuple_depth=params.MAX_TUPLE_DEPTH,
3444
tuple_splat=params.MAX_TUPLE_SPLAT,
3545
union_splitting=params.MAX_UNION_SPLITTING,
46+
hooks=params.hooks,
3647
kwargs...)
3748
end
3849

@@ -1092,6 +1103,15 @@ end
10921103
argtypes_to_type(argtypes::Array{Any,1}) = Tuple{map(widenconst, argtypes)...}
10931104

10941105
function abstract_call(f::ANY, fargs, argtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState)
1106+
if sv.params.hooks.call != nothing
1107+
hack = sv.params.hooks.call(f, argtypes_to_type(argtypes))
1108+
if hack != nothing
1109+
println("NOTICE: overriding abstract_call of ", f, " with ", hack)
1110+
f = hack
1111+
argtypes[1] = typeof(f)
1112+
end
1113+
end
1114+
10951115
if is(f,_apply)
10961116
length(fargs)>1 || return Any
10971117
aft = argtypes[2]
@@ -3100,6 +3120,19 @@ function inlining_pass(e::Expr, sv::InferenceState)
31003120
end
31013121
end
31023122

3123+
if sv.params.hooks.call != nothing
3124+
argtypes = Vector{Any}(length(e.args))
3125+
argtypes[1] = ft
3126+
argtypes[2:end] = map(a->exprtype(a, sv.src, sv.mod), e.args[2:end])
3127+
3128+
hack = sv.params.hooks.call(f, argtypes_to_type(argtypes))
3129+
if hack != nothing
3130+
println("NOTICE: overriding inlining_pass of ", f, " with ", hack)
3131+
f = hack
3132+
ft = typeof(hack)
3133+
end
3134+
end
3135+
31033136
if sv.params.inlining
31043137
if isdefined(Main, :Base) &&
31053138
((isdefined(Main.Base, :^) && is(f, Main.Base.:^)) ||

demo.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
using LLVM
2+
3+
#
4+
# Functions
5+
#
6+
7+
@inline function child(x)
8+
x+1
9+
end
10+
11+
@inline function hacked_child(x)
12+
x+2
13+
end
14+
15+
function parent(x)
16+
return child(x)
17+
end
18+
19+
20+
#
21+
# Inference
22+
#
23+
24+
f = parent
25+
t = Tuple{Int}
26+
tt = Base.to_tuple_type(t)
27+
28+
ms = Base._methods(f, tt, -1)
29+
@assert length(ms) == 1
30+
(sig, spvals, m) = first(ms)
31+
32+
# given a function and the argument tuple type (incl. the function type)
33+
# return a tuple of the replacement function and its type, or nothing
34+
function call_hook(f, tt)
35+
if f == child
36+
return hacked_child
37+
end
38+
return nothing
39+
end
40+
# alternatively, call_hook(f::typeof(child), tt) = return hacked_child
41+
hooks = Core.Inference.InferenceHooks(call_hook)
42+
43+
# raise limits on inference parameters, performing a more exhaustive search
44+
params = Core.Inference.InferenceParams(tuple_depth=32, cached=false, hooks=hooks)
45+
46+
(code, rettyp) = Core.Inference.typeinf_code(m, sig, spvals, params)
47+
code === nothing && error("inference not successful")
48+
println("Returns: $rettyp")
49+
print(code)
50+
println()
51+
52+
53+
#
54+
# IRgen
55+
#
56+
57+
# module set-up
58+
mod = LLVM.Module("my_module")
59+
60+
# irgen
61+
# TODO
62+
exit()
63+
fun = get(functions(mod), "parent")
64+
65+
# execution
66+
ExecutionEngine(mod) do engine
67+
args = [GenericValue(LLVM.Int32Type(), x)]
68+
69+
res = LLVM.run(engine, fun, args)
70+
println(convert(Int, res))
71+
72+
dispose.(args)
73+
dispose(res)
74+
end
75+
76+
# jl_get_llvmf_defn vs jl_compile_linfo?

0 commit comments

Comments
 (0)