Skip to content

Commit

Permalink
add 4d dch
Browse files Browse the repository at this point in the history
  • Loading branch information
Baiyuetribe committed Jan 13, 2025
1 parent 5bd1679 commit 5ba56f3
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion src/layer/flip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ int Flip::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons
}
else if (axis.w == 3)
{
// dch3、dcw4、chw6
// dch3、dcw4、dhw5,chw6
int axis0 = axis_ptr[0] < 0 ? 4 + axis_ptr[0] : axis_ptr[0];
int axis1 = axis_ptr[1] < 0 ? 4 + axis_ptr[1] : axis_ptr[1];
int axis2 = axis_ptr[2] < 0 ? 4 + axis_ptr[2] : axis_ptr[2];
Expand Down Expand Up @@ -510,6 +510,29 @@ int Flip::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons
}
}
}
else if (axis_sum == 5)
{
// 对应dhw,除了d外全翻转
for (int c = 0; c < channels; c++)
{
int flipped_c = channels - 1 - c; // 翻转c维度

for (int z = 0; z < d; z++) // d维度保持不变
{
for (int i = 0; i < h; i++)
{
const float* ptr = bottom_blob.channel(c).depth(z).row(i);
float* outptr = const_cast<float*>(top_blob.channel(flipped_c).depth(z).row(h - 1 - i)); // 翻转h维度

// 翻转w维度
for (int k = 0; k < w; k++)
{
outptr[k] = ptr[w - 1 - k];
}
}
}
}
}
else if (axis_sum == 6)
{
// 对应chw,除了c外全翻转
Expand Down

0 comments on commit 5ba56f3

Please sign in to comment.