Skip to content

Commit

Permalink
support dilation and feature/batch group count in convolution reverse (
Browse files Browse the repository at this point in the history
#181)

* support dilation and feature group count in convolution reverse

* support batch group count

* fix dimensions for post conv transpose (batch gruop count)
  • Loading branch information
Pangoraw authored and vimarsh6739 committed Dec 14, 2024
1 parent 872d3a0 commit 74e879b
Show file tree
Hide file tree
Showing 10 changed files with 360 additions and 38 deletions.
286 changes: 250 additions & 36 deletions src/enzyme_ad/jax/Implementations/HLODerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,145 @@ def ConvBatchGroupCount : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{

// GradData

def GradDataFilterReshape1 : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
auto featureGroupCount = op.getFeatureGroupCount();
auto batchGroupCount = op.getBatchGroupCount();
assert(featureGroupCount == 1 || batchGroupCount == 1);
auto groupCount = featureGroupCount == 1 ? batchGroupCount : featureGroupCount;

auto rhs = op.getRhs();
auto dimensionNumbers = op.getDimensionNumbers();
auto Ty = cast<RankedTensorType>(rhs.getType());
auto shape = Ty.getShape();

auto odim = dimensionNumbers.getKernelOutputFeatureDimension();

SmallVector<int64_t> newShape;
for (int64_t i = 0, e = shape.size(); i < e; ++i) {
if (i == odim) {
newShape.push_back(groupCount);
newShape.push_back(shape[i] / groupCount);
} else {
newShape.push_back(shape[i]);
}
}

RankedTensorType::get(newShape, Ty.getElementType());
}]>;

def GradDataFilterTranspose : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
SmallVector<int64_t> transposes;
auto dimensionNumbers = op.getDimensionNumbers();
auto idim = dimensionNumbers.getKernelInputFeatureDimension();
auto odim = dimensionNumbers.getKernelOutputFeatureDimension();

if (odim < idim)
idim++;

int64_t i = 0, N = op.getType().getShape().size();
while (i <= N) {
if (i == idim) {
transposes.push_back(odim);
transposes.push_back(idim);
} else if (i != odim) {
transposes.push_back(i);
}
i++;
}

getI64Attr(builder, transposes);
}]>;

def GradDataConvOutputType : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
auto Ty = op.getLhs().getType();
auto batchGroupCount = op.getBatchGroupCount();

if (batchGroupCount > 1) {
SmallVector<int64_t> shape(Ty.getShape().begin(), Ty.getShape().end());
auto dimensionNumbers = op.getDimensionNumbers();
shape[dimensionNumbers.getInputFeatureDimension()] *= batchGroupCount;
shape[dimensionNumbers.getInputBatchDimension()] /= batchGroupCount;
Ty = Ty.clone(shape);
}

Ty;
}]>;

def GradDataConvBatchGroupCountType : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
auto Ty = op.getLhs().getType();
auto batchGroupCount = op.getBatchGroupCount();

auto dimensionNumbers = op.getDimensionNumbers();
auto fdim = dimensionNumbers.getInputFeatureDimension();
auto bdim = dimensionNumbers.getInputBatchDimension();

auto shape = Ty.getShape();
SmallVector<int64_t> newShape;
for (int64_t i = 0, e = shape.size(); i < e; ++i) {
if (i == fdim) {
newShape.push_back(batchGroupCount);
newShape.push_back(shape[i]);
} else if (i == bdim) {
newShape.push_back(shape[i] / batchGroupCount);
} else {
newShape.push_back(shape[i]);
}
}

Ty.clone(newShape);
}]>;

def GradDataConvBatchGroupPerm : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
SmallVector<int64_t> transposes;

auto dimensionNumbers = op.getDimensionNumbers();
auto fdim = dimensionNumbers.getInputFeatureDimension();
auto bdim = dimensionNumbers.getInputBatchDimension();

if (fdim < bdim)
bdim++;

int64_t i = 0, N = op.getType().getShape().size();
while (i <= N) {
if (i == bdim) {
transposes.push_back(fdim);
transposes.push_back(bdim);
} else if (i != fdim) {
transposes.push_back(i);
}
i++;
}

getI64Attr(builder, transposes);
}]>;

def GradDataFilterReshape2 : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
auto featureGroupCount = op.getFeatureGroupCount();
auto batchGroupCount = op.getBatchGroupCount();
auto groupCount = featureGroupCount == 1 ? batchGroupCount : featureGroupCount;

auto rhs = op.getRhs();
auto dimensionNumbers = op.getDimensionNumbers();
auto Ty = cast<RankedTensorType>(rhs.getType());
auto shape = Ty.getShape();

auto odim = dimensionNumbers.getKernelOutputFeatureDimension();
auto idim = dimensionNumbers.getKernelInputFeatureDimension();

