diff --git a/include/defaults.h b/include/defaults.h index ef1750fcf..8aa31fd50 100644 --- a/include/defaults.h +++ b/include/defaults.h @@ -30,5 +30,9 @@ const uint32_t MAX_DEGREE = 64; const uint32_t BUILD_LIST_SIZE = 100; const uint32_t SATURATE_GRAPH = false; const uint32_t SEARCH_LIST_SIZE = 100; + +const size_t VISITED_RESERVE = 4096; +const size_t MAX_FILTERS_PER_QUERY = 4096; + } // namespace defaults } // namespace diskann diff --git a/include/pq_flash_index.h b/include/pq_flash_index.h index b1ec6db87..21381b3fe 100644 --- a/include/pq_flash_index.h +++ b/include/pq_flash_index.h @@ -120,7 +120,9 @@ template class PQFlashIndex protected: DISKANN_DLLEXPORT void use_medoids_data_as_centroids(); - DISKANN_DLLEXPORT void setup_thread_data(uint64_t nthreads, uint64_t visited_reserve = 4096); + DISKANN_DLLEXPORT void setup_thread_data(uint64_t nthreads, uint64_t visited_reserve = defaults::VISITED_RESERVE, + uint64_t max_degree = defaults::MAX_DEGREE, + uint64_t max_filters_per_query = defaults::MAX_FILTERS_PER_QUERY); DISKANN_DLLEXPORT void set_universal_label(const LabelT &label); diff --git a/include/scratch.h b/include/scratch.h index 2f43e3365..8d8105c3b 100644 --- a/include/scratch.h +++ b/include/scratch.h @@ -150,7 +150,7 @@ template class SSDQueryScratch : public AbstractScratch NeighborPriorityQueue retset; std::vector full_retset; - SSDQueryScratch(size_t aligned_dim, size_t visited_reserve); + SSDQueryScratch(size_t aligned_dim, size_t visited_reserve, size_t max_degree, size_t max_filters_per_query); ~SSDQueryScratch(); void reset(); @@ -162,7 +162,7 @@ template class SSDThreadData SSDQueryScratch scratch; IOContext ctx; - SSDThreadData(size_t aligned_dim, size_t visited_reserve); + SSDThreadData(size_t aligned_dim, size_t visited_reserve, size_t max_degree, size_t max_filters_per_query); void clear(); }; diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index 491614a42..11985a16f 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -117,7 +117,8 @@ template inline T *PQFlashIndex::offset } template -void PQFlashIndex::setup_thread_data(uint64_t nthreads, uint64_t visited_reserve) +void PQFlashIndex::setup_thread_data(uint64_t nthreads, uint64_t visited_reserve, uint64_t max_degree, + uint64_t max_filters_per_query) { diskann::cout << "Setting up thread-specific contexts for nthreads: " << nthreads << std::endl; // omp parallel for to generate unique thread IDs @@ -126,7 +127,7 @@ void PQFlashIndex::setup_thread_data(uint64_t nthreads, uint64_t visi { #pragma omp critical { - SSDThreadData *data = new SSDThreadData(this->_aligned_dim, visited_reserve); + SSDThreadData *data = new SSDThreadData(this->_aligned_dim, visited_reserve, max_degree, max_filters_per_query); this->reader->register_thread(); data->ctx = this->reader->get_ctx(); this->_thread_data.push(data); @@ -1309,7 +1310,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t const bool use_filter, const LabelT &filter_label, const bool use_reorder_data, QueryStats *stats) { - std::vector filters(1); + std::vector filters; filters.push_back(filter_label); cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, use_filter, filters, std::numeric_limits::max(), use_reorder_data, stats); diff --git a/src/scratch.cpp b/src/scratch.cpp index c3836ccf1..05c9a1553 100644 --- a/src/scratch.cpp +++ b/src/scratch.cpp @@ -93,12 +93,14 @@ template void SSDQueryScratch::reset() full_retset.clear(); } -template SSDQueryScratch::SSDQueryScratch(size_t aligned_dim, size_t visited_reserve) +template SSDQueryScratch::SSDQueryScratch(size_t aligned_dim, size_t visited_reserve, size_t max_degree, size_t max_filters_per_query) { size_t coord_alloc_size = ROUND_UP(sizeof(T) * aligned_dim, 256); + size_t sector_scratch_size = + std::max(max_degree, std::max(max_filters_per_query, (size_t)defaults::MAX_N_SECTOR_READS)); diskann::alloc_aligned((void **)&coord_scratch, coord_alloc_size, 256); - diskann::alloc_aligned((void **)§or_scratch, defaults::MAX_N_SECTOR_READS * defaults::SECTOR_LEN, + diskann::alloc_aligned((void **)§or_scratch, sector_scratch_size * defaults::SECTOR_LEN, defaults::SECTOR_LEN); diskann::alloc_aligned((void **)&this->_aligned_query_T, aligned_dim * sizeof(T), 8 * sizeof(T)); @@ -121,7 +123,9 @@ template SSDQueryScratch::~SSDQueryScratch() } template -SSDThreadData::SSDThreadData(size_t aligned_dim, size_t visited_reserve) : scratch(aligned_dim, visited_reserve) +SSDThreadData::SSDThreadData(size_t aligned_dim, size_t visited_reserve, size_t max_degree, + size_t max_filters_per_query) + : scratch(aligned_dim, visited_reserve, max_degree, max_filters_per_query) { }