Skip to content

Commit

Permalink
Fix insert with filtered label
Browse files Browse the repository at this point in the history
  • Loading branch information
Sanhaoji2 committed Dec 2, 2024
1 parent f95a526 commit 764be36
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 41 deletions.
7 changes: 7 additions & 0 deletions include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,13 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
InMemQueryScratch<T> *scratch, bool use_filter = false,
uint32_t filteredLindex = 0);

void search_for_point_and_prune(int location, uint32_t Lindex, std::vector<uint32_t>& pruned_list,
const std::vector<LabelT>& labels,
InMemQueryScratch<T>* scratch,
uint32_t filteredLindex);

void prune_search_result(int location, std::vector<uint32_t>& pruned_list, InMemQueryScratch<T>* scratch);

void prune_neighbors(const uint32_t location, std::vector<Neighbor> &pool, std::vector<uint32_t> &pruned_list,
InMemQueryScratch<T> *scratch);

Expand Down
107 changes: 66 additions & 41 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1034,53 +1034,79 @@ void Index<T, TagT, LabelT>::search_for_point_and_prune(int location, uint32_t L
{
_data_store->get_vector(location, scratch->aligned_query());
iterate_to_fixed_point(scratch, Lindex, init_ids, false, unused_filter_label, false);
prune_search_result(location, pruned_list, scratch);
}
else
{
std::vector<LabelT> labels;
std::shared_lock<std::shared_timed_mutex> tl(_tag_lock, std::defer_lock);
if (_dynamic_index)
tl.lock();
std::vector<uint32_t> filter_specific_start_nodes;
for (auto &x : _location_to_labels[location])
filter_specific_start_nodes.emplace_back(_label_to_start_id[x]);

labels = _location_to_labels[location];

if (_dynamic_index)
tl.unlock();

_data_store->get_vector(location, scratch->aligned_query());
iterate_to_fixed_point(scratch, filteredLindex, filter_specific_start_nodes, true,
_location_to_labels[location], false);
search_for_point_and_prune(location, Lindex, pruned_list, labels, scratch, filteredLindex);
}

assert(_graph_store->get_total_points() == _max_points);
}

template <typename T, typename TagT, typename LabelT>
void Index<T, TagT, LabelT>::search_for_point_and_prune(
int location, uint32_t Lindex,
std::vector<uint32_t>& pruned_list,
const std::vector<LabelT>& labels,
InMemQueryScratch<T>* scratch,
uint32_t filteredLindex)
{
std::vector<uint32_t> filter_specific_start_nodes;
for (auto& x : labels)
filter_specific_start_nodes.emplace_back(_label_to_start_id[x]);

_data_store->get_vector(location, scratch->aligned_query());
iterate_to_fixed_point(scratch, filteredLindex, filter_specific_start_nodes, true,
labels, false);

if (Lindex > 0)
if (Lindex > 0)
{
// combine candidate pools obtained with filter and unfiltered criteria.
const std::vector<uint32_t> init_ids = get_init_ids();
const std::vector<LabelT> unused_filter_label;
std::set<Neighbor> best_candidate_pool;
for (auto filtered_neighbor : scratch->pool())
{
// combine candidate pools obtained with filter and unfiltered criteria.
std::set<Neighbor> best_candidate_pool;
for (auto filtered_neighbor : scratch->pool())
{
best_candidate_pool.insert(filtered_neighbor);
}
best_candidate_pool.insert(filtered_neighbor);
}

// clear scratch for finding unfiltered candidates
scratch->clear();
// clear scratch for finding unfiltered candidates
scratch->clear();

_data_store->get_vector(location, scratch->aligned_query());
iterate_to_fixed_point(scratch, Lindex, init_ids, false, unused_filter_label, false);
_data_store->get_vector(location, scratch->aligned_query());
iterate_to_fixed_point(scratch, Lindex, init_ids, false, unused_filter_label, false);

for (auto unfiltered_neighbour : scratch->pool())
for (auto unfiltered_neighbour : scratch->pool())
{
// insert if this neighbour is not already in best_candidate_pool
if (best_candidate_pool.find(unfiltered_neighbour) == best_candidate_pool.end())
{
// insert if this neighbour is not already in best_candidate_pool
if (best_candidate_pool.find(unfiltered_neighbour) == best_candidate_pool.end())
{
best_candidate_pool.insert(unfiltered_neighbour);
}
best_candidate_pool.insert(unfiltered_neighbour);
}

scratch->pool().clear();
std::copy(best_candidate_pool.begin(), best_candidate_pool.end(), std::back_inserter(scratch->pool()));
}

scratch->pool().clear();
std::copy(best_candidate_pool.begin(), best_candidate_pool.end(), std::back_inserter(scratch->pool()));
}

auto &pool = scratch->pool();
prune_search_result(location, pruned_list, scratch);
}

template <typename T, typename TagT, typename LabelT>
void Index<T, TagT, LabelT>::prune_search_result(int location, std::vector<uint32_t>& pruned_list, InMemQueryScratch<T>* scratch)
{
auto& pool = scratch->pool();

for (uint32_t i = 0; i < pool.size(); i++)
{
Expand All @@ -1099,7 +1125,6 @@ void Index<T, TagT, LabelT>::search_for_point_and_prune(int location, uint32_t L
prune_neighbors(location, pool, pruned_list, scratch);

assert(!pruned_list.empty());
assert(_graph_store->get_total_points() == _max_points);
}

template <typename T, typename TagT, typename LabelT>
Expand Down Expand Up @@ -3194,18 +3219,6 @@ int Index<T, TagT, LabelT>::insert_point(const T *point, const TagT tag, const s
} // cant insert as active pts >= max_pts
dl.unlock();

if (_filtered_index)
{
// _location_to_labels[location] = labels;
auto bitsets = _bitmask_buf.get_bitmask(location);
memset(bitsets, 0, _bitmask_buf._bitmask_size);
simple_bitmask bm(bitsets, _bitmask_buf._bitmask_size);
for (LabelT label : labels)
{
bm.set(label);
}
}

// Insert tag and mapping to location
if (_enable_tags)
{
Expand All @@ -3230,7 +3243,7 @@ int Index<T, TagT, LabelT>::insert_point(const T *point, const TagT tag, const s
if (_filtered_index)
{
// when filtered the best_candidates will share the same label ( label_present > distance)
search_for_point_and_prune(location, _indexingQueueSize, pruned_list, scratch, true, _filterIndexingQueueSize);
search_for_point_and_prune(location, _indexingQueueSize, pruned_list, labels, scratch, _filterIndexingQueueSize);
}
else
{
Expand All @@ -3257,6 +3270,18 @@ int Index<T, TagT, LabelT>::insert_point(const T *point, const TagT tag, const s
_graph_store->set_neighbours(location, neighbor_links);
assert(_graph_store->get_neighbours(location).size() <= _indexingRange);

if (_filtered_index)
{
// _location_to_labels[location] = labels;
auto bitsets = _bitmask_buf.get_bitmask(location);
memset(bitsets, 0, _bitmask_buf._bitmask_size);
simple_bitmask bm(bitsets, _bitmask_buf._bitmask_size);
for (LabelT label : labels)
{
bm.set(label);
}
}

if (_conc_consolidate)
tlock.unlock();
}
Expand Down

0 comments on commit 764be36

Please sign in to comment.