Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Porting multi-filter OR search support from the DLVS branch #546

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 49 additions & 12 deletions apps/search_disk_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "common_includes.h"
#include <boost/program_options.hpp>

#include "utils.h"
#include "index.h"
#include "disk_utils.h"
#include "math_utils.h"
Expand Down Expand Up @@ -47,6 +48,44 @@ void print_stats(std::string category, std::vector<float> percentiles, std::vect
diskann::cout << std::endl;
}

template<typename T, typename LabelT>
void parse_labels_of_query(const std::string &filters_for_query,
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> &pFlashIndex,
std::vector<LabelT> &label_ids_for_query)
{
std::vector<std::string> label_strs_for_query;
diskann::split_string(filters_for_query, MULTIPLE_LABEL_SEPARATOR, label_strs_for_query);
gopalrs marked this conversation as resolved.
Show resolved Hide resolved
for (auto &label_str_for_query : label_strs_for_query)
{
label_ids_for_query.push_back(pFlashIndex->get_converted_label(label_str_for_query));
}
}

template<typename T, typename LabelT>
void populate_label_ids(const std::vector<std::string> &filters_of_queries,
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> &pFlashIndex,
std::vector<std::vector<LabelT>> &label_ids_of_queries, bool apply_one_to_all, uint32_t query_count)
{
if (apply_one_to_all)
{
std::vector<LabelT> label_ids_of_query;
parse_labels_of_query(filters_of_queries[0], pFlashIndex, label_ids_of_query);
for (auto i = 0; i < query_count; i++)
{
label_ids_of_queries.push_back(label_ids_of_query);
}
}
else
{
for (auto &filters_of_query : filters_of_queries)
{
std::vector<LabelT> label_ids_of_query;
parse_labels_of_query(filters_of_query, pFlashIndex, label_ids_of_query);
label_ids_of_queries.push_back(label_ids_of_query);
}
}
}

template <typename T, typename LabelT = uint32_t>
int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix,
const std::string &result_output_prefix, const std::string &query_file, std::string &gt_file,
Expand Down Expand Up @@ -173,6 +212,14 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre
diskann::cout << "..done" << std::endl;
}

std::vector<std::vector<LabelT>> per_query_label_ids;
if (filtered_search)
{
populate_label_ids(query_filters, _pFlashIndex, per_query_label_ids, (query_filters.size() == 1), query_num );
}



diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
diskann::cout.precision(2);

Expand Down Expand Up @@ -236,19 +283,10 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre
}
else
{
LabelT label_for_search;
if (query_filters.size() == 1)
{ // one label for all queries
label_for_search = _pFlashIndex->get_converted_label(query_filters[0]);
}
else
{ // one label for each query
label_for_search = _pFlashIndex->get_converted_label(query_filters[i]);
}
_pFlashIndex->cached_beam_search(
query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at),
query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search,
use_reorder_data, stats + i);
query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, per_query_label_ids[i],
search_io_limit, use_reorder_data, stats + i);
gopalrs marked this conversation as resolved.
Show resolved Hide resolved
}
}
auto e = std::chrono::high_resolution_clock::now();
Expand Down Expand Up @@ -443,7 +481,6 @@ int main(int argc, char **argv)
{
query_filters = read_file_to_vector_of_strings(query_filters_file);
}

try
{
if (!query_filters.empty() && label_type == "ushort")
Expand Down
12 changes: 10 additions & 2 deletions include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT license.

#pragma once
#include <unordered_map>
#include "common_includes.h"

#include "aligned_file_reader.h"
Expand Down Expand Up @@ -35,6 +36,11 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
DISKANN_DLLEXPORT int load(uint32_t num_threads, const char *index_prefix);
#endif

DISKANN_DLLEXPORT void load_labels(const std::string& disk_index_filepath);
DISKANN_DLLEXPORT void load_label_medoid_map(
const std::string &labels_to_medoids_filepath, std::istream &medoid_stream);
DISKANN_DLLEXPORT void load_dummy_map(const std::string& dummy_map_filepath, std::istream &dummy_map_stream);

#ifdef EXEC_ENV_OLS
DISKANN_DLLEXPORT int load_from_separate_paths(diskann::MemoryMappedFiles &files, uint32_t num_threads,
const char *index_filepath, const char *pivots_filepath,
Expand Down Expand Up @@ -77,7 +83,7 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex

DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search,
uint64_t *res_ids, float *res_dists, const uint64_t beam_width,
const bool use_filter, const LabelT &filter_label,
const bool use_filter, const std::vector<LabelT> &filter_labels,
const uint32_t io_limit, const bool use_reorder_data = false,
QueryStats *stats = nullptr);

Expand Down Expand Up @@ -116,7 +122,9 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex

private:
DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id);
std::unordered_map<std::string, LabelT> load_label_map(std::basic_istream<char> &infile);
DISKANN_DLLEXPORT inline bool point_has_any_label(uint32_t point_id, const std::vector<LabelT> &label_ids);
void load_label_map(std::basic_istream<char> &map_reader,
std::unordered_map<std::string, LabelT> &string_to_int_map);
DISKANN_DLLEXPORT void parse_label_file(std::basic_istream<char> &infile, size_t &num_pts_labels);
DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts,
uint32_t &num_total_labels);
Expand Down
9 changes: 4 additions & 5 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ typedef int FileHandle;

#define PBSTR "||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||"
#define PBWIDTH 60
#define MULTIPLE_LABEL_SEPARATOR "|"

inline bool file_exists_impl(const std::string &name, bool dirCheck = false)
{
Expand Down Expand Up @@ -683,6 +684,9 @@ DISKANN_DLLEXPORT double calculate_range_search_recall(unsigned num_queries,
std::vector<std::vector<uint32_t>> &groundtruth,
std::vector<std::vector<uint32_t>> &our_results);

DISKANN_DLLEXPORT void split_string(const std::string &string_to_split, const std::string &delimiter,
std::vector<std::string> &pieces);

template <typename T>
inline void load_bin(const std::string &bin_file, std::unique_ptr<T[]> &data, size_t &npts, size_t &dim,
size_t offset = 0)
Expand Down Expand Up @@ -1101,11 +1105,6 @@ inline std::vector<std::string> read_file_to_vector_of_strings(const std::string
{
break;
}
if (line.find(',') != std::string::npos)
{
std::cerr << "Every query must have exactly one filter" << std::endl;
exit(-1);
}
if (!line.empty() && (line.back() == '\r' || line.back() == '\n'))
{
line.erase(line.size() - 1);
Expand Down
Loading