SmallVector<int64_t> newShape;
for (int64_t i = 0, e = shape.size(); i < e; ++i) {
if (i == idim) {
newShape.push_back(shape[i] * groupCount);
} else if (i == odim) {
newShape.push_back(shape[i] / groupCount);
} else {
newShape.push_back(shape[i]);
}
}

RankedTensorType::get(newShape, Ty.getElementType());
}]>;

def GradDataConvWindowStrides : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
int64_t N = op.getType().getShape().size() - 2;
llvm::SmallVector<int64_t> windowStrides(N, 1);
Expand All @@ -410,6 +549,15 @@ def GradDataConvPadding : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
initialPadding.value().value_end<int64_t>());
}

auto dilateShape = [](int64_t shape, int64_t dilation) {
if (dilation == 1) return shape;
int64_t dilated = 1 + dilation * (shape - 1);
return dilated < 0 ? 0 : dilated;
};

auto lhsDilations = op.getLhsDilation();
auto rhsDilations = op.getRhsDilation();
auto windowStrides = op.getWindowStrides();
for (int i = 0; i < N; ++i) {
auto weightDim = dimensionNumbers.getKernelSpatialDimensions()[i];
auto dataDim = dimensionNumbers.getInputSpatialDimensions()[i];
Expand All @@ -418,9 +566,19 @@ def GradDataConvPadding : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
auto padBefore = newPaddingValues[2 * i];
auto padAfter = newPaddingValues[2 * i + 1];

auto rhsShape = op.getRhs().getType().getShape()[weightDim];
auto lhsShape = op.getLhs().getType().getShape()[dataDim];
auto outShape = op.getType().getShape()[outputDim];
auto lhsDilation = lhsDilations.has_value() ?
getI64Value(lhsDilations.value(), i) :
1;
auto rhsDilation = rhsDilations.has_value() ?
getI64Value(rhsDilations.value(), i) :
1;
auto windowStride = windowStrides.has_value() ?
getI64Value(windowStrides.value(), i) :
1;

auto lhsShape = dilateShape(op.getLhs().getType().getShape()[dataDim], lhsDilation);
auto rhsShape = dilateShape(op.getRhs().getType().getShape()[weightDim], rhsDilation);
auto outShape = dilateShape(op.getType().getShape()[outputDim], windowStride);

auto newPadBefore = rhsShape - padBefore - 1;
newPaddingValues[2 * i] = newPadBefore;
Expand Down Expand Up @@ -475,11 +633,14 @@ def GradDataConvDimensionNumbers : GlobalExpr</*needsprimal*/0, /*needsshadow*/0
}]>;

def GradDataConvFeatureGroupCount : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
op.getFeatureGroupCountAttr();
auto featureGroupCount = op.getFeatureGroupCount();
auto batchGroupCount = op.getBatchGroupCount();
auto groupCount = featureGroupCount == 1 ? batchGroupCount : featureGroupCount;
builder.getI64IntegerAttr(groupCount);
}]>;

def GradDataConvBatchGroupCount : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
op.getBatchGroupCountAttr();
builder.getI64IntegerAttr(1);
}]>;

// GradFilter
Expand All @@ -500,17 +661,36 @@ def GradFilterConvPadding : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
initialPadding.value().value_end<int64_t>());
}

auto dilateShape = [](int64_t shape, int64_t dilation) {
if (dilation == 1) return shape;
int64_t dilated = 1 + dilation * (shape - 1);
return dilated < 0 ? 0 : dilated;
};

auto lhsDilations = op.getLhsDilation();
auto rhsDilations = op.getRhsDilation();
auto windowStrides = op.getWindowStrides();
for (int i = 0; i < N; ++i) {
auto weightDim = dimensionNumbers.getKernelSpatialDimensions()[i];
auto dataDim = dimensionNumbers.getInputSpatialDimensions()[i];
auto weightDim = dimensionNumbers.getKernelSpatialDimensions()[i];
auto outputDim = dimensionNumbers.getOutputSpatialDimensions()[i];

auto padBefore = newPaddingValues[2 * i];
auto padAfter = newPaddingValues[2 * i + 1];

auto rhsShape = op.getRhs().getType().getShape()[weightDim];
auto lhsShape = op.getLhs().getType().getShape()[dataDim];
auto outShape = op.getType().getShape()[outputDim];
auto lhsDilation = lhsDilations.has_value() ?
getI64Value(lhsDilations.value(), i) :
1;
auto rhsDilation = rhsDilations.has_value() ?
getI64Value(rhsDilations.value(), i) :
1;
auto windowStride = windowStrides.has_value() ?
getI64Value(windowStrides.value(), i) :
1;

auto lhsShape = dilateShape(op.getLhs().getType().getShape()[dataDim], lhsDilation);
auto rhsShape = dilateShape(op.getRhs().getType().getShape()[weightDim], rhsDilation);
auto outShape = dilateShape(op.getType().getShape()[outputDim], windowStride);

newPaddingValues[2 * i] = padBefore;
newPaddingValues[2 * i + 1] = outShape - lhsShape + rhsShape - padBefore - 1;
Expand All @@ -522,6 +702,22 @@ def GradFilterConvPadding : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
newPaddingAttr;
}]>;

