Skip to content

Commit ec9e92e

Browse files
authored
add mechanism for spoofing inference work-limiting heuristics (#24852)
1 parent daf1235 commit ec9e92e

File tree

8 files changed

+125
-14
lines changed

8 files changed

+125
-14
lines changed

base/boot.jl

+1
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ struct GeneratedFunctionStub
456456
spnames::Union{Nothing, Array{Any,1}}
457457
line::Int
458458
file::Symbol
459+
expand_early::Bool
459460
end
460461

461462
# invoke and wrap the results of @generated

base/inference.jl

+37-8
Original file line numberDiff line numberDiff line change
@@ -371,10 +371,14 @@ function _validate(linfo::MethodInstance, src::CodeInfo, kind::String)
371371
end
372372

373373
function get_staged(li::MethodInstance)
374-
return ccall(:jl_code_for_staged, Any, (Any,), li)::CodeInfo
374+
try
375+
# user code might throw errors – ignore them
376+
return ccall(:jl_code_for_staged, Any, (Any,), li)::CodeInfo
377+
catch
378+
return nothing
379+
end
375380
end
376381

377-
378382
mutable struct OptimizationState
379383
linfo::MethodInstance
380384
vararg_type_container #::Type
@@ -472,12 +476,7 @@ end
472476
function retrieve_code_info(linfo::MethodInstance)
473477
m = linfo.def::Method
474478
if isdefined(m, :generator)
475-
try
476-
# user code might throw errors – ignore them
477-
c = get_staged(linfo)
478-
catch
479-
return nothing
480-
end
479+
return get_staged(linfo)
481480
else
482481
# TODO: post-inference see if we can swap back to the original arrays?
483482
if isa(m.source, Array{UInt8,1})
@@ -489,6 +488,35 @@ function retrieve_code_info(linfo::MethodInstance)
489488
return c
490489
end
491490

491+
# TODO: Use these functions instead of directly manipulating
492+
# the "actual" method for appropriate places in inference (see #24676)
493+
function method_for_inference_heuristics(cinfo, default)
494+
if isa(cinfo, CodeInfo)
495+
# appropriate format for `sig` is svec(ftype, argtypes, world)
496+
sig = cinfo.signature_for_inference_heuristics
497+
if isa(sig, SimpleVector) && length(sig) == 3
498+
methods = _methods(sig[1], sig[2], -1, sig[3])
499+
if length(methods) == 1
500+
_, _, m = methods[]
501+
if isa(m, Method)
502+
return m
503+
end
504+
end
505+
end
506+
end
507+
return default
508+
end
509+
510+
function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams, world)
511+
if isdefined(method, :generator) && method.generator.expand_early
512+
method_instance = code_for_method(method, sig, sparams, world, false)
513+
if isa(method_instance, MethodInstance)
514+
return method_for_inference_heuristics(get_staged(method_instance), method)
515+
end
516+
end
517+
return method
518+
end
519+
492520
@inline slot_id(s) = isa(s, SlotNumber) ? (s::SlotNumber).id : (s::TypedSlot).id # using a function to ensure we can infer this
493521

494522
# avoid cycle due to over-specializing `any` when used by inference
@@ -3396,6 +3424,7 @@ function typeinf_code(linfo::MethodInstance, optimize::Bool, cached::Bool,
33963424
method = linfo.def::Method
33973425
tree = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ())
33983426
tree.code = Any[ Expr(:return, quoted(linfo.inferred_const)) ]
3427+
tree.signature_for_inference_heuristics = nothing
33993428
tree.slotnames = Any[ compiler_temp_sym for i = 1:method.nargs ]
34003429
tree.slotflags = UInt8[ 0 for i = 1:method.nargs ]
34013430
tree.slottypes = nothing

src/jltypes.c

