diff --git a/src/enzyme_ad/jax/Passes/LowerKernel.cpp b/src/enzyme_ad/jax/Passes/LowerKernel.cpp index 5ca38ab48..a076f2193 100644 --- a/src/enzyme_ad/jax/Passes/LowerKernel.cpp +++ b/src/enzyme_ad/jax/Passes/LowerKernel.cpp @@ -55,18 +55,124 @@ #include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" #include "llvm/ExecutionEngine/SectionMemoryManager.h" +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" +#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h" +#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/Pipelines/Passes.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassOptions.h" +#include "mlir/Transforms/Passes.h" + #include "mlir/Target/LLVMIR/Export.h" #define DEBUG_TYPE "lower-kernel" using namespace mlir; using namespace mlir::enzyme; +using namespace mlir::gpu; using namespace enzyme; using namespace mlir::enzymexla; using namespace enzymexla; using namespace stablehlo; +namespace { + +void buildCommonPassPipeline( + OpPassManager &pm, const mlir::gpu::GPUToNVVMPipelineOptions &options) { + pm.addPass(createConvertNVGPUToNVVMPass()); + pm.addPass(createGpuKernelOutliningPass()); + pm.addPass(createConvertVectorToSCFPass()); + pm.addPass(createConvertSCFToCFPass()); + pm.addPass(createConvertNVVMToLLVMPass()); + pm.addPass(createConvertFuncToLLVMPass()); + pm.addPass(memref::createExpandStridedMetadataPass()); + + GpuNVVMAttachTargetOptions nvvmTargetOptions; + nvvmTargetOptions.triple = options.cubinTriple; + nvvmTargetOptions.chip = options.cubinChip; + nvvmTargetOptions.features = options.cubinFeatures; + nvvmTargetOptions.optLevel = options.optLevel; + pm.addPass(createGpuNVVMAttachTarget(nvvmTargetOptions)); + pm.addPass(createLowerAffinePass()); + pm.addPass(createArithToLLVMConversionPass()); + ConvertIndexToLLVMPassOptions convertIndexToLLVMPassOpt; + convertIndexToLLVMPassOpt.indexBitwidth = options.indexBitWidth; + pm.addPass(createConvertIndexToLLVMPass(convertIndexToLLVMPassOpt)); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); +} + +//===----------------------------------------------------------------------===// +// GPUModule-specific stuff. +//===----------------------------------------------------------------------===// +void buildGpuPassPipeline(OpPassManager &pm, + const mlir::gpu::GPUToNVVMPipelineOptions &options) { + pm.addNestedPass(createStripDebugInfoPass()); + ConvertGpuOpsToNVVMOpsOptions opt; + opt.useBarePtrCallConv = options.kernelUseBarePtrCallConv; + opt.indexBitwidth = options.indexBitWidth; + pm.addNestedPass(createConvertGpuOpsToNVVMOps(opt)); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCSEPass()); + pm.addNestedPass(createReconcileUnrealizedCastsPass()); +} + +//===----------------------------------------------------------------------===// +// Host Post-GPU pipeline +//===----------------------------------------------------------------------===// +void buildHostPostPipeline(OpPassManager &pm, + const mlir::gpu::GPUToNVVMPipelineOptions &options, + std::string toolkitPath, + llvm::SmallVectorImpl &linkFiles) { + GpuToLLVMConversionPassOptions opt; + opt.hostBarePtrCallConv = options.hostUseBarePtrCallConv; + opt.kernelBarePtrCallConv = options.kernelUseBarePtrCallConv; + pm.addPass(createGpuToLLVMConversionPass(opt)); + + GpuModuleToBinaryPassOptions gpuModuleToBinaryPassOptions; + gpuModuleToBinaryPassOptions.compilationTarget = options.cubinFormat; + gpuModuleToBinaryPassOptions.toolkitPath = toolkitPath; + gpuModuleToBinaryPassOptions.linkFiles.append(linkFiles); + pm.addPass(createGpuModuleToBinaryPass(gpuModuleToBinaryPassOptions)); + pm.addPass(createConvertMathToLLVMPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(createReconcileUnrealizedCastsPass()); +} + +void buildLowerToNVVMPassPipeline( + OpPassManager &pm, const GPUToNVVMPipelineOptions &options, + std::string toolkitPath, llvm::SmallVectorImpl &linkFiles) { + // Common pipelines + buildCommonPassPipeline(pm, options); + + // GPUModule-specific stuff + buildGpuPassPipeline(pm, options); + + // Host post-GPUModule-specific stuff + buildHostPostPipeline(pm, options, toolkitPath, linkFiles); +} + +} // namespace + typedef void XlaCustomCallStatus; llvm::StringMap kernels; @@ -104,8 +210,7 @@ void *CompileHostModule(std::string &key, mlir::ModuleOp modOp) { std::unique_ptr ctx(new llvm::LLVMContext); auto llvmModule = translateModuleToLLVMIR(modOp, *ctx); if (!llvmModule) { - llvm::errs() << "could not convert to LLVM IR" - << "\n"; + llvm::errs() << "could not convert to LLVM IR\n"; return nullptr; } llvmModule->setDataLayout(JIT->getDataLayout()); @@ -191,7 +296,10 @@ gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op) { void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc, FunctionOpInterface op, bool jit, size_t gridx, size_t gridy, size_t gridz, size_t blockx, size_t blocky, - size_t blockz, size_t shmem) { + size_t blockz, size_t shmem, std::string toolkitPath, + llvm::SmallVectorImpl &linkFiles, + int indexBitWidth, std::string cubinChip, + std::string cubinFeatures) { OpBuilder builder(op); @@ -332,15 +440,15 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc, PassManager pm(submod.getContext()); mlir::gpu::GPUToNVVMPipelineOptions options; - options.indexBitWidth = 64; + options.indexBitWidth = indexBitWidth; options.cubinTriple = "nvptx64-nvidia-cuda"; - options.cubinChip = "sm_50"; - options.cubinFeatures = "+ptx60"; + options.cubinChip = cubinChip; + options.cubinFeatures = cubinFeatures; options.cubinFormat = "fatbin"; options.optLevel = 2; options.kernelUseBarePtrCallConv = false; options.hostUseBarePtrCallConv = false; - mlir::gpu::buildLowerToNVVMPassPipeline(pm, options); + buildLowerToNVVMPassPipeline(pm, options, toolkitPath, linkFiles); pm.run(submod); @@ -491,7 +599,9 @@ struct LowerKernelPass : public LowerKernelPassBase { options.optLevel = 2; options.kernelUseBarePtrCallConv = false; options.hostUseBarePtrCallConv = false; - mlir::gpu::buildLowerToNVVMPassPipeline(pm, options); + std::string toolkitPath = ""; + SmallVector linkFiles; + buildLowerToNVVMPassPipeline(pm, options, toolkitPath, linkFiles); pm.getDependentDialects(registry); registry.insert { mlir::NVVM::NVVMDialect, mlir::LLVM::LLVMDialect>(); } + SmallVector parseLinkFilesString(StringRef inp) { + if (inp.size() == 0) + return {}; + SmallVector split; + SmallVector out; + StringRef(inp.data(), inp.size()).split(split, ';'); + for (auto &str : split) { + out.push_back(str.str()); + } + return out; + } + void runOnOperation() override { auto context = getOperation()->getContext(); SymbolTableCollection symbolTable; symbolTable.getSymbolTable(getOperation()); + llvm::SmallVector linkFilesArray = + parseLinkFilesString(linkFiles.getValue()); getOperation()->walk([&](KernelCallOp op) { mlir::ArrayAttr operand_layouts = op.getOperandLayouts() @@ -542,9 +666,11 @@ struct LowerKernelPass : public LowerKernelPassBase { } // Compiled kernel goes here once ready - data[0] = (size_t)CompileKernel(symbolTable, op.getLoc(), fn, jit, - data[1], data[2], data[3], data[4], - data[5], data[6], data[7]); + data[0] = (size_t)CompileKernel( + symbolTable, op.getLoc(), fn, jit, data[1], data[2], data[3], data[4], + data[5], data[6], data[7], toolkitPath.getValue(), linkFilesArray, + indexBitWidth.getValue(), cubinChip.getValue(), + cubinFeatures.getValue()); std::string backendinfo((char *)&data, sizeof(void *)); diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index bc20f6a68..3f8cbdbed 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -143,7 +143,42 @@ def LowerKernelPass : Pass<"lower-kernel"> { /*type=*/"bool", /*default=*/"true", /*description=*/"Whether to jit the kernel" - > + >, + Option< + /*C++ variable name=*/"toolkitPath", + /*CLI argument=*/"toolkitPath", + /*type=*/"std::string", + /*default=*/"", + /*description=*/"The location of the cuda toolkit" + >, + Option< + /*C++ variable name=*/"linkFiles", + /*CLI argument=*/"linkFiles", + /*type=*/"std::string", + /*default=*/"", + /*description=*/"Semicolon separated list of files to link" + >, + Option< + /*C++ variable name=*/"cubinChip", + /*CLI argument=*/"cubinChip", + /*type=*/"std::string", + /*default=*/"\"sm_50\"", + /*description=*/"cubinChip" + >, + Option< + /*C++ variable name=*/"cubinFeatures", + /*CLI argument=*/"cubinFeatures", + /*type=*/"std::string", + /*default=*/"\"+ptx60\"", + /*description=*/"cubinChip" + >, + Option< + /*C++ variable name=*/"indexBitWidth", + /*CLI argument=*/"indexBitWidth", + /*type=*/"int", + /*default=*/"64", + /*description=*/"indexBitWidth" + >, ]; }