Skip to content

Commit

Permalink
Add gradient of pad operation
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 15, 2024
1 parent 5db59e2 commit 5b7b514
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 81 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependen

pip_install_dependencies()

ENZYME_COMMIT = "fb96efd5d9528e73cb9e69491b872b006799c482"
ENZYME_SHA256 = "184d9633439883407376d80ee5fc58b6b86d4e2c04e589602534cb52f187f4ae"
ENZYME_COMMIT = "9acbc0a667ec8ae76407b5708758667a65ff15aa"
ENZYME_SHA256 = "287143133ccf9501a02f1bdab351c34adcab3bbfc8648b180ebd79d0e058b3af"

http_archive(
name = "enzyme",
Expand Down
26 changes: 25 additions & 1 deletion src/enzyme_ad/jax/Implementations/HLODerivatives.td
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

def Add : HLOInst<"AddOp">;
def Sub : HLOInst<"SubtractOp">;
def Neg : HLOInst<"NegOp">;
Expand Down Expand Up @@ -96,12 +95,37 @@ def Transpose : HLOInst<"TransposeOp">;
def Reshape : HLOInst<"ReshapeOp">;
def : HLOReadOnlyIdentityOp<"ReshapeOp", [0], (Op $x), [(Reshape (TypeOf $x), (DiffeRet))]>;

def Slice : HLOInst<"SliceOp">;
def : HLOReadOnlyIdentityOp<"SliceOp">;

def Reduce : HLOInst<"ReduceOp">;
def : HLOReadOnlyIdentityOp<"BroadcastInDimOp">;

def : HLOMemoryIdentityOp<"ConcatenateOp", [], [-1]>;

def PadToSliceStart : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
op.getEdgePaddingLow();
}]>;

def PadToSliceLimit : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
SmallVector<int64_t> limits;
for (auto &&[high, dim] : llvm::zip(op.getEdgePaddingHigh(), op.getType().getShape()))
limits.push_back(to_i64(dim - high));
getI64Attr(builder, limits);
}]>;

def PadToSliceStride : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
SmallVector<int64_t> strides;
for (auto interior : op.getInteriorPadding())
strides.push_back(to_i64(interior + 1));
getI64Attr(builder, strides);
}]>;

def : HLOMemoryIdentityOp<"PadOp", [], [-1], (Op $op, $padval), [
(Slice (TypeOf $op), (DiffeRet), (PadToSliceStart), (PadToSliceLimit), (PadToSliceStride)),
(AssertingInactiveArg)
]>;

// convert


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ using namespace mlir;
using namespace mlir::enzyme;
using namespace mlir::mhlo;

static int64_t to_i64(int64_t x) { return x; }
static int64_t to_i64(llvm::APInt x) { return x.getSExtValue(); }

static mlir::DenseIntElementsAttr getI64Attr(OpBuilder &builder,
llvm::ArrayRef<int64_t> vals) {
return builder.getI64VectorAttr(vals);
}

namespace {
#include "src/enzyme_ad/jax/Implementations/MHLODerivatives.inc"
} // namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ using namespace mlir;
using namespace mlir::enzyme;
using namespace mlir::stablehlo;

static int64_t to_i64(int64_t x) { return x; }
static int64_t to_i64(llvm::APInt x) { return x.getSExtValue(); }

static mlir::DenseI64ArrayAttr getI64Attr(OpBuilder &builder,
llvm::ArrayRef<int64_t> vals) {
return builder.getDenseI64ArrayAttr(vals);
}

