Skip to content

add mechanism for spoofing inference work-limiting heuristics #24852

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 12, 2018
1 change: 1 addition & 0 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ struct GeneratedFunctionStub
spnames::Union{Nothing, Array{Any,1}}
line::Int
file::Symbol
expand_early::Bool
end

# invoke and wrap the results of @generated
Expand Down
45 changes: 37 additions & 8 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,14 @@ function _validate(linfo::MethodInstance, src::CodeInfo, kind::String)
end

function get_staged(li::MethodInstance)
return ccall(:jl_code_for_staged, Any, (Any,), li)::CodeInfo
try
# user code might throw errors – ignore them
return ccall(:jl_code_for_staged, Any, (Any,), li)::CodeInfo
catch
return nothing
end
end


mutable struct OptimizationState
linfo::MethodInstance
vararg_type_container #::Type
Expand Down Expand Up @@ -472,12 +476,7 @@ end
function retrieve_code_info(linfo::MethodInstance)
m = linfo.def::Method
if isdefined(m, :generator)
try
# user code might throw errors – ignore them
c = get_staged(linfo)
catch
return nothing
end
return get_staged(linfo)
else
# TODO: post-inference see if we can swap back to the original arrays?
if isa(m.source, Array{UInt8,1})
Expand All @@ -489,6 +488,35 @@ function retrieve_code_info(linfo::MethodInstance)
return c
end

# TODO: Use these functions instead of directly manipulating
# the "actual" method for appropriate places in inference (see #24676)
function method_for_inference_heuristics(cinfo, default)
if isa(cinfo, CodeInfo)
# appropriate format for `sig` is svec(ftype, argtypes, world)
sig = cinfo.signature_for_inference_heuristics
if isa(sig, SimpleVector) && length(sig) == 3
methods = _methods(sig[1], sig[2], -1, sig[3])
if length(methods) == 1
_, _, m = methods[]
if isa(m, Method)
return m
end
end
end
end
return default
end

function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams, world)
if isdefined(method, :generator) && method.generator.expand_early
method_instance = code_for_method(method, sig, sparams, world, false)
if isa(method_instance, MethodInstance)
return method_for_inference_heuristics(get_staged(method_instance), method)
end
end
return method
end

@inline slot_id(s) = isa(s, SlotNumber) ? (s::SlotNumber).id : (s::TypedSlot).id # using a function to ensure we can infer this

