Skip to content

Commit

Permalink
Set sector scratch to max of (maxdegree,max_filters_per_query,max_sec…
Browse files Browse the repository at this point in the history
…tor_reads)
  • Loading branch information
gopal-msr committed Jun 20, 2024
1 parent 4871688 commit e94b9a8
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 9 deletions.
4 changes: 4 additions & 0 deletions include/defaults.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,9 @@ const uint32_t MAX_DEGREE = 64;
const uint32_t BUILD_LIST_SIZE = 100;
const uint32_t SATURATE_GRAPH = false;
const uint32_t SEARCH_LIST_SIZE = 100;

const size_t VISITED_RESERVE = 4096;
const size_t MAX_FILTERS_PER_QUERY = 4096;

} // namespace defaults
} // namespace diskann
4 changes: 3 additions & 1 deletion include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex

protected:
DISKANN_DLLEXPORT void use_medoids_data_as_centroids();
DISKANN_DLLEXPORT void setup_thread_data(uint64_t nthreads, uint64_t visited_reserve = 4096);
DISKANN_DLLEXPORT void setup_thread_data(uint64_t nthreads, uint64_t visited_reserve = defaults::VISITED_RESERVE,
uint64_t max_degree = defaults::MAX_DEGREE,
uint64_t max_filters_per_query = defaults::MAX_FILTERS_PER_QUERY);

DISKANN_DLLEXPORT void set_universal_label(const LabelT &label);

Expand Down
4 changes: 2 additions & 2 deletions include/scratch.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ template <typename T> class SSDQueryScratch : public AbstractScratch<T>
NeighborPriorityQueue retset;
std::vector<Neighbor> full_retset;

SSDQueryScratch(size_t aligned_dim, size_t visited_reserve);
SSDQueryScratch(size_t aligned_dim, size_t visited_reserve, size_t max_degree, size_t max_filters_per_query);
~SSDQueryScratch();

void reset();
Expand All @@ -162,7 +162,7 @@ template <typename T> class SSDThreadData
SSDQueryScratch<T> scratch;
IOContext ctx;

SSDThreadData(size_t aligned_dim, size_t visited_reserve);
SSDThreadData(size_t aligned_dim, size_t visited_reserve, size_t max_degree, size_t max_filters_per_query);
void clear();
};

Expand Down
7 changes: 4 additions & 3 deletions src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ template <typename T, typename LabelT> inline T *PQFlashIndex<T, LabelT>::offset
}

template <typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::setup_thread_data(uint64_t nthreads, uint64_t visited_reserve)
void PQFlashIndex<T, LabelT>::setup_thread_data(uint64_t nthreads, uint64_t visited_reserve, uint64_t max_degree,
uint64_t max_filters_per_query)
{
diskann::cout << "Setting up thread-specific contexts for nthreads: " << nthreads << std::endl;
// omp parallel for to generate unique thread IDs
Expand All @@ -126,7 +127,7 @@ void PQFlashIndex<T, LabelT>::setup_thread_data(uint64_t nthreads, uint64_t visi
{
#pragma omp critical
{
SSDThreadData<T> *data = new SSDThreadData<T>(this->_aligned_dim, visited_reserve);
SSDThreadData<T> *data = new SSDThreadData<T>(this->_aligned_dim, visited_reserve, max_degree, max_filters_per_query);
this->reader->register_thread();
data->ctx = this->reader->get_ctx();
this->_thread_data.push(data);
Expand Down Expand Up @@ -1309,7 +1310,7 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
const bool use_filter, const LabelT &filter_label,
const bool use_reorder_data, QueryStats *stats)
{
std::vector<LabelT> filters(1);
std::vector<LabelT> filters;
filters.push_back(filter_label);
cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, use_filter, filters,
std::numeric_limits<uint32_t>::max(), use_reorder_data, stats);
Expand Down
10 changes: 7 additions & 3 deletions src/scratch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,14 @@ template <typename T> void SSDQueryScratch<T>::reset()
full_retset.clear();
}

template <typename T> SSDQueryScratch<T>::SSDQueryScratch(size_t aligned_dim, size_t visited_reserve)
template <typename T> SSDQueryScratch<T>::SSDQueryScratch(size_t aligned_dim, size_t visited_reserve, size_t max_degree, size_t max_filters_per_query)
{
size_t coord_alloc_size = ROUND_UP(sizeof(T) * aligned_dim, 256);
size_t sector_scratch_size =
std::max(max_degree, std::max(max_filters_per_query, (size_t)defaults::MAX_N_SECTOR_READS));

diskann::alloc_aligned((void **)&coord_scratch, coord_alloc_size, 256);
diskann::alloc_aligned((void **)&sector_scratch, defaults::MAX_N_SECTOR_READS * defaults::SECTOR_LEN,
diskann::alloc_aligned((void **)&sector_scratch, sector_scratch_size * defaults::SECTOR_LEN,
defaults::SECTOR_LEN);
diskann::alloc_aligned((void **)&this->_aligned_query_T, aligned_dim * sizeof(T), 8 * sizeof(T));

Expand All @@ -121,7 +123,9 @@ template <typename T> SSDQueryScratch<T>::~SSDQueryScratch()
}

template <typename T>
SSDThreadData<T>::SSDThreadData(size_t aligned_dim, size_t visited_reserve) : scratch(aligned_dim, visited_reserve)
SSDThreadData<T>::SSDThreadData(size_t aligned_dim, size_t visited_reserve, size_t max_degree,
size_t max_filters_per_query)
: scratch(aligned_dim, visited_reserve, max_degree, max_filters_per_query)
{
}

Expand Down

0 comments on commit e94b9a8

Please sign in to comment.