Skip to content

Commit

Permalink
simplifier scalar tensor shape to scalar i32 (#1310)
Browse files Browse the repository at this point in the history
* simplfier scalar tensor shape to scalar i32

* fix ut

* update
  • Loading branch information
Yancey1989 authored Aug 6, 2024
1 parent 8f21582 commit fbe39bc
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 3 deletions.
95 changes: 93 additions & 2 deletions tao_compiler/mlir/disc/transforms/disc_algebraic_simplifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <stack>

#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -423,7 +425,6 @@ struct BroadCastInDimOfReshapeOpCanonicalizationPattern
return success();
}
};

// Simplifier extract and from-element op pattern, an example as following:
// %0 = tensor.extract %arg0[] : tensor<f32>
// %1 = tensor.from_elements %0 : tensor<1xf32>
Expand All @@ -439,6 +440,8 @@ struct SimplifierFromElementsPattern
auto loc = op->getLoc();
Value input = op->getOperand(0);
Value result = op->getResult(0);
// only support scalar tensor
if (op->getNumOperands() != 1) return failure();
auto extractOp = input.getDefiningOp<tensor::ExtractOp>();
if (!extractOp) return failure();

Expand Down Expand Up @@ -530,6 +533,93 @@ struct IndexCastSimplifierPattern
return failure();
}
};
// Simplify get_dimension_size pattern. An examples as following:
// Case 1):
// %2 = "mhlo.get_dimension_size"(%1)
// %3 = "tensor.extract" %2[] -> i32
// %from_elements = tensor.from_elements %3, ...
// Convert to:
// %2 = "tensor.dim"(%1, %cst0) -> i32
// %from_elements = tensor.from_elements %2, ...
//
// Case 2):
// %2 = "mhlo.get_dimension_size"(%1)
// %3 = mhlo.mul %2, %4
// %4 = "tensor.extract" %3[] -> i32
// %from_elements = tensor.from_elements %4, ...
// Convert to:
// %2 = "tensor.dim"(%1, %cst0) -> i32
// %3 = arith.mul %2, %4
// %from_elements = tensor.from_elements %3, ...

struct SimplifierGetDimensionSizePattern
: public OpRewritePattern<mhlo::GetDimensionSizeOp> {
using OpRewritePattern<mhlo::GetDimensionSizeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::GetDimensionSizeOp getDimOp,
PatternRewriter& rewriter) const override {
auto loc = getDimOp->getLoc();
Value tensor = getDimOp->getOperand(0);
auto dim = getDimOp.getDimension();
auto elemTy = getDimOp.getResult()
.getType()
.cast<RankedTensorType>()
.getElementType();

SmallVector<Operation*, 4> ops;
std::stack<Operation*> stack;
for (auto user : getDimOp->getUsers()) {
if (isa<tensor::ExtractOp, func::ReturnOp, mhlo::ReshapeOp,
mhlo::DynamicBroadcastInDimOp>(user))
continue;
stack.push(user);
}
while (!stack.empty()) {
auto user = stack.top();
stack.pop();
ops.push_back(user);
for (auto op : user->getUsers()) {
if (isa<tensor::ExtractOp, func::ReturnOp, mhlo::ReshapeOp,
mhlo::DynamicBroadcastInDimOp>(op))
continue;
stack.push(op);
}
}
for (auto op : ops) {
auto loc = op->getLoc();
rewriter.setInsertionPoint(op);
auto v1 = rewriter.create<tensor::ExtractOp>(loc, op->getOperand(0));
auto v2 = rewriter.create<tensor::ExtractOp>(loc, op->getOperand(1));
Value newOpValue;
if (isa<mhlo::MulOp>(op)) {
newOpValue = rewriter.create<arith::MulIOp>(loc, v1, v2).getResult();
} else if (isa<mhlo::AddOp>(op)) {
newOpValue = rewriter.create<arith::AddIOp>(loc, v1, v2).getResult();
} else if (isa<mhlo::SubtractOp>(op)) {
newOpValue = rewriter.create<arith::SubIOp>(loc, v1, v2).getResult();
} else if (isa<mhlo::DivOp>(op)) {
newOpValue = rewriter.create<arith::DivSIOp>(loc, v1, v2).getResult();
} else {
return failure();
}
auto result = rewriter.create<tensor::FromElementsOp>(
loc, getDimOp.getResult().getType().cast<RankedTensorType>(),
newOpValue);
op->replaceAllUsesWith(result);
}
rewriter.setInsertionPoint(getDimOp);
auto dimValue =
rewriter.create<tensor::DimOp>(loc, tensor, dim).getResult();
auto castValue = rewriter.create<arith::IndexCastOp>(loc, elemTy, dimValue);
auto dimValueTensor =
rewriter
.create<tensor::FromElementsOp>(
loc, getDimOp.getResult().getType().cast<RankedTensorType>(),
ValueRange{castValue})
.getResult();
getDimOp.replaceAllUsesWith(dimValueTensor);
return success();
}
};

// Consant folding the broadcasted constant, for patterns like:
// %0 = mhlo.constant // Scalar or splat constant
Expand Down Expand Up @@ -627,7 +717,8 @@ void populateDiscAlgebraicSimplifierPatterns(RewritePatternSet& patterns) {
IdentityBroadCastInDimOpCanonicalizationPattern<mhlo::DynamicBroadcastInDimOp>,
SimplifierFromElementsPattern,
TrunciSimplifierPattern,
IndexCastSimplifierPattern
IndexCastSimplifierPattern,
SimplifierGetDimensionSizePattern
>(patterns.getContext());
if (isMemIntensiveOptExperimentalEnabled()) {
// Will be enabled by default after a set of robustness testing.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,4 +256,21 @@ func.func @select_simp(%arg0: tensor<16xf16>) -> (tensor<20xf16>, tensor<20xf16>
%10 = mhlo.constant dense<false> : tensor<20xi1>
%11 = "mhlo.select"(%10, %7, %3): (tensor<20xi1>, tensor<20xf16>, tensor<20xf16>) -> tensor<20xf16>
return %8, %3 : tensor<20xf16>, tensor<20xf16>
}
}

// -----

// CHECK-LABEL: @main
func.func @main(%arg0: tensor<?x10xf32>, %arg1: tensor<10xf32>) -> tensor<?x10xf32> {
%c10_i32 = arith.constant 10 : i32
%c_0 = mhlo.constant dense<4> : tensor<i32>
// CHECK: %dim = tensor.dim %arg0, %c0 : tensor<?x10xf32>
// CHECK: %0 = arith.index_cast %dim : index to i32
%2 = "mhlo.get_dimension_size"(%arg0) {dimension = 0 : i64} : (tensor<?x10xf32>) -> tensor<i32>
// CHECK: %1 = arith.muli %0, %c4_i32 : i32
%3 = mhlo.multiply %2, %c_0 : tensor<i32>
%extracted = tensor.extract %3[] : tensor<i32>
%from_elements = tensor.from_elements %extracted, %c10_i32 : tensor<2xi32>
%4 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %from_elements) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<10xf32>, tensor<2xi32>) -> tensor<?x10xf32>
return %4 : tensor<?x10xf32>
}

0 comments on commit fbe39bc

Please sign in to comment.