Skip to content

Commit

Permalink
Incorporating code review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
gopal-msr committed Sep 6, 2023
1 parent b8aeafd commit 604bd92
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 87 deletions.
1 change: 0 additions & 1 deletion include/abstract_data_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ template <typename data_t> class AbstractDataStore
virtual void get_distance(const data_t *query, const location_t *locations, const uint32_t location_count,
float *distances, AbstractScratch<data_t> *scratch_space = nullptr) const = 0;
// Specific overload for index.cpp.
// REFACTOR TODO: Check if the default implementation is sufficient for most cases.
virtual void get_distance(const data_t *preprocessed_query, const std::vector<location_t> &ids,
std::vector<float> &distances, AbstractScratch<data_t> *scratch_space) const = 0;
virtual float get_distance(const location_t loc1, const location_t loc2) const = 0;
Expand Down
2 changes: 1 addition & 1 deletion include/abstract_scratch.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ template <typename data_t> struct AbstractScratch
AbstractScratch() = default;
// This class does not take any responsibilty for memory management of
// its members. It is the responsibility of the derived classes to do so.
virtual ~AbstractScratch(){};
virtual ~AbstractScratch() = default;

// Scratch objects should not be copied
AbstractScratch(const AbstractScratch &) = delete;
Expand Down
7 changes: 3 additions & 4 deletions include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include "in_mem_graph_store.h"
#include "abstract_index.h"

// REFACTOR
#include "quantized_distance.h"
#include "pq_data_store.h"

Expand Down Expand Up @@ -63,9 +62,9 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
// Constructor for incremental index
DISKANN_DLLEXPORT Index(Metric m, const size_t dim, const size_t max_points,
const std::shared_ptr<IndexWriteParameters> index_parameters,
const std::shared_ptr<IndexSearchParams> index_search_params, const size_t num_frozen_pts,
const bool dynamic_index, const bool enable_tags, const bool concurrent_consolidate,
const bool pq_dist_build, const size_t num_pq_chunks, const bool use_opq);
const std::shared_ptr<IndexSearchParams> index_search_params, const size_t num_frozen_pts = 0,
const bool dynamic_index = false, const bool enable_tags = false, const bool concurrent_consolidate = false,
const bool pq_dist_build = false, const size_t num_pq_chunks = 0, const bool use_opq = false);

DISKANN_DLLEXPORT ~Index();

Expand Down
10 changes: 6 additions & 4 deletions include/pq_data_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ template <typename data_t> class PQDataStore : public AbstractDataStore<data_t>
{

public:
PQDataStore(size_t dim, location_t num_points, size_t num_pq_chunks, std::shared_ptr<Distance<data_t>> distance_fn,
std::shared_ptr<QuantizedDistance<data_t>> pq_distance_fn);
PQDataStore(size_t dim, location_t num_points, size_t num_pq_chunks, std::unique_ptr<Distance<data_t>> distance_fn,
std::unique_ptr<QuantizedDistance<data_t>> pq_distance_fn);
PQDataStore(const PQDataStore&) = delete;
PQDataStore &operator=(const PQDataStore&) = delete;
~PQDataStore();

// Load quantized vectors from a set of files. Here filename is treated
Expand Down Expand Up @@ -89,7 +91,7 @@ template <typename data_t> class PQDataStore : public AbstractDataStore<data_t>
bool _use_opq = false;

Metric _distance_metric;
std::shared_ptr<Distance<data_t>> _distance_fn = nullptr;
std::shared_ptr<QuantizedDistance<data_t>> _pq_distance_fn = nullptr;
std::unique_ptr<Distance<data_t>> _distance_fn = nullptr;
std::unique_ptr<QuantizedDistance<data_t>> _pq_distance_fn = nullptr;
};
} // namespace diskann
2 changes: 1 addition & 1 deletion include/quantized_distance.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ template <typename data_t> class QuantizedDistance
QuantizedDistance() = default;
QuantizedDistance(const QuantizedDistance &) = delete;
QuantizedDistance &operator=(const QuantizedDistance &) = delete;
virtual ~QuantizedDistance(){};
virtual ~QuantizedDistance() = default;

