Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a new PQ Distance Metric and PQ Data Store #384

Merged
merged 18 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,8 @@ if(MSVC)
else()
set(ENV{TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD} 500000000000)
# set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -DDEBUG -O0 -fsanitize=address -fsanitize=leak -fsanitize=undefined")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -DDEBUG -Wall -Wextra")
# set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -DDEBUG -Wall -Wextra")
gopalrs marked this conversation as resolved.
Show resolved Hide resolved
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -DDEBUG")
if (NOT PYBIND)
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -Ofast -DNDEBUG -march=native -mtune=native -ftree-vectorize")
else()
Expand Down
19 changes: 15 additions & 4 deletions include/abstract_data_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
namespace diskann
{

template <typename data_t> struct AbstractScratch;

template <typename data_t> class AbstractDataStore
{
public:
Expand Down Expand Up @@ -76,19 +78,28 @@ template <typename data_t> class AbstractDataStore
// num_points) to zero
virtual void copy_vectors(const location_t from_loc, const location_t to_loc, const location_t num_points) = 0;

// metric specific operations

// Some datastores like PQ stores need a preprocessing step before querying.
// Optionally, a scratch object can be passed in to avoid memory allocations
// Default implementation does nothing.
// REFACTOR TODO: Currently, we take an aligned_query as parameter, but this
// should change and this function should do the necessary alignment.
virtual void preprocess_query(const data_t *aligned_query, AbstractScratch<data_t> *query_scratch = nullptr) const;
gopalrs marked this conversation as resolved.
Show resolved Hide resolved
// distance functions.
virtual float get_distance(const data_t *query, const location_t loc) const = 0;
virtual void get_distance(const data_t *query, const location_t *locations, const uint32_t location_count,
float *distances) const = 0;
float *distances, AbstractScratch<data_t> *scratch_space = nullptr) const = 0;
// Specific overload for index.cpp.
// REFACTOR TODO: Check if the default implementation is sufficient for most cases.
virtual void get_distance(const data_t *preprocessed_query, const std::vector<location_t> &ids,
std::vector<float> &distances, AbstractScratch<data_t> *scratch_space) const;
virtual float get_distance(const location_t loc1, const location_t loc2) const = 0;

// stats of the data stored in store
// Returns the point in the dataset that is closest to the mean of all points
// in the dataset
virtual location_t calculate_medoid() const = 0;

virtual Distance<data_t> *get_dist_fn() = 0;
virtual std::shared_ptr<Distance<data_t>> get_dist_fn() const = 0;

// search helpers
// if the base data is aligned per the request of the metric, this will tell
Expand Down
34 changes: 34 additions & 0 deletions include/abstract_scratch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#pragma once
namespace diskann
{

template <typename data_t> struct PQScratch;

// By somewhat more than a coincidence, it seems that both InMemQueryScratch
// and SSDQueryScratch have the aligned query and PQScratch objects. So we
// can put them in a neat hierarchy and keep PQScratch as a standalone class.
template <typename data_t> struct AbstractScratch
{
AbstractScratch() = default;
// This class does not take any responsibilty for memory management of
// its members. It is the responsibility of the derived classes to do so.
virtual ~AbstractScratch(){};
gopalrs marked this conversation as resolved.
Show resolved Hide resolved

// Scratch objects should not be copied
AbstractScratch(const AbstractScratch &) = delete;
AbstractScratch &operator=(const AbstractScratch &) = delete;

data_t *aligned_query_T()
{
return _aligned_query_T;
gopalrs marked this conversation as resolved.
Show resolved Hide resolved
}
PQScratch<data_t> *pq_scratch()
{
return _pq_scratch;
}

protected:
data_t *_aligned_query_T = nullptr;
PQScratch<data_t> *_pq_scratch = nullptr;
};
} // namespace diskann
5 changes: 3 additions & 2 deletions include/in_mem_data_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ template <typename data_t> class InMemDataStore : public AbstractDataStore<data_

virtual float get_distance(const data_t *query, const location_t loc) const override;
virtual float get_distance(const location_t loc1, const location_t loc2) const override;

virtual void get_distance(const data_t *query, const location_t *locations, const uint32_t location_count,
float *distances) const override;
float *distances, AbstractScratch<data_t> *scratch) const override;

virtual location_t calculate_medoid() const override;

virtual Distance<data_t> *get_dist_fn() override;
virtual std::shared_ptr<Distance<data_t>> get_dist_fn() const override;

virtual size_t get_alignment_factor() const override;

Expand Down
9 changes: 8 additions & 1 deletion include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
#include "in_mem_data_store.h"
#include "abstract_index.h"

// REFACTOR
#include "quantized_distance.h"
#include "pq_data_store.h"

