Skip to content

Commit

Permalink
merging multifilter for bann (#543)
Browse files Browse the repository at this point in the history
  • Loading branch information
MS-Renan authored Apr 25, 2024
1 parent facfc28 commit 9c8e88d
Show file tree
Hide file tree
Showing 11 changed files with 593 additions and 294 deletions.
35 changes: 35 additions & 0 deletions include/abstract_scratch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once
namespace diskann
{

template <typename data_t> class PQScratch;

// By somewhat more than a coincidence, it seems that both InMemQueryScratch
// and SSDQueryScratch have the aligned query and PQScratch objects. So we
// can put them in a neat hierarchy and keep PQScratch as a standalone class.
template <typename data_t> class AbstractScratch
{
public:
AbstractScratch() = default;
// This class does not take any responsibilty for memory management of
// its members. It is the responsibility of the derived classes to do so.
virtual ~AbstractScratch() = default;

// Scratch objects should not be copied
AbstractScratch(const AbstractScratch &) = delete;
AbstractScratch &operator=(const AbstractScratch &) = delete;

data_t *aligned_query_T()
{
return _aligned_query_T;
}
PQScratch<data_t> *pq_scratch()
{
return _pq_scratch;
}

protected:
data_t *_aligned_query_T = nullptr;
PQScratch<data_t> *_pq_scratch = nullptr;
};
} // namespace diskann
2 changes: 2 additions & 0 deletions include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "in_mem_data_store.h"
#include "in_mem_graph_store.h"
#include "abstract_index.h"
#include "pq_scratch.h"
#include "pq.h"

#define OVERHEAD_FACTOR 1.1
#ifdef EXEC_ENV_OLS
Expand Down
5 changes: 5 additions & 0 deletions include/neighbor.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ class NeighborPriorityQueue
return _cur < _size;
}

void sort()
{
std::sort(_data.begin(), _data.begin() + _size);
}

