From 5fa5d6eda0e2524fcf8cc2deb5157d9ecf3df74a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=B0=E9=98=85?= <43716063+Baiyuetribe@users.noreply.github.com> Date: Tue, 31 Dec 2024 13:38:22 +0800 Subject: [PATCH] =?UTF-8?q?ref=20argmax=EF=BC=8Cfix=20?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/layer/topk.cpp | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/src/layer/topk.cpp b/src/layer/topk.cpp index 4150db63aab..f4a3a640d35 100644 --- a/src/layer/topk.cpp +++ b/src/layer/topk.cpp @@ -13,8 +13,7 @@ // specific language governing permissions and limitations under the License. #include "topk.h" -#include -#include +#include namespace ncnn { @@ -52,40 +51,25 @@ int TopK::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons } // [](const std::pair& a, const std::pair& b) {return a.first > b.first;}); // fix Lambda with lower version of C++ - struct CompareGreater - { - bool operator()(const std::pair& a, const std::pair& b) const - { - return a.first > b.first; - } - }; - - struct CompareLess - { - bool operator()(const std::pair& a, const std::pair& b) const - { - return a.first < b.first; - } - }; if (largest == 1) { - std::partial_sort(vec.begin(), vec.begin() + k_, vec.end(), CompareGreater()); + std::partial_sort(vec.begin(), vec.begin() + k_, vec.end(), std::greater >()); } else { - std::partial_sort(vec.begin(), vec.begin() + k_, vec.end(), CompareLess()); + std::partial_sort(vec.begin(), vec.begin() + k_, vec.end(), std::less >()); } if (sorted) { if (largest == 1) { - std::sort(vec.begin(), vec.begin() + k_, CompareGreater()); + std::sort(vec.begin(), vec.begin() + k_, std::greater >()); } else { - std::sort(vec.begin(), vec.begin() + k_, CompareLess()); + std::sort(vec.begin(), vec.begin() + k_, std::less >()); } }