#define OVERHEAD_FACTOR 1.1
#define EXPAND_IF_FULL 0
#define DEFAULT_MAXC 750
Expand Down Expand Up @@ -393,7 +397,10 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
bool _pq_dist = false;
bool _use_opq = false;
size_t _num_pq_chunks = 0;
uint8_t *_pq_data = nullptr;
// REFACTOR
// uint8_t *_pq_data = nullptr;
std::shared_ptr<QuantizedDistance<T>> _pq_distance_fn = nullptr;
rakri marked this conversation as resolved.
Show resolved Hide resolved
std::shared_ptr<PQDataStore<T>> _pq_data_store = nullptr;
bool _pq_generated = false;
FixedChunkPQTable _pq_table;

Expand Down
75 changes: 35 additions & 40 deletions include/pq.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,7 @@
#pragma once

#include "utils.h"

#define NUM_PQ_BITS 8
#define NUM_PQ_CENTROIDS (1 << NUM_PQ_BITS)
#define MAX_OPQ_ITERS 20
#define NUM_KMEANS_REPS_PQ 12
#define MAX_PQ_TRAINING_SET_SIZE 256000
#define MAX_PQ_CHUNKS 512
#include "pq_common.h"

namespace diskann
{
Expand Down Expand Up @@ -53,39 +47,40 @@ class FixedChunkPQTable
void populate_chunk_inner_products(const float *query_vec, float *dist_vec);
};

template <typename T> struct PQScratch
{
float *aligned_pqtable_dist_scratch = nullptr; // MUST BE AT LEAST [256 * NCHUNKS]
float *aligned_dist_scratch = nullptr; // MUST BE AT LEAST diskann MAX_DEGREE
uint8_t *aligned_pq_coord_scratch = nullptr; // MUST BE AT LEAST [N_CHUNKS * MAX_DEGREE]
float *rotated_query = nullptr;
float *aligned_query_float = nullptr;

PQScratch(size_t graph_degree, size_t aligned_dim)
{
diskann::alloc_aligned((void **)&aligned_pq_coord_scratch,
(size_t)graph_degree * (size_t)MAX_PQ_CHUNKS * sizeof(uint8_t), 256);
diskann::alloc_aligned((void **)&aligned_pqtable_dist_scratch, 256 * (size_t)MAX_PQ_CHUNKS * sizeof(float),
256);
diskann::alloc_aligned((void **)&aligned_dist_scratch, (size_t)graph_degree * sizeof(float), 256);
diskann::alloc_aligned((void **)&aligned_query_float, aligned_dim * sizeof(float), 8 * sizeof(float));
diskann::alloc_aligned((void **)&rotated_query, aligned_dim * sizeof(float), 8 * sizeof(float));

memset(aligned_query_float, 0, aligned_dim * sizeof(float));
memset(rotated_query, 0, aligned_dim * sizeof(float));
}

void set(size_t dim, T *query, const float norm = 1.0f)
{
for (size_t d = 0; d < dim; ++d)
{
if (norm != 1.0f)
rotated_query[d] = aligned_query_float[d] = static_cast<float>(query[d]) / norm;
else
rotated_query[d] = aligned_query_float[d] = static_cast<float>(query[d]);
}
}
};
// REFACTOR
// template <typename T> struct PQScratch
//{
// float *aligned_pqtable_dist_scratch = nullptr; // MUST BE AT LEAST [256 * NCHUNKS]
// float *aligned_dist_scratch = nullptr; // MUST BE AT LEAST diskann MAX_DEGREE
// uint8_t *aligned_pq_coord_scratch = nullptr; // MUST BE AT LEAST [N_CHUNKS * MAX_DEGREE]
// float *rotated_query = nullptr;
// float *aligned_query_float = nullptr;
//
// PQScratch(size_t graph_degree, size_t aligned_dim)
// {
// diskann::alloc_aligned((void **)&aligned_pq_coord_scratch,
// (size_t)graph_degree * (size_t)MAX_PQ_CHUNKS * sizeof(uint8_t), 256);
// diskann::alloc_aligned((void **)&aligned_pqtable_dist_scratch, 256 * (size_t)MAX_PQ_CHUNKS * sizeof(float),
// 256);
// diskann::alloc_aligned((void **)&aligned_dist_scratch, (size_t)graph_degree * sizeof(float), 256);
// diskann::alloc_aligned((void **)&aligned_query_float, aligned_dim * sizeof(float), 8 * sizeof(float));
// diskann::alloc_aligned((void **)&rotated_query, aligned_dim * sizeof(float), 8 * sizeof(float));
//
// memset(aligned_query_float, 0, aligned_dim * sizeof(float));
// memset(rotated_query, 0, aligned_dim * sizeof(float));
// }
//
// void set(size_t dim, T *query, const float norm = 1.0f)
// {
// for (size_t d = 0; d < dim; ++d)
// {
// if (norm != 1.0f)
// rotated_query[d] = aligned_query_float[d] = static_cast<float>(query[d]) / norm;
// else
// rotated_query[d] = aligned_query_float[d] = static_cast<float>(query[d]);
// }
// }
// };
gopalrs marked this conversation as resolved.
Show resolved Hide resolved

