From 14e14a9ae8dd38a924c958ea49d11066f9f5be0c Mon Sep 17 00:00:00 2001 From: nihui Date: Wed, 25 Oct 2023 17:12:03 +0800 Subject: [PATCH] slice with indices (#5103) --- docs/developer-guide/operators.md | 1 + src/layer/arm/slice_arm.cpp | 324 ++++++++++-- src/layer/loongarch/slice_loongarch.cpp | 162 +++++- src/layer/mips/slice_mips.cpp | 162 +++++- src/layer/slice.cpp | 163 +++++- src/layer/slice.h | 1 + src/layer/vulkan/slice_vulkan.cpp | 464 +++++++++++++++--- src/layer/x86/slice_x86.cpp | 162 +++++- tests/test_slice.cpp | 41 +- .../pass_ncnn/convert_torch_tensor_split.cpp | 42 +- 10 files changed, 1319 insertions(+), 203 deletions(-) diff --git a/docs/developer-guide/operators.md b/docs/developer-guide/operators.md index 4cabb049340..11d41295f91 100644 --- a/docs/developer-guide/operators.md +++ b/docs/developer-guide/operators.md @@ -1722,6 +1722,7 @@ split x along axis into slices, each part slice size is based on slices array | --------- | ------------- | ----- | --------- | ----------------- | | 0 | slices | array | [ ] | | | 1 | axis | int | 0 | | +| 2 | indices | array | [ ] | | # Softmax ``` diff --git a/src/layer/arm/slice_arm.cpp b/src/layer/arm/slice_arm.cpp index 7d8c8bef763..a1c04e3f73d 100644 --- a/src/layer/arm/slice_arm.cpp +++ b/src/layer/arm/slice_arm.cpp @@ -51,6 +51,7 @@ int Slice_arm::forward(const std::vector& bottom_blobs, std::vector& t size_t elemsize = bottom_blob.elemsize; int elempack = bottom_blob.elempack; const int* slices_ptr = slices; + const int* indices_ptr = indices; int positive_axis = axis < 0 ? dims + axis : axis; if (dims == 1) // positive_axis == 0 @@ -60,10 +61,27 @@ int Slice_arm::forward(const std::vector& bottom_blobs, std::vector& t int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) { - slice = (w - q) / (top_blobs.size() - i); + if (i == top_blobs.size() - 1) + { + slice = w - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? w + indice : indice; + slice = positive_indice - q; + } + } + else + { + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((w - q) / (top_blobs.size() - i)); + } } int out_elempack = 1; @@ -97,10 +115,27 @@ int Slice_arm::forward(const std::vector& bottom_blobs, std::vector& t int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = h - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? h + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (h - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((h - q) / (top_blobs.size() - i)); + } } int out_elempack = 1; @@ -184,10 +219,27 @@ int Slice_arm::forward(const std::vector& bottom_blobs, std::vector& t int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = w - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? w + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (w - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((w - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; @@ -225,10 +277,27 @@ int Slice_arm::forward(const std::vector& bottom_blobs, std::vector& t int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = channels - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? channels + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (channels - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((channels - q) / (top_blobs.size() - i)); + } } int out_elempack = 1; @@ -319,10 +388,27 @@ int Slice_arm::forward(const std::vector& bottom_blobs, std::vector& t int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = h - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? h + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (h - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((h - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; @@ -368,10 +454,27 @@ int Slice_arm::forward(const std::vector& bottom_blobs, std::vector& t int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = w - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? w + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (w - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((w - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; @@ -417,10 +520,27 @@ int Slice_arm::forward(const std::vector& bottom_blobs, std::vector& t int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = d - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? d + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (d - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((d - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; @@ -460,6 +580,7 @@ int Slice_arm::forward_bf16s_fp16s(const std::vector& bottom_blobs, std::ve size_t elemsize = bottom_blob.elemsize; int elempack = bottom_blob.elempack; const int* slices_ptr = slices; + const int* indices_ptr = indices; int positive_axis = axis < 0 ? dims + axis : axis; if (dims == 1) // positive_axis == 0 @@ -469,10 +590,27 @@ int Slice_arm::forward_bf16s_fp16s(const std::vector& bottom_blobs, std::ve int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) { - slice = (w - q) / (top_blobs.size() - i); + if (i == top_blobs.size() - 1) + { + slice = w - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? w + indice : indice; + slice = positive_indice - q; + } + } + else + { + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((w - q) / (top_blobs.size() - i)); + } } int out_elempack = 1; @@ -510,10 +648,27 @@ int Slice_arm::forward_bf16s_fp16s(const std::vector& bottom_blobs, std::ve int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = h - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? h + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (h - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((h - q) / (top_blobs.size() - i)); + } } int out_elempack = 1; @@ -663,10 +818,27 @@ int Slice_arm::forward_bf16s_fp16s(const std::vector& bottom_blobs, std::ve int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = w - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? w + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (w - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((w - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; @@ -704,10 +876,27 @@ int Slice_arm::forward_bf16s_fp16s(const std::vector& bottom_blobs, std::ve int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = channels - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? channels + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (channels - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((channels - q) / (top_blobs.size() - i)); + } } int out_elempack = 1; @@ -868,10 +1057,27 @@ int Slice_arm::forward_bf16s_fp16s(const std::vector& bottom_blobs, std::ve int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = h - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? h + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (h - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((h - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; @@ -917,10 +1123,27 @@ int Slice_arm::forward_bf16s_fp16s(const std::vector& bottom_blobs, std::ve int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = w - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? w + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (w - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((w - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; @@ -966,10 +1189,27 @@ int Slice_arm::forward_bf16s_fp16s(const std::vector& bottom_blobs, std::ve int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = d - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? d + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (d - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((d - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; diff --git a/src/layer/loongarch/slice_loongarch.cpp b/src/layer/loongarch/slice_loongarch.cpp index 7fceb481231..2da903253f7 100644 --- a/src/layer/loongarch/slice_loongarch.cpp +++ b/src/layer/loongarch/slice_loongarch.cpp @@ -30,6 +30,7 @@ int Slice_loongarch::forward(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector((w - q) / (top_blobs.size() - i)); + } } int out_elempack = 1; @@ -74,10 +92,27 @@ int Slice_loongarch::forward(const std::vector& bottom_blobs, std::vector((h - q) / (top_blobs.size() - i)); + } } int out_elempack = 1; @@ -159,10 +194,27 @@ int Slice_loongarch::forward(const std::vector& bottom_blobs, std::vector((w - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; @@ -200,10 +252,27 @@ int Slice_loongarch::forward(const std::vector& bottom_blobs, std::vector((channels - q) / (top_blobs.size() - i)); + } } int out_elempack = 1; @@ -292,10 +361,27 @@ int Slice_loongarch::forward(const std::vector& bottom_blobs, std::vector((h - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; @@ -341,10 +427,27 @@ int Slice_loongarch::forward(const std::vector& bottom_blobs, std::vector((w - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; @@ -390,10 +493,27 @@ int Slice_loongarch::forward(const std::vector& bottom_blobs, std::vector((d - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; diff --git a/src/layer/mips/slice_mips.cpp b/src/layer/mips/slice_mips.cpp index 799371a1f9b..dead2610bdb 100644 --- a/src/layer/mips/slice_mips.cpp +++ b/src/layer/mips/slice_mips.cpp @@ -30,6 +30,7 @@ int Slice_mips::forward(const std::vector& bottom_blobs, std::vector& size_t elemsize = bottom_blob.elemsize; int elempack = bottom_blob.elempack; const int* slices_ptr = slices; + const int* indices_ptr = indices; int positive_axis = axis < 0 ? dims + axis : axis; if (dims == 1) // positive_axis == 0 @@ -39,10 +40,27 @@ int Slice_mips::forward(const std::vector& bottom_blobs, std::vector& int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) { - slice = (w - q) / (top_blobs.size() - i); + if (i == top_blobs.size() - 1) + { + slice = w - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? w + indice : indice; + slice = positive_indice - q; + } + } + else + { + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((w - q) / (top_blobs.size() - i)); + } } int out_elempack = 1; @@ -74,10 +92,27 @@ int Slice_mips::forward(const std::vector& bottom_blobs, std::vector& int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = h - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? h + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (h - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((h - q) / (top_blobs.size() - i)); + } } int out_elempack = 1; @@ -159,10 +194,27 @@ int Slice_mips::forward(const std::vector& bottom_blobs, std::vector& int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = w - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? w + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (w - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((w - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; @@ -200,10 +252,27 @@ int Slice_mips::forward(const std::vector& bottom_blobs, std::vector& int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = channels - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? channels + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (channels - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((channels - q) / (top_blobs.size() - i)); + } } int out_elempack = 1; @@ -292,10 +361,27 @@ int Slice_mips::forward(const std::vector& bottom_blobs, std::vector& int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = h - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? h + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (h - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((h - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; @@ -341,10 +427,27 @@ int Slice_mips::forward(const std::vector& bottom_blobs, std::vector& int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = w - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? w + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (w - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((w - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; @@ -390,10 +493,27 @@ int Slice_mips::forward(const std::vector& bottom_blobs, std::vector& int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = d - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? d + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (d - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((d - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; diff --git a/src/layer/slice.cpp b/src/layer/slice.cpp index ff0f6fc67de..cf7a2a37622 100644 --- a/src/layer/slice.cpp +++ b/src/layer/slice.cpp @@ -24,6 +24,7 @@ int Slice::load_param(const ParamDict& pd) { slices = pd.get(0, Mat()); axis = pd.get(1, 0); + indices = pd.get(2, Mat()); return 0; } @@ -34,6 +35,7 @@ int Slice::forward(const std::vector& bottom_blobs, std::vector& top_b int dims = bottom_blob.dims; size_t elemsize = bottom_blob.elemsize; const int* slices_ptr = slices; + const int* indices_ptr = indices; int positive_axis = axis < 0 ? dims + axis : axis; if (dims == 1) // positive_axis == 0 @@ -43,10 +45,27 @@ int Slice::forward(const std::vector& bottom_blobs, std::vector& top_b int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) { - slice = static_cast((w - q) / (top_blobs.size() - i)); + if (i == top_blobs.size() - 1) + { + slice = w - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? w + indice : indice; + slice = positive_indice - q; + } + } + else + { + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((w - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; @@ -72,10 +91,27 @@ int Slice::forward(const std::vector& bottom_blobs, std::vector& top_b int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = h - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? h + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = static_cast((h - q) / (top_blobs.size() - i)); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((h - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; @@ -103,10 +139,27 @@ int Slice::forward(const std::vector& bottom_blobs, std::vector& top_b int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = w - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? w + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = static_cast((w - q) / (top_blobs.size() - i)); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((w - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; @@ -138,10 +191,27 @@ int Slice::forward(const std::vector& bottom_blobs, std::vector& top_b int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = channels - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? channels + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = static_cast((channels - q) / (top_blobs.size() - i)); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((channels - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; @@ -173,10 +243,27 @@ int Slice::forward(const std::vector& bottom_blobs, std::vector& top_b int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = h - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? h + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = static_cast((h - q) / (top_blobs.size() - i)); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((h - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; @@ -215,10 +302,27 @@ int Slice::forward(const std::vector& bottom_blobs, std::vector& top_b int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = w - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? w + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = static_cast((w - q) / (top_blobs.size() - i)); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((w - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; @@ -262,10 +366,27 @@ int Slice::forward(const std::vector& bottom_blobs, std::vector& top_b int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = d - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? d + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = static_cast((d - q) / (top_blobs.size() - i)); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((d - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; diff --git a/src/layer/slice.h b/src/layer/slice.h index 4da3c803697..c3bcacb95e1 100644 --- a/src/layer/slice.h +++ b/src/layer/slice.h @@ -30,6 +30,7 @@ class Slice : public Layer public: Mat slices; + Mat indices; int axis; }; diff --git a/src/layer/vulkan/slice_vulkan.cpp b/src/layer/vulkan/slice_vulkan.cpp index 47caa86877d..a44ce35520f 100644 --- a/src/layer/vulkan/slice_vulkan.cpp +++ b/src/layer/vulkan/slice_vulkan.cpp @@ -238,6 +238,7 @@ int Slice_vulkan::forward(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector((w - q) / (top_blobs.size() - i)); + } } int out_elempack = opt.use_shader_pack8 && slice % 8 == 0 ? 8 : slice % 4 == 0 ? 4 : 1; @@ -350,10 +368,27 @@ int Slice_vulkan::forward(const std::vector& bottom_blobs, std::vector((h - q) / (top_blobs.size() - i)); + } } int out_elempack = opt.use_shader_pack8 && slice % 8 == 0 ? 8 : slice % 4 == 0 ? 4 : 1; @@ -453,10 +488,27 @@ int Slice_vulkan::forward(const std::vector& bottom_blobs, std::vector((w - q) / (top_blobs.size() - i)); + } } VkMat& top_blob = top_blobs[i]; @@ -513,10 +565,27 @@ int Slice_vulkan::forward(const std::vector& bottom_blobs, std::vector((channels - q) / (top_blobs.size() - i)); + } } int out_elempack = opt.use_shader_pack8 && slice % 8 == 0 ? 8 : slice % 4 == 0 ? 4 : 1; @@ -617,10 +686,27 @@ int Slice_vulkan::forward(const std::vector& bottom_blobs, std::vector((h - q) / (top_blobs.size() - i)); + } } VkMat& top_blob = top_blobs[i]; @@ -677,10 +763,27 @@ int Slice_vulkan::forward(const std::vector& bottom_blobs, std::vector((w - q) / (top_blobs.size() - i)); + } } VkMat& top_blob = top_blobs[i]; @@ -737,10 +840,27 @@ int Slice_vulkan::forward(const std::vector& bottom_blobs, std::vector((channels - q) / (top_blobs.size() - i)); + } } int out_elempack = opt.use_shader_pack8 && slice % 8 == 0 ? 8 : slice % 4 == 0 ? 4 : 1; @@ -841,10 +961,27 @@ int Slice_vulkan::forward(const std::vector& bottom_blobs, std::vector((d - q) / (top_blobs.size() - i)); + } } VkMat& top_blob = top_blobs[i]; @@ -901,10 +1038,27 @@ int Slice_vulkan::forward(const std::vector& bottom_blobs, std::vector((h - q) / (top_blobs.size() - i)); + } } VkMat& top_blob = top_blobs[i]; @@ -961,10 +1115,27 @@ int Slice_vulkan::forward(const std::vector& bottom_blobs, std::vector((w - q) / (top_blobs.size() - i)); + } } VkMat& top_blob = top_blobs[i]; @@ -1021,6 +1192,7 @@ int Slice_vulkan::forward(const std::vector& bottom_blobs, std::vect size_t elemsize = bottom_blob.elemsize; int elempack = bottom_blob.elempack; const int* slices_ptr = slices; + const int* indices_ptr = indices; int positive_axis = axis < 0 ? dims + axis : axis; if (dims == 1) // positive_axis == 0 @@ -1030,10 +1202,27 @@ int Slice_vulkan::forward(const std::vector& bottom_blobs, std::vect int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) { - slice = (w - q) / (top_blobs.size() - i); + if (i == top_blobs.size() - 1) + { + slice = w - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? w + indice : indice; + slice = positive_indice - q; + } + } + else + { + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((w - q) / (top_blobs.size() - i)); + } } int out_elempack = opt.use_shader_pack8 && slice % 8 == 0 ? 8 : slice % 4 == 0 ? 4 : 1; @@ -1133,10 +1322,27 @@ int Slice_vulkan::forward(const std::vector& bottom_blobs, std::vect int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) { - slice = (h - q) / (top_blobs.size() - i); + if (i == top_blobs.size() - 1) + { + slice = h - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? h + indice : indice; + slice = positive_indice - q; + } + } + else + { + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((h - q) / (top_blobs.size() - i)); + } } int out_elempack = opt.use_shader_pack8 && slice % 8 == 0 ? 8 : slice % 4 == 0 ? 4 : 1; @@ -1236,10 +1442,27 @@ int Slice_vulkan::forward(const std::vector& bottom_blobs, std::vect int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = w - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? w + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (w - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((w - q) / (top_blobs.size() - i)); + } } VkImageMat& top_blob = top_blobs[i]; @@ -1296,10 +1519,27 @@ int Slice_vulkan::forward(const std::vector& bottom_blobs, std::vect int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = channels - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? channels + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (channels - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((channels - q) / (top_blobs.size() - i)); + } } int out_elempack = opt.use_shader_pack8 && slice % 8 == 0 ? 8 : slice % 4 == 0 ? 4 : 1; @@ -1400,10 +1640,27 @@ int Slice_vulkan::forward(const std::vector& bottom_blobs, std::vect int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = h - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? h + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (h - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((h - q) / (top_blobs.size() - i)); + } } VkImageMat& top_blob = top_blobs[i]; @@ -1460,10 +1717,27 @@ int Slice_vulkan::forward(const std::vector& bottom_blobs, std::vect int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) { - slice = (w - q) / (top_blobs.size() - i); + if (i == top_blobs.size() - 1) + { + slice = w - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? w + indice : indice; + slice = positive_indice - q; + } + } + else + { + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((w - q) / (top_blobs.size() - i)); + } } VkImageMat& top_blob = top_blobs[i]; @@ -1520,10 +1794,27 @@ int Slice_vulkan::forward(const std::vector& bottom_blobs, std::vect int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) { - slice = (channels - q) / (top_blobs.size() - i); + if (i == top_blobs.size() - 1) + { + slice = channels - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? channels + indice : indice; + slice = positive_indice - q; + } + } + else + { + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((channels - q) / (top_blobs.size() - i)); + } } int out_elempack = opt.use_shader_pack8 && slice % 8 == 0 ? 8 : slice % 4 == 0 ? 4 : 1; @@ -1624,10 +1915,27 @@ int Slice_vulkan::forward(const std::vector& bottom_blobs, std::vect int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) { - slice = (d - q) / (top_blobs.size() - i); + if (i == top_blobs.size() - 1) + { + slice = d - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? d + indice : indice; + slice = positive_indice - q; + } + } + else + { + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((d - q) / (top_blobs.size() - i)); + } } VkImageMat& top_blob = top_blobs[i]; @@ -1684,10 +1992,27 @@ int Slice_vulkan::forward(const std::vector& bottom_blobs, std::vect int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = h - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? h + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (h - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((h - q) / (top_blobs.size() - i)); + } } VkImageMat& top_blob = top_blobs[i]; @@ -1744,10 +2069,27 @@ int Slice_vulkan::forward(const std::vector& bottom_blobs, std::vect int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) - { - slice = (w - q) / (top_blobs.size() - i); + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = w - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? w + indice : indice; + slice = positive_indice - q; + } + } + else + { + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((w - q) / (top_blobs.size() - i)); + } } VkImageMat& top_blob = top_blobs[i]; diff --git a/src/layer/x86/slice_x86.cpp b/src/layer/x86/slice_x86.cpp index 8b7ec0632e9..14764f0f030 100644 --- a/src/layer/x86/slice_x86.cpp +++ b/src/layer/x86/slice_x86.cpp @@ -37,6 +37,7 @@ int Slice_x86::forward(const std::vector& bottom_blobs, std::vector& t size_t elemsize = bottom_blob.elemsize; int elempack = bottom_blob.elempack; const int* slices_ptr = slices; + const int* indices_ptr = indices; int positive_axis = axis < 0 ? dims + axis : axis; if (dims == 1) // positive_axis == 0 @@ -46,10 +47,27 @@ int Slice_x86::forward(const std::vector& bottom_blobs, std::vector& t int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) { - slice = (w - q) / (top_blobs.size() - i); + if (i == top_blobs.size() - 1) + { + slice = w - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? w + indice : indice; + slice = positive_indice - q; + } + } + else + { + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((w - q) / (top_blobs.size() - i)); + } } int out_elempack = 1; @@ -89,10 +107,27 @@ int Slice_x86::forward(const std::vector& bottom_blobs, std::vector& t int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = h - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? h + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (h - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((h - q) / (top_blobs.size() - i)); + } } int out_elempack = 1; @@ -373,10 +408,27 @@ int Slice_x86::forward(const std::vector& bottom_blobs, std::vector& t int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = w - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? w + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (w - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((w - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; @@ -414,10 +466,27 @@ int Slice_x86::forward(const std::vector& bottom_blobs, std::vector& t int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = channels - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? channels + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (channels - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((channels - q) / (top_blobs.size() - i)); + } } int out_elempack = 1; @@ -715,10 +784,27 @@ int Slice_x86::forward(const std::vector& bottom_blobs, std::vector& t int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = h - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? h + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (h - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((h - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; @@ -764,10 +850,27 @@ int Slice_x86::forward(const std::vector& bottom_blobs, std::vector& t int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = w - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? w + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (w - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((w - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; @@ -813,10 +916,27 @@ int Slice_x86::forward(const std::vector& bottom_blobs, std::vector& t int q = 0; for (size_t i = 0; i < top_blobs.size(); i++) { - int slice = slices_ptr[i]; - if (slice == -233) + int slice; + if (indices_ptr) + { + if (i == top_blobs.size() - 1) + { + slice = d - q; + } + else + { + int indice = indices_ptr[i]; + int positive_indice = indice < 0 ? d + indice : indice; + slice = positive_indice - q; + } + } + else { - slice = (d - q) / (top_blobs.size() - i); + slice = slices_ptr[i]; + if (slice == -233) + { + slice = static_cast((d - q) / (top_blobs.size() - i)); + } } Mat& top_blob = top_blobs[i]; diff --git a/tests/test_slice.cpp b/tests/test_slice.cpp index 852f6f144f3..59cf10e8d68 100644 --- a/tests/test_slice.cpp +++ b/tests/test_slice.cpp @@ -88,6 +88,29 @@ static int test_slice(const ncnn::Mat& a, const ncnn::Mat& slices, int axis) return ret; } +static int test_slice_indices(const ncnn::Mat& a, const ncnn::Mat& indices, int axis) +{ + ncnn::ParamDict pd; + pd.set(1, axis); + pd.set(2, indices); + + std::vector weights(0); + + std::vector a0(1); + a0[0] = a; + + int ret = test_layer("Slice", pd, weights, a0, indices.w); + if (ret != 0) + { + fprintf(stderr, "test_slice_indices failed a.dims=%d a=(%d %d %d %d)", a.dims, a.w, a.h, a.d, a.c); + fprintf(stderr, " indices="); + print_int_array(indices); + fprintf(stderr, " axis=%d\n", axis); + } + + return ret; +} + static int test_slice_0() { ncnn::Mat a[] = { @@ -108,7 +131,11 @@ static int test_slice_0() || test_slice(a[i], IntArrayMat(32, 8, -233), 0) || test_slice(a[i], IntArrayMat(2, 12, 16, -233), 1) || test_slice(a[i], IntArrayMat(16, 4, 5, -233), -2) - || test_slice(a[i], IntArrayMat(8, 2, 16, -233), 3); + || test_slice(a[i], IntArrayMat(8, 2, 16, -233), 3) + || test_slice_indices(a[i], IntArrayMat(2, -24, -8), 0) + || test_slice_indices(a[i], IntArrayMat(4, 20, 4), 1) + || test_slice_indices(a[i], IntArrayMat(16, -16), -2) + || test_slice_indices(a[i], IntArrayMat(1, -12), 3); if (ret != 0) return ret; @@ -135,7 +162,10 @@ static int test_slice_1() || test_slice(a[i], IntArrayMat(12, 16, -233), 0) || test_slice(a[i], IntArrayMat(32, 8, -233), 0) || test_slice(a[i], IntArrayMat(2, 12, 16, -233), 1) - || test_slice(a[i], IntArrayMat(16, 4, 5, -233), -1); + || test_slice(a[i], IntArrayMat(16, 4, 5, -233), -1) + || test_slice_indices(a[i], IntArrayMat(2, -24, -8), 0) + || test_slice_indices(a[i], IntArrayMat(4, 20, 4), 1) + || test_slice_indices(a[i], IntArrayMat(1, -12), 2); if (ret != 0) return ret; @@ -160,7 +190,9 @@ static int test_slice_2() || test_slice(a[i], IntArrayMat(3, 12, 16, -233), 0) || test_slice(a[i], IntArrayMat(12, 16, -233), 0) || test_slice(a[i], IntArrayMat(32, 8, -233), -2) - || test_slice(a[i], IntArrayMat(2, 12, 16, -233), -1); + || test_slice(a[i], IntArrayMat(2, 12, 16, -233), -1) + || test_slice_indices(a[i], IntArrayMat(2, -24, -8), 0) + || test_slice_indices(a[i], IntArrayMat(1, -12), 1); if (ret != 0) return ret; @@ -183,7 +215,8 @@ static int test_slice_3() || test_slice(a[i], IntArrayMat(-233, -233, -233), 0) || test_slice(a[i], IntArrayMat(3, 12, 16, -233), 0) || test_slice(a[i], IntArrayMat(12, 16, -233), 0) - || test_slice(a[i], IntArrayMat(32, 8, -233), -1); + || test_slice(a[i], IntArrayMat(32, 8, -233), -1) + || test_slice_indices(a[i], IntArrayMat(2, -24, -8), 0); if (ret != 0) return ret; diff --git a/tools/pnnx/src/pass_ncnn/convert_torch_tensor_split.cpp b/tools/pnnx/src/pass_ncnn/convert_torch_tensor_split.cpp index 989104caa87..04fd7bfd29e 100644 --- a/tools/pnnx/src/pass_ncnn/convert_torch_tensor_split.cpp +++ b/tools/pnnx/src/pass_ncnn/convert_torch_tensor_split.cpp @@ -67,22 +67,40 @@ void convert_torch_tensor_split(Graph& graph) { const std::vector& indices = op->params.at("indices").ai; - op->params["0"].type = 5; - op->params["0"].ai.resize(indices.size() + 1); - - for (size_t i = 0; i < indices.size() + 1; i++) + bool has_negative_indice = false; + for (auto x : indices) { - if (i == 0) + if (x < 0) { - op->params["0"].ai[i] = indices[i]; + // negative indice + has_negative_indice = true; + break; } - else if (i == indices.size()) - { - op->params["0"].ai[i] = -233; - } - else + } + + if (has_negative_indice) + { + op->params["2"] = indices; + } + else + { + op->params["0"].type = 5; + op->params["0"].ai.resize(indices.size() + 1); + + for (size_t i = 0; i < indices.size() + 1; i++) { - op->params["0"].ai[i] = indices[i] - indices[i - 1]; + if (i == 0) + { + op->params["0"].ai[i] = indices[i]; + } + else if (i == indices.size()) + { + op->params["0"].ai[i] = -233; + } + else + { + op->params["0"].ai[i] = indices[i] - indices[i - 1]; + } } }