diff --git a/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc b/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc index 6c717f611ec..4a2132c0de1 100644 --- a/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc +++ b/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc @@ -132,6 +132,95 @@ LogicalResult SliceOpConvert::matchAndRewrite(mhlo::SliceOp op, op, op.getType(), operand, startIndices, limitIndices, strides); return success(); } + +struct ScatterOpConverter : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(mhlo::ScatterOp hloOp, + PatternRewriter& rewriter) const override { + Operation* op = hloOp.getOperation(); + auto operands = hloOp.getOperands(); + + if (!operands[0] + .getType() + .cast() + .getElementType() + .isBF16()) { + return failure(); + } + + auto& body = hloOp.getUpdateComputation(); + Operation* calcOp = nullptr; + body.walk([&](Operation* op) { + if (llvm::isa(op)) + calcOp = op; + }); + if (!calcOp) { + return failure(); + } + + // Insert Type Conversion + auto selfBf16Type = operands[0].getType().cast(); + auto selfFp32Type = RankedTensorType::get(selfBf16Type.getShape().vec(), + rewriter.getF32Type()); + + auto srcBf16Type = operands[2].getType().cast(); + auto srcFp32Type = RankedTensorType::get(srcBf16Type.getShape().vec(), + rewriter.getF32Type()); + + auto self_convert_type = rewriter.create( + hloOp.getLoc(), selfFp32Type, hloOp.getOperand(0)); + auto src_convert_type = rewriter.create( + hloOp.getLoc(), srcFp32Type, hloOp.getOperand(2)); + + SmallVector newOperands; + newOperands.push_back(self_convert_type->getResult(0)); + newOperands.push_back(operands[1]); + newOperands.push_back(src_convert_type->getResult(0)); + + auto scatterOp = rewriter.create( + hloOp.getLoc(), selfFp32Type, newOperands, hloOp->getAttrs()); + + auto res_convert_back = rewriter.create( + hloOp.getLoc(), selfBf16Type, scatterOp->getResult(0)); + + // Construct updateComputation region, here we treat it as update operation + Block& block = scatterOp.getUpdateComputation().emplaceBlock(); + auto blockArgType = + RankedTensorType::get({}, selfFp32Type.getElementType()); + block.addArgument(blockArgType, hloOp.getLoc()); + block.addArgument(blockArgType, hloOp.getLoc()); + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + auto createOp = [&](auto opType) { + return rewriter + .create(op->getLoc(), block.getArgument(0), + block.getArgument(1)) + .getResult(); + }; + + Value retValue; + if (llvm::isa(calcOp)) { + retValue = createOp(mhlo::AddOp()); + } else if (llvm::isa(calcOp)) { + retValue = createOp(mhlo::MulOp()); + } else if (llvm::isa(calcOp)) { + retValue = createOp(mhlo::MaxOp()); + } else if (llvm::isa(calcOp)) { + retValue = createOp(mhlo::MinOp()); + } else { + // Should not happen + retValue = block.getArgument(1); + } + rewriter.create(op->getLoc(), retValue); + } + + rewriter.replaceOp(hloOp, res_convert_back->getResult(0)); + + return success(); + } +}; } // namespace struct MhloDecompositionRewriterPass @@ -148,6 +237,7 @@ struct MhloDecompositionRewriterPass patterns.insert(ctx); patterns.insert(ctx); patterns.insert(ctx); + patterns.insert(ctx); if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { func.emitError("applyPatternsAndFoldGreedily does not converge"); signalPassFailure(); diff --git a/tao_compiler/mlir/disc/transforms/tests/mhlo_decomp_rewriter.mlir b/tao_compiler/mlir/disc/transforms/tests/mhlo_decomp_rewriter.mlir old mode 100644 new mode 100755 index 5768a2f224b..4229d957d6b --- a/tao_compiler/mlir/disc/transforms/tests/mhlo_decomp_rewriter.mlir +++ b/tao_compiler/mlir/disc/transforms/tests/mhlo_decomp_rewriter.mlir @@ -35,3 +35,21 @@ func.func @batch_norm_inference(%arg0: tensor, %arg1: tensor<128x %0 = "mhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>) -> tensor return %0: tensor } + +// ----- + +// CHECK-LABEL: @bf16_scatter +func.func @bf16_scatter(%arg0: tensor<2x3xbf16>, %arg1: tensor<2x2xi64>, %arg2: tensor<2xbf16>) -> tensor<2x3xbf16>{ + // CHECK: %0 = mhlo.convert %arg0 : (tensor<2x3xbf16>) -> tensor<2x3xf32> + // CHECK: %1 = mhlo.convert %arg2 : (tensor<2xbf16>) -> tensor<2xf32> + %1 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg6: tensor, %arg7: tensor): + %15 = "mhlo.convert"(%arg6) : (tensor) -> tensor + %16 = "mhlo.convert"(%arg7) : (tensor) -> tensor + %17 = "mhlo.add"(%15, %16) : (tensor, tensor) -> tensor + %18 = "mhlo.convert"(%17) : (tensor) -> tensor + mhlo.return %18 : tensor + }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false} : (tensor<2x3xbf16>, tensor<2x2xi64>, tensor<2xbf16>) -> tensor<2x3xbf16> + // CHECK: %3 = mhlo.convert %2 : (tensor<2x3xf32>) -> tensor<2x3xbf16> + return %1 : tensor<2x3xbf16> +} \ No newline at end of file