void aggregate_coords(const std::vector<unsigned> &ids, const uint8_t *all_coords, const uint64_t ndims, uint8_t *out);

Expand Down
30 changes: 30 additions & 0 deletions include/pq_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include <string>
#include <sstream>

#define NUM_PQ_BITS 8
#define NUM_PQ_CENTROIDS (1 << NUM_PQ_BITS)
#define MAX_OPQ_ITERS 20
#define NUM_KMEANS_REPS_PQ 12
#define MAX_PQ_TRAINING_SET_SIZE 256000
#define MAX_PQ_CHUNKS 512

namespace diskann
{
inline std::string get_quantized_vectors_filename(const std::string &prefix, bool use_opq, uint32_t num_chunks)
{
return prefix + (use_opq ? "_opq" : "pq") + std::to_string(num_chunks) + "_compressed.bin";
}

inline std::string get_pivot_data_filename(const std::string &prefix, bool use_opq, uint32_t num_chunks)
{
return prefix + (use_opq ? "_opq" : "pq") + std::to_string(num_chunks) + "_pivots.bin";
}

inline std::string get_rotation_matrix_suffix(const std::string &pivot_data_filename)
{
return pivot_data_filename + "_rotation_matrix.bin";
}

} // namespace diskann
93 changes: 93 additions & 0 deletions include/pq_data_store.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#pragma once
#include <memory>
#include "distance.h"
#include "quantized_distance.h"
#include "pq.h"
#include "abstract_data_store.h"

namespace diskann
{
template <typename data_t> class PQDataStore : public AbstractDataStore<data_t>
{

public:
PQDataStore(size_t dim, location_t num_points, size_t num_pq_chunks, std::shared_ptr<Distance<data_t>> distance_fn,
std::shared_ptr<QuantizedDistance<data_t>> pq_distance_fn);
~PQDataStore();

// Load quantized vectors from a set of files. Here filename is treated
// as a prefix and the files are assumed to be named with DiskANN
// conventions.
virtual location_t load(const std::string &file_prefix) override;

// Save quantized vectors to a set of files whose names start with
// file_prefix.
// Currently, the plan is to save the quantized vectors to the quantized
// vectors file.
virtual size_t save(const std::string &file_prefix, const location_t num_points) override;

// Since base class function is pure virtual, we need to declare it here, even though alignent concept is not needed
// for Quantized data stores.
virtual size_t get_aligned_dim() const override;

// Populate quantized data from unaligned data using PQ functionality
virtual void populate_data(const data_t *vectors, const location_t num_pts) override;
virtual void populate_data(const std::string &filename, const size_t offset) override;

virtual void extract_data_to_bin(const std::string &filename, const location_t num_pts) override;

virtual void get_vector(const location_t i, data_t *target) const override;
virtual void set_vector(const location_t i, const data_t *const vector) override;
virtual void prefetch_vector(const location_t loc) override;

virtual void move_vectors(const location_t old_location_start, const location_t new_location_start,
const location_t num_points) override;
virtual void copy_vectors(const location_t from_loc, const location_t to_loc, const location_t num_points) override;

virtual void preprocess_query(const data_t *query, AbstractScratch<data_t> *scratch) const override;

virtual float get_distance(const data_t *query, const location_t loc) const override;
virtual float get_distance(const location_t loc1, const location_t loc2) const override;

// NOTE: Caller must invoke "PQDistance->preprocess_query" ONCE before calling
// this function.
virtual void get_distance(const data_t *preprocessed_query, const location_t *locations,
const uint32_t location_count, float *distances,
AbstractScratch<data_t> *scratch_space) const override;

// NOTE: Caller must invoke "PQDistance->preprocess_query" ONCE before calling
// this function.
virtual void get_distance(const data_t *preprocessed_query, const std::vector<location_t> &ids,
std::vector<float> &distances, AbstractScratch<data_t> *scratch_space) const override;

// We are returning the distance function that is used for full precision
// vectors here, not the PQ distance function. This is because the callers
// all are expecting a Distance<T> not QuantizedDistance<T>.
virtual std::shared_ptr<Distance<data_t>> get_dist_fn() const override;

virtual location_t calculate_medoid() const override;

virtual size_t get_alignment_factor() const override;

protected:
virtual location_t expand(const location_t new_size) override;
virtual location_t shrink(const location_t new_size) override;

virtual location_t load_impl(const std::string &filename);
#ifdef EXEC_ENV_OLS
virtual location_t load_impl(AlignedFileReader &reader);
#endif

private:
uint8_t *_quantized_data = nullptr;
size_t _num_chunks = 0;

// REFACTOR TODO: Doing this temporarily before refactoring OPQ into
// its own class. Remove later.
bool _use_opq = false;

Metric _distance_metric;
std::shared_ptr<Distance<data_t>> _distance_fn = nullptr;
std::shared_ptr<QuantizedDistance<data_t>> _pq_distance_fn = nullptr;
};
} // namespace diskann
Loading