Skip to content

Commit

Permalink
Annotate kernel params
Browse files Browse the repository at this point in the history
  • Loading branch information
chelini committed Feb 2, 2025
1 parent 81445c4 commit fb7f715
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 12 deletions.
16 changes: 14 additions & 2 deletions src/enzyme_ad/jax/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,20 @@ def RemoveDuplicateFuncDefPass
def PropagateConstantBoundsPass
: Pass<"propagate-constant-bounds", "ModuleOp"> {
let summary = "Propagate constant bounds";
let dependentDialects =
["mlir::LLVM::LLVMDialect", "mlir::NVVM::NVVMDialect"];
let description = [{
Propagate constant bounds information:
1. thread index
2. Block dimension
3. Block index.
Additionally, set all the following attributes for all kernel pointers:
1. align 128
2. no alias
3. dereferenceable = tensor.size() * sizeof(element_type)
}];
let dependentDialects = [
"mlir::LLVM::LLVMDialect",
"mlir::NVVM::NVVMDialect"
];
}

def ArithRaisingPass : Pass<"arith-raise"> {
Expand Down
105 changes: 102 additions & 3 deletions src/enzyme_ad/jax/Passes/PropagateConstantBound.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,32 @@ struct PropagateConstantBoundsPass
: public enzyme::impl::PropagateConstantBoundsPassBase<
PropagateConstantBoundsPass> {

// If we know that this range is constant, we attach the LLVM range attribute
// to the target. If the target already has a range, we update it by taking
// the maximum between the current value and the old value to be conservative.
static void attachConstantRangeIfConstant(MLIRContext *ctx,
Operation *maybeCst,
Operation *target) {
APInt intValue;
if (matchPattern(maybeCst, m_ConstantInt(&intValue)))
target->setAttr("range", LLVM::ConstantRangeAttr::get(
ctx, 32, 0, intValue.getSExtValue()));
if (matchPattern(maybeCst, m_ConstantInt(&intValue))) {
std::string constantRangeAttrName = "range";
Attribute maybeRange = target->getAttr(constantRangeAttrName);
if (!maybeRange) {
target->setAttr(
constantRangeAttrName,
LLVM::ConstantRangeAttr::get(ctx, 32, 0, intValue.getSExtValue()));
} else {
LLVM::ConstantRangeAttr range =
dyn_cast<LLVM::ConstantRangeAttr>(maybeRange);
int64_t high = range.getUpper().getSExtValue();
high = std::max(high, intValue.getSExtValue());
target->setAttr(constantRangeAttrName,
LLVM::ConstantRangeAttr::get(ctx, 32, 0, high));
}
}
}

// Replace the target with a constant if the target is a constant value.
static void replaceWithConstantIfConstant(OpBuilder &builder,
Operation *maybeCst,
Operation *target) {
Expand All @@ -58,6 +75,37 @@ struct PropagateConstantBoundsPass
}
}

static int32_t getSizeInBytes(Type ty) {
int32_t bitWidth = 0;
if (auto inType = dyn_cast<IntegerType>(ty)) {
bitWidth = inType.getWidth();
if (bitWidth == 1)
return 1;
}
if (auto floatType = dyn_cast<FloatType>(ty))
bitWidth = floatType.getWidth();
assert(bitWidth != 0);
return bitWidth / 8;
}

static int32_t getMemRefSizeInBytes(Value operand) {
auto ty = operand.getType();
int32_t numberOfElems = 0;
if (auto tensorTy = dyn_cast<RankedTensorType>(ty)) {
numberOfElems = tensorTy.getNumElements();
}
return numberOfElems * getSizeInBytes(getElementTypeOrSelf(ty));
}

FailureOr<Attribute> getAttributeWithName(ArrayRef<NamedAttribute> attrs,
StringRef name) {
for (const NamedAttribute &attr : attrs) {
if (attr.getName() == name)
return attr.getValue();
}
return failure();
}

void runOnOperation() override {
auto moduleOp = getOperation();
auto *ctx = moduleOp->getContext();
Expand Down Expand Up @@ -129,6 +177,57 @@ struct PropagateConstantBoundsPass
gridIdzOp.getOperation());
});
});

