Skip to content

Commit

Permalink
add test for already sorted inputs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 672929428
  • Loading branch information
jan-wassenberg authored and copybara-github committed Sep 10, 2024
1 parent 88a2b66 commit cce7182
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 9 deletions.
2 changes: 2 additions & 0 deletions hwy/contrib/sort/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ cc_test(
":vqsort_for_test",
"//:hwy",
"//:hwy_test_util",
"//third_party/highway:thread_pool",
"//third_party/highway:topology",
] + TEST_MAIN,
)

Expand Down
49 changes: 45 additions & 4 deletions hwy/contrib/sort/sort_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,28 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <stdint.h>
#include <stdio.h>

#include <numeric> // std::iota
#include <random>
#include <vector>

#include "hwy/aligned_allocator.h" // IsAligned
#include "hwy/base.h"
#include "hwy/contrib/sort/vqsort.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/contrib/thread_pool/topology.h"
#include "hwy/per_target.h"

#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "hwy/contrib/sort/sort_test.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
// After foreach_target
#include "hwy/aligned_allocator.h" // IsAligned
#include "hwy/highway.h"
// After highway.h
#include "hwy/contrib/sort/algo-inl.h"
#include "hwy/contrib/sort/result-inl.h"
#include "hwy/contrib/sort/vqsort-inl.h" // BaseCase
#include "hwy/contrib/sort/vqsort.h"
#include "hwy/highway.h"
#include "hwy/print-inl.h"
#include "hwy/tests/test_util-inl.h"

Expand Down Expand Up @@ -59,6 +66,39 @@ using detail::OrderDescendingKV128;
using detail::Traits128;
#endif // !HAVE_INTEL && HWY_TARGET != HWY_SCALAR

template <typename Key>
void TestSortIota(hwy::ThreadPool& pool) {
pool.Run(128, 300, [](uint64_t task, size_t /*thread*/) {
const size_t num = static_cast<size_t>(task);
Key keys[300];
std::iota(keys, keys + num, Key{0});
VQSort(keys, num, hwy::SortAscending());
for (size_t i = 0; i < num; ++i) {
if (keys[i] != i) {
HWY_ABORT("num %zu i %zu: not iota, got %.0f\n", num, i,
static_cast<double>(keys[i]));
}
}
});
}

void TestAllSortIota() {
if constexpr (VQSORT_ENABLED) {
hwy::ThreadPool pool(hwy::HaveThreadingSupport() ? 4 : 0);
TestSortIota<uint32_t>(pool);
TestSortIota<int32_t>(pool);
if (hwy::HaveInteger64()) {
TestSortIota<int64_t>(pool);
TestSortIota<uint64_t>(pool);
}
TestSortIota<float>(pool);
if (hwy::HaveFloat64()) {
TestSortIota<double>(pool);
}
fprintf(stderr, "Iota OK\n");
}
}

// Supports full/partial sort and select.
template <class Traits>
void TestAnySort(const std::vector<Algo>& algos, size_t num_lanes) {
Expand Down Expand Up @@ -232,6 +272,7 @@ HWY_AFTER_NAMESPACE();

namespace hwy {
HWY_BEFORE_TEST(SortTest);
HWY_EXPORT_AND_TEST_P(SortTest, TestAllSortIota);
HWY_EXPORT_AND_TEST_P(SortTest, TestAllSort);
HWY_EXPORT_AND_TEST_P(SortTest, TestAllSelect);
HWY_EXPORT_AND_TEST_P(SortTest, TestAllPartialSort);
Expand Down
8 changes: 4 additions & 4 deletions hwy/contrib/sort/sort_unit_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,22 @@
#include <unordered_map>
#include <vector>

#include "hwy/aligned_allocator.h" // IsAligned
#include "hwy/base.h"
#include "hwy/contrib/sort/vqsort.h"
#include "hwy/detect_compiler_arch.h"

// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "hwy/contrib/sort/sort_unit_test.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
// After foreach_target
#include "hwy/aligned_allocator.h" // IsAligned
#include "hwy/highway.h"
// After highway.h
#include "hwy/contrib/sort/algo-inl.h"
#include "hwy/contrib/sort/result-inl.h"
#include "hwy/contrib/sort/traits128-inl.h"
#include "hwy/contrib/sort/vqsort-inl.h" // BaseCase
#include "hwy/contrib/sort/vqsort.h"
#include "hwy/highway.h"
#include "hwy/print-inl.h"
#include "hwy/tests/test_util-inl.h"

Expand Down
6 changes: 6 additions & 0 deletions hwy/per_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace hwy {
namespace HWY_NAMESPACE {
int64_t GetTarget() { return HWY_TARGET; }
size_t GetVectorBytes() { return Lanes(ScalableTag<uint8_t>()); }
bool GetHaveInteger64() { return HWY_HAVE_INTEGER64 != 0; }
bool GetHaveFloat16() { return HWY_HAVE_FLOAT16 != 0; }
bool GetHaveFloat64() { return HWY_HAVE_FLOAT64 != 0; }
// NOLINTNEXTLINE(google-readability-namespace-comments)
Expand All @@ -45,6 +46,7 @@ namespace hwy {
namespace {
HWY_EXPORT(GetTarget);
HWY_EXPORT(GetVectorBytes);
HWY_EXPORT(GetHaveInteger64);
HWY_EXPORT(GetHaveFloat16);
HWY_EXPORT(GetHaveFloat64);
} // namespace
Expand All @@ -57,6 +59,10 @@ HWY_DLLEXPORT size_t VectorBytes() {
return HWY_DYNAMIC_DISPATCH(GetVectorBytes)();
}

HWY_DLLEXPORT bool HaveInteger64() {
return HWY_DYNAMIC_DISPATCH(GetHaveInteger64)();
}

HWY_DLLEXPORT bool HaveFloat16() {
return HWY_DYNAMIC_DISPATCH(GetHaveFloat16)();
}
Expand Down
3 changes: 2 additions & 1 deletion hwy/per_target.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ HWY_DLLEXPORT int64_t DispatchedTarget();
// unnecessarily.
HWY_DLLEXPORT size_t VectorBytes();

// Returns whether 16/64-bit floats are a supported lane type.
// Returns whether 64-bit integers, 16/64-bit floats are a supported lane type.
HWY_DLLEXPORT bool HaveInteger64();
HWY_DLLEXPORT bool HaveFloat16();
HWY_DLLEXPORT bool HaveFloat64();

Expand Down

0 comments on commit cce7182

Please sign in to comment.