From 264115eb30f0b23f73040bd68ee6964b634330e9 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 28 Jan 2025 19:16:07 -0500 Subject: [PATCH] CPU handler --- src/enzyme_ad/jax/BUILD | 9 +++ .../jax/Passes/ConvertPolygeistToLLVM.cpp | 4 -- src/enzyme_ad/jax/Passes/LowerKernel.cpp | 56 +++++++++++++------ src/enzyme_ad/jax/Passes/Passes.td | 7 +++ src/enzyme_ad/jax/cpu.cc | 16 ++++++ src/enzyme_ad/jax/enzyme_call.cc | 4 ++ test/lit_tests/lowering/cpu.mlir | 2 +- 7 files changed, 76 insertions(+), 22 deletions(-) create mode 100644 src/enzyme_ad/jax/cpu.cc diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index ad777cde7..9ad98af57 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -20,6 +20,15 @@ cc_library( ], ) +cc_library( + name = "cpu", + srcs = ["cpu.cc"], + deps = [ + "@xla//xla/service:custom_call_target_registry", + ], +) + + pybind_library( name = "clang_compile", srcs = ["clang_compile.cc"], diff --git a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp index 3d00de4be..b4d3baad7 100644 --- a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp @@ -14,7 +14,6 @@ #include "src/enzyme_ad/jax/Dialect/Ops.h" #include "src/enzyme_ad/jax/Passes/Passes.h" -//#include "mlir/../../lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp" #include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" @@ -50,7 +49,6 @@ #include "mlir/Target/LLVMIR/Import.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/RegionUtils.h" -//#include "enzymexla/Passes/Passes.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" @@ -61,8 +59,6 @@ #include #include -//#include "RuntimeWrapperUtils.h" - #define DEBUG_TYPE "convert-enzymexla-to-llvm" #define DBGS() ::llvm::dbgs() << "[" DEBUG_TYPE ":" << PATTERN << "] " diff --git a/src/enzyme_ad/jax/Passes/LowerKernel.cpp b/src/enzyme_ad/jax/Passes/LowerKernel.cpp index c645ea832..12d3ff0b2 100644 --- a/src/enzyme_ad/jax/Passes/LowerKernel.cpp +++ b/src/enzyme_ad/jax/Passes/LowerKernel.cpp @@ -75,6 +75,7 @@ #include "mlir/Pass/PassOptions.h" #include "mlir/Transforms/Passes.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" #include "mlir/Target/LLVMIR/Export.h" @@ -855,7 +856,8 @@ CallInfo CompileCPUKernel(SymbolTableCollection &symbolTable, mlir::Location loc, FunctionOpInterface op, bool jit, size_t gridx, size_t gridy, size_t gridz, size_t blockx, size_t blocky, size_t blockz, - size_t shmem, enzymexla::KernelCallOp, bool debug) { + size_t shmem, enzymexla::KernelCallOp, bool debug, + bool openmp) { OpBuilder builder(op); @@ -1064,7 +1066,10 @@ CallInfo CompileCPUKernel(SymbolTableCollection &symbolTable, ptr = found->second; } else { PassManager pm(submod.getContext()); - pm.addPass(createConvertSCFToOpenMPPass()); + if (openmp) + pm.addPass(createConvertSCFToOpenMPPass()); + else + pm.addPass(createConvertSCFToCFPass()); buildLowerToCPUPassPipeline(pm); auto subres = pm.run(submod); @@ -1104,7 +1109,8 @@ struct LowerKernelPass buildLowerToNVVMPassPipeline(pm, options, toolkitPath, linkFiles); } else if (backend == "cpu") { buildLowerToCPUPassPipeline(pm); - registry.insert(); + if (openmp) + registry.insert(); } pm.getDependentDialects(registry); @@ -1187,7 +1193,7 @@ struct LowerKernelPass else if (backend == "cpu") cdata = CompileCPUKernel(symbolTable, op.getLoc(), fn, jit, data[1], data[2], data[3], data[4], data[5], data[6], - data[7], op, debug); + data[7], op, debug, openmp); else { op->emitError() << "Cannot lower kernel to unknown backend \"" << backend << "\""; @@ -1197,22 +1203,38 @@ struct LowerKernelPass OpBuilder rewriter(op); + auto backendstr = rewriter.getStringAttr(backendinfo); SmallVector names; - names.push_back(NamedAttribute(rewriter.getStringAttr("attr"), - rewriter.getStringAttr(backendinfo))); + names.push_back( + NamedAttribute(rewriter.getStringAttr("attr"), backendstr)); auto dattr = DictionaryAttr::get(op.getContext(), names); - auto replacement = rewriter.create( - op.getLoc(), op.getResultTypes(), op.getInputs(), - rewriter.getStringAttr("enzymexla_compile_gpu"), - /* has_side_effect*/ rewriter.getBoolAttr(false), - /*backend_config*/ dattr, - /* api_version*/ - CustomCallApiVersionAttr::get( - rewriter.getContext(), - mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI), - /*calledcomputations*/ nullptr, operand_layouts, result_layouts, - output_operand_aliases); + Operation *replacement; + if (backend == "cuda") + replacement = rewriter.create( + op.getLoc(), op.getResultTypes(), op.getInputs(), + rewriter.getStringAttr("enzymexla_compile_gpu"), + /* has_side_effect*/ rewriter.getBoolAttr(false), + /*backend_config*/ dattr, + /* api_version*/ + CustomCallApiVersionAttr::get( + rewriter.getContext(), + mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI), + /*calledcomputations*/ nullptr, operand_layouts, result_layouts, + output_operand_aliases); + else if (backend == "cpu") + replacement = rewriter.create( + op.getLoc(), op.getResultTypes(), op.getInputs(), + rewriter.getStringAttr("enzymexla_compile_cpu"), + /* has_side_effect*/ rewriter.getBoolAttr(false), + /*backend_config*/ backendstr, + /* api_version*/ + CustomCallApiVersionAttr::get( + rewriter.getContext(), + mlir::stablehlo::CustomCallApiVersion:: + API_VERSION_STATUS_RETURNING_UNIFIED), + /*calledcomputations*/ nullptr, operand_layouts, result_layouts, + output_operand_aliases); op.replaceAllUsesWith(replacement); op.erase(); diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 011a18d0f..4cf563335 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -371,6 +371,13 @@ def LowerKernelPass : Pass<"lower-kernel"> { /*default=*/"\"cuda\"", /*description=*/"HW backend" >, + Option< + /*C++ variable name=*/"openmp", + /*CLI argument=*/"openmp", + /*type=*/"bool", + /*default=*/"true", + /*description=*/"whether to use openmp for lowering" + >, ]; } diff --git a/src/enzyme_ad/jax/cpu.cc b/src/enzyme_ad/jax/cpu.cc new file mode 100644 index 000000000..4ab2415b4 --- /dev/null +++ b/src/enzyme_ad/jax/cpu.cc @@ -0,0 +1,16 @@ +#include "xla/service/custom_call_target_registry.h" + +struct CallInfo { + void (*run)(const void **); +}; + +void forwarding_custom_call(void *out, const void **in, const CallInfo *opaque, + size_t opaque_len, void *status) { + opaque->run(in); +} +void CallOpaque(void* out, void** in + +extern "C" void RegisterEnzymeXLACPUHandler() { + CustomCallTargetRegistry::Global()->Register("enzymexla_compile_cpu", + *forwarding_custom_call, "CPU"); +} diff --git a/src/enzyme_ad/jax/enzyme_call.cc b/src/enzyme_ad/jax/enzyme_call.cc index d0d2db32a..4ea1a3caf 100644 --- a/src/enzyme_ad/jax/enzyme_call.cc +++ b/src/enzyme_ad/jax/enzyme_call.cc @@ -1018,6 +1018,7 @@ void Callback(void *out, void **ins) { } extern "C" void RegisterEnzymeXLAGPUHandler(); +extern "C" void RegisterEnzymeXLACPUHandler(); PYBIND11_MODULE(enzyme_call, m) { llvm::InitializeAllTargets(); @@ -1245,6 +1246,9 @@ PYBIND11_MODULE(enzyme_call, m) { return run_pass_pipeline(oldsyms, mlir, pass_pipeline); }); + m.def("register_enzymexla_cpu_handler", + []() { RegisterEnzymeXLACPUHandler(); }); + m.def("register_enzymexla_gpu_handler", []() { RegisterEnzymeXLAGPUHandler(); }); diff --git a/test/lit_tests/lowering/cpu.mlir b/test/lit_tests/lowering/cpu.mlir index ba9c4ed76..9e0a8e8e5 100644 --- a/test/lit_tests/lowering/cpu.mlir +++ b/test/lit_tests/lowering/cpu.mlir @@ -33,7 +33,7 @@ module { // CHECK-NEXT: stablehlo.constant // CHECK-NEXT: stablehlo.constant // CHECK-NEXT: stablehlo.constant -// CHECK-NEXT: %0 = stablehlo.custom_call @enzymexla_compile_gpu(%arg0) {api_version = 4 : i32, backend_config = {attr = "\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00"}, output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor<64xi64>) -> tensor<64xi64> +// CHECK-NEXT: %0 = stablehlo.custom_call @enzymexla_compile_cpu(%arg0) {api_version = 3 : i32, backend_config = "\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00", output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor<64xi64>) -> tensor<64xi64> // CHECK-NEXT: return %0 : tensor<64xi64> // CHECK-NEXT: } // CHECK-NEXT:}