Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin' into disc_remat
Browse files Browse the repository at this point in the history
  • Loading branch information
eedalong committed Aug 13, 2024
2 parents c05a979 + 72eec1c commit 25cbf46
Show file tree
Hide file tree
Showing 4 changed files with 1,432 additions and 115 deletions.
49 changes: 47 additions & 2 deletions tao_compiler/mlir/disc/transforms/disc_argsmutation_expand.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "mlir/disc/IR/disc_shape_ops.h"
#include "mlir/disc/IR/lhlo_disc_ops.h"
#include "mlir/disc/disc_util.h"
#include "mlir/disc/transforms/PassDetail.h"
#include "mlir/disc/transforms/placement_utils.h"
#include "mlir/disc/transforms/rewriters.h"
Expand All @@ -67,8 +68,52 @@ struct LhloDISCArgsMutationOpConverter
PatternRewriter& rewriter) const override {
auto op = lhloOp.getOperation();
auto operands = op->getOperands();
operands[0].replaceAllUsesWith(operands[1]);
rewriter.eraseOp(op);
if (operands[0] == operands[1]) {
rewriter.eraseOp(op);
return success();
}

if (operands[0].getType().cast<MemRefType>() ==
operands[1].getType().cast<MemRefType>()) {
for (auto user : operands[1].getUsers()) {
// Prevent double dealloc
if (isa<memref::DeallocOp>(user)) {
rewriter.eraseOp(user);
}
}
operands[0].replaceAllUsesWith(operands[1]);
rewriter.eraseOp(op);
// rewriter.eraseOp(operands[0].getDefiningOp<memref::AllocOp>());
} else {
llvm::dbgs() << "Reinterprete cast need to be inserted here between "
<< operands[0] << " and " << operands[1] << "\n";

for (auto user : operands[1].getUsers()) {
if (isa<memref::DeallocOp>(user)) {
rewriter.eraseOp(user);
}
}

auto shape_a = operands[0].getType().cast<MemRefType>().getShape();
auto alloc_a = operands[0].getDefiningOp<memref::AllocOp>();
SmallVector<Value> dimSizes;
int dynamic_dim_idx = 0;
for (int i = 0; i < shape_a.size(); ++i) {
if (shape_a[i] == ShapedType::kDynamic) {
dimSizes.push_back(alloc_a->getOperand(dynamic_dim_idx++));
} else {
dimSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
op->getLoc(), shape_a[i]));
}
}

Value newValue = disc_ral::CastMemRefTo(
rewriter, op->getLoc(), operands[1],
operands[0].getType().cast<MemRefType>(), dimSizes);
operands[0].replaceAllUsesWith(newValue);
rewriter.eraseOp(op);
}

return success();
}
};
Expand Down
Loading

0 comments on commit 25cbf46

Please sign in to comment.