diff --git a/pytorch_blade/pytorch_blade/compiler/jit/torch/shape_analysis.cpp b/pytorch_blade/pytorch_blade/compiler/jit/torch/shape_analysis.cpp index 8cb6d2ed963..43c09bedd51 100644 --- a/pytorch_blade/pytorch_blade/compiler/jit/torch/shape_analysis.cpp +++ b/pytorch_blade/pytorch_blade/compiler/jit/torch/shape_analysis.cpp @@ -955,6 +955,8 @@ class ShapePropagator : public PropertyPropBase { "aten::index_put.hacked_twin(Tensor self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor", "aten::index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)", "aten::scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor", + "aten::scatter(Tensor self, int dim, Tensor index, Tensor value) -> Tensor", + "aten::scatter_add(Tensor self, int dim, Tensor index, Tensor value) -> Tensor", #if PYTORCH_VERSION_GE(1, 13) "aten::select_scatter(Tensor self, Tensor src, int dim, int index) -> Tensor", "aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor", diff --git a/pytorch_blade/pytorch_blade/compiler/mlir/converters/torch_mlir_op_filter.cpp b/pytorch_blade/pytorch_blade/compiler/mlir/converters/torch_mlir_op_filter.cpp index 1111b3ba0fd..1c7cc20ca6f 100644 --- a/pytorch_blade/pytorch_blade/compiler/mlir/converters/torch_mlir_op_filter.cpp +++ b/pytorch_blade/pytorch_blade/compiler/mlir/converters/torch_mlir_op_filter.cpp @@ -165,6 +165,8 @@ const std::unordered_set &GetTorchMlirWhiteList() { "aten::sub_inplace", // use aten namespace to work with PyTorch mutation pass "aten::mul_inplace", // use aten namespace to work with PyTorch mutation pass "aten::div_inplace", // use aten namespace to work with PyTorch mutation pass + "aten::scatter", + "aten::scatter_add", "torch_blade::fake_quant", "torch_blade::conv2d_weight_nhwc", }; diff --git a/pytorch_blade/pytorch_blade/torch-mlir/lib/Conversion/TorchToMhlo/DiscTorchToMhlo.cpp b/pytorch_blade/pytorch_blade/torch-mlir/lib/Conversion/TorchToMhlo/DiscTorchToMhlo.cpp index 8ccadd46ca9..040439f7005 100644 --- a/pytorch_blade/pytorch_blade/torch-mlir/lib/Conversion/TorchToMhlo/DiscTorchToMhlo.cpp +++ b/pytorch_blade/pytorch_blade/torch-mlir/lib/Conversion/TorchToMhlo/DiscTorchToMhlo.cpp @@ -1542,6 +1542,7 @@ Value getNormalizedDimSizeInternal( return rewriter.create( loc, indexPositive, index, dimSizePlusIndex); } + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenSliceScatterOp op, @@ -1590,6 +1591,266 @@ LogicalResult ConvertAtenOp::matchAndRewrite( start_indices); return success(); } + +// Reference implementation: +// https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_lower_util.cpp#L139 +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenScatterSrcOp op, + OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const { + Location loc = op.getLoc(); + Value self = adaptor.getSelf(); + Value index = adaptor.getIndex(); + Value src = adaptor.getSrc(); + + RankedTensorType selfType = self.getType().cast(); + RankedTensorType indexType = index.getType().cast(); + RankedTensorType srcType = src.getType().cast(); + + if (selfType.getRank() != indexType.getRank() || + indexType.getRank() != srcType.getRank()) + return rewriter.notifyMatchFailure( + op, + "'self', 'index' and 'src' should all" + "have the same number of dimensions."); + + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "unimplemented: dim is not constant"); + + // Insert SliceOp for optimization + Value zero = + rewriter.create(loc, rewriter.getI32IntegerAttr(0)); + Value one = + rewriter.create(loc, rewriter.getI32IntegerAttr(1)); + SmallVector baseIndices(srcType.getRank()); + SmallVector limitIndices(srcType.getRank()); + SmallVector strides(srcType.getRank()); + RankedTensorType indicesType = + RankedTensorType::get({srcType.getRank()}, rewriter.getIntegerType(64)); + for (int i = 0; i < srcType.getRank(); i++) { + baseIndices[i] = zero; + strides[i] = one; + limitIndices[i] = rewriter.create( + loc, rewriter.getI32IntegerAttr(indexType.getShape()[i])); + } + + Value baseIndicesValue = + rewriter.create(loc, baseIndices); + Value stridesValue = rewriter.create(loc, strides); + Value limitIndicesValue = + rewriter.create(loc, limitIndices); + + auto sliceOpResultType = + RankedTensorType::get(indexType.getShape(), srcType.getElementType()); + src = rewriter.create( + loc, + getTypeConverter()->convertType(sliceOpResultType), + src, + baseIndicesValue, + limitIndicesValue, + stridesValue); + + // Construct ScatterDimensionNumbersAttr + int64_t indexVectorDim = srcType.getRank(); + SmallVector updateWindowDimsVec; + SmallVector insertWindowDimsVec; + SmallVector scatterDimsToOperationDimsVec; + for (int i = 0; i < indexVectorDim; i++) { + insertWindowDimsVec.push_back(i); + scatterDimsToOperationDimsVec.push_back(i); + } + auto scatterDimension = mhlo::ScatterDimensionNumbersAttr::get( + rewriter.getContext(), + updateWindowDimsVec, + insertWindowDimsVec, + scatterDimsToOperationDimsVec, + indexVectorDim); + + // Convert index to scatter_indices + limitIndices.push_back(one); + auto indexShape = rewriter.create(loc, limitIndices); + + auto originalShapeVec = indexType.getShape().vec(); + originalShapeVec.push_back(1); + + auto iotaType = + RankedTensorType::get(originalShapeVec, indexType.getElementType()); + SmallVector toConcat; + for (int i = 0; i < indexVectorDim; i++) { + if (i == dim) { + toConcat.push_back(rewriter.create( + loc, getTypeConverter()->convertType(iotaType), index, indexShape)); + } else { + toConcat.push_back(rewriter.create(loc, iotaType, i)); + } + } + Value scatter_indices = + rewriter.create(loc, toConcat, indexVectorDim); + + // Construct mhlo::ScatterOp + auto mhloScatterOp = rewriter.create( + loc, + getTypeConverter()->convertType(op.getType()), + ValueRange{self}, + scatter_indices, + ValueRange{src}, + scatterDimension, + false, + false); + + // Construct updateComputation region, here we treat it as update operation + Block& block = mhloScatterOp.getUpdateComputation().emplaceBlock(); + auto blockArg1Type = RankedTensorType::get({}, srcType.getElementType()); + auto blockArg2Type = RankedTensorType::get({}, srcType.getElementType()); + block.addArgument(blockArg1Type, loc); + block.addArgument(blockArg2Type, loc); + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + rewriter.create(op->getLoc(), block.getArgument(1)); + } + + // Replace Op + rewriter.replaceOp(op, mhloScatterOp.getResults()); + + return success(); +} + +// Reference implementation: +// https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_lower_util.cpp#L139 +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenScatterAddOp op, + OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const { + Location loc = op.getLoc(); + Value self = adaptor.getSelf(); + Value index = adaptor.getIndex(); + Value src = adaptor.getSrc(); + + RankedTensorType selfType = self.getType().cast(); + RankedTensorType indexType = index.getType().cast(); + RankedTensorType srcType = src.getType().cast(); + + if (selfType.getRank() != indexType.getRank() || + indexType.getRank() != srcType.getRank()) + return rewriter.notifyMatchFailure( + op, + "'self', 'index' and 'src' should all" + "have the same number of dimensions."); + + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "unimplemented: dim is not constant"); + + // Insert SliceOp for optimization + Value zero = + rewriter.create(loc, rewriter.getI32IntegerAttr(0)); + Value one = + rewriter.create(loc, rewriter.getI32IntegerAttr(1)); + SmallVector baseIndices(srcType.getRank()); + SmallVector limitIndices(srcType.getRank()); + SmallVector strides(srcType.getRank()); + RankedTensorType indicesType = + RankedTensorType::get({srcType.getRank()}, rewriter.getIntegerType(64)); + for (int i = 0; i < srcType.getRank(); i++) { + baseIndices[i] = zero; + strides[i] = one; + limitIndices[i] = rewriter.create( + loc, rewriter.getI32IntegerAttr(indexType.getShape()[i])); + } + + Value baseIndicesValue = + rewriter.create(loc, baseIndices); + Value stridesValue = rewriter.create(loc, strides); + Value limitIndicesValue = + rewriter.create(loc, limitIndices); + + auto sliceOpResultType = + RankedTensorType::get(indexType.getShape(), srcType.getElementType()); + src = rewriter.create( + loc, + getTypeConverter()->convertType(sliceOpResultType), + src, + baseIndicesValue, + limitIndicesValue, + stridesValue); + + // Construct ScatterDimensionNumbersAttr + int64_t indexVectorDim = srcType.getRank(); + SmallVector updateWindowDimsVec; + SmallVector insertWindowDimsVec; + SmallVector scatterDimsToOperationDimsVec; + for (int i = 0; i < indexVectorDim; i++) { + insertWindowDimsVec.push_back(i); + scatterDimsToOperationDimsVec.push_back(i); + } + auto scatterDimension = mhlo::ScatterDimensionNumbersAttr::get( + rewriter.getContext(), + updateWindowDimsVec, + insertWindowDimsVec, + scatterDimsToOperationDimsVec, + indexVectorDim); + + // Convert index to scatter_indices + limitIndices.push_back(one); + auto indexShape = rewriter.create(loc, limitIndices); + + auto originalShapeVec = indexType.getShape().vec(); + originalShapeVec.push_back(1); + + auto iotaType = + RankedTensorType::get(originalShapeVec, indexType.getElementType()); + SmallVector toConcat; + for (int i = 0; i < indexVectorDim; i++) { + if (i == dim) { + toConcat.push_back(rewriter.create( + loc, getTypeConverter()->convertType(iotaType), index, indexShape)); + } else { + toConcat.push_back(rewriter.create(loc, iotaType, i)); + } + } + Value scatter_indices = + rewriter.create(loc, toConcat, indexVectorDim); + + // Construct mhlo::ScatterOp + auto mhloScatterOp = rewriter.create( + loc, + getTypeConverter()->convertType(op.getType()), + ValueRange{self}, + scatter_indices, + ValueRange{src}, + scatterDimension, + false, + false); + + // Construct updateComputation region, here we treat it as update operation + Block& block = mhloScatterOp.getUpdateComputation().emplaceBlock(); + auto blockArg1Type = RankedTensorType::get({}, srcType.getElementType()); + auto blockArg2Type = RankedTensorType::get({}, srcType.getElementType()); + block.addArgument(blockArg1Type, loc); + block.addArgument(blockArg2Type, loc); + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + auto retValue = + rewriter + .create( + op->getLoc(), block.getArgument(0), block.getArgument(1)) + .getResult(); + rewriter.create(op->getLoc(), retValue); + } + + // Replace Op + rewriter.replaceOp(op, mhloScatterOp.getResults()); + + return success(); +} + } // namespace namespace { @@ -1807,6 +2068,8 @@ class DiscConvertTorchToMhlo INSERT_ATENOP_PATTERN(AtenFillScalarOp); INSERT_ATENOP_PATTERN(OverwriteTensorContentsOp); INSERT_ATENOP_PATTERN(AtenSliceScatterOp); + INSERT_ATENOP_PATTERN(AtenScatterSrcOp); + INSERT_ATENOP_PATTERN(AtenScatterAddOp); #undef INSERT_ATENOP_PATTERN #define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, MhloOp) \ diff --git a/pytorch_blade/tests/disc/ops/test_scatter.py b/pytorch_blade/tests/disc/ops/test_scatter.py new file mode 100644 index 00000000000..879b362488a --- /dev/null +++ b/pytorch_blade/tests/disc/ops/test_scatter.py @@ -0,0 +1,87 @@ +# Copyright 2021 The BladeDISC Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import unittest +from tests.disc.testing_base import DiscTestCase, skipTorchLE + +@skipTorchLE("1.6.1") +class TestAtenScatter(DiscTestCase): + + def test_scatter(self): + if self.device != torch.device('cuda'): + return + + @torch.jit.script + def scatter_func(destination, place_at, source): + a = torch.scatter(destination, 0, place_at, source) + return a + + destination = torch.tensor( + [ + [4.0, 0.0, 3.0, 1.0, 0.0], + [0.0, 5.0, 8.0, 0.0, 0.0], + [6.0, 0.0, 0.0, 9.0, 0.0] + ], dtype=torch.float32, device=self.device) + + source = torch.tensor( + [ + [0.3992, 0.2908, 0.9044, 0.4850, 0.6004], + [0.5735, 0.9006, 0.6797, 0.4152, 0.1732] + ], dtype=torch.float32, device=self.device) + + place_at = torch.tensor( + [ + [0, 1, 2, 0], + [2, 0, 0, 1] + ], dtype=torch.int64, device=self.device) + + annotations = [(list(destination.shape), torch.float32), (list( + place_at.shape), torch.int64), (list(source.shape), torch.float32)] + self._test_disc(scatter_func, annotations, + (destination, place_at, source)) + + def test_scatteradd(self): + if self.device != torch.device('cuda'): + return + + @torch.jit.script + def scatter_func(destination, place_at, source): + a = torch.scatter_add(destination, 0, place_at, source) + return a + + destination = torch.tensor( + [ + [4.0, 0.0, 3.0, 1.0, 0.0], + [0.0, 5.0, 8.0, 0.0, 0.0], + [6.0, 0.0, 0.0, 9.0, 0.0] + ], dtype=torch.float32, device=self.device) + + source = torch.tensor( + [ + [0.3992, 0.2908, 0.9044, 0.4850, 0.6004], + [0.5735, 0.9006, 0.6797, 0.4152, 0.1732] + ], dtype=torch.float32, device=self.device) + + place_at = torch.tensor( + [ + [0, 1, 2, 0], + [2, 0, 0, 1] + ], dtype=torch.int64, device=self.device) + + annotations = [(list(destination.shape), torch.float32), (list( + place_at.shape), torch.int64), (list(source.shape), torch.float32)] + self._test_disc(scatter_func, annotations, + (destination, place_at, source)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tao_compiler/mlir/disc/transforms/codegen_utils.cc b/tao_compiler/mlir/disc/transforms/codegen_utils.cc old mode 100755 new mode 100644 index 22eb07ddc69..7588d607fb9 --- a/tao_compiler/mlir/disc/transforms/codegen_utils.cc +++ b/tao_compiler/mlir/disc/transforms/codegen_utils.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/disc/IR/disc_shape_ops.h" +#include "mlir/disc/IR/lhlo_disc_ops.h" #include "mlir/disc/disc_util.h" using mlir::memref::DimOp; @@ -108,6 +109,12 @@ Value emitNumElementsComputation(OpBuilder& b, Location loc, Operation* op) { op->getOperand(0) == op->getOperand(num_operands - 1)) { return emitNumElementsComputation(b, loc, op->getOperand(1)); } + + if (isa(op) && + op->getOperand(0) == op->getOperand(num_operands - 1)) { + return emitNumElementsComputation(b, loc, op->getOperand(2)); + } + Value result_memref = op->getOperand(num_operands - 1); return emitNumElementsComputation(b, loc, result_memref); } diff --git a/tao_compiler/mlir/disc/transforms/disc_lhlo_rewriter.cc b/tao_compiler/mlir/disc/transforms/disc_lhlo_rewriter.cc index b727d21a627..bdc298bf945 100644 --- a/tao_compiler/mlir/disc/transforms/disc_lhlo_rewriter.cc +++ b/tao_compiler/mlir/disc/transforms/disc_lhlo_rewriter.cc @@ -141,6 +141,36 @@ struct LhloConcatenateOpConverter } }; +struct LhloScatterOpConverter : public OpRewritePattern { + explicit LhloScatterOpConverter(MLIRContext* context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(lmhlo::ScatterOp lhloOp, + PatternRewriter& rewriter) const override { + Operation* op = lhloOp.getOperation(); + + auto operands = op->getOperands(); + + // Already rewrited + if (operands[0] == operands[3]) { + return failure(); + } + + SmallVector ins(operands.begin(), operands.end()); + rewriter.create(op->getLoc(), operands[0], operands[3]); + ins[0] = operands[3]; + auto newOp = rewriter.create(op->getLoc(), TypeRange{}, + ins, op->getAttrs()); + // Copy over the operations inside the region. + rewriter.inlineRegionBefore(lhloOp.getUpdateComputation(), + newOp.getUpdateComputation(), + newOp.getUpdateComputation().end()); + + rewriter.eraseOp(op); + + return success(); + } +}; + struct DiscLhloRewriterPass : public DiscLhloRewriterPassBase { using DiscLhloRewriterPassBase< @@ -164,6 +194,7 @@ struct DiscLhloRewriterPass target.addIllegalOp(); patterns.insert(&context); + patterns.insert(&context); patterns.insert(&context); if (failed( applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) diff --git a/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc b/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc index 025081f15c6..bd6b7296dd2 100644 --- a/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc +++ b/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc @@ -436,6 +436,15 @@ Value getCopyRemovableResult(Operation* op) { if (!IsSmallCpuBuffer(result) && !IsSmallCpuBuffer(op->getOperand(0))) return result; #else + // If the copy result buffer is used by other inplace op as result buffer, + // it is not removable. + for (Operation* user : result.getUsers()) { + if (isInplaceOperator(user) && + result == cast(user).getResultBuffer()) { + return {}; + } + } + if (placement_utils::isGpuMemRef(result)) return result; #endif } @@ -446,6 +455,7 @@ Value getCopyRemovableResult(Operation* op) { Value getRootMemRefIfSafe(Value memref) { Value rootMemRef = memref; DenseSet knownSafeOpSet; + while (auto view = dyn_cast_or_null( rootMemRef.getDefiningOp())) { knownSafeOpSet.insert(view); diff --git a/tao_compiler/mlir/disc/transforms/fusion_utils.cc b/tao_compiler/mlir/disc/transforms/fusion_utils.cc index 78f03d765d6..2475f1d94fb 100644 --- a/tao_compiler/mlir/disc/transforms/fusion_utils.cc +++ b/tao_compiler/mlir/disc/transforms/fusion_utils.cc @@ -71,7 +71,7 @@ bool enableTransposeLibraryCall() { } bool isInplaceOperator(Operation* op) { - if (isa(op)) { + if (isa(op) && !isa(op)) { Value resultMemref = cast(op).getResultBuffer(); for (int i = 0; i < op->getNumOperands() - 1; ++i) { if (op->getOperand(i) == resultMemref) { diff --git a/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc b/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc index 5301bc7208b..e47cf39c957 100644 --- a/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc +++ b/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc @@ -162,6 +162,18 @@ Operation* getReduceOperator(Region& body) { return calc_op; } +Operation* tryGetUpdateOperator(Region& body) { + Operation* calc_op = nullptr; + int64_t num_calc_ops = 0; + body.walk([&](Operation* op) { + if (isa(op)) { + calc_op = op; + return; + } + }); + return calc_op; +} + AccumulatorFactory getFactory(OpBuilder& b, Location loc, Region& body) { return AccumulatorFactory([&](Value lhs, Value rhs) { auto calc_op = getReduceOperator(body); @@ -440,7 +452,6 @@ Value elementalLower(OpBuilder* b, Location loc, mayCreateStore(b, loc, op.getOperation(), result, output_index, lower_config); return result; } - namespace { template @@ -514,6 +525,221 @@ Value elementalLower( return result; } +Value LowerInplaceScatterOp(OpBuilder* b, Location loc, lmhlo::ScatterOp op, + ValueRange update_index, bool check_cache, + LowerConfig* lower_config) { + Value indices_memref = op->getOperand(1); + Value update_memref = op->getOperand(2); + Value result_memref = op->getOperand(3); + + int update_memref_rank = + update_memref.getType().dyn_cast().getRank(); + int scatter_indices_rank = + indices_memref.getType().dyn_cast().getRank(); + int input_rank = result_memref.getType().dyn_cast().getRank(); + auto scatter_indices_shape = + indices_memref.getType().dyn_cast().getShape(); + + auto scatter_dimension_numbers = op.getScatterDimensionNumbers(); + auto update_window_dims = scatter_dimension_numbers.getUpdateWindowDims(); + auto index_vector_dim = scatter_dimension_numbers.getIndexVectorDim(); + auto scatter_dims_to_operand = + scatter_dimension_numbers.getScatterDimsToOperandDims(); + auto inserted_window_dims = scatter_dimension_numbers.getInsertedWindowDims(); + + Value zero = b->create(loc, 0); + int rank = update_index.size(); + SmallVector update_scatter_dims, update_scatter_index, start_index, + update_window_index, full_window_index, result_index, full_start_index; + + int inner_index1 = 0; + int inner_index2 = 0; + + // update_scatter_dims = [d for d in axes(updates[0]) and d not in + // update_window_dims] update_scatter_index = + // update_index[update_scatter_dims...] + for (inner_index1 = 0; inner_index1 < update_memref_rank; inner_index1++) { + bool add_to_update_scatter = true; + for (inner_index2 = 0; inner_index2 < update_window_dims.size(); + inner_index2++) { + if (inner_index1 == update_window_dims[inner_index2]) + add_to_update_scatter = false; + } + + if (add_to_update_scatter) { + update_scatter_index.push_back(update_index[inner_index1]); + } + } + + // start_index = scatter_indices[si0, ..., :, ..., siN] if index_vector_dim < + // rank(scatter_indices) else [scatter_indices[update_scatter_index] + if (index_vector_dim < scatter_indices_rank) { + // Require update_scatter_index.size() == scatter_indices_shape.size() - 1 + SmallVector tmpIndex; + tmpIndex.resize(scatter_indices_rank); + inner_index1 = 0; + inner_index2 = 0; + for (inner_index1 = 0; inner_index1 < scatter_indices_shape.size(); + inner_index1++) { + if (inner_index1 == index_vector_dim) { + continue; + } + tmpIndex[inner_index1] = update_scatter_index[inner_index2++]; + } + + for (inner_index1 = 0; + inner_index1 < scatter_indices_shape[index_vector_dim]; + inner_index1++) { + tmpIndex[index_vector_dim] = + b->create(loc, inner_index1); + Value idx; + if (!check_cache) { + idx = createMaySpecificLoad(*b, loc, op.getOperation(), indices_memref, + tmpIndex, lower_config); + } else { + idx = createLoadOrUseCachedValue(loc, b, op.getOperation(), + indices_memref, tmpIndex, + b->saveInsertionPoint(), lower_config); + } + start_index.push_back(idx); + } + } else { + Value idx; + if (!check_cache) { + idx = createMaySpecificLoad(*b, loc, op.getOperation(), indices_memref, + update_scatter_index, lower_config); + } else { + idx = createLoadOrUseCachedValue(loc, b, op.getOperation(), + indices_memref, update_scatter_index, + b->saveInsertionPoint(), lower_config); + } + start_index.push_back(idx); + } + // full_start_index[d_input] = start_index[d_start] if d_input = + // scatter_dims_to_operand_dims[d_start]. full_start_index[d_input] = 0 + // otherwise. + int d_start = 0; + for (int d_input = 0; d_input < input_rank; d_input++) { + if (d_start < scatter_dims_to_operand.size() && + d_input == scatter_dims_to_operand[d_start]) { + full_start_index.push_back(start_index[d_start++]); + } else { + full_start_index.push_back(zero); + } + } + + // update_window_index = update_index[update_window_dims...]. + // full_window_index = [wi0, ..., 0, ..., wiN] where wi are individual + // elements in update_window_index, and 0 is inserted at indices from + // inserted_window_dims Require rank(inputs[0]) = size(update_window_dims) + + // size(inserted_window_dims). + inner_index1 = 0; + inner_index2 = 0; + for (int d_input = 0; d_input < input_rank; d_input++) { + // 0 is inserted at indices from inserted_window_dims + if (inner_index1 < inserted_window_dims.size() & + d_input == inserted_window_dims[inner_index1]) { + full_window_index.push_back(zero); + inner_index1 += 1; + } else { + // wi are individual elements in update_window_index + full_window_index.push_back( + update_index[update_window_dims[inner_index2++]]); + } + } + + // result_index = full_start_index + full_window_index + Value in_bound = b->create(loc, 1, 1); + for (int d_input = 0; d_input < input_rank; d_input++) { + auto dim_res = b->create( + loc, b->getIndexType(), + b->create(loc, full_start_index[d_input], + full_window_index[d_input])); + in_bound = b->create( + loc, in_bound, + b->create( + loc, arith::CmpIPredicate::slt, dim_res, + disc_ral::getDimSizeValue(b, result_memref, d_input))); + result_index.push_back(dim_res); + } + + // Combine computation + auto calc_op = tryGetUpdateOperator(op.getUpdateComputation()); + Value update_value, updated_value; + if (!check_cache) { + update_value = createMaySpecificLoad( + *b, loc, op.getOperation(), update_memref, update_index, lower_config); + } else { + update_value = createLoadOrUseCachedValue( + loc, b, op.getOperation(), update_memref, update_index, + b->saveInsertionPoint(), lower_config); + } + + SmallVector result_types; + result_types.push_back( + result_memref.getType().cast().getElementType()); + + auto if_inbound_op = b->create(loc, result_types, in_bound, true); + if_inbound_op.getThenRegion().front().clear(); + if_inbound_op.getElseRegion().front().clear(); + b->setInsertionPointToEnd(&if_inbound_op.getThenRegion().front()); + if (calc_op == nullptr) { + b->create(loc, update_value, result_memref, result_index); + } else { + Value original_value; + if (!check_cache) { + original_value = + createMaySpecificLoad(*b, loc, op.getOperation(), result_memref, + result_index, lower_config); + } else { + original_value = createLoadOrUseCachedValue( + loc, b, op.getOperation(), result_memref, result_index, + b->saveInsertionPoint(), lower_config); + } + SmallVector operand_values; + operand_values.push_back(original_value); + operand_values.push_back(update_value); + auto num_operands = calc_op->getNumOperands(); + auto result_type = calc_op->getOperand(num_operands - 1).getType(); + auto result_elem_type = result_type.cast().getElementType(); + if (isa(calc_op)) { + updated_value = LhloOpToStdScalarOp::map( + llvm::cast(calc_op), result_elem_type, operand_values, + b); + } else { + assert(false && "unexpected update computation in scatter op"); + } + + b->create(loc, updated_value, result_memref, result_index); + } + b->create(loc, update_value); + b->setInsertionPointToEnd(&if_inbound_op.getElseRegion().front()); + // Else we do nothing + b->create(loc, update_value); + b->setInsertionPointAfter(if_inbound_op); + + Value result = *(if_inbound_op.getResults().begin()); + + return result; +} + +template <> +Value elementalLower(OpBuilder* b, Location loc, + lmhlo::ScatterOp op, + ValueRange output_index, + bool check_cache, + LowerConfig* lower_config) { + int rank = output_index.size(); + Value result_opearnd = op->getOperand(op->getNumOperands() - 1); + Value src_memref = op->getOperand(0); + if (result_opearnd == src_memref) { + return LowerInplaceScatterOp(b, loc, op, output_index, check_cache, + lower_config); + } + op->emitError("Out of place ScatterOp is not supported now"); + return Value(nullptr); +} + template <> Value elementalLower(OpBuilder* b, Location loc, lmhlo::BroadcastInDimOp op, diff --git a/tao_compiler/mlir/disc/transforms/lhlo_fusion.cc b/tao_compiler/mlir/disc/transforms/lhlo_fusion.cc index 2d3dccfc5e0..394d736015a 100644 --- a/tao_compiler/mlir/disc/transforms/lhlo_fusion.cc +++ b/tao_compiler/mlir/disc/transforms/lhlo_fusion.cc @@ -307,9 +307,11 @@ class FusionPlanner { // Thus these operands are supposed to be updated. // Suppose that an op (or its nested ops) can only write the buffers // explicit passed in as operands of this op. - if (op->getDialect()->getTypeID() != TypeID::get() && - op->getDialect()->getTypeID() != - TypeID::get()) { + if (isInplaceOperator(op) || + (op->getDialect()->getTypeID() != + TypeID::get() && + op->getDialect()->getTypeID() != + TypeID::get())) { // If an op is not in lmhlo or lmhlo_disc dialect, it may be written // multiple times (e.g. multiple memref.store ops for the same // underlying buffer). diff --git a/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc b/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc index 7477b8b72be..5cd454b4849 100644 --- a/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc +++ b/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc @@ -168,6 +168,12 @@ LogicalResult miscLowerHelper(OpBuilder& b, Location loc, Operation* opaque_op, if (isa(op) && isInplaceOperator(op)) { memref = cast(&*op).getOperation()->getOperand(1); } + + // for inplace scatter op, output_index according to operand(3) + if (isa(op) && isInplaceOperator(op)) { + memref = cast(&*op).getOperation()->getOperand(2); + } + for (int64_t i = 0; i < vector_size; i++) { Value linear_index = linear_indices[i]; auto multidim_index = calcMultiDimIndex(&b, loc, linear_index, memref); @@ -285,6 +291,8 @@ LogicalResult lowerHelper(OpBuilder& b, Location loc, Operation* op, succeeded(miscLowerHelper( b, loc, op, output_linear_index, shape_analysis, vector_size, lower_config)) || succeeded(miscLowerHelper( + b, loc, op, output_linear_index, shape_analysis, vector_size, lower_config)) || + succeeded(miscLowerHelper( b, loc, op, output_linear_index, shape_analysis, vector_size, lower_config)) ) { return success(); @@ -324,6 +332,7 @@ LogicalResult lowerWithScheduleLoop( } else { (void)createLoopAndSetInsPt(b, loc, var, zero, thread_number, one, {}); } + for (Operation* root_op : root_ops) { // TODO: vectorize here if (failed(lowerHelper(b, loc, root_op, var, shape_analysis, vector_size))) @@ -5749,6 +5758,7 @@ struct DiscLhloLegalizeRootsToParallelLoops // TODO(disc): single nodes with non kLoop schedule like ReduceOp // is not implemented yet. Currently ReduceOp is lowered with loop // schedule, which means for poor performance. + if (failed(lowerWithScheduleLoop({op}, op, nullptr, /*non_fusion=*/true, /*parallel_loop=*/true))) { diff --git a/tao_compiler/mlir/disc/transforms/placement_utils.cc b/tao_compiler/mlir/disc/transforms/placement_utils.cc index 8b99c25eafa..183286164eb 100644 --- a/tao_compiler/mlir/disc/transforms/placement_utils.cc +++ b/tao_compiler/mlir/disc/transforms/placement_utils.cc @@ -32,7 +32,7 @@ bool isGpuMhlo(Operation* op) { } bool isGpuLmhlo(Operation* op) { - if (!op || !isa(op)) { + if (!op || !isa(op) || isa(op)) { return false; } auto attr = diff --git a/tf_community b/tf_community index ef7d949ce74..1f7c9e80e6a 160000 --- a/tf_community +++ b/tf_community @@ -1 +1 @@ -Subproject commit ef7d949ce74bcbaa2517fbb120932be626031f4c +Subproject commit 1f7c9e80e6a9eb786d960b4b84133d645a36e39c