Skip to content

Commit 4823351

Browse files
authored
Merge pull request #17212 from JuliaLang/jn/invoke-union-splitting
callsite union splitting
2 parents ef56ff8 + 316bfdf commit 4823351

File tree

6 files changed

+171
-42
lines changed

6 files changed

+171
-42
lines changed

base/inference.jl

Lines changed: 107 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ const MAX_TUPLETYPE_LEN = 15
99
const MAX_TUPLE_DEPTH = 4
1010

1111
const MAX_TUPLE_SPLAT = 16
12+
const MAX_UNION_SPLITTING = 6
1213

1314
# alloc_elim_pass! relies on `Slot_AssignedOnce | Slot_UsedUndef` being
1415
# SSA. This should be true now but can break if we start to track conditional
@@ -2359,6 +2360,16 @@ function inline_as_constant(val::ANY, argexprs, sv)
23592360
return (QuoteNode(val), stmts)
23602361
end
23612362

2363+
function countunionsplit(atypes::Vector{Any})
2364+
nu = 1
2365+
for ti in atypes
2366+
if isa(ti, Union)
2367+
nu *= length((ti::Union).types)
2368+
end
2369+
end
2370+
return nu
2371+
end
2372+
23622373
# inline functions whose bodies are "inline_worthy"
23632374
# where the function body doesn't contain any argument more than once.
23642375
# static parameters are ok if all the static parameter values are leaf types,
@@ -2413,13 +2424,106 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
24132424
return NF
24142425
end
24152426

2416-
atype_unlimited = argtypes_to_type(atypes)
2427+
local atype_unlimited = argtypes_to_type(atypes)
24172428
function invoke_NF()
24182429
# converts a :call to :invoke
2419-
cache_linfo = ccall(:jl_get_spec_lambda, Any, (Any,), atype_unlimited)
2420-
if cache_linfo !== nothing
2430+
local nu = countunionsplit(atypes)
2431+
nu > MAX_UNION_SPLITTING && return NF
2432+
2433+
if nu > 1
2434+
local spec_hit = nothing
2435+
local spec_miss = nothing
2436+
local error_label = nothing
2437+
local linfo_var = add_slot!(enclosing, LambdaInfo, false)
2438+
local ex = copy(e)
2439+
local stmts = []
2440+
for i = 1:length(atypes); local i
2441+
local ti = atypes[i]
2442+
if isa(ti, Union)
2443+
aei = ex.args[i]
2444+
if !effect_free(aei, sv, false)
2445+
newvar = newvar!(sv, ti)
2446+
push!(stmts, Expr(:(=), newvar, aei))
2447+
ex.args[i] = newvar
2448+
end
2449+
end
2450+
end
2451+
function splitunion(atypes::Vector{Any}, i::Int)
2452+
if i == 0
2453+
local sig = argtypes_to_type(atypes)
2454+
local li = ccall(:jl_get_spec_lambda, Any, (Any,), sig)
2455+
li === nothing && return false
2456+
local stmt = []
2457+
push!(stmt, Expr(:(=), linfo_var, li))
2458+
spec_hit === nothing && (spec_hit = genlabel(sv))
2459+
push!(stmt, GotoNode(spec_hit.label))
2460+
return stmt
2461+
else
2462+
local ti = atypes[i]
2463+
if isa(ti, Union)
2464+
local all = true
2465+
local stmts = []
2466+
local aei = ex.args[i]
2467+
for ty in (ti::Union).types; local ty
2468+
atypes[i] = ty
2469+
local match = splitunion(atypes, i - 1)
2470+
if match !== false
2471+
after = genlabel(sv)
2472+
unshift!(match, Expr(:gotoifnot, Expr(:call, GlobalRef(Core, :isa), aei, ty), after.label))
2473+
append!(stmts, match)
2474+
push!(stmts, after)
2475+
else
2476+
all = false
2477+
end
2478+
end
2479+
if all
2480+
error_label === nothing && (error_label = genlabel(sv))
2481+
push!(stmts, GotoNode(error_label.label))
2482+
else
2483+
spec_miss === nothing && (spec_miss = genlabel(sv))
2484+
push!(stmts, GotoNode(spec_miss.label))
2485+
end
2486+
atypes[i] = ti
2487+
return isempty(stmts) ? false : stmts
2488+
else
2489+
return splitunion(atypes, i - 1)
2490+
end
2491+
end
2492+
end
2493+
local match = splitunion(atypes, length(atypes))
2494+
if match !== false && spec_hit !== nothing
2495+
append!(stmts, match)
2496+
if error_label !== nothing
2497+
push!(stmts, error_label)
2498+
push!(stmts, Expr(:call, GlobalRef(_topmod(sv.mod), :error), "error in type inference due to #265"))
2499+
end
2500+
local ret_var, merge
2501+
if spec_miss !== nothing
2502+
ret_var = add_slot!(enclosing, ex.typ, false)
2503+
merge = genlabel(sv)
2504+
push!(stmts, spec_miss)
2505+
push!(stmts, Expr(:(=), ret_var, ex))
2506+
push!(stmts, GotoNode(merge.label))
2507+
else
2508+
ret_var = newvar!(sv, ex.typ)
2509+
end
2510+
push!(stmts, spec_hit)
2511+
ex = copy(ex)
2512+
ex.head = :invoke
2513+
unshift!(ex.args, linfo_var)
2514+
push!(stmts, Expr(:(=), ret_var, ex))
2515+
if spec_miss !== nothing
2516+
push!(stmts, merge)
2517+
end
2518+
#println(stmts)
2519+
return (ret_var, stmts)
2520+
end
2521+
else
2522+
local cache_linfo = ccall(:jl_get_spec_lambda, Any, (Any,), atype_unlimited)
2523+
cache_linfo === nothing && return NF
24212524
e.head = :invoke
24222525
unshift!(e.args, cache_linfo)
2526+
return e
24232527
end
24242528
return NF
24252529
end

