Skip to content

Commit

Permalink
support scatter op and update tf_community submodule (#1272)
Browse files Browse the repository at this point in the history
  • Loading branch information
eedalong authored Dec 19, 2023
1 parent 7286cd9 commit 67c3242
Show file tree
Hide file tree
Showing 13 changed files with 647 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,8 @@ class ShapePropagator : public PropertyPropBase {
"aten::index_put.hacked_twin(Tensor self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor",
"aten::index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)",
"aten::scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor",
"aten::scatter(Tensor self, int dim, Tensor index, Tensor value) -> Tensor",
"aten::scatter_add(Tensor self, int dim, Tensor index, Tensor value) -> Tensor",
#if PYTORCH_VERSION_GE(1, 13)
"aten::select_scatter(Tensor self, Tensor src, int dim, int index) -> Tensor",
"aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ const std::unordered_set<std::string> &GetTorchMlirWhiteList() {
"aten::sub_inplace", // use aten namespace to work with PyTorch mutation pass
"aten::mul_inplace", // use aten namespace to work with PyTorch mutation pass
"aten::div_inplace", // use aten namespace to work with PyTorch mutation pass
"aten::scatter",
"aten::scatter_add",
"torch_blade::fake_quant",
"torch_blade::conv2d_weight_nhwc",
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1542,6 +1542,7 @@ Value getNormalizedDimSizeInternal(
return rewriter.create<arith::SelectOp>(
loc, indexPositive, index, dimSizePlusIndex);
}

template <>
LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
AtenSliceScatterOp op,
Expand Down Expand Up @@ -1590,6 +1591,266 @@ LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
start_indices);
return success();
}

// Reference implementation:
// https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_lower_util.cpp#L139
template <>
LogicalResult ConvertAtenOp<AtenScatterSrcOp>::matchAndRewrite(
AtenScatterSrcOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const {
Location loc = op.getLoc();
Value self = adaptor.getSelf();
Value index = adaptor.getIndex();
Value src = adaptor.getSrc();

RankedTensorType selfType = self.getType().cast<RankedTensorType>();
RankedTensorType indexType = index.getType().cast<RankedTensorType>();
RankedTensorType srcType = src.getType().cast<RankedTensorType>();

if (selfType.getRank() != indexType.getRank() ||
indexType.getRank() != srcType.getRank())
return rewriter.notifyMatchFailure(
op,
"'self', 'index' and 'src' should all"
"have the same number of dimensions.");

int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "unimplemented: dim is not constant");

// Insert SliceOp for optimization
Value zero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
Value one =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
SmallVector<Value> baseIndices(srcType.getRank());
SmallVector<Value> limitIndices(srcType.getRank());
SmallVector<Value> strides(srcType.getRank());
RankedTensorType indicesType =
RankedTensorType::get({srcType.getRank()}, rewriter.getIntegerType(64));
for (int i = 0; i < srcType.getRank(); i++) {
baseIndices[i] = zero;
strides[i] = one;
limitIndices[i] = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(indexType.getShape()[i]));
}

Value baseIndicesValue =
rewriter.create<tensor::FromElementsOp>(loc, baseIndices);
Value stridesValue = rewriter.create<tensor::FromElementsOp>(loc, strides);
Value limitIndicesValue =
rewriter.create<tensor::FromElementsOp>(loc, limitIndices);

auto sliceOpResultType =
RankedTensorType::get(indexType.getShape(), srcType.getElementType());
src = rewriter.create<mhlo::RealDynamicSliceOp>(
loc,
getTypeConverter()->convertType(sliceOpResultType),
src,
baseIndicesValue,
limitIndicesValue,
stridesValue);

// Construct ScatterDimensionNumbersAttr
int64_t indexVectorDim = srcType.getRank();
SmallVector<int64_t> updateWindowDimsVec;
SmallVector<int64_t> insertWindowDimsVec;
SmallVector<int64_t> scatterDimsToOperationDimsVec;
for (int i = 0; i < indexVectorDim; i++) {
insertWindowDimsVec.push_back(i);
scatterDimsToOperationDimsVec.push_back(i);
}
auto scatterDimension = mhlo::ScatterDimensionNumbersAttr::get(
rewriter.getContext(),
updateWindowDimsVec,
insertWindowDimsVec,
scatterDimsToOperationDimsVec,
indexVectorDim);

// Convert index to scatter_indices
limitIndices.push_back(one);
auto indexShape = rewriter.create<tensor::FromElementsOp>(loc, limitIndices);

