diff --git a/src/enzyme_ad/jax/Passes/LowerKernel.cpp b/src/enzyme_ad/jax/Passes/LowerKernel.cpp index 4f66c6148..afe61725d 100644 --- a/src/enzyme_ad/jax/Passes/LowerKernel.cpp +++ b/src/enzyme_ad/jax/Passes/LowerKernel.cpp @@ -190,10 +190,9 @@ struct CallInfo { llvm::StringMap kernels; llvm::sys::SmartRWMutex kernel_mutex; std::unique_ptr JIT = nullptr; +llvm::orc::SymbolMap MappedSymbols; -CallInfo CompileHostModule(std::string &key, mlir::ModuleOp modOp, - bool run_init, size_t *cuLaunchPtr, - bool compileInit = true) { +bool initJIT() { if (!JIT) { auto tJIT = llvm::orc::LLJITBuilder() @@ -215,7 +214,7 @@ CallInfo CompileHostModule(std::string &key, mlir::ModuleOp modOp, .create(); if (!tJIT) { llvm::errs() << " jit creating error: " << tJIT.takeError() << "\n"; - return {}; + return false; } JIT = std::move(tJIT.get()); assert(JIT); @@ -230,12 +229,23 @@ CallInfo CompileHostModule(std::string &key, mlir::ModuleOp modOp, if (!ProcessSymsGenerator) { llvm::errs() << " failure creating symbol generator: " << ProcessSymsGenerator.takeError() << "\n"; - return {}; + return false; } JIT->getMainJITDylib().addGenerator(std::move(ProcessSymsGenerator.get())); } + return true; +} + +extern "C" void EnzymeJaXMapSymbol(const char *name, void *symbol) { + initJIT(); + MappedSymbols[JIT->mangleAndIntern(name)] = llvm::orc::ExecutorSymbolDef( + llvm::orc::ExecutorAddr::fromPtr(symbol), llvm::JITSymbolFlags()); +} +CallInfo CompileHostModule(std::string &key, mlir::ModuleOp modOp, + bool run_init, size_t *cuLaunchPtr, + bool compileInit = true) { std::unique_ptr ctx(new llvm::LLVMContext); auto llvmModule = translateModuleToLLVMIR(modOp, *ctx); if (!llvmModule) { @@ -243,6 +253,9 @@ CallInfo CompileHostModule(std::string &key, mlir::ModuleOp modOp, llvm::errs() << "could not convert to LLVM IR\n"; return {}; } + if (!initJIT()) + return {}; + llvmModule->setDataLayout(JIT->getDataLayout()); llvmModule->setTargetTriple(JIT->getTargetTriple().getTriple()); @@ -254,6 +267,10 @@ CallInfo CompileHostModule(std::string &key, mlir::ModuleOp modOp, llvm::errs() << " addIRModuleError " << Err << "\n"; return {}; } + if (auto Err = LibA->define(llvm::orc::absoluteSymbols(MappedSymbols))) { + llvm::errs() << " Symbol define Error " << Err << "\n"; + return {}; + } if (cuLaunchPtr && cuLaunchPtr[0] == 0) { // Look up the JIT'd code entry point. diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 4cf563335..388154d2b 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -11,154 +11,127 @@ include "mlir/Pass/PassBase.td" -def RemoveDuplicateFuncDefPass : Pass<"remove-duplicate-func-def", "mlir::ModuleOp"> { +def RemoveDuplicateFuncDefPass + : Pass<"remove-duplicate-func-def", "mlir::ModuleOp"> { let summary = "Remove duplicate function definitions"; - let dependentDialects = [ - "mlir::LLVM::LLVMDialect" - ]; + let dependentDialects = ["mlir::LLVM::LLVMDialect"]; } -def PropagateConstantBoundsPass : Pass<"propagate-constant-bounds", "ModuleOp"> { +def PropagateConstantBoundsPass + : Pass<"propagate-constant-bounds", "ModuleOp"> { let summary = "Propagate constant bounds"; - let dependentDialects = [ - "mlir::LLVM::LLVMDialect", - "mlir::NVVM::NVVMDialect" - ]; + let dependentDialects = + ["mlir::LLVM::LLVMDialect", "mlir::NVVM::NVVMDialect"]; } def ArithRaisingPass : Pass<"arith-raise"> { let summary = "Raise Arith to mhlo"; let dependentDialects = [ - "arith::ArithDialect", - "mhlo::MhloDialect", - "stablehlo::StablehloDialect", - "chlo::ChloDialect", - "enzyme::EnzymeDialect" + "arith::ArithDialect", "mhlo::MhloDialect", "stablehlo::StablehloDialect", + "chlo::ChloDialect", "enzyme::EnzymeDialect" ]; - let options = [ - Option< + let options = [Option< /*C++ variable name=*/"use_stablehlo", /*CLI argument=*/"stablehlo", /*type=*/"bool", /*default=*/"true", - /*description=*/"Whether to raise to stablehlo vs mhlo" - > - ]; + /*description=*/"Whether to raise to stablehlo vs mhlo">]; } def ConsumingInterpreterPass : Pass<"enzyme-consuming-transform-interpreter"> { let summary = "Run the transform interpreter and remove the script"; - let description = [{ - This pass isolates the transform script in a separate module, making it - possible to apply the script to the anchor operation of the pass. - }]; + let description = + [{This pass isolates the transform script in a separate module, + making it possible to apply the script to the anchor operation of the + pass.}]; } def EnzymeHLOOptPass : Pass<"enzyme-hlo-opt"> { let summary = "Optimize stablehlo"; - let dependentDialects = [ - "stablehlo::StablehloDialect", - "tensor::TensorDialect" - ]; + let dependentDialects = + ["stablehlo::StablehloDialect", "tensor::TensorDialect"]; let options = [ Option< - /*C++ variable name=*/"all_finite", - /*CLI argument=*/"all_finite", - /*type=*/"bool", - /*default=*/"false", - /*description=*/"Whether to raise to assume all variables are finite" - >, + /*C++ variable name=*/"all_finite", + /*CLI argument=*/"all_finite", + /*type=*/"bool", + /*default=*/"false", + /*description=*/"Whether to raise to assume all variables are finite">, Option< - /*C++ variable name=*/"no_nan", - /*CLI argument=*/"no_nan", - /*type=*/"bool", - /*default=*/"false", - /*description=*/"Whether to raise to assume no variables are nan" - >, + /*C++ variable name=*/"no_nan", + /*CLI argument=*/"no_nan", + /*type=*/"bool", + /*default=*/"false", + /*description=*/"Whether to raise to assume no variables are nan">, Option< - /*C++ variable name=*/"max_constant_expansion", - /*CLI argument=*/"max_constant_expansion", - /*type=*/"size_t", - /*default=*/"1024", - /*description=*/"Maximum size to expand constants into" - >, + /*C++ variable name=*/"max_constant_expansion", + /*CLI argument=*/"max_constant_expansion", + /*type=*/"size_t", + /*default=*/"1024", + /*description=*/"Maximum size to expand constants into">, Option< - /*C++ variable name=*/"max_iterations", - /*CLI argument=*/"max_iterations", - /*type=*/"int64_t", - /*default=*/"100", - /*description=*/"Maximum number of pattern iterations" - >, + /*C++ variable name=*/"max_iterations", + /*CLI argument=*/"max_iterations", + /*type=*/"int64_t", + /*default=*/"100", + /*description=*/"Maximum number of pattern iterations">, Option< - /*C++ variable name=*/"top_down", - /*CLI argument=*/"top_down", - /*type=*/"bool", - /*default=*/"false", - /*description=*/"Use top down traversal" - >, + /*C++ variable name=*/"top_down", + /*CLI argument=*/"top_down", + /*type=*/"bool", + /*default=*/"false", + /*description=*/"Use top down traversal">, Option< - /*C++ variable name=*/"cse", - /*CLI argument=*/"cse", - /*type=*/"bool", - /*default=*/"true", - /*description=*/"Run CSE alongside" - >, + /*C++ variable name=*/"cse", + /*CLI argument=*/"cse", + /*type=*/"bool", + /*default=*/"true", + /*description=*/"Run CSE alongside">, Option< - /*C++ variable name=*/"passses", - /*CLI argument=*/"passses", - /*type=*/"uint64_t", - /*default=*/"24575", - /*description=*/"Additional optimization passes" - > - ]; + /*C++ variable name=*/"passses", + /*CLI argument=*/"passses", + /*type=*/"uint64_t", + /*default=*/"24575", + /*description=*/"Additional optimization passes"> + ]; } def EnzymeHLOUnrollPass : Pass<"enzyme-hlo-unroll"> { let summary = "Unroll stablehlo"; - let dependentDialects = [ - "stablehlo::StablehloDialect", - "tensor::TensorDialect" - ]; + let dependentDialects = + ["stablehlo::StablehloDialect", "tensor::TensorDialect"]; } def PrintPass : Pass<"print"> { let summary = "Print the module"; - let options = [ - Option< + let options = [Option< /*C++ variable name=*/"use_stdout", /*CLI argument=*/"stdout", /*type=*/"bool", /*default=*/"true", - /*description=*/"Whether to print to stdout (vs stderr)" - > - ]; + /*description=*/"Whether to print to stdout (vs stderr)">]; } def SROAWrappersPass : Pass<"sroa-wrappers", "mlir::ModuleOp"> { let summary = "Run LLVM SROA (Scalar Replacement of Aggregates)"; let dependentDialects = [ - "mlir::LLVM::LLVMDialect", - "mlir::DLTIDialect", - "mlir::NVVM::NVVMDialect", - "mlir::arith::ArithDialect", - "mlir::math::MathDialect" + "mlir::LLVM::LLVMDialect", "mlir::DLTIDialect", "mlir::NVVM::NVVMDialect", + "mlir::arith::ArithDialect", "mlir::math::MathDialect" ]; let options = [ Option< - /*C++ variable name=*/"dump_prellvm", - /*CLI argument=*/"dump_prellvm", - /*type=*/"bool", - /*default=*/"false", - /*description=*/"Whether to dump LLVM before optimizations" - >, + /*C++ variable name=*/"dump_prellvm", + /*CLI argument=*/"dump_prellvm", + /*type=*/"bool", + /*default=*/"false", + /*description=*/"Whether to dump LLVM before optimizations">, Option< - /*C++ variable name=*/"dump_postllvm", - /*CLI argument=*/"dump_postllvm", - /*type=*/"bool", - /*default=*/"false", - /*description=*/"Whether to dump LLVM after optimizations" - > - ]; + /*C++ variable name=*/"dump_postllvm", + /*CLI argument=*/"dump_postllvm", + /*type=*/"bool", + /*default=*/"false", + /*description=*/"Whether to dump LLVM after optimizations"> + ]; } def LibDeviceFuncsRaisingPass : Pass<"libdevice-funcs-raise"> { @@ -169,33 +142,38 @@ def LibDeviceFuncsRaisingPass : Pass<"libdevice-funcs-raise"> { ]; } - -def ConvertPolygeistToLLVM : Pass<"convert-polygeist-to-llvm", "mlir::ModuleOp"> { +def ConvertPolygeistToLLVM + : Pass<"convert-polygeist-to-llvm", "mlir::ModuleOp"> { let summary = "Convert scalar and vector operations from the Standard to the " "LLVM dialect"; let description = [{ Convert standard operations into the LLVM IR dialect operations. - #### Input invariant + ####Input invariant - - operations including: arithmetic on integers and floats, constants, - direct calls, returns and branches; - - no `tensor` types; - - all `vector` are one-dimensional; - - all blocks are reachable by following the successors of the first basic + - operations including : arithmetic on integers and floats, + constants, direct calls, returns and branches; + -no `tensor` types; + -all `vector` are one - dimensional; + -all blocks are reachable by following the successors of the first basic block; If other operations are present and their results are required by the LLVM - IR dialect operations, the pass will fail. Any LLVM IR operations or types - already present in the IR will be kept as is. + IR dialect operations, + the pass will + fail.Any LLVM IR operations or types already present in the IR + will be kept as is + . - #### Output IR + ####Output IR - Functions converted to LLVM IR. Function arguments types are converted - one-to-one. Function results are converted one-to-one and, in case more than - 1 value is returned, packed into an LLVM IR struct type. Function calls and - returns are updated accordingly. Block argument types are updated to use - LLVM IR types. + Functions converted to LLVM IR.Function arguments types are + converted one - + to - one.Function results are converted one - to - one and, + in case more than 1 value is returned, + packed into an LLVM + IR struct type.Function calls and returns are updated accordingly + .Block argument types are updated to use LLVM IR types. }]; let dependentDialects = [ "func::FuncDialect", @@ -212,7 +190,7 @@ def ConvertPolygeistToLLVM : Pass<"convert-polygeist-to-llvm", "mlir::ModuleOp"> "Replace FuncOp's MemRef arguments with bare pointers to the MemRef " "element types">, Option<"indexBitwidth", "index-bitwidth", "unsigned", - /*default=kDeriveIndexBitwidthFromDataLayout*/"0", + /*default=kDeriveIndexBitwidthFromDataLayout*/ "0", "Bitwidth of the index type, 0 to use size of machine word">, Option<"dataLayout", "data-layout", "std::string", /*default=*/"\"\"", @@ -227,158 +205,130 @@ def ConvertPolygeistToLLVM : Pass<"convert-polygeist-to-llvm", "mlir::ModuleOp"> def LowerKernelPass : Pass<"lower-kernel"> { let summary = "Lower kernel to custom call"; + let dependentDialects = []; let dependentDialects = [ - ]; - let dependentDialects = [ - "stablehlo::StablehloDialect", - "gpu::GPUDialect", - "func::FuncDialect", - "math::MathDialect", - "memref::MemRefDialect", - "scf::SCFDialect", - "vector::VectorDialect", - "nvgpu::NVGPUDialect", - "NVVM::NVVMDialect", - "LLVM::LLVMDialect", - "arith::ArithDialect", - "tensor::TensorDialect" + "stablehlo::StablehloDialect", "gpu::GPUDialect", "func::FuncDialect", + "math::MathDialect", "memref::MemRefDialect", "scf::SCFDialect", + "vector::VectorDialect", "nvgpu::NVGPUDialect", "NVVM::NVVMDialect", + "LLVM::LLVMDialect", "arith::ArithDialect", "tensor::TensorDialect" ]; let options = [ Option< - /*C++ variable name=*/"jit", - /*CLI argument=*/"jit", - /*type=*/"bool", - /*default=*/"true", - /*description=*/"Whether to jit the kernel" - >, + /*C++ variable name=*/"jit", + /*CLI argument=*/"jit", + /*type=*/"bool", + /*default=*/"true", + /*description=*/"Whether to jit the kernel">, Option< - /*C++ variable name=*/"compileLaunch", - /*CLI argument=*/"compileLaunch", - /*type=*/"bool", - /*default=*/"true", - /*description=*/"Whether to jit the host code" - >, + /*C++ variable name=*/"compileLaunch", + /*CLI argument=*/"compileLaunch", + /*type=*/"bool", + /*default=*/"true", + /*description=*/"Whether to jit the host code">, Option< - /*C++ variable name=*/"toolkitPath", - /*CLI argument=*/"toolkitPath", - /*type=*/"std::string", - /*default=*/"", - /*description=*/"The location of the cuda toolkit" - >, + /*C++ variable name=*/"toolkitPath", + /*CLI argument=*/"toolkitPath", + /*type=*/"std::string", + /*default=*/"", + /*description=*/"The location of the cuda toolkit">, Option< - /*C++ variable name=*/"linkFiles", - /*CLI argument=*/"linkFiles", - /*type=*/"std::string", - /*default=*/"", - /*description=*/"Semicolon separated list of files to link" - >, + /*C++ variable name=*/"linkFiles", + /*CLI argument=*/"linkFiles", + /*type=*/"std::string", + /*default=*/"", + /*description=*/"Semicolon separated list of files to link">, Option< - /*C++ variable name=*/"cubinChip", - /*CLI argument=*/"cubinChip", - /*type=*/"std::string", - /*default=*/"\"sm_50\"", - /*description=*/"cubinChip" - >, + /*C++ variable name=*/"cubinChip", + /*CLI argument=*/"cubinChip", + /*type=*/"std::string", + /*default=*/"\"sm_50\"", + /*description=*/"cubinChip">, Option< - /*C++ variable name=*/"cubinFeatures", - /*CLI argument=*/"cubinFeatures", - /*type=*/"std::string", - /*default=*/"\"+ptx60\"", - /*description=*/"cubinChip" - >, + /*C++ variable name=*/"cubinFeatures", + /*CLI argument=*/"cubinFeatures", + /*type=*/"std::string", + /*default=*/"\"+ptx60\"", + /*description=*/"cubinChip">, Option< - /*C++ variable name=*/"indexBitWidth", - /*CLI argument=*/"indexBitWidth", - /*type=*/"int", - /*default=*/"64", - /*description=*/"indexBitWidth" - >, + /*C++ variable name=*/"indexBitWidth", + /*CLI argument=*/"indexBitWidth", + /*type=*/"int", + /*default=*/"64", + /*description=*/"indexBitWidth">, Option< - /*C++ variable name=*/"cuLaunchKernelPtr", - /*CLI argument=*/"cuLaunchKernelPtr", - /*type=*/"size_t", - /*default=*/"0", - /*description=*/"cuLaunchKernelPtr" - >, + /*C++ variable name=*/"cuLaunchKernelPtr", + /*CLI argument=*/"cuLaunchKernelPtr", + /*type=*/"size_t", + /*default=*/"0", + /*description=*/"cuLaunchKernelPtr">, Option< - /*C++ variable name=*/"cuModuleLoadDataPtr", - /*CLI argument=*/"cuModuleLoadDataPtr", - /*type=*/"size_t", - /*default=*/"0", - /*description=*/"cuModuleLoadDataPtr" - >, + /*C++ variable name=*/"cuModuleLoadDataPtr", + /*CLI argument=*/"cuModuleLoadDataPtr", + /*type=*/"size_t", + /*default=*/"0", + /*description=*/"cuModuleLoadDataPtr">, Option< - /*C++ variable name=*/"cuModuleGetFunctionPtr", - /*CLI argument=*/"cuModuleGetFunctionPtr", - /*type=*/"size_t", - /*default=*/"0", - /*description=*/"cuModuleGetFunctionPtr" - >, + /*C++ variable name=*/"cuModuleGetFunctionPtr", + /*CLI argument=*/"cuModuleGetFunctionPtr", + /*type=*/"size_t", + /*default=*/"0", + /*description=*/"cuModuleGetFunctionPtr">, Option< - /*C++ variable name=*/"run_init", - /*CLI argument=*/"run_init", - /*type=*/"bool", - /*default=*/"false", - /*description=*/"Run initialization of cuda module" - >, + /*C++ variable name=*/"run_init", + /*CLI argument=*/"run_init", + /*type=*/"bool", + /*default=*/"false", + /*description=*/"Run initialization of cuda module">, Option< - /*C++ variable name=*/"debug", - /*CLI argument=*/"debug", - /*type=*/"bool", - /*default=*/"false", - /*description=*/"Compile in debug prints" - >, + /*C++ variable name=*/"debug", + /*CLI argument=*/"debug", + /*type=*/"bool", + /*default=*/"false", + /*description=*/"Compile in debug prints">, Option< - /*C++ variable name=*/"cuResultHandlerPtr", - /*CLI argument=*/"cuResultHandlerPtr", - /*type=*/"size_t", - /*default=*/"0", - /*description=*/"Function handler to call with result of curesult" - >, + /*C++ variable name=*/"cuResultHandlerPtr", + /*CLI argument=*/"cuResultHandlerPtr", + /*type=*/"size_t", + /*default=*/"0", + /*description=*/"Function handler to call with result of curesult">, Option< - /*C++ variable name=*/"cuStreamSynchronizePtr", - /*CLI argument=*/"cuStreamSynchronizePtr", - /*type=*/"size_t", - /*default=*/"0", - /*description=*/"Function handler to sync results" - >, + /*C++ variable name=*/"cuStreamSynchronizePtr", + /*CLI argument=*/"cuStreamSynchronizePtr", + /*type=*/"size_t", + /*default=*/"0", + /*description=*/"Function handler to sync results">, Option< - /*C++ variable name=*/"cubinFormat", - /*CLI argument=*/"cubinFormat", - /*type=*/"std::string", - /*default=*/"\"bin\"", - /*description=*/"Binary format" - >, + /*C++ variable name=*/"cubinFormat", + /*CLI argument=*/"cubinFormat", + /*type=*/"std::string", + /*default=*/"\"bin\"", + /*description=*/"Binary format">, Option< - /*C++ variable name=*/"cuOptLevel", - /*CLI argument=*/"cuOptLevel", - /*type=*/"int", - /*default=*/"2", - /*description=*/"Opt level for ptx" - >, + /*C++ variable name=*/"cuOptLevel", + /*CLI argument=*/"cuOptLevel", + /*type=*/"int", + /*default=*/"2", + /*description=*/"Opt level for ptx">, Option< - /*C++ variable name=*/"cubinTriple", - /*CLI argument=*/"cubinTriple", - /*type=*/"std::string", - /*default=*/"\"nvptx64-nvidia-cuda\"", - /*description=*/"Target triple" - >, + /*C++ variable name=*/"cubinTriple", + /*CLI argument=*/"cubinTriple", + /*type=*/"std::string", + /*default=*/"\"nvptx64-nvidia-cuda\"", + /*description=*/"Target triple">, Option< - /*C++ variable name=*/"backend", - /*CLI argument=*/"backend", - /*type=*/"std::string", - /*default=*/"\"cuda\"", - /*description=*/"HW backend" - >, + /*C++ variable name=*/"backend", + /*CLI argument=*/"backend", + /*type=*/"std::string", + /*default=*/"\"cuda\"", + /*description=*/"HW backend">, Option< - /*C++ variable name=*/"openmp", - /*CLI argument=*/"openmp", - /*type=*/"bool", - /*default=*/"true", - /*description=*/"whether to use openmp for lowering" - >, - ]; + /*C++ variable name=*/"openmp", + /*CLI argument=*/"openmp", + /*type=*/"bool", + /*default=*/"true", + /*description=*/"whether to use openmp for lowering">, + ]; } //===----------------------------------------------------------------------===// @@ -413,12 +363,12 @@ def EnzymeLiftControlFlowToSCFPass : Pass<"enzyme-lift-cf-to-scf"> { `CFGToSCFInterface::createUnreachableTerminator` implementation. }]; - let dependentDialects = ["scf::SCFDialect", - "arith::ArithDialect", - "ub::UBDialect", - // TODO: This is only necessary until we have a - // ub.unreachable op. - "func::FuncDialect"]; + let dependentDialects = [ + "scf::SCFDialect", "arith::ArithDialect", "ub::UBDialect", + // TODO: This is only necessary until we have a + // ub.unreachable op. + "func::FuncDialect" + ]; } #endif