Skip to content

Commit

Permalink
Adding a new PQ Distance Metric and PQ Data Store (#384)
Browse files Browse the repository at this point in the history
* Added PQ distance hierarchy

Changes to CMakelists

PQDataStore version that builds correctly

Clang-format

* Fixing compile issues after rebase to main

* minor renaming functions

* fixed small bug post rebasing with index factory

* Changes to index factory to support PQDataStore

* Merged graph_store and pq_data_store

* Implementing preprocessing for inmemdatastore

* Incorporating code review comments

* minor bugfix for PQ data allocation

* clang-formatted

* Incorporating CR comments

* Fixing compile error

* minor bug fix + clang-format

* Update pq.h

* Fixing warnings about struct/class incompatibility

---------

Co-authored-by: Gopal Srinivasa <[email protected]>
Co-authored-by: ravishankar <[email protected]>
Co-authored-by: gopalrs <[email protected]>
  • Loading branch information
4 people authored Dec 5, 2023
1 parent 03abc71 commit 5744060
Show file tree
Hide file tree
Showing 28 changed files with 1,179 additions and 276 deletions.
3 changes: 2 additions & 1 deletion apps/build_stitched_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ void prune_and_save(path final_index_path_prefix, path full_index_path_prefix, p

diskann::get_bin_metadata(input_data_path, number_of_label_points, dimension);

diskann::Index<T> index(diskann::Metric::L2, dimension, number_of_label_points, nullptr, nullptr, 0, false, false);
diskann::Index<T> index(diskann::Metric::L2, dimension, number_of_label_points, nullptr, nullptr, 0, false, false,
false, false, 0, false);

// not searching this index, set search_l to 0
index.load(full_index_path_prefix.c_str(), num_threads, 1);
Expand Down
3 changes: 2 additions & 1 deletion apps/utils/count_bfs_levels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ template <typename T> void bfs_count(const std::string &index_path, uint32_t dat
{
using TagT = uint32_t;
using LabelT = uint32_t;
diskann::Index<T, TagT, LabelT> index(diskann::Metric::L2, data_dims, 0, nullptr, nullptr, 0, false, false);
diskann::Index<T, TagT, LabelT> index(diskann::Metric::L2, data_dims, 0, nullptr, nullptr, 0, false, false, false,
false, 0, false);
std::cout << "Index class instantiated" << std::endl;
index.load(index_path.c_str(), 1, 100);
std::cout << "Index loaded" << std::endl;
Expand Down
20 changes: 16 additions & 4 deletions include/abstract_data_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
namespace diskann
{

template <typename data_t> class AbstractScratch;

template <typename data_t> class AbstractDataStore
{
public:
Expand Down Expand Up @@ -78,19 +80,29 @@ template <typename data_t> class AbstractDataStore
// num_points) to zero
virtual void copy_vectors(const location_t from_loc, const location_t to_loc, const location_t num_points) = 0;

// metric specific operations

// With the PQ Data Store PR, we have also changed iterate_to_fixed_point to NOT take the query
// from the scratch object. Therefore every data store has to implement preprocess_query which
// at the least will be to copy the query into the scratch object. So making this pure virtual.
virtual void preprocess_query(const data_t *aligned_query,
AbstractScratch<data_t> *query_scratch = nullptr) const = 0;
// distance functions.
virtual float get_distance(const data_t *query, const location_t loc) const = 0;
virtual void get_distance(const data_t *query, const location_t *locations, const uint32_t location_count,
float *distances) const = 0;
float *distances, AbstractScratch<data_t> *scratch_space = nullptr) const = 0;
// Specific overload for index.cpp.
virtual void get_distance(const data_t *preprocessed_query, const std::vector<location_t> &ids,
std::vector<float> &distances, AbstractScratch<data_t> *scratch_space) const = 0;
virtual float get_distance(const location_t loc1, const location_t loc2) const = 0;

// stats of the data stored in store
// Returns the point in the dataset that is closest to the mean of all points
// in the dataset
virtual location_t calculate_medoid() const = 0;

virtual Distance<data_t> *get_dist_fn() = 0;
// REFACTOR PQ TODO: Each data store knows about its distance function, so this is
// redundant. However, we don't have an OptmizedDataStore yet, and to preserve code
// compability, we are exposing this function.
virtual Distance<data_t> *get_dist_fn() const = 0;

// search helpers
// if the base data is aligned per the request of the metric, this will tell
Expand Down
35 changes: 35 additions & 0 deletions include/abstract_scratch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once
namespace diskann
{

template <typename data_t> class PQScratch;

// By somewhat more than a coincidence, it seems that both InMemQueryScratch
// and SSDQueryScratch have the aligned query and PQScratch objects. So we
// can put them in a neat hierarchy and keep PQScratch as a standalone class.
template <typename data_t> class AbstractScratch
{
public:
AbstractScratch() = default;
// This class does not take any responsibilty for memory management of
// its members. It is the responsibility of the derived classes to do so.
virtual ~AbstractScratch() = default;

// Scratch objects should not be copied
AbstractScratch(const AbstractScratch &) = delete;
AbstractScratch &operator=(const AbstractScratch &) = delete;

data_t *aligned_query_T()
{
return _aligned_query_T;
}
PQScratch<data_t> *pq_scratch()
{
return _pq_scratch;
}

protected:
data_t *_aligned_query_T = nullptr;
PQScratch<data_t> *_pq_scratch = nullptr;
};
} // namespace diskann
2 changes: 1 addition & 1 deletion include/distance.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ template <typename T> class Distance

// Providing a default implementation for the virtual destructor because we
// don't expect most metric implementations to need it.
DISKANN_DLLEXPORT virtual ~Distance();
DISKANN_DLLEXPORT virtual ~Distance() = default;

protected:
diskann::Metric _distance_metric;
Expand Down
14 changes: 10 additions & 4 deletions include/in_mem_data_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,20 @@ template <typename data_t> class InMemDataStore : public AbstractDataStore<data_
const location_t num_points) override;
virtual void copy_vectors(const location_t from_loc, const location_t to_loc, const location_t num_points) override;

virtual float get_distance(const data_t *query, const location_t loc) const override;
virtual void preprocess_query(const data_t *query, AbstractScratch<data_t> *query_scratch) const override;

virtual float get_distance(const data_t *preprocessed_query, const location_t loc) const override;
virtual float get_distance(const location_t loc1, const location_t loc2) const override;
virtual void get_distance(const data_t *query, const location_t *locations, const uint32_t location_count,
float *distances) const override;

virtual void get_distance(const data_t *preprocessed_query, const location_t *locations,
const uint32_t location_count, float *distances,
AbstractScratch<data_t> *scratch) const override;
virtual void get_distance(const data_t *preprocessed_query, const std::vector<location_t> &ids,
std::vector<float> &distances, AbstractScratch<data_t> *scratch_space) const override;

virtual location_t calculate_medoid() const override;

virtual Distance<data_t> *get_dist_fn() override;
virtual Distance<data_t> *get_dist_fn() const override;

virtual size_t get_alignment_factor() const override;

Expand Down
28 changes: 18 additions & 10 deletions include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
#include "in_mem_graph_store.h"
#include "abstract_index.h"

#include "quantized_distance.h"
#include "pq_data_store.h"

#define OVERHEAD_FACTOR 1.1
#define EXPAND_IF_FULL 0
#define DEFAULT_MAXC 750
Expand Down Expand Up @@ -50,7 +53,13 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
**************************************************************************/

public:
// Call this when creating and passing Index Config is inconvenient.
// Constructor for Bulk operations and for creating the index object solely
// for loading a prexisting index.
DISKANN_DLLEXPORT Index(const IndexConfig &index_config, std::shared_ptr<AbstractDataStore<T>> data_store,
std::unique_ptr<AbstractGraphStore> graph_store,
std::shared_ptr<AbstractDataStore<T>> pq_data_store = nullptr);

// Constructor for incremental index
DISKANN_DLLEXPORT Index(Metric m, const size_t dim, const size_t max_points,
const std::shared_ptr<IndexWriteParameters> index_parameters,
const std::shared_ptr<IndexSearchParams> index_search_params,
Expand All @@ -59,9 +68,6 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
const bool pq_dist_build = false, const size_t num_pq_chunks = 0,
const bool use_opq = false, const bool filtered_index = false);

DISKANN_DLLEXPORT Index(const IndexConfig &index_config, std::unique_ptr<AbstractDataStore<T>> data_store,
std::unique_ptr<AbstractGraphStore> graph_store);

DISKANN_DLLEXPORT ~Index();

// Saves graph, data, metadata and associated tags.
Expand Down Expand Up @@ -247,9 +253,9 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
// with iterate_to_fixed_point.
std::vector<uint32_t> get_init_ids();

std::pair<uint32_t, uint32_t> iterate_to_fixed_point(const T *node_coords, const uint32_t Lindex,
const std::vector<uint32_t> &init_ids,
InMemQueryScratch<T> *scratch, bool use_filter,
// The query to use is placed in scratch->aligned_query
std::pair<uint32_t, uint32_t> iterate_to_fixed_point(InMemQueryScratch<T> *scratch, const uint32_t Lindex,
const std::vector<uint32_t> &init_ids, bool use_filter,
const std::vector<LabelT> &filters, bool search_invocation);

void search_for_point_and_prune(int location, uint32_t Lindex, std::vector<uint32_t> &pruned_list,
Expand Down Expand Up @@ -329,14 +335,13 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
Metric _dist_metric = diskann::L2;

// Data
std::unique_ptr<AbstractDataStore<T>> _data_store;
std::shared_ptr<AbstractDataStore<T>> _data_store;

// Graph related data structures
std::unique_ptr<AbstractGraphStore> _graph_store;

char *_opt_graph = nullptr;

T *_data = nullptr; // coordinates of all base points
// Dimensions
size_t _dim = 0;
size_t _nd = 0; // number of active points i.e. existing in the graph
Expand Down Expand Up @@ -396,7 +401,10 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
bool _pq_dist = false;
bool _use_opq = false;
size_t _num_pq_chunks = 0;
uint8_t *_pq_data = nullptr;
// REFACTOR
// uint8_t *_pq_data = nullptr;
std::shared_ptr<QuantizedDistance<T>> _pq_distance_fn = nullptr;
std::shared_ptr<AbstractDataStore<T>> _pq_data_store = nullptr;
bool _pq_generated = false;
FixedChunkPQTable _pq_table;

Expand Down
22 changes: 15 additions & 7 deletions include/index_factory.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "index.h"
#include "abstract_graph_store.h"
#include "in_mem_graph_store.h"
#include "pq_data_store.h"

namespace diskann
{
Expand All @@ -10,16 +11,23 @@ class IndexFactory
DISKANN_DLLEXPORT explicit IndexFactory(const IndexConfig &config);
DISKANN_DLLEXPORT std::unique_ptr<AbstractIndex> create_instance();

// Consruct a data store with distance function emplaced within
template <typename T>
DISKANN_DLLEXPORT static std::unique_ptr<AbstractDataStore<T>> construct_datastore(const DataStoreStrategy stratagy,
const size_t num_points,
const size_t dimension,
const Metric m);

DISKANN_DLLEXPORT static std::unique_ptr<AbstractGraphStore> construct_graphstore(
const GraphStoreStrategy stratagy, const size_t size, const size_t reserve_graph_degree);

template <typename T>
DISKANN_DLLEXPORT static std::shared_ptr<AbstractDataStore<T>> construct_datastore(DataStoreStrategy stratagy,
size_t num_points,
size_t dimension, Metric m);
// For now PQDataStore incorporates within itself all variants of quantization that we support. In the
// future it may be necessary to introduce an AbstractPQDataStore class to spearate various quantization
// flavours.
template <typename T>
DISKANN_DLLEXPORT static std::shared_ptr<PQDataStore<T>> construct_pq_datastore(DataStoreStrategy strategy,
size_t num_points, size_t dimension,
Metric m, size_t num_pq_chunks,
bool use_opq);
template <typename T> static Distance<T> *construct_inmem_distance_fn(Metric m);

private:
void check_config();

Expand Down
42 changes: 1 addition & 41 deletions include/pq.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,7 @@
#pragma once

#include "utils.h"

#define NUM_PQ_BITS 8
#define NUM_PQ_CENTROIDS (1 << NUM_PQ_BITS)
#define MAX_OPQ_ITERS 20
#define NUM_KMEANS_REPS_PQ 12
#define MAX_PQ_TRAINING_SET_SIZE 256000
#define MAX_PQ_CHUNKS 512
#include "pq_common.h"

namespace diskann
{
Expand Down Expand Up @@ -53,40 +47,6 @@ class FixedChunkPQTable
void populate_chunk_inner_products(const float *query_vec, float *dist_vec);
};

template <typename T> struct PQScratch
{
float *aligned_pqtable_dist_scratch = nullptr; // MUST BE AT LEAST [256 * NCHUNKS]
float *aligned_dist_scratch = nullptr; // MUST BE AT LEAST diskann MAX_DEGREE
uint8_t *aligned_pq_coord_scratch = nullptr; // MUST BE AT LEAST [N_CHUNKS * MAX_DEGREE]
float *rotated_query = nullptr;
float *aligned_query_float = nullptr;

PQScratch(size_t graph_degree, size_t aligned_dim)
{
diskann::alloc_aligned((void **)&aligned_pq_coord_scratch,
(size_t)graph_degree * (size_t)MAX_PQ_CHUNKS * sizeof(uint8_t), 256);
diskann::alloc_aligned((void **)&aligned_pqtable_dist_scratch, 256 * (size_t)MAX_PQ_CHUNKS * sizeof(float),
256);
diskann::alloc_aligned((void **)&aligned_dist_scratch, (size_t)graph_degree * sizeof(float), 256);
diskann::alloc_aligned((void **)&aligned_query_float, aligned_dim * sizeof(float), 8 * sizeof(float));
diskann::alloc_aligned((void **)&rotated_query, aligned_dim * sizeof(float), 8 * sizeof(float));

memset(aligned_query_float, 0, aligned_dim * sizeof(float));
memset(rotated_query, 0, aligned_dim * sizeof(float));
}

void set(size_t dim, T *query, const float norm = 1.0f)
{
for (size_t d = 0; d < dim; ++d)
{
if (norm != 1.0f)
rotated_query[d] = aligned_query_float[d] = static_cast<float>(query[d]) / norm;
else
rotated_query[d] = aligned_query_float[d] = static_cast<float>(query[d]);
}
}
};

void aggregate_coords(const std::vector<unsigned> &ids, const uint8_t *all_coords, const uint64_t ndims, uint8_t *out);

void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, const size_t pq_nchunks, const float *pq_dists,
Expand Down
30 changes: 30 additions & 0 deletions include/pq_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include <string>
#include <sstream>

#define NUM_PQ_BITS 8
#define NUM_PQ_CENTROIDS (1 << NUM_PQ_BITS)
#define MAX_OPQ_ITERS 20
#define NUM_KMEANS_REPS_PQ 12
#define MAX_PQ_TRAINING_SET_SIZE 256000
#define MAX_PQ_CHUNKS 512

namespace diskann
{
inline std::string get_quantized_vectors_filename(const std::string &prefix, bool use_opq, uint32_t num_chunks)
{
return prefix + (use_opq ? "_opq" : "pq") + std::to_string(num_chunks) + "_compressed.bin";
}

inline std::string get_pivot_data_filename(const std::string &prefix, bool use_opq, uint32_t num_chunks)
{
return prefix + (use_opq ? "_opq" : "pq") + std::to_string(num_chunks) + "_pivots.bin";
}

inline std::string get_rotation_matrix_suffix(const std::string &pivot_data_filename)
{
return pivot_data_filename + "_rotation_matrix.bin";
}

} // namespace diskann
Loading

0 comments on commit 5744060

Please sign in to comment.