size_t size() const
{
return _size;
Expand Down
50 changes: 9 additions & 41 deletions include/pq.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,7 @@
#pragma once

#include "utils.h"

#define NUM_PQ_BITS 8
#define NUM_PQ_CENTROIDS (1 << NUM_PQ_BITS)
#define MAX_OPQ_ITERS 20
#define NUM_KMEANS_REPS_PQ 12
#define MAX_PQ_TRAINING_SET_SIZE 256000
#define MAX_PQ_CHUNKS 512
#include "pq_common.h"

namespace diskann
{
Expand Down Expand Up @@ -53,40 +47,6 @@ class FixedChunkPQTable
void populate_chunk_inner_products(const float *query_vec, float *dist_vec);
};

template <typename T> struct PQScratch
{
float *aligned_pqtable_dist_scratch = nullptr; // MUST BE AT LEAST [256 * NCHUNKS]
float *aligned_dist_scratch = nullptr; // MUST BE AT LEAST diskann MAX_DEGREE
uint8_t *aligned_pq_coord_scratch = nullptr; // MUST BE AT LEAST [N_CHUNKS * MAX_DEGREE]
float *rotated_query = nullptr;
float *aligned_query_float = nullptr;

PQScratch(size_t graph_degree, size_t aligned_dim)
{
diskann::alloc_aligned((void **)&aligned_pq_coord_scratch,
(size_t)graph_degree * (size_t)MAX_PQ_CHUNKS * sizeof(uint8_t), 256);
diskann::alloc_aligned((void **)&aligned_pqtable_dist_scratch, 256 * (size_t)MAX_PQ_CHUNKS * sizeof(float),
256);
diskann::alloc_aligned((void **)&aligned_dist_scratch, (size_t)graph_degree * sizeof(float), 256);
diskann::alloc_aligned((void **)&aligned_query_float, aligned_dim * sizeof(float), 8 * sizeof(float));
diskann::alloc_aligned((void **)&rotated_query, aligned_dim * sizeof(float), 8 * sizeof(float));

memset(aligned_query_float, 0, aligned_dim * sizeof(float));
memset(rotated_query, 0, aligned_dim * sizeof(float));
}

void set(size_t dim, T *query, const float norm = 1.0f)
{
for (size_t d = 0; d < dim; ++d)
{
if (norm != 1.0f)
rotated_query[d] = aligned_query_float[d] = static_cast<float>(query[d]) / norm;
else
rotated_query[d] = aligned_query_float[d] = static_cast<float>(query[d]);
}
}
};

void aggregate_coords(const std::vector<unsigned> &ids, const uint8_t *all_coords, const uint64_t ndims, uint8_t *out);

void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, const size_t pq_nchunks, const float *pq_dists,
Expand All @@ -107,11 +67,19 @@ DISKANN_DLLEXPORT int generate_opq_pivots(const float *train_data, size_t num_tr
unsigned num_pq_chunks, std::string opq_pivots_path,
bool make_zero_mean = false);

DISKANN_DLLEXPORT int generate_pq_pivots_simplified(const float *train_data, size_t num_train, size_t dim,
size_t num_pq_chunks, std::vector<float> &pivot_data_vector);

template <typename T>
int generate_pq_data_from_pivots(const std::string &data_file, unsigned num_centers, unsigned num_pq_chunks,
const std::string &pq_pivots_path, const std::string &pq_compressed_vectors_path,
bool use_opq = false);

DISKANN_DLLEXPORT int generate_pq_data_from_pivots_simplified(const float *data, const size_t num,
const float *pivot_data, const size_t pivots_num,
const size_t dim, const size_t num_pq_chunks,
std::vector<uint8_t> &pq);

template <typename T>
void generate_disk_quantized_data(const std::string &data_file_to_use, const std::string &disk_pq_pivots_path,
const std::string &disk_pq_compressed_vectors_path,
Expand Down
30 changes: 30 additions & 0 deletions include/pq_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include <string>
#include <sstream>

#define NUM_PQ_BITS 8
#define NUM_PQ_CENTROIDS (1 << NUM_PQ_BITS)
#define MAX_OPQ_ITERS 20
#define NUM_KMEANS_REPS_PQ 12
#define MAX_PQ_TRAINING_SET_SIZE 256000
#define MAX_PQ_CHUNKS 512

namespace diskann
{
inline std::string get_quantized_vectors_filename(const std::string &prefix, bool use_opq, uint32_t num_chunks)
{
return prefix + (use_opq ? "_opq" : "pq") + std::to_string(num_chunks) + "_compressed.bin";
}

inline std::string get_pivot_data_filename(const std::string &prefix, bool use_opq, uint32_t num_chunks)
{
return prefix + (use_opq ? "_opq" : "pq") + std::to_string(num_chunks) + "_pivots.bin";
}

inline std::string get_rotation_matrix_suffix(const std::string &pivot_data_filename)
{
return pivot_data_filename + "_rotation_matrix.bin";
}

} // namespace diskann
18 changes: 15 additions & 3 deletions include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT license.

#pragma once
#include <unordered_map>
#include "common_includes.h"

#include "aligned_file_reader.h"
Expand Down Expand Up @@ -35,6 +36,15 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
DISKANN_DLLEXPORT int load(uint32_t num_threads, const char *index_prefix);
#endif

#ifdef EXEC_ENV_OLS
DISKANN_DLLEXPORT void load_labels(MemoryMappedFiles &files, const std::string &disk_index_file);
#else
DISKANN_DLLEXPORT void load_labels(const std::string& disk_index_filepath);
#endif
DISKANN_DLLEXPORT void load_label_medoid_map(
const std::string &labels_to_medoids_filepath, std::istream &medoid_stream);
DISKANN_DLLEXPORT void load_dummy_map(const std::string& dummy_map_filepath, std::istream &dummy_map_stream);

#ifdef EXEC_ENV_OLS
DISKANN_DLLEXPORT int load_from_separate_paths(diskann::MemoryMappedFiles &files, uint32_t num_threads,
const char *index_filepath, const char *pivots_filepath,
Expand Down Expand Up @@ -77,7 +87,7 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex

DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search,
uint64_t *res_ids, float *res_dists, const uint64_t beam_width,
const bool use_filter, const LabelT &filter_label,
const bool use_filter, const std::vector<LabelT> &filter_labels,
const uint32_t io_limit, const bool use_reorder_data = false,
QueryStats *stats = nullptr);

Expand Down Expand Up @@ -116,9 +126,11 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex

private:
DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id);
std::unordered_map<std::string, LabelT> load_label_map(std::basic_istream<char> &infile);
DISKANN_DLLEXPORT inline bool point_has_any_label(uint32_t point_id, const std::vector<LabelT> &label_ids);
void load_label_map(std::basic_istream<char> &map_reader,
std::unordered_map<std::string, LabelT> &string_to_int_map);
DISKANN_DLLEXPORT void parse_label_file(std::basic_istream<char> &infile, size_t &num_pts_labels);
DISKANN_DLLEXPORT void get_label_file_metadata(std::basic_istream<char> &infile, uint32_t &num_pts,
DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts,
uint32_t &num_total_labels);
DISKANN_DLLEXPORT void generate_random_labels(std::vector<LabelT> &labels, const uint32_t num_labels,
const uint32_t nthreads);
Expand Down
22 changes: 22 additions & 0 deletions include/pq_scratch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#pragma once
#include <cstdint>
#include "pq_common.h"
#include "utils.h"

namespace diskann
{

template <typename T> class PQScratch
{
public:
float *aligned_pqtable_dist_scratch = nullptr; // MUST BE AT LEAST [256 * NCHUNKS]
float *aligned_dist_scratch = nullptr; // MUST BE AT LEAST diskann MAX_DEGREE
uint8_t *aligned_pq_coord_scratch = nullptr; // AT LEAST [N_CHUNKS * MAX_DEGREE]
float *rotated_query = nullptr;
float *aligned_query_float = nullptr;

PQScratch(size_t graph_degree, size_t aligned_dim);
void initialize(size_t dim, const T *query, const float norm = 1.0f);
};

} // namespace diskann
28 changes: 10 additions & 18 deletions include/scratch.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,22 @@
#include "tsl/sparse_map.h"

#include "aligned_file_reader.h"
#include "concurrent_queue.h"
#include "defaults.h"
#include "abstract_scratch.h"
#include "neighbor.h"
#include "pq.h"
#include "defaults.h"
#include "concurrent_queue.h"

namespace diskann
{
template <typename T> class PQScratch;

//
// Scratch space for in-memory index based search
// AbstractScratch space for in-memory index based search
//
template <typename T> class InMemQueryScratch
template <typename T> class InMemQueryScratch : public AbstractScratch<T>
{
public:
~InMemQueryScratch();
// REFACTOR TODO: move all parameters to a new class.
InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, uint32_t r, uint32_t maxc, size_t dim, size_t aligned_dim,
size_t alignment_factor, bool init_pq_scratch = false);
void resize_for_new_L(uint32_t new_search_l);
Expand All @@ -47,11 +47,11 @@ template <typename T> class InMemQueryScratch
}
inline T *aligned_query()
{
return _aligned_query;
return this->_aligned_query_T;
}
inline PQScratch<T> *pq_scratch()
{
return _pq_scratch;
return this->_pq_scratch;
}
inline std::vector<Neighbor> &pool()
{
Expand Down Expand Up @@ -99,10 +99,6 @@ template <typename T> class InMemQueryScratch
uint32_t _R;
uint32_t _maxc;

T *_aligned_query = nullptr;

PQScratch<T> *_pq_scratch = nullptr;

// _pool stores all neighbors explored from best_L_nodes.
// Usually around L+R, but could be higher.
// Initialized to 3L+R for some slack, expands as needed.
Expand Down Expand Up @@ -139,21 +135,17 @@ template <typename T> class InMemQueryScratch
};

//
// Scratch space for SSD index based search
// AbstractScratch space for SSD index based search
//

template <typename T> class SSDQueryScratch
template <typename T> class SSDQueryScratch : public AbstractScratch<T>
{
public:
T *coord_scratch = nullptr; // MUST BE AT LEAST [sizeof(T) * data_dim]

char *sector_scratch = nullptr; // MUST BE AT LEAST [MAX_N_SECTOR_READS * SECTOR_LEN]
size_t sector_idx = 0; // index of next [SECTOR_LEN] scratch to use

T *aligned_query_T = nullptr;

PQScratch<T> *_pq_scratch;

tsl::robin_set<size_t> visited;
NeighborPriorityQueue retset;
std::vector<Neighbor> full_retset;
Expand Down
2 changes: 1 addition & 1 deletion src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
{
query_float[d] = (float)aligned_query[d];
}
pq_query_scratch->set(_dim, aligned_query);
pq_query_scratch->initialize(_dim, aligned_query);

// center the query and rotate if we have a rotation matrix
_pq_table.preprocess_query(query_rotated);
Expand Down
Loading

0 comments on commit 9c8e88d

Please sign in to comment.