Skip to content

Commit

Permalink
[XLA:CPU] Enable emitting of nested calls from ElementalKernelEmitter
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 711376068
  • Loading branch information
WillFroom authored and tensorflower-gardener committed Jan 2, 2025
1 parent ea1ddd0 commit 0e8ab96
Show file tree
Hide file tree
Showing 17 changed files with 439 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class TargetMachineFeatures {
explicit TargetMachineFeatures(llvm::TargetMachine* target_machine);
virtual ~TargetMachineFeatures() = default;

TargetMachineFeatures(TargetMachineFeatures&&) = default;
TargetMachineFeatures& operator=(TargetMachineFeatures&&) = default;

// Return the vectorization factor, which is the number of bytes of data
// explicitly vectorized routines will try to process at once.
virtual int32_t vectorization_factor_in_bytes() const;
Expand Down
6 changes: 5 additions & 1 deletion third_party/xla/xla/backends/cpu/testlib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ cc_library(
"//xla/codegen:llvm_ir_kernel_source",
"//xla/codegen/testlib:kernel_runner",
"//xla/service/cpu:runtime_symbol_generator",
"//xla/tsl/platform:errors",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down Expand Up @@ -96,6 +97,7 @@ cc_library(
"//xla:shape_util",
"//xla:util",
"//xla/backends/cpu/codegen:kernel_api_ir_builder",
"//xla/backends/cpu/codegen:target_machine_features",
"//xla/codegen:kernel_emitter",
"//xla/codegen:kernel_spec",
"//xla/codegen:llvm_ir_kernel_source",
Expand Down Expand Up @@ -152,14 +154,15 @@ tsl_pybind_extension(
"@com_google_absl//absl/strings:string_view",
"@nanobind",
"@local_config_python//:python_headers", # buildcleaner: keep
"//xla/backends/cpu/codegen:jit_compiler",
"//xla/backends/cpu/codegen:target_machine_features",
"//xla/codegen:kernel_emitter",
"//xla/codegen:kernel_spec",
"//xla/codegen/testlib:kernel_runner",
"//xla/hlo/ir:hlo",
"//xla/hlo/parser:hlo_parser",
"//xla/service:buffer_assignment",
"//xla/service/cpu:cpu_compiler_pure",
"//xla/service/cpu:ir_emitter",
"//xla/stream_executor:launch_dim",
],
)
Expand Down Expand Up @@ -223,6 +226,7 @@ py_strict_test(
":testlib",
"//third_party/py/numpy",
"//xla/codegen/testlib",
"//xla/python:xla_extension",
"@absl_py//absl/testing:absltest",
"@absl_py//absl/testing:parameterized",
],
Expand Down
4 changes: 4 additions & 0 deletions third_party/xla/xla/backends/cpu/testlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@

# go/keep-sorted start
ElementalKernelEmitter = _extension.ElementalKernelEmitter
HloCompiler = _extension.HloCompiler
HloModule = _extension.HloModule
JitCompiler = _extension.JitCompiler
KernelRunner = _extension.KernelRunner
LlvmIrKernelEmitter = _extension.LlvmIrKernelEmitter
LlvmIrKernelSpec = _extension.LlvmIrKernelSpec
TargetMachineFeatures = _extension.TargetMachineFeatures
# go/keep-sorted end
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@ limitations under the License.
#include "llvm/IR/Module.h"
#include "llvm/IR/Value.h"
#include "xla/backends/cpu/codegen/kernel_api_ir_builder.h"
#include "xla/backends/cpu/codegen/target_machine_features.h"
#include "xla/backends/cpu/testlib/llvm_ir_kernel_spec.h" // Move this outside of testlib?
#include "xla/codegen/kernel_spec.h"
#include "xla/codegen/llvm_ir_kernel_source.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/buffer_assignment.h"
Expand Down Expand Up @@ -156,30 +158,73 @@ ParallelPartitionBounds EmitParallelPartitionBounds(
return bounds;
}

// Implementation detail for ComputationsTransitivelyContainCustomCall, which
// recursively checks whether a computation contains a custom call.
bool RecursivelyCheckForCustomCall(
const HloComputation& computation,
absl::flat_hash_map<const HloComputation*, bool>& custom_call_map) {
bool contains_custom_call = computation.IsCustomCallComputation();

for (const HloInstruction* instruction : computation.instructions()) {
for (const HloComputation* nested_computation :
instruction->called_computations()) {
if (const auto itr = custom_call_map.find(nested_computation);
itr != custom_call_map.end()) {
return itr->second;
}
contains_custom_call |=
RecursivelyCheckForCustomCall(*nested_computation, custom_call_map);
}
}

custom_call_map[&computation] = contains_custom_call;
return contains_custom_call;
}

// For each called computation in operation, determines whether that computation
// calls a custom-call function, either directly or indirectly (e.g. because it
// calls another computation that does).
absl::flat_hash_map<const HloComputation*, bool>
ComputationsTransitivelyContainCustomCall(const HloInstruction& op_hlo) {
absl::flat_hash_map<const HloComputation*, bool> custom_call_map;

for (const HloComputation* computation : op_hlo.called_computations()) {
RecursivelyCheckForCustomCall(*computation, custom_call_map);
}

return custom_call_map;
}

} // namespace

ElementalKernelEmitter::ElementalKernelEmitter(const HloInstruction& op_hlo)
: op_hlo_(op_hlo),
context_(std::make_unique<llvm::LLVMContext>()),
kernel_api_ir_builder_(*context_.getContext(),
KernelApiIrBuilder::Options{true, 256}) {}

ElementalKernelEmitter::ElementalKernelEmitter(
std::unique_ptr<HloInstruction> op_hlo, const HloModule* hlo_module,
const BufferAssignment* buffer_assignment)
: op_hlo_(std::move(op_hlo)),
const HloModule* hlo_module, const BufferAssignment* buffer_assignment,
const TargetMachineFeatures* target_machine)
: op_hlo_(*hlo_module->entry_computation()->root_instruction()),
hlo_module_(hlo_module),
buffer_assignment_(buffer_assignment),
target_machine_(target_machine),
context_(std::make_unique<llvm::LLVMContext>()),
kernel_api_ir_builder_(*context_.getContext(),
KernelApiIrBuilder::Options{true, 256}) {}

