Skip to content

Commit

Permalink
No recall diff.
Browse files Browse the repository at this point in the history
  • Loading branch information
REDMOND\ninchen committed Apr 22, 2024
1 parent ab59517 commit 234aefc
Showing 1 changed file with 39 additions and 15 deletions.
54 changes: 39 additions & 15 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,21 @@ Index<T, TagT, LabelT>::Index(const IndexConfig &index_config,
throw ANNException("ERROR: Dynamic Indexing must have tags enabled.", -1, __FUNCSIG__, __FILE__, __LINE__);
}

_data_store = data_store;
_graph_store = std::move(graph_store);

if (_pq_dist)
{
if (_dynamic_index)
throw ANNException("ERROR: Dynamic Indexing not supported with PQ distance based "
"index construction",
-1, __FUNCSIG__, __FILE__, __LINE__);
if (_dist_metric == diskann::Metric::INNER_PRODUCT)
throw ANNException("ERROR: Inner product metrics not yet supported "
"with PQ distance "
"base index",
-1, __FUNCSIG__, __FILE__, __LINE__);
_pq_data_store = pq_data_store;
}
else
{
_pq_data_store = _data_store;
}

if (_dynamic_index && _num_frozen_pts == 0)
Expand All @@ -82,11 +86,6 @@ Index<T, TagT, LabelT>::Index(const IndexConfig &index_config,
const size_t total_internal_points = _max_points + _num_frozen_pts;

_start = (uint32_t)_max_points;

_data_store = data_store;
_pq_data_store = pq_data_store;
_graph_store = std::move(graph_store);

_locks = std::vector<non_recursive_mutex>(total_internal_points);
if (_enable_tags)
{
Expand Down Expand Up @@ -158,6 +157,9 @@ Index<T, TagT, LabelT>::Index(Metric m, const size_t dim, const size_t max_point
.with_data_type(diskann_type_to_name<T>())
.with_pq_codebook_path(codebook_path)
.build(),
#ifdef EXEC_ENV_OLS
files,
#endif
IndexFactory::construct_datastore<T>(
DataStoreStrategy::MEMORY,
max_points + (dynamic_index && num_frozen_pts == 0 ? (size_t)1 : num_frozen_pts), dim, m),
Expand Down Expand Up @@ -259,6 +261,8 @@ template <typename T, typename TagT, typename LabelT> size_t Index<T, TagT, Labe
// Note: at this point, either _nd == _max_points or any frozen points have
// been temporarily moved to _nd, so _nd + _num_frozen_pts is the valid
// location limit.

// ninchen: Save pq vector as well?
return _data_store->save(data_file, (location_t)(_nd + _num_frozen_pts));
}

Expand Down Expand Up @@ -841,9 +845,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
assert(id_scratch.size() == 0);

T *aligned_query = scratch->aligned_query();

float *pq_dists = nullptr;

_pq_data_store->preprocess_query(aligned_query, scratch);

if (expanded_nodes.size() > 0 || id_scratch.size() > 0)
Expand Down Expand Up @@ -2202,10 +2204,8 @@ size_t Index<T, TagT, LabelT>::search_with_tags(const T *query, const uint64_t K
std::shared_lock<std::shared_timed_mutex> ul(_update_lock);

const std::vector<uint32_t> init_ids = get_init_ids();
_pq_data_store->preprocess_query(query, scratch);

//_distance->preprocess_query(query, _data_store->get_dims(),
// scratch->aligned_query());
_data_store->preprocess_query(query, scratch);
if (!use_filters)
{
const std::vector<LabelT> unused_filter_label;
Expand Down Expand Up @@ -2236,7 +2236,14 @@ size_t Index<T, TagT, LabelT>::search_with_tags(const T *query, const uint64_t K

if (res_vectors.size() > 0)
{
_data_store->get_vector(node.id, res_vectors[pos]);
if (_pq_dist)
{
_pq_data_store->get_vector(node.id, res_vectors[pos]);
}
else
{
_data_store->get_vector(node.id, res_vectors[pos]);
}
}

if (distances != nullptr)
Expand Down Expand Up @@ -2830,6 +2837,12 @@ template <typename T, typename TagT, typename LabelT> void Index<T, TagT, LabelT
assert(_empty_slots.size() == 0); // should not resize if there are empty slots.

_data_store->resize((location_t)new_internal_points);

if (_pq_dist)
{
_pq_data_store->resize((location_t)new_internal_points);
}

_graph_store->resize_graph(new_internal_points);
_locks = std::vector<non_recursive_mutex>(new_internal_points);

Expand Down Expand Up @@ -2940,6 +2953,12 @@ int Index<T, TagT, LabelT>::insert_point(const T *point, const TagT tag, const s
_label_to_start_id[label] = (uint32_t)fz_location;
_location_to_labels[fz_location] = {label};
_data_store->set_vector((location_t)fz_location, point);

if (_pq_dist)
{
_pq_data_store->set_vector((location_t)fz_location, point);
}

_frozen_pts_used++;
}
}
Expand Down Expand Up @@ -3002,6 +3021,11 @@ int Index<T, TagT, LabelT>::insert_point(const T *point, const TagT tag, const s

_data_store->set_vector(location, point); // update datastore

if (_pq_dist)
{
_pq_data_store->set_vector(location, point); // Update PQDataStore
}

// Find and add appropriate graph edges
ScratchStoreManager<InMemQueryScratch<T>> manager(_query_scratch);
auto scratch = manager.scratch_space();
Expand Down

0 comments on commit 234aefc

Please sign in to comment.