def GradFilterConvReverseDims : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
auto windowReversals = op.getWindowReversal();

SmallVector<int64_t> reverseDims;

if (windowReversals.has_value()) {
for (auto it : llvm::enumerate(getBoolIter(windowReversals.value()))) {
if (it.value()) {
reverseDims.push_back(it.index());
}
}
}

getI64Attr(builder, reverseDims);
}]>;

def GradFilterConvLhsDilation : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
op.getLhsDilationAttr();
}]>;
Expand Down Expand Up @@ -552,8 +748,8 @@ def GradFilterConvDimensionNumbers : GlobalExpr</*needsprimal*/0, /*needsshadow*
}]>;

def GradFilterConvFeatureGroupCount : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
unsigned int newFeatureGroupCount = 1;
newFeatureGroupCount;
auto batchGroupCount = op.getBatchGroupCount();
batchGroupCount;
}]>;

def GradFilterConvBatchGroupCount : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
Expand All @@ -564,33 +760,51 @@ def GradFilterConvBatchGroupCount : GlobalExpr</*needsprimal*/0, /*needsshadow*/

def : HLODerivative<"ConvolutionOp", (Op $lhs, $rhs),
[
(Convolution
(Reshape
(TypeOf $lhs),
(DiffeRet),
$rhs,
(GradDataConvWindowStrides),
(GradDataConvPadding),
(GradDataConvLhsDilation),
(GradDataConvRhsDilation),
(GradDataConvWindowReversal),
(GradDataConvDimensionNumbers),
(GradDataConvFeatureGroupCount),
(GradDataConvBatchGroupCount),
(ResultDotPrec)
(Transpose
(Reshape
(GradDataConvBatchGroupCountType),
(Convolution
(GradDataConvOutputType),
(DiffeRet),
(Reshape
(GradDataFilterReshape2),
(Transpose
(Reshape (GradDataFilterReshape1), $rhs),
(GradDataFilterTranspose)
)
),
(GradDataConvWindowStrides),
(GradDataConvPadding),
(GradDataConvLhsDilation),
(GradDataConvRhsDilation),
(GradDataConvWindowReversal),
(GradDataConvDimensionNumbers),
(GradDataConvFeatureGroupCount),
(GradDataConvBatchGroupCount),
(ResultDotPrec)
)
),
(GradDataConvBatchGroupPerm)
)
),
(Convolution
(TypeOf $rhs),
$lhs,
(DiffeRet),
(GradFilterConvWindowStrides),
(GradFilterConvPadding),
(GradFilterConvLhsDilation),
(GradFilterConvRhsDilation),
(GradFilterConvWindowReversal),
(GradFilterConvDimensionNumbers),
(GradFilterConvFeatureGroupCount),
(GradFilterConvBatchGroupCount),
(ResultDotPrec)
(Reverse
(Convolution
(TypeOf $rhs),
$lhs,
(DiffeRet),
(GradFilterConvWindowStrides),
(GradFilterConvPadding),
(GradFilterConvLhsDilation),
(GradFilterConvRhsDilation),
(GradFilterConvWindowReversal),
(GradFilterConvDimensionNumbers),
(GradFilterConvFeatureGroupCount),
(GradFilterConvBatchGroupCount),
(ResultDotPrec)
),
(GradFilterConvReverseDims)
)
],
(Add
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ static mlir::DenseIntElementsAttr getI64Attr(OpBuilder &builder,
return builder.getI64VectorAttr(vals);
}

static int64_t getI64Value(mlir::DenseIntElementsAttr attr, size_t pos) {
return attr.getValues<int64_t>()[pos];
}

static mlir::DenseElementsAttr getBoolAttr(OpBuilder &builder,
llvm::ArrayRef<bool> vals) {
return builder.getBoolVectorAttr(vals);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ static mlir::DenseI64ArrayAttr getI64Attr(OpBuilder &builder,
return builder.getDenseI64ArrayAttr(vals);
}

static int64_t getI64Value(llvm::ArrayRef<int64_t> attr, size_t pos) {
return attr[pos];
}

static mlir::DenseBoolArrayAttr getBoolAttr(OpBuilder &builder,
llvm::ArrayRef<bool> vals) {
return builder.getDenseBoolArrayAttr(vals);
Expand Down
Loading

0 comments on commit 74e879b

Please sign in to comment.