Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
REDMOND\ninchen committed Apr 22, 2024
1 parent 234aefc commit fd98477
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 39 deletions.
35 changes: 26 additions & 9 deletions apps/search_memory_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ namespace po = boost::program_options;

template <typename T, typename LabelT = uint32_t>
int search_memory_index(diskann::Metric &metric, const std::string &index_path, const std::string &result_path_prefix,
const std::string &query_file, const std::string &truthset_file, const uint32_t num_threads,
const std::string &query_file, const std::string &truthset_file,
const std::string &codebook_file, const bool use_pq_build, const bool use_opq,
const uint32_t pq_num_chunks, const uint32_t num_threads,
const uint32_t recall_at, const bool print_all_recalls, const std::vector<uint32_t> &Lvec,
const bool dynamic, const bool tags, const bool show_qps_per_thread,
const std::vector<std::string> &query_filters, const float fail_if_recall_below)
Expand Down Expand Up @@ -82,10 +84,11 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
.is_dynamic_index(dynamic)
.is_enable_tags(tags)
.is_concurrent_consolidate(false)
.is_pq_dist_build(false)
.is_use_opq(false)
.with_num_pq_chunks(0)
.is_pq_dist_build(use_pq_build)
.is_use_opq(use_pq_build)
.with_num_pq_chunks(pq_num_chunks)
.with_num_frozen_pts(num_frozen_pts)
.with_pq_codebook_path(codebook_file)
.build();

auto index_factory = diskann::IndexFactory(config);
Expand Down Expand Up @@ -278,10 +281,10 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
int main(int argc, char **argv)
{
std::string data_type, dist_fn, index_path_prefix, result_path, query_file, gt_file, filter_label, label_type,
query_filters_file;
uint32_t num_threads, K;
query_filters_file, codebook_path;
uint32_t num_threads, K, build_PQ_bytes;
std::vector<uint32_t> Lvec;
bool print_all_recalls, dynamic, tags, show_qps_per_thread;
bool print_all_recalls, dynamic, tags, show_qps_per_thread, use_pq_build, use_opq;
float fail_if_recall_below = 0.0f;

po::options_description desc{
Expand Down Expand Up @@ -331,6 +334,12 @@ int main(int argc, char **argv)
optional_configs.add_options()("fail_if_recall_below",
po::value<float>(&fail_if_recall_below)->default_value(0.0f),
program_options_utils::FAIL_IF_RECALL_BELOW);
optional_configs.add_options()("build_PQ_bytes", po::value<uint32_t>(&build_PQ_bytes)->default_value(0),
program_options_utils::BUIlD_GRAPH_PQ_BYTES);
optional_configs.add_options()("codebook_path", po::value<std::string>(&codebook_path)->default_value(""),
program_options_utils::CODEBOOK_PATH);
optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false),
program_options_utils::USE_OPQ);

