Skip to content

Commit 0487033

Browse files
committed
Add optional hooks to inference parameters.
This allows controling inference decisions, for now only method calling.
1 parent 3b33217 commit 0487033

File tree

2 files changed

+110
-2
lines changed

2 files changed

+110
-2
lines changed

base/inference.jl

+34-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@ import Core: _apply, svec, apply_type, Builtin, IntrinsicFunction, MethodInstanc
66
const MAX_TYPEUNION_LEN = 3
77
const MAX_TYPE_DEPTH = 7
88

9+
immutable InferenceHooks
10+
call
11+
12+
InferenceHooks(call) = new(call)
13+
InferenceHooks() = new(nothing)
14+
end
15+
916
immutable InferenceParams
1017
# optimization
1118
inlining::Bool
@@ -16,12 +23,15 @@ immutable InferenceParams
1623
MAX_TUPLE_SPLAT::Int
1724
MAX_UNION_SPLITTING::Int
1825

26+
hooks::InferenceHooks
27+
1928
# reasonable defaults
2029
InferenceParams(;inlining::Bool=inlining_enabled(),
2130
tupletype_len::Int=15, tuple_depth::Int=4,
22-
tuple_splat::Int=16, union_splitting::Int=4) =
31+
tuple_splat::Int=16, union_splitting::Int=4,
32+
hooks::InferenceHooks=InferenceHooks()) =
2333
new(inlining, tupletype_len,
24-
tuple_depth, tuple_splat, union_splitting)
34+
tuple_depth, tuple_splat, union_splitting, hooks)
2535
end
2636

2737
const UNION_SPLIT_MISMATCH_ERROR = false
@@ -1087,6 +1097,15 @@ end
10871097
argtypes_to_type(argtypes::Array{Any,1}) = Tuple{map(widenconst, argtypes)...}
10881098

10891099
function abstract_call(f::ANY, fargs, argtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState)
1100+
if sv.params.hooks.call != nothing
1101+
hack = sv.params.hooks.call(f, argtypes_to_type(argtypes))
1102+
if hack != nothing
1103+
println("NOTICE: overriding abstract_call of ", f, " with ", hack)
1104+
f = hack
1105+
argtypes[1] = typeof(f)
1106+
end
1107+
end
1108+
10901109
if f === _apply
10911110
length(fargs)>1 || return Any
10921111
aft = argtypes[2]
@@ -3109,6 +3128,19 @@ function inlining_pass(e::Expr, sv::InferenceState)
31093128
end
31103129
end
31113130

3131+
if sv.params.hooks.call != nothing
3132+
argtypes = Vector{Any}(length(e.args))
3133+
argtypes[1] = ft
3134+
argtypes[2:end] = map(a->exprtype(a, sv.src, sv.mod), e.args[2:end])
3135+
3136+
hack = sv.params.hooks.call(f, argtypes_to_type(argtypes))
3137+
if hack != nothing
3138+
println("NOTICE: overriding inlining_pass of ", f, " with ", hack)
3139+
f = hack
3140+
ft = typeof(hack)
3141+
end
3142+
end
3143+
31123144
if sv.params.inlining
31133145
if isdefined(Main, :Base) &&
31143146
((isdefined(Main.Base, :^) && f === Main.Base.:^) ||

demo.jl

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
using LLVM
2+
3+
#
4+
# Functions
5+
#
6+
7+
@inline function child(x)
8+
x+1
9+
end
10+
11+
@inline function hacked_child(x)
12+
x+2
13+
end
14+
15+
function parent(x)
16+
return child(x)
17+
end
18+
19+
20+
#
21+
# Inference
22+
#
23+
24+
f = parent
25+
t = Tuple{Int}
26+
tt = Base.to_tuple_type(t)
27+
28+
ms = Base._methods(f, tt, -1)
29+
@assert length(ms) == 1
30+
(sig, spvals, m) = first(ms)
31+
32+
# given a function and the argument tuple type (incl. the function type)
33+
# return a tuple of the replacement function and its type, or nothing
34+
function call_hook(f, tt)
35+
if f == child
36+
return hacked_child
37+
end
38+
return nothing
39+
end
40+
# alternatively, call_hook(f::typeof(child), tt) = return hacked_child
41+
hooks = Core.Inference.InferenceHooks(call_hook)
42+
43+
# raise limits on inference parameters, performing a more exhaustive search
44+
params = Core.Inference.InferenceParams(tuple_depth=32, hooks=hooks)
45+
46+
(code, rettyp) = Core.Inference.typeinf_code(m, sig, spvals, true, true, params)
47+
code === nothing && error("inference not successful")
48+
println("Returns: $rettyp")
49+
print(Base.uncompressed_ast(m, code))
50+
println()
51+
52+
53+
#
54+
# IRgen
55+
#
56+
57+
# module set-up
58+
mod = LLVM.Module("my_module")
59+
60+
# irgen
61+
# TODO
62+
exit()
63+
fun = get(functions(mod), "parent")
64+
65+
# execution
66+
ExecutionEngine(mod) do engine
67+
args = [GenericValue(LLVM.Int32Type(), x)]
68+
69+
res = LLVM.run(engine, fun, args)
70+
println(convert(Int, res))
71+
72+
dispose.(args)
73+
dispose(res)
74+
end
75+
76+
# jl_get_llvmf_defn vs jl_compile_linfo?

0 commit comments

Comments
 (0)