diff --git a/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc index b4124f2673d958..69e2367427c388 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc @@ -42,6 +42,8 @@ bool HloWasRewrittenToUseCubSort(const HloModule& module) { return false; } +constexpr int kTestDataSize = 10000; + // ----- Sort keys class CubSortKeysTest : public HloTestBase, @@ -50,13 +52,18 @@ class CubSortKeysTest : public HloTestBase, public: void SetUp() override { HloTestBase::SetUp(); - SortRewriter::SetSortSizeThresholdForTestingOnly(33000); + SortRewriter::SetSortSizeThresholdForTestingOnly( + 0); // Always use CUB sort. } }; +TEST_F(CubSortKeysTest, AlwaysUsesCubSort) { + EXPECT_EQ(SortRewriter::SortSizeThreshold(), 0); +} + TEST_P(CubSortKeysTest, CompareToReference) { int batch_size = std::get<2>(GetParam()); - int segment_size = SortRewriter::SortSizeThreshold() / batch_size; + int segment_size = kTestDataSize / batch_size; const char* kHloTpl = R"( HloModule TestSortKeys @@ -103,7 +110,7 @@ ENTRY m { })"; int batch_size = std::get<2>(GetParam()); - int segment_size = SortRewriter::SortSizeThreshold() / batch_size; + int segment_size = kTestDataSize / batch_size; std::string hlo_str = absl::Substitute( kHloTpl, primitive_util::LowercasePrimitiveTypeName(std::get<0>(GetParam())), @@ -138,13 +145,18 @@ class CubSortPairsTest public: void SetUp() override { HloTestBase::SetUp(); - SortRewriter::SetSortSizeThresholdForTestingOnly(33000); + SortRewriter::SetSortSizeThresholdForTestingOnly( + 0); // Always use CUB sort. } }; +TEST_F(CubSortPairsTest, AlwaysUsesCubSort) { + EXPECT_EQ(SortRewriter::SortSizeThreshold(), 0); +} + TEST_P(CubSortPairsTest, CompareToReference) { int batch_size = std::get<3>(GetParam()); - int segment_size = SortRewriter::SortSizeThreshold() / batch_size; + int segment_size = kTestDataSize / batch_size; const char* kHloTpl = R"( HloModule TestSortPairs @@ -216,7 +228,7 @@ ENTRY m { })"; int batch_size = std::get<3>(GetParam()); - int segment_size = SortRewriter::SortSizeThreshold() / batch_size; + int segment_size = kTestDataSize / batch_size; std::string hlo_str = absl::Substitute( kHloTpl, primitive_util::LowercasePrimitiveTypeName(std::get<0>(GetParam())), diff --git a/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc index 248e6c0525b18c..cb660e747f11ff 100644 --- a/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc @@ -39,7 +39,8 @@ class SortRewriterTest : public HloTestBase { public: void SetUp() override { HloTestBase::SetUp(); - SortRewriter::SetSortSizeThresholdForTestingOnly(1000); + SortRewriter::SetSortSizeThresholdForTestingOnly( + 0); // Always use CUB sort. } bool RunModuleAndPass(HloModule* module) { @@ -307,6 +308,7 @@ ENTRY %main { // Small shapes do not see improvement from CUB sort. TEST_F(SortRewriterTest, NoRewriteSmallSize) { + SortRewriter::SetSortSizeThresholdForTestingOnly(16385); constexpr char kHlo[] = R"( HloModule TestModule @@ -398,8 +400,8 @@ ENTRY %main { RunAndFilecheckHloRewrite(kHlo, SortRewriter(), kExpectedPattern); } -TEST_F(SortRewriterTest, SortSizeThresholdIsSet) { - EXPECT_EQ(SortRewriter::SortSizeThreshold(), 1000); +TEST_F(SortRewriterTest, AlwaysUsesCubSort) { + EXPECT_EQ(SortRewriter::SortSizeThreshold(), 0); } } // namespace