Skip to content

Commit

Permalink
first cut code for post processing after single filter search
Browse files Browse the repository at this point in the history
  • Loading branch information
rakri committed Nov 2, 2023
1 parent 3d58ceb commit 97e0e47
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 7 deletions.
24 changes: 19 additions & 5 deletions apps/search_memory_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <unistd.h>
#endif

#include "filter_utils.h"
#include "index.h"
#include "memory_mapper.h"
#include "utils.h"
Expand All @@ -30,7 +31,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
const std::string &query_file, const std::string &truthset_file, const uint32_t num_threads,
const uint32_t recall_at, const bool print_all_recalls, const std::vector<uint32_t> &Lvec,
const bool dynamic, const bool tags, const bool show_qps_per_thread,
const std::vector<std::string> &query_filters, const float fail_if_recall_below)
const std::vector<label_set> &query_filters, const float fail_if_recall_below)
{
using TagT = uint32_t;
// Load the query file
Expand Down Expand Up @@ -165,12 +166,20 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
auto qs = std::chrono::high_resolution_clock::now();
if (filtered_search)
{
std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i];
auto raw_filters = query_filters.size() == 1 ? query_filters[0] : query_filters[i];

if (raw_filters.size() == 1) {
auto raw_filter = *(raw_filters.begin());
auto retval = index->search_with_filters(query + i * query_aligned_dim, raw_filter, recall_at, L,
query_result_ids[test_id].data() + i * recall_at,
query_result_dists[test_id].data() + i * recall_at);
cmp_stats[i] = retval.second;
} else {
auto retval = index->conjunctive_search_by_postprocessing(query + i * query_aligned_dim, raw_filters, recall_at, L,
query_result_ids[test_id].data() + i * recall_at,
query_result_dists[test_id].data() + i * recall_at);
cmp_stats[i] = retval.second;
}
}
else if (metric == diskann::FAST_L2)
{
Expand Down Expand Up @@ -392,14 +401,19 @@ int main(int argc, char **argv)
return -1;
}

std::vector<std::string> query_filters;
std::vector<label_set> query_filters;
if (filter_label != "")
{
query_filters.push_back(filter_label);
label_set single_filter_set;
single_filter_set.insert(filter_label);
query_filters.push_back(single_filter_set);
}
else if (query_filters_file != "")
{
query_filters = read_file_to_vector_of_strings(query_filters_file);
tsl::robin_map<std::string, uint32_t> label_counts;
label_set unique_labels;
std::tie(query_filters, label_counts, unique_labels) = diskann::parse_label_file(query_filters_file);
// query_filters = read_file_to_vector_of_strings(query_filters_file);
}

try
Expand Down
12 changes: 12 additions & 0 deletions include/abstract_index.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include "distance.h"
#include "filter_utils.h"
#include "parameters.h"
#include "utils.h"
#include "types.h"
Expand Down Expand Up @@ -78,6 +79,12 @@ class AbstractIndex
const size_t K, const uint32_t L, IndexType *indices,
float *distances);

template <typename IndexType>
std::pair<uint32_t, uint32_t> conjunctive_search_by_postprocessing(const DataType &query, const label_set &raw_label,
const size_t K, const uint32_t L, IndexType *indices,
float *distances);


// insert points with labels, labels should be present for filtered index
template <typename data_type, typename tag_type, typename label_type>
int insert_point(const data_type *point, const tag_type tag, const std::vector<label_type> &labels);
Expand Down Expand Up @@ -112,6 +119,11 @@ class AbstractIndex
virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query, const std::string &filter_label,
const size_t K, const uint32_t L, std::any &indices,
float *distances) = 0;
virtual std::pair<uint32_t, uint32_t> _conjunctive_search_by_postprocessing(const DataType &query,
const label_set &raw_filter_labels, const size_t K,
const uint32_t L, std::any &indices,
float *distances) = 0;

