From 5ba56f30677f06df08b9d96ff10adaf33279d499 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=B0=E9=98=85?= <43716063+Baiyuetribe@users.noreply.github.com> Date: Mon, 13 Jan 2025 23:39:12 +0800 Subject: [PATCH] add 4d dch --- src/layer/flip.cpp | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/layer/flip.cpp b/src/layer/flip.cpp index 7c571ea7e2e..dbb278a8955 100644 --- a/src/layer/flip.cpp +++ b/src/layer/flip.cpp @@ -458,7 +458,7 @@ int Flip::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons } else if (axis.w == 3) { - // dch3、dcw4、chw6 + // dch3、dcw4、dhw5,chw6 int axis0 = axis_ptr[0] < 0 ? 4 + axis_ptr[0] : axis_ptr[0]; int axis1 = axis_ptr[1] < 0 ? 4 + axis_ptr[1] : axis_ptr[1]; int axis2 = axis_ptr[2] < 0 ? 4 + axis_ptr[2] : axis_ptr[2]; @@ -510,6 +510,29 @@ int Flip::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons } } } + else if (axis_sum == 5) + { + // 对应dhw,除了d外全翻转 + for (int c = 0; c < channels; c++) + { + int flipped_c = channels - 1 - c; // 翻转c维度 + + for (int z = 0; z < d; z++) // d维度保持不变 + { + for (int i = 0; i < h; i++) + { + const float* ptr = bottom_blob.channel(c).depth(z).row(i); + float* outptr = const_cast(top_blob.channel(flipped_c).depth(z).row(h - 1 - i)); // 翻转h维度 + + // 翻转w维度 + for (int k = 0; k < w; k++) + { + outptr[k] = ptr[w - 1 - k]; + } + } + } + } + } else if (axis_sum == 6) { // 对应chw,除了c外全翻转