Skip to content

WIP: turn on specialization spoofing mechanism in the compiler #25639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,16 +189,19 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl
cyclei = 0
infstate = sv
edgecycle = false
spoofed_sig, spoofed_method = method_for_specialization_heuristics(method, sig, sparams, sv.params.world)
spoofed_sv_sig, spoofed_sv_method = method_for_specialization_heuristics(sv)
while !(infstate === nothing)
infstate = infstate::InferenceState
if method === infstate.linfo.def
if infstate.linfo.specTypes == sig
# avoid widening when detecting self-recursion
# TODO: merge call cycle and return right away
topmost = nothing
edgecycle = true
break
end
working_sig, working_method = method_for_specialization_heuristics(infstate)
if working_sig == spoofed_sig
# avoid widening when detecting self-recursion
# TODO: merge call cycle and return right away
topmost = nothing
edgecycle = true
break
end
if spoofed_method === working_method
if topmost === nothing
# inspect the parent of this edge,
# to see if they are the same Method as sv
Expand All @@ -207,7 +210,8 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl
for parent in infstate.callers_in_cycle
# check in the cycle list first
# all items in here are mutual parents of all others
if parent.linfo.def === sv.linfo.def
_, parent_method = method_for_specialization_heuristics(parent)
if parent_method === spoofed_sv_method
topmost = infstate
edgecycle = true
break
Expand All @@ -217,7 +221,8 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl
# then check the parent link
if topmost === nothing && parent !== nothing
parent = parent::InferenceState
if parent.cached && parent.linfo.def === sv.linfo.def
_, parent_method = method_for_specialization_heuristics(parent)
if parent.cached && parent_method === spoofed_sv_method
topmost = infstate
edgecycle = true
end
Expand Down
4 changes: 4 additions & 0 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@ function is_specializable_vararg_slot(@nospecialize(arg), sv::InferenceState)
isa(sv.vararg_type_container, DataType))
end

function method_for_specialization_heuristics(sv::InferenceState)
return method_for_specialization_heuristics(sv.src, (sv.linfo.specTypes, sv.linfo.def))
end

function print_callstack(sv::InferenceState)
while sv !== nothing
print(sv.linfo)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ function typeinf_code(linfo::MethodInstance, optimize::Bool, cached::Bool,
method = linfo.def::Method
tree = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ())
tree.code = Any[ Expr(:return, quoted(linfo.inferred_const)) ]
tree.signature_for_inference_heuristics = nothing
tree.signature_for_specialization_heuristics = nothing
tree.slotnames = Any[ COMPILER_TEMP_SYM for i = 1:method.nargs ]
tree.slotflags = UInt8[ 0 for i = 1:method.nargs ]
tree.slottypes = nothing
Expand Down
16 changes: 8 additions & 8 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,31 +144,31 @@ end

# TODO: Use these functions instead of directly manipulating
# the "actual" method for appropriate places in inference (see #24676)
function method_for_inference_heuristics(cinfo, default)
function method_for_specialization_heuristics(cinfo, default)
if isa(cinfo, CodeInfo)
# appropriate format for `sig` is svec(ftype, argtypes, world)
sig = cinfo.signature_for_inference_heuristics
sig = cinfo.signature_for_specialization_heuristics
if isa(sig, SimpleVector) && length(sig) == 3
methods = _methods(sig[1], sig[2], -1, sig[3])
if length(methods) == 1
_, _, m = methods[]
if isa(m, Method)
return m
spoofed_sig, _, spoofed_method = methods[]
if isa(spoofed_method, Method)
return spoofed_sig, spoofed_method
end
end
end
end
return default
end

function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams, world)
function method_for_specialization_heuristics(method::Method, @nospecialize(sig), sparams, world)
if isdefined(method, :generator) && method.generator.expand_early
method_instance = code_for_method(method, sig, sparams, world, false)
if isa(method_instance, MethodInstance)
return method_for_inference_heuristics(get_staged(method_instance), method)
return method_for_specialization_heuristics(get_staged(method_instance), (sig, method))
end
end
return method
return (sig, method)
end

function exprtype(@nospecialize(x), src::CodeInfo, mod::Module)
Expand Down
2 changes: 1 addition & 1 deletion src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2048,7 +2048,7 @@ void jl_init_types(void)
jl_any_type, jl_emptysvec,
jl_perm_symsvec(10,
"code",
"signature_for_inference_heuristics",
"signature_for_specialization_heuristics",
"slottypes",
"ssavaluetypes",
"slotflags",
Expand Down
2 changes: 1 addition & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ typedef struct _jl_llvm_functions_t {
// This type describes a single function body
typedef struct _jl_code_info_t {
jl_array_t *code; // Any array of statements
jl_value_t *signature_for_inference_heuristics; // optional method used during inference
jl_value_t *signature_for_specialization_heuristics; // optional method used during inference
jl_value_t *slottypes; // types of variable slots (or `nothing`)
jl_value_t *ssavaluetypes; // types of ssa values (or count of them)
jl_array_t *slotflags; // local var bit flags
Expand Down
4 changes: 2 additions & 2 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ static void jl_code_info_set_ast(jl_code_info_t *li, jl_expr_t *ast)
jl_array_del_end(meta, na - ins);
}
}
li->signature_for_inference_heuristics = jl_nothing;
li->signature_for_specialization_heuristics = jl_nothing;
jl_array_t *vinfo = (jl_array_t*)jl_exprarg(ast, 1);
jl_array_t *vis = (jl_array_t*)jl_array_ptr_ref(vinfo, 0);
size_t nslots = jl_array_len(vis);
Expand Down Expand Up @@ -256,7 +256,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void)
(jl_code_info_t*)jl_gc_alloc(ptls, sizeof(jl_code_info_t),
jl_code_info_type);
src->code = NULL;
src->signature_for_inference_heuristics = NULL;
src->signature_for_specialization_heuristics = NULL;
src->slotnames = NULL;
src->slotflags = NULL;
src->slottypes = NULL;
Expand Down
2 changes: 1 addition & 1 deletion src/toplevel.c
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ static jl_code_info_t *expr_to_code_info(jl_value_t *expr)
jl_gc_wb(src, src->slotflags);
src->ssavaluetypes = jl_box_long(0);
jl_gc_wb(src, src->ssavaluetypes);
src->signature_for_inference_heuristics = jl_nothing;
src->signature_for_specialization_heuristics = jl_nothing;

JL_GC_POP();
return src;
Expand Down
4 changes: 2 additions & 2 deletions test/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1342,7 +1342,7 @@ end

function f24852_gen_cinfo_inflated(X, Y, f, x, y)
method, code_info = f24852_kernel_cinfo(x, y)
code_info.signature_for_inference_heuristics = Core.Compiler.svec(f, (x, y), typemax(UInt))
code_info.signature_for_specialization_heuristics = Core.Compiler.svec(f, (x, y), typemax(UInt))
return code_info
end

Expand Down Expand Up @@ -1391,7 +1391,7 @@ result = f24852_kernel(x, y)
@test result === f24852_early_uninflated(x, y)
@test result === f24852_early_inflated(x, y)

# TODO: test that `expand_early = true` + inflated `signature_for_inference_heuristics`
# TODO: test that `expand_early = true` + inflated `signature_for_specialization_heuristics`
# can be used to tighten up some inference result.

# Test that Conditional doesn't get widened to Bool too quickly
Expand Down