auto result = moduleOp->walk([&](enzymexla::KernelCallOp callOp) {
auto symbolName = callOp.getFn();
auto callee = symTable.lookup<FunctionOpInterface>(symbolName);
if (!callee)
return WalkResult::advance();
MLIRContext *ctx = callee->getContext();
for (auto [index, valTy] : llvm::enumerate(callee.getArgumentTypes())) {
ArrayRef<NamedAttribute> operandAttrs = callee.getArgAttrs(index);
if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(valTy)) {

FailureOr<Attribute> noAliasAttr = getAttributeWithName(
operandAttrs, LLVM::LLVMDialect::getNoAliasAttrName());
FailureOr<Attribute> alignAttr = getAttributeWithName(
operandAttrs, LLVM::LLVMDialect::getAlignAttrName());
FailureOr<Attribute> dereferenceableAttr = getAttributeWithName(
operandAttrs, LLVM::LLVMDialect::getDereferenceableAttrName());

if (failed(noAliasAttr)) {
callee.setArgAttr(index, LLVM::LLVMDialect::getNoAliasAttrName(),
UnitAttr::get(ctx));
}

if (failed(alignAttr)) {
callee.setArgAttr(index, LLVM::LLVMDialect::getAlignAttrName(),
IntegerAttr::get(IntegerType::get(ctx, 32), 128));
}

if (failed(dereferenceableAttr)) {
callee.setArgAttr(index,
LLVM::LLVMDialect::getDereferenceableAttrName(),
IntegerAttr::get(IntegerType::get(ctx, 32),
getMemRefSizeInBytes(
callOp.getInputs()[index])));
} else {
// Conservatively update the dereferenceable attribute if the
// current value is less than we already have.
IntegerAttr intAttr = cast<IntegerAttr>(*dereferenceableAttr);
int64_t oldVal = intAttr.getInt();
int64_t currentVal =
getMemRefSizeInBytes(callOp.getInputs()[index]);
if (currentVal < oldVal) {
callee.setArgAttr(
index, LLVM::LLVMDialect::getDereferenceableAttrName(),
IntegerAttr::get(IntegerType::get(ctx, 32), currentVal));
}
}
}
}
return WalkResult::advance();
});
}
};
} // end namespace
87 changes: 87 additions & 0 deletions test/lit_tests/annotate_func_args.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// RUN: enzymexlamlir-opt %s --propagate-constant-bounds --split-input-file | FileCheck %s


// CHECK-LABEL: ptx_kernelcc @foo
// CHECK-SAME: llvm.align = 128 : i32, llvm.dereferenceable = 16 : i32, llvm.noalias
llvm.func ptx_kernelcc @foo(%arg0: !llvm.ptr<1> {llvm.nocapture, llvm.nofree}) {
llvm.return
}

// CHECK-LABEL: ptx_kernelcc @bar
// CHECK-SAME: llvm.align = 128 : i32, llvm.dereferenceable = 64 : i32, llvm.noalias
llvm.func ptx_kernelcc @bar(%arg0: !llvm.ptr<1> {llvm.nocapture, llvm.nofree}) {
llvm.return
}

func.func @main(%arg0: tensor<4xf32>, %arg1: tensor<8xf64>) {
%c_4 = stablehlo.constant dense<1> : tensor<i64>
%c_5 = stablehlo.constant dense<2> : tensor<i64>
%c_6 = stablehlo.constant dense<3> : tensor<i64>
%c_8 = stablehlo.constant dense<4> : tensor<i64>
enzymexla.kernel_call @foo blocks in(%c_5, %c_8, %c_8) threads in(%c_4, %c_8, %c_8) shmem = %c_6 (%arg0) {} : (tensor<4xf32>) -> ()
enzymexla.kernel_call @bar blocks in(%c_5, %c_8, %c_8) threads in(%c_4, %c_8, %c_8) shmem = %c_6 (%arg1) {} : (tensor<8xf64>) -> ()
return
}

