Skip to content

Commit ad01a60

Browse files
pytorchbotswolchok
andauthored
Clean up portable op_topk (#7543)
The comparison function was repeated several times. Factor it out. Differential Revision: [D67875394](https://our.internmc.facebook.com/intern/diff/D67875394/) ghstack-source-id: 260406714 Pull Request resolved: #7528 Co-authored-by: Scott Wolchok <[email protected]>
1 parent 271a277 commit ad01a60

File tree

2 files changed

+49
-61
lines changed

2 files changed

+49
-61
lines changed

kernels/portable/cpu/op_topk.cpp

+19-61
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ bool get_topk_target_size(
5050
return true;
5151
}
5252

53+
template <typename T>
54+
bool float_less_than(T x, T y) {
55+
if constexpr (std::is_integral_v<T>) {
56+
return x < y;
57+
}
58+
return (!std::isnan(x) && std::isnan(y)) || x < y;
59+
}
60+
5361
template <typename CTYPE, typename elem_t = std::pair<CTYPE, int64_t>>
5462
void perform_topk(
5563
const Tensor& in,
@@ -101,69 +109,19 @@ void perform_topk(
101109
}
102110

103111
// Perform topk on the queue
112+
const auto elem_greater = [](const elem_t& x, const elem_t& y) -> bool {
113+
return float_less_than(y.first, x.first);
114+
};
115+
const auto elem_less = [](const elem_t& x, const elem_t& y) -> bool {
116+
return float_less_than(x.first, y.first);
117+
};
118+
const auto cmp = largest ? elem_greater : elem_less;
104119
if (use_partial_sort) {
105-
if (largest) {
106-
std::partial_sort(
107-
queue,
108-
queue + k,
109-
queue + dim_size,
110-
[](const elem_t& x, const elem_t& y) -> bool {
111-
return (
112-
(std::isnan(x.first) && !std::isnan(y.first)) ||
113-
(x.first > y.first));
114-
});
115-
} else {
116-
std::partial_sort(
117-
queue,
118-
queue + k,
119-
queue + dim_size,
120-
[](const elem_t& x, const elem_t& y) -> bool {
121-
return (
122-
(!std::isnan(x.first) && std::isnan(y.first)) ||
123-
(x.first < y.first));
124-
});
125-
}
120+
std::partial_sort(queue, queue + k, queue + dim_size, cmp);
126121
} else {
127-
if (largest) {
128-
std::nth_element(
129-
queue,
130-
queue + k - 1,
131-
queue + dim_size,
132-
[](const elem_t& x, const elem_t& y) -> bool {
133-
return (
134-
(std::isnan(x.first) && !std::isnan(y.first)) ||
135-
(x.first > y.first));
136-
});
137-
if (sorted) {
138-
std::sort(
139-
queue,
140-
queue + k - 1,
141-
[](const elem_t& x, const elem_t& y) -> bool {
142-
return (
143-
(std::isnan(x.first) && !std::isnan(y.first)) ||
144-
(x.first > y.first));
145-
});
146-
}
147-
} else {
148-
std::nth_element(
149-
queue,
150-
queue + k - 1,
151-
queue + dim_size,
152-
[](const elem_t& x, const elem_t& y) -> bool {
153-
return (
154-
(!std::isnan(x.first) && std::isnan(y.first)) ||
155-
(x.first < y.first));
156-
});
157-
if (sorted) {
158-
std::sort(
159-
queue,
160-
queue + k - 1,
161-
[](const elem_t& x, const elem_t& y) -> bool {
162-
return (
163-
(!std::isnan(x.first) && std::isnan(y.first)) ||
164-
(x.first < y.first));
165-
});
166-
}
122+
std::nth_element(queue, queue + k - 1, queue + dim_size, cmp);
123+
if (sorted) {
124+
std::sort(queue, queue + k - 1, cmp);
167125
}
168126
}
169127

kernels/test/op_topk_test.cpp

+30
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
#include <gtest/gtest.h>
1717

18+
#include <algorithm>
19+
1820
using namespace ::testing;
1921
using exec_aten::IntArrayRef;
2022
using exec_aten::ScalarType;
@@ -135,4 +137,32 @@ TEST_F(OpTopkValuesTest, SmokeTest) {
135137
op_topk_values(input, k, dim, largest, sorted, values, indices);
136138
EXPECT_TENSOR_CLOSE(values, values_expected);
137139
EXPECT_TENSOR_EQ(indices, indices_expected);
140+
141+
largest = false;
142+
values_expected = tfFloat.make({2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
143+
indices_expected = tfLong.make({2, 2, 2}, {0, 0, 0, 0, 1, 1, 1, 1});
144+
op_topk_values(input, k, dim, largest, sorted, values, indices);
145+
EXPECT_TENSOR_CLOSE(values, values_expected);
146+
EXPECT_TENSOR_EQ(indices, indices_expected);
147+
}
148+
149+
TEST_F(OpTopkValuesTest, NonPartialSort) {
150+
TensorFactory<ScalarType::Float> tfFloat;
151+
TensorFactory<ScalarType::Long> tfLong;
152+
153+
std::vector<float> data(100);
154+
std::iota(data.begin(), data.end(), 0);
155+
156+
for (const bool largest : {true, false}) {
157+
Tensor input = tfFloat.make({(int)data.size()}, data);
158+
Tensor values = tfFloat.zeros({1});
159+
Tensor indices = tfLong.zeros({1});
160+
Tensor values_expected =
161+
tfFloat.make({1}, {largest ? data.back() : data.front()});
162+
Tensor indices_expected =
163+
tfLong.make({1}, {largest ? (long)data.size() - 1 : 0L});
164+
op_topk_values(input, 1, 0, largest, true, values, indices);
165+
EXPECT_TENSOR_CLOSE(values, values_expected);
166+
EXPECT_TENSOR_EQ(indices, indices_expected);
167+
}
138168
}

0 commit comments

Comments
 (0)