Skip to content

Commit

Permalink
[XLA:GPU] Set the SortSizeThreshold in tests to zero.
Browse files Browse the repository at this point in the history
This threshold defines for which tensor sizes Cub Raddix sort (custom call) will be preferred over XLA GPU's native Bitonic sort. Test for Cub should set this threshold to zero to ensure Cub is used in all test cases. This allows correctness tests on very small tensors that produce user friendly error messages.

PiperOrigin-RevId: 701260355
  • Loading branch information
thomasjoerg authored and tensorflower-gardener committed Nov 29, 2024
1 parent 4680b7f commit 8adb31f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
24 changes: 18 additions & 6 deletions third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ bool HloWasRewrittenToUseCubSort(const HloModule& module) {
return false;
}

constexpr int kTestDataSize = 10000;

// ----- Sort keys

class CubSortKeysTest : public HloTestBase,
Expand All @@ -50,13 +52,18 @@ class CubSortKeysTest : public HloTestBase,
public:
void SetUp() override {
HloTestBase::SetUp();
SortRewriter::SetSortSizeThresholdForTestingOnly(33000);
SortRewriter::SetSortSizeThresholdForTestingOnly(
0); // Always use CUB sort.
}
};

TEST_F(CubSortKeysTest, AlwaysUsesCubSort) {
EXPECT_EQ(SortRewriter::SortSizeThreshold(), 0);
}

TEST_P(CubSortKeysTest, CompareToReference) {
int batch_size = std::get<2>(GetParam());
int segment_size = SortRewriter::SortSizeThreshold() / batch_size;
int segment_size = kTestDataSize / batch_size;

const char* kHloTpl = R"(
HloModule TestSortKeys
Expand Down Expand Up @@ -103,7 +110,7 @@ ENTRY m {
})";

int batch_size = std::get<2>(GetParam());
int segment_size = SortRewriter::SortSizeThreshold() / batch_size;
int segment_size = kTestDataSize / batch_size;
std::string hlo_str = absl::Substitute(
kHloTpl,
primitive_util::LowercasePrimitiveTypeName(std::get<0>(GetParam())),
Expand Down Expand Up @@ -138,13 +145,18 @@ class CubSortPairsTest
public:
void SetUp() override {
HloTestBase::SetUp();
SortRewriter::SetSortSizeThresholdForTestingOnly(33000);
SortRewriter::SetSortSizeThresholdForTestingOnly(
0); // Always use CUB sort.
}
};

TEST_F(CubSortPairsTest, AlwaysUsesCubSort) {
EXPECT_EQ(SortRewriter::SortSizeThreshold(), 0);
}

TEST_P(CubSortPairsTest, CompareToReference) {
int batch_size = std::get<3>(GetParam());
int segment_size = SortRewriter::SortSizeThreshold() / batch_size;
int segment_size = kTestDataSize / batch_size;

const char* kHloTpl = R"(
HloModule TestSortPairs
Expand Down Expand Up @@ -216,7 +228,7 @@ ENTRY m {
})";

int batch_size = std::get<3>(GetParam());
int segment_size = SortRewriter::SortSizeThreshold() / batch_size;
int segment_size = kTestDataSize / batch_size;
std::string hlo_str = absl::Substitute(
kHloTpl,
primitive_util::LowercasePrimitiveTypeName(std::get<0>(GetParam())),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class SortRewriterTest : public HloTestBase {
public:
void SetUp() override {
HloTestBase::SetUp();
SortRewriter::SetSortSizeThresholdForTestingOnly(1000);
SortRewriter::SetSortSizeThresholdForTestingOnly(
0); // Always use CUB sort.
}

bool RunModuleAndPass(HloModule* module) {
Expand Down Expand Up @@ -307,6 +308,7 @@ ENTRY %main {

// Small shapes do not see improvement from CUB sort.
TEST_F(SortRewriterTest, NoRewriteSmallSize) {
SortRewriter::SetSortSizeThresholdForTestingOnly(16385);
constexpr char kHlo[] = R"(
HloModule TestModule
Expand Down Expand Up @@ -398,8 +400,8 @@ ENTRY %main {
RunAndFilecheckHloRewrite(kHlo, SortRewriter(), kExpectedPattern);
}

TEST_F(SortRewriterTest, SortSizeThresholdIsSet) {
EXPECT_EQ(SortRewriter::SortSizeThreshold(), 1000);
TEST_F(SortRewriterTest, AlwaysUsesCubSort) {
EXPECT_EQ(SortRewriter::SortSizeThreshold(), 0);
}

} // namespace
Expand Down

0 comments on commit 8adb31f

Please sign in to comment.