Skip to content

Commit 7c2f45e

Browse files
committed
inference: enable CodeInfo signature_for_inference_heuristics support
allow unbounded inference recursion, as long as the user-generated token in signature_for_inference_heuristics does not match the target frame
1 parent 9660a30 commit 7c2f45e

File tree

9 files changed

+99
-61
lines changed

9 files changed

+99
-61
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl
187187
cyclei = 0
188188
infstate = sv
189189
edgecycle = false
190+
method2 = method_for_inference_heuristics(method, sig, sparams, sv.params.world) # Union{Method, Nothing}
190191
while !(infstate === nothing)
191192
infstate = infstate::InferenceState
192193
if method === infstate.linfo.def
@@ -197,7 +198,9 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl
197198
edgecycle = true
198199
break
199200
end
200-
if topmost === nothing
201+
inf_method2 = infstate.src.method_for_inference_limit_heuristics # limit only if user token match
202+
inf_method2 isa Method || (inf_method2 = nothing) # Union{Method, Nothing}
203+
if topmost === nothing && method2 === inf_method2
201204
# inspect the parent of this edge,
202205
# to see if they are the same Method as sv
203206
# in which case we'll need to ensure it is convergent

base/compiler/typeinfer.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ function typeinf_code(linfo::MethodInstance, optimize::Bool, cached::Bool,
168168
method = linfo.def::Method
169169
tree = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ())
170170
tree.code = Any[ Expr(:return, quoted(linfo.inferred_const)) ]
171-
tree.signature_for_inference_heuristics = nothing
171+
tree.method_for_inference_limit_heuristics = nothing
172172
tree.slotnames = Any[ COMPILER_TEMP_SYM for i = 1:method.nargs ]
173173
tree.slotflags = UInt8[ 0 for i = 1:method.nargs ]
174174
tree.slottypes = nothing

base/compiler/utilities.jl

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -155,33 +155,21 @@ function code_for_method(method::Method, @nospecialize(atypes), sparams::SimpleV
155155
return ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any, UInt), method, atypes, sparams, world)
156156
end
157157

158-
# TODO: Use these functions instead of directly manipulating
159-
# the "actual" method for appropriate places in inference (see #24676)
160-
function method_for_inference_heuristics(cinfo, default)
161-
if isa(cinfo, CodeInfo)
162-
# appropriate format for `sig` is svec(ftype, argtypes, world)
163-
sig = cinfo.signature_for_inference_heuristics
164-
if isa(sig, SimpleVector) && length(sig) == 3
165-
methods = _methods(sig[1], sig[2], -1, sig[3])
166-
if length(methods) == 1
167-
_, _, m = methods[]
168-
if isa(m, Method)
169-
return m
170-
end
171-
end
172-
end
173-
end
174-
return default
175-
end
176-
177-
function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams, world)
158+
# This function is used for computing alternate limit heuristics
159+
function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams::SimpleVector, world::UInt)
178160
if isdefined(method, :generator) && method.generator.expand_early
179161
method_instance = code_for_method(method, sig, sparams, world, false)
180162
if isa(method_instance, MethodInstance)
181-
return method_for_inference_heuristics(get_staged(method_instance), method)
163+
cinfo = get_staged(method_instance)
164+
if isa(cinfo, CodeInfo)
165+
method2 = cinfo.method_for_inference_limit_heuristics
166+
if method2 isa Method
167+
return method2
168+
end
169+
end
182170
end
183171
end
184-
return method
172+
return nothing
185173
end
186174

187175
function exprtype(@nospecialize(x), src, mod::Module)

src/dump.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2300,7 +2300,8 @@ JL_DLLEXPORT jl_array_t *jl_compress_ast(jl_method_t *m, jl_code_info_t *code)
23002300

