diff --git a/include/in_mem_filter_store.h b/include/in_mem_filter_store.h index 0458dd564..5ac1b1b26 100644 --- a/include/in_mem_filter_store.h +++ b/include/in_mem_filter_store.h @@ -1,90 +1,123 @@ #pragma once #include +#include "windows_customizations.h" #include "tsl/robin_map.h" #include "tsl/robin_set.h" #include -namespace diskann -{ +namespace diskann { template - class InMemFilterStore : public AbstractFilterStore - { - public: - /// - /// Returns the filters for a data point. Only valid for base points - /// - /// base point id - /// list of filters of the base point - virtual const std::vector &get_filters_for_point(location_t point) const override; - - /// - /// Adds filters for a point. - /// - /// - /// - virtual void add_filters_for_point(location_t point, const std::vector &filters) override; - - /// - /// Returns a score between [0,1] indicating how many points in the dataset - /// matched the predicate - /// - /// Predicate to match - /// Score between [0,1] indicate %age of points matching pred - virtual float get_predicate_selectivity(const AbstractPredicate &pred) const override; - - - virtual const std::unordered_map>& get_label_to_medoids() const; - - virtual const std::vector &get_medoids_of_label(const LabelT label) const; - - virtual void set_universal_label(const LabelT univ_label); - - inline bool point_has_label(location_t point_id, const LabelT label_id) const; - - inline bool is_dummy_point(location_t id) const; - - inline bool point_has_label_or_universal_label(location_t point_id, const LabelT label_id) const; - - inline LabelT get_converted_label(const std::string &filter_label) const; - - //Returns true if the index is filter-enabled and all files were loaded correctly. - //false otherwise. Note that "false" can mean that the index does not have filter support, - //or that some index files do not exist, or that they exist and could not be opened. - bool load(const std::string& disk_index_file); + class InMemFilterStore : public AbstractFilterStore { + public: + // Do nothing constructor because all the work is done in load() + DISKANN_DLLEXPORT InMemFilterStore() { + } + + /// + /// Destructor + /// + DISKANN_DLLEXPORT virtual ~InMemFilterStore(); + + // No copy, no assignment. + DISKANN_DLLEXPORT InMemFilterStore &operator=( + const InMemFilterStore &v) = delete; + DISKANN_DLLEXPORT InMemFilterStore(const InMemFilterStore &v) = + delete; + + DISKANN_DLLEXPORT virtual bool has_filter_support() const; + + /// + /// Returns the filters for a data point. Only valid for base points + /// + /// base point id + /// list of filters of the base point + DISKANN_DLLEXPORT virtual const std::vector &get_filters_for_point( + location_t point) const override; + + /// + /// Adds filters for a point. + /// + /// + /// + DISKANN_DLLEXPORT virtual void add_filters_for_point( + location_t point, const std::vector &filters) override; + + /// + /// Returns a score between [0,1] indicating how many points in the dataset + /// matched the predicate + /// + /// Predicate to match + /// Score between [0,1] indicate %age of points matching + /// pred + DISKANN_DLLEXPORT virtual float get_predicate_selectivity( + const AbstractPredicate &pred) const override; + + DISKANN_DLLEXPORT virtual const std::unordered_map> + &get_label_to_medoids() const; + + DISKANN_DLLEXPORT virtual const std::vector + &get_medoids_of_label(const LabelT label) ; + + DISKANN_DLLEXPORT virtual void set_universal_label(const LabelT univ_label); + + DISKANN_DLLEXPORT inline bool point_has_label(location_t point_id, + const LabelT label_id) const; + + DISKANN_DLLEXPORT inline bool is_dummy_point(location_t id) const; + + DISKANN_DLLEXPORT inline location_t get_real_point_for_dummy( + location_t dummy_id); + + DISKANN_DLLEXPORT inline bool point_has_label_or_universal_label( + location_t point_id, const LabelT label_id) const; + + DISKANN_DLLEXPORT inline LabelT get_converted_label( + const std::string &filter_label); + + // Returns true if the index is filter-enabled and all files were loaded + // correctly. false otherwise. Note that "false" can mean that the index + // does not have filter support, or that some index files do not exist, or + // that they exist and could not be opened. + DISKANN_DLLEXPORT bool load(const std::string &disk_index_file); + + DISKANN_DLLEXPORT void generate_random_labels(std::vector &labels, + const uint32_t num_labels, + const uint32_t nthreads); private: - - // Load functions for search START - void load_label_file(const std::string_view& file_content); - void load_label_map(std::basic_istream &map_reader); - void load_labels_to_medoids(std::basic_istream &reader); - void load_dummy_map(std::basic_istream &dummy_map_stream); - - bool load_file_and_parse( - const std::string &filename, - void (*parse_fn)(const std::string_view &content)); - - bool load_file_and_parse( - const std::string &filename, - void (*parse_fn)(std::basic_istream &stream)) - - - // Load functions for search END - - // filter support - uint32_t *_pts_to_label_offsets = nullptr; - uint32_t *_pts_to_label_counts = nullptr; - LabelT *_pts_to_labels = nullptr; - std::unordered_map> _filter_to_medoid_ids; - bool _use_universal_label = false; - LabelT _universal_filter_label; - tsl::robin_set _dummy_pts; - tsl::robin_set _has_dummy_pts; - tsl::robin_map _dummy_to_real_map; - tsl::robin_map> _real_to_dummy_map; - std::unordered_map _label_map; - + // Load functions for search START + void load_label_file(const std::string_view &file_content); + void load_label_map(std::basic_istream &map_reader); + void load_labels_to_medoids(std::basic_istream &reader); + void load_dummy_map(std::basic_istream &dummy_map_stream); + void parse_universal_label(const std::string_view &content); + void get_label_file_metadata(const std::string_view &fileContent, + uint32_t &num_pts, uint32_t &num_total_labels); + + bool load_file_and_parse(const std::string &filename, + void (InMemFilterStore::*parse_fn)(const std::string_view &content)); + bool parse_stream( + const std::string &filename, + void (InMemFilterStore::*parse_fn)(std::basic_istream &stream)); + + void reset_stream_for_reading(std::basic_istream &infile); + // Load functions for search END + + location_t _num_points = 0; + location_t *_pts_to_label_offsets = nullptr; + location_t *_pts_to_label_counts = nullptr; + LabelT *_pts_to_labels = nullptr; + bool _use_universal_label = false; + LabelT _universal_filter_label; + tsl::robin_set _dummy_pts; + tsl::robin_set _has_dummy_pts; + tsl::robin_map _dummy_to_real_map; + tsl::robin_map> _real_to_dummy_map; + std::unordered_map _label_map; + std::unordered_map> _filter_to_medoid_ids; + bool _is_valid = false; }; -} +} // namespace diskann diff --git a/include/pq_flash_index.h b/include/pq_flash_index.h index e199710fa..7f451371f 100644 --- a/include/pq_flash_index.h +++ b/include/pq_flash_index.h @@ -83,7 +83,6 @@ template class PQFlashIndex const uint32_t io_limit, const bool use_reorder_data = false, QueryStats *stats = nullptr); - DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &filter_label); DISKANN_DLLEXPORT uint32_t range_search(const T *query1, const double range, const uint64_t min_l_search, const uint64_t max_l_search, std::vector &indices, @@ -114,18 +113,7 @@ template class PQFlashIndex DISKANN_DLLEXPORT void use_medoids_data_as_centroids(); DISKANN_DLLEXPORT void setup_thread_data(uint64_t nthreads, uint64_t visited_reserve = 4096); - DISKANN_DLLEXPORT void set_universal_label(const LabelT &label); - private: - DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id); - std::unordered_map load_label_map(std::basic_istream &infile); - DISKANN_DLLEXPORT void parse_label_file(std::basic_istream &infile, size_t &num_pts_labels); - DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts, - uint32_t &num_total_labels); - DISKANN_DLLEXPORT void generate_random_labels(std::vector &labels, const uint32_t num_labels, - const uint32_t nthreads); - void reset_stream_for_reading(std::basic_istream &infile); - // sector # on disk where node_id is present with in the graph part DISKANN_DLLEXPORT uint64_t get_node_sector(uint64_t node_id); @@ -225,7 +213,8 @@ template class PQFlashIndex //Moved filter-specific data structures to in_mem_filter_store. //TODO: Make this a unique pointer - InMemFilterStore* _filter_store; + bool _filter_index = false; + std::unique_ptr> _filter_store; #ifdef EXEC_ENV_OLS diff --git a/src/dll/CMakeLists.txt b/src/dll/CMakeLists.txt index 096d1b76e..633e8edb7 100644 --- a/src/dll/CMakeLists.txt +++ b/src/dll/CMakeLists.txt @@ -4,7 +4,8 @@ add_library(${PROJECT_NAME} SHARED dllmain.cpp ../abstract_data_store.cpp ../partition.cpp ../pq.cpp ../pq_flash_index.cpp ../logger.cpp ../utils.cpp ../windows_aligned_file_reader.cpp ../distance.cpp ../pq_l2_distance.cpp ../memory_mapper.cpp ../index.cpp ../in_mem_data_store.cpp ../pq_data_store.cpp ../in_mem_graph_store.cpp ../math_utils.cpp ../disk_utils.cpp ../filter_utils.cpp - ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp) + ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp + ../in_mem_filter_store.cpp) set(TARGET_DIR "$<$:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}>$<$:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}>") diff --git a/src/in_mem_filter_store.cpp b/src/in_mem_filter_store.cpp index 2bf2c79ed..67e26b6b7 100644 --- a/src/in_mem_filter_store.cpp +++ b/src/in_mem_filter_store.cpp @@ -1,18 +1,49 @@ #include +#include +#include +#include #include -#include -#include -#include +#include +#include +#include +#include "tsl/robin_map.h" +#include "tsl/robin_set.h" +#include "utils.h" +#include "ann_exception.h" +#include "in_mem_filter_store.h" +#include "multi_filter/abstract_predicate.h" +#include "multi_filter/simple_boolean_predicate.h" namespace diskann { + //TODO: Move to utils.h + DISKANN_DLLEXPORT std::unique_ptr get_file_content(const std::string &filename, + uint64_t & file_size); + + template + InMemFilterStore::~InMemFilterStore() { + if (_pts_to_label_offsets != nullptr) { + delete[] _pts_to_label_offsets; + _pts_to_label_offsets = nullptr; + } + if (_pts_to_label_counts != nullptr) { + delete[] _pts_to_label_counts; + _pts_to_label_counts = nullptr; + } + if (_pts_to_labels != nullptr) { + delete[] _pts_to_labels; + _pts_to_labels = nullptr; + } + } template const std::vector &InMemFilterStore::get_filters_for_point( location_t point) const { + throw ANNException("Not implemented!", -1); } template void InMemFilterStore::add_filters_for_point( location_t point, const std::vector &filters) { + throw ANNException("Not implemented!", -1); } template @@ -24,11 +55,21 @@ namespace diskann { template const std::unordered_map> &InMemFilterStore::get_label_to_medoids() const { + return this->_filter_to_medoid_ids; } template const std::vector &InMemFilterStore::get_medoids_of_label( - const LabelT label) const { + const LabelT label) { + if (_filter_to_medoid_ids.find(label) != _filter_to_medoid_ids.end()) { + return this->_filter_to_medoid_ids[label]; + } else { + std::stringstream ss; + ss << "Could not find " << label << " in filters_to_medoid_ids map." + << std::endl; + diskann::cerr << ss.str(); + throw ANNException(ss.str(), -1); + } } template @@ -57,16 +98,27 @@ namespace diskann { return _dummy_pts.find(id) != _dummy_pts.end(); } + template + inline location_t InMemFilterStore::get_real_point_for_dummy( + location_t dummy_id) { + if (is_dummy_point(dummy_id)) { + return _dummy_to_real_map[dummy_id]; + } else { + return dummy_id; // it is a real point. + } + } + template inline bool InMemFilterStore::point_has_label_or_universal_label( location_t id, const LabelT filter_label) const { - return point_has_label(id, filter_label) || - (_use_universal_label && point_has_label(id, _universal_filter_label)); + return point_has_label(id, filter_label) || + (_use_universal_label && + point_has_label(id, _universal_filter_label)); } template inline LabelT InMemFilterStore::get_converted_label( - const std::string &filter_label) const { + const std::string &filter_label) { if (_label_map.find(filter_label) != _label_map.end()) { return _label_map[filter_label]; } @@ -95,116 +147,127 @@ namespace diskann { // reading them as bytes, not sure if that can cause an issue with UTF // encodings. bool has_filters = true; - if (false == load_file_and_parse(labels_file, &load_label_file)) { + if (false == load_file_and_parse( + labels_file, &InMemFilterStore::load_label_file)) { diskann::cout << "Index does not have filter data. " << std::endl; return false; } - if (false == load_file_and_parse(labels_map_file, &load_label_map)) { + if (false == parse_stream(labels_map_file, + &InMemFilterStore::load_label_map)) { diskann::cerr << "Failed to find file: " << labels_map_file << " while labels_file exists." << std::endl; return false; } + if (false == - load_file_and_parse(labels_to_medoids, &load_labels_to_medoids)) { + parse_stream(labels_to_medoids, + &InMemFilterStore::load_labels_to_medoids)) { diskann::cerr << "Failed to find file: " << labels_to_medoids << " while labels file exists." << std::endl; return false; } // missing universal label file is NOT an error. load_file_and_parse(univ_label_file, - [this](const std::string_view &content) { - label_as_num = (LabelT) std::strtoul(univ_label); - this->set_universal_label(label_as_num); - }); + &InMemFilterStore::parse_universal_label); // missing dummy map file is also NOT an error. - load_file_and_parse(dummy_map_file, &load_dummy_map); - return true; + parse_stream(dummy_map_file, &InMemFilterStore::load_dummy_map); + _is_valid = true; + return _is_valid; } + template + bool InMemFilterStore::has_filter_support() const { + return _is_valid; + } -// TODO: Improve this to not load the entire file in memory -template -void InMemFilterStore::load_label_file( - const std::string_view &label_file_content) { - std::string line; - uint32_t line_cnt = 0; - - uint32_t num_pts_in_label_file; - uint32_t num_total_labels; - get_label_file_metadata(label_file_content, num_pts_in_label_file, - num_total_labels); - - _pts_to_label_offsets = new uint32_t[num_pts_in_label_file]; - _pts_to_label_counts = new uint32_t[num_pts_in_label_file]; - _pts_to_labels = new LabelT[num_total_labels]; - uint32_t labels_seen_so_far = 0; - - std::string label_str; - size_t cur_pos = 0; - size_t next_pos = 0; - while (cur_pos < file_size && cur_pos != std::string::npos) { - next_pos = label_file_content.find('\n', cur_pos); - if (next_pos == std::string::npos) { - break; - } - - _pts_to_label_offsets[line_cnt] = labels_seen_so_far; - uint32_t &num_lbls_in_cur_pt = _pts_to_label_counts[line_cnt]; - num_lbls_in_cur_pt = 0; - - size_t lbl_pos = cur_pos; - size_t next_lbl_pos = 0; - while (lbl_pos < next_pos && lbl_pos != std::string::npos) { - next_lbl_pos = label_file_content.find(',', lbl_pos); - if (next_lbl_pos == - std::string::npos) // the last label in the whole file - { - next_lbl_pos = next_pos; + // TODO: Improve this to not load the entire file in memory + template + void InMemFilterStore::load_label_file( + const std::string_view &label_file_content) { + std::string line; + uint32_t line_cnt = 0; + + uint32_t num_pts_in_label_file; + uint32_t num_total_labels; + get_label_file_metadata(label_file_content, num_pts_in_label_file, + num_total_labels); + + _pts_to_label_offsets = new uint32_t[num_pts_in_label_file]; + _pts_to_label_counts = new uint32_t[num_pts_in_label_file]; + _pts_to_labels = new LabelT[num_total_labels]; + uint32_t labels_seen_so_far = 0; + + std::string label_str; + size_t cur_pos = 0; + size_t next_pos = 0; + size_t file_size = label_file_content.size(); + + while (cur_pos < file_size && cur_pos != std::string_view::npos) { + next_pos = label_file_content.find('\n', cur_pos); + if (next_pos == std::string_view::npos) { + break; } - if (next_lbl_pos > - next_pos) // the last label in one line, just read to the end - { - next_lbl_pos = next_pos; - } + _pts_to_label_offsets[line_cnt] = labels_seen_so_far; + uint32_t &num_lbls_in_cur_pt = _pts_to_label_counts[line_cnt]; + num_lbls_in_cur_pt = 0; + + size_t lbl_pos = cur_pos; + size_t next_lbl_pos = 0; + while (lbl_pos < next_pos && lbl_pos != std::string_view::npos) { + next_lbl_pos = label_file_content.find(',', lbl_pos); + if (next_lbl_pos == + std::string_view::npos) // the last label in the whole file + { + next_lbl_pos = next_pos; + } - label_str.assign(label_file_content.c_str() + lbl_pos, - next_lbl_pos - lbl_pos); - if (label_str[label_str.length() - 1] == - '\t') // '\t' won't exist in label file? - { - label_str.erase(label_str.length() - 1); - } + if (next_lbl_pos > + next_pos) // the last label in one line, just read to the end + { + next_lbl_pos = next_pos; + } - LabelT token_as_num = (LabelT) std::stoul(label_str); - _pts_to_labels[labels_seen_so_far++] = (LabelT) token_as_num; - num_lbls_in_cur_pt++; + // TODO: SHOULD NOT EXPECT label_file_content TO BE NULL_TERMINATED + label_str.assign(label_file_content.data() + lbl_pos, + next_lbl_pos - lbl_pos); + if (label_str[label_str.length() - 1] == + '\t') // '\t' won't exist in label file? + { + label_str.erase(label_str.length() - 1); + } - // move to next label - lbl_pos = next_lbl_pos + 1; - } + LabelT token_as_num = (LabelT) std::stoul(label_str); + _pts_to_labels[labels_seen_so_far++] = (LabelT) token_as_num; + num_lbls_in_cur_pt++; - // move to next line - cur_pos = next_pos + 1; + // move to next label + lbl_pos = next_lbl_pos + 1; + } + + // move to next line + cur_pos = next_pos + 1; - if (num_lbls_in_cur_pt == 0) { - diskann::cout << "No label found for point " << line_cnt << std::endl; - exit(-1); + if (num_lbls_in_cur_pt == 0) { + diskann::cout << "No label found for point " << line_cnt << std::endl; + exit(-1); + } + + line_cnt++; } - line_cnt++; + // TODO: We need to check if the number of labels and the number of points + // is as expected. Maybe add the check in PQFlashIndex? + // num_points_labels = line_cnt; } - num_points_labels = line_cnt; -} - -template -void InMemFilterStore::load_labels_to_medoids(std::basic_istream& medoid_stream) { - std::string line, token; + template + void InMemFilterStore::load_labels_to_medoids( + std::basic_istream &medoid_stream) { + std::string line, token; - _filter_to_medoid_ids.clear(); - try { + _filter_to_medoid_ids.clear(); while (std::getline(medoid_stream, line)) { std::istringstream iss(line); uint32_t cnt = 0; @@ -219,258 +282,189 @@ void InMemFilterStore::load_labels_to_medoids(std::basic_istream& } _filter_to_medoid_ids[label].swap(medoids); } - } catch (std::system_error &e) { - throw FileException(labels_to_medoids, e, __FUNCSIG__, __FILE__, __LINE__); - } -} - -template -void InMemFilterStore::load_label_map( - std::basic_istream &map_reader) { - std::string line, token; - LabelT token_as_num; - std::string label_str; - while (std::getline(map_reader, line)) { - std::istringstream iss(line); - getline(iss, token, '\t'); - label_str = token; - getline(iss, token, '\t'); - token_as_num = (LabelT) std::stoul(token); - _label_map[label_str] = token_as_num; } - return _label_map; -} - -template -void InMemFilterStore::load_dummy_map( - std::basic_istream &dummy_map_stream) { - std::string line, token; - - while (std::getline(dummy_map_stream, line)) { - std::istringstream iss(line); - uint32_t cnt = 0; - uint32_t dummy_id; - uint32_t real_id; - while (std::getline(iss, token, ',')) { - if (cnt == 0) - dummy_id = (uint32_t) stoul(token); - else - real_id = (uint32_t) stoul(token); - cnt++; - } - _dummy_pts.insert(dummy_id); - _has_dummy_pts.insert(real_id); - _dummy_to_real_map[dummy_id] = real_id; - if (_real_to_dummy_map.find(real_id) == _real_to_dummy_map.end()) - _real_to_dummy_map[real_id] = std::vector(); + template + void InMemFilterStore::load_label_map( + std::basic_istream &map_reader) { + std::string line, token; + LabelT token_as_num; + std::string label_str; + while (std::getline(map_reader, line)) { + std::istringstream iss(line); + getline(iss, token, '\t'); + label_str = token; + getline(iss, token, '\t'); + token_as_num = (LabelT) std::stoul(token); + _label_map[label_str] = token_as_num; + } + } - _real_to_dummy_map[real_id].emplace_back(dummy_id); + template + void InMemFilterStore::parse_universal_label( + const std::string_view &content) { + LabelT label_as_num = (LabelT) std::stoul(std::string(content)); + this->set_universal_label(label_as_num); } - diskann::cout << "Loaded dummy map" << std::endl; -} + template + void InMemFilterStore::load_dummy_map( + std::basic_istream &dummy_map_stream) { + std::string line, token; + + while (std::getline(dummy_map_stream, line)) { + std::istringstream iss(line); + uint32_t cnt = 0; + uint32_t dummy_id; + uint32_t real_id; + while (std::getline(iss, token, ',')) { + if (cnt == 0) + dummy_id = (uint32_t) stoul(token); + else + real_id = (uint32_t) stoul(token); + cnt++; + } + _dummy_pts.insert(dummy_id); + _has_dummy_pts.insert(real_id); + _dummy_to_real_map[dummy_id] = real_id; + + if (_real_to_dummy_map.find(real_id) == _real_to_dummy_map.end()) + _real_to_dummy_map[real_id] = std::vector(); -template -bool InMemFilterStore::load_file_and_parse(const std::string &filename, - void (*parse_fn)(std::basic_istream &stream)) { - if (file_exists(filename)) { - std::basic_istream stream(filename); - if (false == stream.fail()) { - parse_fn(stream); - return true; - } else { - std::stringstream ss; - ss << "Could not open file: " << filename << std::endl; - throw diskann::ANNException(ss.str(), -1); + _real_to_dummy_map[real_id].emplace_back(dummy_id); } - } else { - return false; + diskann::cout << "Loaded dummy map" << std::endl; } -} - -template -bool InMemFilterStore::load_file_and_parse( - const std::string &filename, - void (*parse_fn)(const std::string_view &content)) { - if (file_exists(filename)) { - size_t file_size = 0; - auto file_content_ptr = get_file_content(filename, file_size); - std::string_view content_as_str(file_content_ptr.get(), file_size); - parse_fn(content_as_str); - return true; - } else { - return false; + + template + void InMemFilterStore::generate_random_labels( + std::vector &labels, const uint32_t num_labels, + const uint32_t nthreads) { + std::random_device rd; + labels.clear(); + labels.resize(num_labels); + + uint64_t num_total_labels = _pts_to_label_offsets[_num_points - 1] + + _pts_to_label_counts[_num_points - 1]; + std::mt19937 gen(rd()); + if (num_total_labels == 0) { + std::stringstream stream; + stream << "No labels found in data. Not sampling random labels "; + diskann::cerr << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, + __LINE__); + } + std::uniform_int_distribution dis(0, num_total_labels - 1); + +#pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads) + for (int64_t i = 0; i < num_labels; i++) { + uint64_t rnd_loc = dis(gen); + labels[i] = (LabelT) _pts_to_labels[rnd_loc]; + } } -} - -template -std::unique_ptr get_file_content(const std::string &filename, - uint64_t & file_size) { - std::ifstream infile(filename, std::ios::binary); - if (infile.fail()) { - throw diskann::ANNException(std::string("Failed to open file ") + filename, - -1); + + template + void InMemFilterStore::reset_stream_for_reading( + std::basic_istream &infile) { + infile.clear(); + infile.seekg(0); } - infile.seekg(0, std::ios::end); - file_size = infile.tellg(); - buffer = new char[file_size]; - infile.seekg(0, std::ios::beg); - infile.read(buffer, file_size); + template + void InMemFilterStore::get_label_file_metadata( + const std::string_view &fileContent, uint32_t &num_pts, + uint32_t &num_total_labels) { + num_pts = 0; + num_total_labels = 0; + + size_t file_size = fileContent.length(); + + std::string label_str; + size_t cur_pos = 0; + size_t next_pos = 0; + while (cur_pos < file_size && cur_pos != std::string::npos) { + next_pos = fileContent.find('\n', cur_pos); + if (next_pos == std::string::npos) { + break; + } - return std::unique_ptr(buffer); -} + size_t lbl_pos = cur_pos; + size_t next_lbl_pos = 0; + while (lbl_pos < next_pos && lbl_pos != std::string::npos) { + next_lbl_pos = fileContent.find(',', lbl_pos); + if (next_lbl_pos == std::string::npos) // the last label + { + next_lbl_pos = next_pos; + } -// Load functions for SEARCH END -} + num_total_labels++; + lbl_pos = next_lbl_pos + 1; + } -/* - template -#ifdef EXEC_ENV_OLS - bool InMemFilterStore::load(MemoryMappedFiles &files, - const std::string &disk_index_prefix) { -#else - bool InMemFilterStore::load(const std::string &label_files_prefix) { -#endif - std::string labels_file = std ::string(_disk_index_file) + "_labels.txt"; - std::string labels_to_medoids = - std ::string(_disk_index_file) + "_labels_to_medoids.txt"; - std::string dummy_map_file = - std ::string(_disk_index_file) + "_dummy_map.txt"; - std::string labels_map_file = - std ::string(_disk_index_file) + "_labels_map.txt"; - size_t num_pts_in_label_file = 0; + cur_pos = next_pos + 1; - // TODO: Ideally we don't want to read entire data files into memory for - // processing them. Fortunately for us, the most restrictive client in terms - // of runtime memory already loads the data into blobs. So we'll go ahead - // and do the same. But this needs to be fixed, maybe with separate code - // paths. - // TODO: Check for encoding issues here. We are opening files as binary and - // reading them as bytes, not sure if that can cause an issue with UTF - // encodings. - bool has_filters = true; -#ifdef EXEC_ENV_OLS - if (files.fileExists(labels_file)) { - FileContent & content_labels = files.getContent(labels_file); - std::stringstream infile(std::string( - (const char *) content_labels._content, content_labels._size)); -#else - if (file_exists(labels_file)) { - size_t file_size; - auto label_file_content = get_file_content(labels_file, file_size); - std::string_view content_as_str(label_file_content.get(), file_size); -#endif - load_label_file(content_as_str); - assert(num_pts_in_label_file == this->_num_points); + num_pts++; + } + + diskann::cout << "Labels file metadata: num_points: " << num_pts + << ", #total_labels: " << num_total_labels << std::endl; + } + + template + bool InMemFilterStore::parse_stream( + const std::string &filename, + void (InMemFilterStore::*parse_fn)(std::basic_istream &stream)) { + if (file_exists(filename)) { + std::ifstream stream(filename); + if (false == stream.fail()) { + std::invoke(parse_fn, this, stream); + return true; + } else { + std::stringstream ss; + ss << "Could not open file: " << filename << std::endl; + throw diskann::ANNException(ss.str(), -1); + } } else { - diskann::cerr << "Index does not have filter data." << std::endl; return false; } + } - // If we have come here, it means that the labels_file exists. This means - // the other files must also exist, and them missing is a bug. -#ifdef EXEC_ENV_OLS - if (files.fileExists(labels_map_file)) { - FileContent & content_labels_map = files.getContent(labels_map_file); - std::stringstream map_reader( - std::string((const char *) content_labels_map._content, - content_labels_map._size)); -#else - if (file_exists(labels_map_file)) { - std::ifstream map_reader(labels_map_file); -#endif - _label_map = load_label_map(map_reader); - -#ifndef EXEC_ENV_OLS - map_reader.close(); -#endif + template + bool InMemFilterStore::load_file_and_parse( + const std::string &filename, + void (InMemFilterStore::*parse_fn)(const std::string_view &content)) { + if (file_exists(filename)) { + size_t file_size = 0; + auto file_content_ptr = get_file_content(filename, file_size); + std::string_view content_as_str(file_content_ptr.get(), file_size); + std::invoke(parse_fn, this, content_as_str); + return true; } else { - std::stringstream ss; - ss << "Index is filter enabled (labels file exists) but label map file: " - << labels_map_file << " could not be opened"; - diskann::cerr << ss.str() << std::endl; - throw diskann::ANNException(ss.str(), -1); + return false; } + } -#ifdef EXEC_ENV_OLS - if (files.fileExists(labels_to_medoids)) { - FileContent &content_labels_to_meoids = - files.getContent(labels_to_medoids); - std::stringstream medoid_stream( - std::string((const char *) content_labels_to_meoids._content, - content_labels_to_meoids._size)); -#else - if (file_exists(labels_to_medoids)) { - std::ifstream medoid_stream(labels_to_medoids); - assert(medoid_stream.is_open()); -#endif - load_labels_to_medoids(medoid_stream); - } - std::string univ_label_file = - std ::string(_disk_index_file) + "_universal_label.txt"; - -#ifdef EXEC_ENV_OLS - if (files.fileExists(univ_label_file)) { - FileContent & content_univ_label = files.getContent(univ_label_file); - std::stringstream universal_label_reader( - std::string((const char *) content_univ_label._content, - content_univ_label._size)); -#else - if (file_exists(univ_label_file)) { - std::ifstream universal_label_reader(univ_label_file); - assert(universal_label_reader.is_open()); -#endif - std::string univ_label; - universal_label_reader >> univ_label; -#ifndef EXEC_ENV_OLS - universal_label_reader.close(); -#endif - LabelT label_as_num = (LabelT) std::stoul(univ_label); - set_universal_label(label_as_num); + std::unique_ptr get_file_content(const std::string &filename, + uint64_t & file_size) { + std::ifstream infile(filename, std::ios::binary); + if (infile.fail()) { + throw diskann::ANNException( + std::string("Failed to open file ") + filename, -1); } + infile.seekg(0, std::ios::end); + file_size = infile.tellg(); -#ifdef EXEC_ENV_OLS - if (files.fileExists(dummy_map_file)) { - FileContent & content_dummy_map = files.getContent(dummy_map_file); - std::stringstream dummy_map_stream(std::string( - (const char *) content_dummy_map._content, content_dummy_map._size)); -#else - if (file_exists(dummy_map_file)) { - std::ifstream dummy_map_stream(dummy_map_file); - assert(dummy_map_stream.is_open()); -#endif - std::string line, token; - - while (std::getline(dummy_map_stream, line)) { - std::istringstream iss(line); - uint32_t cnt = 0; - uint32_t dummy_id; - uint32_t real_id; - while (std::getline(iss, token, ',')) { - if (cnt == 0) - dummy_id = (uint32_t) stoul(token); - else - real_id = (uint32_t) stoul(token); - cnt++; - } - _dummy_pts.insert(dummy_id); - _has_dummy_pts.insert(real_id); - _dummy_to_real_map[dummy_id] = real_id; - - if (_real_to_dummy_map.find(real_id) == _real_to_dummy_map.end()) - _real_to_dummy_map[real_id] = std::vector(); + auto buffer = new char[file_size]; + infile.seekg(0, std::ios::beg); + infile.read(buffer, file_size); - _real_to_dummy_map[real_id].emplace_back(dummy_id); - } -#ifndef EXEC_ENV_OLS - dummy_map_stream.close(); -#endif - diskann::cout << "Loaded dummy map" << std::endl; - } + return std::unique_ptr(buffer); } + // Load functions for SEARCH END + template class InMemFilterStore; + template class InMemFilterStore; + template class InMemFilterStore; + -} -*/ \ No newline at end of file +} // namespace diskann diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index ea3ed1981..a220382ec 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -15,197 +15,191 @@ #include "linux_aligned_file_reader.h" #endif -#define READ_U64(stream, val) stream.read((char *)&val, sizeof(uint64_t)) -#define READ_U32(stream, val) stream.read((char *)&val, sizeof(uint32_t)) -#define READ_UNSIGNED(stream, val) stream.read((char *)&val, sizeof(unsigned)) +#define READ_U64(stream, val) stream.read((char *) &val, sizeof(uint64_t)) +#define READ_U32(stream, val) stream.read((char *) &val, sizeof(uint32_t)) +#define READ_UNSIGNED(stream, val) stream.read((char *) &val, sizeof(unsigned)) // sector # beyond the end of graph where data for id is present for reordering -#define VECTOR_SECTOR_NO(id) (((uint64_t)(id)) / _nvecs_per_sector + _reorder_data_start_sector) +#define VECTOR_SECTOR_NO(id) \ + (((uint64_t) (id)) / _nvecs_per_sector + _reorder_data_start_sector) // sector # beyond the end of graph where data for id is present for reordering -#define VECTOR_SECTOR_OFFSET(id) ((((uint64_t)(id)) % _nvecs_per_sector) * _data_dim * sizeof(float)) +#define VECTOR_SECTOR_OFFSET(id) \ + ((((uint64_t) (id)) % _nvecs_per_sector) * _data_dim * sizeof(float)) -namespace diskann -{ +namespace diskann { -template -PQFlashIndex::PQFlashIndex(std::shared_ptr &fileReader, diskann::Metric m) - : reader(fileReader), metric(m), _thread_data(nullptr) -{ + template + PQFlashIndex::PQFlashIndex( + std::shared_ptr &fileReader, diskann::Metric m) + : reader(fileReader), metric(m), _thread_data(nullptr) { diskann::Metric metric_to_invoke = m; - if (m == diskann::Metric::COSINE || m == diskann::Metric::INNER_PRODUCT) - { - if (std::is_floating_point::value) - { - diskann::cout << "Since data is floating point, we assume that it has been appropriately pre-processed " - "(normalization for cosine, and convert-to-l2 by adding extra dimension for MIPS). So we " - "shall invoke an l2 distance function." - << std::endl; - metric_to_invoke = diskann::Metric::L2; - } - else - { - diskann::cerr << "WARNING: Cannot normalize integral data types." - << " This may result in erroneous results or poor recall." - << " Consider using L2 distance with integral data types." << std::endl; - } + if (m == diskann::Metric::COSINE || m == diskann::Metric::INNER_PRODUCT) { + if (std::is_floating_point::value) { + diskann::cout << "Since data is floating point, we assume that it has " + "been appropriately pre-processed " + "(normalization for cosine, and convert-to-l2 by " + "adding extra dimension for MIPS). So we " + "shall invoke an l2 distance function." + << std::endl; + metric_to_invoke = diskann::Metric::L2; + } else { + diskann::cerr << "WARNING: Cannot normalize integral data types." + << " This may result in erroneous results or poor recall." + << " Consider using L2 distance with integral data types." + << std::endl; + } } this->_dist_cmp.reset(diskann::get_distance_function(metric_to_invoke)); - this->_dist_cmp_float.reset(diskann::get_distance_function(metric_to_invoke)); -} + this->_dist_cmp_float.reset( + diskann::get_distance_function(metric_to_invoke)); + } -template PQFlashIndex::~PQFlashIndex() -{ + template + PQFlashIndex::~PQFlashIndex() { #ifndef EXEC_ENV_OLS - if (data != nullptr) - { - delete[] data; + if (data != nullptr) { + delete[] data; } #endif if (_centroid_data != nullptr) - aligned_free(_centroid_data); + aligned_free(_centroid_data); // delete backing bufs for nhood and coord cache - if (_nhood_cache_buf != nullptr) - { - delete[] _nhood_cache_buf; - diskann::aligned_free(_coord_cache_buf); + if (_nhood_cache_buf != nullptr) { + delete[] _nhood_cache_buf; + diskann::aligned_free(_coord_cache_buf); } - if (_load_flag) - { - diskann::cout << "Clearing scratch" << std::endl; - ScratchStoreManager> manager(this->_thread_data); - manager.destroy(); - this->reader->deregister_all_threads(); - reader->close(); - } - if (_pts_to_label_offsets != nullptr) - { - delete[] _pts_to_label_offsets; + if (_medoids != nullptr) { + delete[] _medoids; + _medoids = nullptr; } - if (_pts_to_label_counts != nullptr) - { - delete[] _pts_to_label_counts; - } - if (_pts_to_labels != nullptr) - { - delete[] _pts_to_labels; - } - if (_medoids != nullptr) - { - delete[] _medoids; + + if (_load_flag) { + diskann::cout << "Clearing scratch" << std::endl; + ScratchStoreManager> manager(this->_thread_data); + manager.destroy(); + this->reader->deregister_all_threads(); + reader->close(); } -} - -template inline uint64_t PQFlashIndex::get_node_sector(uint64_t node_id) -{ - return 1 + (_nnodes_per_sector > 0 ? node_id / _nnodes_per_sector - : node_id * DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN)); -} - -template -inline char *PQFlashIndex::offset_to_node(char *sector_buf, uint64_t node_id) -{ - return sector_buf + (_nnodes_per_sector == 0 ? 0 : (node_id % _nnodes_per_sector) * _max_node_len); -} - -template inline uint32_t *PQFlashIndex::offset_to_node_nhood(char *node_buf) -{ - return (unsigned *)(node_buf + _disk_bytes_per_point); -} - -template inline T *PQFlashIndex::offset_to_node_coords(char *node_buf) -{ - return (T *)(node_buf); -} - -template -void PQFlashIndex::setup_thread_data(uint64_t nthreads, uint64_t visited_reserve) -{ - diskann::cout << "Setting up thread-specific contexts for nthreads: " << nthreads << std::endl; + } + + template + inline uint64_t PQFlashIndex::get_node_sector(uint64_t node_id) { + return 1 + + (_nnodes_per_sector > 0 + ? node_id / _nnodes_per_sector + : node_id * DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN)); + } + + template + inline char *PQFlashIndex::offset_to_node(char * sector_buf, + uint64_t node_id) { + return sector_buf + (_nnodes_per_sector == 0 + ? 0 + : (node_id % _nnodes_per_sector) * _max_node_len); + } + + template + inline uint32_t *PQFlashIndex::offset_to_node_nhood( + char *node_buf) { + return (unsigned *) (node_buf + _disk_bytes_per_point); + } + + template + inline T *PQFlashIndex::offset_to_node_coords(char *node_buf) { + return (T *) (node_buf); + } + + template + void PQFlashIndex::setup_thread_data(uint64_t nthreads, + uint64_t visited_reserve) { + diskann::cout << "Setting up thread-specific contexts for nthreads: " + << nthreads << std::endl; // omp parallel for to generate unique thread IDs -#pragma omp parallel for num_threads((int)nthreads) - for (int64_t thread = 0; thread < (int64_t)nthreads; thread++) - { +#pragma omp parallel for num_threads((int) nthreads) + for (int64_t thread = 0; thread < (int64_t) nthreads; thread++) { #pragma omp critical - { - SSDThreadData *data = new SSDThreadData(this->_aligned_dim, visited_reserve); - this->reader->register_thread(); - data->ctx = this->reader->get_ctx(); - this->_thread_data.push(data); - } + { + SSDThreadData *data = + new SSDThreadData(this->_aligned_dim, visited_reserve); + this->reader->register_thread(); + data->ctx = this->reader->get_ctx(); + this->_thread_data.push(data); + } } _load_flag = true; -} + } -template -std::vector PQFlashIndex::read_nodes(const std::vector &node_ids, - std::vector &coord_buffers, - std::vector> &nbr_buffers) -{ + template + std::vector PQFlashIndex::read_nodes( + const std::vector &node_ids, std::vector &coord_buffers, + std::vector> &nbr_buffers) { std::vector read_reqs; - std::vector retval(node_ids.size(), true); + std::vector retval(node_ids.size(), true); char *buf = nullptr; - auto num_sectors = _nnodes_per_sector > 0 ? 1 : DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN); - alloc_aligned((void **)&buf, node_ids.size() * num_sectors * defaults::SECTOR_LEN, defaults::SECTOR_LEN); + auto num_sectors = _nnodes_per_sector > 0 + ? 1 + : DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN); + alloc_aligned((void **) &buf, + node_ids.size() * num_sectors * defaults::SECTOR_LEN, + defaults::SECTOR_LEN); // create read requests - for (size_t i = 0; i < node_ids.size(); ++i) - { - auto node_id = node_ids[i]; - - AlignedRead read; - read.len = num_sectors * defaults::SECTOR_LEN; - read.buf = buf + i * num_sectors * defaults::SECTOR_LEN; - read.offset = get_node_sector(node_id) * defaults::SECTOR_LEN; - read_reqs.push_back(read); + for (size_t i = 0; i < node_ids.size(); ++i) { + auto node_id = node_ids[i]; + + AlignedRead read; + read.len = num_sectors * defaults::SECTOR_LEN; + read.buf = buf + i * num_sectors * defaults::SECTOR_LEN; + read.offset = get_node_sector(node_id) * defaults::SECTOR_LEN; + read_reqs.push_back(read); } // borrow thread data and issue reads ScratchStoreManager> manager(this->_thread_data); - auto this_thread_data = manager.scratch_space(); + auto this_thread_data = manager.scratch_space(); IOContext &ctx = this_thread_data->ctx; reader->read(read_reqs, ctx); // copy reads into buffers - for (uint32_t i = 0; i < read_reqs.size(); i++) - { -#if defined(_WINDOWS) && defined(USE_BING_INFRA) // this block is to handle failed reads in - // production settings - if ((*ctx.m_pRequestsStatus)[i] != IOContext::READ_SUCCESS) - { - retval[i] = false; - continue; - } + for (uint32_t i = 0; i < read_reqs.size(); i++) { +#if defined(_WINDOWS) && \ + defined(USE_BING_INFRA) // this block is to handle failed reads in + // production settings + if ((*ctx.m_pRequestsStatus)[i] != IOContext::READ_SUCCESS) { + retval[i] = false; + continue; + } #endif - char *node_buf = offset_to_node((char *)read_reqs[i].buf, node_ids[i]); + char *node_buf = offset_to_node((char *) read_reqs[i].buf, node_ids[i]); - if (coord_buffers[i] != nullptr) - { - T *node_coords = offset_to_node_coords(node_buf); - memcpy(coord_buffers[i], node_coords, _disk_bytes_per_point); - } + if (coord_buffers[i] != nullptr) { + T *node_coords = offset_to_node_coords(node_buf); + memcpy(coord_buffers[i], node_coords, _disk_bytes_per_point); + } - if (nbr_buffers[i].second != nullptr) - { - uint32_t *node_nhood = offset_to_node_nhood(node_buf); - auto num_nbrs = *node_nhood; - nbr_buffers[i].first = num_nbrs; - memcpy(nbr_buffers[i].second, node_nhood + 1, num_nbrs * sizeof(uint32_t)); - } + if (nbr_buffers[i].second != nullptr) { + uint32_t *node_nhood = offset_to_node_nhood(node_buf); + auto num_nbrs = *node_nhood; + nbr_buffers[i].first = num_nbrs; + memcpy(nbr_buffers[i].second, node_nhood + 1, + num_nbrs * sizeof(uint32_t)); + } } aligned_free(buf); return retval; -} + } -template void PQFlashIndex::load_cache_list(std::vector &node_list) -{ + template + void PQFlashIndex::load_cache_list( + std::vector &node_list) { diskann::cout << "Loading the cache list into memory.." << std::flush; size_t num_cached_nodes = node_list.size(); @@ -215,462 +209,377 @@ template void PQFlashIndex::load_cache_ // Allocate space for coordinate cache size_t coord_cache_buf_len = num_cached_nodes * _aligned_dim; - diskann::alloc_aligned((void **)&_coord_cache_buf, coord_cache_buf_len * sizeof(T), 8 * sizeof(T)); + diskann::alloc_aligned((void **) &_coord_cache_buf, + coord_cache_buf_len * sizeof(T), 8 * sizeof(T)); memset(_coord_cache_buf, 0, coord_cache_buf_len * sizeof(T)); size_t BLOCK_SIZE = 8; size_t num_blocks = DIV_ROUND_UP(num_cached_nodes, BLOCK_SIZE); - for (size_t block = 0; block < num_blocks; block++) - { - size_t start_idx = block * BLOCK_SIZE; - size_t end_idx = (std::min)(num_cached_nodes, (block + 1) * BLOCK_SIZE); - - // Copy offset into buffers to read into - std::vector nodes_to_read; - std::vector coord_buffers; - std::vector> nbr_buffers; - for (size_t node_idx = start_idx; node_idx < end_idx; node_idx++) - { - nodes_to_read.push_back(node_list[node_idx]); - coord_buffers.push_back(_coord_cache_buf + node_idx * _aligned_dim); - nbr_buffers.emplace_back(0, _nhood_cache_buf + node_idx * (_max_degree + 1)); - } - - // issue the reads - auto read_status = read_nodes(nodes_to_read, coord_buffers, nbr_buffers); - - // check for success and insert into the cache. - for (size_t i = 0; i < read_status.size(); i++) - { - if (read_status[i] == true) - { - _coord_cache.insert(std::make_pair(nodes_to_read[i], coord_buffers[i])); - _nhood_cache.insert(std::make_pair(nodes_to_read[i], nbr_buffers[i])); - } + for (size_t block = 0; block < num_blocks; block++) { + size_t start_idx = block * BLOCK_SIZE; + size_t end_idx = (std::min) (num_cached_nodes, (block + 1) * BLOCK_SIZE); + + // Copy offset into buffers to read into + std::vector nodes_to_read; + std::vector coord_buffers; + std::vector> nbr_buffers; + for (size_t node_idx = start_idx; node_idx < end_idx; node_idx++) { + nodes_to_read.push_back(node_list[node_idx]); + coord_buffers.push_back(_coord_cache_buf + node_idx * _aligned_dim); + nbr_buffers.emplace_back( + 0, _nhood_cache_buf + node_idx * (_max_degree + 1)); + } + + // issue the reads + auto read_status = read_nodes(nodes_to_read, coord_buffers, nbr_buffers); + + // check for success and insert into the cache. + for (size_t i = 0; i < read_status.size(); i++) { + if (read_status[i] == true) { + _coord_cache.insert( + std::make_pair(nodes_to_read[i], coord_buffers[i])); + _nhood_cache.insert(std::make_pair(nodes_to_read[i], nbr_buffers[i])); } + } } diskann::cout << "..done." << std::endl; -} + } #ifdef EXEC_ENV_OLS -template -void PQFlashIndex::generate_cache_list_from_sample_queries(MemoryMappedFiles &files, std::string sample_bin, - uint64_t l_search, uint64_t beamwidth, - uint64_t num_nodes_to_cache, uint32_t nthreads, - std::vector &node_list) -{ + template + void PQFlashIndex::generate_cache_list_from_sample_queries( + MemoryMappedFiles &files, std::string sample_bin, uint64_t l_search, + uint64_t beamwidth, uint64_t num_nodes_to_cache, uint32_t nthreads, + std::vector &node_list) { #else -template -void PQFlashIndex::generate_cache_list_from_sample_queries(std::string sample_bin, uint64_t l_search, - uint64_t beamwidth, uint64_t num_nodes_to_cache, - uint32_t nthreads, - std::vector &node_list) -{ + template + void PQFlashIndex::generate_cache_list_from_sample_queries( + std::string sample_bin, uint64_t l_search, uint64_t beamwidth, + uint64_t num_nodes_to_cache, uint32_t nthreads, + std::vector &node_list) { #endif - if (num_nodes_to_cache >= this->_num_points) - { - // for small num_points and big num_nodes_to_cache, use below way to get the node_list quickly - node_list.resize(this->_num_points); - for (uint32_t i = 0; i < this->_num_points; ++i) - { - node_list[i] = i; - } - return; + if (num_nodes_to_cache >= this->_num_points) { + // for small num_points and big num_nodes_to_cache, use below way to get + // the node_list quickly + node_list.resize(this->_num_points); + for (uint32_t i = 0; i < this->_num_points; ++i) { + node_list[i] = i; + } + return; } this->_count_visited_nodes = true; this->_node_visit_counter.clear(); this->_node_visit_counter.resize(this->_num_points); - for (uint32_t i = 0; i < _node_visit_counter.size(); i++) - { - this->_node_visit_counter[i].first = i; - this->_node_visit_counter[i].second = 0; + for (uint32_t i = 0; i < _node_visit_counter.size(); i++) { + this->_node_visit_counter[i].first = i; + this->_node_visit_counter[i].second = 0; } uint64_t sample_num, sample_dim, sample_aligned_dim; - T *samples; + T * samples; #ifdef EXEC_ENV_OLS - if (files.fileExists(sample_bin)) - { - diskann::load_aligned_bin(files, sample_bin, samples, sample_num, sample_dim, sample_aligned_dim); + if (files.fileExists(sample_bin)) { + diskann::load_aligned_bin(files, sample_bin, samples, sample_num, + sample_dim, sample_aligned_dim); } #else - if (file_exists(sample_bin)) - { - diskann::load_aligned_bin(sample_bin, samples, sample_num, sample_dim, sample_aligned_dim); + if (file_exists(sample_bin)) { + diskann::load_aligned_bin(sample_bin, samples, sample_num, sample_dim, + sample_aligned_dim); } #endif - else - { - diskann::cerr << "Sample bin file not found. Not generating cache." << std::endl; - return; + else { + diskann::cerr << "Sample bin file not found. Not generating cache." + << std::endl; + return; } std::vector tmp_result_ids_64(sample_num, 0); - std::vector tmp_result_dists(sample_num, 0); + std::vector tmp_result_dists(sample_num, 0); - bool filtered_search = false; + bool filtered_search = false; std::vector random_query_filters(sample_num); - if (_filter_to_medoid_ids.size() != 0) - { - filtered_search = true; - generate_random_labels(random_query_filters, (uint32_t)sample_num, nthreads); + if (this->_filter_index) { + filtered_search = true; + _filter_store->generate_random_labels(random_query_filters, + (uint32_t) sample_num, nthreads); } #pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads) - for (int64_t i = 0; i < (int64_t)sample_num; i++) - { - auto &label_for_search = random_query_filters[i]; - // run a search on the sample query with a random label (sampled from base label distribution), and it will - // concurrently update the node_visit_counter to track most visited nodes. The last false is to not use the - // "use_reorder_data" option which enables a final reranking if the disk index itself contains only PQ data. - cached_beam_search(samples + (i * sample_aligned_dim), 1, l_search, tmp_result_ids_64.data() + i, - tmp_result_dists.data() + i, beamwidth, filtered_search, label_for_search, false); + for (int64_t i = 0; i < (int64_t) sample_num; i++) { + auto &label_for_search = random_query_filters[i]; + // run a search on the sample query with a random label (sampled from base + // label distribution), and it will concurrently update the + // node_visit_counter to track most visited nodes. The last false is to + // not use the "use_reorder_data" option which enables a final reranking + // if the disk index itself contains only PQ data. + cached_beam_search(samples + (i * sample_aligned_dim), 1, l_search, + tmp_result_ids_64.data() + i, + tmp_result_dists.data() + i, beamwidth, + filtered_search, label_for_search, false); } std::sort(this->_node_visit_counter.begin(), _node_visit_counter.end(), - [](std::pair &left, std::pair &right) { - return left.second > right.second; + [](std::pair &left, + std::pair &right) { + return left.second > right.second; }); node_list.clear(); node_list.shrink_to_fit(); - num_nodes_to_cache = std::min(num_nodes_to_cache, this->_node_visit_counter.size()); + num_nodes_to_cache = + std::min(num_nodes_to_cache, this->_node_visit_counter.size()); node_list.reserve(num_nodes_to_cache); - for (uint64_t i = 0; i < num_nodes_to_cache; i++) - { - node_list.push_back(this->_node_visit_counter[i].first); + for (uint64_t i = 0; i < num_nodes_to_cache; i++) { + node_list.push_back(this->_node_visit_counter[i].first); } this->_count_visited_nodes = false; diskann::aligned_free(samples); -} + } -template -void PQFlashIndex::cache_bfs_levels(uint64_t num_nodes_to_cache, std::vector &node_list, - const bool shuffle) -{ + template + void PQFlashIndex::cache_bfs_levels( + uint64_t num_nodes_to_cache, std::vector &node_list, + const bool shuffle) { std::random_device rng; - std::mt19937 urng(rng()); + std::mt19937 urng(rng()); tsl::robin_set node_set; // Do not cache more than 10% of the nodes in the index - uint64_t tenp_nodes = (uint64_t)(std::round(this->_num_points * 0.1)); - if (num_nodes_to_cache > tenp_nodes) - { - diskann::cout << "Reducing nodes to cache from: " << num_nodes_to_cache << " to: " << tenp_nodes - << "(10 percent of total nodes:" << this->_num_points << ")" << std::endl; - num_nodes_to_cache = tenp_nodes == 0 ? 1 : tenp_nodes; + uint64_t tenp_nodes = (uint64_t) (std::round(this->_num_points * 0.1)); + if (num_nodes_to_cache > tenp_nodes) { + diskann::cout << "Reducing nodes to cache from: " << num_nodes_to_cache + << " to: " << tenp_nodes + << "(10 percent of total nodes:" << this->_num_points << ")" + << std::endl; + num_nodes_to_cache = tenp_nodes == 0 ? 1 : tenp_nodes; } diskann::cout << "Caching " << num_nodes_to_cache << "..." << std::endl; // borrow thread data ScratchStoreManager> manager(this->_thread_data); - auto this_thread_data = manager.scratch_space(); + auto this_thread_data = manager.scratch_space(); IOContext &ctx = this_thread_data->ctx; std::unique_ptr> cur_level, prev_level; cur_level = std::make_unique>(); prev_level = std::make_unique>(); - for (uint64_t miter = 0; miter < _num_medoids && cur_level->size() < num_nodes_to_cache; miter++) - { - cur_level->insert(_medoids[miter]); + for (uint64_t miter = 0; + miter < _num_medoids && cur_level->size() < num_nodes_to_cache; + miter++) { + cur_level->insert(_medoids[miter]); } auto filter_to_medoid_ids = _filter_store->get_label_to_medoids(); - if ((filter_to_medoid_ids.size() > 0) && (cur_level->size() < num_nodes_to_cache)) - { - for (auto &x : filter_to_medoid_ids) - { - for (auto &y : x.second) - { - cur_level->insert(y); - if (cur_level->size() == num_nodes_to_cache) - break; - } - if (cur_level->size() == num_nodes_to_cache) - break; + if ((filter_to_medoid_ids.size() > 0) && + (cur_level->size() < num_nodes_to_cache)) { + for (auto &x : filter_to_medoid_ids) { + for (auto &y : x.second) { + cur_level->insert(y); + if (cur_level->size() == num_nodes_to_cache) + break; } + if (cur_level->size() == num_nodes_to_cache) + break; + } } uint64_t lvl = 1; uint64_t prev_node_set_size = 0; - while ((node_set.size() + cur_level->size() < num_nodes_to_cache) && cur_level->size() != 0) - { - // swap prev_level and cur_level - std::swap(prev_level, cur_level); - // clear cur_level - cur_level->clear(); - - std::vector nodes_to_expand; - - for (const uint32_t &id : *prev_level) - { - if (node_set.find(id) != node_set.end()) - { - continue; - } - node_set.insert(id); - nodes_to_expand.push_back(id); + while ((node_set.size() + cur_level->size() < num_nodes_to_cache) && + cur_level->size() != 0) { + // swap prev_level and cur_level + std::swap(prev_level, cur_level); + // clear cur_level + cur_level->clear(); + + std::vector nodes_to_expand; + + for (const uint32_t &id : *prev_level) { + if (node_set.find(id) != node_set.end()) { + continue; } + node_set.insert(id); + nodes_to_expand.push_back(id); + } + + if (shuffle) + std::shuffle(nodes_to_expand.begin(), nodes_to_expand.end(), urng); + else + std::sort(nodes_to_expand.begin(), nodes_to_expand.end()); + + diskann::cout << "Level: " << lvl << std::flush; + bool finish_flag = false; + + uint64_t BLOCK_SIZE = 1024; + uint64_t nblocks = DIV_ROUND_UP(nodes_to_expand.size(), BLOCK_SIZE); + for (size_t block = 0; block < nblocks && !finish_flag; block++) { + diskann::cout << "." << std::flush; + size_t start = block * BLOCK_SIZE; + size_t end = + (std::min) ((block + 1) * BLOCK_SIZE, nodes_to_expand.size()); - if (shuffle) - std::shuffle(nodes_to_expand.begin(), nodes_to_expand.end(), urng); - else - std::sort(nodes_to_expand.begin(), nodes_to_expand.end()); - - diskann::cout << "Level: " << lvl << std::flush; - bool finish_flag = false; - - uint64_t BLOCK_SIZE = 1024; - uint64_t nblocks = DIV_ROUND_UP(nodes_to_expand.size(), BLOCK_SIZE); - for (size_t block = 0; block < nblocks && !finish_flag; block++) - { - diskann::cout << "." << std::flush; - size_t start = block * BLOCK_SIZE; - size_t end = (std::min)((block + 1) * BLOCK_SIZE, nodes_to_expand.size()); - - std::vector nodes_to_read; - std::vector coord_buffers(end - start, nullptr); - std::vector> nbr_buffers; - - for (size_t cur_pt = start; cur_pt < end; cur_pt++) - { - nodes_to_read.push_back(nodes_to_expand[cur_pt]); - nbr_buffers.emplace_back(0, new uint32_t[_max_degree + 1]); - } + std::vector nodes_to_read; + std::vector coord_buffers(end - start, nullptr); + std::vector> nbr_buffers; + + for (size_t cur_pt = start; cur_pt < end; cur_pt++) { + nodes_to_read.push_back(nodes_to_expand[cur_pt]); + nbr_buffers.emplace_back(0, new uint32_t[_max_degree + 1]); + } - // issue read requests - auto read_status = read_nodes(nodes_to_read, coord_buffers, nbr_buffers); - - // process each nhood buf - for (uint32_t i = 0; i < read_status.size(); i++) - { - if (read_status[i] == false) - { - continue; - } - else - { - uint32_t nnbrs = nbr_buffers[i].first; - uint32_t *nbrs = nbr_buffers[i].second; - - // explore next level - for (uint32_t j = 0; j < nnbrs && !finish_flag; j++) - { - if (node_set.find(nbrs[j]) == node_set.end()) - { - cur_level->insert(nbrs[j]); - } - if (cur_level->size() + node_set.size() >= num_nodes_to_cache) - { - finish_flag = true; - } - } - } - delete[] nbr_buffers[i].second; + // issue read requests + auto read_status = + read_nodes(nodes_to_read, coord_buffers, nbr_buffers); + + // process each nhood buf + for (uint32_t i = 0; i < read_status.size(); i++) { + if (read_status[i] == false) { + continue; + } else { + uint32_t nnbrs = nbr_buffers[i].first; + uint32_t *nbrs = nbr_buffers[i].second; + + // explore next level + for (uint32_t j = 0; j < nnbrs && !finish_flag; j++) { + if (node_set.find(nbrs[j]) == node_set.end()) { + cur_level->insert(nbrs[j]); + } + if (cur_level->size() + node_set.size() >= num_nodes_to_cache) { + finish_flag = true; + } } + } + delete[] nbr_buffers[i].second; } + } - diskann::cout << ". #nodes: " << node_set.size() - prev_node_set_size - << ", #nodes thus far: " << node_set.size() << std::endl; - prev_node_set_size = node_set.size(); - lvl++; + diskann::cout << ". #nodes: " << node_set.size() - prev_node_set_size + << ", #nodes thus far: " << node_set.size() << std::endl; + prev_node_set_size = node_set.size(); + lvl++; } - assert(node_set.size() + cur_level->size() == num_nodes_to_cache || cur_level->size() == 0); + assert(node_set.size() + cur_level->size() == num_nodes_to_cache || + cur_level->size() == 0); node_list.clear(); node_list.reserve(node_set.size() + cur_level->size()); for (auto node : node_set) - node_list.push_back(node); + node_list.push_back(node); for (auto node : *cur_level) - node_list.push_back(node); + node_list.push_back(node); diskann::cout << "Level: " << lvl << std::flush; - diskann::cout << ". #nodes: " << node_list.size() - prev_node_set_size << ", #nodes thus far: " << node_list.size() - << std::endl; + diskann::cout << ". #nodes: " << node_list.size() - prev_node_set_size + << ", #nodes thus far: " << node_list.size() << std::endl; diskann::cout << "done" << std::endl; -} + } -template void PQFlashIndex::use_medoids_data_as_centroids() -{ + template + void PQFlashIndex::use_medoids_data_as_centroids() { if (_centroid_data != nullptr) - aligned_free(_centroid_data); - alloc_aligned(((void **)&_centroid_data), _num_medoids * _aligned_dim * sizeof(float), 32); + aligned_free(_centroid_data); + alloc_aligned(((void **) &_centroid_data), + _num_medoids * _aligned_dim * sizeof(float), 32); std::memset(_centroid_data, 0, _num_medoids * _aligned_dim * sizeof(float)); - diskann::cout << "Loading centroid data from medoids vector data of " << _num_medoids << " medoid(s)" << std::endl; + diskann::cout << "Loading centroid data from medoids vector data of " + << _num_medoids << " medoid(s)" << std::endl; - std::vector nodes_to_read; - std::vector medoid_bufs; + std::vector nodes_to_read; + std::vector medoid_bufs; std::vector> nbr_bufs; - for (uint64_t cur_m = 0; cur_m < _num_medoids; cur_m++) - { - nodes_to_read.push_back(_medoids[cur_m]); - medoid_bufs.push_back(new T[_data_dim]); - nbr_bufs.emplace_back(0, nullptr); + for (uint64_t cur_m = 0; cur_m < _num_medoids; cur_m++) { + nodes_to_read.push_back(_medoids[cur_m]); + medoid_bufs.push_back(new T[_data_dim]); + nbr_bufs.emplace_back(0, nullptr); } auto read_status = read_nodes(nodes_to_read, medoid_bufs, nbr_bufs); - for (uint64_t cur_m = 0; cur_m < _num_medoids; cur_m++) - { - if (read_status[cur_m] == true) - { - if (!_use_disk_index_pq) - { - for (uint32_t i = 0; i < _data_dim; i++) - _centroid_data[cur_m * _aligned_dim + i] = medoid_bufs[cur_m][i]; - } - else - { - _disk_pq_table.inflate_vector((uint8_t *)medoid_bufs[cur_m], (_centroid_data + cur_m * _aligned_dim)); - } - } - else - { - throw ANNException("Unable to read a medoid", -1, __FUNCSIG__, __FILE__, __LINE__); - } - delete[] medoid_bufs[cur_m]; - } -} - -template -void PQFlashIndex::generate_random_labels(std::vector &labels, const uint32_t num_labels, - const uint32_t nthreads) -{ - std::random_device rd; - labels.clear(); - labels.resize(num_labels); - - uint64_t num_total_labels = _pts_to_label_offsets[_num_points - 1] + _pts_to_label_counts[_num_points - 1]; - std::mt19937 gen(rd()); - if (num_total_labels == 0) - { - std::stringstream stream; - stream << "No labels found in data. Not sampling random labels "; - diskann::cerr << stream.str() << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); - } - std::uniform_int_distribution dis(0, num_total_labels - 1); - -#pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads) - for (int64_t i = 0; i < num_labels; i++) - { - uint64_t rnd_loc = dis(gen); - labels[i] = (LabelT)_pts_to_labels[rnd_loc]; - } -} - - - -template -void PQFlashIndex::reset_stream_for_reading(std::basic_istream &infile) -{ - infile.clear(); - infile.seekg(0); -} - -template -void PQFlashIndex::get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts, - uint32_t &num_total_labels) -{ - num_pts = 0; - num_total_labels = 0; - - size_t file_size = fileContent.length(); - - std::string label_str; - size_t cur_pos = 0; - size_t next_pos = 0; - while (cur_pos < file_size && cur_pos != std::string::npos) - { - next_pos = fileContent.find('\n', cur_pos); - if (next_pos == std::string::npos) - { - break; + for (uint64_t cur_m = 0; cur_m < _num_medoids; cur_m++) { + if (read_status[cur_m] == true) { + if (!_use_disk_index_pq) { + for (uint32_t i = 0; i < _data_dim; i++) + _centroid_data[cur_m * _aligned_dim + i] = medoid_bufs[cur_m][i]; + } else { + _disk_pq_table.inflate_vector( + (uint8_t *) medoid_bufs[cur_m], + (_centroid_data + cur_m * _aligned_dim)); } - - size_t lbl_pos = cur_pos; - size_t next_lbl_pos = 0; - while (lbl_pos < next_pos && lbl_pos != std::string::npos) - { - next_lbl_pos = fileContent.find(',', lbl_pos); - if (next_lbl_pos == std::string::npos) // the last label - { - next_lbl_pos = next_pos; - } - - num_total_labels++; - - lbl_pos = next_lbl_pos + 1; - } - - cur_pos = next_pos + 1; - - num_pts++; + } else { + throw ANNException("Unable to read a medoid", -1, __FUNCSIG__, __FILE__, + __LINE__); + } + delete[] medoid_bufs[cur_m]; } - - diskann::cout << "Labels file metadata: num_points: " << num_pts << ", #total_labels: " << num_total_labels - << std::endl; -} - + } #ifdef EXEC_ENV_OLS -template -int PQFlashIndex::load(MemoryMappedFiles &files, uint32_t num_threads, const char *index_prefix) -{ + template + int PQFlashIndex::load(MemoryMappedFiles &files, + uint32_t num_threads, + const char * index_prefix) { #else -template int PQFlashIndex::load(uint32_t num_threads, const char *index_prefix) -{ + template + int PQFlashIndex::load(uint32_t num_threads, + const char *index_prefix) { #endif std::string pq_table_bin = std::string(index_prefix) + "_pq_pivots.bin"; - std::string pq_compressed_vectors = std::string(index_prefix) + "_pq_compressed.bin"; + std::string pq_compressed_vectors = + std::string(index_prefix) + "_pq_compressed.bin"; std::string _disk_index_file = std::string(index_prefix) + "_disk.index"; #ifdef EXEC_ENV_OLS - return load_from_separate_paths(files, num_threads, _disk_index_file.c_str(), pq_table_bin.c_str(), - pq_compressed_vectors.c_str()); + return load_from_separate_paths( + files, num_threads, _disk_index_file.c_str(), pq_table_bin.c_str(), + pq_compressed_vectors.c_str()); #else - return load_from_separate_paths(num_threads, _disk_index_file.c_str(), pq_table_bin.c_str(), + return load_from_separate_paths(num_threads, _disk_index_file.c_str(), + pq_table_bin.c_str(), pq_compressed_vectors.c_str()); #endif -} + } #ifdef EXEC_ENV_OLS -template -int PQFlashIndex::load_from_separate_paths(diskann::MemoryMappedFiles &files, uint32_t num_threads, - const char *index_filepath, const char *pivots_filepath, - const char *compressed_filepath) -{ + template + int PQFlashIndex::load_from_separate_paths( + diskann::MemoryMappedFiles &files, uint32_t num_threads, + const char *index_filepath, const char *pivots_filepath, + const char *compressed_filepath) { #else -template -int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, const char *index_filepath, - const char *pivots_filepath, const char *compressed_filepath) -{ + template + int PQFlashIndex::load_from_separate_paths( + uint32_t num_threads, const char *index_filepath, + const char *pivots_filepath, const char *compressed_filepath) { #endif std::string pq_table_bin = pivots_filepath; std::string pq_compressed_vectors = compressed_filepath; std::string _disk_index_file = index_filepath; std::string medoids_file = std::string(_disk_index_file) + "_medoids.bin"; - std::string centroids_file = std::string(_disk_index_file) + "_centroids.bin"; + std::string centroids_file = + std::string(_disk_index_file) + "_centroids.bin"; size_t pq_file_dim, pq_file_num_centroids; #ifdef EXEC_ENV_OLS - get_bin_metadata(files, pq_table_bin, pq_file_num_centroids, pq_file_dim, METADATA_SIZE); + get_bin_metadata(files, pq_table_bin, pq_file_num_centroids, pq_file_dim, + METADATA_SIZE); #else - get_bin_metadata(pq_table_bin, pq_file_num_centroids, pq_file_dim, METADATA_SIZE); + get_bin_metadata(pq_table_bin, pq_file_num_centroids, pq_file_dim, + METADATA_SIZE); #endif this->_disk_index_file = _disk_index_file; - if (pq_file_num_centroids != 256) - { - diskann::cout << "Error. Number of PQ centroids is not 256. Exiting." << std::endl; - return -1; + if (pq_file_num_centroids != 256) { + diskann::cout << "Error. Number of PQ centroids is not 256. Exiting." + << std::endl; + return -1; } this->_data_dim = pq_file_dim; @@ -681,20 +590,31 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons size_t npts_u64, nchunks_u64; #ifdef EXEC_ENV_OLS - diskann::load_bin(files, pq_compressed_vectors, this->data, npts_u64, nchunks_u64); + diskann::load_bin(files, pq_compressed_vectors, this->data, + npts_u64, nchunks_u64); #else - diskann::load_bin(pq_compressed_vectors, this->data, npts_u64, nchunks_u64); + diskann::load_bin(pq_compressed_vectors, this->data, npts_u64, + nchunks_u64); #endif this->_num_points = npts_u64; this->_n_chunks = nchunks_u64; _filter_store = std::make_unique>(); - if (_filter_store->load(_disk_index_file) == false) { - diskann::cout << "Index does not have filter support." << std::endl; - } else { - diskann::cout << "Index has filter support. " << std::endl; + try { + _filter_index = _filter_store->load(_disk_index_file); + if (_filter_index) { + diskann::cout << "Index has filter support. " << std::endl; + } else { + diskann::cout << "Index does not have filter support." << std::endl; + } + } catch (diskann::ANNException &ex) { + // If filter_store=>load() returns false, it means filters are not + // enabled. If it throws, it means there was an error in processing a + // filter index. + diskann::cerr << "Filter index load failed because: " << ex.what() + << std::endl; + return false; } - #ifdef EXEC_ENV_OLS _pq_table.load_pq_centroid_bin(files, pq_table_bin.c_str(), nchunks_u64); @@ -702,40 +622,43 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons _pq_table.load_pq_centroid_bin(pq_table_bin.c_str(), nchunks_u64); #endif - diskann::cout << "Loaded PQ centroids and in-memory compressed vectors. #points: " << _num_points - << " #dim: " << _data_dim << " #aligned_dim: " << _aligned_dim << " #chunks: " << _n_chunks - << std::endl; - - if (_n_chunks > MAX_PQ_CHUNKS) - { - std::stringstream stream; - stream << "Error loading index. Ensure that max PQ bytes for in-memory " - "PQ data does not exceed " - << MAX_PQ_CHUNKS << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + diskann::cout + << "Loaded PQ centroids and in-memory compressed vectors. #points: " + << _num_points << " #dim: " << _data_dim + << " #aligned_dim: " << _aligned_dim << " #chunks: " << _n_chunks + << std::endl; + + if (_n_chunks > MAX_PQ_CHUNKS) { + std::stringstream stream; + stream << "Error loading index. Ensure that max PQ bytes for in-memory " + "PQ data does not exceed " + << MAX_PQ_CHUNKS << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, + __LINE__); } std::string disk_pq_pivots_path = this->_disk_index_file + "_pq_pivots.bin"; #ifdef EXEC_ENV_OLS - if (files.fileExists(disk_pq_pivots_path)) - { - _use_disk_index_pq = true; - // giving 0 chunks to make the _pq_table infer from the - // chunk_offsets file the correct value - _disk_pq_table.load_pq_centroid_bin(files, disk_pq_pivots_path.c_str(), 0); + if (files.fileExists(disk_pq_pivots_path)) { + _use_disk_index_pq = true; + // giving 0 chunks to make the _pq_table infer from the + // chunk_offsets file the correct value + _disk_pq_table.load_pq_centroid_bin(files, disk_pq_pivots_path.c_str(), + 0); #else - if (file_exists(disk_pq_pivots_path)) - { - _use_disk_index_pq = true; - // giving 0 chunks to make the _pq_table infer from the - // chunk_offsets file the correct value - _disk_pq_table.load_pq_centroid_bin(disk_pq_pivots_path.c_str(), 0); + if (file_exists(disk_pq_pivots_path)) { + _use_disk_index_pq = true; + // giving 0 chunks to make the _pq_table infer from the + // chunk_offsets file the correct value + _disk_pq_table.load_pq_centroid_bin(disk_pq_pivots_path.c_str(), 0); #endif - _disk_pq_n_chunks = _disk_pq_table.get_num_chunks(); - _disk_bytes_per_point = - _disk_pq_n_chunks * sizeof(uint8_t); // revising disk_bytes_per_point since DISK PQ is used. - diskann::cout << "Disk index uses PQ data compressed down to " << _disk_pq_n_chunks << " bytes per point." - << std::endl; + _disk_pq_n_chunks = _disk_pq_table.get_num_chunks(); + _disk_bytes_per_point = + _disk_pq_n_chunks * + sizeof( + uint8_t); // revising disk_bytes_per_point since DISK PQ is used. + diskann::cout << "Disk index uses PQ data compressed down to " + << _disk_pq_n_chunks << " bytes per point." << std::endl; } // read index metadata @@ -749,44 +672,44 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons this->setup_thread_data(num_threads); this->_max_nthreads = num_threads; - char *bytes = getHeaderBytes(); - ContentBuf buf(bytes, HEADER_SIZE); + char * bytes = getHeaderBytes(); + ContentBuf buf(bytes, HEADER_SIZE); std::basic_istream index_metadata(&buf); #else std::ifstream index_metadata(_disk_index_file, std::ios::binary); #endif - uint32_t nr, nc; // metadata itself is stored as bin format (nr is number of - // metadata, nc should be 1) + uint32_t nr, nc; // metadata itself is stored as bin format (nr is number + // of metadata, nc should be 1) READ_U32(index_metadata, nr); READ_U32(index_metadata, nc); uint64_t disk_nnodes; - uint64_t disk_ndims; // can be disk PQ dim if disk_PQ is set to true + uint64_t disk_ndims; // can be disk PQ dim if disk_PQ is set to true READ_U64(index_metadata, disk_nnodes); READ_U64(index_metadata, disk_ndims); - if (disk_nnodes != _num_points) - { - diskann::cout << "Mismatch in #points for compressed data file and disk " - "index file: " - << disk_nnodes << " vs " << _num_points << std::endl; - return -1; + if (disk_nnodes != _num_points) { + diskann::cout << "Mismatch in #points for compressed data file and disk " + "index file: " + << disk_nnodes << " vs " << _num_points << std::endl; + return -1; } size_t medoid_id_on_file; READ_U64(index_metadata, medoid_id_on_file); READ_U64(index_metadata, _max_node_len); READ_U64(index_metadata, _nnodes_per_sector); - _max_degree = ((_max_node_len - _disk_bytes_per_point) / sizeof(uint32_t)) - 1; - - if (_max_degree > defaults::MAX_GRAPH_DEGREE) - { - std::stringstream stream; - stream << "Error loading index. Ensure that max graph degree (R) does " - "not exceed " - << defaults::MAX_GRAPH_DEGREE << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + _max_degree = + ((_max_node_len - _disk_bytes_per_point) / sizeof(uint32_t)) - 1; + + if (_max_degree > defaults::MAX_GRAPH_DEGREE) { + std::stringstream stream; + stream << "Error loading index. Ensure that max graph degree (R) does " + "not exceed " + << defaults::MAX_GRAPH_DEGREE << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, + __LINE__); } // setting up concept of frozen points in disk index for streaming-DiskANN @@ -794,25 +717,24 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons uint64_t file_frozen_id; READ_U64(index_metadata, file_frozen_id); if (this->_num_frozen_points == 1) - this->_frozen_location = file_frozen_id; - if (this->_num_frozen_points == 1) - { - diskann::cout << " Detected frozen point in index at location " << this->_frozen_location - << ". Will not output it at search time." << std::endl; + this->_frozen_location = file_frozen_id; + if (this->_num_frozen_points == 1) { + diskann::cout << " Detected frozen point in index at location " + << this->_frozen_location + << ". Will not output it at search time." << std::endl; } READ_U64(index_metadata, this->_reorder_data_exists); - if (this->_reorder_data_exists) - { - if (this->_use_disk_index_pq == false) - { - throw ANNException("Reordering is designed for used with disk PQ " - "compression option", - -1, __FUNCSIG__, __FILE__, __LINE__); - } - READ_U64(index_metadata, this->_reorder_data_start_sector); - READ_U64(index_metadata, this->_ndims_reorder_vecs); - READ_U64(index_metadata, this->_nvecs_per_sector); + if (this->_reorder_data_exists) { + if (this->_use_disk_index_pq == false) { + throw ANNException( + "Reordering is designed for used with disk PQ " + "compression option", + -1, __FUNCSIG__, __FILE__, __LINE__); + } + READ_U64(index_metadata, this->_reorder_data_start_sector); + READ_U64(index_metadata, this->_ndims_reorder_vecs); + READ_U64(index_metadata, this->_nvecs_per_sector); } diskann::cout << "Disk-Index File Meta-data: "; @@ -836,181 +758,175 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons #endif #ifdef EXEC_ENV_OLS - if (files.fileExists(medoids_file)) - { - size_t tmp_dim; - diskann::load_bin(files, norm_file, medoids_file, _medoids, _num_medoids, tmp_dim); + if (files.fileExists(medoids_file)) { + size_t tmp_dim; + diskann::load_bin(files, norm_file, medoids_file, _medoids, + _num_medoids, tmp_dim); #else - if (file_exists(medoids_file)) - { - size_t tmp_dim; - diskann::load_bin(medoids_file, _medoids, _num_medoids, tmp_dim); + if (file_exists(medoids_file)) { + size_t tmp_dim; + diskann::load_bin(medoids_file, _medoids, _num_medoids, + tmp_dim); #endif - if (tmp_dim != 1) - { - std::stringstream stream; - stream << "Error loading medoids file. Expected bin format of m times " - "1 vector of uint32_t." - << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); - } + if (tmp_dim != 1) { + std::stringstream stream; + stream << "Error loading medoids file. Expected bin format of m times " + "1 vector of uint32_t." + << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, + __LINE__); + } #ifdef EXEC_ENV_OLS - if (!files.fileExists(centroids_file)) - { + if (!files.fileExists(centroids_file)) { #else - if (!file_exists(centroids_file)) - { + if (!file_exists(centroids_file)) { #endif - diskann::cout << "Centroid data file not found. Using corresponding vectors " - "for the medoids " - << std::endl; - use_medoids_data_as_centroids(); - } - else - { - size_t num_centroids, aligned_tmp_dim; + diskann::cout + << "Centroid data file not found. Using corresponding vectors " + "for the medoids " + << std::endl; + use_medoids_data_as_centroids(); + } else { + size_t num_centroids, aligned_tmp_dim; #ifdef EXEC_ENV_OLS - diskann::load_aligned_bin(files, centroids_file, _centroid_data, num_centroids, tmp_dim, - aligned_tmp_dim); + diskann::load_aligned_bin(files, centroids_file, _centroid_data, + num_centroids, tmp_dim, + aligned_tmp_dim); #else - diskann::load_aligned_bin(centroids_file, _centroid_data, num_centroids, tmp_dim, aligned_tmp_dim); + diskann::load_aligned_bin(centroids_file, _centroid_data, + num_centroids, tmp_dim, + aligned_tmp_dim); #endif - if (aligned_tmp_dim != _aligned_dim || num_centroids != _num_medoids) - { - std::stringstream stream; - stream << "Error loading centroids data file. Expected bin format " - "of " - "m times data_dim vector of float, where m is number of " - "medoids " - "in medoids file."; - diskann::cerr << stream.str() << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); - } + if (aligned_tmp_dim != _aligned_dim || num_centroids != _num_medoids) { + std::stringstream stream; + stream << "Error loading centroids data file. Expected bin format " + "of " + "m times data_dim vector of float, where m is number of " + "medoids " + "in medoids file."; + diskann::cerr << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, + __LINE__); } - } - else - { - _num_medoids = 1; - _medoids = new uint32_t[1]; - _medoids[0] = (uint32_t)(medoid_id_on_file); - use_medoids_data_as_centroids(); + } + } else { + _num_medoids = 1; + _medoids = new uint32_t[1]; + _medoids[0] = (uint32_t) (medoid_id_on_file); + use_medoids_data_as_centroids(); } - std::string norm_file = std::string(_disk_index_file) + "_max_base_norm.bin"; + std::string norm_file = + std::string(_disk_index_file) + "_max_base_norm.bin"; #ifdef EXEC_ENV_OLS - if (files.fileExists(norm_file) && metric == diskann::Metric::INNER_PRODUCT) - { - uint64_t dumr, dumc; - float *norm_val; - diskann::load_bin(files, norm_val, dumr, dumc); + if (files.fileExists(norm_file) && + metric == diskann::Metric::INNER_PRODUCT) { + uint64_t dumr, dumc; + float * norm_val; + diskann::load_bin(files, norm_val, dumr, dumc); #else - if (file_exists(norm_file) && metric == diskann::Metric::INNER_PRODUCT) - { - uint64_t dumr, dumc; - float *norm_val; - diskann::load_bin(norm_file, norm_val, dumr, dumc); + if (file_exists(norm_file) && metric == diskann::Metric::INNER_PRODUCT) { + uint64_t dumr, dumc; + float * norm_val; + diskann::load_bin(norm_file, norm_val, dumr, dumc); #endif - this->_max_base_norm = norm_val[0]; - diskann::cout << "Setting re-scaling factor of base vectors to " << this->_max_base_norm << std::endl; - delete[] norm_val; + this->_max_base_norm = norm_val[0]; + diskann::cout << "Setting re-scaling factor of base vectors to " + << this->_max_base_norm << std::endl; + delete[] norm_val; } diskann::cout << "done.." << std::endl; return 0; -} + } #ifdef USE_BING_INFRA -bool getNextCompletedRequest(std::shared_ptr &reader, IOContext &ctx, size_t size, - int &completedIndex) -{ - if ((*ctx.m_pRequests)[0].m_callback) - { - bool waitsRemaining = false; - long completeCount = ctx.m_completeCount; - do - { - for (int i = 0; i < size; i++) - { - auto ithStatus = (*ctx.m_pRequestsStatus)[i]; - if (ithStatus == IOContext::Status::READ_SUCCESS) - { - completedIndex = i; - return true; - } - else if (ithStatus == IOContext::Status::READ_WAIT) - { - waitsRemaining = true; - } - } + bool getNextCompletedRequest(std::shared_ptr &reader, + IOContext &ctx, size_t size, + int &completedIndex) { + if ((*ctx.m_pRequests)[0].m_callback) { + bool waitsRemaining = false; + long completeCount = ctx.m_completeCount; + do { + for (int i = 0; i < size; i++) { + auto ithStatus = (*ctx.m_pRequestsStatus)[i]; + if (ithStatus == IOContext::Status::READ_SUCCESS) { + completedIndex = i; + return true; + } else if (ithStatus == IOContext::Status::READ_WAIT) { + waitsRemaining = true; + } + } - // if we didn't find one in READ_SUCCESS, wait for one to complete. - if (waitsRemaining) - { - WaitOnAddress(&ctx.m_completeCount, &completeCount, sizeof(completeCount), 100); - // this assumes the knowledge of the reader behavior (implicit - // contract). need better factoring? - } - } while (waitsRemaining); + // if we didn't find one in READ_SUCCESS, wait for one to complete. + if (waitsRemaining) { + WaitOnAddress(&ctx.m_completeCount, &completeCount, + sizeof(completeCount), 100); + // this assumes the knowledge of the reader behavior (implicit + // contract). need better factoring? + } + } while (waitsRemaining); - completedIndex = -1; - return false; - } - else - { - reader->wait(ctx, completedIndex); - return completedIndex != -1; + completedIndex = -1; + return false; + } else { + reader->wait(ctx, completedIndex); + return completedIndex != -1; } -} + } #endif -template -void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, - uint64_t *indices, float *distances, const uint64_t beam_width, - const bool use_reorder_data, QueryStats *stats) -{ - cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, std::numeric_limits::max(), + template + void PQFlashIndex::cached_beam_search( + const T *query1, const uint64_t k_search, const uint64_t l_search, + uint64_t *indices, float *distances, const uint64_t beam_width, + const bool use_reorder_data, QueryStats *stats) { + cached_beam_search(query1, k_search, l_search, indices, distances, + beam_width, std::numeric_limits::max(), use_reorder_data, stats); -} - -template -void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, - uint64_t *indices, float *distances, const uint64_t beam_width, - const bool use_filter, const LabelT &filter_label, - const bool use_reorder_data, QueryStats *stats) -{ - cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, use_filter, filter_label, - std::numeric_limits::max(), use_reorder_data, stats); -} - -template -void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, - uint64_t *indices, float *distances, const uint64_t beam_width, - const uint32_t io_limit, const bool use_reorder_data, - QueryStats *stats) -{ + } + + template + void PQFlashIndex::cached_beam_search( + const T *query1, const uint64_t k_search, const uint64_t l_search, + uint64_t *indices, float *distances, const uint64_t beam_width, + const bool use_filter, const LabelT &filter_label, + const bool use_reorder_data, QueryStats *stats) { + cached_beam_search(query1, k_search, l_search, indices, distances, + beam_width, use_filter, filter_label, + std::numeric_limits::max(), use_reorder_data, + stats); + } + + template + void PQFlashIndex::cached_beam_search( + const T *query1, const uint64_t k_search, const uint64_t l_search, + uint64_t *indices, float *distances, const uint64_t beam_width, + const uint32_t io_limit, const bool use_reorder_data, QueryStats *stats) { LabelT dummy_filter = 0; - cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, false, dummy_filter, io_limit, + cached_beam_search(query1, k_search, l_search, indices, distances, + beam_width, false, dummy_filter, io_limit, use_reorder_data, stats); -} - -template -void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, - uint64_t *indices, float *distances, const uint64_t beam_width, - const bool use_filter, const LabelT &filter_label, - const uint32_t io_limit, const bool use_reorder_data, - QueryStats *stats) -{ - - uint64_t num_sector_per_nodes = DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN); + } + + template + void PQFlashIndex::cached_beam_search( + const T *query1, const uint64_t k_search, const uint64_t l_search, + uint64_t *indices, float *distances, const uint64_t beam_width, + const bool use_filter, const LabelT &filter_label, + const uint32_t io_limit, const bool use_reorder_data, QueryStats *stats) { + uint64_t num_sector_per_nodes = + DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN); if (beam_width > num_sector_per_nodes * defaults::MAX_N_SECTOR_READS) - throw ANNException("Beamwidth can not be higher than defaults::MAX_N_SECTOR_READS", -1, __FUNCSIG__, __FILE__, - __LINE__); + throw ANNException( + "Beamwidth can not be higher than defaults::MAX_N_SECTOR_READS", -1, + __FUNCSIG__, __FILE__, __LINE__); ScratchStoreManager> manager(this->_thread_data); - auto data = manager.scratch_space(); - IOContext &ctx = data->ctx; - auto query_scratch = &(data->scratch); + auto data = manager.scratch_space(); + IOContext & ctx = data->ctx; + auto query_scratch = &(data->scratch); auto pq_query_scratch = query_scratch->pq_scratch(); // reset query scratch @@ -1018,113 +934,113 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t // copy query to thread specific aligned and allocated memory (for distance // calculations we need aligned data) - float query_norm = 0; - T *aligned_query_T = query_scratch->aligned_query_T(); + float query_norm = 0; + T * aligned_query_T = query_scratch->aligned_query_T(); float *query_float = pq_query_scratch->aligned_query_float; float *query_rotated = pq_query_scratch->rotated_query; // normalization step. for cosine, we simply normalize the query - // for mips, we normalize the first d-1 dims, and add a 0 for last dim, since an extra coordinate was used to - // convert MIPS to L2 search - if (metric == diskann::Metric::INNER_PRODUCT || metric == diskann::Metric::COSINE) - { - uint64_t inherent_dim = (metric == diskann::Metric::COSINE) ? this->_data_dim : (uint64_t)(this->_data_dim - 1); - for (size_t i = 0; i < inherent_dim; i++) - { - aligned_query_T[i] = query1[i]; - query_norm += query1[i] * query1[i]; - } - if (metric == diskann::Metric::INNER_PRODUCT) - aligned_query_T[this->_data_dim - 1] = 0; - - query_norm = std::sqrt(query_norm); - - for (size_t i = 0; i < inherent_dim; i++) - { - aligned_query_T[i] = (T)(aligned_query_T[i] / query_norm); - } - pq_query_scratch->initialize(this->_data_dim, aligned_query_T); - } - else - { - for (size_t i = 0; i < this->_data_dim; i++) - { - aligned_query_T[i] = query1[i]; - } - pq_query_scratch->initialize(this->_data_dim, aligned_query_T); + // for mips, we normalize the first d-1 dims, and add a 0 for last dim, + // since an extra coordinate was used to convert MIPS to L2 search + if (metric == diskann::Metric::INNER_PRODUCT || + metric == diskann::Metric::COSINE) { + uint64_t inherent_dim = (metric == diskann::Metric::COSINE) + ? this->_data_dim + : (uint64_t) (this->_data_dim - 1); + for (size_t i = 0; i < inherent_dim; i++) { + aligned_query_T[i] = query1[i]; + query_norm += query1[i] * query1[i]; + } + if (metric == diskann::Metric::INNER_PRODUCT) + aligned_query_T[this->_data_dim - 1] = 0; + + query_norm = std::sqrt(query_norm); + + for (size_t i = 0; i < inherent_dim; i++) { + aligned_query_T[i] = (T) (aligned_query_T[i] / query_norm); + } + pq_query_scratch->initialize(this->_data_dim, aligned_query_T); + } else { + for (size_t i = 0; i < this->_data_dim; i++) { + aligned_query_T[i] = query1[i]; + } + pq_query_scratch->initialize(this->_data_dim, aligned_query_T); } // pointers to buffers for data T *data_buf = query_scratch->coord_scratch; - _mm_prefetch((char *)data_buf, _MM_HINT_T1); + _mm_prefetch((char *) data_buf, _MM_HINT_T1); // sector scratch - char *sector_scratch = query_scratch->sector_scratch; - uint64_t §or_scratch_idx = query_scratch->sector_idx; + char * sector_scratch = query_scratch->sector_scratch; + uint64_t & sector_scratch_idx = query_scratch->sector_idx; const uint64_t num_sectors_per_node = - _nnodes_per_sector > 0 ? 1 : DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN); + _nnodes_per_sector > 0 + ? 1 + : DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN); // query <-> PQ chunk centers distances - _pq_table.preprocess_query(query_rotated); // center the query and rotate if - // we have a rotation matrix + _pq_table.preprocess_query(query_rotated); // center the query and rotate + // if we have a rotation matrix float *pq_dists = pq_query_scratch->aligned_pqtable_dist_scratch; _pq_table.populate_chunk_distances(query_rotated, pq_dists); // query <-> neighbor list - float *dist_scratch = pq_query_scratch->aligned_dist_scratch; + float * dist_scratch = pq_query_scratch->aligned_dist_scratch; uint8_t *pq_coord_scratch = pq_query_scratch->aligned_pq_coord_scratch; // lambda to batch compute query<-> node distances in PQ space - auto compute_dists = [this, pq_coord_scratch, pq_dists](const uint32_t *ids, const uint64_t n_ids, - float *dists_out) { - diskann::aggregate_coords(ids, n_ids, this->data, this->_n_chunks, pq_coord_scratch); - diskann::pq_dist_lookup(pq_coord_scratch, n_ids, this->_n_chunks, pq_dists, dists_out); + auto compute_dists = [this, pq_coord_scratch, pq_dists]( + const uint32_t *ids, const uint64_t n_ids, + float *dists_out) { + diskann::aggregate_coords(ids, n_ids, this->data, this->_n_chunks, + pq_coord_scratch); + diskann::pq_dist_lookup(pq_coord_scratch, n_ids, this->_n_chunks, + pq_dists, dists_out); }; Timer query_timer, io_timer, cpu_timer; tsl::robin_set &visited = query_scratch->visited; - NeighborPriorityQueue &retset = query_scratch->retset; + NeighborPriorityQueue & retset = query_scratch->retset; retset.reserve(l_search); std::vector &full_retset = query_scratch->full_retset; uint32_t best_medoid = 0; - float best_dist = (std::numeric_limits::max)(); - if (!use_filter) - { - for (uint64_t cur_m = 0; cur_m < _num_medoids; cur_m++) - { - float cur_expanded_dist = - _dist_cmp_float->compare(query_float, _centroid_data + _aligned_dim * cur_m, (uint32_t)_aligned_dim); - if (cur_expanded_dist < best_dist) - { - best_medoid = _medoids[cur_m]; - best_dist = cur_expanded_dist; - } - } - } - else - { - if (_filter_to_medoid_ids.find(filter_label) != _filter_to_medoid_ids.end()) - { - //const auto &medoid_ids = _filter_to_medoid_ids[filter_label]; - const auto &medoid_ids = _filter_store->get_medoids_of_label(filter_label); - for (uint64_t cur_m = 0; cur_m < medoid_ids.size(); cur_m++) - { - // for filtered index, we dont store global centroid data as for unfiltered index, so we use PQ distance - // as approximation to decide closest medoid matching the query filter. - compute_dists(&medoid_ids[cur_m], 1, dist_scratch); - float cur_expanded_dist = dist_scratch[0]; - if (cur_expanded_dist < best_dist) - { - best_medoid = medoid_ids[cur_m]; - best_dist = cur_expanded_dist; - } - } + float best_dist = (std::numeric_limits::max) (); + if (!use_filter) { + for (uint64_t cur_m = 0; cur_m < _num_medoids; cur_m++) { + float cur_expanded_dist = _dist_cmp_float->compare( + query_float, _centroid_data + _aligned_dim * cur_m, + (uint32_t) _aligned_dim); + if (cur_expanded_dist < best_dist) { + best_medoid = _medoids[cur_m]; + best_dist = cur_expanded_dist; } - else - { - throw ANNException("Cannot find medoid for specified filter.", -1, __FUNCSIG__, __FILE__, __LINE__); + } + } else { + const auto &medoid_ids = + _filter_store->get_medoids_of_label(filter_label); + if (medoid_ids.size() > 0) + // if (_filter_to_medoid_ids.find(filter_label) != + // _filter_to_medoid_ids.end()) + { + // const auto &medoid_ids = _filter_to_medoid_ids[filter_label]; + + for (uint64_t cur_m = 0; cur_m < medoid_ids.size(); cur_m++) { + // for filtered index, we dont store global centroid data as for + // unfiltered index, so we use PQ distance as approximation to decide + // closest medoid matching the query filter. + compute_dists(&medoid_ids[cur_m], 1, dist_scratch); + float cur_expanded_dist = dist_scratch[0]; + if (cur_expanded_dist < best_dist) { + best_medoid = medoid_ids[cur_m]; + best_dist = cur_expanded_dist; + } } + } else { + throw ANNException("Cannot find medoid for specified filter.", -1, + __FUNCSIG__, __FILE__, __LINE__); + } } compute_dists(&best_medoid, 1, dist_scratch); @@ -1142,356 +1058,346 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t frontier_nhoods.reserve(2 * beam_width); std::vector frontier_read_reqs; frontier_read_reqs.reserve(2 * beam_width); - std::vector>> cached_nhoods; + std::vector>> + cached_nhoods; cached_nhoods.reserve(2 * beam_width); - while (retset.has_unexpanded_node() && num_ios < io_limit) - { - // clear iteration state - frontier.clear(); - frontier_nhoods.clear(); - frontier_read_reqs.clear(); - cached_nhoods.clear(); - sector_scratch_idx = 0; - // find new beam - uint32_t num_seen = 0; - while (retset.has_unexpanded_node() && frontier.size() < beam_width && num_seen < beam_width) - { - auto nbr = retset.closest_unexpanded(); - num_seen++; - auto iter = _nhood_cache.find(nbr.id); - if (iter != _nhood_cache.end()) - { - cached_nhoods.push_back(std::make_pair(nbr.id, iter->second)); - if (stats != nullptr) - { - stats->n_cache_hits++; - } - } - else - { - frontier.push_back(nbr.id); - } - if (this->_count_visited_nodes) - { - reinterpret_cast &>(this->_node_visit_counter[nbr.id].second).fetch_add(1); - } + while (retset.has_unexpanded_node() && num_ios < io_limit) { + // clear iteration state + frontier.clear(); + frontier_nhoods.clear(); + frontier_read_reqs.clear(); + cached_nhoods.clear(); + sector_scratch_idx = 0; + // find new beam + uint32_t num_seen = 0; + while (retset.has_unexpanded_node() && frontier.size() < beam_width && + num_seen < beam_width) { + auto nbr = retset.closest_unexpanded(); + num_seen++; + auto iter = _nhood_cache.find(nbr.id); + if (iter != _nhood_cache.end()) { + cached_nhoods.push_back(std::make_pair(nbr.id, iter->second)); + if (stats != nullptr) { + stats->n_cache_hits++; + } + } else { + frontier.push_back(nbr.id); } + if (this->_count_visited_nodes) { + reinterpret_cast &>( + this->_node_visit_counter[nbr.id].second) + .fetch_add(1); + } + } - // read nhoods of frontier ids - if (!frontier.empty()) - { - if (stats != nullptr) - stats->n_hops++; - for (uint64_t i = 0; i < frontier.size(); i++) - { - auto id = frontier[i]; - std::pair fnhood; - fnhood.first = id; - fnhood.second = sector_scratch + num_sectors_per_node * sector_scratch_idx * defaults::SECTOR_LEN; - sector_scratch_idx++; - frontier_nhoods.push_back(fnhood); - frontier_read_reqs.emplace_back(get_node_sector((size_t)id) * defaults::SECTOR_LEN, - num_sectors_per_node * defaults::SECTOR_LEN, fnhood.second); - if (stats != nullptr) - { - stats->n_4k++; - stats->n_ios++; - } - num_ios++; - } - io_timer.reset(); + // read nhoods of frontier ids + if (!frontier.empty()) { + if (stats != nullptr) + stats->n_hops++; + for (uint64_t i = 0; i < frontier.size(); i++) { + auto id = frontier[i]; + std::pair fnhood; + fnhood.first = id; + fnhood.second = sector_scratch + num_sectors_per_node * + sector_scratch_idx * + defaults::SECTOR_LEN; + sector_scratch_idx++; + frontier_nhoods.push_back(fnhood); + frontier_read_reqs.emplace_back( + get_node_sector((size_t) id) * defaults::SECTOR_LEN, + num_sectors_per_node * defaults::SECTOR_LEN, fnhood.second); + if (stats != nullptr) { + stats->n_4k++; + stats->n_ios++; + } + num_ios++; + } + io_timer.reset(); #ifdef USE_BING_INFRA - reader->read(frontier_read_reqs, ctx, - true); // asynhronous reader for Bing. + reader->read(frontier_read_reqs, ctx, + true); // asynhronous reader for Bing. #else - reader->read(frontier_read_reqs, ctx); // synchronous IO linux + reader->read(frontier_read_reqs, ctx); // synchronous IO linux #endif - if (stats != nullptr) - { - stats->io_us += (float)io_timer.elapsed(); - } + if (stats != nullptr) { + stats->io_us += (float) io_timer.elapsed(); + } + } + + // process cached nhoods + for (auto &cached_nhood : cached_nhoods) { + auto global_cache_iter = _coord_cache.find(cached_nhood.first); + T * node_fp_coords_copy = global_cache_iter->second; + float cur_expanded_dist; + if (!_use_disk_index_pq) { + cur_expanded_dist = _dist_cmp->compare( + aligned_query_T, node_fp_coords_copy, (uint32_t) _aligned_dim); + } else { + if (metric == diskann::Metric::INNER_PRODUCT) + cur_expanded_dist = _disk_pq_table.inner_product( + query_float, (uint8_t *) node_fp_coords_copy); + else + cur_expanded_dist = + _disk_pq_table.l2_distance( // disk_pq does not support OPQ yet + query_float, (uint8_t *) node_fp_coords_copy); + } + full_retset.push_back( + Neighbor((uint32_t) cached_nhood.first, cur_expanded_dist)); + + uint64_t nnbrs = cached_nhood.second.first; + uint32_t *node_nbrs = cached_nhood.second.second; + + // compute node_nbrs <-> query dists in PQ space + cpu_timer.reset(); + compute_dists(node_nbrs, nnbrs, dist_scratch); + if (stats != nullptr) { + stats->n_cmps += (uint32_t) nnbrs; + stats->cpu_us += (float) cpu_timer.elapsed(); } - // process cached nhoods - for (auto &cached_nhood : cached_nhoods) - { - auto global_cache_iter = _coord_cache.find(cached_nhood.first); - T *node_fp_coords_copy = global_cache_iter->second; - float cur_expanded_dist; - if (!_use_disk_index_pq) - { - cur_expanded_dist = _dist_cmp->compare(aligned_query_T, node_fp_coords_copy, (uint32_t)_aligned_dim); - } - else - { - if (metric == diskann::Metric::INNER_PRODUCT) - cur_expanded_dist = _disk_pq_table.inner_product(query_float, (uint8_t *)node_fp_coords_copy); - else - cur_expanded_dist = _disk_pq_table.l2_distance( // disk_pq does not support OPQ yet - query_float, (uint8_t *)node_fp_coords_copy); - } - full_retset.push_back(Neighbor((uint32_t)cached_nhood.first, cur_expanded_dist)); - - uint64_t nnbrs = cached_nhood.second.first; - uint32_t *node_nbrs = cached_nhood.second.second; - - // compute node_nbrs <-> query dists in PQ space - cpu_timer.reset(); - compute_dists(node_nbrs, nnbrs, dist_scratch); - if (stats != nullptr) - { - stats->n_cmps += (uint32_t)nnbrs; - stats->cpu_us += (float)cpu_timer.elapsed(); - } - - // process prefetched nhood - for (uint64_t m = 0; m < nnbrs; ++m) - { - uint32_t id = node_nbrs[m]; - if (visited.insert(id).second) - { - //if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) - //unfiltered search, but filtered index! - if (!use_filter && _filter_store->is_dummy_point(id)) - continue; - - //if (use_filter && !(point_has_label(id, filter_label)) && - // (!_use_universal_label || !point_has_label(id, _universal_filter_label))) - if (use_filter && !_filter_store->point_has_label_or_universal_label(id, filter_label)) - continue; - cmps++; - float dist = dist_scratch[m]; - Neighbor nn(id, dist); - retset.insert(nn); - } - } + // process prefetched nhood + for (uint64_t m = 0; m < nnbrs; ++m) { + uint32_t id = node_nbrs[m]; + if (visited.insert(id).second) { + // if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) + // unfiltered search, but filtered index! + if (!use_filter && _filter_store->is_dummy_point(id)) + continue; + + // if (use_filter && !(point_has_label(id, filter_label)) && + // (!_use_universal_label || !point_has_label(id, + // _universal_filter_label))) + if (use_filter && + !_filter_store->point_has_label_or_universal_label( + id, filter_label)) + continue; + cmps++; + float dist = dist_scratch[m]; + Neighbor nn(id, dist); + retset.insert(nn); + } } + } #ifdef USE_BING_INFRA - // process each frontier nhood - compute distances to unvisited nodes - int completedIndex = -1; - long requestCount = static_cast(frontier_read_reqs.size()); - // If we issued read requests and if a read is complete or there are - // reads in wait state, then enter the while loop. - while (requestCount > 0 && getNextCompletedRequest(reader, ctx, requestCount, completedIndex)) - { - assert(completedIndex >= 0); - auto &frontier_nhood = frontier_nhoods[completedIndex]; - (*ctx.m_pRequestsStatus)[completedIndex] = IOContext::PROCESS_COMPLETE; + // process each frontier nhood - compute distances to unvisited nodes + int completedIndex = -1; + long requestCount = static_cast(frontier_read_reqs.size()); + // If we issued read requests and if a read is complete or there are + // reads in wait state, then enter the while loop. + while ( + requestCount > 0 && + getNextCompletedRequest(reader, ctx, requestCount, completedIndex)) { + assert(completedIndex >= 0); + auto &frontier_nhood = frontier_nhoods[completedIndex]; + (*ctx.m_pRequestsStatus)[completedIndex] = IOContext::PROCESS_COMPLETE; #else - for (auto &frontier_nhood : frontier_nhoods) - { + for (auto &frontier_nhood : frontier_nhoods) { #endif - char *node_disk_buf = offset_to_node(frontier_nhood.second, frontier_nhood.first); - uint32_t *node_buf = offset_to_node_nhood(node_disk_buf); - uint64_t nnbrs = (uint64_t)(*node_buf); - T *node_fp_coords = offset_to_node_coords(node_disk_buf); - memcpy(data_buf, node_fp_coords, _disk_bytes_per_point); - float cur_expanded_dist; - if (!_use_disk_index_pq) - { - cur_expanded_dist = _dist_cmp->compare(aligned_query_T, data_buf, (uint32_t)_aligned_dim); - } - else - { - if (metric == diskann::Metric::INNER_PRODUCT) - cur_expanded_dist = _disk_pq_table.inner_product(query_float, (uint8_t *)data_buf); - else - cur_expanded_dist = _disk_pq_table.l2_distance(query_float, (uint8_t *)data_buf); - } - full_retset.push_back(Neighbor(frontier_nhood.first, cur_expanded_dist)); - uint32_t *node_nbrs = (node_buf + 1); - // compute node_nbrs <-> query dist in PQ space - cpu_timer.reset(); - compute_dists(node_nbrs, nnbrs, dist_scratch); - if (stats != nullptr) - { - stats->n_cmps += (uint32_t)nnbrs; - stats->cpu_us += (float)cpu_timer.elapsed(); - } + char *node_disk_buf = + offset_to_node(frontier_nhood.second, frontier_nhood.first); + uint32_t *node_buf = offset_to_node_nhood(node_disk_buf); + uint64_t nnbrs = (uint64_t) (*node_buf); + T * node_fp_coords = offset_to_node_coords(node_disk_buf); + memcpy(data_buf, node_fp_coords, _disk_bytes_per_point); + float cur_expanded_dist; + if (!_use_disk_index_pq) { + cur_expanded_dist = _dist_cmp->compare(aligned_query_T, data_buf, + (uint32_t) _aligned_dim); + } else { + if (metric == diskann::Metric::INNER_PRODUCT) + cur_expanded_dist = + _disk_pq_table.inner_product(query_float, (uint8_t *) data_buf); + else + cur_expanded_dist = + _disk_pq_table.l2_distance(query_float, (uint8_t *) data_buf); + } + full_retset.push_back( + Neighbor(frontier_nhood.first, cur_expanded_dist)); + uint32_t *node_nbrs = (node_buf + 1); + // compute node_nbrs <-> query dist in PQ space + cpu_timer.reset(); + compute_dists(node_nbrs, nnbrs, dist_scratch); + if (stats != nullptr) { + stats->n_cmps += (uint32_t) nnbrs; + stats->cpu_us += (float) cpu_timer.elapsed(); + } - cpu_timer.reset(); - // process prefetch-ed nhood - for (uint64_t m = 0; m < nnbrs; ++m) - { - uint32_t id = node_nbrs[m]; - if (visited.insert(id).second) - { - //if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) - if (!use_filter && _filter_store->is_dummy_point(id)) - continue; - - //if (use_filter && !(point_has_label(id, filter_label)) && - // (!_use_universal_label || !point_has_label(id, _universal_filter_label))) - if (use_filter && _filter_store->point_has_label_or_universal_label(id, filter_label)) - continue; - cmps++; - float dist = dist_scratch[m]; - if (stats != nullptr) - { - stats->n_cmps++; - } - - Neighbor nn(id, dist); - retset.insert(nn); - } + cpu_timer.reset(); + // process prefetch-ed nhood + for (uint64_t m = 0; m < nnbrs; ++m) { + uint32_t id = node_nbrs[m]; + if (visited.insert(id).second) { + // if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) + if (!use_filter && _filter_store->is_dummy_point(id)) + continue; + + // if (use_filter && !(point_has_label(id, filter_label)) && + // (!_use_universal_label || !point_has_label(id, + // _universal_filter_label))) + if (use_filter && _filter_store->point_has_label_or_universal_label( + id, filter_label)) + continue; + cmps++; + float dist = dist_scratch[m]; + if (stats != nullptr) { + stats->n_cmps++; } - if (stats != nullptr) - { - stats->cpu_us += (float)cpu_timer.elapsed(); - } + Neighbor nn(id, dist); + retset.insert(nn); + } + } + + if (stats != nullptr) { + stats->cpu_us += (float) cpu_timer.elapsed(); } + } - hops++; + hops++; } // re-sort by distance std::sort(full_retset.begin(), full_retset.end()); - if (use_reorder_data) - { - if (!(this->_reorder_data_exists)) - { - throw ANNException("Requested use of reordering data which does " - "not exist in index " - "file", - -1, __FUNCSIG__, __FILE__, __LINE__); - } - - std::vector vec_read_reqs; - - if (full_retset.size() > k_search * FULL_PRECISION_REORDER_MULTIPLIER) - full_retset.erase(full_retset.begin() + k_search * FULL_PRECISION_REORDER_MULTIPLIER, full_retset.end()); - - for (size_t i = 0; i < full_retset.size(); ++i) - { - // MULTISECTORFIX - vec_read_reqs.emplace_back(VECTOR_SECTOR_NO(((size_t)full_retset[i].id)) * defaults::SECTOR_LEN, - defaults::SECTOR_LEN, sector_scratch + i * defaults::SECTOR_LEN); - - if (stats != nullptr) - { - stats->n_4k++; - stats->n_ios++; - } + if (use_reorder_data) { + if (!(this->_reorder_data_exists)) { + throw ANNException( + "Requested use of reordering data which does " + "not exist in index " + "file", + -1, __FUNCSIG__, __FILE__, __LINE__); + } + + std::vector vec_read_reqs; + + if (full_retset.size() > k_search * FULL_PRECISION_REORDER_MULTIPLIER) + full_retset.erase( + full_retset.begin() + k_search * FULL_PRECISION_REORDER_MULTIPLIER, + full_retset.end()); + + for (size_t i = 0; i < full_retset.size(); ++i) { + // MULTISECTORFIX + vec_read_reqs.emplace_back( + VECTOR_SECTOR_NO(((size_t) full_retset[i].id)) * + defaults::SECTOR_LEN, + defaults::SECTOR_LEN, sector_scratch + i * defaults::SECTOR_LEN); + + if (stats != nullptr) { + stats->n_4k++; + stats->n_ios++; } + } - io_timer.reset(); + io_timer.reset(); #ifdef USE_BING_INFRA - reader->read(vec_read_reqs, ctx, true); // async reader windows. + reader->read(vec_read_reqs, ctx, true); // async reader windows. #else - reader->read(vec_read_reqs, ctx); // synchronous IO linux + reader->read(vec_read_reqs, ctx); // synchronous IO linux #endif - if (stats != nullptr) - { - stats->io_us += io_timer.elapsed(); - } - - for (size_t i = 0; i < full_retset.size(); ++i) - { - auto id = full_retset[i].id; - // MULTISECTORFIX - auto location = (sector_scratch + i * defaults::SECTOR_LEN) + VECTOR_SECTOR_OFFSET(id); - full_retset[i].distance = _dist_cmp->compare(aligned_query_T, (T *)location, (uint32_t)this->_data_dim); - } - - std::sort(full_retset.begin(), full_retset.end()); + if (stats != nullptr) { + stats->io_us += io_timer.elapsed(); + } + + for (size_t i = 0; i < full_retset.size(); ++i) { + auto id = full_retset[i].id; + // MULTISECTORFIX + auto location = (sector_scratch + i * defaults::SECTOR_LEN) + + VECTOR_SECTOR_OFFSET(id); + full_retset[i].distance = _dist_cmp->compare( + aligned_query_T, (T *) location, (uint32_t) this->_data_dim); + } + + std::sort(full_retset.begin(), full_retset.end()); } // copy k_search values - for (uint64_t i = 0; i < k_search; i++) - { - indices[i] = full_retset[i].id; - auto key = (uint32_t)indices[i]; - if (_dummy_pts.find(key) != _dummy_pts.end()) - { - indices[i] = _dummy_to_real_map[key]; - } - - if (distances != nullptr) - { - distances[i] = full_retset[i].distance; - if (metric == diskann::Metric::INNER_PRODUCT) - { - // flip the sign to convert min to max - distances[i] = (-distances[i]); - // rescale to revert back to original norms (cancelling the - // effect of base and query pre-processing) - if (_max_base_norm != 0) - distances[i] *= (_max_base_norm * query_norm); - } + for (uint64_t i = 0; i < k_search; i++) { + indices[i] = full_retset[i].id; + auto key = (uint32_t) indices[i]; + if (_filter_store->is_dummy_point(key)) { + indices[i] = _filter_store->get_real_point_for_dummy(key); + } + + if (distances != nullptr) { + distances[i] = full_retset[i].distance; + if (metric == diskann::Metric::INNER_PRODUCT) { + // flip the sign to convert min to max + distances[i] = (-distances[i]); + // rescale to revert back to original norms (cancelling the + // effect of base and query pre-processing) + if (_max_base_norm != 0) + distances[i] *= (_max_base_norm * query_norm); } + } } #ifdef USE_BING_INFRA ctx.m_completeCount = 0; #endif - if (stats != nullptr) - { - stats->total_us = (float)query_timer.elapsed(); + if (stats != nullptr) { + stats->total_us = (float) query_timer.elapsed(); } -} - -// range search returns results of all neighbors within distance of range. -// indices and distances need to be pre-allocated of size l_search and the -// return value is the number of matching hits. -template -uint32_t PQFlashIndex::range_search(const T *query1, const double range, const uint64_t min_l_search, - const uint64_t max_l_search, std::vector &indices, - std::vector &distances, const uint64_t min_beam_width, - QueryStats *stats) -{ + } + + // range search returns results of all neighbors within distance of range. + // indices and distances need to be pre-allocated of size l_search and the + // return value is the number of matching hits. + template + uint32_t PQFlashIndex::range_search( + const T *query1, const double range, const uint64_t min_l_search, + const uint64_t max_l_search, std::vector &indices, + std::vector &distances, const uint64_t min_beam_width, + QueryStats *stats) { uint32_t res_count = 0; bool stop_flag = false; - uint32_t l_search = (uint32_t)min_l_search; // starting size of the candidate list - while (!stop_flag) - { - indices.resize(l_search); - distances.resize(l_search); - uint64_t cur_bw = min_beam_width > (l_search / 5) ? min_beam_width : l_search / 5; - cur_bw = (cur_bw > 100) ? 100 : cur_bw; - for (auto &x : distances) - x = std::numeric_limits::max(); - this->cached_beam_search(query1, l_search, l_search, indices.data(), distances.data(), cur_bw, false, stats); - for (uint32_t i = 0; i < l_search; i++) - { - if (distances[i] > (float)range) - { - res_count = i; - break; - } - else if (i == l_search - 1) - res_count = l_search; - } - if (res_count < (uint32_t)(l_search / 2.0)) - stop_flag = true; - l_search = l_search * 2; - if (l_search > max_l_search) - stop_flag = true; + uint32_t l_search = + (uint32_t) min_l_search; // starting size of the candidate list + while (!stop_flag) { + indices.resize(l_search); + distances.resize(l_search); + uint64_t cur_bw = + min_beam_width > (l_search / 5) ? min_beam_width : l_search / 5; + cur_bw = (cur_bw > 100) ? 100 : cur_bw; + for (auto &x : distances) + x = std::numeric_limits::max(); + this->cached_beam_search(query1, l_search, l_search, indices.data(), + distances.data(), cur_bw, false, stats); + for (uint32_t i = 0; i < l_search; i++) { + if (distances[i] > (float) range) { + res_count = i; + break; + } else if (i == l_search - 1) + res_count = l_search; + } + if (res_count < (uint32_t) (l_search / 2.0)) + stop_flag = true; + l_search = l_search * 2; + if (l_search > max_l_search) + stop_flag = true; } indices.resize(res_count); distances.resize(res_count); return res_count; -} + } -template uint64_t PQFlashIndex::get_data_dim() -{ + template + uint64_t PQFlashIndex::get_data_dim() { return _data_dim; -} + } -template diskann::Metric PQFlashIndex::get_metric() -{ + template + diskann::Metric PQFlashIndex::get_metric() { return this->metric; -} + } #ifdef EXEC_ENV_OLS -template char *PQFlashIndex::getHeaderBytes() -{ - IOContext &ctx = reader->get_ctx(); + template + char *PQFlashIndex::getHeaderBytes() { + IOContext & ctx = reader->get_ctx(); AlignedRead readReq; readReq.buf = new char[PQFlashIndex::HEADER_SIZE]; readReq.len = PQFlashIndex::HEADER_SIZE; @@ -1502,28 +1408,28 @@ template char *PQFlashIndex::getHeaderB reader->read(readReqs, ctx, false); - return (char *)readReq.buf; -} + return (char *) readReq.buf; + } #endif -template -std::vector PQFlashIndex::get_pq_vector(std::uint64_t vid) -{ + template + std::vector PQFlashIndex::get_pq_vector( + std::uint64_t vid) { std::uint8_t *pqVec = &this->data[vid * this->_n_chunks]; return std::vector(pqVec, pqVec + this->_n_chunks); -} + } -template std::uint64_t PQFlashIndex::get_num_points() -{ + template + std::uint64_t PQFlashIndex::get_num_points() { return _num_points; -} + } -// instantiations -template class PQFlashIndex; -template class PQFlashIndex; -template class PQFlashIndex; -template class PQFlashIndex; -template class PQFlashIndex; -template class PQFlashIndex; + // instantiations + template class PQFlashIndex; + template class PQFlashIndex; + template class PQFlashIndex; + template class PQFlashIndex; + template class PQFlashIndex; + template class PQFlashIndex; -} // namespace diskann +} // namespace diskann