From 9c8e88d2da7554a6709571da3f5d38b7abc2c090 Mon Sep 17 00:00:00 2001 From: Renan S Date: Thu, 25 Apr 2024 13:39:25 -0700 Subject: [PATCH] merging multifilter for bann (#543) --- include/abstract_scratch.h | 35 ++ include/index.h | 2 + include/neighbor.h | 5 + include/pq.h | 50 +-- include/pq_common.h | 30 ++ include/pq_flash_index.h | 18 +- include/pq_scratch.h | 22 ++ include/scratch.h | 28 +- src/index.cpp | 2 +- src/pq_flash_index.cpp | 640 ++++++++++++++++++++++++------------- src/scratch.cpp | 55 +++- 11 files changed, 593 insertions(+), 294 deletions(-) create mode 100644 include/abstract_scratch.h create mode 100644 include/pq_common.h create mode 100644 include/pq_scratch.h diff --git a/include/abstract_scratch.h b/include/abstract_scratch.h new file mode 100644 index 000000000..b42a836f6 --- /dev/null +++ b/include/abstract_scratch.h @@ -0,0 +1,35 @@ +#pragma once +namespace diskann +{ + +template class PQScratch; + +// By somewhat more than a coincidence, it seems that both InMemQueryScratch +// and SSDQueryScratch have the aligned query and PQScratch objects. So we +// can put them in a neat hierarchy and keep PQScratch as a standalone class. +template class AbstractScratch +{ + public: + AbstractScratch() = default; + // This class does not take any responsibilty for memory management of + // its members. It is the responsibility of the derived classes to do so. + virtual ~AbstractScratch() = default; + + // Scratch objects should not be copied + AbstractScratch(const AbstractScratch &) = delete; + AbstractScratch &operator=(const AbstractScratch &) = delete; + + data_t *aligned_query_T() + { + return _aligned_query_T; + } + PQScratch *pq_scratch() + { + return _pq_scratch; + } + + protected: + data_t *_aligned_query_T = nullptr; + PQScratch *_pq_scratch = nullptr; +}; +} // namespace diskann diff --git a/include/index.h b/include/index.h index 4a9cbf999..fd5db1488 100644 --- a/include/index.h +++ b/include/index.h @@ -21,6 +21,8 @@ #include "in_mem_data_store.h" #include "in_mem_graph_store.h" #include "abstract_index.h" +#include "pq_scratch.h" +#include "pq.h" #define OVERHEAD_FACTOR 1.1 #ifdef EXEC_ENV_OLS diff --git a/include/neighbor.h b/include/neighbor.h index d7c0c25ed..7e6b58a65 100644 --- a/include/neighbor.h +++ b/include/neighbor.h @@ -109,6 +109,11 @@ class NeighborPriorityQueue return _cur < _size; } + void sort() + { + std::sort(_data.begin(), _data.begin() + _size); + } + size_t size() const { return _size; diff --git a/include/pq.h b/include/pq.h index acfa1b30a..3e6119f22 100644 --- a/include/pq.h +++ b/include/pq.h @@ -4,13 +4,7 @@ #pragma once #include "utils.h" - -#define NUM_PQ_BITS 8 -#define NUM_PQ_CENTROIDS (1 << NUM_PQ_BITS) -#define MAX_OPQ_ITERS 20 -#define NUM_KMEANS_REPS_PQ 12 -#define MAX_PQ_TRAINING_SET_SIZE 256000 -#define MAX_PQ_CHUNKS 512 +#include "pq_common.h" namespace diskann { @@ -53,40 +47,6 @@ class FixedChunkPQTable void populate_chunk_inner_products(const float *query_vec, float *dist_vec); }; -template struct PQScratch -{ - float *aligned_pqtable_dist_scratch = nullptr; // MUST BE AT LEAST [256 * NCHUNKS] - float *aligned_dist_scratch = nullptr; // MUST BE AT LEAST diskann MAX_DEGREE - uint8_t *aligned_pq_coord_scratch = nullptr; // MUST BE AT LEAST [N_CHUNKS * MAX_DEGREE] - float *rotated_query = nullptr; - float *aligned_query_float = nullptr; - - PQScratch(size_t graph_degree, size_t aligned_dim) - { - diskann::alloc_aligned((void **)&aligned_pq_coord_scratch, - (size_t)graph_degree * (size_t)MAX_PQ_CHUNKS * sizeof(uint8_t), 256); - diskann::alloc_aligned((void **)&aligned_pqtable_dist_scratch, 256 * (size_t)MAX_PQ_CHUNKS * sizeof(float), - 256); - diskann::alloc_aligned((void **)&aligned_dist_scratch, (size_t)graph_degree * sizeof(float), 256); - diskann::alloc_aligned((void **)&aligned_query_float, aligned_dim * sizeof(float), 8 * sizeof(float)); - diskann::alloc_aligned((void **)&rotated_query, aligned_dim * sizeof(float), 8 * sizeof(float)); - - memset(aligned_query_float, 0, aligned_dim * sizeof(float)); - memset(rotated_query, 0, aligned_dim * sizeof(float)); - } - - void set(size_t dim, T *query, const float norm = 1.0f) - { - for (size_t d = 0; d < dim; ++d) - { - if (norm != 1.0f) - rotated_query[d] = aligned_query_float[d] = static_cast(query[d]) / norm; - else - rotated_query[d] = aligned_query_float[d] = static_cast(query[d]); - } - } -}; - void aggregate_coords(const std::vector &ids, const uint8_t *all_coords, const uint64_t ndims, uint8_t *out); void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, const size_t pq_nchunks, const float *pq_dists, @@ -107,11 +67,19 @@ DISKANN_DLLEXPORT int generate_opq_pivots(const float *train_data, size_t num_tr unsigned num_pq_chunks, std::string opq_pivots_path, bool make_zero_mean = false); +DISKANN_DLLEXPORT int generate_pq_pivots_simplified(const float *train_data, size_t num_train, size_t dim, + size_t num_pq_chunks, std::vector &pivot_data_vector); + template int generate_pq_data_from_pivots(const std::string &data_file, unsigned num_centers, unsigned num_pq_chunks, const std::string &pq_pivots_path, const std::string &pq_compressed_vectors_path, bool use_opq = false); +DISKANN_DLLEXPORT int generate_pq_data_from_pivots_simplified(const float *data, const size_t num, + const float *pivot_data, const size_t pivots_num, + const size_t dim, const size_t num_pq_chunks, + std::vector &pq); + template void generate_disk_quantized_data(const std::string &data_file_to_use, const std::string &disk_pq_pivots_path, const std::string &disk_pq_compressed_vectors_path, diff --git a/include/pq_common.h b/include/pq_common.h new file mode 100644 index 000000000..c6a3a5739 --- /dev/null +++ b/include/pq_common.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include + +#define NUM_PQ_BITS 8 +#define NUM_PQ_CENTROIDS (1 << NUM_PQ_BITS) +#define MAX_OPQ_ITERS 20 +#define NUM_KMEANS_REPS_PQ 12 +#define MAX_PQ_TRAINING_SET_SIZE 256000 +#define MAX_PQ_CHUNKS 512 + +namespace diskann +{ +inline std::string get_quantized_vectors_filename(const std::string &prefix, bool use_opq, uint32_t num_chunks) +{ + return prefix + (use_opq ? "_opq" : "pq") + std::to_string(num_chunks) + "_compressed.bin"; +} + +inline std::string get_pivot_data_filename(const std::string &prefix, bool use_opq, uint32_t num_chunks) +{ + return prefix + (use_opq ? "_opq" : "pq") + std::to_string(num_chunks) + "_pivots.bin"; +} + +inline std::string get_rotation_matrix_suffix(const std::string &pivot_data_filename) +{ + return pivot_data_filename + "_rotation_matrix.bin"; +} + +} // namespace diskann diff --git a/include/pq_flash_index.h b/include/pq_flash_index.h index 49a504a07..b1ec6db87 100644 --- a/include/pq_flash_index.h +++ b/include/pq_flash_index.h @@ -2,6 +2,7 @@ // Licensed under the MIT license. #pragma once +#include #include "common_includes.h" #include "aligned_file_reader.h" @@ -35,6 +36,15 @@ template class PQFlashIndex DISKANN_DLLEXPORT int load(uint32_t num_threads, const char *index_prefix); #endif +#ifdef EXEC_ENV_OLS + DISKANN_DLLEXPORT void load_labels(MemoryMappedFiles &files, const std::string &disk_index_file); +#else + DISKANN_DLLEXPORT void load_labels(const std::string& disk_index_filepath); +#endif + DISKANN_DLLEXPORT void load_label_medoid_map( + const std::string &labels_to_medoids_filepath, std::istream &medoid_stream); + DISKANN_DLLEXPORT void load_dummy_map(const std::string& dummy_map_filepath, std::istream &dummy_map_stream); + #ifdef EXEC_ENV_OLS DISKANN_DLLEXPORT int load_from_separate_paths(diskann::MemoryMappedFiles &files, uint32_t num_threads, const char *index_filepath, const char *pivots_filepath, @@ -77,7 +87,7 @@ template class PQFlashIndex DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, uint64_t *res_ids, float *res_dists, const uint64_t beam_width, - const bool use_filter, const LabelT &filter_label, + const bool use_filter, const std::vector &filter_labels, const uint32_t io_limit, const bool use_reorder_data = false, QueryStats *stats = nullptr); @@ -116,9 +126,11 @@ template class PQFlashIndex private: DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id); - std::unordered_map load_label_map(std::basic_istream &infile); + DISKANN_DLLEXPORT inline bool point_has_any_label(uint32_t point_id, const std::vector &label_ids); + void load_label_map(std::basic_istream &map_reader, + std::unordered_map &string_to_int_map); DISKANN_DLLEXPORT void parse_label_file(std::basic_istream &infile, size_t &num_pts_labels); - DISKANN_DLLEXPORT void get_label_file_metadata(std::basic_istream &infile, uint32_t &num_pts, + DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts, uint32_t &num_total_labels); DISKANN_DLLEXPORT void generate_random_labels(std::vector &labels, const uint32_t num_labels, const uint32_t nthreads); diff --git a/include/pq_scratch.h b/include/pq_scratch.h new file mode 100644 index 000000000..2aa90dbe1 --- /dev/null +++ b/include/pq_scratch.h @@ -0,0 +1,22 @@ +#pragma once +#include +#include "pq_common.h" +#include "utils.h" + +namespace diskann +{ + +template class PQScratch +{ + public: + float *aligned_pqtable_dist_scratch = nullptr; // MUST BE AT LEAST [256 * NCHUNKS] + float *aligned_dist_scratch = nullptr; // MUST BE AT LEAST diskann MAX_DEGREE + uint8_t *aligned_pq_coord_scratch = nullptr; // AT LEAST [N_CHUNKS * MAX_DEGREE] + float *rotated_query = nullptr; + float *aligned_query_float = nullptr; + + PQScratch(size_t graph_degree, size_t aligned_dim); + void initialize(size_t dim, const T *query, const float norm = 1.0f); +}; + +} // namespace diskann \ No newline at end of file diff --git a/include/scratch.h b/include/scratch.h index f685b36d9..2f43e3365 100644 --- a/include/scratch.h +++ b/include/scratch.h @@ -12,22 +12,22 @@ #include "tsl/sparse_map.h" #include "aligned_file_reader.h" -#include "concurrent_queue.h" -#include "defaults.h" +#include "abstract_scratch.h" #include "neighbor.h" -#include "pq.h" +#include "defaults.h" +#include "concurrent_queue.h" namespace diskann { +template class PQScratch; // -// Scratch space for in-memory index based search +// AbstractScratch space for in-memory index based search // -template class InMemQueryScratch +template class InMemQueryScratch : public AbstractScratch { public: ~InMemQueryScratch(); - // REFACTOR TODO: move all parameters to a new class. InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, uint32_t r, uint32_t maxc, size_t dim, size_t aligned_dim, size_t alignment_factor, bool init_pq_scratch = false); void resize_for_new_L(uint32_t new_search_l); @@ -47,11 +47,11 @@ template class InMemQueryScratch } inline T *aligned_query() { - return _aligned_query; + return this->_aligned_query_T; } inline PQScratch *pq_scratch() { - return _pq_scratch; + return this->_pq_scratch; } inline std::vector &pool() { @@ -99,10 +99,6 @@ template class InMemQueryScratch uint32_t _R; uint32_t _maxc; - T *_aligned_query = nullptr; - - PQScratch *_pq_scratch = nullptr; - // _pool stores all neighbors explored from best_L_nodes. // Usually around L+R, but could be higher. // Initialized to 3L+R for some slack, expands as needed. @@ -139,10 +135,10 @@ template class InMemQueryScratch }; // -// Scratch space for SSD index based search +// AbstractScratch space for SSD index based search // -template class SSDQueryScratch +template class SSDQueryScratch : public AbstractScratch { public: T *coord_scratch = nullptr; // MUST BE AT LEAST [sizeof(T) * data_dim] @@ -150,10 +146,6 @@ template class SSDQueryScratch char *sector_scratch = nullptr; // MUST BE AT LEAST [MAX_N_SECTOR_READS * SECTOR_LEN] size_t sector_idx = 0; // index of next [SECTOR_LEN] scratch to use - T *aligned_query_T = nullptr; - - PQScratch *_pq_scratch; - tsl::robin_set visited; NeighborPriorityQueue retset; std::vector full_retset; diff --git a/src/index.cpp b/src/index.cpp index e199d7ea5..6d61aa45e 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1063,7 +1063,7 @@ std::pair Index::iterate_to_fixed_point( { query_float[d] = (float)aligned_query[d]; } - pq_query_scratch->set(_dim, aligned_query); + pq_query_scratch->initialize(_dim, aligned_query); // center the query and rotate if we have a rotation matrix _pq_table.preprocess_query(query_rotated); diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index cb6239f9c..d5fbbff20 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -4,6 +4,8 @@ #include "common_includes.h" #include "timer.h" +#include "pq.h" +#include "pq_scratch.h" #include "pq_flash_index.h" #include "cosine_similarity.h" @@ -25,19 +27,20 @@ namespace diskann { - template PQFlashIndex::PQFlashIndex(std::shared_ptr &fileReader, diskann::Metric m) : reader(fileReader), metric(m), _thread_data(nullptr) { + diskann::Metric metric_to_invoke = m; if (m == diskann::Metric::COSINE || m == diskann::Metric::INNER_PRODUCT) { if (std::is_floating_point::value) { - diskann::cout << "Cosine metric chosen for (normalized) float data." - "Changing distance to L2 to boost accuracy." + diskann::cout << "Since data is floating point, we assume that it has been appropriately pre-processed " + "(normalization for cosine, and convert-to-l2 by adding extra dimension for MIPS). So we " + "shall invoke an l2 distance function." << std::endl; - metric = diskann::Metric::L2; + metric_to_invoke = diskann::Metric::L2; } else { @@ -47,8 +50,8 @@ PQFlashIndex::PQFlashIndex(std::shared_ptr &fileRe } } - this->_dist_cmp.reset(diskann::get_distance_function(metric)); - this->_dist_cmp_float.reset(diskann::get_distance_function(metric)); + this->_dist_cmp.reset(diskann::get_distance_function(metric_to_invoke)); + this->_dist_cmp_float.reset(diskann::get_distance_function(metric_to_invoke)); } template PQFlashIndex::~PQFlashIndex() @@ -567,9 +570,8 @@ void PQFlashIndex::generate_random_labels(std::vector &labels } template -std::unordered_map PQFlashIndex::load_label_map(std::basic_istream &map_reader) +void PQFlashIndex::load_label_map(std::basic_istream &map_reader, std::unordered_map& string_to_int_map) { - std::unordered_map string_to_int_mp; std::string line, token; LabelT token_as_num; std::string label_str; @@ -580,9 +582,8 @@ std::unordered_map PQFlashIndex::load_label_map( label_str = token; getline(iss, token, '\t'); token_as_num = (LabelT)std::stoul(token); - string_to_int_mp[label_str] = token_as_num; + string_to_int_map[label_str] = token_as_num; } - return string_to_int_mp; } template @@ -610,28 +611,47 @@ void PQFlashIndex::reset_stream_for_reading(std::basic_istream } template -void PQFlashIndex::get_label_file_metadata(std::basic_istream &infile, uint32_t &num_pts, +void PQFlashIndex::get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts, uint32_t &num_total_labels) { - std::string line, token; num_pts = 0; num_total_labels = 0; - while (std::getline(infile, line)) + size_t file_size = fileContent.length(); + + std::string label_str; + size_t cur_pos = 0; + size_t next_pos = 0; + while (cur_pos < file_size && cur_pos != std::string::npos) { - std::istringstream iss(line); - while (getline(iss, token, ',')) + next_pos = fileContent.find('\n', cur_pos); + if (next_pos == std::string::npos) { - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + break; + } + + size_t lbl_pos = cur_pos; + size_t next_lbl_pos = 0; + while (lbl_pos < next_pos && lbl_pos != std::string::npos) + { + next_lbl_pos = fileContent.find(',', lbl_pos); + if (next_lbl_pos == std::string::npos) // the last label + { + next_lbl_pos = next_pos; + } + num_total_labels++; + + lbl_pos = next_lbl_pos + 1; } + + cur_pos = next_pos + 1; + num_pts++; } diskann::cout << "Labels file metadata: num_points: " << num_pts << ", #total_labels: " << num_total_labels << std::endl; - reset_stream_for_reading(infile); } template @@ -651,47 +671,103 @@ inline bool PQFlashIndex::point_has_label(uint32_t point_id, LabelT l return ret_val; } +template +bool PQFlashIndex::point_has_any_label(uint32_t point_id, const std::vector &label_ids) +{ + uint32_t start_vec = _pts_to_label_offsets[point_id]; + uint32_t num_lbls = _pts_to_label_counts[start_vec]; + bool ret_val = false; + for (auto &cur_lbl : label_ids) + { + if (point_has_label(point_id, cur_lbl)) + { + ret_val = true; + break; + } + } + return ret_val; +} + + template void PQFlashIndex::parse_label_file(std::basic_istream &infile, size_t &num_points_labels) { - std::string line, token; + infile.seekg(0, std::ios::end); + size_t file_size = infile.tellg(); + + std::string buffer(file_size, ' '); + + infile.seekg(0, std::ios::beg); + infile.read(&buffer[0], file_size); + + std::string line; uint32_t line_cnt = 0; uint32_t num_pts_in_label_file; uint32_t num_total_labels; - get_label_file_metadata(infile, num_pts_in_label_file, num_total_labels); + get_label_file_metadata(buffer, num_pts_in_label_file, num_total_labels); _pts_to_label_offsets = new uint32_t[num_pts_in_label_file]; _pts_to_label_counts = new uint32_t[num_pts_in_label_file]; _pts_to_labels = new LabelT[num_total_labels]; uint32_t labels_seen_so_far = 0; - while (std::getline(infile, line)) + std::string label_str; + size_t cur_pos = 0; + size_t next_pos = 0; + while (cur_pos < file_size && cur_pos != std::string::npos) { - std::istringstream iss(line); - std::vector lbls(0); + next_pos = buffer.find('\n', cur_pos); + if (next_pos == std::string::npos) + { + break; + } _pts_to_label_offsets[line_cnt] = labels_seen_so_far; uint32_t &num_lbls_in_cur_pt = _pts_to_label_counts[line_cnt]; num_lbls_in_cur_pt = 0; - getline(iss, token, '\t'); - std::istringstream new_iss(token); - while (getline(new_iss, token, ',')) + + size_t lbl_pos = cur_pos; + size_t next_lbl_pos = 0; + while (lbl_pos < next_pos && lbl_pos != std::string::npos) { - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - LabelT token_as_num = (LabelT)std::stoul(token); + next_lbl_pos = buffer.find(',', lbl_pos); + if (next_lbl_pos == std::string::npos) // the last label in the whole file + { + next_lbl_pos = next_pos; + } + + if (next_lbl_pos > next_pos) // the last label in one line, just read to the end + { + next_lbl_pos = next_pos; + } + + label_str.assign(buffer.c_str() + lbl_pos, next_lbl_pos - lbl_pos); + if (label_str[label_str.length() - 1] == '\t') // '\t' won't exist in label file? + { + label_str.erase(label_str.length() - 1); + } + + LabelT token_as_num = (LabelT)std::stoul(label_str); _pts_to_labels[labels_seen_so_far++] = (LabelT)token_as_num; num_lbls_in_cur_pt++; + + // move to next label + lbl_pos = next_lbl_pos + 1; } + // move to next line + cur_pos = next_pos + 1; + if (num_lbls_in_cur_pt == 0) { diskann::cout << "No label found for point " << line_cnt << std::endl; exit(-1); } + line_cnt++; } + num_points_labels = line_cnt; reset_stream_for_reading(infile); } @@ -702,80 +778,85 @@ template void PQFlashIndex::set_univers _universal_filter_label = label; } -#ifdef EXEC_ENV_OLS template -int PQFlashIndex::load(MemoryMappedFiles &files, uint32_t num_threads, const char *index_prefix) +void PQFlashIndex::load_label_medoid_map(const std::string& labels_to_medoids_filepath, std::istream& medoid_stream) { -#else -template int PQFlashIndex::load(uint32_t num_threads, const char *index_prefix) -{ -#endif - std::string pq_table_bin = std::string(index_prefix) + "_pq_pivots.bin"; - std::string pq_compressed_vectors = std::string(index_prefix) + "_pq_compressed.bin"; - std::string _disk_index_file = std::string(index_prefix) + "_disk.index"; -#ifdef EXEC_ENV_OLS - return load_from_separate_paths(files, num_threads, _disk_index_file.c_str(), pq_table_bin.c_str(), - pq_compressed_vectors.c_str()); -#else - return load_from_separate_paths(num_threads, _disk_index_file.c_str(), pq_table_bin.c_str(), - pq_compressed_vectors.c_str()); -#endif -} + std::string line, token; -#ifdef EXEC_ENV_OLS -template -int PQFlashIndex::load_from_separate_paths(diskann::MemoryMappedFiles &files, uint32_t num_threads, - const char *index_filepath, const char *pivots_filepath, - const char *compressed_filepath) -{ -#else + _filter_to_medoid_ids.clear(); + try + { + while (std::getline(medoid_stream, line)) + { + std::istringstream iss(line); + uint32_t cnt = 0; + std::vector medoids; + LabelT label; + while (std::getline(iss, token, ',')) + { + if (cnt == 0) + label = (LabelT)std::stoul(token); + else + medoids.push_back((uint32_t)stoul(token)); + cnt++; + } + _filter_to_medoid_ids[label].swap(medoids); + } + } + catch (std::system_error &e) + { + throw FileException(labels_to_medoids_filepath, e, __FUNCSIG__, __FILE__, __LINE__); + } +} template -int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, const char *index_filepath, - const char *pivots_filepath, const char *compressed_filepath) +void PQFlashIndex::load_dummy_map(const std::string &dummy_map_filepath, std::istream &dummy_map_stream) { -#endif - std::string pq_table_bin = pivots_filepath; - std::string pq_compressed_vectors = compressed_filepath; - std::string _disk_index_file = index_filepath; - std::string medoids_file = std::string(_disk_index_file) + "_medoids.bin"; - std::string centroids_file = std::string(_disk_index_file) + "_centroids.bin"; - - std::string labels_file = std ::string(_disk_index_file) + "_labels.txt"; - std::string labels_to_medoids = std ::string(_disk_index_file) + "_labels_to_medoids.txt"; - std::string dummy_map_file = std ::string(_disk_index_file) + "_dummy_map.txt"; - std::string labels_map_file = std ::string(_disk_index_file) + "_labels_map.txt"; - size_t num_pts_in_label_file = 0; + std::string line, token; - size_t pq_file_dim, pq_file_num_centroids; -#ifdef EXEC_ENV_OLS - get_bin_metadata(files, pq_table_bin, pq_file_num_centroids, pq_file_dim, METADATA_SIZE); -#else - get_bin_metadata(pq_table_bin, pq_file_num_centroids, pq_file_dim, METADATA_SIZE); -#endif + try + { + while (std::getline(dummy_map_stream, line)) + { + std::istringstream iss(line); + uint32_t cnt = 0; + uint32_t dummy_id; + uint32_t real_id; + while (std::getline(iss, token, ',')) + { + if (cnt == 0) + dummy_id = (uint32_t)stoul(token); + else + real_id = (uint32_t)stoul(token); + cnt++; + } + _dummy_pts.insert(dummy_id); + _has_dummy_pts.insert(real_id); + _dummy_to_real_map[dummy_id] = real_id; - this->_disk_index_file = _disk_index_file; + if (_real_to_dummy_map.find(real_id) == _real_to_dummy_map.end()) + _real_to_dummy_map[real_id] = std::vector(); - if (pq_file_num_centroids != 256) + _real_to_dummy_map[real_id].emplace_back(dummy_id); + } + } + catch (std::system_error &e) { - diskann::cout << "Error. Number of PQ centroids is not 256. Exiting." << std::endl; - return -1; + throw FileException (dummy_map_filepath, e, __FUNCSIG__, __FILE__, __LINE__); } - - this->_data_dim = pq_file_dim; - // will change later if we use PQ on disk or if we are using - // inner product without PQ - this->_disk_bytes_per_point = this->_data_dim * sizeof(T); - this->_aligned_dim = ROUND_UP(pq_file_dim, 8); - - size_t npts_u64, nchunks_u64; +} #ifdef EXEC_ENV_OLS - diskann::load_bin(files, pq_compressed_vectors, this->data, npts_u64, nchunks_u64); +template +void PQFlashIndex::load_labels(MemoryMappedFiles &files, const std::string &disk_index_file) #else - diskann::load_bin(pq_compressed_vectors, this->data, npts_u64, nchunks_u64); +template void PQFlashIndex::load_labels(const std::string &disk_index_file) #endif +{ + std::string labels_file = _disk_index_file + "_labels.txt"; + std::string labels_to_medoids = _disk_index_file + "_labels_to_medoids.txt"; + std::string dummy_map_file = _disk_index_file + "_dummy_map.txt"; + std::string labels_map_file = _disk_index_file + "_labels_map.txt"; + size_t num_pts_in_label_file = 0; - this->_num_points = npts_u64; - this->_n_chunks = nchunks_u64; #ifdef EXEC_ENV_OLS if (files.fileExists(labels_file)) { @@ -784,7 +865,7 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons #else if (file_exists(labels_file)) { - std::ifstream infile(labels_file); + std::ifstream infile(labels_file, std::ios::binary); if (infile.fail()) { throw diskann::ANNException(std::string("Failed to open file ") + labels_file, -1); @@ -803,7 +884,7 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons #else std::ifstream map_reader(labels_map_file); #endif - _label_map = load_label_map(map_reader); + load_label_map(map_reader, _label_map); #ifndef EXEC_ENV_OLS map_reader.close(); @@ -821,32 +902,7 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons std::ifstream medoid_stream(labels_to_medoids); assert(medoid_stream.is_open()); #endif - std::string line, token; - - _filter_to_medoid_ids.clear(); - try - { - while (std::getline(medoid_stream, line)) - { - std::istringstream iss(line); - uint32_t cnt = 0; - std::vector medoids; - LabelT label; - while (std::getline(iss, token, ',')) - { - if (cnt == 0) - label = (LabelT)std::stoul(token); - else - medoids.push_back((uint32_t)stoul(token)); - cnt++; - } - _filter_to_medoid_ids[label].swap(medoids); - } - } - catch (std::system_error &e) - { - throw FileException(labels_to_medoids, e, __FUNCSIG__, __FILE__, __LINE__); - } + load_label_medoid_map(labels_to_medoids, medoid_stream); } std::string univ_label_file = std ::string(_disk_index_file) + "_universal_label.txt"; @@ -883,37 +939,87 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons std::ifstream dummy_map_stream(dummy_map_file); assert(dummy_map_stream.is_open()); #endif - std::string line, token; - - while (std::getline(dummy_map_stream, line)) - { - std::istringstream iss(line); - uint32_t cnt = 0; - uint32_t dummy_id; - uint32_t real_id; - while (std::getline(iss, token, ',')) - { - if (cnt == 0) - dummy_id = (uint32_t)stoul(token); - else - real_id = (uint32_t)stoul(token); - cnt++; - } - _dummy_pts.insert(dummy_id); - _has_dummy_pts.insert(real_id); - _dummy_to_real_map[dummy_id] = real_id; - - if (_real_to_dummy_map.find(real_id) == _real_to_dummy_map.end()) - _real_to_dummy_map[real_id] = std::vector(); - - _real_to_dummy_map[real_id].emplace_back(dummy_id); - } + load_dummy_map(dummy_map_file, dummy_map_stream); #ifndef EXEC_ENV_OLS dummy_map_stream.close(); #endif diskann::cout << "Loaded dummy map" << std::endl; } } + else + { + diskann::cout << "Index built without filter support." << std::endl; + } +} + +#ifdef EXEC_ENV_OLS +template +int PQFlashIndex::load(MemoryMappedFiles &files, uint32_t num_threads, const char *index_prefix) +{ +#else +template int PQFlashIndex::load(uint32_t num_threads, const char *index_prefix) +{ +#endif + std::string pq_table_bin = std::string(index_prefix) + "_pq_pivots.bin"; + std::string pq_compressed_vectors = std::string(index_prefix) + "_pq_compressed.bin"; + std::string _disk_index_file = std::string(index_prefix) + "_disk.index"; +#ifdef EXEC_ENV_OLS + return load_from_separate_paths(files, num_threads, _disk_index_file.c_str(), pq_table_bin.c_str(), + pq_compressed_vectors.c_str()); +#else + return load_from_separate_paths(num_threads, _disk_index_file.c_str(), pq_table_bin.c_str(), + pq_compressed_vectors.c_str()); +#endif +} + +#ifdef EXEC_ENV_OLS +template +int PQFlashIndex::load_from_separate_paths(diskann::MemoryMappedFiles &files, uint32_t num_threads, + const char *index_filepath, const char *pivots_filepath, + const char *compressed_filepath) +{ +#else +template +int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, const char *index_filepath, + const char *pivots_filepath, const char *compressed_filepath) +{ +#endif + std::string pq_table_bin = pivots_filepath; + std::string pq_compressed_vectors = compressed_filepath; + std::string _disk_index_file = index_filepath; + std::string medoids_file = std::string(_disk_index_file) + "_medoids.bin"; + std::string centroids_file = std::string(_disk_index_file) + "_centroids.bin"; + + size_t pq_file_dim, pq_file_num_centroids; +#ifdef EXEC_ENV_OLS + get_bin_metadata(files, pq_table_bin, pq_file_num_centroids, pq_file_dim, METADATA_SIZE); +#else + get_bin_metadata(pq_table_bin, pq_file_num_centroids, pq_file_dim, METADATA_SIZE); +#endif + + this->_disk_index_file = _disk_index_file; + + if (pq_file_num_centroids != 256) + { + diskann::cout << "Error. Number of PQ centroids is not 256. Exiting." << std::endl; + return -1; + } + + this->_data_dim = pq_file_dim; + // will change later if we use PQ on disk or if we are using + // inner product without PQ + this->_disk_bytes_per_point = this->_data_dim * sizeof(T); + this->_aligned_dim = ROUND_UP(pq_file_dim, 8); + + size_t npts_u64, nchunks_u64; +#ifdef EXEC_ENV_OLS + diskann::load_bin(files, pq_compressed_vectors, this->data, npts_u64, nchunks_u64); +#else + diskann::load_bin(pq_compressed_vectors, this->data, npts_u64, nchunks_u64); +#endif + + this->_num_points = npts_u64; + this->_n_chunks = nchunks_u64; #ifdef EXEC_ENV_OLS _pq_table.load_pq_centroid_bin(files, pq_table_bin.c_str(), nchunks_u64); @@ -1034,6 +1140,12 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons READ_U64(index_metadata, this->_nvecs_per_sector); } + #ifdef EXEC_ENV_OLS + load_labels(files, _disk_index_file); + #else + load_labels(_disk_index_file); + #endif + diskann::cout << "Disk-Index File Meta-data: "; diskann::cout << "# nodes per sector: " << _nnodes_per_sector; diskann::cout << ", max node len (bytes): " << _max_node_len; @@ -1135,51 +1247,43 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons diskann::cout << "Setting re-scaling factor of base vectors to " << this->_max_base_norm << std::endl; delete[] norm_val; } + diskann::cout << "done.." << std::endl; return 0; } #ifdef USE_BING_INFRA -bool getNextCompletedRequest(std::shared_ptr &reader, IOContext &ctx, size_t size, - int &completedIndex) +bool getNextCompletedRequest(const IOContext &ctx, size_t size, int &completedIndex) { - if ((*ctx.m_pRequests)[0].m_callback) + bool waitsRemaining = false; + long completeCount = ctx.m_completeCount; + do { - bool waitsRemaining = false; - long completeCount = ctx.m_completeCount; - do + for (int i = 0; i < size; i++) { - for (int i = 0; i < size; i++) + auto ithStatus = (*ctx.m_pRequestsStatus)[i]; + if (ithStatus == IOContext::Status::READ_SUCCESS) { - auto ithStatus = (*ctx.m_pRequestsStatus)[i]; - if (ithStatus == IOContext::Status::READ_SUCCESS) - { - completedIndex = i; - return true; - } - else if (ithStatus == IOContext::Status::READ_WAIT) - { - waitsRemaining = true; - } + completedIndex = i; + return true; } - - // if we didn't find one in READ_SUCCESS, wait for one to complete. - if (waitsRemaining) + else if (ithStatus == IOContext::Status::READ_WAIT) { - WaitOnAddress(&ctx.m_completeCount, &completeCount, sizeof(completeCount), 100); - // this assumes the knowledge of the reader behavior (implicit - // contract). need better factoring? + waitsRemaining = true; } - } while (waitsRemaining); + } - completedIndex = -1; - return false; - } - else - { - reader->wait(ctx, completedIndex); - return completedIndex != -1; - } + // if we didn't find one in READ_SUCCESS, wait for one to complete. + if (waitsRemaining) + { + WaitOnAddress(&ctx.m_completeCount, &completeCount, sizeof(completeCount), 100); + // this assumes the knowledge of the reader behavior (implicit + // contract). need better factoring? + } + } while (waitsRemaining); + + completedIndex = -1; + return false; } #endif @@ -1198,7 +1302,9 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t const bool use_filter, const LabelT &filter_label, const bool use_reorder_data, QueryStats *stats) { - cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, use_filter, filter_label, + std::vector filters(1); + filters.push_back(filter_label); + cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, use_filter, filters, std::numeric_limits::max(), use_reorder_data, stats); } @@ -1208,15 +1314,15 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t const uint32_t io_limit, const bool use_reorder_data, QueryStats *stats) { - LabelT dummy_filter = 0; - cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, false, dummy_filter, io_limit, + std::vector dummy_filters(0); + cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, false, dummy_filters, io_limit, use_reorder_data, stats); } template void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, uint64_t *indices, float *distances, const uint64_t beam_width, - const bool use_filter, const LabelT &filter_label, + const bool use_filters, const std::vector &filter_labels, const uint32_t io_limit, const bool use_reorder_data, QueryStats *stats) { @@ -1230,7 +1336,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t auto data = manager.scratch_space(); IOContext &ctx = data->ctx; auto query_scratch = &(data->scratch); - auto pq_query_scratch = query_scratch->_pq_scratch; + auto pq_query_scratch = query_scratch->pq_scratch(); // reset query scratch query_scratch->reset(); @@ -1238,28 +1344,33 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t // copy query to thread specific aligned and allocated memory (for distance // calculations we need aligned data) float query_norm = 0; - T *aligned_query_T = query_scratch->aligned_query_T; + T *aligned_query_T = query_scratch->aligned_query_T(); float *query_float = pq_query_scratch->aligned_query_float; float *query_rotated = pq_query_scratch->rotated_query; - // if inner product, we laso normalize the query and set the last coordinate - // to 0 (this is the extra coordindate used to convert MIPS to L2 search) - if (metric == diskann::Metric::INNER_PRODUCT) + uint32_t filter_label_count = (uint32_t)filter_labels.size(); + + // normalization step. for cosine, we simply normalize the query + // for mips, we normalize the first d-1 dims, and add a 0 for last dim, since an extra coordinate was used to + // convert MIPS to L2 search + if (metric == diskann::Metric::INNER_PRODUCT || metric == diskann::Metric::COSINE) { - for (size_t i = 0; i < this->_data_dim - 1; i++) + uint64_t inherent_dim = (metric == diskann::Metric::COSINE) ? this->_data_dim : (uint64_t)(this->_data_dim - 1); + for (size_t i = 0; i < inherent_dim; i++) { aligned_query_T[i] = query1[i]; query_norm += query1[i] * query1[i]; } - aligned_query_T[this->_data_dim - 1] = 0; + if (metric == diskann::Metric::INNER_PRODUCT) + aligned_query_T[this->_data_dim - 1] = 0; query_norm = std::sqrt(query_norm); - for (size_t i = 0; i < this->_data_dim - 1; i++) + for (size_t i = 0; i < inherent_dim; i++) { aligned_query_T[i] = (T)(aligned_query_T[i] / query_norm); } - pq_query_scratch->set(this->_data_dim, aligned_query_T); + pq_query_scratch->initialize(this->_data_dim, aligned_query_T); } else { @@ -1267,7 +1378,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t { aligned_query_T[i] = query1[i]; } - pq_query_scratch->set(this->_data_dim, aligned_query_T); + pq_query_scratch->initialize(this->_data_dim, aligned_query_T); } // pointers to buffers for data @@ -1300,12 +1411,22 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t tsl::robin_set &visited = query_scratch->visited; NeighborPriorityQueue &retset = query_scratch->retset; - retset.reserve(l_search); std::vector &full_retset = query_scratch->full_retset; + tsl::robin_set full_retset_ids; + if (use_filters) { + uint64_t size_to_reserve = std::max(l_search, (std::min((uint64_t)filter_label_count, this->_max_degree) + 1)); + retset.reserve(size_to_reserve); + full_retset.reserve(4096); + full_retset_ids.reserve(4096); + } else { + retset.reserve(l_search + 1); + } + uint32_t best_medoid = 0; + uint32_t cur_list_size = 0; float best_dist = (std::numeric_limits::max)(); - if (!use_filter) + if (!use_filters) { for (uint64_t cur_m = 0; cur_m < _num_medoids; cur_m++) { @@ -1317,35 +1438,36 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t best_dist = cur_expanded_dist; } } - } - else - { - if (_filter_to_medoid_ids.find(filter_label) != _filter_to_medoid_ids.end()) + compute_dists(&best_medoid, 1, dist_scratch); + retset.insert(Neighbor(best_medoid, dist_scratch[0])); + visited.insert(best_medoid); + cur_list_size = 1; + } else { + std::vector filter_specific_medoids; + filter_specific_medoids.reserve(filter_label_count); + location_t ctr = 0; + for (; ctr < filter_label_count && ctr < this->_max_degree; ctr++) { - const auto &medoid_ids = _filter_to_medoid_ids[filter_label]; - for (uint64_t cur_m = 0; cur_m < medoid_ids.size(); cur_m++) + if (filter_labels[ctr] != -1) { - // for filtered index, we dont store global centroid data as for unfiltered index, so we use PQ distance - // as approximation to decide closest medoid matching the query filter. - compute_dists(&medoid_ids[cur_m], 1, dist_scratch); - float cur_expanded_dist = dist_scratch[0]; - if (cur_expanded_dist < best_dist) + for (auto id : this->_filter_to_medoid_ids[filter_labels[ctr]]) { - best_medoid = medoid_ids[cur_m]; - best_dist = cur_expanded_dist; + filter_specific_medoids.push_back(id); } } } - else + compute_dists(filter_specific_medoids.data(), filter_specific_medoids.size(), dist_scratch); + for (ctr = 0; ctr < filter_specific_medoids.size(); ctr++) { - throw ANNException("Cannot find medoid for specified filter.", -1, __FUNCSIG__, __FILE__, __LINE__); + retset.insert(Neighbor(filter_specific_medoids[ctr], dist_scratch[ctr])); + //retset[ctr].id = filter_specific_medoids[ctr]; + //retset[ctr].distance = dist_scratch[ctr]; + //retset[ctr].expanded = false; + visited.insert(filter_specific_medoids[ctr]); } + cur_list_size = (uint32_t) filter_specific_medoids.size(); } - compute_dists(&best_medoid, 1, dist_scratch); - retset.insert(Neighbor(best_medoid, dist_scratch[0])); - visited.insert(best_medoid); - uint32_t cmps = 0; uint32_t hops = 0; uint32_t num_ios = 0; @@ -1360,7 +1482,15 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t std::vector>> cached_nhoods; cached_nhoods.reserve(2 * beam_width); - while (retset.has_unexpanded_node() && num_ios < io_limit) + //if we are doing multi-filter search we don't want to restrict the number of IOs + //at present. Must revisit this decision later. + uint32_t max_ios_for_query = use_filters || (io_limit == 0) ? std::numeric_limits::max() : io_limit; + const std::vector& label_ids = filter_labels; //avoid renaming. + std::vector lbl_vec; + + retset.sort(); + + while (retset.has_unexpanded_node() && num_ios < max_ios_for_query) { // clear iteration state frontier.clear(); @@ -1370,6 +1500,45 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t sector_scratch_idx = 0; // find new beam uint32_t num_seen = 0; + + + for (const auto &lbl : label_ids) + { // assuming that number of OR labels is + // less than max frontier size allowed + uint32_t lbl_marker = 0; + while (lbl_marker < cur_list_size) + { + lbl_vec.clear(); + lbl_vec.emplace_back(lbl); + + if (!retset[lbl_marker].expanded && point_has_any_label(retset[lbl_marker].id, lbl_vec)) + { + num_seen++; + auto iter = _nhood_cache.find(retset[lbl_marker].id); + if (iter != _nhood_cache.end()) + { + cached_nhoods.push_back(std::make_pair(retset[lbl_marker].id, iter->second)); + if (stats != nullptr) + { + stats->n_cache_hits++; + } + } + else + { + frontier.push_back(retset[lbl_marker].id); + } + retset[lbl_marker].expanded = true; + if (this->_count_visited_nodes) + { + reinterpret_cast &>(this->_node_visit_counter[retset[lbl_marker].id].second) + .fetch_add(1); + } + break; + } + lbl_marker++; + } + } + while (retset.has_unexpanded_node() && frontier.size() < beam_width && num_seen < beam_width) { auto nbr = retset.closest_unexpanded(); @@ -1446,7 +1615,24 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t cur_expanded_dist = _disk_pq_table.l2_distance( // disk_pq does not support OPQ yet query_float, (uint8_t *)node_fp_coords_copy); } - full_retset.push_back(Neighbor((uint32_t)cached_nhood.first, cur_expanded_dist)); + if (use_filters) + { + location_t real_id = cached_nhood.first; + if (_dummy_pts.find(real_id) != _dummy_pts.end()) + { + real_id = _dummy_to_real_map[real_id]; + } + if (full_retset_ids.find(real_id) == full_retset_ids.end()) + { + full_retset.push_back(Neighbor((uint32_t)real_id, cur_expanded_dist)); + full_retset_ids.insert(real_id); + } + } + else + { + full_retset.push_back(Neighbor((unsigned)cached_nhood.first, cur_expanded_dist)); + } + uint64_t nnbrs = cached_nhood.second.first; uint32_t *node_nbrs = cached_nhood.second.second; @@ -1466,10 +1652,10 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t uint32_t id = node_nbrs[m]; if (visited.insert(id).second) { - if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) + if (!use_filters && _dummy_pts.find(id) != _dummy_pts.end()) continue; - if (use_filter && !(point_has_label(id, filter_label)) && + if (use_filters && !(point_has_any_label(id, label_ids)) && (!_use_universal_label || !point_has_label(id, _universal_filter_label))) continue; cmps++; @@ -1485,7 +1671,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t long requestCount = static_cast(frontier_read_reqs.size()); // If we issued read requests and if a read is complete or there are // reads in wait state, then enter the while loop. - while (requestCount > 0 && getNextCompletedRequest(reader, ctx, requestCount, completedIndex)) + while (requestCount > 0 && getNextCompletedRequest(ctx, requestCount, completedIndex)) { assert(completedIndex >= 0); auto &frontier_nhood = frontier_nhoods[completedIndex]; @@ -1511,7 +1697,25 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t else cur_expanded_dist = _disk_pq_table.l2_distance(query_float, (uint8_t *)data_buf); } - full_retset.push_back(Neighbor(frontier_nhood.first, cur_expanded_dist)); + if (use_filters) + { + location_t real_id = frontier_nhood.first; + if (_dummy_pts.find(real_id) != _dummy_pts.end()) + { + real_id = _dummy_to_real_map[real_id]; + } + + if (full_retset_ids.find(real_id) == full_retset_ids.end()) + { + full_retset.push_back(Neighbor(real_id, cur_expanded_dist)); + full_retset_ids.insert(real_id); + } + } + else + { + full_retset.push_back(Neighbor(frontier_nhood.first, cur_expanded_dist)); + } + uint32_t *node_nbrs = (node_buf + 1); // compute node_nbrs <-> query dist in PQ space cpu_timer.reset(); @@ -1529,10 +1733,10 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t uint32_t id = node_nbrs[m]; if (visited.insert(id).second) { - if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) + if (!use_filters && _dummy_pts.find(id) != _dummy_pts.end()) continue; - if (use_filter && !(point_has_label(id, filter_label)) && + if (use_filters && !(point_has_any_label(id, label_ids)) && (!_use_universal_label || !point_has_label(id, _universal_filter_label))) continue; cmps++; @@ -1552,10 +1756,8 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t stats->cpu_us += (float)cpu_timer.elapsed(); } } - hops++; } - // re-sort by distance std::sort(full_retset.begin(), full_retset.end()); diff --git a/src/scratch.cpp b/src/scratch.cpp index 112c65d28..c3836ccf1 100644 --- a/src/scratch.cpp +++ b/src/scratch.cpp @@ -5,6 +5,7 @@ #include #include "scratch.h" +#include "pq_scratch.h" namespace diskann { @@ -24,13 +25,13 @@ InMemQueryScratch::InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, throw diskann::ANNException(ss.str(), -1); } - alloc_aligned(((void **)&_aligned_query), aligned_dim * sizeof(T), alignment_factor * sizeof(T)); - memset(_aligned_query, 0, aligned_dim * sizeof(T)); + alloc_aligned(((void **)&this->_aligned_query_T), aligned_dim * sizeof(T), alignment_factor * sizeof(T)); + memset(this->_aligned_query_T, 0, aligned_dim * sizeof(T)); if (init_pq_scratch) - _pq_scratch = new PQScratch(defaults::MAX_GRAPH_DEGREE, aligned_dim); + this->_pq_scratch = new PQScratch(defaults::MAX_GRAPH_DEGREE, aligned_dim); else - _pq_scratch = nullptr; + this->_pq_scratch = nullptr; _occlude_factor.reserve(maxc); _inserted_into_pool_bs = new boost::dynamic_bitset<>(); @@ -71,12 +72,13 @@ template void InMemQueryScratch::resize_for_new_L(uint32_t new_l template InMemQueryScratch::~InMemQueryScratch() { - if (_aligned_query != nullptr) + if (this->_aligned_query_T != nullptr) { - aligned_free(_aligned_query); + aligned_free(this->_aligned_query_T); + this->_aligned_query_T = nullptr; } - delete _pq_scratch; + delete this->_pq_scratch; delete _inserted_into_pool_bs; } @@ -98,12 +100,12 @@ template SSDQueryScratch::SSDQueryScratch(size_t aligned_dim, si diskann::alloc_aligned((void **)&coord_scratch, coord_alloc_size, 256); diskann::alloc_aligned((void **)§or_scratch, defaults::MAX_N_SECTOR_READS * defaults::SECTOR_LEN, defaults::SECTOR_LEN); - diskann::alloc_aligned((void **)&aligned_query_T, aligned_dim * sizeof(T), 8 * sizeof(T)); + diskann::alloc_aligned((void **)&this->_aligned_query_T, aligned_dim * sizeof(T), 8 * sizeof(T)); - _pq_scratch = new PQScratch(defaults::MAX_GRAPH_DEGREE, aligned_dim); + this->_pq_scratch = new PQScratch(defaults::MAX_GRAPH_DEGREE, aligned_dim); memset(coord_scratch, 0, coord_alloc_size); - memset(aligned_query_T, 0, aligned_dim * sizeof(T)); + memset(this->_aligned_query_T, 0, aligned_dim * sizeof(T)); visited.reserve(visited_reserve); full_retset.reserve(visited_reserve); @@ -113,9 +115,9 @@ template SSDQueryScratch::~SSDQueryScratch() { diskann::aligned_free((void *)coord_scratch); diskann::aligned_free((void *)sector_scratch); - diskann::aligned_free((void *)aligned_query_T); + diskann::aligned_free((void *)this->_aligned_query_T); - delete[] _pq_scratch; + delete[] this->_pq_scratch; } template @@ -128,6 +130,30 @@ template void SSDThreadData::clear() scratch.reset(); } +template PQScratch::PQScratch(size_t graph_degree, size_t aligned_dim) +{ + diskann::alloc_aligned((void **)&aligned_pq_coord_scratch, + (size_t)graph_degree * (size_t)MAX_PQ_CHUNKS * sizeof(uint8_t), 256); + diskann::alloc_aligned((void **)&aligned_pqtable_dist_scratch, 256 * (size_t)MAX_PQ_CHUNKS * sizeof(float), 256); + diskann::alloc_aligned((void **)&aligned_dist_scratch, (size_t)graph_degree * sizeof(float), 256); + diskann::alloc_aligned((void **)&aligned_query_float, aligned_dim * sizeof(float), 8 * sizeof(float)); + diskann::alloc_aligned((void **)&rotated_query, aligned_dim * sizeof(float), 8 * sizeof(float)); + + memset(aligned_query_float, 0, aligned_dim * sizeof(float)); + memset(rotated_query, 0, aligned_dim * sizeof(float)); +} + +template void PQScratch::initialize(size_t dim, const T *query, const float norm) +{ + for (size_t d = 0; d < dim; ++d) + { + if (norm != 1.0f) + rotated_query[d] = aligned_query_float[d] = static_cast(query[d]) / norm; + else + rotated_query[d] = aligned_query_float[d] = static_cast(query[d]); + } +} + template DISKANN_DLLEXPORT class InMemQueryScratch; template DISKANN_DLLEXPORT class InMemQueryScratch; template DISKANN_DLLEXPORT class InMemQueryScratch; @@ -136,7 +162,12 @@ template DISKANN_DLLEXPORT class SSDQueryScratch; template DISKANN_DLLEXPORT class SSDQueryScratch; template DISKANN_DLLEXPORT class SSDQueryScratch; +template DISKANN_DLLEXPORT class PQScratch; +template DISKANN_DLLEXPORT class PQScratch; +template DISKANN_DLLEXPORT class PQScratch; + template DISKANN_DLLEXPORT class SSDThreadData; template DISKANN_DLLEXPORT class SSDThreadData; template DISKANN_DLLEXPORT class SSDThreadData; + } // namespace diskann