Skip to content

Commit

Permalink
Add a method to sort L3->L2 in "compatible" pairs
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-Varma committed Sep 6, 2024
1 parent 74f2434 commit c5d3d0c
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,60 +50,6 @@ int64_t calculateNbIterations(int64_t lowerBound, int64_t upperBound,

namespace {

/// Utility affine expression visitor to retrieve the scale and optional bias
/// from the expression.
struct RetrieveScaleAndBias
: public AffineExprVisitor<RetrieveScaleAndBias, LogicalResult> {
std::optional<int64_t> scale;
std::optional<int64_t> bias;
LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr /*expr*/) {
return failure();
}
LogicalResult visitConstantExpr(AffineConstantExpr /*expr*/) {
return failure();
}
LogicalResult visitDimExpr(AffineDimExpr /*expr*/) { return failure(); }
LogicalResult visitSymbolExpr(AffineSymbolExpr /*expr*/) { return failure(); }
LogicalResult visitMulExpr(AffineBinaryOpExpr expr) {
if (auto rhsSize = dyn_cast<AffineConstantExpr>(expr.getRHS());
isa<AffineDimExpr>(expr.getLHS())) {
scale = rhsSize.getValue();
} else if (auto lhsSize = dyn_cast<AffineConstantExpr>(expr.getLHS());
isa<AffineDimExpr>(expr.getRHS())) {
scale = lhsSize.getValue();
}
return success();
}
LogicalResult visitAddExpr(AffineBinaryOpExpr expr) {
if (bias) return failure();
if (auto rhsSize = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
bias = rhsSize.getValue();
if (bias.value() < 0) return failure();
if (isa<AffineBinaryOpExpr>(expr.getLHS())) {
return visit(expr.getLHS());
} else if (isa<AffineDimExpr>(expr.getLHS())) {
scale = 1;
return success();
} else {
return failure();
}
} else if (auto lhsSize = dyn_cast<AffineConstantExpr>(expr.getLHS())) {
bias = lhsSize.getValue();
if (bias.value() < 0) return failure();
if (isa<AffineBinaryOpExpr>(expr.getRHS())) {
return visit(expr.getRHS());
} else if (isa<AffineDimExpr>(expr.getRHS())) {
scale = 1;
return success();
} else {
return failure();
}
} else {
return failure();
}
}
};

struct SubsumeLoopIntoDMA
: public OpInterfaceRewritePattern<AMDAIE::DoublyStridedOpInterface> {
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include "iree-amd-aie/IR/AMDAIEOps.h"
#include "iree-amd-aie/aie_runtime/iree_aie_runtime.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
Expand All @@ -21,6 +23,60 @@ namespace mlir::iree_compiler::AMDAIE {
/// Utility to retrieve a constant index from an OpFoldResult.
int64_t getConstantIndexOrAssert(OpFoldResult dim);

/// Utility affine expression visitor to retrieve the scale and optional bias
/// from the expression.
struct RetrieveScaleAndBias
: public AffineExprVisitor<RetrieveScaleAndBias, LogicalResult> {
std::optional<int64_t> scale;
std::optional<int64_t> bias;
LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr /*expr*/) {
return failure();
}
LogicalResult visitConstantExpr(AffineConstantExpr /*expr*/) {
return failure();
}
LogicalResult visitDimExpr(AffineDimExpr /*expr*/) { return failure(); }
LogicalResult visitSymbolExpr(AffineSymbolExpr /*expr*/) { return failure(); }
LogicalResult visitMulExpr(AffineBinaryOpExpr expr) {
if (auto rhsSize = dyn_cast<AffineConstantExpr>(expr.getRHS());
isa<AffineDimExpr>(expr.getLHS())) {
scale = rhsSize.getValue();
} else if (auto lhsSize = dyn_cast<AffineConstantExpr>(expr.getLHS());
isa<AffineDimExpr>(expr.getRHS())) {
scale = lhsSize.getValue();
}
return success();
}
LogicalResult visitAddExpr(AffineBinaryOpExpr expr) {
if (bias) return failure();
if (auto rhsSize = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
bias = rhsSize.getValue();
if (bias.value() < 0) return failure();
if (isa<AffineBinaryOpExpr>(expr.getLHS())) {
return visit(expr.getLHS());
} else if (isa<AffineDimExpr>(expr.getLHS())) {
scale = 1;
return success();
} else {
return failure();
}
} else if (auto lhsSize = dyn_cast<AffineConstantExpr>(expr.getLHS())) {
bias = lhsSize.getValue();
if (bias.value() < 0) return failure();
if (isa<AffineBinaryOpExpr>(expr.getRHS())) {
return visit(expr.getRHS());
} else if (isa<AffineDimExpr>(expr.getRHS())) {
scale = 1;
return success();
} else {
return failure();
}
} else {
return failure();
}
}
};

// Constant specifying the number of inter-iteration dimension for DMA
// operations.
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,63 @@ static CoreOp fetchUniqueCoreOp(DmaCpyNdOp &l2ToL1DmaOp) {
return coreOps[0];
}

static bool compareL3ToL2DmaPair(DmaCpyNdOp &a, DmaCpyNdOp &b) {
SmallVector<OpFoldResult> sourceOffsetsA = a.getSourceMixedOffsets();
SmallVector<OpFoldResult> sourceSizesA = a.getSourceMixedSizes();
SmallVector<OpFoldResult> sourceOffsetsB = b.getSourceMixedOffsets();
SmallVector<OpFoldResult> sourceSizesB = b.getSourceMixedSizes();
// We'll add assertion checks on the size before invoking this function.
for (int64_t i = 0, n = sourceOffsetsA.size(); i < n; i++) {
std::optional<int64_t> offsetA = getConstantIntValue(sourceOffsetsA[i]);
std::optional<int64_t> offsetB = getConstantIntValue(sourceOffsetsB[i]);
if (offsetA && offsetB) {
if (offsetA < offsetB) return true;
if (offsetA > offsetB) return false;
continue;
}
if (!offsetA && !offsetB) {
auto offsetValA = cast<Value>(sourceOffsetsA[i]);
auto offsetValB = cast<Value>(sourceOffsetsB[i]);
auto affineApplyOpA = dyn_cast_if_present<affine::AffineApplyOp>(
offsetValA.getDefiningOp());
auto affineApplyOpB = dyn_cast_if_present<affine::AffineApplyOp>(
offsetValB.getDefiningOp());
// TODO(avarma): This should be handled better. The overall possibility
// here already makes this complex enough.
assert(affineApplyOpA && "expected affine.apply op");
assert(affineApplyOpB && "expected affine.apply op");
for (auto &&[valA, valB] :
llvm::zip_equal(affineApplyOpA.getMapOperands(),
affineApplyOpB.getMapOperands())) {
assert((valA == valB) &&
"different base values being operated on between the L3->L2 Dma "
"op pair");
}
AffineMap affineMapA = affineApplyOpA.getAffineMap();
AffineMap affineMapB = affineApplyOpB.getAffineMap();
RetrieveScaleAndBias retrieverA, retrieverB;
assert(!failed(retrieverA.visit(affineMapA.getResult(0))) &&
"failed to retrieve scale and bias");
assert(!failed(retrieverB.visit(affineMapB.getResult(0))) &&
"failed to retrieve scale and bias");
int64_t biasA = 0, biasB = 0;
if (retrieverA.bias) {
biasA = retrieverA.bias.value();
}
if (retrieverB.bias) {
biasB = retrieverB.bias.value();
}
// TODO(avarma): We should also check the scale value as well.
if (biasA < biasB) return true;
if (biasA > biasB) return false;
continue;
}
assert(false &&
"unexpected combination of offset val amongst L3->L2 Dma pair");
}
return false;
}

LogicalResult combineLogicalObjectFifos(
IRRewriter &rewriter, SmallVector<AMDAIE::DmaCpyNdOp> &l2ToL1DmaOps,
MLIRContext *context) {
Expand Down Expand Up @@ -568,6 +625,8 @@ LogicalResult combineLogicalObjectFifos(
SmallVector<OpFoldResult> sourceSizes =
l3ToL2DmaOps[i].getSourceMixedSizes();
unsigned j = 0, m = sourceOffsets.size();
// Traverse through the i-th L3->L2 Dma op's source offset/size to find a
// continuous sequence of 0 offset dims with size as 1.
while (j < m) {
std::optional<int64_t> constantOffset =
getConstantIntValue(sourceOffsets[j]);
Expand All @@ -593,6 +652,26 @@ LogicalResult combineLogicalObjectFifos(
SmallVector<int64_t> nonSplitDims(maxSplitDimIndex + 1);
std::iota(nonSplitDims.begin(), nonSplitDims.end(), splitDims.size());

// At this point it's nice to perhaps just sort the L3->L2 Dma ops based on
// the "overlapping" offsets. And we'll sort the corresponding L2->L1 Dma ops
// accordingly.
for (int64_t i = 1, n = l3ToL2DmaOps.size(); i < n; i++) {
DmaCpyNdOp currL3ToL2DmaOp = l3ToL2DmaOps[i];
DmaCpyNdOp currL2ToL1DmaOp = l2ToL1DmaOps[i];
int64_t j = i - 1;
while (j >= 0 && compareL3ToL2DmaPair(currL3ToL2DmaOp, l3ToL2DmaOps[j])) {
l3ToL2DmaOps[j + 1] = l3ToL2DmaOps[j];
l2ToL1DmaOps[j + 1] = l2ToL1DmaOps[j];
j--;
}
l3ToL2DmaOps[j + 1] = currL3ToL2DmaOp;
l2ToL1DmaOps[j + 1] = currL2ToL1DmaOp;
}

for (auto x : l3ToL2DmaOps) {
llvm::outs() << "===> " << x << "\n";
llvm::outs().flush();
}
// For now pick the first two L3->L2 Dma op and try to combine them. Later
// we'll implement the selector.
////////////////////////////////////////////////
Expand All @@ -612,13 +691,7 @@ LogicalResult combineLogicalObjectFifos(
/////// COMBINE the picked L3->L2 pair /////////
////////////////////////////////////////////////
{
/// The maximum number of addressing dimensions on the source side of the
/// DMA.
// int64_t sourceMaxNbDims{0};
// /// The maximum number of addressing dimensions on the target side of
// the DMA. int64_t targetMaxNbDims{0};
OpBuilder::InsertionGuard guard(rewriter);
// rewriter.setInsertionPoint(op);
SmallVector<OpFoldResult> sourceOffsetsA = op.getSourceMixedOffsets();
SmallVector<OpFoldResult> sourceSizesA = op.getSourceMixedSizes();
SmallVector<OpFoldResult> sourceStridesA = op.getSourceMixedStrides();
Expand Down Expand Up @@ -759,8 +832,6 @@ LogicalResult combineLogicalObjectFifos(
}
});
}
// llvm::outs() << "NOT Compatible\n";
// llvm::outs().flush();
}
}

Expand Down

0 comments on commit c5d3d0c

Please sign in to comment.