23012301
size_t nf = jl_datatype_nfields(jl_code_info_type);
23022302
for (i = 0; i < nf - 5; i++) {
2303-
jl_serialize_value_(&s, jl_get_nth_field((jl_value_t*)code, i), 1);
2303+
int copy = (i != 2); // don't copy contents of method_for_inference_limit_heuristics field
2304+
jl_serialize_value_(&s, jl_get_nth_field((jl_value_t*)code, i), copy);
23042305
}
23052306

23062307
ios_putc('\0', s.s);

src/jltypes.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2026,7 +2026,7 @@ void jl_init_types(void)
20262026
jl_perm_symsvec(12,
20272027
"code",
20282028
"codelocs",
2029-
"signature_for_inference_heuristics",
2029+
"method_for_inference_limit_heuristics",
20302030
"slottypes",
20312031
"ssavaluetypes",
20322032
"linetable",

src/julia.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ typedef struct _jl_llvm_functions_t {
235235
typedef struct _jl_code_info_t {
236236
jl_array_t *code; // Any array of statements
237237
jl_value_t *codelocs; // Int array of indicies into the line table
238-
jl_value_t *signature_for_inference_heuristics; // optional method used during inference
238+
jl_value_t *method_for_inference_limit_heuristics; // optional method used during inference
239239
jl_value_t *slottypes; // types of variable slots (or `nothing`)
240240
jl_value_t *ssavaluetypes; // types of ssa values (or count of them)
241241
jl_value_t *linetable; // Table of locations

src/method.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ static void jl_code_info_set_ast(jl_code_info_t *li, jl_expr_t *ast)
232232
jl_array_del_end(meta, na - ins);
233233
}
234234
}
235-
li->signature_for_inference_heuristics = jl_nothing;
235+
li->method_for_inference_limit_heuristics = jl_nothing;
236236
jl_array_t *vinfo = (jl_array_t*)jl_exprarg(ast, 1);
237237
jl_array_t *vis = (jl_array_t*)jl_array_ptr_ref(vinfo, 0);
238238
size_t nslots = jl_array_len(vis);
@@ -303,7 +303,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void)
303303
(jl_code_info_t*)jl_gc_alloc(ptls, sizeof(jl_code_info_t),
304304
jl_code_info_type);
305305
src->code = NULL;
306-
src->signature_for_inference_heuristics = NULL;
306+
src->method_for_inference_limit_heuristics = NULL;
307307
src->slotnames = NULL;
308308
src->slotflags = NULL;
309309
src->slottypes = NULL;

src/toplevel.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ static jl_code_info_t *expr_to_code_info(jl_value_t *expr)
615615
jl_gc_wb(src, src->slotflags);
616616
src->ssavaluetypes = jl_box_long(0);
617617
jl_gc_wb(src, src->ssavaluetypes);
618-
src->signature_for_inference_heuristics = jl_nothing;
618+
src->method_for_inference_limit_heuristics = jl_nothing;
619619
src->codelocs = jl_nothing;
620620
src->linetable = jl_nothing;
621621

test/compiler/compiler.jl

