Skip to content

Commit 21ad890

Browse files
authored
[Torch] enhance fold of aten.slice.Tensor (llvm#3557)
so that it could support folding slice with any static shape.
1 parent 7884642 commit 21ad890

File tree

2 files changed

+50
-24
lines changed

2 files changed

+50
-24
lines changed

lib/Dialect/Torch/IR/TorchOps.cpp

+29-18
Original file line numberDiff line numberDiff line change
@@ -3625,42 +3625,53 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
36253625
return DenseElementsAttr::get(outType.toBuiltinTensor(),
36263626
input.getSplatValue<Attribute>());
36273627

3628-
int count = 1;
3628+
int64_t count = 1;
36293629
for (auto dim : outType.getSizes())
36303630
count = count * dim;
3631-
36323631
if (count == 0)
3633-
return {};
3632+
return nullptr;
36343633

36353634
if (!dim)
36363635
return nullptr;
36373636
int64_t dimInt = dim.getValue().getSExtValue();
36383637
if (dimInt < 0)
36393638
dimInt += inType.getSizes().size();
36403639

3641-
bool unaryNonDim = true;
3642-
for (int i = 0, s = outType.getSizes().size(); i < s; ++i)
3643-
unaryNonDim &= outType.getSizes()[i] == 1 || i == dimInt;
3644-
36453640
// Fold the slice if the output tensor is relatively small, currently
36463641
// coded to 16:
3647-
if (input && start && step && dim && count < 16 && unaryNonDim &&
3648-
count < 16) {
3649-
int64_t inCount = input.getNumElements();
3642+
constexpr int64_t kMaxFold = 16;
3643+
if (input && start && step && dim && count <= kMaxFold) {
36503644
int64_t begin = start.getValue().getSExtValue();
3645+
int64_t limit = end.getValue().getSExtValue();
36513646
int64_t stride = step.getValue().getSExtValue();
36523647
if (stride < 1)
3653-
return {};
3654-
int64_t limit = end.getValue().getSExtValue();
3655-
begin = begin < 0 ? begin + inCount : begin;
3656-
limit = limit < 0 ? limit + inCount : limit;
3657-
limit = limit < 0 ? inType.getSizes()[dimInt] : limit;
3648+
return nullptr;
3649+
begin = begin < 0 ? begin + inType.getSizes()[dimInt] : begin;
3650+
limit = limit < 0 ? limit + inType.getSizes()[dimInt] : limit;
36583651
limit = std::min(limit, inType.getSizes()[dimInt]);
36593652

3660-
llvm::SmallVector<Attribute> values;
3661-
for (int i = begin; i < limit; i += stride)
3662-
values.push_back(input.getValues<Attribute>()[i]);
3653+
int64_t inputRank = inType.getSizes().size();
3654+
llvm::SmallVector<int64_t> inputStrides(inputRank, 1);
3655+
for (int64_t i = inputRank - 2; i >= 0; i--) {
3656+
inputStrides[i] = inputStrides[i + 1] * inType.getSizes()[i + 1];
3657+
}
36633658

3659+
llvm::SmallVector<Attribute> values;
3660+
values.reserve(count);
3661+
auto recursiveIter = [&](auto &self, int64_t currDim, int64_t currOffset) {
3662+
if (currDim >= inputRank)
3663+
return;
3664+
size_t _begin = (currDim == dimInt) ? begin : 0;
3665+
size_t _limit = (currDim == dimInt) ? limit : inType.getSizes()[currDim];
3666+
size_t _stride = (currDim == dimInt) ? stride : 1;
3667+
for (size_t i = _begin; i < _limit; i += _stride) {
3668+
if (currDim == inputRank - 1) {
3669+
values.push_back(input.getValues<Attribute>()[currOffset + i]);
3670+
}
3671+
self(self, currDim + 1, currOffset + inputStrides[currDim] * i);
3672+
}
3673+
};
3674+
recursiveIter(recursiveIter, 0, 0);
36643675
return DenseElementsAttr::get(outType.toBuiltinTensor(), values);
36653676
}
36663677

test/Dialect/Torch/canonicalize.mlir

+21-6
Original file line numberDiff line numberDiff line change
@@ -2139,15 +2139,15 @@ func.func @torch.aten.broadcast_to$fold_splat() -> !torch.vtensor<[3,4,2],f32> {
21392139

21402140
// -----
21412141

2142-
// CHECK-LABEL: @torch.aten.slice.tensor$fold_full_domain_slice
2142+
// CHECK-LABEL: @torch.aten.slice.tensor$not_fold_slice
21432143
// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[4],f32>
2144-
// CHECK: return %[[ARG0]] : !torch.vtensor<[4],f32>
2145-
func.func @torch.aten.slice.tensor$fold_full_domain_slice(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[4],f32> {
2144+
// CHECK: torch.aten.slice.Tensor
2145+
func.func @torch.aten.slice.tensor$not_fold_slice(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3],f32> {
21462146
%int1 = torch.constant.int 1
21472147
%int-1 = torch.constant.int -1
21482148
%int0 = torch.constant.int 0
2149-
%0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int-1, %int1 : !torch.vtensor<[4], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4], f32>
2150-
return %0 : !torch.vtensor<[4],f32>
2149+
%0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int-1, %int1 : !torch.vtensor<[4], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3], f32>
2150+
return %0 : !torch.vtensor<[3],f32>
21512151
}
21522152

21532153
// CHECK-LABEL: @torch.aten.slice.tensor$fold_full_slice
@@ -2209,7 +2209,10 @@ func.func @torch.aten.slice.tensor$fold_small() -> (!torch.vtensor<[2],si32>) {
22092209
}
22102210

22112211
// -----
2212-
2212+
// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) {
2213+
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<1.600000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
2214+
// CHECK: %[[CST0:.+]] = torch.vtensor.literal(dense<6.400000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
2215+
// CHECK: return %[[CST]], %[[CST0]]
22132216
func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1],f32>) {
22142217
%tensor = torch.vtensor.literal(dense<[[2.0],[4.0],[8.0],[16.0],[32.0],[64.0],[128.0],[256.0],[512.0],[1024.0]]> : tensor<10x1xf32>) : !torch.vtensor<[10, 1],f32>
22152218
%int0 = torch.constant.int 0
@@ -2224,6 +2227,18 @@ func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>,
22242227
return %0, %1 : !torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1], f32>
22252228
}
22262229

2230+
// -----
2231+
// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_0_non_contiguous() -> !torch.vtensor<[4,1],si64> {
2232+
// CHECK{LITERAL}: %0 = torch.vtensor.literal(dense<[[28], [14], [7], [4]]> : tensor<4x1xsi64>) : !torch.vtensor<[4,1],si64>
2233+
// CHECK: return %0
2234+
func.func @torch.aten.slice.tensor$fold_dim_0_non_contiguous() -> (!torch.vtensor<[4,1],si64>) {
2235+
%int1 = torch.constant.int 1
2236+
%int2 = torch.constant.int 2
2237+
%0 = torch.vtensor.literal(dense<[[28, 28], [14, 14], [7, 7], [4, 4]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64>
2238+
%1 = torch.aten.slice.Tensor %0, %int1, %int1, %int2, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],si64>
2239+
return %1 : !torch.vtensor<[4,1],si64>
2240+
}
2241+
22272242
// -----
22282243

22292244
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {

0 commit comments

Comments
 (0)