// -----

// CHECK-LABEL: ptx_kernelcc @bar
// CHECK-SAME: llvm.align = 128 : i32, llvm.dereferenceable = 4 : i32, llvm.noalias, llvm.nocapture, llvm.nofree
llvm.func ptx_kernelcc @bar(%arg0: !llvm.ptr<1> {llvm.nocapture, llvm.nofree}) {
llvm.return
}

func.func @main(%arg0: tensor<4xf32>, %arg1: tensor<f32>) {
%c_1 = stablehlo.constant dense<1> : tensor<i64>
%c_2 = stablehlo.constant dense<2> : tensor<i64>
%c_3 = stablehlo.constant dense<3> : tensor<i64>
%c_4 = stablehlo.constant dense<4> : tensor<i64>
%c_5 = stablehlo.constant dense<5> : tensor<i64>
%c_6 = stablehlo.constant dense<6> : tensor<i64>
%c_8 = stablehlo.constant dense<8> : tensor<i64>
enzymexla.kernel_call @bar blocks in(%c_5, %c_8, %c_8) threads in(%c_4, %c_8, %c_8) shmem = %c_6 (%arg0) {} : (tensor<4xf32>) -> ()
enzymexla.kernel_call @bar blocks in(%c_1, %c_2, %c_3) threads in(%c_4, %c_8, %c_8) shmem = %c_6 (%arg1) {} : (tensor<f32>) -> ()
return
}

// -----

// CHECK-LABEL: ptx_kernelcc @bar
// CHECK-SAME: llvm.align = 128 : i32, llvm.dereferenceable = 4 : i32, llvm.noalias, llvm.nocapture, llvm.nofree
llvm.func ptx_kernelcc @bar(%arg0: !llvm.ptr<1> {llvm.nocapture, llvm.nofree}) {
llvm.return
}

func.func @main(%arg0: tensor<f32>, %arg1: tensor<4xf32>) {
%c_1 = stablehlo.constant dense<1> : tensor<i64>
%c_2 = stablehlo.constant dense<2> : tensor<i64>
%c_3 = stablehlo.constant dense<3> : tensor<i64>
%c_4 = stablehlo.constant dense<4> : tensor<i64>
%c_5 = stablehlo.constant dense<5> : tensor<i64>
%c_6 = stablehlo.constant dense<6> : tensor<i64>
%c_8 = stablehlo.constant dense<8> : tensor<i64>
enzymexla.kernel_call @bar blocks in(%c_5, %c_8, %c_8) threads in(%c_4, %c_8, %c_8) shmem = %c_6 (%arg0) {} : (tensor<f32>) -> ()
enzymexla.kernel_call @bar blocks in(%c_1, %c_2, %c_3) threads in(%c_4, %c_8, %c_8) shmem = %c_6 (%arg1) {} : (tensor<4xf32>) -> ()
return
}

// -----

// CHECK-LABEL: ptx_kernelcc @bar
// CHECK-SAME: llvm.align = 128 : i32, llvm.dereferenceable = 4 : i32, llvm.noalias, llvm.nocapture, llvm.nofree
llvm.func ptx_kernelcc @bar(%arg0: !llvm.ptr<1> {llvm.noalias, llvm.nocapture, llvm.nofree}) {
llvm.return
}