virtual int _insert_point(const DataType &data_point, const TagType tag, Labelvector &labels) = 0;
virtual int _insert_point(const DataType &data_point, const TagType tag) = 0;
virtual int _lazy_delete(const TagType &tag) = 0;
Expand Down
2 changes: 1 addition & 1 deletion include/filter_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ template <typename LabelT>
DISKANN_DLLEXPORT std::tuple<std::vector<std::vector<LabelT>>, tsl::robin_set<LabelT>> parse_formatted_label_file(
path label_file);

DISKANN_DLLEXPORT parse_label_file_return_values parse_label_file(path label_data_path, std::string universal_label);
DISKANN_DLLEXPORT parse_label_file_return_values parse_label_file(path label_data_path, std::string universal_label = "");

template <typename T>
DISKANN_DLLEXPORT tsl::robin_map<std::string, std::vector<uint32_t>> generate_label_specific_vector_files_compat(
Expand Down
15 changes: 15 additions & 0 deletions include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
const size_t K, const uint32_t L,
IndexType *indices, float *distances);

// Filter support search
template <typename IndexType>
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> conjunctive_search_by_postprocessing(const T *query, const std::vector<LabelT> &filter_labels,
const size_t K, const uint32_t L,
IndexType *indices, float *distances);


// Will fail if tag already in the index or if tag=0.
DISKANN_DLLEXPORT int insert_point(const T *point, const TagT tag);

Expand Down Expand Up @@ -199,11 +206,18 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas

virtual std::pair<uint32_t, uint32_t> _search(const DataType &query, const size_t K, const uint32_t L,
std::any &indices, float *distances = nullptr) override;

virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query,
const std::string &filter_label_raw, const size_t K,
const uint32_t L, std::any &indices,
float *distances) override;

virtual std::pair<uint32_t, uint32_t> _conjunctive_search_by_postprocessing(const DataType &query,
const label_set &raw_filter_labels, const size_t K,
const uint32_t L, std::any &indices,
float *distances) override;


virtual int _insert_point(const DataType &data_point, const TagType tag) override;
virtual int _insert_point(const DataType &data_point, const TagType tag, Labelvector &labels) override;

Expand Down Expand Up @@ -376,6 +390,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
std::string _labels_file;
std::unordered_map<LabelT, uint32_t> _label_to_start_id;
std::unordered_map<uint32_t, uint32_t> _medoid_counts;
std::unordered_map<LabelT, uint32_t> _label_counts;

bool _use_universal_label = false;
LabelT _universal_label = 0;
Expand Down
18 changes: 18 additions & 0 deletions src/abstract_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ std::pair<uint32_t, uint32_t> AbstractIndex::search_with_filters(const DataType
return _search_with_filters(query, raw_label, K, L, any_indices, distances);
}

template <typename IndexType>
std::pair<uint32_t, uint32_t> AbstractIndex::conjunctive_search_by_postprocessing(const DataType &query, const label_set &raw_labels,
const size_t K, const uint32_t L, IndexType *indices,
float *distances)
{
auto any_indices = std::any(indices);
return _conjunctive_search_by_postprocessing(query, raw_labels, K, L, any_indices, distances);
}


template <typename data_type>
void AbstractIndex::search_with_optimized_layout(const data_type *query, size_t K, size_t L, uint32_t *indices)
{
Expand Down Expand Up @@ -162,6 +172,14 @@ template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search_w
const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint64_t *indices,
float *distances);

template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::conjunctive_search_by_postprocessing<uint32_t>(
const DataType &query, const label_set &raw_label, const size_t K, const uint32_t L, uint32_t *indices,
float *distances);

template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::conjunctive_search_by_postprocessing<uint64_t>(
const DataType &query, const label_set &raw_label, const size_t K, const uint32_t L, uint64_t *indices,
float *distances);

template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, int32_t>(const float *query, const uint64_t K,
const uint32_t L, int32_t *tags,
float *distances,
Expand Down
Loading

0 comments on commit 97e0e47

Please sign in to comment.