From ad01a604294b936fa8e81dccd1d39aa060c4b9c9 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Tue, 7 Jan 2025 11:28:49 -0600 Subject: [PATCH] Clean up portable op_topk (#7543) The comparison function was repeated several times. Factor it out. Differential Revision: [D67875394](https://our.internmc.facebook.com/intern/diff/D67875394/) ghstack-source-id: 260406714 Pull Request resolved: https://github.com/pytorch/executorch/pull/7528 Co-authored-by: Scott Wolchok --- kernels/portable/cpu/op_topk.cpp | 80 ++++++++------------------------ kernels/test/op_topk_test.cpp | 30 ++++++++++++ 2 files changed, 49 insertions(+), 61 deletions(-) diff --git a/kernels/portable/cpu/op_topk.cpp b/kernels/portable/cpu/op_topk.cpp index 1c862c5761..caaf7e033d 100644 --- a/kernels/portable/cpu/op_topk.cpp +++ b/kernels/portable/cpu/op_topk.cpp @@ -50,6 +50,14 @@ bool get_topk_target_size( return true; } +template +bool float_less_than(T x, T y) { + if constexpr (std::is_integral_v) { + return x < y; + } + return (!std::isnan(x) && std::isnan(y)) || x < y; +} + template > void perform_topk( const Tensor& in, @@ -101,69 +109,19 @@ void perform_topk( } // Perform topk on the queue + const auto elem_greater = [](const elem_t& x, const elem_t& y) -> bool { + return float_less_than(y.first, x.first); + }; + const auto elem_less = [](const elem_t& x, const elem_t& y) -> bool { + return float_less_than(x.first, y.first); + }; + const auto cmp = largest ? elem_greater : elem_less; if (use_partial_sort) { - if (largest) { - std::partial_sort( - queue, - queue + k, - queue + dim_size, - [](const elem_t& x, const elem_t& y) -> bool { - return ( - (std::isnan(x.first) && !std::isnan(y.first)) || - (x.first > y.first)); - }); - } else { - std::partial_sort( - queue, - queue + k, - queue + dim_size, - [](const elem_t& x, const elem_t& y) -> bool { - return ( - (!std::isnan(x.first) && std::isnan(y.first)) || - (x.first < y.first)); - }); - } + std::partial_sort(queue, queue + k, queue + dim_size, cmp); } else { - if (largest) { - std::nth_element( - queue, - queue + k - 1, - queue + dim_size, - [](const elem_t& x, const elem_t& y) -> bool { - return ( - (std::isnan(x.first) && !std::isnan(y.first)) || - (x.first > y.first)); - }); - if (sorted) { - std::sort( - queue, - queue + k - 1, - [](const elem_t& x, const elem_t& y) -> bool { - return ( - (std::isnan(x.first) && !std::isnan(y.first)) || - (x.first > y.first)); - }); - } - } else { - std::nth_element( - queue, - queue + k - 1, - queue + dim_size, - [](const elem_t& x, const elem_t& y) -> bool { - return ( - (!std::isnan(x.first) && std::isnan(y.first)) || - (x.first < y.first)); - }); - if (sorted) { - std::sort( - queue, - queue + k - 1, - [](const elem_t& x, const elem_t& y) -> bool { - return ( - (!std::isnan(x.first) && std::isnan(y.first)) || - (x.first < y.first)); - }); - } + std::nth_element(queue, queue + k - 1, queue + dim_size, cmp); + if (sorted) { + std::sort(queue, queue + k - 1, cmp); } } diff --git a/kernels/test/op_topk_test.cpp b/kernels/test/op_topk_test.cpp index 44a709687f..9e77e9ad5a 100644 --- a/kernels/test/op_topk_test.cpp +++ b/kernels/test/op_topk_test.cpp @@ -15,6 +15,8 @@ #include +#include + using namespace ::testing; using exec_aten::IntArrayRef; using exec_aten::ScalarType; @@ -135,4 +137,32 @@ TEST_F(OpTopkValuesTest, SmokeTest) { op_topk_values(input, k, dim, largest, sorted, values, indices); EXPECT_TENSOR_CLOSE(values, values_expected); EXPECT_TENSOR_EQ(indices, indices_expected); + + largest = false; + values_expected = tfFloat.make({2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + indices_expected = tfLong.make({2, 2, 2}, {0, 0, 0, 0, 1, 1, 1, 1}); + op_topk_values(input, k, dim, largest, sorted, values, indices); + EXPECT_TENSOR_CLOSE(values, values_expected); + EXPECT_TENSOR_EQ(indices, indices_expected); +} + +TEST_F(OpTopkValuesTest, NonPartialSort) { + TensorFactory tfFloat; + TensorFactory tfLong; + + std::vector data(100); + std::iota(data.begin(), data.end(), 0); + + for (const bool largest : {true, false}) { + Tensor input = tfFloat.make({(int)data.size()}, data); + Tensor values = tfFloat.zeros({1}); + Tensor indices = tfLong.zeros({1}); + Tensor values_expected = + tfFloat.make({1}, {largest ? data.back() : data.front()}); + Tensor indices_expected = + tfLong.make({1}, {largest ? (long)data.size() - 1 : 0L}); + op_topk_values(input, 1, 0, largest, true, values, indices); + EXPECT_TENSOR_CLOSE(values, values_expected); + EXPECT_TENSOR_EQ(indices, indices_expected); + } }