Skip to content

Commit

Permalink
ref argmax,fix <vect>
Browse files Browse the repository at this point in the history
  • Loading branch information
Baiyuetribe committed Dec 31, 2024
1 parent c6edde6 commit 5fa5d6e
Showing 1 changed file with 5 additions and 21 deletions.
26 changes: 5 additions & 21 deletions src/layer/topk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
// specific language governing permissions and limitations under the License.

#include "topk.h"
#include <vector>
#include <algorithm>
#include <functional>

namespace ncnn {

Expand Down Expand Up @@ -52,40 +51,25 @@ int TopK::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons
}

// [](const std::pair<float, int>& a, const std::pair<float, int>& b) {return a.first > b.first;}); // fix Lambda with lower version of C++
struct CompareGreater
{
bool operator()(const std::pair<float, int>& a, const std::pair<float, int>& b) const
{
return a.first > b.first;
}
};

struct CompareLess
{
bool operator()(const std::pair<float, int>& a, const std::pair<float, int>& b) const
{
return a.first < b.first;
}
};

if (largest == 1)
{
std::partial_sort(vec.begin(), vec.begin() + k_, vec.end(), CompareGreater());
std::partial_sort(vec.begin(), vec.begin() + k_, vec.end(), std::greater<std::pair<float, int> >());
}
else
{
std::partial_sort(vec.begin(), vec.begin() + k_, vec.end(), CompareLess());
std::partial_sort(vec.begin(), vec.begin() + k_, vec.end(), std::less<std::pair<float, int> >());
}

if (sorted)
{
if (largest == 1)
{
std::sort(vec.begin(), vec.begin() + k_, CompareGreater());
std::sort(vec.begin(), vec.begin() + k_, std::greater<std::pair<float, int> >());
}
else
{
std::sort(vec.begin(), vec.begin() + k_, CompareLess());
std::sort(vec.begin(), vec.begin() + k_, std::less<std::pair<float, int> >());
}
}

Expand Down

0 comments on commit 5fa5d6e

Please sign in to comment.