namespace {
#include "src/enzyme_ad/jax/Implementations/StableHLODerivatives.inc"

Expand Down Expand Up @@ -295,84 +303,6 @@ class AutoDiffSliceRev

gutils->addToDiffe(op.getOperand(), red->getResult(0), builder);
return success();
#if 0

auto outTy = op.getType();
auto zero = inTy.cast<AutoDiffTypeInterface>().createNullValue(builder,
op.getLoc());
Value idxs;
{
SmallVector<int64_t> concat_data;
for (size_t i = 0; i < outTy.getShape().size(); i++) {
concat_data.push_back(outTy.getShape()[i]);
}
concat_data.push_back(1);
auto toConcatType =
RankedTensorType::get(concat_data, builder.getI32Type());
std::vector<Value> inds;
size_t idx = 0;
for (auto &&[start, limit, stride] : llvm::zip(
op.getStartIndices(), op.getLimitIndices(), op.getStrides())) {
std::vector<int32_t> data;
for (int32_t i = start; i < limit; i += stride) {
data.push_back(i);
}
Value ind = builder.create<ConstantOp>(op.getLoc(), RankedTensorType::get({(int64_t)data.size()}, builder.getI32Type()),
builder.getI32TensorAttr(data));

auto bcast_ind = builder.getDenseI64ArrayAttr({(int64_t)idx});
ind = builder.create<BroadcastInDimOp>(op.getLoc(), toConcatType, ind,
bcast_ind);
inds.push_back(ind);
idx++;
}
idxs = builder.create<ConcatenateOp>(
op.getLoc(), inds, builder.getI64IntegerAttr(concat_data.size() - 1));
}

// empty extra index into the slice
std::vector<int64_t> update_window_dims;
std::vector<int64_t> scatter_dims_to_operand_dims;
std::vector<int64_t> inserted_window_dims;
for (int i = 0; i < inTy.getShape().size(); i++) {
scatter_dims_to_operand_dims.push_back(i);
inserted_window_dims.push_back(i);
}

int64_t indexVectorDim = inTy.getShape().size();

auto dims = ScatterDimensionNumbersAttr::get(
builder.getContext(), update_window_dims, inserted_window_dims,
scatter_dims_to_operand_dims, indexVectorDim);

// auto prev = gutils->diffe(op.getOperand(), builder);

auto red = builder.create<ScatterOp>(
op.getLoc(), TypeRange(gutils->getShadowType(inTy)), ValueRange(zero),
idxs, ValueRange(inDiffe), dims,
/*indices_are_sorted*/ builder.getBoolAttr(true),
/*unique_indices*/ builder.getBoolAttr(true));

red.getUpdateComputation().push_back(new Block());
Block &body = red.getUpdateComputation().front();
OpBuilder bodyBuilder(orig->getContext());
bodyBuilder.setInsertionPointToEnd(&body);

auto TT = RankedTensorType::get({}, inTy.getElementType());
body.addArgument(TT, op.getLoc());
body.addArgument(TT, op.getLoc());
/*
auto add = bodyBuilder.create<AddOp>(op.getLoc(), body.getArgument(0),
body.getArgument(1));
bodyBuilder.create<ReturnOp>(op.getLoc(), ValueRange(add));
*/
bodyBuilder.create<ReturnOp>(op.getLoc(), ValueRange(body.getArgument(1)));

gutils->addToDiffe(op.getOperand(), red->getResult(0), builder);
// gutils->setDiffe(op.getOperand(), red->getResult(0), builder);

return success();
#endif
}

SmallVector<Value> cacheValues(Operation *orig,
Expand Down
28 changes: 28 additions & 0 deletions test/lit_tests/grad_pad.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=main outfn= retTys=enzyme_dup argTys=enzyme_dup mode=ForwardMode" | FileCheck %s --check-prefix=FORWARD
// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=main outfn= retTys=enzyme_out argTys=enzyme_out mode=ReverseModeCombined" --canonicalize --remove-unnecessary-enzyme-ops | FileCheck %s --check-prefix=REVERSE

module {

func.func @main(%a : tensor<2x3xf32>) -> tensor<18x27xf32> {
%0 = stablehlo.constant dense<3.140000e+00> : tensor<f32>
%2 = stablehlo.pad %a, %0, low = [5, 7], high = [11, 13], interior = [0, 2] : (tensor<2x3xf32>, tensor<f32>) -> tensor<18x27xf32>
return %2 : tensor<18x27xf32>
}
}

// FORWARD: func.func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>) -> (tensor<18x27xf32>, tensor<18x27xf32>) {
// FORWARD-NEXT: %0 = stablehlo.constant dense<3.140000e+00> : tensor<f32>
// FORWARD-NEXT: %cst = arith.constant dense<0.000000e+00> : tensor<f32>
// FORWARD-NEXT: %1 = stablehlo.pad %arg1, %cst, low = [5, 7], high = [11, 13], interior = [0, 2] : (tensor<2x3xf32>, tensor<f32>) -> tensor<18x27xf32>
// FORWARD-NEXT: %2 = stablehlo.pad %arg0, %0, low = [5, 7], high = [11, 13], interior = [0, 2] : (tensor<2x3xf32>, tensor<f32>) -> tensor<18x27xf32>
// FORWARD-NEXT: return %2, %1 : tensor<18x27xf32>, tensor<18x27xf32>
// FORWARD-NEXT: }

// REVERSE: func.func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<18x27xf32>) -> tensor<2x3xf32> {
// REVERSE-NEXT: %cst = arith.constant dense<0.000000e+00> : tensor<18x27xf32>
// REVERSE-NEXT: %cst_0 = arith.constant dense<0.000000e+00> : tensor<2x3xf32>
// REVERSE-NEXT: %0 = arith.addf %arg1, %cst : tensor<18x27xf32>
// REVERSE-NEXT: %1 = stablehlo.slice %0 [5:7, 7:14:3] : (tensor<18x27xf32>) -> tensor<2x3xf32>
// REVERSE-NEXT: %2 = arith.addf %1, %cst_0 : tensor<2x3xf32>
// REVERSE-NEXT: return %2 : tensor<2x3xf32>
// REVERSE-NEXT: }

0 comments on commit 5b7b514

Please sign in to comment.