Skip to content

Commit

Permalink
Fix flaky test in Faiss JNI range search (opensearch-project#1705)
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei authored May 16, 2024
1 parent 6ea376b commit e584822
Showing 1 changed file with 16 additions and 17 deletions.
33 changes: 16 additions & 17 deletions jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ using ::testing::Return;

float randomDataMin = -500.0;
float randomDataMax = 500.0;
float rangeSearchRandomDataMin = -50;
float rangeSearchRandomDataMax = 50;
float rangeSearchRadius = 20000;

TEST(FaissCreateIndexTest, BasicAssertions) {
// Define the data
Expand Down Expand Up @@ -621,21 +624,20 @@ TEST(FaissRangeSearchQueryIndexTest, BasicAssertions) {
faiss::idx_t numIds = 200;
int dim = 2;
std::vector<faiss::idx_t> ids = test_util::Range(numIds);
std::vector<float> vectors = test_util::RandomVectors(dim, numIds, randomDataMin, randomDataMax);
std::vector<float> vectors = test_util::RandomVectors(dim, numIds, rangeSearchRandomDataMin, rangeSearchRandomDataMax);

faiss::MetricType metricType = faiss::METRIC_L2;
std::string method = "HNSW32,Flat";

// Define query data
float radius = 100000.0;
int numQueries = 100;
std::vector<std::vector<float>> queries;

for (int i = 0; i < numQueries; i++) {
std::vector<float> query;
query.reserve(dim);
for (int j = 0; j < dim; j++) {
query.push_back(test_util::RandomFloat(randomDataMin, randomDataMax));
query.push_back(test_util::RandomFloat(rangeSearchRandomDataMin, rangeSearchRandomDataMax));
}
queries.push_back(query);
}
Expand All @@ -659,7 +661,7 @@ TEST(FaissRangeSearchQueryIndexTest, BasicAssertions) {
knn_jni::faiss_wrapper::RangeSearch(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jfloatArray>(&query), radius, maxResultWindow, nullptr)));
reinterpret_cast<jfloatArray>(&query), rangeSearchRadius, maxResultWindow, nullptr)));

// assert result size is not 0
ASSERT_NE(0, results->size());
Expand All @@ -677,21 +679,20 @@ TEST(FaissRangeSearchQueryIndexTest_WhenHitMaxWindowResult, BasicAssertions){
faiss::idx_t numIds = 200;
int dim = 2;
std::vector<faiss::idx_t> ids = test_util::Range(numIds);
std::vector<float> vectors = test_util::RandomVectors(dim, numIds, randomDataMin, randomDataMax);
std::vector<float> vectors = test_util::RandomVectors(dim, numIds, rangeSearchRandomDataMin, rangeSearchRandomDataMax);

faiss::MetricType metricType = faiss::METRIC_L2;
std::string method = "HNSW32,Flat";

// Define query data
float radius = 100000.0;
int numQueries = 100;
std::vector<std::vector<float>> queries;

for (int i = 0; i < numQueries; i++) {
std::vector<float> query;
query.reserve(dim);
for (int j = 0; j < dim; j++) {
query.push_back(test_util::RandomFloat(randomDataMin, randomDataMax));
query.push_back(test_util::RandomFloat(rangeSearchRandomDataMin, rangeSearchRandomDataMax));
}
queries.push_back(query);
}
Expand All @@ -715,7 +716,7 @@ TEST(FaissRangeSearchQueryIndexTest_WhenHitMaxWindowResult, BasicAssertions){
knn_jni::faiss_wrapper::RangeSearch(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jfloatArray>(&query), radius, maxResultWindow, nullptr)));
reinterpret_cast<jfloatArray>(&query), rangeSearchRadius, maxResultWindow, nullptr)));

// assert result size is not 0
ASSERT_NE(0, results->size());
Expand All @@ -734,21 +735,20 @@ TEST(FaissRangeSearchQueryIndexTestWithFilterTest, BasicAssertions) {
faiss::idx_t numIds = 200;
int dim = 2;
std::vector<faiss::idx_t> ids = test_util::Range(numIds);
std::vector<float> vectors = test_util::RandomVectors(dim, numIds, randomDataMin, randomDataMax);
std::vector<float> vectors = test_util::RandomVectors(dim, numIds, rangeSearchRandomDataMin, rangeSearchRandomDataMax);

faiss::MetricType metricType = faiss::METRIC_L2;
std::string method = "HNSW32,Flat";

// Define query data
float radius = 100000.0;
int numQueries = 100;
std::vector<std::vector<float>> queries;

for (int i = 0; i < numQueries; i++) {
std::vector<float> query;
query.reserve(dim);
for (int j = 0; j < dim; j++) {
query.push_back(test_util::RandomFloat(randomDataMin, randomDataMax));
query.push_back(test_util::RandomFloat(rangeSearchRandomDataMin, rangeSearchRandomDataMax));
}
queries.push_back(query);
}
Expand All @@ -767,7 +767,7 @@ TEST(FaissRangeSearchQueryIndexTestWithFilterTest, BasicAssertions) {
std::vector<jlong> bitmap(num_bits,0);
std::vector<int64_t> filterIds;

for (int64_t i = 154; i < 163; i++) {
for (int64_t i = 1; i < 50; i++) {
filterIds.push_back(i);
test_util::setBitSet(i, bitmap.data(), bitmap.size());
}
Expand All @@ -782,7 +782,7 @@ TEST(FaissRangeSearchQueryIndexTestWithFilterTest, BasicAssertions) {
knn_jni::faiss_wrapper::RangeSearchWithFilter(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jfloatArray>(&query), radius, maxResultWindow,
reinterpret_cast<jfloatArray>(&query), rangeSearchRadius, maxResultWindow,
reinterpret_cast<jlongArray>(&bitmap), 0, nullptr)));

// assert result size is not 0
Expand Down Expand Up @@ -814,23 +814,22 @@ TEST(FaissRangeSearchQueryIndexTestWithParentFilterTest, BasicAssertions) {
}
ids.push_back(i);
for (int j = 0; j < dim; j++) {
vectors.push_back(test_util::RandomFloat(-500.0, 500.0));
vectors.push_back(test_util::RandomFloat(rangeSearchRandomDataMin, rangeSearchRandomDataMax));
}
}

faiss::MetricType metricType = faiss::METRIC_L2;
std::string method = "HNSW32,Flat";

// Define query data
float radius = 100000.0;
int numQueries = 1;
std::vector<std::vector<float>> queries;

for (int i = 0; i < numQueries; i++) {
std::vector<float> query;
query.reserve(dim);
for (int j = 0; j < dim; j++) {
query.push_back(test_util::RandomFloat(-500.0, 500.0));
query.push_back(test_util::RandomFloat(rangeSearchRandomDataMin, rangeSearchRandomDataMax));
}
queries.push_back(query);
}
Expand Down Expand Up @@ -858,7 +857,7 @@ TEST(FaissRangeSearchQueryIndexTestWithParentFilterTest, BasicAssertions) {
knn_jni::faiss_wrapper::RangeSearchWithFilter(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jfloatArray>(&query), radius, maxResultWindow, nullptr, 0,
reinterpret_cast<jfloatArray>(&query), rangeSearchRadius, maxResultWindow, nullptr, 0,
reinterpret_cast<jintArray>(&parentIds))));

// assert result size is not 0
Expand Down

0 comments on commit e584822

Please sign in to comment.