Skip to content

Commit

Permalink
Support bf16 ScatterOp (#1299)
Browse files Browse the repository at this point in the history
support bf16 scatter
  • Loading branch information
eedalong authored Jun 4, 2024
1 parent 6a5228e commit 800c697
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 0 deletions.
90 changes: 90 additions & 0 deletions tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,95 @@ LogicalResult SliceOpConvert::matchAndRewrite(mhlo::SliceOp op,
op, op.getType(), operand, startIndices, limitIndices, strides);
return success();
}

struct ScatterOpConverter : public OpRewritePattern<mhlo::ScatterOp> {
public:
using OpRewritePattern<mhlo::ScatterOp>::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::ScatterOp hloOp,
PatternRewriter& rewriter) const override {
Operation* op = hloOp.getOperation();
auto operands = hloOp.getOperands();

if (!operands[0]
.getType()
.cast<RankedTensorType>()
.getElementType()
.isBF16()) {
return failure();
}

auto& body = hloOp.getUpdateComputation();
Operation* calcOp = nullptr;
body.walk([&](Operation* op) {
if (llvm::isa<mhlo::AddOp, mhlo::MulOp, mhlo::MaxOp, mhlo::MinOp>(op))
calcOp = op;
});
if (!calcOp) {
return failure();
}

// Insert Type Conversion
auto selfBf16Type = operands[0].getType().cast<RankedTensorType>();
auto selfFp32Type = RankedTensorType::get(selfBf16Type.getShape().vec(),
rewriter.getF32Type());

auto srcBf16Type = operands[2].getType().cast<RankedTensorType>();
auto srcFp32Type = RankedTensorType::get(srcBf16Type.getShape().vec(),
rewriter.getF32Type());

auto self_convert_type = rewriter.create<mhlo::ConvertOp>(
hloOp.getLoc(), selfFp32Type, hloOp.getOperand(0));
auto src_convert_type = rewriter.create<mhlo::ConvertOp>(
hloOp.getLoc(), srcFp32Type, hloOp.getOperand(2));

SmallVector<Value> newOperands;
newOperands.push_back(self_convert_type->getResult(0));
newOperands.push_back(operands[1]);
newOperands.push_back(src_convert_type->getResult(0));

auto scatterOp = rewriter.create<mhlo::ScatterOp>(
hloOp.getLoc(), selfFp32Type, newOperands, hloOp->getAttrs());

auto res_convert_back = rewriter.create<mhlo::ConvertOp>(
hloOp.getLoc(), selfBf16Type, scatterOp->getResult(0));

// Construct updateComputation region, here we treat it as update operation
Block& block = scatterOp.getUpdateComputation().emplaceBlock();
auto blockArgType =
RankedTensorType::get({}, selfFp32Type.getElementType());
block.addArgument(blockArgType, hloOp.getLoc());
block.addArgument(blockArgType, hloOp.getLoc());
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
auto createOp = [&](auto opType) {
return rewriter
.create<decltype(opType)>(op->getLoc(), block.getArgument(0),
block.getArgument(1))
.getResult();
};

Value retValue;
if (llvm::isa<mhlo::AddOp>(calcOp)) {
retValue = createOp(mhlo::AddOp());
} else if (llvm::isa<mhlo::MulOp>(calcOp)) {
retValue = createOp(mhlo::MulOp());
} else if (llvm::isa<mhlo::MaxOp>(calcOp)) {
retValue = createOp(mhlo::MaxOp());
} else if (llvm::isa<mhlo::MinOp>(calcOp)) {
retValue = createOp(mhlo::MinOp());
} else {
// Should not happen
retValue = block.getArgument(1);
}
rewriter.create<mhlo::ReturnOp>(op->getLoc(), retValue);
}

rewriter.replaceOp(hloOp, res_convert_back->getResult(0));

return success();
}
};
} // namespace

struct MhloDecompositionRewriterPass
Expand All @@ -148,6 +237,7 @@ struct MhloDecompositionRewriterPass
patterns.insert<BatchNormInferenceOpConvert>(ctx);
patterns.insert<PadOpConvert>(ctx);
patterns.insert<SliceOpConvert>(ctx);
patterns.insert<ScatterOpConverter>(ctx);
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
func.emitError("applyPatternsAndFoldGreedily does not converge");
signalPassFailure();
Expand Down
18 changes: 18 additions & 0 deletions tao_compiler/mlir/disc/transforms/tests/mhlo_decomp_rewriter.mlir
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,21 @@ func.func @batch_norm_inference(%arg0: tensor<?x128x?x?xf32>, %arg1: tensor<128x
%0 = "mhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x128x?x?xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>) -> tensor<?x128x?x?xf32>
return %0: tensor<?x128x?x?xf32>
}

// -----

// CHECK-LABEL: @bf16_scatter
func.func @bf16_scatter(%arg0: tensor<2x3xbf16>, %arg1: tensor<2x2xi64>, %arg2: tensor<2xbf16>) -> tensor<2x3xbf16>{
// CHECK: %0 = mhlo.convert %arg0 : (tensor<2x3xbf16>) -> tensor<2x3xf32>
// CHECK: %1 = mhlo.convert %arg2 : (tensor<2xbf16>) -> tensor<2xf32>
%1 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
^bb0(%arg6: tensor<bf16>, %arg7: tensor<bf16>):
%15 = "mhlo.convert"(%arg6) : (tensor<bf16>) -> tensor<f32>
%16 = "mhlo.convert"(%arg7) : (tensor<bf16>) -> tensor<f32>
%17 = "mhlo.add"(%15, %16) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%18 = "mhlo.convert"(%17) : (tensor<f32>) -> tensor<bf16>
mhlo.return %18 : tensor<bf16>
}) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = false} : (tensor<2x3xbf16>, tensor<2x2xi64>, tensor<2xbf16>) -> tensor<2x3xbf16>
// CHECK: %3 = mhlo.convert %2 : (tensor<2x3xf32>) -> tensor<2x3xbf16>
return %1 : tensor<2x3xbf16>
}

0 comments on commit 800c697

Please sign in to comment.