Skip to content

Commit

Permalink
always try to lower standalone transposeOp to custom call
Browse files Browse the repository at this point in the history
  • Loading branch information
eedalong committed Feb 23, 2024
1 parent 28960f0 commit 99cc4f4
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 72 deletions.
48 changes: 4 additions & 44 deletions tao_compiler/mlir/disc/transforms/disc_input_output_alias.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -141,57 +141,17 @@ struct DiscInputOutputAliasPass
if (outputs[outputs_index[i]] == params[params_index[i]]) {
continue;
}
// Inplace buffer reuse.
bool inplace_reuse = false;
// DISC now only support one-hop buffer sharing.
auto defineOp = outputs[outputs_index[i]].getDefiningOp();
for (const auto& value : defineOp->getOperands()) {
if (params[params_index[i]] == value) {
builder.setInsertionPointAfterValue(outputs[outputs_index[i]]);
builder.create<mhlo_disc::ArgsMutationOp>(
outputs[outputs_index[i]].getLoc(), outputs[outputs_index[i]],
params[params_index[i]]);
inplace_reuse = true;
builder.create<mhlo_disc::ArgsMutationOp>(main_func.getLoc(),
outputs[outputs_index[i]],
params[params_index[i]]);
break;
}
}

// Try one-hop buffer sharing propogation
if (!inplace_reuse) {
OneHopBufferReusePropogation(params[params_index[i]],
outputs[outputs_index[i]], builder);
}
}
}

private:
/*
A = op(src)
A = op(src) => args_mutation(A, src)
B = op(A) => B = op(A)
args_mutation(B, A)
*/
void OneHopBufferReusePropogation(Value src, Value dst, OpBuilder& builder) {
auto dst_op = dst.getDefiningOp();
auto user_begin = src.user_begin();
auto user_end = src.user_end();
auto users_cnt = std::distance(user_begin, user_end);

if (users_cnt > 1 || user_begin->getNumResults() > 1) {
return;
}

auto user_result = user_begin->getResult(0);
for (const auto& operand : dst_op->getOperands()) {
if (operand == user_result) {
builder.setInsertionPointAfterValue(user_result);
builder.create<mhlo_disc::ArgsMutationOp>(user_result.getLoc(),
user_result, src);

builder.setInsertionPointAfterValue(dst);
builder.create<mhlo_disc::ArgsMutationOp>(dst.getLoc(), dst,
user_result);
break;
}
}
}
};
Expand Down
5 changes: 2 additions & 3 deletions tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ struct TransposeConverter : public OpRewritePattern<lmhlo::TransposeOp> {
if (rank != 2 && rank != 3) return failure();
// only rewriter custom library when switch 1 and 2 dimensions of
// a 3d tensor, that means permute = [0, 2, 1]
if (rank == 3 && permutation[1] != 2 && permutation[2] != 1)
if (rank == 3 && (permutation[1] != 2 || permutation[2] != 1))
return failure();
bool on_gpu = placement_utils::isGpuMemRef(op->getOperand(0));
// TODO: support other device
Expand Down Expand Up @@ -914,8 +914,7 @@ struct DiscLowerToLibraryCallPass
SendOutputOpConvertor
>(context);
// clang-format on
if (enableTransposeLibraryCall())
patterns.insert<TransposeConverter>(context);
patterns.insert<TransposeConverter>(context);

// GPU copy related ops
patterns.insert<GpuCopyOpConvertor<H2DOp>>(context, "h2d");
Expand Down

This file was deleted.

This file was deleted.

0 comments on commit 99cc4f4

Please sign in to comment.