Skip to content

Commit

Permalink
CPU handler
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 29, 2025
1 parent 8e7bfb4 commit 264115e
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 22 deletions.
9 changes: 9 additions & 0 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ cc_library(
],
)

cc_library(
name = "cpu",
srcs = ["cpu.cc"],
deps = [
"@xla//xla/service:custom_call_target_registry",
],
)


pybind_library(
name = "clang_compile",
srcs = ["clang_compile.cc"],
Expand Down
4 changes: 0 additions & 4 deletions src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include "src/enzyme_ad/jax/Dialect/Ops.h"
#include "src/enzyme_ad/jax/Passes/Passes.h"

//#include "mlir/../../lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp"
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
Expand Down Expand Up @@ -50,7 +49,6 @@
#include "mlir/Target/LLVMIR/Import.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
//#include "enzymexla/Passes/Passes.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
Expand All @@ -61,8 +59,6 @@
#include <map>
#include <numeric>

//#include "RuntimeWrapperUtils.h"

#define DEBUG_TYPE "convert-enzymexla-to-llvm"
#define DBGS() ::llvm::dbgs() << "[" DEBUG_TYPE ":" << PATTERN << "] "

Expand Down
56 changes: 39 additions & 17 deletions src/enzyme_ad/jax/Passes/LowerKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
#include "mlir/Pass/PassOptions.h"
#include "mlir/Transforms/Passes.h"

#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"

#include "mlir/Target/LLVMIR/Export.h"
Expand Down Expand Up @@ -855,7 +856,8 @@ CallInfo CompileCPUKernel(SymbolTableCollection &symbolTable,
mlir::Location loc, FunctionOpInterface op, bool jit,
size_t gridx, size_t gridy, size_t gridz,
size_t blockx, size_t blocky, size_t blockz,
size_t shmem, enzymexla::KernelCallOp, bool debug) {
size_t shmem, enzymexla::KernelCallOp, bool debug,
bool openmp) {

OpBuilder builder(op);

Expand Down Expand Up @@ -1064,7 +1066,10 @@ CallInfo CompileCPUKernel(SymbolTableCollection &symbolTable,
ptr = found->second;
} else {
PassManager pm(submod.getContext());
pm.addPass(createConvertSCFToOpenMPPass());
if (openmp)
pm.addPass(createConvertSCFToOpenMPPass());
else
pm.addPass(createConvertSCFToCFPass());
buildLowerToCPUPassPipeline(pm);

auto subres = pm.run(submod);
Expand Down Expand Up @@ -1104,7 +1109,8 @@ struct LowerKernelPass
buildLowerToNVVMPassPipeline(pm, options, toolkitPath, linkFiles);
} else if (backend == "cpu") {
buildLowerToCPUPassPipeline(pm);
registry.insert<mlir::omp::OpenMPDialect>();
if (openmp)
registry.insert<mlir::omp::OpenMPDialect>();
}
pm.getDependentDialects(registry);

Expand Down Expand Up @@ -1187,7 +1193,7 @@ struct LowerKernelPass
else if (backend == "cpu")
cdata = CompileCPUKernel(symbolTable, op.getLoc(), fn, jit, data[1],
data[2], data[3], data[4], data[5], data[6],
data[7], op, debug);
data[7], op, debug, openmp);
else {
op->emitError() << "Cannot lower kernel to unknown backend \""
<< backend << "\"";
Expand All @@ -1197,22 +1203,38 @@ struct LowerKernelPass

OpBuilder rewriter(op);

auto backendstr = rewriter.getStringAttr(backendinfo);
SmallVector<NamedAttribute> names;
names.push_back(NamedAttribute(rewriter.getStringAttr("attr"),
rewriter.getStringAttr(backendinfo)));
names.push_back(
NamedAttribute(rewriter.getStringAttr("attr"), backendstr));
auto dattr = DictionaryAttr::get(op.getContext(), names);

auto replacement = rewriter.create<stablehlo::CustomCallOp>(
op.getLoc(), op.getResultTypes(), op.getInputs(),
rewriter.getStringAttr("enzymexla_compile_gpu"),
/* has_side_effect*/ rewriter.getBoolAttr(false),
/*backend_config*/ dattr,
/* api_version*/
CustomCallApiVersionAttr::get(
rewriter.getContext(),
mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI),
/*calledcomputations*/ nullptr, operand_layouts, result_layouts,
output_operand_aliases);
Operation *replacement;
if (backend == "cuda")
replacement = rewriter.create<stablehlo::CustomCallOp>(
op.getLoc(), op.getResultTypes(), op.getInputs(),
rewriter.getStringAttr("enzymexla_compile_gpu"),
/* has_side_effect*/ rewriter.getBoolAttr(false),
/*backend_config*/ dattr,
/* api_version*/
CustomCallApiVersionAttr::get(
rewriter.getContext(),
mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI),
/*calledcomputations*/ nullptr, operand_layouts, result_layouts,
output_operand_aliases);
else if (backend == "cpu")
replacement = rewriter.create<stablehlo::CustomCallOp>(
op.getLoc(), op.getResultTypes(), op.getInputs(),
rewriter.getStringAttr("enzymexla_compile_cpu"),
/* has_side_effect*/ rewriter.getBoolAttr(false),
/*backend_config*/ backendstr,
/* api_version*/
CustomCallApiVersionAttr::get(
rewriter.getContext(),
mlir::stablehlo::CustomCallApiVersion::
API_VERSION_STATUS_RETURNING_UNIFIED),
/*calledcomputations*/ nullptr, operand_layouts, result_layouts,
output_operand_aliases);

op.replaceAllUsesWith(replacement);
op.erase();
Expand Down
7 changes: 7 additions & 0 deletions src/enzyme_ad/jax/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,13 @@ def LowerKernelPass : Pass<"lower-kernel"> {
/*default=*/"\"cuda\"",
/*description=*/"HW backend"
>,
Option<
/*C++ variable name=*/"openmp",
/*CLI argument=*/"openmp",
/*type=*/"bool",
/*default=*/"true",
/*description=*/"whether to use openmp for lowering"
>,
];
}

Expand Down
16 changes: 16 additions & 0 deletions src/enzyme_ad/jax/cpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include "xla/service/custom_call_target_registry.h"

struct CallInfo {
void (*run)(const void **);
};

void forwarding_custom_call(void *out, const void **in, const CallInfo *opaque,
size_t opaque_len, void *status) {
opaque->run(in);
}
void CallOpaque(void* out, void** in

extern "C" void RegisterEnzymeXLACPUHandler() {
CustomCallTargetRegistry::Global()->Register("enzymexla_compile_cpu",
*forwarding_custom_call, "CPU");
}
4 changes: 4 additions & 0 deletions src/enzyme_ad/jax/enzyme_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,7 @@ void Callback(void *out, void **ins) {
}

extern "C" void RegisterEnzymeXLAGPUHandler();
extern "C" void RegisterEnzymeXLACPUHandler();

PYBIND11_MODULE(enzyme_call, m) {
llvm::InitializeAllTargets();
Expand Down Expand Up @@ -1245,6 +1246,9 @@ PYBIND11_MODULE(enzyme_call, m) {
return run_pass_pipeline(oldsyms, mlir, pass_pipeline);
});

m.def("register_enzymexla_cpu_handler",
[]() { RegisterEnzymeXLACPUHandler(); });

m.def("register_enzymexla_gpu_handler",
[]() { RegisterEnzymeXLAGPUHandler(); });

Expand Down
2 changes: 1 addition & 1 deletion test/lit_tests/lowering/cpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ module {
// CHECK-NEXT: stablehlo.constant
// CHECK-NEXT: stablehlo.constant
// CHECK-NEXT: stablehlo.constant
// CHECK-NEXT: %0 = stablehlo.custom_call @enzymexla_compile_gpu(%arg0) {api_version = 4 : i32, backend_config = {attr = "\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00"}, output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>]} : (tensor<64xi64>) -> tensor<64xi64>
// CHECK-NEXT: %0 = stablehlo.custom_call @enzymexla_compile_cpu(%arg0) {api_version = 3 : i32, backend_config = "\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00", output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>]} : (tensor<64xi64>) -> tensor<64xi64>
// CHECK-NEXT: return %0 : tensor<64xi64>
// CHECK-NEXT: }
// CHECK-NEXT:}

0 comments on commit 264115e

Please sign in to comment.