diff --git a/src/enzyme_ad/jax/Dialect/Ops.cpp b/src/enzyme_ad/jax/Dialect/Ops.cpp index a96049d94..f87813d8a 100644 --- a/src/enzyme_ad/jax/Dialect/Ops.cpp +++ b/src/enzyme_ad/jax/Dialect/Ops.cpp @@ -82,7 +82,7 @@ class ReadOnlyArg final LogicalResult matchAndRewrite(OpTy launchOp, PatternRewriter &rewriter) const override { SymbolTableCollection symbolTable; - symbolTable.getSymbolTable(launchOp->getParentOfType()); + symbolTable.getSymbolTable(((Operation*)launchOp)->getParentOfType()); auto fn = cast( symbolTable.lookupNearestSymbolFrom(launchOp, launchOp.getFnAttr())); @@ -167,6 +167,7 @@ class ReadOnlyArg final } }; +template<> enzymexla::KernelCallOp ReadOnlyArg::create(PatternRewriter &rewriter, enzymexla::KernelCallOp launchOp, ArrayRef resTys, ArrayAttr outputAliases) const { return rewriter.create( launchOp.getLoc(), resTys, launchOp.getFn(), launchOp.getGridx(), @@ -177,9 +178,10 @@ enzymexla::KernelCallOp ReadOnlyArg::create(PatternRewr outputAliases); } +template<> enzymexla::JITCallOp ReadOnlyArg::create(PatternRewriter &rewriter, enzymexla::JITCallOp launchOp, ArrayRef resTys, ArrayAttr outputAliases) const { - return rewriter.create( - launchOp.getLoc(), resTys, + return rewriter.create( + launchOp.getLoc(), resTys, launchOp.getFn(), launchOp.getInputs(), launchOp.getBackendConfigAttr(), launchOp.getOperandLayoutsAttr(), /*resultLayouts*/ nullptr, outputAliases); @@ -194,7 +196,7 @@ class ReadNoneArg final LogicalResult matchAndRewrite(OpTy launchOp, PatternRewriter &rewriter) const override { SymbolTableCollection symbolTable; - auto mod = launchOp->getParentOfType(); + auto mod = ((Operation*)launchOp)->getParentOfType(); symbolTable.getSymbolTable(mod); auto fn = cast( symbolTable.lookupNearestSymbolFrom(launchOp, launchOp.getFnAttr())); @@ -206,7 +208,7 @@ class ReadNoneArg final if (!use_opt) return failure(); for (auto u : *use_opt) { - auto launch2 = dyn_cast(u.getUser()); + auto launch2 = dyn_cast(u.getUser()); if (!launch2) return failure(); calls.push_back(launch2); diff --git a/src/enzyme_ad/jax/Passes/LowerCall.cpp b/src/enzyme_ad/jax/Passes/LowerJIT.cpp similarity index 92% rename from src/enzyme_ad/jax/Passes/LowerCall.cpp rename to src/enzyme_ad/jax/Passes/LowerJIT.cpp index ee5b15497..17e4a0073 100644 --- a/src/enzyme_ad/jax/Passes/LowerCall.cpp +++ b/src/enzyme_ad/jax/Passes/LowerJIT.cpp @@ -80,11 +80,11 @@ #include "mlir/Target/LLVMIR/Export.h" -#define DEBUG_TYPE "lower-call" +#define DEBUG_TYPE "lower-jit" namespace mlir { namespace enzyme { -#define GEN_PASS_DEF_LOWERCALLPASS +#define GEN_PASS_DEF_LOWERJITPASS #include "src/enzyme_ad/jax/Passes/Passes.h.inc" } // namespace enzyme } // namespace mlir @@ -184,7 +184,7 @@ CallInfo CompileHostModule(std::string &key, mlir::ModuleOp modOp, llvmModule->setTargetTriple(JIT->getTargetTriple().getTriple()); auto LibA = - JIT->createJITDylib("enzymecudadl_" + std::to_string(kernels.size())); + JIT->createJITDylib("enzymejitdl_" + std::to_string(jitkernels.size())); if (auto Err = JIT->addIRModule( LibA.get(), llvm::orc::ThreadSafeModule(std::move(llvmModule), std::move(ctx)))) { @@ -196,17 +196,6 @@ CallInfo CompileHostModule(std::string &key, mlir::ModuleOp modOp, return {}; } - if (cuLaunchPtr && cuLaunchPtr[0] == 0) { - // Look up the JIT'd code entry point. - auto LaunchSym = JIT->lookup(LibA.get(), "cuLaunchKernel"); - if (!LaunchSym) { - llvm::errs() << " lookupError[cuLaunchKernel] " << LaunchSym.takeError() - << "\n"; - return {}; - } - *cuLaunchPtr = (size_t)LaunchSym->getValue(); - } - llvm::Expected NVSym(llvm::orc::ExecutorAddr{}); if (compileInit) { NVSym = JIT->lookup(LibA.get(), "nv_func_init"); @@ -231,7 +220,7 @@ CallInfo CompileHostModule(std::string &key, mlir::ModuleOp modOp, CallInfo CompileCall(SymbolTableCollection &symbolTable, mlir::Location loc, FunctionOpInterface op, bool jit, - enzymexla::JITCallOp) { + enzymexla::JITCallOp, bool openmp) { OpBuilder builder(op); @@ -340,19 +329,19 @@ CallInfo CompileCall(SymbolTableCollection &symbolTable, op.getFunctionBody().cloneInto(&func.getBody(), map); - auto second = entryBlock->getNextNode(); - entryBlock->getOperations().splice(entryBlock->getOperations().end(), + auto second = entryBlock.getNextNode(); + entryBlock.getOperations().splice(entryBlock.getOperations().end(), second->getOperations()); second->erase(); - func.getBody()->walk([](LLVM::ReturnOp op) { + func.getBody().walk([](LLVM::ReturnOp op) { OpBuilder rewriter(op); rewriter.create(op.getLoc()); op.erase(); }); - func.getBody()->walk([](LLVM::UnreachableOp op) { + func.getBody().walk([](LLVM::UnreachableOp op) { OpBuilder rewriter(op); rewriter.create(op.getLoc()); op.erase(); @@ -368,21 +357,26 @@ CallInfo CompileCall(SymbolTableCollection &symbolTable, CallInfo ptr; { - llvm::sys::SmartScopedWriter lock(kernel_mutex); + llvm::sys::SmartScopedWriter jit_lock(jit_kernel_mutex); - auto found = kernels.find(ss.str()); - if (found != kernels.end()) { + auto found = jitkernels.find(ss.str()); + if (found != jitkernels.end()) { ptr = found->second; } else { PassManager pm(submod.getContext()); + if (openmp) + pm.addPass(createConvertSCFToOpenMPPass()); + else + pm.addPass(createConvertSCFToCFPass()); + buildLowerToCPUPassPipeline(pm); auto subres = pm.run(submod); if (!subres.succeeded()) { return {}; } - ptr = CompileHostModule(ss.str(), submod, false, 0, false); - kernels[ss.str()] = ptr; + ptr = CompileHostModule(ss.str(), submod, false, false); + jitkernels[ss.str()] = ptr; submod.erase(); } } @@ -394,7 +388,7 @@ namespace { struct LowerJITPass : public mlir::enzyme::impl::LowerJITPassBase { - using LowerKernelPassBase::LowerKernelPassBase; + using LowerJITPassBase::LowerJITPassBase; void getDependentDialects(DialectRegistry ®istry) const override { OpPassManager pm; @@ -412,8 +406,6 @@ struct LowerJITPass SymbolTableCollection symbolTable; symbolTable.getSymbolTable(getOperation()); - llvm::SmallVector linkFilesArray = - parseLinkFilesString(linkFiles.getValue()); getOperation()->walk([&](JITCallOp op) { mlir::ArrayAttr operand_layouts = op.getOperandLayouts() @@ -435,7 +427,7 @@ struct LowerJITPass return; } - CallInfo cdata = CompileCall(symbolTable, op.getLoc(), fn, jit, op); + CallInfo cdata = CompileCall(symbolTable, op.getLoc(), fn, jit, op, openmp); std::string backendinfo((char *)&cdata, sizeof(CallInfo)); diff --git a/src/enzyme_ad/jax/Passes/LowerKernel.cpp b/src/enzyme_ad/jax/Passes/LowerKernel.cpp index 83b093bf2..359c7ade4 100644 --- a/src/enzyme_ad/jax/Passes/LowerKernel.cpp +++ b/src/enzyme_ad/jax/Passes/LowerKernel.cpp @@ -195,8 +195,7 @@ extern llvm::orc::SymbolMap MappedSymbols; bool initJIT(); CallInfo CompileHostModule(std::string &key, mlir::ModuleOp modOp, - bool run_init, size_t *cuLaunchPtr, - bool compileInit = true); + bool run_init, bool compileInit = true); gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op) { ArrayRef objects = op.getObjectsAttr().getValue(); @@ -752,7 +751,7 @@ CallInfo CompileCUDAKernel( if (!compileLaunch) return {}; - ptr = CompileHostModule(ss.str(), submod, run_init, &cuLaunchKernelPtr); + ptr = CompileHostModule(ss.str(), submod, run_init); kernels[ss.str()] = ptr; submod.erase(); @@ -766,14 +765,12 @@ bool CompileCPUKernel(SymbolTableCollection &symbolTable, mlir::Location loc, FunctionOpInterface op, bool jit, size_t gridx, size_t gridy, size_t gridz, size_t blockx, size_t blocky, size_t blockz, - size_t shmem, enzymexla::KernelCallOp kcall, bool debug, - bool openmp) { + size_t shmem, enzymexla::KernelCallOp kcall, bool debug) { OpBuilder builder(op); auto ptrty = LLVM::LLVMPointerType::get(builder.getContext()); mlir::Type intys[] = {ptrty, ptrty, ptrty}; - FunctionType calleeType = builder.getFunctionType(intys, {}); FunctionType gpuTy0 = dyn_cast(op.getFunctionType()); if (!gpuTy0) { @@ -794,16 +791,14 @@ bool CompileCPUKernel(SymbolTableCollection &symbolTable, } FunctionType gpuTy = builder.getFunctionType(newParams, {}); - auto func = builder.create(loc, "entry", calleeType); - + static int id = 0; + auto callName = (op.getName() + "$" + "par" + std::to_string(id)).str(); + id++; + auto func = builder.create(loc, callName, gpuTy0); + func.setVisibility(SymbolTable::Visibility::Private); auto &entryBlock = *func.addEntryBlock(); builder.setInsertionPointToStart(&entryBlock); - mlir::Value buffers = entryBlock.getArgument(1); - - auto idx = builder.getIntegerType(64); - auto i32 = builder.getIntegerType(32); - SmallVector inits; SmallVector finals; SmallVector incs; @@ -813,33 +808,8 @@ bool CompileCPUKernel(SymbolTableCollection &symbolTable, finals.push_back(builder.create(loc, val)); } - SmallVector arguments; - for (auto arg : op.getArguments()) { - LLVM::GEPArg args[1] = {arg.getArgNumber()}; - auto gep = - builder.create(loc, ptrty, ptrty, buffers, args, true); - auto argTy = arg.getType(); - if (auto AT = dyn_cast(argTy)) { - argTy = AT.getElementType(); - } - auto ld = builder.create(loc, argTy, gep); - arguments.push_back(ld); - } - IRMapping map; - for (auto &&[oldarg, newarg] : zip(op.getArguments(), arguments)) { - Value newval = newarg; - - if (auto AT = dyn_cast(oldarg.getType())) { - auto ud = - builder.create(newarg.getLoc(), oldarg.getType()); - int64_t c0[1] = {0}; - newval = builder.create( - newarg.getLoc(), oldarg.getType(), ud, newval, c0); - } - - map.map(oldarg, newval); - } + map.map(op.getArguments(), entryBlock.getArguments()); auto par = builder.create(loc, inits, finals, incs); @@ -917,22 +887,10 @@ bool CompileCPUKernel(SymbolTableCollection &symbolTable, }); } - PassManager pm(func.getContext()); - if (openmp) - pm.addPass(createConvertSCFToOpenMPPass()); - else - pm.addPass(createConvertSCFToCFPass()); - - auto subres = pm.run(func); - if (!subres.succeeded()) { - return false; - } - - OpBuilder rewriter(kcall); - auto replacement = rewriter.create(kcall.getLoc(), kcall.getResultTypes(), kcall.getInputs(), func, kcall.getBackendConfig(), kcall.getOperandLayouts(), kcall.getResultLayouts(), kcall.getOutputPperandAliases()); - kcall.replace(replacement); - op.erase(); + auto replacement = rewriter.create(kcall.getLoc(), kcall.getResultTypes(), mlir::FlatSymbolRefAttr::get(kcall.getContext(), callName), kcall.getInputs(), kcall.getBackendConfigAttr(), kcall.getOperandLayoutsAttr(), kcall.getResultLayoutsAttr(), kcall.getOutputOperandAliasesAttr()); + kcall.replaceAllUsesWith(replacement); + kcall.erase(); return true; }; @@ -959,8 +917,6 @@ struct LowerKernelPass buildLowerToNVVMPassPipeline(pm, options, toolkitPath, linkFiles); } else if (backend == "cpu") { buildLowerToCPUPassPipeline(pm); - if (openmp) - registry.insert(); } pm.getDependentDialects(registry); @@ -1063,9 +1019,9 @@ struct LowerKernelPass op.replaceAllUsesWith(replacement); op.erase(); } else if (backend == "cpu") { - cdata = CompileCPUKernel(symbolTable, op.getLoc(), fn, jit, data[1], + CompileCPUKernel(symbolTable, op.getLoc(), fn, jit, data[1], data[2], data[3], data[4], data[5], data[6], - data[7], op, debug, openmp); + data[7], op, debug); } else { op->emitError() << "Cannot lower kernel to unknown backend \"" diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 5ab1420e1..80a42873b 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -322,12 +322,6 @@ def LowerKernelPass : Pass<"lower-kernel"> { /*type=*/"std::string", /*default=*/"\"cuda\"", /*description=*/"HW backend">, - Option< - /*C++ variable name=*/"openmp", - /*CLI argument=*/"openmp", - /*type=*/"bool", - /*default=*/"true", - /*description=*/"whether to use openmp for lowering">, ]; } diff --git a/test/lit_tests/lowering/cpu.mlir b/test/lit_tests/lowering/cpu.mlir index 9e0a8e8e5..8cc316e0f 100644 --- a/test/lit_tests/lowering/cpu.mlir +++ b/test/lit_tests/lowering/cpu.mlir @@ -1,4 +1,4 @@ -// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(lower-kernel{jit=false backend=cpu})" | FileCheck %s +// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(lower-kernel{jit=false backend=cpu},canonicalize)" | FileCheck %s module { llvm.func internal unnamed_addr fastcc @throw_boundserror_2676() attributes {dso_local, no_inline, sym_visibility = "private"} { @@ -29,11 +29,31 @@ module { } } -// CHECK: func.func @main(%arg0: tensor<64xi64>) -> tensor<64xi64> { -// CHECK-NEXT: stablehlo.constant -// CHECK-NEXT: stablehlo.constant -// CHECK-NEXT: stablehlo.constant -// CHECK-NEXT: %0 = stablehlo.custom_call @enzymexla_compile_cpu(%arg0) {api_version = 3 : i32, backend_config = "\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00", output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor<64xi64>) -> tensor<64xi64> +// CHECK: func.func private @kern$par0(%arg0: !llvm.ptr<1>) { +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %0 = llvm.mlir.constant(63 : i32) : i32 +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-NEXT: %c40 = arith.constant 40 : index +// CHECK-NEXT: scf.parallel (%arg1) = (%c0) to (%c40) step (%c1) { +// CHECK-NEXT: scf.execute_region { +// CHECK-NEXT: %1 = llvm.icmp "ugt" %c0_i32, %0 : i32 +// CHECK-NEXT: llvm.cond_br %1, ^bb2, ^bb1 +// CHECK-NEXT: ^bb1: // pred: ^bb0 +// CHECK-NEXT: %2 = llvm.load %arg0 {alignment = 1 : i64} : !llvm.ptr<1> -> i64 +// CHECK-NEXT: %3 = llvm.mul %2, %2 : i64 +// CHECK-NEXT: llvm.store %3, %arg0 {alignment = 1 : i64} : i64, !llvm.ptr<1> +// CHECK-NEXT: scf.yield +// CHECK-NEXT: ^bb2: // pred: ^bb0 +// CHECK-NEXT: llvm.call fastcc @throw_boundserror_2676() : () -> () +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: scf.reduce +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } + +// CHECK: func.func @main(%arg0: tensor<64xi64>) -> tensor<64xi64> { +// CHECK-NEXT: %0 = enzymexla.jit_call @kern$par0 (%arg0) {output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor<64xi64>) -> tensor<64xi64> // CHECK-NEXT: return %0 : tensor<64xi64> // CHECK-NEXT: } -// CHECK-NEXT:} diff --git a/test/lit_tests/lowering/cpujit.mlir b/test/lit_tests/lowering/cpujit.mlir new file mode 100644 index 000000000..ffc6ab6e9 --- /dev/null +++ b/test/lit_tests/lowering/cpujit.mlir @@ -0,0 +1,40 @@ +// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(lower-jit{jit=false backend=cpu})" | FileCheck %s + +module { + llvm.func internal unnamed_addr fastcc @throw_boundserror_2676() attributes {dso_local, no_inline, sym_visibility = "private"} { + llvm.unreachable + } + func.func private @kern$par0(%arg0: !llvm.ptr<1>) { + %c0_i32 = arith.constant 0 : i32 + %0 = llvm.mlir.constant(63 : i32) : i32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c40 = arith.constant 40 : index + scf.parallel (%arg1) = (%c0) to (%c40) step (%c1) { + scf.execute_region { + %1 = llvm.icmp "ugt" %c0_i32, %0 : i32 + llvm.cond_br %1, ^bb2, ^bb1 + ^bb1: // pred: ^bb0 + %2 = llvm.load %arg0 {alignment = 1 : i64} : !llvm.ptr<1> -> i64 + %3 = llvm.mul %2, %2 : i64 + llvm.store %3, %arg0 {alignment = 1 : i64} : i64, !llvm.ptr<1> + scf.yield + ^bb2: // pred: ^bb0 + llvm.call fastcc @throw_boundserror_2676() : () -> () + scf.yield + } + scf.reduce + } + return + } + func.func @main(%arg0: tensor<64xi64>) -> tensor<64xi64> { + %0 = enzymexla.jit_call @kern$par0 (%arg0) {output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor<64xi64>) -> tensor<64xi64> + return %0 : tensor<64xi64> + } +} + +// CHECK: func.func @main(%arg0: tensor<64xi64>) -> tensor<64xi64> { +// CHECK-NEXT: %0 = stablehlo.custom_call @enzymexla_compile_cpu(%arg0) {api_version = 3 : i32, backend_config = "\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00", output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor<64xi64>) -> tensor<64xi64> +// CHECK-NEXT: return %0 : tensor<64xi64> +// CHECK-NEXT: } +// CHECK-NEXT:}