// Output controls
po::options_description output_controls("Output controls");
Expand All @@ -352,6 +361,8 @@ int main(int argc, char **argv)
return 0;
}
po::notify(vm);
use_pq_build = (build_PQ_bytes > 0);
use_opq = vm["use_opq"].as<bool>();
}
catch (const std::exception &ex)
{
Expand Down Expand Up @@ -420,18 +431,21 @@ int main(int argc, char **argv)
if (data_type == std::string("int8"))
{
return search_memory_index<int8_t, uint16_t>(
metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls,
metric, index_path_prefix, result_path, query_file, gt_file, codebook_path, use_pq_build, use_opq,
build_PQ_bytes, num_threads, K, print_all_recalls,
Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below);
}
else if (data_type == std::string("uint8"))
{
return search_memory_index<uint8_t, uint16_t>(
metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls,
metric, index_path_prefix, result_path, query_file, gt_file, codebook_path, use_pq_build, use_opq,
build_PQ_bytes, num_threads, K, print_all_recalls,
Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below);
}
else if (data_type == std::string("float"))
{
return search_memory_index<float, uint16_t>(metric, index_path_prefix, result_path, query_file, gt_file,
codebook_path, use_pq_build, use_opq, build_PQ_bytes,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below);
}
Expand All @@ -446,18 +460,21 @@ int main(int argc, char **argv)
if (data_type == std::string("int8"))
{
return search_memory_index<int8_t>(metric, index_path_prefix, result_path, query_file, gt_file,
codebook_path, use_pq_build, use_opq, build_PQ_bytes,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below);
}
else if (data_type == std::string("uint8"))
{
return search_memory_index<uint8_t>(metric, index_path_prefix, result_path, query_file, gt_file,
codebook_path, use_pq_build, use_opq, build_PQ_bytes,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below);
}
else if (data_type == std::string("float"))
{
return search_memory_index<float>(metric, index_path_prefix, result_path, query_file, gt_file,
codebook_path, use_pq_build, use_opq, build_PQ_bytes,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below);
}
Expand Down
27 changes: 20 additions & 7 deletions apps/test_insert_deletes_consolidate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ void build_incremental_index(const std::string &data_path, diskann::IndexWritePa
uint32_t num_start_pts, size_t points_per_checkpoint, size_t checkpoints_per_snapshot,
const std::string &save_path, size_t points_to_delete_from_beginning,
size_t start_deletes_after, bool concurrent, const std::string &label_file,
const std::string &universal_label)
const std::string &universal_label, size_t num_pq_chunks, const std::string& pq_pivot_file)
{
size_t dim, aligned_dim;
size_t num_points;
Expand All @@ -161,7 +161,7 @@ void build_incremental_index(const std::string &data_path, diskann::IndexWritePa
using LabelT = uint32_t;

size_t current_point_offset = points_to_skip;
const size_t last_point_threshold = points_to_skip + max_points_to_insert;
size_t last_point_threshold = points_to_skip + max_points_to_insert;

bool enable_tags = true;
using TagT = uint32_t;
Expand All @@ -182,6 +182,9 @@ void build_incremental_index(const std::string &data_path, diskann::IndexWritePa
.is_filtered(has_labels)
.with_num_frozen_pts(num_start_pts)
.is_concurrent_consolidate(concurrent)
.with_pq_codebook_path(pq_pivot_file)
.is_pq_dist_build(!pq_pivot_file.empty())
.with_num_pq_chunks(num_pq_chunks)
.build();

diskann::IndexFactory index_factory = diskann::IndexFactory(index_config);
Expand All @@ -206,6 +209,7 @@ void build_incremental_index(const std::string &data_path, diskann::IndexWritePa
if (points_to_skip + max_points_to_insert > num_points)
{
max_points_to_insert = num_points - points_to_skip;
last_point_threshold = num_points;
std::cerr << "WARNING: Reducing max_points_to_insert to " << max_points_to_insert
<< " points since the data file has only that many" << std::endl;
}
Expand Down Expand Up @@ -327,6 +331,7 @@ void build_incremental_index(const std::string &data_path, diskann::IndexWritePa
start += points_per_checkpoint, current_point_offset += points_per_checkpoint)
{
const size_t end = std::min(start + points_per_checkpoint, last_point_threshold);
std::cout << std::endl << "Last Point Threshold is:" << last_point_threshold << std::endl;
std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl;

load_aligned_bin_part(data_path, data, start, end - start);
Expand Down Expand Up @@ -377,11 +382,11 @@ int main(int argc, char **argv)
uint32_t num_threads, R, L, num_start_pts;
float alpha, start_point_norm;
size_t points_to_skip, max_points_to_insert, beginning_index_size, points_per_checkpoint, checkpoints_per_snapshot,
points_to_delete_from_beginning, start_deletes_after;
points_to_delete_from_beginning, start_deletes_after, num_pq_chunks;
bool concurrent;

// label options
std::string label_file, label_type, universal_label;
std::string label_file, label_type, universal_label, pq_pivot_file;
std::uint32_t Lf, unique_labels_supported;

po::options_description desc{program_options_utils::make_program_description("test_insert_deletes_consolidate",
Expand Down Expand Up @@ -449,6 +454,11 @@ int main(int argc, char **argv)
optional_configs.add_options()("unique_labels_supported",
po::value<uint32_t>(&unique_labels_supported)->default_value(0),
"Number of unique labels supported by the dynamic index.");
optional_configs.add_options()("pq_pivot_file", po::value<std::string>(&pq_pivot_file)->default_value(""),
"The file stored pq pivot info.");
optional_configs.add_options()("num_pq_chunks", po::value<uint64_t>(&num_pq_chunks)->default_value(0),
"Number of PQ chunks to use.");


optional_configs.add_options()(
"num_start_points",
Expand Down Expand Up @@ -507,17 +517,20 @@ int main(int argc, char **argv)
build_incremental_index<int8_t>(
data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm,
num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix,
points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label);
points_to_delete_from_beginning, start_deletes_after, concurrent,
label_file, universal_label, num_pq_chunks, pq_pivot_file);
else if (data_type == std::string("uint8"))
build_incremental_index<uint8_t>(
data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm,
num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix,
points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label);
points_to_delete_from_beginning, start_deletes_after, concurrent,
label_file, universal_label, num_pq_chunks, pq_pivot_file);
else if (data_type == std::string("float"))
build_incremental_index<float>(data_path, params, points_to_skip, max_points_to_insert,
beginning_index_size, start_point_norm, num_start_pts, points_per_checkpoint,
checkpoints_per_snapshot, index_path_prefix, points_to_delete_from_beginning,
start_deletes_after, concurrent, label_file, universal_label);
start_deletes_after, concurrent, label_file, universal_label,
num_pq_chunks, pq_pivot_file);
else
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
}
Expand Down
11 changes: 9 additions & 2 deletions include/pq_data_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#include <memory>
#include "distance.h"
#include "quantized_distance.h"
#include "pq.h"
#include "fixed_chunk_pq_table.h"
#include "abstract_data_store.h"

namespace diskann
Expand All @@ -13,8 +13,14 @@ template <typename data_t> class PQDataStore : public AbstractDataStore<data_t>
{

public:
#ifdef EXEC_ENV_OLS
PQDataStore(size_t dim, location_t num_points, size_t num_pq_chunks, std::unique_ptr<Distance<data_t>> distance_fn,
std::unique_ptr<QuantizedDistance<data_t>> pq_distance_fn, MemoryMappedFiles &files,
const std::string &codebook_path);
#else
PQDataStore(size_t dim, location_t num_points, size_t num_pq_chunks, std::unique_ptr<Distance<data_t>> distance_fn,
std::unique_ptr<QuantizedDistance<data_t>> pq_distance_fn, const std::string& codebook_path);
std::unique_ptr<QuantizedDistance<data_t>> pq_distance_fn, const std::string &codebook_path);
#endif
PQDataStore(const PQDataStore &) = delete;
PQDataStore &operator=(const PQDataStore &) = delete;
~PQDataStore();
Expand Down Expand Up @@ -85,6 +91,7 @@ template <typename data_t> class PQDataStore : public AbstractDataStore<data_t>
private:
uint8_t *_quantized_data = nullptr;
size_t _num_chunks = 0;
size_t _aligned_dim;

// REFACTOR TODO: Doing this temporarily before refactoring OPQ into
// its own class. Remove later.
Expand Down
Loading

0 comments on commit fd98477

Please sign in to comment.