Skip to content

Commit

Permalink
Clean up portable op_topk (#7543)
Browse files Browse the repository at this point in the history
The comparison function was repeated several times. Factor it out.

Differential Revision: [D67875394](https://our.internmc.facebook.com/intern/diff/D67875394/)

ghstack-source-id: 260406714
Pull Request resolved: #7528

Co-authored-by: Scott Wolchok <[email protected]>
  • Loading branch information
pytorchbot and swolchok authored Jan 7, 2025
1 parent 271a277 commit ad01a60
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 61 deletions.
80 changes: 19 additions & 61 deletions kernels/portable/cpu/op_topk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ bool get_topk_target_size(
return true;
}

template <typename T>
bool float_less_than(T x, T y) {
if constexpr (std::is_integral_v<T>) {
return x < y;
}
return (!std::isnan(x) && std::isnan(y)) || x < y;
}

template <typename CTYPE, typename elem_t = std::pair<CTYPE, int64_t>>
void perform_topk(
const Tensor& in,
Expand Down Expand Up @@ -101,69 +109,19 @@ void perform_topk(
}

// Perform topk on the queue
const auto elem_greater = [](const elem_t& x, const elem_t& y) -> bool {
return float_less_than(y.first, x.first);
};
const auto elem_less = [](const elem_t& x, const elem_t& y) -> bool {
return float_less_than(x.first, y.first);
};
const auto cmp = largest ? elem_greater : elem_less;
if (use_partial_sort) {
if (largest) {
std::partial_sort(
queue,
queue + k,
queue + dim_size,
[](const elem_t& x, const elem_t& y) -> bool {
return (
(std::isnan(x.first) && !std::isnan(y.first)) ||
(x.first > y.first));
});
} else {
std::partial_sort(
queue,
queue + k,
queue + dim_size,
[](const elem_t& x, const elem_t& y) -> bool {
return (
(!std::isnan(x.first) && std::isnan(y.first)) ||
(x.first < y.first));
});
}
std::partial_sort(queue, queue + k, queue + dim_size, cmp);
} else {
if (largest) {
std::nth_element(
queue,
queue + k - 1,
queue + dim_size,
[](const elem_t& x, const elem_t& y) -> bool {
return (
(std::isnan(x.first) && !std::isnan(y.first)) ||
(x.first > y.first));
});
if (sorted) {
std::sort(
queue,
queue + k - 1,
[](const elem_t& x, const elem_t& y) -> bool {
return (
(std::isnan(x.first) && !std::isnan(y.first)) ||
(x.first > y.first));
});
}
} else {
std::nth_element(
queue,
queue + k - 1,
queue + dim_size,
[](const elem_t& x, const elem_t& y) -> bool {
return (
(!std::isnan(x.first) && std::isnan(y.first)) ||
(x.first < y.first));
});
if (sorted) {
std::sort(
queue,
queue + k - 1,
[](const elem_t& x, const elem_t& y) -> bool {
return (
(!std::isnan(x.first) && std::isnan(y.first)) ||
(x.first < y.first));
});
}
std::nth_element(queue, queue + k - 1, queue + dim_size, cmp);
if (sorted) {
std::sort(queue, queue + k - 1, cmp);
}
}

Expand Down
30 changes: 30 additions & 0 deletions kernels/test/op_topk_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

#include <gtest/gtest.h>

#include <algorithm>

using namespace ::testing;
using exec_aten::IntArrayRef;
using exec_aten::ScalarType;
Expand Down Expand Up @@ -135,4 +137,32 @@ TEST_F(OpTopkValuesTest, SmokeTest) {
op_topk_values(input, k, dim, largest, sorted, values, indices);
EXPECT_TENSOR_CLOSE(values, values_expected);
EXPECT_TENSOR_EQ(indices, indices_expected);

largest = false;
values_expected = tfFloat.make({2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
indices_expected = tfLong.make({2, 2, 2}, {0, 0, 0, 0, 1, 1, 1, 1});
op_topk_values(input, k, dim, largest, sorted, values, indices);
EXPECT_TENSOR_CLOSE(values, values_expected);
EXPECT_TENSOR_EQ(indices, indices_expected);
}

TEST_F(OpTopkValuesTest, NonPartialSort) {
TensorFactory<ScalarType::Float> tfFloat;
TensorFactory<ScalarType::Long> tfLong;

std::vector<float> data(100);
std::iota(data.begin(), data.end(), 0);

for (const bool largest : {true, false}) {
Tensor input = tfFloat.make({(int)data.size()}, data);
Tensor values = tfFloat.zeros({1});
Tensor indices = tfLong.zeros({1});
Tensor values_expected =
tfFloat.make({1}, {largest ? data.back() : data.front()});
Tensor indices_expected =
tfLong.make({1}, {largest ? (long)data.size() - 1 : 0L});
op_topk_values(input, 1, 0, largest, true, values, indices);
EXPECT_TENSOR_CLOSE(values, values_expected);
EXPECT_TENSOR_EQ(indices, indices_expected);
}
}

0 comments on commit ad01a60

Please sign in to comment.