+5-3
Original file line numberDiff line numberDiff line change
@@ -2046,8 +2046,9 @@ void jl_init_types(void)
20462046
jl_code_info_type =
20472047
jl_new_datatype(jl_symbol("CodeInfo"), core,
20482048
jl_any_type, jl_emptysvec,
2049-
jl_perm_symsvec(9,
2049+
jl_perm_symsvec(10,
20502050
"code",
2051+
"signature_for_inference_heuristics",
20512052
"slottypes",
20522053
"ssavaluetypes",
20532054
"slotflags",
@@ -2056,17 +2057,18 @@ void jl_init_types(void)
20562057
"inlineable",
20572058
"propagate_inbounds",
20582059
"pure"),
2059-
jl_svec(9,
2060+
jl_svec(10,
20602061
jl_array_any_type,
20612062
jl_any_type,
20622063
jl_any_type,
2064+
jl_any_type,
20632065
jl_array_uint8_type,
20642066
jl_array_any_type,
20652067
jl_bool_type,
20662068
jl_bool_type,
20672069
jl_bool_type,
20682070
jl_bool_type),
2069-
0, 1, 9);
2071+
0, 1, 10);
20702072

20712073
jl_method_type =
20722074
jl_new_datatype(jl_symbol("Method"), core,

src/julia-syntax.scm

+2-1
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,8 @@
362362
'nothing
363363
(cons 'list (map car sparams)))
364364
,(if (null? loc) 0 (cadr loc))
365-
(inert ,(if (null? loc) 'none (caddr loc))))))))
365+
(inert ,(if (null? loc) 'none (caddr loc)))
366+
false)))))
366367
(list gf))
367368
'()))
368369
(types (llist-types argl))

src/julia.h

+1
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ typedef struct _jl_llvm_functions_t {
229229
// This type describes a single function body
230230
typedef struct _jl_code_info_t {
231231
jl_array_t *code; // Any array of statements
232+
jl_value_t *signature_for_inference_heuristics; // optional method used during inference
232233
jl_value_t *slottypes; // types of variable slots (or `nothing`)
233234
jl_value_t *ssavaluetypes; // types of ssa values (or count of them)
234235
jl_array_t *slotflags; // local var bit flags

src/method.c

+4-2
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ static void jl_code_info_set_ast(jl_code_info_t *li, jl_expr_t *ast)
187187
jl_array_del_end(meta, na - ins);
188188
}
189189
}
190+
li->signature_for_inference_heuristics = jl_nothing;
190191
jl_array_t *vinfo = (jl_array_t*)jl_exprarg(ast, 1);
191192
jl_array_t *vis = (jl_array_t*)jl_array_ptr_ref(vinfo, 0);
192193
size_t nslots = jl_array_len(vis);
@@ -255,6 +256,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void)
255256
(jl_code_info_t*)jl_gc_alloc(ptls, sizeof(jl_code_info_t),
256257
jl_code_info_type);
257258
src->code = NULL;
259+
src->signature_for_inference_heuristics = NULL;
258260
src->slotnames = NULL;
259261
src->slotflags = NULL;
260262
src->slottypes = NULL;
@@ -442,8 +444,8 @@ static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src)
442444
else if (jl_expr_nargs(st) == 2 && jl_exprarg(st, 0) == (jl_value_t*)generated_sym) {
443445
m->generator = NULL;
444446
jl_value_t *gexpr = jl_exprarg(st, 1);
445-
if (jl_expr_nargs(gexpr) == 6) {
446-
// expects (new (core GeneratedFunctionStub) funcname argnames sp line file)
447+
if (jl_expr_nargs(gexpr) == 7) {
448+
// expects (new (core GeneratedFunctionStub) funcname argnames sp line file expandearly)
447449
jl_value_t *funcname = jl_exprarg(gexpr, 1);
448450
assert(jl_is_symbol(funcname));
449451
if (jl_get_global(m->module, (jl_sym_t*)funcname) != NULL) {

src/toplevel.c

+1
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,7 @@ static jl_code_info_t *expr_to_code_info(jl_value_t *expr)
576576
jl_gc_wb(src, src->slotflags);
577577
src->ssavaluetypes = jl_box_long(0);
578578
jl_gc_wb(src, src->ssavaluetypes);
579+
src->signature_for_inference_heuristics = jl_nothing;
579580

580581
JL_GC_POP();
581582
return src;

test/inference.jl

+74
Original file line numberDiff line numberDiff line change
@@ -1317,3 +1317,77 @@ bar_22708(x) = f_22708(x)
13171317

13181318
@test bar_22708(1) == "x"
13191319

1320+
# mechanism for spoofing work-limiting heuristics and early generator expansion (#24852)
1321+
function _generated_stub(gen::Symbol, args::Vector{Any}, params::Vector{Any}, line, file, expand_early)
1322+
stub = Expr(:new, Core.GeneratedFunctionStub, gen, args, params, line, file, expand_early)
1323+
return Expr(:meta, :generated, stub)
1324+
end
1325+
1326+
f24852_kernel(x, y) = x * y
1327+
1328+
function f24852_kernel_cinfo(x, y)
1329+
sig, spvals, method = Base._methods_by_ftype(Tuple{typeof(f24852_kernel),x,y}, -1, typemax(UInt))[1]
1330+
code_info = Base.uncompressed_ast(method)
1331+
body = Expr(:block, code_info.code...)
1332+
Base.Core.Inference.substitute!(body, 0, Any[], sig, Any[spvals...], 0, :propagate)
1333+
return method, code_info
1334+
end
1335+
1336+
function f24852_gen_cinfo_uninflated(X, Y, f, x, y)
1337+
_, code_info = f24852_kernel_cinfo(x, y)
1338+
return code_info
1339+
end
1340+
1341+
function f24852_gen_cinfo_inflated(X, Y, f, x, y)
1342+
method, code_info = f24852_kernel_cinfo(x, y)
1343+
code_info.signature_for_inference_heuristics = Core.Inference.svec(f, (x, y), typemax(UInt))
1344+
return code_info
1345+
end
1346+
1347+
function f24852_gen_expr(X, Y, f, x, y)
1348+
return :(f24852_kernel(x::$X, y::$Y))
1349+
end
1350+
1351+
@eval begin
1352+
function f24852_late_expr(x::X, y::Y) where {X, Y}
1353+
$(_generated_stub(:f24852_gen_expr, Any[:f24852_late_expr, :x, :y],
1354+
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false))
1355+
end
1356+
function f24852_late_inflated(x::X, y::Y) where {X, Y}
1357+
$(_generated_stub(:f24852_gen_cinfo_inflated, Any[:f24852_late_inflated, :x, :y],
1358+
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false))
1359+
end
1360+
function f24852_late_uninflated(x::X, y::Y) where {X, Y}
1361+
$(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:f24852_late_uninflated, :x, :y],
1362+
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false))
1363+
end
1364+
end
1365+
1366+
@eval begin
1367+
function f24852_early_expr(x::X, y::Y) where {X, Y}
1368+
$(_generated_stub(:f24852_gen_expr, Any[:f24852_early_expr, :x, :y],
1369+
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true))
1370+
end
1371+
function f24852_early_inflated(x::X, y::Y) where {X, Y}
1372+
$(_generated_stub(:f24852_gen_cinfo_inflated, Any[:f24852_early_inflated, :x, :y],
1373+
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true))
1374+
end
1375+
function f24852_early_uninflated(x::X, y::Y) where {X, Y}
1376+
$(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:f24852_early_uninflated, :x, :y],
1377+
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true))
1378+
end
1379+
end
1380+
1381+
x, y = rand(), rand()
1382+
result = f24852_kernel(x, y)
1383+
1384+
@test result === f24852_late_expr(x, y)
1385+
@test result === f24852_late_uninflated(x, y)
1386+
@test result === f24852_late_inflated(x, y)
1387+
1388+
@test result === f24852_early_expr(x, y)
1389+
@test result === f24852_early_uninflated(x, y)
1390+
@test result === f24852_early_inflated(x, y)
1391+
1392+
# TODO: test that `expand_early = true` + inflated `signature_for_inference_heuristics`
1393+
# can be used to tighten up some inference result.

0 commit comments

Comments
 (0)