Lines changed: 77 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,71 +1310,117 @@ function _generated_stub(gen::Symbol, args::Vector{Any}, params::Vector{Any}, li
13101310
return Expr(:meta, :generated, stub)
13111311
end
13121312

1313-
f24852_kernel(x, y) = x * y
1314-
1315-
function f24852_kernel_cinfo(x, y)
1316-
sig, spvals, method = Base._methods_by_ftype(Tuple{typeof(f24852_kernel),x,y}, -1, typemax(UInt))[1]
1313+
f24852_kernel1(x, y::Tuple) = x * y[1][1][1]
1314+
f24852_kernel2(x, y::Tuple) = f24852_kernel1(x, (y,))
1315+
f24852_kernel3(x, y::Tuple) = f24852_kernel2(x, (y,))
1316+
f24852_kernel(x, y::Number) = f24852_kernel3(x, (y,))
1317+
1318+
function f24852_kernel_cinfo(fsig::Type)
1319+
world = typemax(UInt) # FIXME
1320+
sig, spvals, method = Base._methods_by_ftype(fsig, -1, world)[1]
1321+
isdefined(method, :source) || return (nothing, :(f(x, y)))
13171322
code_info = Base.uncompressed_ast(method)
13181323
body = Expr(:block, code_info.code...)
1319-
Base.Core.Compiler.substitute!(body, 0, Any[], sig, Any[spvals...], 0, :propagate)
1324+
Base.Core.Compiler.substitute!(body, 0, Any[], sig, Any[spvals...], 1, :propagate)
1325+
if startswith(String(method.name), "f24852")
1326+
for a in body.args
1327+
if a isa Expr && a.head == :(=)
1328+
a = a.args[2]
1329+
end
1330+
if a isa Expr && length(a.args) === 3 && a.head === :call
1331+
pushfirst!(a.args, Core.SlotNumber(1))
1332+
end
1333+
end
1334+
end
1335+
pushfirst!(code_info.slotnames, Symbol("#self#"))
1336+
pushfirst!(code_info.slotflags, 0x00)
13201337
return method, code_info
13211338
end
13221339

1323-
function f24852_gen_cinfo_uninflated(X, Y, f, x, y)
1324-
_, code_info = f24852_kernel_cinfo(x, y)
1340+
function f24852_gen_cinfo_uninflated(X, Y, _, f, x, y)
1341+
_, code_info = f24852_kernel_cinfo(Tuple{f, x, y})
13251342
return code_info
13261343
end
13271344

1328-
function f24852_gen_cinfo_inflated(X, Y, f, x, y)
1329-
method, code_info = f24852_kernel_cinfo(x, y)
1330-
code_info.signature_for_inference_heuristics = Core.Compiler.svec(f, (x, y), typemax(UInt))
1345+
function f24852_gen_cinfo_inflated(X, Y, _, f, x, y)
1346+
method, code_info = f24852_kernel_cinfo(Tuple{f, x, y})
1347+
code_info.signature_for_inference_heuristics = method
13311348
return code_info
13321349
end
13331350

1334-
function f24852_gen_expr(X, Y, f, x, y)
1335-
return :(f24852_kernel(x::$X, y::$Y))
1351+
function f24852_gen_expr(X, Y, _, f, x, y) # deparse f(x::X, y::Y) where {X, Y}
1352+
if f === typeof(f24852_kernel)
1353+
f2 = :f24852_kernel3
1354+
elseif f === typeof(f24852_kernel3)
1355+
f2 = :f24852_kernel2
1356+
elseif f === typeof(f24852_kernel2)
1357+
f2 = :f24852_kernel1
1358+
elseif f === typeof(f24852_kernel1)
1359+
return :((x::$X) * (y::$Y)[1][1][1])
1360+
else
1361+
return :(error(repr(f)))
1362+
end
1363+
return :(f24852_late_expr($f2, x::$X, (y::$Y,)))
13361364
end
13371365

13381366
@eval begin
1339-
function f24852_late_expr(x::X, y::Y) where {X, Y}
1340-
$(_generated_stub(:f24852_gen_expr, Any[:f24852_late_expr, :x, :y],
1367+
function f24852_late_expr(f, x::X, y::Y) where {X, Y}
1368+
$(_generated_stub(:f24852_gen_expr, Any[:self, :f, :x, :y],
13411369
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false))
1370+
$(Expr(:meta, :generated_only))
1371+
#= no body =#
13421372
end
1343-
function f24852_late_inflated(x::X, y::Y) where {X, Y}
1344-
$(_generated_stub(:f24852_gen_cinfo_inflated, Any[:f24852_late_inflated, :x, :y],
1373+
function f24852_late_inflated(f, x::X, y::Y) where {X, Y}
1374+
$(_generated_stub(:f24852_gen_cinfo_inflated, Any[:self, :f, :x, :y],
13451375
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false))
1376+
$(Expr(:meta, :generated_only))
1377+
#= no body =#
13461378
end
1347-
function f24852_late_uninflated(x::X, y::Y) where {X, Y}
1348-
$(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:f24852_late_uninflated, :x, :y],
1379+
function f24852_late_uninflated(f, x::X, y::Y) where {X, Y}
1380+
$(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:self, :f, :x, :y],
13491381
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false))
1382+
$(Expr(:meta, :generated_only))
1383+
#= no body =#
13501384
end
13511385
end
13521386

13531387
@eval begin
1354-
function f24852_early_expr(x::X, y::Y) where {X, Y}
1355-
$(_generated_stub(:f24852_gen_expr, Any[:f24852_early_expr, :x, :y],
1388+
function f24852_early_expr(f, x::X, y::Y) where {X, Y}
1389+
$(_generated_stub(:f24852_gen_expr, Any[:self, :f, :x, :y],
13561390
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true))
1391+
$(Expr(:meta, :generated_only))
1392+
#= no body =#
13571393
end
1358-
function f24852_early_inflated(x::X, y::Y) where {X, Y}
1359-
$(_generated_stub(:f24852_gen_cinfo_inflated, Any[:f24852_early_inflated, :x, :y],
1394+
function f24852_early_inflated(f, x::X, y::Y) where {X, Y}
1395+
$(_generated_stub(:f24852_gen_cinfo_inflated, Any[:self, :f, :x, :y],
13601396
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true))
1397+
$(Expr(:meta, :generated_only))
1398+
#= no body =#
13611399
end
1362-
function f24852_early_uninflated(x::X, y::Y) where {X, Y}
1363-
$(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:f24852_early_uninflated, :x, :y],
1400+
function f24852_early_uninflated(f, x::X, y::Y) where {X, Y}
1401+
$(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:self, :f, :x, :y],
13641402
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true))
1403+
$(Expr(:meta, :generated_only))
1404+
#= no body =#
13651405
end
13661406
end
13671407

13681408
x, y = rand(), rand()
13691409
result = f24852_kernel(x, y)
13701410

1371-
@test result === f24852_late_expr(x, y)
1372-
@test result === f24852_late_uninflated(x, y)
1373-
@test result === f24852_late_inflated(x, y)
1374-
1375-
@test result === f24852_early_expr(x, y)
1376-
@test result === f24852_early_uninflated(x, y)
1377-
@test result === f24852_early_inflated(x, y)
1411+
@test result === f24852_late_expr(f24852_kernel, x, y)
1412+
@test Base.return_types(f24852_late_expr, typeof((f24852_kernel, x, y))) == Any[Any]
1413+
@test result === f24852_late_uninflated(f24852_kernel, x, y)
1414+
@test Base.return_types(f24852_late_uninflated, typeof((f24852_kernel, x, y))) == Any[Any]
1415+
@test result === f24852_late_uninflated(f24852_kernel, x, y)
1416+
@test Base.return_types(f24852_late_uninflated, typeof((f24852_kernel, x, y))) == Any[Any]
1417+
1418+
@test result === f24852_early_expr(f24852_kernel, x, y)
1419+
@test Base.return_types(f24852_early_expr, typeof((f24852_kernel, x, y))) == Any[Any]
1420+
@test result === f24852_early_uninflated(f24852_kernel, x, y)
1421+
@test Base.return_types(f24852_early_uninflated, typeof((f24852_kernel, x, y))) == Any[Any]
1422+
@test result === @inferred f24852_early_inflated(f24852_kernel, x, y)
1423+
@test Base.return_types(f24852_early_inflated, typeof((f24852_kernel, x, y))) == Any[Float64]
13781424

13791425
# TODO: test that `expand_early = true` + inflated `signature_for_inference_heuristics`
13801426
# can be used to tighten up some inference result.

0 commit comments

Comments
 (0)