auto originalShapeVec = indexType.getShape().vec();
originalShapeVec.push_back(1);

auto iotaType =
RankedTensorType::get(originalShapeVec, indexType.getElementType());
SmallVector<Value> toConcat;
for (int i = 0; i < indexVectorDim; i++) {
if (i == dim) {
toConcat.push_back(rewriter.create<mhlo::DynamicReshapeOp>(
loc, getTypeConverter()->convertType(iotaType), index, indexShape));
} else {
toConcat.push_back(rewriter.create<mhlo::IotaOp>(loc, iotaType, i));
}
}
Value scatter_indices =
rewriter.create<mhlo::ConcatenateOp>(loc, toConcat, indexVectorDim);

// Construct mhlo::ScatterOp
auto mhloScatterOp = rewriter.create<mhlo::ScatterOp>(
loc,
getTypeConverter()->convertType(op.getType()),
ValueRange{self},
scatter_indices,
ValueRange{src},
scatterDimension,
false,
false);

// Construct updateComputation region, here we treat it as update operation
Block& block = mhloScatterOp.getUpdateComputation().emplaceBlock();
auto blockArg1Type = RankedTensorType::get({}, srcType.getElementType());
auto blockArg2Type = RankedTensorType::get({}, srcType.getElementType());
block.addArgument(blockArg1Type, loc);
block.addArgument(blockArg2Type, loc);
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), block.getArgument(1));
}

// Replace Op
rewriter.replaceOp(op, mhloScatterOp.getResults());

return success();
}

// Reference implementation:
// https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_lower_util.cpp#L139
template <>
LogicalResult ConvertAtenOp<AtenScatterAddOp>::matchAndRewrite(
AtenScatterAddOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const {
Location loc = op.getLoc();
Value self = adaptor.getSelf();
Value index = adaptor.getIndex();
Value src = adaptor.getSrc();

RankedTensorType selfType = self.getType().cast<RankedTensorType>();
RankedTensorType indexType = index.getType().cast<RankedTensorType>();
RankedTensorType srcType = src.getType().cast<RankedTensorType>();

if (selfType.getRank() != indexType.getRank() ||
indexType.getRank() != srcType.getRank())
return rewriter.notifyMatchFailure(
op,
"'self', 'index' and 'src' should all"
"have the same number of dimensions.");

int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "unimplemented: dim is not constant");

// Insert SliceOp for optimization
Value zero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
Value one =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
SmallVector<Value> baseIndices(srcType.getRank());
SmallVector<Value> limitIndices(srcType.getRank());
SmallVector<Value> strides(srcType.getRank());
RankedTensorType indicesType =
RankedTensorType::get({srcType.getRank()}, rewriter.getIntegerType(64));
for (int i = 0; i < srcType.getRank(); i++) {
baseIndices[i] = zero;
strides[i] = one;
limitIndices[i] = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(indexType.getShape()[i]));
}

Value baseIndicesValue =
rewriter.create<tensor::FromElementsOp>(loc, baseIndices);
Value stridesValue = rewriter.create<tensor::FromElementsOp>(loc, strides);
Value limitIndicesValue =
rewriter.create<tensor::FromElementsOp>(loc, limitIndices);

auto sliceOpResultType =
RankedTensorType::get(indexType.getShape(), srcType.getElementType());
src = rewriter.create<mhlo::RealDynamicSliceOp>(
loc,
getTypeConverter()->convertType(sliceOpResultType),
src,
baseIndicesValue,
limitIndicesValue,
stridesValue);

// Construct ScatterDimensionNumbersAttr
int64_t indexVectorDim = srcType.getRank();
SmallVector<int64_t> updateWindowDimsVec;
SmallVector<int64_t> insertWindowDimsVec;
SmallVector<int64_t> scatterDimsToOperationDimsVec;
for (int i = 0; i < indexVectorDim; i++) {
insertWindowDimsVec.push_back(i);
scatterDimsToOperationDimsVec.push_back(i);
}
auto scatterDimension = mhlo::ScatterDimensionNumbersAttr::get(
rewriter.getContext(),
updateWindowDimsVec,
insertWindowDimsVec,
scatterDimsToOperationDimsVec,
indexVectorDim);

// Convert index to scatter_indices
limitIndices.push_back(one);
auto indexShape = rewriter.create<tensor::FromElementsOp>(loc, limitIndices);

auto originalShapeVec = indexType.getShape().vec();
originalShapeVec.push_back(1);

