Skip to content

Commit

Permalink
[xla:cpu] NFC: Construct JitCompiler in xla::cpu::CpuCompiler
Browse files Browse the repository at this point in the history
In preparation for compiling with JitCompiler make sure that we can create it inside CpuCompiler.

PiperOrigin-RevId: 701065399
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Nov 28, 2024
1 parent 6715fd0 commit 697677f
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions third_party/xla/xla/service/cpu/cpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,13 @@ static AsyncValue::Executor* GetCompilationAsyncExecutor() {
return executor;
}

// Returns task runner that uses the global compilation thread pool.
static cpu::JitCompiler::TaskRunner GetCompilationTaskRunner() {
return [](cpu::JitCompiler::Task task) {
GetCompilationThreadPool()->Schedule(std::move(task));
};
}

// For each computation in the module, determines whether that computation
// calls a custom-call function, either directly or indirectly (e.g. because it
// calls another computation that does).
Expand Down Expand Up @@ -1353,6 +1360,40 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr<HloModule> module) {
debug_options.xla_cpu_parallel_codegen_split_count();
VlogMaxIsa(debug_options.xla_cpu_max_isa());

const HloModuleConfig& config = module->config();

// Options for compiling LLVM IR to machine code.
IrCompiler::Options ir_compiler_options{
/*optimization_level=*/static_cast<int32_t>(CodeGenOptLevel(config)),
/*optimize_for_size=*/options::OptimizeForSizeRequested(config),
/*fast_math_flags=*/llvm_ir::GetCpuFastMathFlags(config),
/*disable_expensive_passes=*/
debug_options.xla_llvm_disable_expensive_passes(),
/*slp_vectorizer_disabled=*/options::SlpVectorizerDisabled(config),
};

// Compiler hooks to intercept compiled LLVM IR modules.
IrCompiler::CompilationHooks ir_compiler_hooks{
pre_optimization_ir_hook,
post_optimization_ir_hook,
CreateOrcJITPostCompilationHook(module.get(), &obj_files),
};

// Options for orchestrating the JIT compilation process.
JitCompiler::Options jit_compiler_options{
std::move(ir_compiler_options),
std::move(ir_compiler_hooks),
/*num_dylibs=*/parallel_codegen_split_count,
/*max_cpu_isa=*/CpuFeatureFromString(debug_options.xla_cpu_max_isa()),
};

TF_ASSIGN_OR_RETURN(
JitCompiler jit_compiler,
JitCompiler::Create(CompilerTargetOptions(module->config()),
CodeGenOptLevel(module->config()),
std::move(jit_compiler_options),
GetCompilationTaskRunner()));

auto jit = SimpleOrcJIT::Create(
CompilerTargetOptions(module->config()),
CodeGenOptLevel(module->config()),
Expand Down Expand Up @@ -2072,6 +2113,33 @@ CpuExecutableAotCompilationResult::LoadExecutable(

const DebugOptions& debug_options = module->config().debug_options();
VlogMaxIsa(debug_options.xla_cpu_max_isa());
const HloModuleConfig& config = module->config();

// Options for compiling LLVM IR to machine code.
IrCompiler::Options ir_compiler_options{
/*optimization_level=*/static_cast<int32_t>(CodeGenOptLevel(config)),
/*optimize_for_size=*/options::OptimizeForSizeRequested(config),
/*fast_math_flags=*/llvm_ir::GetCpuFastMathFlags(config),
/*disable_expensive_passes=*/
debug_options.xla_llvm_disable_expensive_passes(),
/*slp_vectorizer_disabled=*/options::SlpVectorizerDisabled(config),
};

// Options for orchestrating the JIT compilation process.
JitCompiler::Options jit_compiler_options{
std::move(ir_compiler_options),
IrCompiler::CompilationHooks{},
/*num_dylibs=*/1,
/*max_cpu_isa=*/CpuFeatureFromString(debug_options.xla_cpu_max_isa()),
};

TF_ASSIGN_OR_RETURN(
JitCompiler jit_compiler,
JitCompiler::Create(CompilerTargetOptions(module->config()),
CodeGenOptLevel(module->config()),
std::move(jit_compiler_options),
/*task_runner=*/nullptr));

auto jit = SimpleOrcJIT::Create(
CompilerTargetOptions(module->config()),
CodeGenOptLevel(module->config()),
Expand Down

0 comments on commit 697677f

Please sign in to comment.