diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index 02ecaae5d6d8ee..bd15a94d184af4 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -261,6 +261,13 @@ static AsyncValue::Executor* GetCompilationAsyncExecutor() { return executor; } +// Returns task runner that uses the global compilation thread pool. +static cpu::JitCompiler::TaskRunner GetCompilationTaskRunner() { + return [](cpu::JitCompiler::Task task) { + GetCompilationThreadPool()->Schedule(std::move(task)); + }; +} + // For each computation in the module, determines whether that computation // calls a custom-call function, either directly or indirectly (e.g. because it // calls another computation that does). @@ -1353,6 +1360,40 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { debug_options.xla_cpu_parallel_codegen_split_count(); VlogMaxIsa(debug_options.xla_cpu_max_isa()); + const HloModuleConfig& config = module->config(); + + // Options for compiling LLVM IR to machine code. + IrCompiler::Options ir_compiler_options{ + /*optimization_level=*/static_cast(CodeGenOptLevel(config)), + /*optimize_for_size=*/options::OptimizeForSizeRequested(config), + /*fast_math_flags=*/llvm_ir::GetCpuFastMathFlags(config), + /*disable_expensive_passes=*/ + debug_options.xla_llvm_disable_expensive_passes(), + /*slp_vectorizer_disabled=*/options::SlpVectorizerDisabled(config), + }; + + // Compiler hooks to intercept compiled LLVM IR modules. + IrCompiler::CompilationHooks ir_compiler_hooks{ + pre_optimization_ir_hook, + post_optimization_ir_hook, + CreateOrcJITPostCompilationHook(module.get(), &obj_files), + }; + + // Options for orchestrating the JIT compilation process. + JitCompiler::Options jit_compiler_options{ + std::move(ir_compiler_options), + std::move(ir_compiler_hooks), + /*num_dylibs=*/parallel_codegen_split_count, + /*max_cpu_isa=*/CpuFeatureFromString(debug_options.xla_cpu_max_isa()), + }; + + TF_ASSIGN_OR_RETURN( + JitCompiler jit_compiler, + JitCompiler::Create(CompilerTargetOptions(module->config()), + CodeGenOptLevel(module->config()), + std::move(jit_compiler_options), + GetCompilationTaskRunner())); + auto jit = SimpleOrcJIT::Create( CompilerTargetOptions(module->config()), CodeGenOptLevel(module->config()), @@ -2072,6 +2113,33 @@ CpuExecutableAotCompilationResult::LoadExecutable( const DebugOptions& debug_options = module->config().debug_options(); VlogMaxIsa(debug_options.xla_cpu_max_isa()); + const HloModuleConfig& config = module->config(); + + // Options for compiling LLVM IR to machine code. + IrCompiler::Options ir_compiler_options{ + /*optimization_level=*/static_cast(CodeGenOptLevel(config)), + /*optimize_for_size=*/options::OptimizeForSizeRequested(config), + /*fast_math_flags=*/llvm_ir::GetCpuFastMathFlags(config), + /*disable_expensive_passes=*/ + debug_options.xla_llvm_disable_expensive_passes(), + /*slp_vectorizer_disabled=*/options::SlpVectorizerDisabled(config), + }; + + // Options for orchestrating the JIT compilation process. + JitCompiler::Options jit_compiler_options{ + std::move(ir_compiler_options), + IrCompiler::CompilationHooks{}, + /*num_dylibs=*/1, + /*max_cpu_isa=*/CpuFeatureFromString(debug_options.xla_cpu_max_isa()), + }; + + TF_ASSIGN_OR_RETURN( + JitCompiler jit_compiler, + JitCompiler::Create(CompilerTargetOptions(module->config()), + CodeGenOptLevel(module->config()), + std::move(jit_compiler_options), + /*task_runner=*/nullptr)); + auto jit = SimpleOrcJIT::Create( CompilerTargetOptions(module->config()), CodeGenOptLevel(module->config()),