# avoid cycle due to over-specializing `any` when used by inference
Expand Down Expand Up @@ -3396,6 +3424,7 @@ function typeinf_code(linfo::MethodInstance, optimize::Bool, cached::Bool,
method = linfo.def::Method
tree = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ())
tree.code = Any[ Expr(:return, quoted(linfo.inferred_const)) ]
tree.signature_for_inference_heuristics = nothing
tree.slotnames = Any[ compiler_temp_sym for i = 1:method.nargs ]
tree.slotflags = UInt8[ 0 for i = 1:method.nargs ]
tree.slottypes = nothing
Expand Down
8 changes: 5 additions & 3 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2046,8 +2046,9 @@ void jl_init_types(void)
jl_code_info_type =
jl_new_datatype(jl_symbol("CodeInfo"), core,
jl_any_type, jl_emptysvec,
jl_perm_symsvec(9,
jl_perm_symsvec(10,
"code",
"signature_for_inference_heuristics",
"slottypes",
"ssavaluetypes",
"slotflags",
Expand All @@ -2056,17 +2057,18 @@ void jl_init_types(void)
"inlineable",
"propagate_inbounds",
"pure"),
jl_svec(9,
jl_svec(10,
jl_array_any_type,
jl_any_type,
jl_any_type,
jl_any_type,
jl_array_uint8_type,
jl_array_any_type,
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_bool_type),
0, 1, 9);
0, 1, 10);

jl_method_type =
jl_new_datatype(jl_symbol("Method"), core,
Expand Down
3 changes: 2 additions & 1 deletion src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,8 @@
'nothing
(cons 'list (map car sparams)))
,(if (null? loc) 0 (cadr loc))
(inert ,(if (null? loc) 'none (caddr loc))))))))
(inert ,(if (null? loc) 'none (caddr loc)))
false)))))
Copy link
Member Author

@jrevels jrevels Dec 4, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I've now added a Bool flag to GeneratedFunctionStub that inference can check to decide when to expand the generator. Just to get started, I'm hardcoding the flag to false, but now I need to add a mechanism for actually setting it.

@JeffBezanson What would be the most reasonable way to add this setting mechanism? On the front end, I could add an extra argument to the @generated macro, or add a @generated_early macro, etc. I'm also wondering what the easiest/cleanest way is to propagate that setting to here. I guess I could add a meta node to the expression returned from the @generated macro and then destructure it here?

...I'm a noob when it comes to Julia's parser 😛

(list gf))
'()))
(types (llist-types argl))
Expand Down
1 change: 1 addition & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ typedef struct _jl_llvm_functions_t {
// This type describes a single function body
typedef struct _jl_code_info_t {
jl_array_t *code; // Any array of statements
jl_value_t *signature_for_inference_heuristics; // optional method used during inference
jl_value_t *slottypes; // types of variable slots (or `nothing`)
jl_value_t *ssavaluetypes; // types of ssa values (or count of them)
jl_array_t *slotflags; // local var bit flags
Expand Down
6 changes: 4 additions & 2 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ static void jl_code_info_set_ast(jl_code_info_t *li, jl_expr_t *ast)
jl_array_del_end(meta, na - ins);
}
}
li->signature_for_inference_heuristics = jl_nothing;
jl_array_t *vinfo = (jl_array_t*)jl_exprarg(ast, 1);
jl_array_t *vis = (jl_array_t*)jl_array_ptr_ref(vinfo, 0);
size_t nslots = jl_array_len(vis);
Expand Down Expand Up @@ -255,6 +256,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void)
(jl_code_info_t*)jl_gc_alloc(ptls, sizeof(jl_code_info_t),
jl_code_info_type);
src->code = NULL;
src->signature_for_inference_heuristics = NULL;
src->slotnames = NULL;
src->slotflags = NULL;
src->slottypes = NULL;
Expand Down Expand Up @@ -442,8 +444,8 @@ static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src)
else if (jl_expr_nargs(st) == 2 && jl_exprarg(st, 0) == (jl_value_t*)generated_sym) {
m->generator = NULL;
jl_value_t *gexpr = jl_exprarg(st, 1);
if (jl_expr_nargs(gexpr) == 6) {
// expects (new (core GeneratedFunctionStub) funcname argnames sp line file)
if (jl_expr_nargs(gexpr) == 7) {
// expects (new (core GeneratedFunctionStub) funcname argnames sp line file expandearly)
jl_value_t *funcname = jl_exprarg(gexpr, 1);
assert(jl_is_symbol(funcname));
if (jl_get_global(m->module, (jl_sym_t*)funcname) != NULL) {
Expand Down
1 change: 1 addition & 0 deletions src/toplevel.c
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,7 @@ static jl_code_info_t *expr_to_code_info(jl_value_t *expr)
jl_gc_wb(src, src->slotflags);
src->ssavaluetypes = jl_box_long(0);
jl_gc_wb(src, src->ssavaluetypes);
src->signature_for_inference_heuristics = jl_nothing;

JL_GC_POP();
return src;
Expand Down
74 changes: 74 additions & 0 deletions test/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1317,3 +1317,77 @@ bar_22708(x) = f_22708(x)

@test bar_22708(1) == "x"

# mechanism for spoofing work-limiting heuristics and early generator expansion (#24852)
function _generated_stub(gen::Symbol, args::Vector{Any}, params::Vector{Any}, line, file, expand_early)
stub = Expr(:new, Core.GeneratedFunctionStub, gen, args, params, line, file, expand_early)
return Expr(:meta, :generated, stub)
end

f24852_kernel(x, y) = x * y

function f24852_kernel_cinfo(x, y)
sig, spvals, method = Base._methods_by_ftype(Tuple{typeof(f24852_kernel),x,y}, -1, typemax(UInt))[1]
code_info = Base.uncompressed_ast(method)
body = Expr(:block, code_info.code...)
Base.Core.Inference.substitute!(body, 0, Any[], sig, Any[spvals...], 0, :propagate)
return method, code_info
end

function f24852_gen_cinfo_uninflated(X, Y, f, x, y)
_, code_info = f24852_kernel_cinfo(x, y)
return code_info
end

function f24852_gen_cinfo_inflated(X, Y, f, x, y)
method, code_info = f24852_kernel_cinfo(x, y)
code_info.signature_for_inference_heuristics = Core.Inference.svec(f, (x, y), typemax(UInt))
return code_info
end

function f24852_gen_expr(X, Y, f, x, y)
return :(f24852_kernel(x::$X, y::$Y))
end

@eval begin
function f24852_late_expr(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_expr, Any[:f24852_late_expr, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false))
end
function f24852_late_inflated(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_inflated, Any[:f24852_late_inflated, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false))
end
function f24852_late_uninflated(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:f24852_late_uninflated, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false))
end
end

@eval begin
function f24852_early_expr(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_expr, Any[:f24852_early_expr, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true))
end
function f24852_early_inflated(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_inflated, Any[:f24852_early_inflated, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true))
end
function f24852_early_uninflated(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:f24852_early_uninflated, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true))
end
end

x, y = rand(), rand()
result = f24852_kernel(x, y)

@test result === f24852_late_expr(x, y)
@test result === f24852_late_uninflated(x, y)
@test result === f24852_late_inflated(x, y)

@test result === f24852_early_expr(x, y)
@test result === f24852_early_uninflated(x, y)
@test result === f24852_early_inflated(x, y)

# TODO: test that `expand_early = true` + inflated `signature_for_inference_heuristics`
# can be used to tighten up some inference result.