Skip to content

Commit

Permalink
Merge opaque closure modules with the rest of the workqueue (#50724)
Browse files Browse the repository at this point in the history
This sticks the compiled opaque closure module into the
`compiled_functions` list of modules that we have compiled for the
particular `jl_codegen_params_t`. We probably should manage that vector
in codegen_params, since it lets us see if a particular codeinst has
already been compiled but not yet emitted.
  • Loading branch information
pchintalapudi authored Jul 31, 2023
1 parent f4cb8bc commit 441fcb1
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 89 deletions.
11 changes: 5 additions & 6 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,6 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
jl_native_code_desc_t *data = new jl_native_code_desc_t;
CompilationPolicy policy = (CompilationPolicy) _policy;
bool imaging = imaging_default() || _imaging_mode == 1;
jl_workqueue_t emitted;
jl_method_instance_t *mi = NULL;
jl_code_info_t *src = NULL;
JL_GC_PUSH1(&src);
Expand Down Expand Up @@ -335,7 +334,7 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
// find and prepare the source code to compile
jl_code_instance_t *codeinst = NULL;
jl_ci_cache_lookup(*cgparams, mi, params.world, &codeinst, &src);
if (src && !emitted.count(codeinst)) {
if (src && !params.compiled_functions.count(codeinst)) {
// now add it to our compilation results
JL_GC_PROMISE_ROOTED(codeinst->rettype);
orc::ThreadSafeModule result_m = jl_create_ts_module(name_from_method_instance(codeinst->def),
Expand All @@ -344,13 +343,13 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
Triple(clone.getModuleUnlocked()->getTargetTriple()));
jl_llvm_functions_t decls = jl_emit_code(result_m, mi, src, codeinst->rettype, params);
if (result_m)
emitted[codeinst] = {std::move(result_m), std::move(decls)};
params.compiled_functions[codeinst] = {std::move(result_m), std::move(decls)};
}
}
}

// finally, make sure all referenced methods also get compiled or fixed up
jl_compile_workqueue(emitted, *clone.getModuleUnlocked(), params, policy);
jl_compile_workqueue(params, *clone.getModuleUnlocked(), policy);
}
JL_UNLOCK(&jl_codegen_lock); // Might GC
JL_GC_POP();
Expand All @@ -369,7 +368,7 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
data->jl_value_to_llvm[idx] = global.first;
idx++;
}
CreateNativeMethods += emitted.size();
CreateNativeMethods += params.compiled_functions.size();