virtual bool is_opq() const = 0;
virtual std::string get_quantized_vectors_filename(const std::string &prefix) const = 0;
Expand Down
1 change: 0 additions & 1 deletion include/scratch.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ template <typename T> class InMemQueryScratch : public AbstractScratch<T>
{
public:
~InMemQueryScratch();
// REFACTOR TODO: move all parameters to a new class.
InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, uint32_t r, uint32_t maxc, size_t dim, size_t aligned_dim,
size_t alignment_factor, bool init_pq_scratch = false);
void resize_for_new_L(uint32_t new_search_l);
Expand Down
77 changes: 9 additions & 68 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
#ifdef _WINDOWS
#include <xmmintrin.h>
#endif
// REFACTOR TODO: Must move to factory.
#include "pq_scratch.h"
#include "pq_l2_distance.h"

#include "index.h"

#define MAX_POINTS_FOR_USING_BITSET 10000000
Expand Down Expand Up @@ -763,12 +761,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(

float *pq_dists = nullptr;

//REFACTOR PQ: Preprocess the query with the appropriate "pq" datastore. It could
// also be the "actual" datastore in which case this is a no-op.
//if (_pq_dist)
//{
_pq_data_store->preprocess_query(aligned_query, scratch);
//}
_pq_data_store->preprocess_query(aligned_query, scratch);

if (expanded_nodes.size() > 0 || id_scratch.size() > 0)
{
Expand Down Expand Up @@ -798,7 +791,6 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(

// Lambda to batch compute query<-> node distances in PQ space
auto compute_dists = [this, scratch, pq_dists](const std::vector<uint32_t> &ids, std::vector<float> &dists_out) {
// REFACTOR
_pq_data_store->get_distance(scratch->aligned_query(), ids, dists_out, scratch);
};

Expand Down Expand Up @@ -830,20 +822,11 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
}

float distance;
//if (_pq_dist)
//{
// REFACTOR PQ. We pay a small price in efficiency for better code structure.
uint32_t ids[] = {id};
float distances[] = {std::numeric_limits<float>::max()};
_pq_data_store->get_distance(aligned_query, ids, 1, distances, scratch);
distance = distances[0];
// pq_dist_lookup(pq_coord_scratch, 1, this->_num_pq_chunks, pq_dists,
// &distance);
//}
//else
//{
// distance = _data_store->get_distance(aligned_query, id);
//}
uint32_t ids[] = {id};
float distances[] = {std::numeric_limits<float>::max()};
_pq_data_store->get_distance(aligned_query, ids, 1, distances, scratch);
distance = distances[0];

Neighbor nn = Neighbor(id, distance);
best_L_nodes.insert(nn);
}
Expand Down Expand Up @@ -915,29 +898,8 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
}
}

// Compute distances to unvisited nodes in the expansion
// REFACTOR PQ
//if (_pq_dist)
//{
assert(dist_scratch.capacity() >= id_scratch.size());
compute_dists(id_scratch, dist_scratch);
//}
//else
//{
// assert(dist_scratch.size() == 0);
// for (size_t m = 0; m < id_scratch.size(); ++m)
// {
// uint32_t id = id_scratch[m];

// if (m + 1 < id_scratch.size())
// {
// auto nextn = id_scratch[m + 1];
// _data_store->prefetch_vector(nextn);
// }

// dist_scratch.push_back(_data_store->get_distance(aligned_query, id));
// }
//}
assert(dist_scratch.capacity() >= id_scratch.size());
compute_dists(id_scratch, dist_scratch);
cmps += (uint32_t)id_scratch.size();

// Insert <id, dist> pairs into the pool of candidates
Expand Down Expand Up @@ -1586,8 +1548,6 @@ void Index<T, TagT, LabelT>::build(const char *filename, const size_t num_points
<< " points, but "
<< "index can support only " << _max_points << " points as specified in constructor." << std::endl;

// REFACTOR PQDataStore will take care of its memory
// if (_pq_dist) aligned_free(_pq_data);
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}

Expand All @@ -1597,8 +1557,6 @@ void Index<T, TagT, LabelT>::build(const char *filename, const size_t num_points
stream << "ERROR: Driver requests loading " << num_points_to_load << " points and file has only "
<< file_num_points << " points." << std::endl;

// REFACTOR: PQDataStore will take care of its memory
// if (_pq_dist) aligned_free(_pq_data);
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}

Expand All @@ -1609,16 +1567,13 @@ void Index<T, TagT, LabelT>::build(const char *filename, const size_t num_points
<< "but file has " << file_dim << " dimension." << std::endl;
diskann::cerr << stream.str() << std::endl;

// REFACTOR: PQDataStore will take care of its memory
// if (_pq_dist) aligned_free(_pq_data);
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}