absl::StatusOr<std::unique_ptr<KernelSpec>>
ElementalKernelEmitter::EmitKernelSpec() {
VLOG(2) << "Emit elemental host kernel: " << op_hlo_->name();
VLOG(2) << "Emit elemental host kernel: " << op_hlo_.name();

llvm::LLVMContext& ctx = *context_.getContext();
auto module = std::make_unique<llvm::Module>(
absl::StrCat(op_hlo_->name(), "_elemental_kernel_module"), ctx);
absl::StrCat(op_hlo_.name(), "_elemental_kernel_module"), ctx);

TF_ASSIGN_OR_RETURN(
KernelApiIrBuilder::KernelPrototype kernel_prototype,
kernel_api_ir_builder_.EmitKernelPrototype(
*module, op_hlo_.get(), buffer_assignment_, "_kernel"));
TF_ASSIGN_OR_RETURN(KernelApiIrBuilder::KernelPrototype kernel_prototype,
kernel_api_ir_builder_.EmitKernelPrototype(
*module, &op_hlo_, buffer_assignment_, "_kernel"));

llvm::IRBuilder<> ir_builder(ctx);
ir_builder.SetInsertPoint(
Expand All @@ -190,8 +235,8 @@ ElementalKernelEmitter::EmitKernelSpec() {
ThreadLocalCallbackFactory(ir_builder, *module));

CpuElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
for (int64_t i = 0; i < op_hlo_->operand_count(); ++i) {
const HloInstruction* operand = op_hlo_->operand(i);
for (int64_t i = 0; i < op_hlo_.operand_count(); ++i) {
const HloInstruction* operand = op_hlo_.operand(i);
operand_to_generator[operand] = [&, i](const llvm_ir::IrArray::Index& idx) {
return kernel_prototype.arguments[i].EmitReadArrayElement(idx,
&ir_builder);
Expand All @@ -202,12 +247,11 @@ ElementalKernelEmitter::EmitKernelSpec() {
module.get(), &ir_builder, std::move(thread_local_call_fn), true, true);

llvm_ir::ElementGenerator element_generator =
elemental_ir_emitter.MakeElementGenerator(op_hlo_.get(),
operand_to_generator);
elemental_ir_emitter.MakeElementGenerator(&op_hlo_, operand_to_generator);

TF_ASSIGN_OR_RETURN(se::ThreadDim thread_dims,
EmitElementalLoops(ir_builder, op_hlo_.get(),
kernel_prototype, element_generator));
EmitElementalLoops(ir_builder, &op_hlo_, kernel_prototype,
element_generator));

auto source = std::make_unique<LlvmIrKernelSource>(
context_, std::move(module),
Expand Down Expand Up @@ -283,18 +327,18 @@ ElementalKernelEmitter::ThreadLocalCallbackFactory(llvm::IRBuilderBase& builder,
absl::flat_hash_map<const HloInstruction*, int64_t>{},
/*computation_to_profile_idx=*/
absl::flat_hash_map<const HloComputation*, int64_t>{},
/*computation_transitively_contains_custom_call=*/
absl::flat_hash_map<const HloComputation*, bool>{},
/*target_machine=*/nullptr,
ComputationsTransitivelyContainCustomCall(op_hlo_), target_machine_,
/*emit_code_for_msan=*/false);
IrEmitter::IRBuilderGuard builder_guard = ir_emitter->WithBuilder(builder);

if (op_hlo_->has_to_apply()) {
HloComputation* nested_computation = op_hlo_->to_apply();
bool is_reducer = op_hlo_->opcode() == HloOpcode::kReduce ||
op_hlo_->opcode() == HloOpcode::kReduceWindow;
TF_RETURN_IF_ERROR(ir_emitter->EmitSmallConstantGlobals());

if (op_hlo_.has_to_apply()) {
HloComputation* nested_computation = op_hlo_.to_apply();
bool is_reducer = op_hlo_.opcode() == HloOpcode::kReduce ||
op_hlo_.opcode() == HloOpcode::kReduceWindow;
TF_RETURN_IF_ERROR(ir_emitter->EmitNestedComputation(
*nested_computation, llvm_ir::IrName(op_hlo_.get()), is_reducer));
*nested_computation, llvm_ir::IrName(&op_hlo_), is_reducer));
}

return [ir_emitter = std::move(ir_emitter), &builder](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ limitations under the License.
#ifndef XLA_BACKENDS_CPU_TESTLIB_ELEMENTAL_KERNEL_EMITTER_H_
#define XLA_BACKENDS_CPU_TESTLIB_ELEMENTAL_KERNEL_EMITTER_H_

#include <cstddef>
#include <memory>

#include "absl/status/statusor.h"
#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "xla/backends/cpu/codegen/kernel_api_ir_builder.h"
#include "xla/backends/cpu/codegen/target_machine_features.h"
#include "xla/codegen/kernel_emitter.h"
#include "xla/codegen/kernel_spec.h"
#include "xla/hlo/ir/hlo_instruction.h"
Expand All @@ -35,9 +37,11 @@ namespace xla::cpu {

class ElementalKernelEmitter final : public KernelEmitter {
public:
explicit ElementalKernelEmitter(std::unique_ptr<HloInstruction> op_hlo,
const HloModule* hlo_module,
const BufferAssignment* buffer_assignment);
explicit ElementalKernelEmitter(const HloInstruction& op_hlo);

ElementalKernelEmitter(const HloModule* hlo_module,
const BufferAssignment* buffer_assignment,
const TargetMachineFeatures* target_machine);

absl::StatusOr<std::unique_ptr<KernelSpec>> EmitKernelSpec() override;

Expand All @@ -57,10 +61,11 @@ class ElementalKernelEmitter final : public KernelEmitter {
llvm::Module& module) const;

private:
std::unique_ptr<HloInstruction> op_hlo_;
const HloInstruction& op_hlo_;

const HloModule* hlo_module_;
const BufferAssignment* buffer_assignment_;
const HloModule* hlo_module_ = nullptr;
const BufferAssignment* buffer_assignment_ = nullptr;
const TargetMachineFeatures* target_machine_ = nullptr;

llvm::orc::ThreadSafeContext context_;

Expand Down
Loading

0 comments on commit 0e8ab96

Please sign in to comment.