From 764be36ec82b14de5b757ecf56988b07e6cf4b2f Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Mon, 2 Dec 2024 13:25:05 +0800 Subject: [PATCH] Fix insert with filtered label --- include/index.h | 7 ++++ src/index.cpp | 107 +++++++++++++++++++++++++++++------------------- 2 files changed, 73 insertions(+), 41 deletions(-) diff --git a/include/index.h b/include/index.h index 05c868e31..a5a2b35cc 100644 --- a/include/index.h +++ b/include/index.h @@ -384,6 +384,13 @@ template clas InMemQueryScratch *scratch, bool use_filter = false, uint32_t filteredLindex = 0); + void search_for_point_and_prune(int location, uint32_t Lindex, std::vector& pruned_list, + const std::vector& labels, + InMemQueryScratch* scratch, + uint32_t filteredLindex); + + void prune_search_result(int location, std::vector& pruned_list, InMemQueryScratch* scratch); + void prune_neighbors(const uint32_t location, std::vector &pool, std::vector &pruned_list, InMemQueryScratch *scratch); diff --git a/src/index.cpp b/src/index.cpp index dd732b9c8..d01017add 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1034,53 +1034,79 @@ void Index::search_for_point_and_prune(int location, uint32_t L { _data_store->get_vector(location, scratch->aligned_query()); iterate_to_fixed_point(scratch, Lindex, init_ids, false, unused_filter_label, false); + prune_search_result(location, pruned_list, scratch); } else { + std::vector labels; std::shared_lock tl(_tag_lock, std::defer_lock); if (_dynamic_index) tl.lock(); - std::vector filter_specific_start_nodes; - for (auto &x : _location_to_labels[location]) - filter_specific_start_nodes.emplace_back(_label_to_start_id[x]); + + labels = _location_to_labels[location]; if (_dynamic_index) tl.unlock(); - _data_store->get_vector(location, scratch->aligned_query()); - iterate_to_fixed_point(scratch, filteredLindex, filter_specific_start_nodes, true, - _location_to_labels[location], false); + search_for_point_and_prune(location, Lindex, pruned_list, labels, scratch, filteredLindex); + } + + assert(_graph_store->get_total_points() == _max_points); +} + +template +void Index::search_for_point_and_prune( + int location, uint32_t Lindex, + std::vector& pruned_list, + const std::vector& labels, + InMemQueryScratch* scratch, + uint32_t filteredLindex) +{ + std::vector filter_specific_start_nodes; + for (auto& x : labels) + filter_specific_start_nodes.emplace_back(_label_to_start_id[x]); + + _data_store->get_vector(location, scratch->aligned_query()); + iterate_to_fixed_point(scratch, filteredLindex, filter_specific_start_nodes, true, + labels, false); - if (Lindex > 0) + if (Lindex > 0) + { + // combine candidate pools obtained with filter and unfiltered criteria. + const std::vector init_ids = get_init_ids(); + const std::vector unused_filter_label; + std::set best_candidate_pool; + for (auto filtered_neighbor : scratch->pool()) { - // combine candidate pools obtained with filter and unfiltered criteria. - std::set best_candidate_pool; - for (auto filtered_neighbor : scratch->pool()) - { - best_candidate_pool.insert(filtered_neighbor); - } + best_candidate_pool.insert(filtered_neighbor); + } - // clear scratch for finding unfiltered candidates - scratch->clear(); + // clear scratch for finding unfiltered candidates + scratch->clear(); - _data_store->get_vector(location, scratch->aligned_query()); - iterate_to_fixed_point(scratch, Lindex, init_ids, false, unused_filter_label, false); + _data_store->get_vector(location, scratch->aligned_query()); + iterate_to_fixed_point(scratch, Lindex, init_ids, false, unused_filter_label, false); - for (auto unfiltered_neighbour : scratch->pool()) + for (auto unfiltered_neighbour : scratch->pool()) + { + // insert if this neighbour is not already in best_candidate_pool + if (best_candidate_pool.find(unfiltered_neighbour) == best_candidate_pool.end()) { - // insert if this neighbour is not already in best_candidate_pool - if (best_candidate_pool.find(unfiltered_neighbour) == best_candidate_pool.end()) - { - best_candidate_pool.insert(unfiltered_neighbour); - } + best_candidate_pool.insert(unfiltered_neighbour); } - - scratch->pool().clear(); - std::copy(best_candidate_pool.begin(), best_candidate_pool.end(), std::back_inserter(scratch->pool())); } + + scratch->pool().clear(); + std::copy(best_candidate_pool.begin(), best_candidate_pool.end(), std::back_inserter(scratch->pool())); } - auto &pool = scratch->pool(); + prune_search_result(location, pruned_list, scratch); +} + +template +void Index::prune_search_result(int location, std::vector& pruned_list, InMemQueryScratch* scratch) +{ + auto& pool = scratch->pool(); for (uint32_t i = 0; i < pool.size(); i++) { @@ -1099,7 +1125,6 @@ void Index::search_for_point_and_prune(int location, uint32_t L prune_neighbors(location, pool, pruned_list, scratch); assert(!pruned_list.empty()); - assert(_graph_store->get_total_points() == _max_points); } template @@ -3194,18 +3219,6 @@ int Index::insert_point(const T *point, const TagT tag, const s } // cant insert as active pts >= max_pts dl.unlock(); - if (_filtered_index) - { - // _location_to_labels[location] = labels; - auto bitsets = _bitmask_buf.get_bitmask(location); - memset(bitsets, 0, _bitmask_buf._bitmask_size); - simple_bitmask bm(bitsets, _bitmask_buf._bitmask_size); - for (LabelT label : labels) - { - bm.set(label); - } - } - // Insert tag and mapping to location if (_enable_tags) { @@ -3230,7 +3243,7 @@ int Index::insert_point(const T *point, const TagT tag, const s if (_filtered_index) { // when filtered the best_candidates will share the same label ( label_present > distance) - search_for_point_and_prune(location, _indexingQueueSize, pruned_list, scratch, true, _filterIndexingQueueSize); + search_for_point_and_prune(location, _indexingQueueSize, pruned_list, labels, scratch, _filterIndexingQueueSize); } else { @@ -3257,6 +3270,18 @@ int Index::insert_point(const T *point, const TagT tag, const s _graph_store->set_neighbours(location, neighbor_links); assert(_graph_store->get_neighbours(location).size() <= _indexingRange); + if (_filtered_index) + { + // _location_to_labels[location] = labels; + auto bitsets = _bitmask_buf.get_bitmask(location); + memset(bitsets, 0, _bitmask_buf._bitmask_size); + simple_bitmask bm(bitsets, _bitmask_buf._bitmask_size); + for (LabelT label : labels) + { + bm.set(label); + } + } + if (_conc_consolidate) tlock.unlock(); }