auto iotaType =
RankedTensorType::get(originalShapeVec, indexType.getElementType());
SmallVector<Value> toConcat;
for (int i = 0; i < indexVectorDim; i++) {
if (i == dim) {
toConcat.push_back(rewriter.create<mhlo::DynamicReshapeOp>(
loc, getTypeConverter()->convertType(iotaType), index, indexShape));
} else {
toConcat.push_back(rewriter.create<mhlo::IotaOp>(loc, iotaType, i));
}
}
Value scatter_indices =
rewriter.create<mhlo::ConcatenateOp>(loc, toConcat, indexVectorDim);

// Construct mhlo::ScatterOp
auto mhloScatterOp = rewriter.create<mhlo::ScatterOp>(
loc,
getTypeConverter()->convertType(op.getType()),
ValueRange{self},
scatter_indices,
ValueRange{src},
scatterDimension,
false,
false);

// Construct updateComputation region, here we treat it as update operation
Block& block = mhloScatterOp.getUpdateComputation().emplaceBlock();
auto blockArg1Type = RankedTensorType::get({}, srcType.getElementType());
auto blockArg2Type = RankedTensorType::get({}, srcType.getElementType());
block.addArgument(blockArg1Type, loc);
block.addArgument(blockArg2Type, loc);
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
auto retValue =
rewriter
.create<mhlo::AddOp>(
op->getLoc(), block.getArgument(0), block.getArgument(1))
.getResult();
rewriter.create<mhlo::ReturnOp>(op->getLoc(), retValue);
}

// Replace Op
rewriter.replaceOp(op, mhloScatterOp.getResults());

return success();
}

} // namespace

namespace {
Expand Down Expand Up @@ -1807,6 +2068,8 @@ class DiscConvertTorchToMhlo
INSERT_ATENOP_PATTERN(AtenFillScalarOp);
INSERT_ATENOP_PATTERN(OverwriteTensorContentsOp);
INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
INSERT_ATENOP_PATTERN(AtenScatterSrcOp);
INSERT_ATENOP_PATTERN(AtenScatterAddOp);
#undef INSERT_ATENOP_PATTERN

#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, MhloOp) \
Expand Down
87 changes: 87 additions & 0 deletions pytorch_blade/tests/disc/ops/test_scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2021 The BladeDISC Authors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import unittest
from tests.disc.testing_base import DiscTestCase, skipTorchLE

@skipTorchLE("1.6.1")
class TestAtenScatter(DiscTestCase):

def test_scatter(self):
if self.device != torch.device('cuda'):
return

@torch.jit.script
def scatter_func(destination, place_at, source):
a = torch.scatter(destination, 0, place_at, source)
return a

destination = torch.tensor(
[
[4.0, 0.0, 3.0, 1.0, 0.0],
[0.0, 5.0, 8.0, 0.0, 0.0],
[6.0, 0.0, 0.0, 9.0, 0.0]
], dtype=torch.float32, device=self.device)

source = torch.tensor(
[
[0.3992, 0.2908, 0.9044, 0.4850, 0.6004],
[0.5735, 0.9006, 0.6797, 0.4152, 0.1732]
], dtype=torch.float32, device=self.device)

place_at = torch.tensor(
[
[0, 1, 2, 0],
[2, 0, 0, 1]
], dtype=torch.int64, device=self.device)

annotations = [(list(destination.shape), torch.float32), (list(
place_at.shape), torch.int64), (list(source.shape), torch.float32)]
self._test_disc(scatter_func, annotations,
(destination, place_at, source))

def test_scatteradd(self):
if self.device != torch.device('cuda'):
return

@torch.jit.script
def scatter_func(destination, place_at, source):
a = torch.scatter_add(destination, 0, place_at, source)
return a

destination = torch.tensor(
[
[4.0, 0.0, 3.0, 1.0, 0.0],
[0.0, 5.0, 8.0, 0.0, 0.0],
[6.0, 0.0, 0.0, 9.0, 0.0]
], dtype=torch.float32, device=self.device)

source = torch.tensor(
[
[0.3992, 0.2908, 0.9044, 0.4850, 0.6004],
[0.5735, 0.9006, 0.6797, 0.4152, 0.1732]
], dtype=torch.float32, device=self.device)

place_at = torch.tensor(
[
[0, 1, 2, 0],
[2, 0, 0, 1]
], dtype=torch.int64, device=self.device)

annotations = [(list(destination.shape), torch.float32), (list(
place_at.shape), torch.int64), (list(source.shape), torch.float32)]
self._test_disc(scatter_func, annotations,
(destination, place_at, source))


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 67c3242

Please sign in to comment.