Skip to content

Commit

Permalink
CPU jit lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 2, 2025
1 parent c208c66 commit 8772298
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 104 deletions.
12 changes: 7 additions & 5 deletions src/enzyme_ad/jax/Dialect/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class ReadOnlyArg final
LogicalResult matchAndRewrite(OpTy launchOp,
PatternRewriter &rewriter) const override {
SymbolTableCollection symbolTable;
symbolTable.getSymbolTable(launchOp->getParentOfType<ModuleOp>());
symbolTable.getSymbolTable(((Operation*)launchOp)->getParentOfType<ModuleOp>());
auto fn = cast<FunctionOpInterface>(
symbolTable.lookupNearestSymbolFrom(launchOp, launchOp.getFnAttr()));

Expand Down Expand Up @@ -167,6 +167,7 @@ class ReadOnlyArg final
}
};

template<>
enzymexla::KernelCallOp ReadOnlyArg<enzymexla::KernelCallOp>::create(PatternRewriter &rewriter, enzymexla::KernelCallOp launchOp, ArrayRef<Type> resTys, ArrayAttr outputAliases) const {
return rewriter.create<enzymexla::KernelCallOp>(
launchOp.getLoc(), resTys, launchOp.getFn(), launchOp.getGridx(),
Expand All @@ -177,9 +178,10 @@ enzymexla::KernelCallOp ReadOnlyArg<enzymexla::KernelCallOp>::create(PatternRewr
outputAliases);
}

template<>
enzymexla::JITCallOp ReadOnlyArg<enzymexla::JITCallOp>::create(PatternRewriter &rewriter, enzymexla::JITCallOp launchOp, ArrayRef<Type> resTys, ArrayAttr outputAliases) const {
return rewriter.create<enzymexla::KernelCallOp>(
launchOp.getLoc(), resTys,
return rewriter.create<enzymexla::JITCallOp>(
launchOp.getLoc(), resTys, launchOp.getFn(),
launchOp.getInputs(), launchOp.getBackendConfigAttr(),
launchOp.getOperandLayoutsAttr(), /*resultLayouts*/ nullptr,
outputAliases);
Expand All @@ -194,7 +196,7 @@ class ReadNoneArg final
LogicalResult matchAndRewrite(OpTy launchOp,
PatternRewriter &rewriter) const override {
SymbolTableCollection symbolTable;
auto mod = launchOp->getParentOfType<ModuleOp>();
auto mod = ((Operation*)launchOp)->getParentOfType<ModuleOp>();
symbolTable.getSymbolTable(mod);
auto fn = cast<FunctionOpInterface>(
symbolTable.lookupNearestSymbolFrom(launchOp, launchOp.getFnAttr()));
Expand All @@ -206,7 +208,7 @@ class ReadNoneArg final
if (!use_opt)
return failure();
for (auto u : *use_opt) {
auto launch2 = dyn_cast<T>(u.getUser());
auto launch2 = dyn_cast<OpTy>(u.getUser());
if (!launch2)
return failure();
calls.push_back(launch2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@

#include "mlir/Target/LLVMIR/Export.h"

#define DEBUG_TYPE "lower-call"
#define DEBUG_TYPE "lower-jit"

namespace mlir {
namespace enzyme {
#define GEN_PASS_DEF_LOWERCALLPASS
#define GEN_PASS_DEF_LOWERJITPASS
#include "src/enzyme_ad/jax/Passes/Passes.h.inc"
} // namespace enzyme
} // namespace mlir
Expand Down Expand Up @@ -184,7 +184,7 @@ CallInfo CompileHostModule(std::string &key, mlir::ModuleOp modOp,
llvmModule->setTargetTriple(JIT->getTargetTriple().getTriple());

auto LibA =
JIT->createJITDylib("enzymecudadl_" + std::to_string(kernels.size()));
JIT->createJITDylib("enzymejitdl_" + std::to_string(jitkernels.size()));
if (auto Err = JIT->addIRModule(
LibA.get(),
llvm::orc::ThreadSafeModule(std::move(llvmModule), std::move(ctx)))) {
Expand All @@ -196,17 +196,6 @@ CallInfo CompileHostModule(std::string &key, mlir::ModuleOp modOp,
return {};
}

if (cuLaunchPtr && cuLaunchPtr[0] == 0) {
// Look up the JIT'd code entry point.
auto LaunchSym = JIT->lookup(LibA.get(), "cuLaunchKernel");
if (!LaunchSym) {
llvm::errs() << " lookupError[cuLaunchKernel] " << LaunchSym.takeError()
<< "\n";
return {};
}
*cuLaunchPtr = (size_t)LaunchSym->getValue();
}

llvm::Expected<llvm::orc::ExecutorAddr> NVSym(llvm::orc::ExecutorAddr{});
if (compileInit) {
NVSym = JIT->lookup(LibA.get(), "nv_func_init");
Expand All @@ -231,7 +220,7 @@ CallInfo CompileHostModule(std::string &key, mlir::ModuleOp modOp,

CallInfo CompileCall(SymbolTableCollection &symbolTable,
mlir::Location loc, FunctionOpInterface op, bool jit,
enzymexla::JITCallOp) {
enzymexla::JITCallOp, bool openmp) {

OpBuilder builder(op);

Expand Down Expand Up @@ -340,19 +329,19 @@ CallInfo CompileCall(SymbolTableCollection &symbolTable,

op.getFunctionBody().cloneInto(&func.getBody(), map);

auto second = entryBlock->getNextNode();
entryBlock->getOperations().splice(entryBlock->getOperations().end(),
auto second = entryBlock.getNextNode();
entryBlock.getOperations().splice(entryBlock.getOperations().end(),
second->getOperations());

second->erase();

func.getBody()->walk([](LLVM::ReturnOp op) {
func.getBody().walk([](LLVM::ReturnOp op) {
OpBuilder rewriter(op);
rewriter.create<mlir::func::ReturnOp>(op.getLoc());
op.erase();
});

func.getBody()->walk([](LLVM::UnreachableOp op) {
func.getBody().walk([](LLVM::UnreachableOp op) {
OpBuilder rewriter(op);
rewriter.create<mlir::func::ReturnOp>(op.getLoc());
op.erase();
Expand All @@ -368,21 +357,26 @@ CallInfo CompileCall(SymbolTableCollection &symbolTable,

CallInfo ptr;
{
llvm::sys::SmartScopedWriter<true> lock(kernel_mutex);
llvm::sys::SmartScopedWriter<true> jit_lock(jit_kernel_mutex);

auto found = kernels.find(ss.str());
if (found != kernels.end()) {
auto found = jitkernels.find(ss.str());
if (found != jitkernels.end()) {
ptr = found->second;
} else {
PassManager pm(submod.getContext());
if (openmp)
pm.addPass(createConvertSCFToOpenMPPass());
else
pm.addPass(createConvertSCFToCFPass());

buildLowerToCPUPassPipeline(pm);

auto subres = pm.run(submod);
if (!subres.succeeded()) {
return {};
}
ptr = CompileHostModule(ss.str(), submod, false, 0, false);
kernels[ss.str()] = ptr;
ptr = CompileHostModule(ss.str(), submod, false, false);
jitkernels[ss.str()] = ptr;
submod.erase();
}
}
Expand All @@ -394,7 +388,7 @@ namespace {

struct LowerJITPass
: public mlir::enzyme::impl::LowerJITPassBase<LowerJITPass> {
using LowerKernelPassBase::LowerKernelPassBase;
using LowerJITPassBase::LowerJITPassBase;

void getDependentDialects(DialectRegistry &registry) const override {
OpPassManager pm;
Expand All @@ -412,8 +406,6 @@ struct LowerJITPass
SymbolTableCollection symbolTable;
symbolTable.getSymbolTable(getOperation());

llvm::SmallVector<std::string> linkFilesArray =
parseLinkFilesString(linkFiles.getValue());
getOperation()->walk([&](JITCallOp op) {
mlir::ArrayAttr operand_layouts =
op.getOperandLayouts()
Expand All @@ -435,7 +427,7 @@ struct LowerJITPass
return;
}

CallInfo cdata = CompileCall(symbolTable, op.getLoc(), fn, jit, op);
CallInfo cdata = CompileCall(symbolTable, op.getLoc(), fn, jit, op, openmp);

std::string backendinfo((char *)&cdata, sizeof(CallInfo));

Expand Down
72 changes: 14 additions & 58 deletions src/enzyme_ad/jax/Passes/LowerKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,7 @@ extern llvm::orc::SymbolMap MappedSymbols;
bool initJIT();

CallInfo CompileHostModule(std::string &key, mlir::ModuleOp modOp,
bool run_init, size_t *cuLaunchPtr,
bool compileInit = true);
bool run_init, bool compileInit = true);

gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op) {
ArrayRef<Attribute> objects = op.getObjectsAttr().getValue();
Expand Down Expand Up @@ -752,7 +751,7 @@ CallInfo CompileCUDAKernel(
if (!compileLaunch)
return {};

ptr = CompileHostModule(ss.str(), submod, run_init, &cuLaunchKernelPtr);
ptr = CompileHostModule(ss.str(), submod, run_init);
kernels[ss.str()] = ptr;

submod.erase();
Expand All @@ -766,14 +765,12 @@ bool CompileCPUKernel(SymbolTableCollection &symbolTable,
mlir::Location loc, FunctionOpInterface op, bool jit,
size_t gridx, size_t gridy, size_t gridz,
size_t blockx, size_t blocky, size_t blockz,
size_t shmem, enzymexla::KernelCallOp kcall, bool debug,
bool openmp) {
size_t shmem, enzymexla::KernelCallOp kcall, bool debug) {

OpBuilder builder(op);

auto ptrty = LLVM::LLVMPointerType::get(builder.getContext());
mlir::Type intys[] = {ptrty, ptrty, ptrty};
FunctionType calleeType = builder.getFunctionType(intys, {});

FunctionType gpuTy0 = dyn_cast<FunctionType>(op.getFunctionType());
if (!gpuTy0) {
Expand All @@ -794,16 +791,14 @@ bool CompileCPUKernel(SymbolTableCollection &symbolTable,
}
FunctionType gpuTy = builder.getFunctionType(newParams, {});

auto func = builder.create<func::FuncOp>(loc, "entry", calleeType);

static int id = 0;
auto callName = (op.getName() + "$" + "par" + std::to_string(id)).str();
id++;
auto func = builder.create<func::FuncOp>(loc, callName, gpuTy0);
func.setVisibility(SymbolTable::Visibility::Private);
auto &entryBlock = *func.addEntryBlock();
builder.setInsertionPointToStart(&entryBlock);

mlir::Value buffers = entryBlock.getArgument(1);

auto idx = builder.getIntegerType(64);
auto i32 = builder.getIntegerType(32);

SmallVector<mlir::Value> inits;
SmallVector<mlir::Value> finals;
SmallVector<mlir::Value> incs;
Expand All @@ -813,33 +808,8 @@ bool CompileCPUKernel(SymbolTableCollection &symbolTable,
finals.push_back(builder.create<arith::ConstantIndexOp>(loc, val));
}

SmallVector<mlir::Value> arguments;
for (auto arg : op.getArguments()) {
LLVM::GEPArg args[1] = {arg.getArgNumber()};
auto gep =
builder.create<LLVM::GEPOp>(loc, ptrty, ptrty, buffers, args, true);
auto argTy = arg.getType();
if (auto AT = dyn_cast<LLVM::LLVMArrayType>(argTy)) {
argTy = AT.getElementType();
}
auto ld = builder.create<LLVM::LoadOp>(loc, argTy, gep);
arguments.push_back(ld);
}

IRMapping map;
for (auto &&[oldarg, newarg] : zip(op.getArguments(), arguments)) {
Value newval = newarg;

if (auto AT = dyn_cast<LLVM::LLVMArrayType>(oldarg.getType())) {
auto ud =
builder.create<LLVM::UndefOp>(newarg.getLoc(), oldarg.getType());
int64_t c0[1] = {0};
newval = builder.create<LLVM::InsertValueOp>(
newarg.getLoc(), oldarg.getType(), ud, newval, c0);
}

map.map(oldarg, newval);
}
map.map(op.getArguments(), entryBlock.getArguments());

auto par = builder.create<scf::ParallelOp>(loc, inits, finals, incs);

Expand Down Expand Up @@ -917,22 +887,10 @@ bool CompileCPUKernel(SymbolTableCollection &symbolTable,
});
}

PassManager pm(func.getContext());
if (openmp)
pm.addPass(createConvertSCFToOpenMPPass());
else
pm.addPass(createConvertSCFToCFPass());

auto subres = pm.run(func);
if (!subres.succeeded()) {
return false;
}


OpBuilder rewriter(kcall);
auto replacement = rewriter.create<enzymexla::JITCallOp>(kcall.getLoc(), kcall.getResultTypes(), kcall.getInputs(), func, kcall.getBackendConfig(), kcall.getOperandLayouts(), kcall.getResultLayouts(), kcall.getOutputPperandAliases());
kcall.replace(replacement);
op.erase();
auto replacement = rewriter.create<enzymexla::JITCallOp>(kcall.getLoc(), kcall.getResultTypes(), mlir::FlatSymbolRefAttr::get(kcall.getContext(), callName), kcall.getInputs(), kcall.getBackendConfigAttr(), kcall.getOperandLayoutsAttr(), kcall.getResultLayoutsAttr(), kcall.getOutputOperandAliasesAttr());
kcall.replaceAllUsesWith(replacement);
kcall.erase();
return true;
};

Expand All @@ -959,8 +917,6 @@ struct LowerKernelPass
buildLowerToNVVMPassPipeline(pm, options, toolkitPath, linkFiles);
} else if (backend == "cpu") {
buildLowerToCPUPassPipeline(pm);
if (openmp)
registry.insert<mlir::omp::OpenMPDialect>();
}
pm.getDependentDialects(registry);

Expand Down Expand Up @@ -1063,9 +1019,9 @@ struct LowerKernelPass
op.replaceAllUsesWith(replacement);
op.erase();
} else if (backend == "cpu") {
cdata = CompileCPUKernel(symbolTable, op.getLoc(), fn, jit, data[1],
CompileCPUKernel(symbolTable, op.getLoc(), fn, jit, data[1],
data[2], data[3], data[4], data[5], data[6],
data[7], op, debug, openmp);
data[7], op, debug);
}
else {
op->emitError() << "Cannot lower kernel to unknown backend \""
Expand Down
6 changes: 0 additions & 6 deletions src/enzyme_ad/jax/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,6 @@ def LowerKernelPass : Pass<"lower-kernel"> {
/*type=*/"std::string",
/*default=*/"\"cuda\"",
/*description=*/"HW backend">,
Option<
/*C++ variable name=*/"openmp",
/*CLI argument=*/"openmp",
/*type=*/"bool",
/*default=*/"true",
/*description=*/"whether to use openmp for lowering">,
];
}

Expand Down
34 changes: 27 additions & 7 deletions test/lit_tests/lowering/cpu.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(lower-kernel{jit=false backend=cpu})" | FileCheck %s
// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(lower-kernel{jit=false backend=cpu},canonicalize)" | FileCheck %s

module {
llvm.func internal unnamed_addr fastcc @throw_boundserror_2676() attributes {dso_local, no_inline, sym_visibility = "private"} {
Expand Down Expand Up @@ -29,11 +29,31 @@ module {
}
}

// CHECK: func.func @main(%arg0: tensor<64xi64>) -> tensor<64xi64> {
// CHECK-NEXT: stablehlo.constant
// CHECK-NEXT: stablehlo.constant
// CHECK-NEXT: stablehlo.constant
// CHECK-NEXT: %0 = stablehlo.custom_call @enzymexla_compile_cpu(%arg0) {api_version = 3 : i32, backend_config = "\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00", output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>]} : (tensor<64xi64>) -> tensor<64xi64>
// CHECK: func.func private @kern$par0(%arg0: !llvm.ptr<1>) {
// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
// CHECK-NEXT: %0 = llvm.mlir.constant(63 : i32) : i32
// CHECK-NEXT: %c0 = arith.constant 0 : index
// CHECK-NEXT: %c1 = arith.constant 1 : index
// CHECK-NEXT: %c40 = arith.constant 40 : index
// CHECK-NEXT: scf.parallel (%arg1) = (%c0) to (%c40) step (%c1) {
// CHECK-NEXT: scf.execute_region {
// CHECK-NEXT: %1 = llvm.icmp "ugt" %c0_i32, %0 : i32
// CHECK-NEXT: llvm.cond_br %1, ^bb2, ^bb1
// CHECK-NEXT: ^bb1: // pred: ^bb0
// CHECK-NEXT: %2 = llvm.load %arg0 {alignment = 1 : i64} : !llvm.ptr<1> -> i64
// CHECK-NEXT: %3 = llvm.mul %2, %2 : i64
// CHECK-NEXT: llvm.store %3, %arg0 {alignment = 1 : i64} : i64, !llvm.ptr<1>
// CHECK-NEXT: scf.yield
// CHECK-NEXT: ^bb2: // pred: ^bb0
// CHECK-NEXT: llvm.call fastcc @throw_boundserror_2676() : () -> ()
// CHECK-NEXT: scf.yield
// CHECK-NEXT: }
// CHECK-NEXT: scf.reduce
// CHECK-NEXT: }
// CHECK-NEXT: return
// CHECK-NEXT: }

// CHECK: func.func @main(%arg0: tensor<64xi64>) -> tensor<64xi64> {
// CHECK-NEXT: %0 = enzymexla.jit_call @kern$par0 (%arg0) {output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>]} : (tensor<64xi64>) -> tensor<64xi64>
// CHECK-NEXT: return %0 : tensor<64xi64>
// CHECK-NEXT: }
// CHECK-NEXT:}
Loading

0 comments on commit 8772298

Please sign in to comment.