From 77b432fa71a4772a29c88245379cf37a336e98ff Mon Sep 17 00:00:00 2001 From: Neelam Mahapatro Date: Mon, 18 Dec 2023 16:13:43 +0530 Subject: [PATCH 1/7] Allowing less than K nodes in result (Filter and Non Filter In Memory Search) --- src/index.cpp | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index 4157edcef..9655a074e 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2251,14 +2251,13 @@ std::pair Index::search(const T *query, con break; } - if (pos < K) + if (pos <= K) { - context.SetState(State::Failure); - diskann::cerr << "Found pos: " << pos << "fewer than K elements " << K << " for query" << std::endl; + context.SetState(State::Success); } else { - context.SetState(State::Success); + context.SetState(State::Failure); } return retval; @@ -2387,14 +2386,13 @@ std::pair Index::search_with_filters(const if (pos == K) break; } - if (pos < K) + if (pos <= K) { - context.SetState(State::Failure); - diskann::cerr << "Found fewer than K elements for query" << std::endl; + context.SetState(State::Success); } else { - context.SetState(State::Success); + context.SetState(State::Failure); } return retval; From 8a3329e4291d9873c892de07f97246aff0e12a82 Mon Sep 17 00:00:00 2001 From: Neelam Mahapatro Date: Sat, 23 Dec 2023 00:27:17 +0530 Subject: [PATCH 2/7] GetAllowLessThanKResults in context class --- include/parameters.h | 10 ++++++++-- src/index.cpp | 12 ++++++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/include/parameters.h b/include/parameters.h index edde5df9c..8e5c99455 100644 --- a/include/parameters.h +++ b/include/parameters.h @@ -139,8 +139,8 @@ enum State : uint8_t template class IndexSearchContext { public: - IndexSearchContext(uint32_t time_limit_in_microseconds = 0u, uint32_t io_limit = UINT32_MAX) - : _time_limit_in_microseconds(time_limit_in_microseconds), _io_limit(io_limit), _result_state(State::Unknown) + IndexSearchContext(uint32_t time_limit_in_microseconds = 0u, uint32_t io_limit = UINT32_MAX, bool allowLessThanKResults = false) + : _time_limit_in_microseconds(time_limit_in_microseconds), _io_limit(io_limit), _result_state(State::Unknown), _allowLessThankResults(allowLessThanKResults) { _use_filter = false; _label = (LabelT)0; @@ -198,6 +198,11 @@ template class IndexSearchContext return _stats; } + bool GetAllowLessThanKResults() + { + return _allowLessThankResults; + } + private: uint32_t _time_limit_in_microseconds; uint32_t _io_limit; @@ -206,6 +211,7 @@ template class IndexSearchContext LabelT _label; Timer _timer; QueryStats _stats; + bool _allowLessThankResults; }; } // namespace diskann diff --git a/src/index.cpp b/src/index.cpp index 9655a074e..544d2f31b 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2251,10 +2251,14 @@ std::pair Index::search(const T *query, con break; } - if (pos <= K) + if (pos <= K && context.GetAllowLessThanKResults()) { context.SetState(State::Success); } + else if(pos < K) + { + context.SetState(State::Failure); + } else { context.SetState(State::Failure); @@ -2386,10 +2390,14 @@ std::pair Index::search_with_filters(const if (pos == K) break; } - if (pos <= K) + if (pos <= K && context.GetAllowLessThanKResults()) { context.SetState(State::Success); } + else if(pos < K) + { + context.SetState(State::Failure); + } else { context.SetState(State::Failure); From c252eea630a6df660bbb60e5f6896e71bac96682 Mon Sep 17 00:00:00 2001 From: Neelam Mahapatro Date: Wed, 17 Jan 2024 18:20:28 +0530 Subject: [PATCH 3/7] update result returned --- include/parameters.h | 12 ++++++++++++ src/index.cpp | 7 +++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/include/parameters.h b/include/parameters.h index 8e5c99455..0cb9ad985 100644 --- a/include/parameters.h +++ b/include/parameters.h @@ -144,6 +144,7 @@ template class IndexSearchContext { _use_filter = false; _label = (LabelT)0; + _total_result_returned = 0; } void SetLabel(LabelT label, bool use_filter) @@ -157,6 +158,16 @@ template class IndexSearchContext _result_state = state; } + void UpdateResultReturned(size_t result_returned) + { + _total_result_returned = result_returned; + } + + size_t GetResultReturned(size_t result_returned) + { + return _total_result_returned; + } + State GetState() const { return _result_state; @@ -212,6 +223,7 @@ template class IndexSearchContext Timer _timer; QueryStats _stats; bool _allowLessThankResults; + size_t _total_result_returned; }; } // namespace diskann diff --git a/src/index.cpp b/src/index.cpp index 544d2f31b..35d05f59b 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2251,17 +2251,20 @@ std::pair Index::search(const T *query, con break; } - if (pos <= K && context.GetAllowLessThanKResults()) + if (pos < K && context.GetAllowLessThanKResults()) { context.SetState(State::Success); + context.UpdateResultReturned(pos); } else if(pos < K) { context.SetState(State::Failure); + context.UpdateResultReturned(pos); } else { - context.SetState(State::Failure); + context.SetState(State::Success); + context.UpdateResultReturned(K); } return retval; From 1ced02261d268cf16c88edaadee6da5e67356579 Mon Sep 17 00:00:00 2001 From: Neelam Mahapatro Date: Wed, 17 Jan 2024 18:36:41 +0530 Subject: [PATCH 4/7] fix argument --- include/parameters.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/parameters.h b/include/parameters.h index 0cb9ad985..c9a701221 100644 --- a/include/parameters.h +++ b/include/parameters.h @@ -163,7 +163,7 @@ template class IndexSearchContext _total_result_returned = result_returned; } - size_t GetResultReturned(size_t result_returned) + size_t GetResultReturned() { return _total_result_returned; } From 8121569963ff11228e3b60ca9cade219b4cbc02c Mon Sep 17 00:00:00 2001 From: Neelam Mahapatro Date: Wed, 17 Jan 2024 19:15:19 +0530 Subject: [PATCH 5/7] update status for SSD index --- src/pq_flash_index.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index 68c54ea65..1d9cce796 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -1687,6 +1687,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t } context.SetState(State::Success); + context.UpdateResultReturned(k_search); } // range search returns results of all neighbors within distance of range. From a820f629259ea87d7cd5e4df3ca162dd66b90aff Mon Sep 17 00:00:00 2001 From: Neelam Mahapatro Date: Wed, 17 Jan 2024 22:47:41 +0530 Subject: [PATCH 6/7] rename to GetResultCount --- include/parameters.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/parameters.h b/include/parameters.h index c9a701221..2c16bf9e1 100644 --- a/include/parameters.h +++ b/include/parameters.h @@ -163,7 +163,7 @@ template class IndexSearchContext _total_result_returned = result_returned; } - size_t GetResultReturned() + size_t GetResultCount() { return _total_result_returned; } From 70264f97a959662e5a2a514533c6b5748ba854cb Mon Sep 17 00:00:00 2001 From: Neelam Mahapatro Date: Fri, 19 Jan 2024 12:09:01 +0530 Subject: [PATCH 7/7] fix index.cpp --- src/index.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index 35d05f59b..aac892573 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2393,17 +2393,20 @@ std::pair Index::search_with_filters(const if (pos == K) break; } - if (pos <= K && context.GetAllowLessThanKResults()) + if (pos < K && context.GetAllowLessThanKResults()) { context.SetState(State::Success); + context.UpdateResultReturned(pos); } else if(pos < K) { context.SetState(State::Failure); + context.UpdateResultReturned(pos); } else { - context.SetState(State::Failure); + context.SetState(State::Success); + context.UpdateResultReturned(K); } return retval;