base/reflection.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ function _methods_by_ftype(t::ANY, lim)
192192
if 1 < nu <= 64
193193
return _methods(Any[tp...], length(tp), lim, [])
194194
end
195-
# TODO: the following can return incorrect answers that the above branch would have corrected
195+
# XXX: the following can return incorrect answers that the above branch would have corrected
196196
return ccall(:jl_matching_methods, Any, (Any,Cint,Cint), t, lim, 0)
197197
end
198198
function _methods(t::Array,i,lim::Integer,matching::Array{Any,1})
@@ -206,7 +206,7 @@ function _methods(t::Array,i,lim::Integer,matching::Array{Any,1})
206206
for ty in (ti::Union).types
207207
t[i] = ty
208208
if _methods(t,i-1,lim,matching) === false
209-
t[i] = ty
209+
t[i] = ti
210210
return false
211211
end
212212
end

base/show.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,15 @@ show_unquoted(io::IO, ex::LabelNode, ::Int, ::Int) = print(io, ex.label, ":
561561
show_unquoted(io::IO, ex::GotoNode, ::Int, ::Int) = print(io, "goto ", ex.label)
562562
show_unquoted(io::IO, ex::GlobalRef, ::Int, ::Int) = print(io, ex.mod, '.', ex.name)
563563

564+
function show_unquoted(io::IO, ex::LambdaInfo, ::Int, ::Int)
565+
if isdefined(ex, :specTypes)
566+
print(io, "LambdaInfo for ")
567+
show_lambda_types(io, ex.specTypes.parameters)
568+
else
569+
show(io, ex)
570+
end
571+
end
572+
564573
function show_unquoted(io::IO, ex::Slot, ::Int, ::Int)
565574
typ = isa(ex,TypedSlot) ? ex.typ : Any
566575
slotid = ex.id

src/cgutils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ static jl_cgval_t emit_typeof(const jl_cgval_t &p, jl_codectx_t *ctx)
515515
{
516516
// given p, compute its type
517517
if (!p.constant && p.isboxed && !jl_is_leaf_type(p.typ)) {
518-
return mark_julia_type(emit_typeof(p.V), true, jl_datatype_type, ctx);
518+
return mark_julia_type(emit_typeof(p.V), true, jl_datatype_type, ctx, /*needsroot*/false);
519519
}
520520
jl_value_t *aty = p.typ;
521521
if (jl_is_type_type(aty)) // convert Int::Type{Int} ==> typeof(Int) ==> DataType

src/codegen.cpp

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2700,38 +2700,39 @@ static jl_cgval_t emit_invoke(jl_expr_t *ex, jl_codectx_t *ctx)
27002700
size_t arglen = jl_array_dim0(ex->args);
27012701
size_t nargs = arglen - 1;
27022702
assert(arglen >= 2);
2703-
jl_lambda_info_t *li = (jl_lambda_info_t*)args[0];
2704-
assert(jl_is_lambda_info(li));
27052703

2706-
if (li->jlcall_api == 2) {
2707-
assert(li->constval);
2708-
return mark_julia_const(li->constval);
2709-
}
2710-
else if (li->functionObjectsDecls.functionObject == NULL) {
2711-
assert(!li->inCompile);
2712-
if (li->code == jl_nothing && !li->inInference && li->inferred) {
2713-
// XXX: it was inferred in the past, so it's almost valid to re-infer it now
2714-
jl_type_infer(li, 0);
2704+
jl_cgval_t lival = emit_expr(args[0], ctx);
2705+
if (lival.constant) {
2706+
jl_lambda_info_t *li = (jl_lambda_info_t*)lival.constant;
2707+
assert(jl_is_lambda_info(li));
2708+
if (li->jlcall_api == 2) {
2709+
assert(li->constval);
2710+
return mark_julia_const(li->constval);
2711+
}
2712+
if (li->functionObjectsDecls.functionObject == NULL) {
2713+
assert(!li->inCompile);
2714+
if (li->code == jl_nothing && !li->inInference && li->inferred) {
2715+
// XXX: it was inferred in the past, so it's almost valid to re-infer it now
2716+
jl_type_infer(li, 0);
2717+
}
2718+
if (!li->inInference && li->inferred && li->code != jl_nothing) {
2719+
jl_compile_linfo(li);
2720+
}
27152721
}
2716-
if (!li->inInference && li->inferred && li->code != jl_nothing) {
2717-
jl_compile_linfo(li);
2722+
Value *theFptr = (Value*)li->functionObjectsDecls.functionObject;
2723+
if (theFptr && li->jlcall_api == 0) {
2724+
jl_cgval_t fval = emit_expr(args[1], ctx);
2725+
jl_cgval_t result = emit_call_function_object(li, fval, theFptr, &args[1], nargs - 1, (jl_value_t*)ex, ctx);
2726+
if (result.typ == jl_bottom_type)
2727+
CreateTrap(builder);
2728+
return result;
27182729
}
27192730
}
2720-
Value *theFptr = (Value*)li->functionObjectsDecls.functionObject;
2721-
jl_cgval_t result;
2722-
if (theFptr && li->jlcall_api == 0) {
2723-
jl_cgval_t fval = emit_expr(args[1], ctx);
2724-
result = emit_call_function_object(li, fval, theFptr, &args[1], nargs - 1, (jl_value_t*)ex, ctx);
2725-
}
2726-
else {
2727-
result = mark_julia_type(emit_jlcall(prepare_call(jlinvoke_func), literal_pointer_val((jl_value_t*)li),
2728-
&args[1], nargs, ctx),
2729-
true, expr_type((jl_value_t*)ex, ctx), ctx);
2730-
}
2731-
2732-
if (result.typ == jl_bottom_type) {
2731+
jl_cgval_t result = mark_julia_type(emit_jlcall(prepare_call(jlinvoke_func), boxed(lival, ctx, false),
2732+
&args[1], nargs, ctx),
2733+
true, expr_type((jl_value_t*)ex, ctx), ctx);
2734+
if (result.typ == jl_bottom_type)
27332735
CreateTrap(builder);
2734-
}
27352736
return result;
27362737
}
27372738

@@ -3995,7 +3996,7 @@ static Function *gen_jlcall_wrapper(jl_lambda_info_t *lam, Function *f, bool sre
39953996
bool retboxed;
39963997
(void)julia_type_to_llvm(jlretty, &retboxed);
39973998
if (sret) { assert(!retboxed); }
3998-
jl_cgval_t retval = sret ? mark_julia_slot(result, jlretty, tbaa_stack) : mark_julia_type(call, retboxed, jlretty, &ctx);
3999+
jl_cgval_t retval = sret ? mark_julia_slot(result, jlretty, tbaa_stack) : mark_julia_type(call, retboxed, jlretty, &ctx, /*needsroot*/false);
39994000
builder.CreateRet(boxed(retval, &ctx, false)); // no gcroot needed since this on the return path
40004001

40014002
return w;

src/debuginfo.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -483,8 +483,8 @@ class JuliaJITEventListener: public JITEventListener
483483
else
484484
SectionAddrCheck = SectionLoadAddr;
485485
create_PRUNTIME_FUNCTION(
486-
(uint8_t*)(intptr_t)Addr, (size_t)Size, sName,
487-
(uint8_t*)(intptr_t)SectionLoadAddr, (size_t)SectionSize, UnwindData);
486+
(uint8_t*)(uintptr_t)Addr, (size_t)Size, sName,
487+
(uint8_t*)(uintptr_t)SectionLoadAddr, (size_t)SectionSize, UnwindData);
488488
#endif
489489
StringMap<jl_lambda_info_t*>::iterator linfo_it = linfo_in_flight.find(sName);
490490
jl_lambda_info_t *linfo = NULL;
@@ -559,8 +559,8 @@ class JuliaJITEventListener: public JITEventListener
559559
else
560560
SectionAddrCheck = SectionLoadAddr;
561561
create_PRUNTIME_FUNCTION(
562-
(uint8_t*)(intptr_t)Addr, (size_t)Size, sName,
563-
(uint8_t*)(intptr_t)SectionLoadAddr, (size_t)SectionSize, UnwindData);
562+
(uint8_t*)(uintptr_t)Addr, (size_t)Size, sName,
563+
(uint8_t*)(uintptr_t)SectionLoadAddr, (size_t)SectionSize, UnwindData);
564564
#endif
565565
StringMap<jl_lambda_info_t*>::iterator linfo_it = linfo_in_flight.find(sName);
566566
jl_lambda_info_t *linfo = NULL;
@@ -1256,7 +1256,7 @@ int jl_getFunctionInfo(jl_frame_t **frames_out, size_t pointer, int skipC, int n
12561256
// Without MCJIT we use the FuncInfo structure containing address maps
12571257
std::map<size_t, FuncInfo, revcomp> &info = jl_jit_events->getMap();
12581258
std::map<size_t, FuncInfo, revcomp>::iterator it = info.lower_bound(pointer);
1259-
if (it != info.end() && (intptr_t)(*it).first + (*it).second.lengthAdr >= pointer) {
1259+
if (it != info.end() && (uintptr_t)(*it).first + (*it).second.lengthAdr >= pointer) {
12601260
// We do this to hide the jlcall wrappers when getting julia backtraces,
12611261
// but it is still good to have them for regular lookup of C frames.
12621262
if (skipC && (*it).second.lines.empty()) {
@@ -1330,6 +1330,21 @@ int jl_getFunctionInfo(jl_frame_t **frames_out, size_t pointer, int skipC, int n
13301330
return jl_getDylibFunctionInfo(frames_out, pointer, skipC, noInline);
13311331
}
13321332

1333+
extern "C" jl_lambda_info_t *jl_gdblookuplinfo(void *p)
1334+
{
1335+
#ifndef USE_MCJIT
1336+
std::map<size_t, FuncInfo, revcomp> &info = jl_jit_events->getMap();
1337+
std::map<size_t, FuncInfo, revcomp>::iterator it = info.lower_bound((size_t)p);
1338+
jl_lambda_info_t *li = NULL;
1339+
if (it != info.end() && (uintptr_t)(*it).first + (*it).second.lengthAdr >= (uintptr_t)p)
1340+
li = (*it).second.linfo;
1341+
uv_rwlock_rdunlock(&threadsafe);
1342+
return li;
1343+
#else
1344+
return jl_jit_events->lookupLinfo((size_t)p);
1345+
#endif
1346+
}
1347+
13331348
#if defined(LLVM37) && (defined(_OS_LINUX_) || (defined(_OS_DARWIN_) && defined(LLVM_SHLIB)))
13341349
extern "C" void __register_frame(void*);
13351350
extern "C" void __deregister_frame(void*);
@@ -1745,7 +1760,7 @@ uint64_t jl_getUnwindInfo(uint64_t dwAddr)
17451760
std::map<size_t, ObjectInfo, revcomp>::iterator it = objmap.lower_bound(dwAddr);
17461761
uint64_t ipstart = 0; // ip of the start of the section (if found)
17471762
if (it != objmap.end() && dwAddr < it->first + it->second.SectionSize) {
1748-
ipstart = (uint64_t)(intptr_t)(*it).first;
1763+
ipstart = (uint64_t)(uintptr_t)(*it).first;
17491764
}
17501765
uv_rwlock_rdunlock(&threadsafe);
17511766
return ipstart;
@@ -1758,8 +1773,8 @@ uint64_t jl_getUnwindInfo(uint64_t dwAddr)
17581773
std::map<size_t, FuncInfo, revcomp> &info = jl_jit_events->getMap();
17591774
std::map<size_t, FuncInfo, revcomp>::iterator it = info.lower_bound(dwAddr);
17601775
uint64_t ipstart = 0; // ip of the first instruction in the function (if found)
1761-
if (it != info.end() && (intptr_t)(*it).first + (*it).second.lengthAdr > dwAddr) {
1762-
ipstart = (uint64_t)(intptr_t)(*it).first;
1776+
if (it != info.end() && (uintptr_t)(*it).first + (*it).second.lengthAdr > dwAddr) {
1777+
ipstart = (uint64_t)(uintptr_t)(*it).first;
17631778
}
17641779
uv_rwlock_rdunlock(&threadsafe);
17651780
return ipstart;

0 commit comments

Comments
 (0)