func.func @main(%arg0: tensor<f32>, %arg1: tensor<4xf32>) {
%c_1 = stablehlo.constant dense<1> : tensor<i64>
%c_2 = stablehlo.constant dense<2> : tensor<i64>
%c_3 = stablehlo.constant dense<3> : tensor<i64>
%c_4 = stablehlo.constant dense<4> : tensor<i64>
%c_5 = stablehlo.constant dense<5> : tensor<i64>
%c_6 = stablehlo.constant dense<6> : tensor<i64>
%c_8 = stablehlo.constant dense<8> : tensor<i64>
enzymexla.kernel_call @bar blocks in(%c_5, %c_8, %c_8) threads in(%c_4, %c_8, %c_8) shmem = %c_6 (%arg0) {} : (tensor<f32>) -> ()
enzymexla.kernel_call @bar blocks in(%c_1, %c_2, %c_3) threads in(%c_4, %c_8, %c_8) shmem = %c_6 (%arg1) {} : (tensor<4xf32>) -> ()
return
}
43 changes: 36 additions & 7 deletions test/lit_tests/propagate-constant-values.mlir
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// RUN: enzymexlamlir-opt %s --propagate-constant-bounds | FileCheck %s
// RUN: enzymexlamlir-opt %s --propagate-constant-bounds --split-input-file | FileCheck %s

llvm.func @foo(%arg0: i32) -> i32 {
llvm.return %arg0 : i32
}

// CHECK-LABEL: ptx_kernelcc
llvm.func ptx_kernelcc @"##foo#3846"() {
llvm.func ptx_kernelcc @bar() {
// CHECK: nvvm.read.ptx.sreg.tid.x range <i32, 0, 1> : i32
%0 = nvvm.read.ptx.sreg.tid.x : i32
// CHECK-NEXT: nvvm.read.ptx.sreg.ctaid.x range <i32, 0, 2> : i32
Expand All @@ -18,10 +18,39 @@ llvm.func ptx_kernelcc @"##foo#3846"() {
}

func.func @main() {
%c_4 = stablehlo.constant dense<1> : tensor<i64>
%c_5 = stablehlo.constant dense<2> : tensor<i64>
%c_6 = stablehlo.constant dense<3> : tensor<i64>
%c_8 = stablehlo.constant dense<4> : tensor<i64>
enzymexla.kernel_call @"##foo#3846" blocks in(%c_5, %c_8, %c_8) threads in(%c_4, %c_8, %c_8) shmem = %c_6 () {} : () -> ()
%c_1 = stablehlo.constant dense<1> : tensor<i64>
%c_2 = stablehlo.constant dense<2> : tensor<i64>
%c_4 = stablehlo.constant dense<4> : tensor<i64>
%c_6 = stablehlo.constant dense<6> : tensor<i64>
enzymexla.kernel_call @bar blocks in(%c_2, %c_4, %c_4) threads in(%c_1, %c_4, %c_4) shmem = %c_6 () {} : () -> ()
return
}

// -----

llvm.func @foo(%arg0: i32) -> i32 {
llvm.return %arg0 : i32
}

// CHECK-LABEL: ptx_kernelcc
llvm.func ptx_kernelcc @bar() {
// CHECK: nvvm.read.ptx.sreg.tid.x range <i32, 0, 4> : i32
%0 = nvvm.read.ptx.sreg.tid.x : i32
// CHECK-NEXT: nvvm.read.ptx.sreg.ctaid.x range <i32, 0, 4> : i32
%1 = nvvm.read.ptx.sreg.ctaid.x : i32
// CHECK-NEXT: %[[CST:.+]] = llvm.mlir.constant(4 : i32) : i32
%2 = nvvm.read.ptx.sreg.ntid.x : i32
// CHECK: %{{.+}} = llvm.call @foo(%[[CST]]) : (i32) -> i32
llvm.call @foo(%2) : (i32) -> i32
llvm.return
}

func.func @main() {
%c_1 = stablehlo.constant dense<1> : tensor<i64>
%c_2 = stablehlo.constant dense<2> : tensor<i64>
%c_4 = stablehlo.constant dense<4> : tensor<i64>
%c_6 = stablehlo.constant dense<6> : tensor<i64>
enzymexla.kernel_call @bar blocks in(%c_4, %c_4, %c_4) threads in(%c_4, %c_4, %c_4) shmem = %c_6 () {} : () -> ()
enzymexla.kernel_call @bar blocks in(%c_2, %c_4, %c_4) threads in(%c_1, %c_4, %c_4) shmem = %c_6 () {} : () -> ()
return
}

0 comments on commit fb7f715

Please sign in to comment.