size_t offset = gvars.size();
data->jl_external_to_llvm.resize(params.external_fns.size());
Expand All @@ -394,7 +393,7 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
{
JL_TIMING(NATIVE_AOT, NATIVE_Merge);
Linker L(*clone.getModuleUnlocked());
for (auto &def : emitted) {
for (auto &def : params.compiled_functions) {
jl_merge_module(clone, std::move(std::get<0>(def.second)));
jl_code_instance_t *this_code = def.first;
jl_llvm_functions_t decls = std::get<1>(def.second);
Expand Down
112 changes: 48 additions & 64 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1616,7 +1616,6 @@ class jl_codectx_t {
std::vector<std::tuple<jl_cgval_t, BasicBlock *, AllocaInst *, PHINode *, jl_value_t *>> PhiNodes;
std::vector<bool> ssavalue_assigned;
std::vector<int> ssavalue_usecount;
std::vector<orc::ThreadSafeModule> oc_modules;
jl_module_t *module = NULL;
jl_typecache_t type_cache;
jl_tbaacache_t tbaa_cache;
Expand Down Expand Up @@ -4460,7 +4459,7 @@ static jl_cgval_t emit_invoke(jl_codectx_t &ctx, const jl_cgval_t &lival, const
// Check if we already queued this up
auto it = ctx.call_targets.find(codeinst);
if (need_to_emit && it != ctx.call_targets.end()) {
protoname = std::get<2>(it->second)->getName();
protoname = it->second.decl->getName();
need_to_emit = cache_valid = false;
}

Expand Down Expand Up @@ -4504,7 +4503,7 @@ static jl_cgval_t emit_invoke(jl_codectx_t &ctx, const jl_cgval_t &lival, const
handled = true;
if (need_to_emit) {
Function *trampoline_decl = cast<Function>(jl_Module->getNamedValue(protoname));
ctx.call_targets[codeinst] = std::make_tuple(cc, return_roots, trampoline_decl, specsig);
ctx.call_targets[codeinst] = {cc, return_roots, trampoline_decl, specsig};
}
}
}
Expand Down Expand Up @@ -5369,8 +5368,7 @@ static std::pair<Function*, Function*> get_oc_function(jl_codectx_t &ctx, jl_met
{
jl_svec_t *sig_args = NULL;
jl_value_t *sigtype = NULL;
jl_code_info_t *ir = NULL;
JL_GC_PUSH3(&sig_args, &sigtype, &ir);
JL_GC_PUSH2(&sig_args, &sigtype);

size_t nsig = 1 + jl_svec_len(argt_typ->parameters);
sig_args = jl_alloc_svec_uninit(nsig);
Expand All @@ -5392,16 +5390,25 @@ static std::pair<Function*, Function*> get_oc_function(jl_codectx_t &ctx, jl_met
JL_GC_POP();
return std::make_pair((Function*)NULL, (Function*)NULL);
}
++EmittedOpaqueClosureFunctions;

ir = jl_uncompress_ir(closure_method, ci, (jl_value_t*)inferred);
auto it = ctx.emission_context.compiled_functions.find(ci);

// TODO: Emit this inline and outline it late using LLVM's coroutine support.
orc::ThreadSafeModule closure_m = jl_create_ts_module(
name_from_method_instance(mi), ctx.emission_context.tsctx,
ctx.emission_context.imaging,
jl_Module->getDataLayout(), Triple(jl_Module->getTargetTriple()));
jl_llvm_functions_t closure_decls = emit_function(closure_m, mi, ir, rettype, ctx.emission_context);
if (it == ctx.emission_context.compiled_functions.end()) {
++EmittedOpaqueClosureFunctions;
jl_code_info_t *ir = jl_uncompress_ir(closure_method, ci, (jl_value_t*)inferred);
JL_GC_PUSH1(&ir);
// TODO: Emit this inline and outline it late using LLVM's coroutine support.
orc::ThreadSafeModule closure_m = jl_create_ts_module(
name_from_method_instance(mi), ctx.emission_context.tsctx,
ctx.emission_context.imaging,
jl_Module->getDataLayout(), Triple(jl_Module->getTargetTriple()));
jl_llvm_functions_t closure_decls = emit_function(closure_m, mi, ir, rettype, ctx.emission_context);
JL_GC_POP();
it = ctx.emission_context.compiled_functions.insert(std::make_pair(ci, std::make_pair(std::move(closure_m), std::move(closure_decls)))).first;
}

auto &closure_m = it->second.first;
auto &closure_decls = it->second.second;

assert(closure_decls.functionObject != "jl_fptr_sparam");
bool isspecsig = closure_decls.functionObject != "jl_fptr_args";
Expand Down Expand Up @@ -5432,7 +5439,6 @@ static std::pair<Function*, Function*> get_oc_function(jl_codectx_t &ctx, jl_met
specF = cast<Function>(returninfo.decl.getCallee());
}
}
ctx.oc_modules.push_back(std::move(closure_m));
JL_GC_POP();
return std::make_pair(F, specF);
}
Expand Down Expand Up @@ -5715,7 +5721,7 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaidx_
if (jl_is_concrete_type(env_t)) {
jl_tupletype_t *argt_typ = (jl_tupletype_t*)argt.constant;
Function *F, *specF;
std::tie(F, specF) = get_oc_function(ctx, (jl_method_t*)source.constant, (jl_datatype_t*)env_t, argt_typ, ub.constant);
std::tie(F, specF) = get_oc_function(ctx, (jl_method_t*)source.constant, (jl_tupletype_t*)env_t, argt_typ, ub.constant);
if (F) {
jl_cgval_t jlcall_ptr = mark_julia_type(ctx, F, false, jl_voidpointer_type);
jl_aliasinfo_t ai = jl_aliasinfo_t::fromTBAA(ctx, ctx.tbaa().tbaa_gcframe);
Expand All @@ -5725,7 +5731,7 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaidx_
if (specF)
fptr = mark_julia_type(ctx, specF, false, jl_voidpointer_type);
else
fptr = mark_julia_type(ctx, (llvm::Value*)Constant::getNullValue(ctx.types().T_size), false, jl_voidpointer_type);
fptr = mark_julia_type(ctx, Constant::getNullValue(ctx.types().T_size), false, jl_voidpointer_type);

// TODO: Inline the env at the end of the opaque closure and generate a descriptor for GC
jl_cgval_t env = emit_new_struct(ctx, env_t, nargs-4, &argv.data()[4]);
Expand Down Expand Up @@ -8757,19 +8763,6 @@ static jl_llvm_functions_t
jl_Module->getFunction(FN)->setLinkage(GlobalVariable::InternalLinkage);
}

// link in opaque closure modules
for (auto &TSMod : ctx.oc_modules) {
SmallVector<std::string, 1> Exports;
TSMod.withModuleDo([&](Module &Mod) {
for (const auto &F: Mod.functions())
if (!F.isDeclaration())
Exports.push_back(F.getName().str());
});
jl_merge_module(TSM, std::move(TSMod));
for (auto FN: Exports)
jl_Module->getFunction(FN)->setLinkage(GlobalVariable::InternalLinkage);
}

JL_GC_POP();
return declarations;
}
Expand Down Expand Up @@ -8931,22 +8924,18 @@ jl_llvm_functions_t jl_emit_codeinst(


void jl_compile_workqueue(
jl_workqueue_t &emitted,
jl_codegen_params_t &params,
Module &original,
jl_codegen_params_t &params, CompilationPolicy policy)
CompilationPolicy policy)
{
JL_TIMING(CODEGEN, CODEGEN_Workqueue);
jl_code_info_t *src = NULL;
JL_GC_PUSH1(&src);
while (!params.workqueue.empty()) {
jl_code_instance_t *codeinst;
Function *protodecl;
jl_returninfo_t::CallingConv proto_cc;
bool proto_specsig;
unsigned proto_return_roots;
auto it = params.workqueue.back();
codeinst = it.first;
std::tie(proto_cc, proto_return_roots, protodecl, proto_specsig) = it.second;
auto proto = it.second;
params.workqueue.pop_back();
// try to emit code for this item from the workqueue
assert(codeinst->min_world <= params.world && codeinst->max_world >= params.world &&
Expand Down Expand Up @@ -8974,12 +8963,8 @@ void jl_compile_workqueue(
}
}
else {
auto &result = emitted[codeinst];
jl_llvm_functions_t *decls = NULL;
if (std::get<0>(result)) {
decls = &std::get<1>(result);
}
else {
auto it = params.compiled_functions.find(codeinst);
if (it == params.compiled_functions.end()) {
// Reinfer the function. The JIT came along and removed the inferred
// method body. See #34993
if (policy != CompilationPolicy::Default &&
Expand All @@ -8990,47 +8975,46 @@ void jl_compile_workqueue(
jl_create_ts_module(name_from_method_instance(codeinst->def),
params.tsctx, params.imaging,
original.getDataLayout(), Triple(original.getTargetTriple()));
result.second = jl_emit_code(result_m, codeinst->def, src, src->rettype, params);
result.first = std::move(result_m);
auto decls = jl_emit_code(result_m, codeinst->def, src, src->rettype, params);
if (result_m)
it = params.compiled_functions.insert(std::make_pair(codeinst, std::make_pair(std::move(result_m), std::move(decls)))).first;
}
}
else {
orc::ThreadSafeModule result_m =
jl_create_ts_module(name_from_method_instance(codeinst->def),
params.tsctx, params.imaging,
original.getDataLayout(), Triple(original.getTargetTriple()));
result.second = jl_emit_codeinst(result_m, codeinst, NULL, params);
result.first = std::move(result_m);
auto decls = jl_emit_codeinst(result_m, codeinst, NULL, params);
if (result_m)
it = params.compiled_functions.insert(std::make_pair(codeinst, std::make_pair(std::move(result_m), std::move(decls)))).first;
}
if (std::get<0>(result))
decls = &std::get<1>(result);
else
emitted.erase(codeinst); // undo the insert above
}
if (decls) {
if (decls->functionObject == "jl_fptr_args") {
preal_decl = decls->specFunctionObject;
if (it != params.compiled_functions.end()) {
auto &decls = it->second.second;
if (decls.functionObject == "jl_fptr_args") {
preal_decl = decls.specFunctionObject;
}
else if (decls->functionObject != "jl_fptr_sparam") {
preal_decl = decls->specFunctionObject;
else if (decls.functionObject != "jl_fptr_sparam") {
preal_decl = decls.specFunctionObject;
preal_specsig = true;
}
}
}
// patch up the prototype we emitted earlier
Module *mod = protodecl->getParent();
assert(protodecl->isDeclaration());
if (proto_specsig) {
Module *mod = proto.decl->getParent();
assert(proto.decl->isDeclaration());
if (proto.specsig) {
// expected specsig
if (!preal_specsig) {
// emit specsig-to-(jl)invoke conversion
Function *preal = emit_tojlinvoke(codeinst, mod, params);
protodecl->setLinkage(GlobalVariable::InternalLinkage);
proto.decl->setLinkage(GlobalVariable::InternalLinkage);
//protodecl->setAlwaysInline();
jl_init_function(protodecl, params.TargetTriple);
jl_init_function(proto.decl, params.TargetTriple);
size_t nrealargs = jl_nparams(codeinst->def->specTypes); // number of actual arguments being passed
// TODO: maybe this can be cached in codeinst->specfptr?
emit_cfunc_invalidate(protodecl, proto_cc, proto_return_roots, codeinst->def->specTypes, codeinst->rettype, false, nrealargs, params, preal);
emit_cfunc_invalidate(proto.decl, proto.cc, proto.return_roots, codeinst->def->specTypes, codeinst->rettype, false, nrealargs, params, preal);
preal_decl = ""; // no need to fixup the name
}
else {
Expand All @@ -9047,11 +9031,11 @@ void jl_compile_workqueue(
if (!preal_decl.empty()) {
// merge and/or rename this prototype to the real function
if (Value *specfun = mod->getNamedValue(preal_decl)) {
if (protodecl != specfun)
protodecl->replaceAllUsesWith(specfun);
if (proto.decl != specfun)
proto.decl->replaceAllUsesWith(specfun);
}
else {
protodecl->setName(preal_decl);
proto.decl->setName(preal_decl);
}
}
}
Expand Down
19 changes: 9 additions & 10 deletions src/jitlayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,16 +215,15 @@ static jl_callptr_t _jl_compile_codeinst(
params.world = world;
params.imaging = imaging_default();
params.debug_level = jl_options.debug_level;
jl_workqueue_t emitted;
{
orc::ThreadSafeModule result_m =
jl_create_ts_module(name_from_method_instance(codeinst->def), params.tsctx, params.imaging, params.DL, params.TargetTriple);
jl_llvm_functions_t decls = jl_emit_codeinst(result_m, codeinst, src, params);
if (result_m)
emitted[codeinst] = {std::move(result_m), std::move(decls)};
params.compiled_functions[codeinst] = {std::move(result_m), std::move(decls)};
{
auto temp_module = jl_create_llvm_module(name_from_method_instance(codeinst->def), params.getContext(), params.imaging);
jl_compile_workqueue(emitted, *temp_module, params, CompilationPolicy::Default);
jl_compile_workqueue(params, *temp_module, CompilationPolicy::Default);
}

if (params._shared_module)
Expand All @@ -241,7 +240,7 @@ static jl_callptr_t _jl_compile_codeinst(
for (auto &global : params.global_targets) {
NewGlobals[global.second->getName()] = global.first;
}
for (auto &def : emitted) {
for (auto &def : params.compiled_functions) {
auto M = std::get<0>(def.second).getModuleUnlocked();
for (auto &GV : M->globals()) {
auto InitValue = NewGlobals.find(GV.getName());
Expand All @@ -252,14 +251,14 @@ static jl_callptr_t _jl_compile_codeinst(
}
}

// Collect the exported functions from the emitted modules,
// Collect the exported functions from the params.compiled_functions modules,
// which form dependencies on which functions need to be
// compiled first. Cycles of functions are compiled together.
// (essentially we compile a DAG of SCCs in reverse topological order,
// if we treat declarations of external functions as edges from declaration
// to definition)
StringMap<orc::ThreadSafeModule*> NewExports;
for (auto &def : emitted) {
for (auto &def : params.compiled_functions) {
orc::ThreadSafeModule &TSM = std::get<0>(def.second);
//The underlying context object is still locked because params is not destroyed yet
auto M = TSM.getModuleUnlocked();
Expand All @@ -271,19 +270,19 @@ static jl_callptr_t _jl_compile_codeinst(
}
DenseMap<orc::ThreadSafeModule*, int> Queued;
std::vector<orc::ThreadSafeModule*> Stack;
for (auto &def : emitted) {
for (auto &def : params.compiled_functions) {
// Add the results to the execution engine now
orc::ThreadSafeModule &M = std::get<0>(def.second);
jl_add_to_ee(M, NewExports, Queued, Stack);
assert(Queued.empty() && Stack.empty() && !M);
}
++CompiledCodeinsts;
MaxWorkqueueSize.updateMax(emitted.size());
IndirectCodeinsts += emitted.size() - 1;
MaxWorkqueueSize.updateMax(params.compiled_functions.size());
IndirectCodeinsts += params.compiled_functions.size() - 1;
}

size_t i = 0;
for (auto &def : emitted) {
for (auto &def : params.compiled_functions) {
jl_code_instance_t *this_code = def.first;
if (i < jl_timing_print_limit)
jl_timing_show_func_sig(this_code->def->specTypes, JL_TIMING_DEFAULT_BLOCK);
Expand Down
Loading

0 comments on commit 441fcb1

Please sign in to comment.