//REFACTOR PQ TODO: We can remove this if and add a check in the InMemDataStore
//to not populate_data if it has been called once.
if (_pq_dist)
{
// REFACTOR
#ifdef EXEC_ENV_OLS
std::stringstream ss;
ss << "PQ Build is not supported in DLVS environment (i.e. if EXEC_ENV_OLS is defined)" << std::endl;
Expand Down Expand Up @@ -3023,12 +2978,6 @@ template <typename T, typename TagT, typename LabelT> void Index<T, TagT, LabelT
delete[] bfs_sets;
}

// REFACTOR: This should be an OptimizedDataStore class, dummy impl here for
// compiling sake template <typename T, typename TagT, typename LabelT> void
// Index<T, TagT, LabelT>::optimize_index_layout()
//{ // use after build or load
//}

// REFACTOR: This should be an OptimizedDataStore class
template <typename T, typename TagT, typename LabelT> void Index<T, TagT, LabelT>::optimize_index_layout()
{ // use after build or load
Expand Down Expand Up @@ -3065,14 +3014,6 @@ template <typename T, typename TagT, typename LabelT> void Index<T, TagT, LabelT
delete[] cur_vec;
}

// REFACTOR: once optimized layout becomes its own Data+Graph store, we should
// just invoke regular search
// template <typename T, typename TagT, typename LabelT>
// void Index<T, TagT, LabelT>::search_with_optimized_layout(const T *query,
// size_t K, size_t L, uint32_t *indices)
//{
//}

template <typename T, typename TagT, typename LabelT>
void Index<T, TagT, LabelT>::_search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices)
{
Expand Down
8 changes: 4 additions & 4 deletions src/index_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,16 @@ std::shared_ptr<PQDataStore<T>> IndexFactory::construct_pq_datastore(DataStoreSt
size_t dimension, Metric m, size_t num_pq_chunks,
bool use_opq)
{
std::shared_ptr<Distance<T>> distance_fn;
std::shared_ptr<QuantizedDistance<T>> quantized_distance_fn;
std::unique_ptr<Distance<T>> distance_fn;
std::unique_ptr<QuantizedDistance<T>> quantized_distance_fn;

quantized_distance_fn = std::make_shared<PQL2Distance<T>>((uint32_t)num_pq_chunks, use_opq);
quantized_distance_fn = std::move(std::make_unique<PQL2Distance<T>>((uint32_t)num_pq_chunks, use_opq));
switch (strategy)
{
case DataStoreStrategy::MEMORY:
distance_fn.reset(construct_inmem_distance_fn<T>(m));
return std::make_shared<diskann::PQDataStore<T>>(dimension, (location_t)(num_points), num_pq_chunks,
distance_fn, quantized_distance_fn);
std::move(distance_fn), std::move(quantized_distance_fn));
default:
// REFACTOR TODO: We do support diskPQ - so we may need to add a new class for SSDPQDataStore!
break;
Expand Down
8 changes: 5 additions & 3 deletions src/pq_data_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@ namespace diskann
// this is true.
template <typename data_t>
PQDataStore<data_t>::PQDataStore(size_t dim, location_t num_points, size_t num_pq_chunks,
std::shared_ptr<Distance<data_t>> distance_fn,
std::shared_ptr<QuantizedDistance<data_t>> pq_distance_fn)
std::unique_ptr<Distance<data_t>> distance_fn,
std::unique_ptr<QuantizedDistance<data_t>> pq_distance_fn)
: AbstractDataStore<data_t>(num_points, dim), _quantized_data(nullptr), _num_chunks(num_pq_chunks),
_distance_metric(distance_fn->get_metric()), _distance_fn(distance_fn), _pq_distance_fn(pq_distance_fn)
_distance_metric(distance_fn->get_metric())
{
if (num_pq_chunks > dim) {
throw diskann::ANNException("ERROR: num_pq_chunks > dim", -1, __FUNCSIG__, __FILE__, __LINE__);
}
_distance_fn = std::move(distance_fn);
_pq_distance_fn = std::move(pq_distance_fn);
}

template <typename data_t> PQDataStore<data_t>::~PQDataStore()
Expand Down

0 comments on commit 604bd92

Please sign in to comment.