diff --git a/apps/build_disk_index.cpp b/apps/build_disk_index.cpp index f48b61726..41a885993 100644 --- a/apps/build_disk_index.cpp +++ b/apps/build_disk_index.cpp @@ -1,15 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include +#include -#include "utils.h" #include "disk_utils.h" -#include "math_utils.h" #include "index.h" +#include "math_utils.h" #include "partition.h" #include "program_options_utils.hpp" +#include "utils.h" namespace po = boost::program_options; diff --git a/apps/build_memory_index.cpp b/apps/build_memory_index.cpp index 544e42dee..f0d469f4d 100644 --- a/apps/build_memory_index.cpp +++ b/apps/build_memory_index.cpp @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include #include +#include +#include #include "index.h" -#include "utils.h" #include "program_options_utils.hpp" +#include "utils.h" #ifndef _WINDOWS #include @@ -16,9 +16,9 @@ #include #endif -#include "memory_mapper.h" #include "ann_exception.h" #include "index_factory.h" +#include "memory_mapper.h" namespace po = boost::program_options; diff --git a/apps/build_stitched_index.cpp b/apps/build_stitched_index.cpp index 60e38c1be..0f385cb88 100644 --- a/apps/build_stitched_index.cpp +++ b/apps/build_stitched_index.cpp @@ -1,15 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +#include "filter_utils.h" #include #include #include #include +#include #include #include #include -#include "filter_utils.h" -#include #ifndef _WINDOWS #include #endif @@ -17,8 +17,8 @@ #include "index.h" #include "memory_mapper.h" #include "parameters.h" -#include "utils.h" #include "program_options_utils.hpp" +#include "utils.h" namespace po = boost::program_options; typedef std::tuple>, uint64_t> stitch_indices_return_values; diff --git a/apps/range_search_disk_index.cpp b/apps/range_search_disk_index.cpp index 31675724b..3975298ae 100644 --- a/apps/range_search_disk_index.cpp +++ b/apps/range_search_disk_index.cpp @@ -2,26 +2,26 @@ // Licensed under the MIT license. #include +#include #include #include #include #include -#include -#include "index.h" #include "disk_utils.h" +#include "index.h" #include "math_utils.h" #include "memory_mapper.h" -#include "pq_flash_index.h" #include "partition.h" -#include "timer.h" +#include "pq_flash_index.h" #include "program_options_utils.hpp" +#include "timer.h" #ifndef _WINDOWS +#include "linux_aligned_file_reader.h" #include #include #include -#include "linux_aligned_file_reader.h" #else #ifdef USE_BING_INFRA #include "bing_aligned_file_reader.h" diff --git a/apps/search_disk_index.cpp b/apps/search_disk_index.cpp index 7e2a7ac6d..925f31775 100644 --- a/apps/search_disk_index.cpp +++ b/apps/search_disk_index.cpp @@ -4,21 +4,21 @@ #include "common_includes.h" #include -#include "index.h" #include "disk_utils.h" +#include "index.h" #include "math_utils.h" #include "memory_mapper.h" #include "partition.h" -#include "pq_flash_index.h" -#include "timer.h" #include "percentile_stats.h" +#include "pq_flash_index.h" #include "program_options_utils.hpp" +#include "timer.h" #ifndef _WINDOWS +#include "linux_aligned_file_reader.h" #include #include #include -#include "linux_aligned_file_reader.h" #else #ifdef USE_BING_INFRA #include "bing_aligned_file_reader.h" @@ -123,8 +123,8 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre diskann::cout << "Caching " << num_nodes_to_cache << " nodes around medoid(s)" << std::endl; _pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list); // if (num_nodes_to_cache > 0) - // _pFlashIndex->generate_cache_list_from_sample_queries(warmup_query_file, 15, 6, num_nodes_to_cache, - // num_threads, node_list); + // _pFlashIndex->generate_cache_list_from_sample_queries(warmup_query_file, + // 15, 6, num_nodes_to_cache, num_threads, node_list); _pFlashIndex->load_cache_list(node_list); node_list.clear(); node_list.shrink_to_fit(); diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index 1a9acc285..9126ad1fc 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -1,14 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +#include +#include #include #include -#include #include #include #include #include -#include #ifndef _WINDOWS #include @@ -18,10 +18,10 @@ #endif #include "index.h" +#include "index_factory.h" #include "memory_mapper.h" -#include "utils.h" #include "program_options_utils.hpp" -#include "index_factory.h" +#include "utils.h" namespace po = boost::program_options; @@ -323,9 +323,9 @@ int main(int argc, char **argv) optional_configs.add_options()("num_threads,T", po::value(&num_threads)->default_value(omp_get_num_procs()), program_options_utils::NUMBER_THREADS_DESCRIPTION); - optional_configs.add_options()( - "dynamic", po::value(&dynamic)->default_value(false), - "Whether the index is dynamic. Dynamic indices must have associated tags. Default false."); + optional_configs.add_options()("dynamic", po::value(&dynamic)->default_value(false), + "Whether the index is dynamic. Dynamic indices must have associated " + "tags. Default false."); optional_configs.add_options()("tags", po::value(&tags)->default_value(false), "Whether to search with external identifiers (tags). Default false."); optional_configs.add_options()("fail_if_recall_below", diff --git a/apps/test_insert_deletes_consolidate.cpp b/apps/test_insert_deletes_consolidate.cpp index 97aed1864..21ce4250f 100644 --- a/apps/test_insert_deletes_consolidate.cpp +++ b/apps/test_insert_deletes_consolidate.cpp @@ -1,19 +1,19 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +#include +#include #include #include #include #include #include #include -#include -#include -#include "utils.h" #include "filter_utils.h" -#include "program_options_utils.hpp" #include "index_factory.h" +#include "program_options_utils.hpp" +#include "utils.h" #ifndef _WINDOWS #include diff --git a/apps/test_streaming_scenario.cpp b/apps/test_streaming_scenario.cpp index 5a43a69f3..d8ea0577c 100644 --- a/apps/test_streaming_scenario.cpp +++ b/apps/test_streaming_scenario.cpp @@ -1,20 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +#include +#include +#include #include +#include #include #include #include #include #include -#include -#include -#include -#include -#include "utils.h" #include "filter_utils.h" #include "program_options_utils.hpp" +#include "utils.h" #ifndef _WINDOWS #include diff --git a/apps/utils/bin_to_fvecs.cpp b/apps/utils/bin_to_fvecs.cpp index e9a6a8ecc..ebd8229ba 100644 --- a/apps/utils/bin_to_fvecs.cpp +++ b/apps/utils/bin_to_fvecs.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "util.h" +#include void block_convert(std::ifstream &writr, std::ofstream &readr, float *read_buf, float *write_buf, uint64_t npts, uint64_t ndims) diff --git a/apps/utils/bin_to_tsv.cpp b/apps/utils/bin_to_tsv.cpp index 7851bef6d..5c31c8595 100644 --- a/apps/utils/bin_to_tsv.cpp +++ b/apps/utils/bin_to_tsv.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include template void block_convert(std::ofstream &writer, std::ifstream &reader, T *read_buf, size_t npts, size_t ndims) diff --git a/apps/utils/calculate_recall.cpp b/apps/utils/calculate_recall.cpp index dc76252cc..3946bfdf2 100644 --- a/apps/utils/calculate_recall.cpp +++ b/apps/utils/calculate_recall.cpp @@ -9,8 +9,8 @@ #include #include -#include "utils.h" #include "disk_utils.h" +#include "utils.h" int main(int argc, char **argv) { diff --git a/apps/utils/compute_groundtruth.cpp b/apps/utils/compute_groundtruth.cpp index da32fd7c6..b86f28289 100644 --- a/apps/utils/compute_groundtruth.cpp +++ b/apps/utils/compute_groundtruth.cpp @@ -1,25 +1,25 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include -#include #include +#include +#include +#include -#include #include +#include #include #include -#include -#include #include -#include -#include +#include #include -#include -#include +#include +#include +#include #include #include +#include +#include #ifdef _WINDOWS #include diff --git a/apps/utils/compute_groundtruth_for_filters.cpp b/apps/utils/compute_groundtruth_for_filters.cpp index 52e586475..e90da2444 100644 --- a/apps/utils/compute_groundtruth_for_filters.cpp +++ b/apps/utils/compute_groundtruth_for_filters.cpp @@ -1,25 +1,25 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include -#include #include +#include +#include +#include -#include #include +#include #include #include -#include -#include #include -#include -#include +#include #include -#include -#include +#include +#include +#include #include #include +#include +#include #ifdef _WINDOWS #include diff --git a/apps/utils/count_bfs_levels.cpp b/apps/utils/count_bfs_levels.cpp index 6dd2d6233..6e45ef13d 100644 --- a/apps/utils/count_bfs_levels.cpp +++ b/apps/utils/count_bfs_levels.cpp @@ -1,14 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +#include +#include #include #include -#include #include #include #include #include -#include #ifndef _WINDOWS #include @@ -17,9 +17,9 @@ #include #endif -#include "utils.h" #include "index.h" #include "memory_mapper.h" +#include "utils.h" namespace po = boost::program_options; diff --git a/apps/utils/create_disk_layout.cpp b/apps/utils/create_disk_layout.cpp index f494c1227..6d5314fb4 100644 --- a/apps/utils/create_disk_layout.cpp +++ b/apps/utils/create_disk_layout.cpp @@ -8,9 +8,9 @@ #include #include -#include "utils.h" -#include "disk_utils.h" #include "cached_io.h" +#include "disk_utils.h" +#include "utils.h" template int create_disk_layout(char **argv) { diff --git a/apps/utils/float_bin_to_int8.cpp b/apps/utils/float_bin_to_int8.cpp index 1982005af..c3fa8f8ec 100644 --- a/apps/utils/float_bin_to_int8.cpp +++ b/apps/utils/float_bin_to_int8.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include void block_convert(std::ofstream &writer, int8_t *write_buf, std::ifstream &reader, float *read_buf, size_t npts, size_t ndims, float bias, float scale) diff --git a/apps/utils/fvecs_to_bin.cpp b/apps/utils/fvecs_to_bin.cpp index 873ad3b0c..1428a9c6e 100644 --- a/apps/utils/fvecs_to_bin.cpp +++ b/apps/utils/fvecs_to_bin.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include // Convert float types void block_convert_float(std::ifstream &reader, std::ofstream &writer, float *read_buf, float *write_buf, size_t npts, diff --git a/apps/utils/fvecs_to_bvecs.cpp b/apps/utils/fvecs_to_bvecs.cpp index f9c2aa71b..60ac12126 100644 --- a/apps/utils/fvecs_to_bvecs.cpp +++ b/apps/utils/fvecs_to_bvecs.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include void block_convert(std::ifstream &reader, std::ofstream &writer, float *read_buf, uint8_t *write_buf, size_t npts, size_t ndims) diff --git a/apps/utils/gen_random_slice.cpp b/apps/utils/gen_random_slice.cpp index a4cd96e0a..64bc994ef 100644 --- a/apps/utils/gen_random_slice.cpp +++ b/apps/utils/gen_random_slice.cpp @@ -1,7 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include +#include "partition.h" +#include "utils.h" #include #include #include @@ -10,10 +11,9 @@ #include #include #include +#include #include #include -#include "partition.h" -#include "utils.h" #include #include diff --git a/apps/utils/generate_pq.cpp b/apps/utils/generate_pq.cpp index a881b1104..cff7a3526 100644 --- a/apps/utils/generate_pq.cpp +++ b/apps/utils/generate_pq.cpp @@ -2,8 +2,8 @@ // Licensed under the MIT license. #include "math_utils.h" -#include "pq.h" #include "partition.h" +#include "pq.h" #define KMEANS_ITERS_FOR_PQ 15 diff --git a/apps/utils/generate_synthetic_labels.cpp b/apps/utils/generate_synthetic_labels.cpp index 6741760cb..766c297d7 100644 --- a/apps/utils/generate_synthetic_labels.cpp +++ b/apps/utils/generate_synthetic_labels.cpp @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include +#include "utils.h" #include -#include #include -#include "utils.h" +#include +#include +#include namespace po = boost::program_options; class ZipfDistribution diff --git a/apps/utils/int8_to_float.cpp b/apps/utils/int8_to_float.cpp index dcdfddc0d..8277b9a09 100644 --- a/apps/utils/int8_to_float.cpp +++ b/apps/utils/int8_to_float.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include int main(int argc, char **argv) { diff --git a/apps/utils/int8_to_float_scale.cpp b/apps/utils/int8_to_float_scale.cpp index 19fbc6c43..757e79be1 100644 --- a/apps/utils/int8_to_float_scale.cpp +++ b/apps/utils/int8_to_float_scale.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include void block_convert(std::ofstream &writer, float *write_buf, std::ifstream &reader, int8_t *read_buf, size_t npts, size_t ndims, float bias, float scale) diff --git a/apps/utils/ivecs_to_bin.cpp b/apps/utils/ivecs_to_bin.cpp index ea8a4a3d2..854c06839 100644 --- a/apps/utils/ivecs_to_bin.cpp +++ b/apps/utils/ivecs_to_bin.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include void block_convert(std::ifstream &reader, std::ofstream &writer, uint32_t *read_buf, uint32_t *write_buf, size_t npts, size_t ndims) diff --git a/apps/utils/merge_shards.cpp b/apps/utils/merge_shards.cpp index 106c15eef..be64e6ff9 100644 --- a/apps/utils/merge_shards.cpp +++ b/apps/utils/merge_shards.cpp @@ -10,8 +10,8 @@ #include #include -#include "disk_utils.h" #include "cached_io.h" +#include "disk_utils.h" #include "utils.h" int main(int argc, char **argv) diff --git a/apps/utils/partition_data.cpp b/apps/utils/partition_data.cpp index 2520f3f4a..42c22d231 100644 --- a/apps/utils/partition_data.cpp +++ b/apps/utils/partition_data.cpp @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include #include "cached_io.h" #include "partition.h" +#include +#include // DEPRECATED: NEED TO REPROGRAM diff --git a/apps/utils/partition_with_ram_budget.cpp b/apps/utils/partition_with_ram_budget.cpp index 937b68d2c..c5b6ed596 100644 --- a/apps/utils/partition_with_ram_budget.cpp +++ b/apps/utils/partition_with_ram_budget.cpp @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include #include "cached_io.h" #include "partition.h" +#include +#include // DEPRECATED: NEED TO REPROGRAM diff --git a/apps/utils/rand_data_gen.cpp b/apps/utils/rand_data_gen.cpp index e89ede800..799aa0f33 100644 --- a/apps/utils/rand_data_gen.cpp +++ b/apps/utils/rand_data_gen.cpp @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include +#include +#include #include +#include #include -#include -#include #include "utils.h" @@ -128,7 +128,8 @@ int main(int argc, char **argv) desc.add_options()("norm", po::value(&norm)->default_value(-1.0f), "Norm of the vectors (if not specified, vectors are not normalized)"); desc.add_options()("rand_scaling", po::value(&rand_scaling)->default_value(1.0f), - "Each vector will be scaled (if not explicitly normalized) by a factor randomly chosen from " + "Each vector will be scaled (if not explicitly normalized) by a factor " + "randomly chosen from " "[1, rand_scale]. Only applicable for floating point data"); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); @@ -158,13 +159,17 @@ int main(int argc, char **argv) if (rand_scaling < 1.0) { - std::cout << "We will only scale the vector norms randomly in [1, value], so value must be >= 1." << std::endl; + std::cout << "We will only scale the vector norms randomly in [1, value], " + "so value must be >= 1." + << std::endl; return -1; } if ((rand_scaling > 1.0) && (normalization == true)) { - std::cout << "Data cannot be normalized and randomly scaled at same time. Use one or the other." << std::endl; + std::cout << "Data cannot be normalized and randomly scaled at same time. " + "Use one or the other." + << std::endl; return -1; } diff --git a/apps/utils/simulate_aggregate_recall.cpp b/apps/utils/simulate_aggregate_recall.cpp index 73c4ea0f7..30cb24f13 100644 --- a/apps/utils/simulate_aggregate_recall.cpp +++ b/apps/utils/simulate_aggregate_recall.cpp @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include +#include #include +#include #include -#include inline float aggregate_recall(const uint32_t k_aggr, const uint32_t k, const uint32_t npart, uint32_t *count, const std::vector &recalls) diff --git a/apps/utils/stats_label_data.cpp b/apps/utils/stats_label_data.cpp index 3342672ff..1fad04b61 100644 --- a/apps/utils/stats_label_data.cpp +++ b/apps/utils/stats_label_data.cpp @@ -1,28 +1,28 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include -#include -#include -#include -#include -#include -#include -#include #include +#include +#include #include +#include #include +#include +#include #include -#include +#include +#include +#include +#include +#include #include "utils.h" #ifndef _WINDOWS #include -#include #include #include +#include #else #include #endif diff --git a/apps/utils/tsv_to_bin.cpp b/apps/utils/tsv_to_bin.cpp index c590a8f73..9d52f70a2 100644 --- a/apps/utils/tsv_to_bin.cpp +++ b/apps/utils/tsv_to_bin.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include void block_convert_float(std::ifstream &reader, std::ofstream &writer, size_t npts, size_t ndims) { diff --git a/apps/utils/uint32_to_uint8.cpp b/apps/utils/uint32_to_uint8.cpp index 87b6fb8ed..348dcaa20 100644 --- a/apps/utils/uint32_to_uint8.cpp +++ b/apps/utils/uint32_to_uint8.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include int main(int argc, char **argv) { diff --git a/apps/utils/uint8_to_float.cpp b/apps/utils/uint8_to_float.cpp index 6415b7c92..352aea00c 100644 --- a/apps/utils/uint8_to_float.cpp +++ b/apps/utils/uint8_to_float.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "utils.h" +#include int main(int argc, char **argv) { diff --git a/apps/utils/vector_analysis.cpp b/apps/utils/vector_analysis.cpp index 009df6d05..63364bc67 100644 --- a/apps/utils/vector_analysis.cpp +++ b/apps/utils/vector_analysis.cpp @@ -1,18 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include #include #include #include #include +#include #include #include #include +#include #include #include -#include #include #include #include diff --git a/include/abstract_data_store.h b/include/abstract_data_store.h index 165ada696..327f2b109 100644 --- a/include/abstract_data_store.h +++ b/include/abstract_data_store.h @@ -3,12 +3,12 @@ #pragma once -#include #include +#include +#include "distance.h" #include "types.h" #include "windows_customizations.h" -#include "distance.h" namespace diskann { @@ -80,9 +80,10 @@ template class AbstractDataStore // num_points) to zero virtual void copy_vectors(const location_t from_loc, const location_t to_loc, const location_t num_points) = 0; - // With the PQ Data Store PR, we have also changed iterate_to_fixed_point to NOT take the query - // from the scratch object. Therefore every data store has to implement preprocess_query which - // at the least will be to copy the query into the scratch object. So making this pure virtual. + // With the PQ Data Store PR, we have also changed iterate_to_fixed_point to + // NOT take the query from the scratch object. Therefore every data store has + // to implement preprocess_query which at the least will be to copy the query + // into the scratch object. So making this pure virtual. virtual void preprocess_query(const data_t *aligned_query, AbstractScratch *query_scratch = nullptr) const = 0; // distance functions. @@ -99,9 +100,9 @@ template class AbstractDataStore // in the dataset virtual location_t calculate_medoid() const = 0; - // REFACTOR PQ TODO: Each data store knows about its distance function, so this is - // redundant. However, we don't have an OptmizedDataStore yet, and to preserve code - // compability, we are exposing this function. + // REFACTOR PQ TODO: Each data store knows about its distance function, so + // this is redundant. However, we don't have an OptmizedDataStore yet, and to + // preserve code compability, we are exposing this function. virtual Distance *get_dist_fn() const = 0; // search helpers diff --git a/include/abstract_filter_store.h b/include/abstract_filter_store.h new file mode 100644 index 000000000..858c6e283 --- /dev/null +++ b/include/abstract_filter_store.h @@ -0,0 +1,25 @@ +#pragma once +#include "types.h" +#include "windows_customizations.h" +#include + +namespace diskann +{ +template class AbstractFilterStore +{ + public: + DISKANN_DLLEXPORT virtual bool has_filter_support() const = 0; + + DISKANN_DLLEXPORT virtual bool point_has_label(location_t point_id, const LabelT label_id) const = 0; + + // Returns true if the index is filter-enabled and all files were loaded + // correctly. false otherwise. Note that "false" can mean that the index + // does not have filter support, or that some index files do not exist, or + // that they exist and could not be opened. + DISKANN_DLLEXPORT virtual bool load(const std::string &disk_index_file) = 0; + + DISKANN_DLLEXPORT virtual void generate_random_labels(std::vector &labels, const uint32_t num_labels, + const uint32_t nthreads) = 0; +}; + +} // namespace diskann diff --git a/include/abstract_graph_store.h b/include/abstract_graph_store.h index 115d9ed1c..69bda7fd4 100644 --- a/include/abstract_graph_store.h +++ b/include/abstract_graph_store.h @@ -3,9 +3,9 @@ #pragma once +#include "types.h" #include #include -#include "types.h" #include "neighbor_list.h" namespace diskann diff --git a/include/abstract_index.h b/include/abstract_index.h index 7c84a8ec9..552924abb 100644 --- a/include/abstract_index.h +++ b/include/abstract_index.h @@ -1,10 +1,10 @@ #pragma once #include "distance.h" +#include "index_build_params.h" +#include "index_config.h" #include "parameters.h" -#include "utils.h" #include "types.h" -#include "index_config.h" -#include "index_build_params.h" +#include "utils.h" #include namespace diskann @@ -32,8 +32,9 @@ struct consolidation_report } }; -/* A templated independent class for intercation with Index. Uses Type Erasure to add virtual implemetation of methods -that can take any type(using std::any) and Provides a clean API that can be inherited by different type of Index. +/* A templated independent class for intercation with Index. Uses Type Erasure +to add virtual implemetation of methods that can take any type(using std::any) +and Provides a clean API that can be inherited by different type of Index. */ class AbstractIndex { diff --git a/include/aligned_file_reader.h b/include/aligned_file_reader.h index f39d5da39..447b34609 100644 --- a/include/aligned_file_reader.h +++ b/include/aligned_file_reader.h @@ -5,8 +5,8 @@ #define MAX_IO_DEPTH 128 -#include #include +#include #ifndef _WINDOWS #include @@ -63,12 +63,12 @@ struct IOContext #endif -#include +#include "tsl/robin_map.h" +#include "utils.h" #include +#include #include #include -#include "tsl/robin_map.h" -#include "utils.h" // NOTE :: all 3 fields must be 512-aligned struct AlignedRead diff --git a/include/ann_exception.h b/include/ann_exception.h index 6b81373c1..a9b940573 100644 --- a/include/ann_exception.h +++ b/include/ann_exception.h @@ -2,10 +2,10 @@ // Licensed under the MIT license. #pragma once -#include +#include "windows_customizations.h" #include +#include #include -#include "windows_customizations.h" #ifndef _WINDOWS #define __FUNCSIG__ __PRETTY_FUNCTION__ diff --git a/include/any_wrappers.h b/include/any_wrappers.h index da9005cfb..f35ac947c 100644 --- a/include/any_wrappers.h +++ b/include/any_wrappers.h @@ -3,11 +3,11 @@ #pragma once -#include +#include "tsl/robin_set.h" +#include #include +#include #include -#include -#include "tsl/robin_set.h" namespace AnyWrapper { diff --git a/include/cached_io.h b/include/cached_io.h index daef2f2f7..dabe448dc 100644 --- a/include/cached_io.h +++ b/include/cached_io.h @@ -2,13 +2,14 @@ // Licensed under the MIT license. #pragma once +#include #include #include #include #include -#include "logger.h" #include "ann_exception.h" +#include "logger.h" // sequential cached reads class cached_ifstream diff --git a/include/common_includes.h b/include/common_includes.h index e1a51bdec..8703bd10b 100644 --- a/include/common_includes.h +++ b/include/common_includes.h @@ -14,14 +14,14 @@ #include #include #include -#include #include +#include #include #include #include #include #include -#include #include +#include #include #include diff --git a/include/cosine_similarity.h b/include/cosine_similarity.h index dc51f6c0a..af62eb53b 100644 --- a/include/cosine_similarity.h +++ b/include/cosine_similarity.h @@ -3,16 +3,16 @@ #pragma once -#include -#include -#include +#include #include #include #include -#include +#include #include -#include +#include #include +#include +#include #include "simd_utils.h" diff --git a/include/disk_utils.h b/include/disk_utils.h index 08f046dcd..1acb7f981 100644 --- a/include/disk_utils.h +++ b/include/disk_utils.h @@ -3,10 +3,10 @@ #pragma once #include -#include #include #include #include +#include #include #include #include diff --git a/include/distance.h b/include/distance.h index f3b1de25a..7a3ec8b26 100644 --- a/include/distance.h +++ b/include/distance.h @@ -1,5 +1,6 @@ #pragma once #include "windows_customizations.h" +#include #include namespace diskann diff --git a/include/filter_utils.h b/include/filter_utils.h index 55f7aed28..ba5b2d601 100644 --- a/include/filter_utils.h +++ b/include/filter_utils.h @@ -3,19 +3,19 @@ #pragma once #include -#include #include #include #include +#include #include #include #include #include #include -#include #include #include #include +#include #ifdef __APPLE__ #else #include diff --git a/include/in_mem_data_store.h b/include/in_mem_data_store.h index 184b58160..6174d1922 100644 --- a/include/in_mem_data_store.h +++ b/include/in_mem_data_store.h @@ -2,8 +2,8 @@ // Licensed under the MIT license. #pragma once -#include #include +#include #include "tsl/robin_map.h" #include "tsl/robin_set.h" @@ -12,10 +12,10 @@ #include "abstract_data_store.h" +#include "aligned_file_reader.h" #include "distance.h" #include "natural_number_map.h" #include "natural_number_set.h" -#include "aligned_file_reader.h" namespace diskann { diff --git a/include/in_mem_filter_store.h b/include/in_mem_filter_store.h new file mode 100644 index 000000000..4915f37ee --- /dev/null +++ b/include/in_mem_filter_store.h @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "abstract_filter_store.h" +#include "ann_exception.h" +#include "logger.h" +#include "tsl/robin_map.h" +#include "tsl/robin_set.h" +#include "windows_customizations.h" +#include +#include +#include + +namespace diskann +{ +template class InMemFilterStore : public AbstractFilterStore +{ + public: + // Do nothing constructor because all the work is done in load() + DISKANN_DLLEXPORT InMemFilterStore() + { + } + + /// + /// Destructor + /// + DISKANN_DLLEXPORT virtual ~InMemFilterStore(); + + // No copy, no assignment. + DISKANN_DLLEXPORT InMemFilterStore &operator=(const InMemFilterStore &v) = delete; + DISKANN_DLLEXPORT + InMemFilterStore(const InMemFilterStore &v) = delete; + + DISKANN_DLLEXPORT virtual bool has_filter_support() const; + + DISKANN_DLLEXPORT virtual const std::unordered_map> &get_label_to_medoids() const; + + DISKANN_DLLEXPORT virtual const std::vector &get_medoids_of_label(const LabelT label); + + DISKANN_DLLEXPORT virtual void set_universal_label(const LabelT univ_label); + + DISKANN_DLLEXPORT inline bool point_has_label(location_t point_id, const LabelT label_id) const + { + uint32_t start_vec = _pts_to_label_offsets[point_id]; + uint32_t num_lbls = _pts_to_label_counts[point_id]; + bool ret_val = false; + for (uint32_t i = 0; i < num_lbls; i++) + { + if (_pts_to_labels[start_vec + i] == label_id) + { + ret_val = true; + break; + } + } + return ret_val; + } + + DISKANN_DLLEXPORT inline bool is_dummy_point(location_t id) const + { + return _dummy_pts.find(id) != _dummy_pts.end(); + } + + DISKANN_DLLEXPORT inline location_t get_real_point_for_dummy(location_t dummy_id) + { + if (is_dummy_point(dummy_id)) + { + return _dummy_to_real_map[dummy_id]; + } + else + { + return dummy_id; // it is a real point. + } + } + + DISKANN_DLLEXPORT inline bool point_has_label_or_universal_label(location_t id, const LabelT filter_label) const + { + return point_has_label(id, filter_label) || + (_use_universal_label && point_has_label(id, _universal_filter_label)); + } + + DISKANN_DLLEXPORT inline LabelT get_converted_label(const std::string &filter_label) + { + if (_label_map.find(filter_label) != _label_map.end()) + { + return _label_map[filter_label]; + } + if (_use_universal_label) + { + return _universal_filter_label; + } + std::stringstream stream; + stream << "Unable to find label in the Label Map"; + diskann::cerr << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + + // Returns true if the index is filter-enabled and all files were loaded + // correctly. false otherwise. Note that "false" can mean that the index + // does not have filter support, or that some index files do not exist, or + // that they exist and could not be opened. + DISKANN_DLLEXPORT bool load(const std::string &disk_index_file); + + DISKANN_DLLEXPORT void generate_random_labels(std::vector &labels, const uint32_t num_labels, + const uint32_t nthreads); + + private: + // Load functions for search START + void load_label_file(const std::string_view &file_content); + void load_label_map(std::basic_istream &map_reader); + void load_labels_to_medoids(std::basic_istream &reader); + void load_dummy_map(std::basic_istream &dummy_map_stream); + void parse_universal_label(const std::string_view &content); + void get_label_file_metadata(const std::string_view &fileContent, uint32_t &num_pts, uint32_t &num_total_labels); + + bool load_file_and_parse(const std::string &filename, + void (InMemFilterStore::*parse_fn)(const std::string_view &content)); + bool parse_stream(const std::string &filename, + void (InMemFilterStore::*parse_fn)(std::basic_istream &stream)); + + void reset_stream_for_reading(std::basic_istream &infile); + // Load functions for search END + + location_t _num_points = 0; + location_t *_pts_to_label_offsets = nullptr; + location_t *_pts_to_label_counts = nullptr; + LabelT *_pts_to_labels = nullptr; + bool _use_universal_label = false; + LabelT _universal_filter_label; + tsl::robin_set _dummy_pts; + tsl::robin_set _has_dummy_pts; + tsl::robin_map _dummy_to_real_map; + tsl::robin_map> _real_to_dummy_map; + std::unordered_map _label_map; + std::unordered_map> _filter_to_medoid_ids; + bool _is_valid = false; +}; + +} // namespace diskann diff --git a/include/index.h b/include/index.h index 320942013..7a142b4ac 100644 --- a/include/index.h +++ b/include/index.h @@ -9,22 +9,22 @@ #include "aligned_file_reader.h" #endif +#include "abstract_index.h" #include "distance.h" +#include "in_mem_data_store.h" +#include "in_mem_graph_store.h" #include "locking.h" #include "natural_number_map.h" #include "natural_number_set.h" #include "neighbor.h" #include "parameters.h" +#include "scratch.h" #include "utils.h" #include "windows_customizations.h" -#include "scratch.h" -#include "in_mem_data_store.h" -#include "in_mem_graph_store.h" -#include "abstract_index.h" -#include -#include "quantized_distance.h" #include "pq_data_store.h" +#include "quantized_distance.h" +#include #define OVERHEAD_FACTOR 1.1 #define EXPAND_IF_FULL 0 @@ -167,18 +167,19 @@ template clas public: // Constructor for Bulk operations and for creating the index object solely // for loading a prexisting index. - DISKANN_DLLEXPORT Index(const IndexConfig &index_config, std::shared_ptr> data_store, - std::unique_ptr graph_store, - std::shared_ptr> pq_data_store = nullptr); + DISKANN_DLLEXPORT + Index(const IndexConfig &index_config, std::shared_ptr> data_store, + std::unique_ptr graph_store, + std::shared_ptr> pq_data_store = nullptr); // Constructor for incremental index - DISKANN_DLLEXPORT Index(Metric m, const size_t dim, const size_t max_points, - const std::shared_ptr index_parameters, - const std::shared_ptr index_search_params, - const size_t num_frozen_pts = 0, const bool dynamic_index = false, - const bool enable_tags = false, const bool concurrent_consolidate = false, - const bool pq_dist_build = false, const size_t num_pq_chunks = 0, - const bool use_opq = false, const bool filtered_index = false); + DISKANN_DLLEXPORT + Index(Metric m, const size_t dim, const size_t max_points, + const std::shared_ptr index_parameters, + const std::shared_ptr index_search_params, const size_t num_frozen_pts = 0, + const bool dynamic_index = false, const bool enable_tags = false, const bool concurrent_consolidate = false, + const bool pq_dist_build = false, const size_t num_pq_chunks = 0, const bool use_opq = false, + const bool filtered_index = false); DISKANN_DLLEXPORT ~Index(); @@ -498,8 +499,9 @@ template clas // Filter Support bool _filtered_index = false; - // Location to label is only updated during insert_point(), all other reads are protected by - // default as a location can only be released at end of consolidate deletes + // Location to label is only updated during insert_point(), all other reads + // are protected by default as a location can only be released at end of + // consolidate deletes std::vector> _location_to_labels; tsl::robin_set _labels; std::string _labels_file; @@ -558,7 +560,8 @@ template clas std::shared_timed_mutex // Ensure only one consolidate or compact_data is _consolidate_lock; // ever active std::shared_timed_mutex // RW lock for _tag_to_location, - _tag_lock; // _location_to_tag, _empty_slots, _nd, _max_points, _label_to_start_id + _tag_lock; // _location_to_tag, _empty_slots, _nd, _max_points, + // _label_to_start_id std::shared_timed_mutex // RW Lock on _delete_set and _data_compacted _delete_lock; // variable diff --git a/include/index_build_params.h b/include/index_build_params.h index d4f454830..38434e204 100644 --- a/include/index_build_params.h +++ b/include/index_build_params.h @@ -1,5 +1,6 @@ #pragma once +#include "ann_exception.h" #include "common_includes.h" #include "parameters.h" @@ -32,7 +33,7 @@ class IndexFilterParamsBuilder IndexFilterParamsBuilder &with_save_path_prefix(const std::string &save_path_prefix) { if (save_path_prefix.empty() || save_path_prefix == "") - throw ANNException("Error: save_path_prefix can't be empty", -1); + throw diskann::ANNException("Error: save_path_prefix can't be empty", -1); this->_save_path_prefix = save_path_prefix; return *this; } diff --git a/include/index_config.h b/include/index_config.h index a6fa0d966..20b185c5f 100644 --- a/include/index_config.h +++ b/include/index_config.h @@ -1,7 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + #pragma once +#include "ann_exception.h" #include "common_includes.h" +#include "logger.h" #include "parameters.h" +#include namespace diskann { @@ -210,13 +216,17 @@ class IndexConfigBuilder if (_dynamic_index) { if (_index_search_params != nullptr && _index_search_params->initial_search_list_size == 0) - throw ANNException("Error: please pass initial_search_list_size for building dynamic index.", -1); + throw ANNException("Error: please pass initial_search_list_size for " + "building dynamic index.", + -1); } // sanity check if (_dynamic_index && _num_frozen_pts == 0) { - diskann::cout << "_num_frozen_pts passed as 0 for dynamic_index. Setting it to 1 for safety." << std::endl; + diskann::cout << "_num_frozen_pts passed as 0 for dynamic_index. Setting " + "it to 1 for safety." + << std::endl; _num_frozen_pts = 1; } diff --git a/include/index_factory.h b/include/index_factory.h index 76fb0b978..a41c1f50f 100644 --- a/include/index_factory.h +++ b/include/index_factory.h @@ -1,8 +1,8 @@ #pragma once -#include "index.h" #include "abstract_graph_store.h" #include "in_mem_graph_store.h" +#include "index.h" #include "pq_data_store.h" namespace diskann @@ -20,9 +20,9 @@ class IndexFactory DISKANN_DLLEXPORT static std::shared_ptr> construct_datastore(DataStoreStrategy stratagy, size_t num_points, size_t dimension, Metric m); - // For now PQDataStore incorporates within itself all variants of quantization that we support. In the - // future it may be necessary to introduce an AbstractPQDataStore class to spearate various quantization - // flavours. + // For now PQDataStore incorporates within itself all variants of quantization + // that we support. In the future it may be necessary to introduce an + // AbstractPQDataStore class to spearate various quantization flavours. template DISKANN_DLLEXPORT static std::shared_ptr> construct_pq_datastore(DataStoreStrategy strategy, size_t num_points, size_t dimension, diff --git a/include/logger.h b/include/logger.h index 0b17807db..f1c6ee7f3 100644 --- a/include/logger.h +++ b/include/logger.h @@ -2,9 +2,9 @@ // Licensed under the MIT license. #pragma once +#include "windows_customizations.h" #include #include -#include "windows_customizations.h" #ifdef EXEC_ENV_OLS #ifndef ENABLE_CUSTOM_LOGGER diff --git a/include/logger_impl.h b/include/logger_impl.h index 03c65e0ce..d2dfaf573 100644 --- a/include/logger_impl.h +++ b/include/logger_impl.h @@ -3,8 +3,8 @@ #pragma once -#include #include +#include #include "ann_exception.h" #include "logger.h" diff --git a/include/neighbor.h b/include/neighbor.h index d7c0c25ed..61a6932c1 100644 --- a/include/neighbor.h +++ b/include/neighbor.h @@ -3,10 +3,10 @@ #pragma once +#include "utils.h" #include #include #include -#include "utils.h" namespace diskann { diff --git a/include/parameters.h b/include/parameters.h index 0206814bd..50e7e4a1a 100644 --- a/include/parameters.h +++ b/include/parameters.h @@ -6,8 +6,8 @@ #include #include -#include "omp.h" #include "defaults.h" +#include "omp.h" namespace diskann { diff --git a/include/pq.h b/include/pq.h index 3e6119f22..db9226d8b 100644 --- a/include/pq.h +++ b/include/pq.h @@ -3,8 +3,8 @@ #pragma once -#include "utils.h" #include "pq_common.h" +#include "utils.h" namespace diskann { diff --git a/include/pq_common.h b/include/pq_common.h index c6a3a5739..d7a4b60f4 100644 --- a/include/pq_common.h +++ b/include/pq_common.h @@ -1,7 +1,7 @@ #pragma once -#include #include +#include #define NUM_PQ_BITS 8 #define NUM_PQ_CENTROIDS (1 << NUM_PQ_BITS) diff --git a/include/pq_data_store.h b/include/pq_data_store.h index 227b8a6af..385baddde 100644 --- a/include/pq_data_store.h +++ b/include/pq_data_store.h @@ -1,14 +1,15 @@ #pragma once -#include +#include "abstract_data_store.h" #include "distance.h" -#include "quantized_distance.h" #include "pq.h" -#include "abstract_data_store.h" +#include "quantized_distance.h" +#include namespace diskann { -// REFACTOR TODO: By default, the PQDataStore is an in-memory datastore because both Vamana and -// DiskANN treat it the same way. But with DiskPQ, that may need to change. +// REFACTOR TODO: By default, the PQDataStore is an in-memory datastore because +// both Vamana and DiskANN treat it the same way. But with DiskPQ, that may need +// to change. template class PQDataStore : public AbstractDataStore { @@ -30,8 +31,8 @@ template class PQDataStore : public AbstractDataStore // vectors file. virtual size_t save(const std::string &file_prefix, const location_t num_points) override; - // Since base class function is pure virtual, we need to declare it here, even though alignent concept is not needed - // for Quantized data stores. + // Since base class function is pure virtual, we need to declare it here, even + // though alignent concept is not needed for Quantized data stores. virtual size_t get_aligned_dim() const override; // Populate quantized data from unaligned data using PQ functionality diff --git a/include/pq_flash_index.h b/include/pq_flash_index.h index 8165e44fa..460f918a4 100644 --- a/include/pq_flash_index.h +++ b/include/pq_flash_index.h @@ -10,11 +10,13 @@ #include "parameters.h" #include "percentile_stats.h" #include "pq.h" -#include "utils.h" -#include "windows_customizations.h" #include "scratch.h" #include "tsl/robin_map.h" #include "tsl/robin_set.h" +#include "utils.h" +#include "windows_customizations.h" + +#include "in_mem_filter_store.h" #define FULL_PRECISION_REORDER_MULTIPLIER 3 @@ -95,16 +97,20 @@ template class PQFlashIndex DISKANN_DLLEXPORT uint64_t get_data_dim(); + DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &filter_label); + std::shared_ptr &reader; DISKANN_DLLEXPORT diskann::Metric get_metric(); // // node_ids: input list of node_ids to be read - // coord_buffers: pointers to pre-allocated buffers that coords need to copied to. If null, dont copy. - // nbr_buffers: pre-allocated buffers to copy neighbors into + // coord_buffers: pointers to pre-allocated buffers that coords need to copied + // to. If null, dont copy. nbr_buffers: pre-allocated buffers to copy + // neighbors into // - // returns a vector of bool one for each node_id: true if read is success, else false + // returns a vector of bool one for each node_id: true if read is success, + // else false // DISKANN_DLLEXPORT std::vector read_nodes(const std::vector &node_ids, std::vector &coord_buffers, @@ -117,18 +123,7 @@ template class PQFlashIndex DISKANN_DLLEXPORT void use_medoids_data_as_centroids(); DISKANN_DLLEXPORT void setup_thread_data(uint64_t nthreads, uint64_t visited_reserve = 4096); - DISKANN_DLLEXPORT void set_universal_label(const LabelT &label); - private: - DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id); - std::unordered_map load_label_map(std::basic_istream &infile); - DISKANN_DLLEXPORT void parse_label_file(std::basic_istream &infile, size_t &num_pts_labels); - DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts, - uint32_t &num_total_labels); - DISKANN_DLLEXPORT void generate_random_labels(std::vector &labels, const uint32_t num_labels, - const uint32_t nthreads); - void reset_stream_for_reading(std::basic_istream &infile); - // sector # on disk where node_id is present with in the graph part DISKANN_DLLEXPORT uint64_t get_node_sector(uint64_t node_id); @@ -148,8 +143,8 @@ template class PQFlashIndex // offset in sector: [(i % nnodes_per_sector) * max_node_len] // // index info for multi-sector nodes - // nhood of node `i` is in sector: [i * DIV_ROUND_UP(_max_node_len, SECTOR_LEN)] - // offset in sector: [0] + // nhood of node `i` is in sector: [i * DIV_ROUND_UP(_max_node_len, + // SECTOR_LEN)] offset in sector: [0] // // Common info // coords start at ofsset @@ -228,18 +223,10 @@ template class PQFlashIndex bool _reorder_data_exists = false; uint64_t _reoreder_data_offset = 0; - // filter support - uint32_t *_pts_to_label_offsets = nullptr; - uint32_t *_pts_to_label_counts = nullptr; - LabelT *_pts_to_labels = nullptr; - std::unordered_map> _filter_to_medoid_ids; - bool _use_universal_label = false; - LabelT _universal_filter_label; - tsl::robin_set _dummy_pts; - tsl::robin_set _has_dummy_pts; - tsl::robin_map _dummy_to_real_map; - tsl::robin_map> _real_to_dummy_map; - std::unordered_map _label_map; + // Moved filter-specific data structures to in_mem_filter_store. + // TODO: Make this a unique pointer + bool _filter_index = false; + std::unique_ptr> _filter_store; #ifdef EXEC_ENV_OLS // Set to a larger value than the actual header to accommodate diff --git a/include/pq_scratch.h b/include/pq_scratch.h index 95f1b1395..6b52463eb 100644 --- a/include/pq_scratch.h +++ b/include/pq_scratch.h @@ -1,7 +1,7 @@ #pragma once -#include #include "pq_common.h" #include "utils.h" +#include namespace diskann { diff --git a/include/quantized_distance.h b/include/quantized_distance.h index cc4aea929..44798ac96 100644 --- a/include/quantized_distance.h +++ b/include/quantized_distance.h @@ -1,8 +1,8 @@ #pragma once +#include "abstract_scratch.h" #include #include #include -#include "abstract_scratch.h" namespace diskann { @@ -48,9 +48,10 @@ template class QuantizedDistance virtual void preprocessed_distance(PQScratch &pq_scratch, const uint32_t n_ids, std::vector &dists_out) = 0; - // Currently this function is required for DiskPQ. However, it too can be subsumed - // under preprocessed_distance if we add the appropriate scratch variables to - // PQScratch and initialize them in pq_flash_index.cpp::disk_iterate_to_fixed_point() + // Currently this function is required for DiskPQ. However, it too can be + // subsumed under preprocessed_distance if we add the appropriate scratch + // variables to PQScratch and initialize them in + // pq_flash_index.cpp::disk_iterate_to_fixed_point() virtual float brute_force_distance(const float *query_vec, uint8_t *base_vec) = 0; }; } // namespace diskann diff --git a/include/restapi/search_wrapper.h b/include/restapi/search_wrapper.h index ebd067d8a..d41b2b7cd 100644 --- a/include/restapi/search_wrapper.h +++ b/include/restapi/search_wrapper.h @@ -3,9 +3,9 @@ #pragma once +#include #include #include -#include #include #include diff --git a/include/restapi/server.h b/include/restapi/server.h index 1d75847a2..ddb19d17a 100644 --- a/include/restapi/server.h +++ b/include/restapi/server.h @@ -3,8 +3,8 @@ #pragma once -#include #include +#include namespace diskann { diff --git a/include/scratch.h b/include/scratch.h index bfb5e5a62..d654cf4ef 100644 --- a/include/scratch.h +++ b/include/scratch.h @@ -7,15 +7,15 @@ #include "boost_dynamic_bitset_fwd.h" // #include "boost/dynamic_bitset.hpp" -#include "tsl/robin_set.h" #include "tsl/robin_map.h" +#include "tsl/robin_set.h" #include "tsl/sparse_map.h" -#include "aligned_file_reader.h" #include "abstract_scratch.h" -#include "neighbor.h" -#include "defaults.h" +#include "aligned_file_reader.h" #include "concurrent_queue.h" +#include "defaults.h" +#include "neighbor.h" namespace diskann { diff --git a/include/simd_utils.h b/include/simd_utils.h index 4b0736998..da59c0cde 100644 --- a/include/simd_utils.h +++ b/include/simd_utils.h @@ -2,9 +2,9 @@ #ifdef _WINDOWS #include +#include #include #include -#include #else #include #endif diff --git a/include/types.h b/include/types.h index 953d59a5f..58d8d40a4 100644 --- a/include/types.h +++ b/include/types.h @@ -3,10 +3,10 @@ #pragma once -#include -#include -#include #include "any_wrappers.h" +#include +#include +#include namespace diskann { diff --git a/include/utils.h b/include/utils.h index 0170d7297..5926f77f8 100644 --- a/include/utils.h +++ b/include/utils.h @@ -20,14 +20,14 @@ typedef HANDLE FileHandle; typedef int FileHandle; #endif +#include "ann_exception.h" +#include "cached_io.h" #include "distance.h" #include "logger.h" -#include "cached_io.h" -#include "ann_exception.h" -#include "windows_customizations.h" +#include "tag_uint128.h" #include "tsl/robin_set.h" #include "types.h" -#include "tag_uint128.h" +#include "windows_customizations.h" #include #ifdef EXEC_ENV_OLS @@ -1199,8 +1199,8 @@ template <> inline const char* diskann_type_to_name() } #ifdef _WINDOWS -#include #include +#include extern bool AvxSupportedCPU; extern bool Avx2SupportedCPU; diff --git a/include/windows_aligned_file_reader.h b/include/windows_aligned_file_reader.h index 0d9a3173c..e3a898b9a 100644 --- a/include/windows_aligned_file_reader.h +++ b/include/windows_aligned_file_reader.h @@ -9,13 +9,13 @@ #include #include -#include -#include -#include #include "aligned_file_reader.h" #include "tsl/robin_map.h" #include "utils.h" #include "windows_customizations.h" +#include +#include +#include class WindowsAlignedFileReader : public AlignedFileReader { diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index cbca26440..80b11754f 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -11,9 +11,10 @@ else() set(CPP_SOURCES abstract_data_store.cpp ann_exception.cpp disk_utils.cpp distance.cpp index.cpp in_mem_graph_store.cpp in_mem_data_store.cpp linux_aligned_file_reader.cpp math_utils.cpp natural_number_map.cpp - in_mem_data_store.cpp in_mem_graph_store.cpp + in_mem_data_store.cpp in_mem_graph_store.cpp in_mem_filter_store.cpp natural_number_set.cpp memory_mapper.cpp partition.cpp pq.cpp - pq_flash_index.cpp scratch.cpp logger.cpp utils.cpp filter_utils.cpp index_factory.cpp abstract_index.cpp pq_l2_distance.cpp pq_data_store.cpp) + pq_flash_index.cpp scratch.cpp logger.cpp utils.cpp filter_utils.cpp + index_factory.cpp abstract_index.cpp pq_l2_distance.cpp pq_data_store.cpp) if (RESTAPI) list(APPEND CPP_SOURCES restapi/search_wrapper.cpp restapi/server.cpp) endif() diff --git a/src/abstract_data_store.cpp b/src/abstract_data_store.cpp index 0cff0152e..79efaca45 100644 --- a/src/abstract_data_store.cpp +++ b/src/abstract_data_store.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include "abstract_data_store.h" +#include namespace diskann { diff --git a/src/abstract_index.cpp b/src/abstract_index.cpp index 7550bda3a..21f340a79 100644 --- a/src/abstract_index.cpp +++ b/src/abstract_index.cpp @@ -1,6 +1,6 @@ +#include "abstract_index.h" #include "common_includes.h" #include "windows_customizations.h" -#include "abstract_index.h" namespace diskann { diff --git a/src/disk_utils.cpp b/src/disk_utils.cpp index 3a5ec170d..165da9c1e 100644 --- a/src/disk_utils.cpp +++ b/src/disk_utils.cpp @@ -7,14 +7,14 @@ #include "gperftools/malloc_extension.h" #endif -#include "logger.h" -#include "disk_utils.h" #include "cached_io.h" +#include "disk_utils.h" #include "index.h" +#include "logger.h" #include "mkl.h" #include "omp.h" -#include "percentile_stats.h" #include "partition.h" +#include "percentile_stats.h" #include "pq_flash_index.h" #include "timer.h" #include "tsl/robin_set.h" @@ -1133,7 +1133,8 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const (compareMetric == diskann::Metric::INNER_PRODUCT || compareMetric == diskann::Metric::COSINE)) { std::stringstream stream; - stream << "Disk-index build currently only supports floating point data for Max " + stream << "Disk-index build currently only supports floating point data " + "for Max " "Inner Product Search/ cosine similarity. " << std::endl; throw diskann::ANNException(stream.str(), -1); @@ -1196,9 +1197,8 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const std::string disk_pq_pivots_path = index_prefix_path + "_disk.index_pq_pivots.bin"; // optional, used if disk index must store pq data std::string disk_pq_compressed_vectors_path = index_prefix_path + "_disk.index_pq_compressed.bin"; - std::string prepped_base = - index_prefix_path + - "_prepped_base.bin"; // temp file for storing pre-processed base file for cosine/ mips metrics + std::string prepped_base = index_prefix_path + "_prepped_base.bin"; // temp file for storing pre-processed base file + // for cosine/ mips metrics bool created_temp_file_for_processed_data = false; // output a new base file which contains extra dimension with sqrt(1 - @@ -1210,7 +1210,8 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const std::cout << "Using Inner Product search, so need to pre-process base " "data into temp file. Please ensure there is additional " "(n*(d+1)*4) bytes for storing pre-processed base vectors, " - "apart from the interim indices created by DiskANN and the final index." + "apart from the interim indices created by DiskANN and the " + "final index." << std::endl; data_file_to_use = prepped_base; float max_norm_of_base = diskann::prepare_base_for_inner_products(base_file, prepped_base); @@ -1222,9 +1223,11 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const else if (compareMetric == diskann::Metric::COSINE) { Timer timer; - std::cout << "Normalizing data for cosine to temporary file, please ensure there is additional " + std::cout << "Normalizing data for cosine to temporary file, please ensure " + "there is additional " "(n*d*4) bytes for storing normalized base vectors, " - "apart from the interim indices created by DiskANN and the final index." + "apart from the interim indices created by DiskANN and the " + "final index." << std::endl; data_file_to_use = prepped_base; diskann::normalize_data_file(base_file, prepped_base); @@ -1321,7 +1324,8 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const #if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD) MallocExtension::instance()->ReleaseFreeMemory(); #endif - // Whether it is cosine or inner product, we still L2 metric due to the pre-processing. + // Whether it is cosine or inner product, we still L2 metric due to the + // pre-processing. timer.reset(); diskann::build_merged_vamana_index(data_file_to_use.c_str(), diskann::Metric::L2, L, R, p_val, indexing_ram_budget, mem_index_path, medoids_path, centroids_path, diff --git a/src/distance.cpp b/src/distance.cpp index c2f88c85b..957453ab8 100644 --- a/src/distance.cpp +++ b/src/distance.cpp @@ -3,9 +3,9 @@ #ifdef _WINDOWS #include +#include #include #include -#include #else #include #endif @@ -14,10 +14,10 @@ #include #include +#include "ann_exception.h" #include "distance.h" -#include "utils.h" #include "logger.h" -#include "ann_exception.h" +#include "utils.h" namespace diskann { diff --git a/src/dll/CMakeLists.txt b/src/dll/CMakeLists.txt index 096d1b76e..633e8edb7 100644 --- a/src/dll/CMakeLists.txt +++ b/src/dll/CMakeLists.txt @@ -4,7 +4,8 @@ add_library(${PROJECT_NAME} SHARED dllmain.cpp ../abstract_data_store.cpp ../partition.cpp ../pq.cpp ../pq_flash_index.cpp ../logger.cpp ../utils.cpp ../windows_aligned_file_reader.cpp ../distance.cpp ../pq_l2_distance.cpp ../memory_mapper.cpp ../index.cpp ../in_mem_data_store.cpp ../pq_data_store.cpp ../in_mem_graph_store.cpp ../math_utils.cpp ../disk_utils.cpp ../filter_utils.cpp - ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp) + ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp + ../in_mem_filter_store.cpp) set(TARGET_DIR "$<$:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}>$<$:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}>") diff --git a/src/filter_utils.cpp b/src/filter_utils.cpp index 09d740e35..d5502d361 100644 --- a/src/filter_utils.cpp +++ b/src/filter_utils.cpp @@ -8,11 +8,11 @@ #include #include -#include #include "filter_utils.h" #include "index.h" #include "parameters.h" #include "utils.h" +#include namespace diskann { @@ -266,7 +266,8 @@ parse_label_file_return_values parse_label_file(path label_data_path, std::strin * as either uint16_t or uint32_t * * Returns two objects via std::tuple: - * 1. a vector of vectors of labels, where the outer vector is indexed by point id + * 1. a vector of vectors of labels, where the outer vector is indexed by point + * id * 2. a set of all labels */ template diff --git a/src/in_mem_data_store.cpp b/src/in_mem_data_store.cpp index 1a9e822dc..9fe96689c 100644 --- a/src/in_mem_data_store.cpp +++ b/src/in_mem_data_store.cpp @@ -1,9 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include "abstract_scratch.h" #include "in_mem_data_store.h" +#include "abstract_scratch.h" +#include #include "utils.h" diff --git a/src/in_mem_filter_store.cpp b/src/in_mem_filter_store.cpp new file mode 100644 index 000000000..de581d99d --- /dev/null +++ b/src/in_mem_filter_store.cpp @@ -0,0 +1,410 @@ +#include "in_mem_filter_store.h" +#include "ann_exception.h" +#include "tsl/robin_map.h" +#include "tsl/robin_set.h" +#include "utils.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace diskann +{ +// TODO: Move to utils.h +DISKANN_DLLEXPORT std::unique_ptr get_file_content(const std::string &filename, uint64_t &file_size); + +template InMemFilterStore::~InMemFilterStore() +{ + if (_pts_to_label_offsets != nullptr) + { + delete[] _pts_to_label_offsets; + _pts_to_label_offsets = nullptr; + } + if (_pts_to_label_counts != nullptr) + { + delete[] _pts_to_label_counts; + _pts_to_label_counts = nullptr; + } + if (_pts_to_labels != nullptr) + { + delete[] _pts_to_labels; + _pts_to_labels = nullptr; + } +} + +template +const std::unordered_map> &InMemFilterStore::get_label_to_medoids() const +{ + return this->_filter_to_medoid_ids; +} + +template +const std::vector &InMemFilterStore::get_medoids_of_label(const LabelT label) +{ + if (_filter_to_medoid_ids.find(label) != _filter_to_medoid_ids.end()) + { + return this->_filter_to_medoid_ids[label]; + } + else + { + std::stringstream ss; + ss << "Could not find " << label << " in filters_to_medoid_ids map." << std::endl; + diskann::cerr << ss.str(); + throw ANNException(ss.str(), -1); + } +} + +template void InMemFilterStore::set_universal_label(const LabelT univ_label) +{ + _universal_filter_label = univ_label; + _use_universal_label = true; +} + +// Load functions for SEARCH START +template bool InMemFilterStore::load(const std::string &disk_index_file) +{ + std::string labels_file = disk_index_file + "_labels.txt"; + std::string labels_to_medoids = disk_index_file + "_labels_to_medoids.txt"; + std::string dummy_map_file = disk_index_file + "_dummy_map.txt"; + std::string labels_map_file = disk_index_file + "_labels_map.txt"; + std::string univ_label_file = disk_index_file + "_universal_label.txt"; + + size_t num_pts_in_label_file = 0; + + // TODO: Check for encoding issues here. We are opening files as binary and + // reading them as bytes, not sure if that can cause an issue with UTF + // encodings. + bool has_filters = true; + if (false == load_file_and_parse(labels_file, &InMemFilterStore::load_label_file)) + { + diskann::cout << "Index does not have filter data. " << std::endl; + return false; + } + if (false == parse_stream(labels_map_file, &InMemFilterStore::load_label_map)) + { + diskann::cerr << "Failed to find file: " << labels_map_file << " while labels_file exists." << std::endl; + return false; + } + + if (false == parse_stream(labels_to_medoids, &InMemFilterStore::load_labels_to_medoids)) + { + diskann::cerr << "Failed to find file: " << labels_to_medoids << " while labels file exists." << std::endl; + return false; + } + // missing universal label file is NOT an error. + load_file_and_parse(univ_label_file, &InMemFilterStore::parse_universal_label); + + // missing dummy map file is also NOT an error. + parse_stream(dummy_map_file, &InMemFilterStore::load_dummy_map); + _is_valid = true; + return _is_valid; +} + +template bool InMemFilterStore::has_filter_support() const +{ + return _is_valid; +} + +// TODO: Improve this to not load the entire file in memory +template void InMemFilterStore::load_label_file(const std::string_view &label_file_content) +{ + std::string line; + uint32_t line_cnt = 0; + + uint32_t num_pts_in_label_file; + uint32_t num_total_labels; + get_label_file_metadata(label_file_content, num_pts_in_label_file, num_total_labels); + + _num_points = num_pts_in_label_file; + + _pts_to_label_offsets = new uint32_t[num_pts_in_label_file]; + _pts_to_label_counts = new uint32_t[num_pts_in_label_file]; + _pts_to_labels = new LabelT[num_total_labels]; + uint32_t labels_seen_so_far = 0; + + std::string label_str; + size_t cur_pos = 0; + size_t next_pos = 0; + size_t file_size = label_file_content.size(); + + while (cur_pos < file_size && cur_pos != std::string_view::npos) + { + next_pos = label_file_content.find('\n', cur_pos); + if (next_pos == std::string_view::npos) + { + break; + } + + _pts_to_label_offsets[line_cnt] = labels_seen_so_far; + uint32_t &num_lbls_in_cur_pt = _pts_to_label_counts[line_cnt]; + num_lbls_in_cur_pt = 0; + + size_t lbl_pos = cur_pos; + size_t next_lbl_pos = 0; + while (lbl_pos < next_pos && lbl_pos != std::string_view::npos) + { + next_lbl_pos = label_file_content.find(',', lbl_pos); + if (next_lbl_pos == std::string_view::npos) // the last label in the whole file + { + next_lbl_pos = next_pos; + } + + if (next_lbl_pos > next_pos) // the last label in one line, just read to the end + { + next_lbl_pos = next_pos; + } + + // TODO: SHOULD NOT EXPECT label_file_content TO BE NULL_TERMINATED + label_str.assign(label_file_content.data() + lbl_pos, next_lbl_pos - lbl_pos); + if (label_str[label_str.length() - 1] == '\t') // '\t' won't exist in label file? + { + label_str.erase(label_str.length() - 1); + } + + LabelT token_as_num = (LabelT)std::stoul(label_str); + _pts_to_labels[labels_seen_so_far++] = (LabelT)token_as_num; + num_lbls_in_cur_pt++; + + // move to next label + lbl_pos = next_lbl_pos + 1; + } + + // move to next line + cur_pos = next_pos + 1; + + if (num_lbls_in_cur_pt == 0) + { + diskann::cout << "No label found for point " << line_cnt << std::endl; + exit(-1); + } + + line_cnt++; + } + + // TODO: We need to check if the number of labels and the number of points + // is as expected. Maybe add the check in PQFlashIndex? + // num_points_labels = line_cnt; +} + +template +void InMemFilterStore::load_labels_to_medoids(std::basic_istream &medoid_stream) +{ + std::string line, token; + + _filter_to_medoid_ids.clear(); + while (std::getline(medoid_stream, line)) + { + std::istringstream iss(line); + uint32_t cnt = 0; + std::vector medoids; + LabelT label; + while (std::getline(iss, token, ',')) + { + if (cnt == 0) + label = (LabelT)std::stoul(token); + else + medoids.push_back((uint32_t)stoul(token)); + cnt++; + } + _filter_to_medoid_ids[label].swap(medoids); + } +} + +template void InMemFilterStore::load_label_map(std::basic_istream &map_reader) +{ + std::string line, token; + LabelT token_as_num; + std::string label_str; + while (std::getline(map_reader, line)) + { + std::istringstream iss(line); + getline(iss, token, '\t'); + label_str = token; + getline(iss, token, '\t'); + token_as_num = (LabelT)std::stoul(token); + _label_map[label_str] = token_as_num; + } +} + +template void InMemFilterStore::parse_universal_label(const std::string_view &content) +{ + LabelT label_as_num = (LabelT)std::stoul(std::string(content)); + this->set_universal_label(label_as_num); +} + +template void InMemFilterStore::load_dummy_map(std::basic_istream &dummy_map_stream) +{ + std::string line, token; + + while (std::getline(dummy_map_stream, line)) + { + std::istringstream iss(line); + uint32_t cnt = 0; + uint32_t dummy_id; + uint32_t real_id; + while (std::getline(iss, token, ',')) + { + if (cnt == 0) + dummy_id = (uint32_t)stoul(token); + else + real_id = (uint32_t)stoul(token); + cnt++; + } + _dummy_pts.insert(dummy_id); + _has_dummy_pts.insert(real_id); + _dummy_to_real_map[dummy_id] = real_id; + + if (_real_to_dummy_map.find(real_id) == _real_to_dummy_map.end()) + _real_to_dummy_map[real_id] = std::vector(); + + _real_to_dummy_map[real_id].emplace_back(dummy_id); + } + diskann::cout << "Loaded dummy map" << std::endl; +} + +template +void InMemFilterStore::generate_random_labels(std::vector &labels, const uint32_t num_labels, + const uint32_t nthreads) +{ + std::random_device rd; + labels.clear(); + labels.resize(num_labels); + + uint64_t num_total_labels = _pts_to_label_offsets[_num_points - 1] + _pts_to_label_counts[_num_points - 1]; + std::mt19937 gen(rd()); + if (num_total_labels == 0) + { + std::stringstream stream; + stream << "No labels found in data. Not sampling random labels "; + diskann::cerr << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + std::uniform_int_distribution dis(0, num_total_labels - 1); + +#pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads) + for (int64_t i = 0; i < num_labels; i++) + { + uint64_t rnd_loc = dis(gen); + labels[i] = (LabelT)_pts_to_labels[rnd_loc]; + } +} + +template void InMemFilterStore::reset_stream_for_reading(std::basic_istream &infile) +{ + infile.clear(); + infile.seekg(0); +} + +template +void InMemFilterStore::get_label_file_metadata(const std::string_view &fileContent, uint32_t &num_pts, + uint32_t &num_total_labels) +{ + num_pts = 0; + num_total_labels = 0; + + size_t file_size = fileContent.length(); + + std::string label_str; + size_t cur_pos = 0; + size_t next_pos = 0; + while (cur_pos < file_size && cur_pos != std::string::npos) + { + next_pos = fileContent.find('\n', cur_pos); + if (next_pos == std::string::npos) + { + break; + } + + size_t lbl_pos = cur_pos; + size_t next_lbl_pos = 0; + while (lbl_pos < next_pos && lbl_pos != std::string::npos) + { + next_lbl_pos = fileContent.find(',', lbl_pos); + if (next_lbl_pos == std::string::npos) // the last label + { + next_lbl_pos = next_pos; + } + + num_total_labels++; + + lbl_pos = next_lbl_pos + 1; + } + + cur_pos = next_pos + 1; + + num_pts++; + } + + diskann::cout << "Labels file metadata: num_points: " << num_pts << ", #total_labels: " << num_total_labels + << std::endl; +} + +template +bool InMemFilterStore::parse_stream(const std::string &filename, + void (InMemFilterStore::*parse_fn)(std::basic_istream &stream)) +{ + if (file_exists(filename)) + { + std::ifstream stream(filename); + if (false == stream.fail()) + { + std::invoke(parse_fn, this, stream); + return true; + } + else + { + std::stringstream ss; + ss << "Could not open file: " << filename << std::endl; + throw diskann::ANNException(ss.str(), -1); + } + } + else + { + return false; + } +} + +template +bool InMemFilterStore::load_file_and_parse(const std::string &filename, + void (InMemFilterStore::*parse_fn)(const std::string_view &content)) +{ + if (file_exists(filename)) + { + size_t file_size = 0; + auto file_content_ptr = get_file_content(filename, file_size); + std::string_view content_as_str(file_content_ptr.get(), file_size); + std::invoke(parse_fn, this, content_as_str); + return true; + } + else + { + return false; + } +} + +std::unique_ptr get_file_content(const std::string &filename, uint64_t &file_size) +{ + std::ifstream infile(filename, std::ios::binary); + if (infile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + filename, -1); + } + infile.seekg(0, std::ios::end); + file_size = infile.tellg(); + + auto buffer = new char[file_size]; + infile.seekg(0, std::ios::beg); + infile.read(buffer, file_size); + + return std::unique_ptr(buffer); +} +// Load functions for SEARCH END +template class InMemFilterStore; +template class InMemFilterStore; +template class InMemFilterStore; + +} // namespace diskann diff --git a/src/index.cpp b/src/index.cpp index 0b01afa20..be9d8e881 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1,19 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include - -#include - +#include "ann_exception.h" #include "boost/dynamic_bitset.hpp" #include "index_factory.h" #include "memory_mapper.h" +#include "tag_uint128.h" #include "timer.h" #include "tsl/robin_map.h" #include "tsl/robin_set.h" +#include "utils.h" #include "windows_customizations.h" -#include "tag_uint128.h" +#include +#include #if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD) #include "gperftools/malloc_extension.h" #endif @@ -334,7 +333,8 @@ void Index::save(const char *filename, bool compact_before_save } label_writer.close(); - // write compacted raw_labels if data hence _location_to_labels was also compacted + // write compacted raw_labels if data hence _location_to_labels was also + // compacted if (compact_before_save && _dynamic_index) { _label_map = load_label_map(std::string(filename) + "_labels_map.txt"); @@ -735,8 +735,8 @@ template int Index template uint32_t Index::calculate_entry_point() { - // REFACTOR TODO: This function does not support multi-threaded calculation of medoid. - // Must revisit if perf is a concern. + // REFACTOR TODO: This function does not support multi-threaded calculation of + // medoid. Must revisit if perf is a concern. return _data_store->calculate_medoid(); } @@ -1725,13 +1725,15 @@ void Index::build(const char *filename, const size_t num_points throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } - // REFACTOR PQ TODO: We can remove this if and add a check in the InMemDataStore - // to not populate_data if it has been called once. + // REFACTOR PQ TODO: We can remove this if and add a check in the + // InMemDataStore to not populate_data if it has been called once. if (_pq_dist) { #ifdef EXEC_ENV_OLS std::stringstream ss; - ss << "PQ Build is not supported in DLVS environment (i.e. if EXEC_ENV_OLS is defined)" << std::endl; + ss << "PQ Build is not supported in DLVS environment (i.e. if EXEC_ENV_OLS " + "is defined)" + << std::endl; diskann::cerr << ss.str() << std::endl; throw ANNException(ss.str(), -1, __FUNCSIG__, __FILE__, __LINE__); #else @@ -2076,8 +2078,8 @@ void Index::build_filtered_index(const char *filename, const st size_t num_points_labels = 0; parse_label_file(label_file, - num_points_labels); // determines medoid for each label and identifies - // the points to label mapping + num_points_labels); // determines medoid for each label and + // identifies the points to label mapping convert_pts_label_to_bitmask(_location_to_labels, _bitmask_buf, _labels.size()); @@ -3122,10 +3124,10 @@ int Index::insert_point(const T *point, const TagT tag, const s { if (_frozen_pts_used >= _num_frozen_pts) { - throw ANNException( - "Error: For dynamic filtered index, the number of frozen points should be atleast equal " - "to number of unique labels.", - -1); + throw ANNException("Error: For dynamic filtered index, the number of " + "frozen points should be atleast equal " + "to number of unique labels.", + -1); } auto fz_location = (int)(_max_points) + _frozen_pts_used; // as first _fz_point @@ -3181,7 +3183,8 @@ int Index::insert_point(const T *point, const TagT tag, const s // Insert tag and mapping to location if (_enable_tags) { - // if tags are enabled and tag is already inserted. so we can't reuse that tag. + // if tags are enabled and tag is already inserted. so we can't reuse that + // tag. if (_tag_to_location.find(tag) != _tag_to_location.end()) { release_location(location); @@ -3201,14 +3204,16 @@ int Index::insert_point(const T *point, const TagT tag, const s std::vector pruned_list; // it is the set best candidates to connect to this point if (_filtered_index) { - // when filtered the best_candidates will share the same label ( label_present > distance) + // when filtered the best_candidates will share the same label ( + // label_present > distance) search_for_point_and_prune(location, _indexingQueueSize, pruned_list, scratch, true, _filterIndexingQueueSize); } else { search_for_point_and_prune(location, _indexingQueueSize, pruned_list, scratch); } - assert(pruned_list.size() > 0); // should find atleast one neighbour (i.e frozen point acting as medoid) + assert(pruned_list.size() > 0); // should find atleast one neighbour (i.e + // frozen point acting as medoid) { std::shared_lock tlock(_tag_lock, std::defer_lock); diff --git a/src/index_factory.cpp b/src/index_factory.cpp index 5c7dbee6b..71d6c4cb4 100644 --- a/src/index_factory.cpp +++ b/src/index_factory.cpp @@ -113,7 +113,8 @@ std::shared_ptr> IndexFactory::construct_pq_datastore(DataStoreSt return std::make_shared>(dimension, (location_t)(num_points), num_pq_chunks, std::move(distance_fn), std::move(quantized_distance_fn)); default: - // REFACTOR TODO: We do support diskPQ - so we may need to add a new class for SSDPQDataStore! + // REFACTOR TODO: We do support diskPQ - so we may need to add a new class + // for SSDPQDataStore! break; } return nullptr; @@ -124,7 +125,8 @@ std::unique_ptr IndexFactory::create_instance() { size_t num_points = _config->max_points + _config->num_frozen_pts; size_t dim = _config->dimension; - // auto graph_store = construct_graphstore(_config->graph_strategy, num_points); + // auto graph_store = construct_graphstore(_config->graph_strategy, + // num_points); auto data_store = construct_datastore(_config->data_strategy, num_points, dim, _config->metric); std::shared_ptr> pq_data_store = nullptr; @@ -144,8 +146,9 @@ std::unique_ptr IndexFactory::create_instance() std::unique_ptr graph_store = construct_graphstore(_config->graph_strategy, num_points, max_reserve_degree); - // REFACTOR TODO: Must construct in-memory PQDatastore if strategy == ONDISK and must construct - // in-mem and on-disk PQDataStore if strategy == ONDISK and diskPQ is required. + // REFACTOR TODO: Must construct in-memory PQDatastore if strategy == ONDISK + // and must construct in-mem and on-disk PQDataStore if strategy == ONDISK and + // diskPQ is required. return std::make_unique>(*_config, data_store, std::move(graph_store), pq_data_store); } @@ -193,7 +196,9 @@ std::unique_ptr IndexFactory::create_instance(const std::string & return create_instance(label_type); } else - throw ANNException("Error: unsupported tag_type please choose from [int32/uint32/int64/uint64]", -1); + throw ANNException("Error: unsupported tag_type please choose from " + "[int32/uint32/int64/uint64]", + -1); } template @@ -211,11 +216,17 @@ std::unique_ptr IndexFactory::create_instance(const std::string & throw ANNException("Error: unsupported label_type please choose from [uint/ushort]", -1); } -// template DISKANN_DLLEXPORT std::shared_ptr> IndexFactory::construct_datastore( -// DataStoreStrategy stratagy, size_t num_points, size_t dimension, Metric m); -// template DISKANN_DLLEXPORT std::shared_ptr> IndexFactory::construct_datastore( -// DataStoreStrategy stratagy, size_t num_points, size_t dimension, Metric m); -// template DISKANN_DLLEXPORT std::shared_ptr> IndexFactory::construct_datastore( -// DataStoreStrategy stratagy, size_t num_points, size_t dimension, Metric m); +// template DISKANN_DLLEXPORT std::shared_ptr> +// IndexFactory::construct_datastore( +// DataStoreStrategy stratagy, size_t num_points, size_t dimension, Metric +// m); +// template DISKANN_DLLEXPORT std::shared_ptr> +// IndexFactory::construct_datastore( +// DataStoreStrategy stratagy, size_t num_points, size_t dimension, Metric +// m); +// template DISKANN_DLLEXPORT std::shared_ptr> +// IndexFactory::construct_datastore( +// DataStoreStrategy stratagy, size_t num_points, size_t dimension, Metric +// m); } // namespace diskann diff --git a/src/linux_aligned_file_reader.cpp b/src/linux_aligned_file_reader.cpp index 31bf5f827..94e14dc08 100644 --- a/src/linux_aligned_file_reader.cpp +++ b/src/linux_aligned_file_reader.cpp @@ -3,11 +3,11 @@ #include "linux_aligned_file_reader.h" +#include "tsl/robin_map.h" +#include "utils.h" #include #include #include -#include "tsl/robin_map.h" -#include "utils.h" #define MAX_EVENTS 1024 namespace @@ -149,7 +149,9 @@ void LinuxAlignedFileReader::register_thread() lk.unlock(); if (ret == -EAGAIN) { - std::cerr << "io_setup() failed with EAGAIN: Consider increasing /proc/sys/fs/aio-max-nr" << std::endl; + std::cerr << "io_setup() failed with EAGAIN: Consider increasing " + "/proc/sys/fs/aio-max-nr" + << std::endl; } else { diff --git a/src/math_utils.cpp b/src/math_utils.cpp index 7481da848..5ce66fb2e 100644 --- a/src/math_utils.cpp +++ b/src/math_utils.cpp @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +#include "logger.h" +#include "utils.h" #include #include #include #include -#include "logger.h" -#include "utils.h" namespace math_utils { diff --git a/src/memory_mapper.cpp b/src/memory_mapper.cpp index d1c5ef984..819df7fec 100644 --- a/src/memory_mapper.cpp +++ b/src/memory_mapper.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include "logger.h" #include "memory_mapper.h" +#include "logger.h" #include #include diff --git a/src/partition.cpp b/src/partition.cpp index 570d45c7d..1428eb801 100644 --- a/src/partition.cpp +++ b/src/partition.cpp @@ -7,20 +7,20 @@ #include #include -#include #include "tsl/robin_map.h" #include "tsl/robin_set.h" +#include #if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD) #include "gperftools/malloc_extension.h" #endif -#include "utils.h" -#include "math_utils.h" #include "index.h" -#include "parameters.h" +#include "math_utils.h" #include "memory_mapper.h" +#include "parameters.h" #include "partition.h" +#include "utils.h" #ifdef _WINDOWS #include #endif diff --git a/src/pq.cpp b/src/pq.cpp index d2b545c79..d1cc8e861 100644 --- a/src/pq.cpp +++ b/src/pq.cpp @@ -5,9 +5,9 @@ #if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD) #include "gperftools/malloc_extension.h" #endif -#include "pq.h" -#include "partition.h" #include "math_utils.h" +#include "partition.h" +#include "pq.h" #include "tsl/robin_map.h" // block size for reading/processing large files and matrices in blocks @@ -354,8 +354,9 @@ void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, const size_t pq_n // make_zero_mean is false by default. // These assumptions allow to make the function much simpler and avoid storing // array of chunk_offsets and centroids. -// The compiler pragma for multi-threading support is removed from this implementation -// for the purpose of integration into systems that strictly control resource allocation. +// The compiler pragma for multi-threading support is removed from this +// implementation for the purpose of integration into systems that strictly +// control resource allocation. int generate_pq_pivots_simplified(const float *train_data, size_t num_train, size_t dim, size_t num_pq_chunks, std::vector &pivot_data_vector) { @@ -771,18 +772,19 @@ int generate_opq_pivots(const float *passed_train_data, size_t num_train, uint32 return 0; } -// generate_pq_data_from_pivots_simplified is a simplified version of generate_pq_data_from_pivots. -// Input is provided in the in-memory buffers data and pivot_data. -// Output is stored in the in-memory buffer pq. -// Simplification is based on the following assumptions: +// generate_pq_data_from_pivots_simplified is a simplified version of +// generate_pq_data_from_pivots. Input is provided in the in-memory buffers data +// and pivot_data. Output is stored in the in-memory buffer pq. Simplification +// is based on the following assumptions: // supporting only float data type // dim % num_pq_chunks == 0, which results in a fixed chunk_size // num_centers == 256 by default // make_zero_mean is false by default. // These assumptions allow to make the function much simpler and avoid using // array of chunk_offsets and centroids. -// The compiler pragma for multi-threading support is removed from this implementation -// for the purpose of integration into systems that strictly control resource allocation. +// The compiler pragma for multi-threading support is removed from this +// implementation for the purpose of integration into systems that strictly +// control resource allocation. int generate_pq_data_from_pivots_simplified(const float *data, const size_t num, const float *pivot_data, const size_t pivots_num, const size_t dim, const size_t num_pq_chunks, std::vector &pq) diff --git a/src/pq_data_store.cpp b/src/pq_data_store.cpp index 2136c71e2..55ffbd372 100644 --- a/src/pq_data_store.cpp +++ b/src/pq_data_store.cpp @@ -1,10 +1,10 @@ #include -#include "pq_data_store.h" +#include "distance.h" #include "pq.h" +#include "pq_data_store.h" #include "pq_scratch.h" #include "utils.h" -#include "distance.h" namespace diskann { diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index b38879c58..e39efbb71 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -3,11 +3,12 @@ #include "common_includes.h" -#include "timer.h" +#include "cosine_similarity.h" +#include "in_mem_filter_store.h" #include "pq.h" -#include "pq_scratch.h" #include "pq_flash_index.h" -#include "cosine_similarity.h" +#include "pq_scratch.h" +#include "timer.h" #include #ifdef _WINDOWS @@ -38,8 +39,10 @@ PQFlashIndex::PQFlashIndex(std::shared_ptr &fileRe { if (std::is_floating_point::value) { - diskann::cout << "Since data is floating point, we assume that it has been appropriately pre-processed " - "(normalization for cosine, and convert-to-l2 by adding extra dimension for MIPS). So we " + diskann::cout << "Since data is floating point, we assume that it has " + "been appropriately pre-processed " + "(normalization for cosine, and convert-to-l2 by " + "adding extra dimension for MIPS). So we " "shall invoke an l2 distance function." << std::endl; metric_to_invoke = diskann::Metric::L2; @@ -54,6 +57,7 @@ PQFlashIndex::PQFlashIndex(std::shared_ptr &fileRe this->_dist_cmp.reset(diskann::get_distance_function(metric_to_invoke)); this->_dist_cmp_float.reset(diskann::get_distance_function(metric_to_invoke)); + this->_filter_store = std::make_unique>(); } template PQFlashIndex::~PQFlashIndex() @@ -74,6 +78,12 @@ template PQFlashIndex::~PQFlashIndex() diskann::aligned_free(_coord_cache_buf); } + if (_medoids != nullptr) + { + delete[] _medoids; + _medoids = nullptr; + } + if (_load_flag) { diskann::cout << "Clearing scratch" << std::endl; @@ -82,22 +92,6 @@ template PQFlashIndex::~PQFlashIndex() this->reader->deregister_all_threads(); reader->close(); } - if (_pts_to_label_offsets != nullptr) - { - delete[] _pts_to_label_offsets; - } - if (_pts_to_label_counts != nullptr) - { - delete[] _pts_to_label_counts; - } - if (_pts_to_labels != nullptr) - { - delete[] _pts_to_labels; - } - if (_medoids != nullptr) - { - delete[] _medoids; - } } template inline uint64_t PQFlashIndex::get_node_sector(uint64_t node_id) @@ -270,7 +264,8 @@ void PQFlashIndex::generate_cache_list_from_sample_queries(std::strin #endif if (num_nodes_to_cache >= this->_num_points) { - // for small num_points and big num_nodes_to_cache, use below way to get the node_list quickly + // for small num_points and big num_nodes_to_cache, use below way to get + // the node_list quickly node_list.resize(this->_num_points); for (uint32_t i = 0; i < this->_num_points; ++i) { @@ -313,19 +308,21 @@ void PQFlashIndex::generate_cache_list_from_sample_queries(std::strin bool filtered_search = false; std::vector random_query_filters(sample_num); - if (_filter_to_medoid_ids.size() != 0) + if (this->_filter_index) { filtered_search = true; - generate_random_labels(random_query_filters, (uint32_t)sample_num, nthreads); + _filter_store->generate_random_labels(random_query_filters, (uint32_t)sample_num, nthreads); } #pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads) for (int64_t i = 0; i < (int64_t)sample_num; i++) { auto &label_for_search = random_query_filters[i]; - // run a search on the sample query with a random label (sampled from base label distribution), and it will - // concurrently update the node_visit_counter to track most visited nodes. The last false is to not use the - // "use_reorder_data" option which enables a final reranking if the disk index itself contains only PQ data. + // run a search on the sample query with a random label (sampled from base + // label distribution), and it will concurrently update the + // node_visit_counter to track most visited nodes. The last false is to + // not use the "use_reorder_data" option which enables a final reranking + // if the disk index itself contains only PQ data. cached_beam_search(samples + (i * sample_aligned_dim), 1, l_search, tmp_result_ids_64.data() + i, tmp_result_dists.data() + i, beamwidth, filtered_search, label_for_search, false); } @@ -375,9 +372,10 @@ void PQFlashIndex::cache_bfs_levels(uint64_t num_nodes_to_cache, std: cur_level->insert(_medoids[miter]); } - if ((_filter_to_medoid_ids.size() > 0) && (cur_level->size() < num_nodes_to_cache)) + auto filter_to_medoid_ids = _filter_store->get_label_to_medoids(); + if ((filter_to_medoid_ids.size() > 0) && (cur_level->size() < num_nodes_to_cache)) { - for (auto &x : _filter_to_medoid_ids) + for (auto &x : filter_to_medoid_ids) { for (auto &y : x.second) { @@ -534,237 +532,6 @@ template void PQFlashIndex::use_medoids } } -template -void PQFlashIndex::generate_random_labels(std::vector &labels, const uint32_t num_labels, - const uint32_t nthreads) -{ - std::random_device rd; - labels.clear(); - labels.resize(num_labels); - - uint64_t num_total_labels = _pts_to_label_offsets[_num_points - 1] + _pts_to_label_counts[_num_points - 1]; - std::mt19937 gen(rd()); - if (num_total_labels == 0) - { - std::stringstream stream; - stream << "No labels found in data. Not sampling random labels "; - diskann::cerr << stream.str() << std::endl; - throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); - } - std::uniform_int_distribution dis(0, num_total_labels - 1); - -#pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads) - for (int64_t i = 0; i < num_labels; i++) - { - uint64_t rnd_loc = dis(gen); - labels[i] = (LabelT)_pts_to_labels[rnd_loc]; - } -} - -template -std::unordered_map PQFlashIndex::load_label_map(std::basic_istream &map_reader) -{ - std::unordered_map string_to_int_mp; - std::string line, token; - LabelT token_as_num; - std::string label_str; - while (std::getline(map_reader, line)) - { - std::istringstream iss(line); - getline(iss, token, '\t'); - label_str = token; - getline(iss, token, '\t'); - token_as_num = (LabelT)std::stoul(token); - string_to_int_mp[label_str] = token_as_num; - } - return string_to_int_mp; -} - -template -LabelT PQFlashIndex::get_converted_label(const std::string &filter_label) -{ - if (_label_map.find(filter_label) != _label_map.end()) - { - return _label_map[filter_label]; - } - else if (_use_universal_label) - { - return _universal_filter_label; - } - else - { - return std::numeric_limits::max(); - } -} - -template -void PQFlashIndex::reset_stream_for_reading(std::basic_istream &infile) -{ - infile.clear(); - infile.seekg(0); -} - -template -bool PQFlashIndex::is_label_valid(const std::string& filter_label) -{ - if (_label_map.find(filter_label) != _label_map.end()) - { - return true; - } - - return false; -} - -template -void PQFlashIndex::get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts, - uint32_t &num_total_labels) -{ - num_pts = 0; - num_total_labels = 0; - - size_t file_size = fileContent.length(); - - std::string label_str; - size_t cur_pos = 0; - size_t next_pos = 0; - while (cur_pos < file_size && cur_pos != std::string::npos) - { - next_pos = fileContent.find('\n', cur_pos); - if (next_pos == std::string::npos) - { - break; - } - - size_t lbl_pos = cur_pos; - size_t next_lbl_pos = 0; - while (lbl_pos < next_pos && lbl_pos != std::string::npos) - { - next_lbl_pos = search_string_range(fileContent, ',', lbl_pos, next_pos); - if (next_lbl_pos == std::string::npos) // the last label - { - next_lbl_pos = next_pos; - } - - num_total_labels++; - - lbl_pos = next_lbl_pos + 1; - } - - cur_pos = next_pos + 1; - - num_pts++; - } - - diskann::cout << "Labels file metadata: num_points: " << num_pts << ", #total_labels: " << num_total_labels - << std::endl; -} - -template -inline bool PQFlashIndex::point_has_label(uint32_t point_id, LabelT label_id) -{ - uint32_t start_vec = _pts_to_label_offsets[point_id]; - uint32_t num_lbls = _pts_to_label_counts[point_id]; - bool ret_val = false; - for (uint32_t i = 0; i < num_lbls; i++) - { - if (_pts_to_labels[start_vec + i] == label_id) - { - ret_val = true; - break; - } - } - return ret_val; -} - -template -void PQFlashIndex::parse_label_file(std::basic_istream &infile, size_t &num_points_labels) -{ - infile.seekg(0, std::ios::end); - size_t file_size = infile.tellg(); - - std::string buffer(file_size, ' '); - - infile.seekg(0, std::ios::beg); - infile.read(&buffer[0], file_size); - - std::string line; - uint32_t line_cnt = 0; - - uint32_t num_pts_in_label_file; - uint32_t num_total_labels; - get_label_file_metadata(buffer, num_pts_in_label_file, num_total_labels); - - _pts_to_label_offsets = new uint32_t[num_pts_in_label_file]; - _pts_to_label_counts = new uint32_t[num_pts_in_label_file]; - _pts_to_labels = new LabelT[num_total_labels]; - uint32_t labels_seen_so_far = 0; - - std::string label_str; - size_t cur_pos = 0; - size_t next_pos = 0; - while (cur_pos < file_size && cur_pos != std::string::npos) - { - next_pos = buffer.find('\n', cur_pos); - if (next_pos == std::string::npos) - { - break; - } - - _pts_to_label_offsets[line_cnt] = labels_seen_so_far; - uint32_t &num_lbls_in_cur_pt = _pts_to_label_counts[line_cnt]; - num_lbls_in_cur_pt = 0; - - size_t lbl_pos = cur_pos; - size_t next_lbl_pos = 0; - while (lbl_pos < next_pos && lbl_pos != std::string::npos) - { - next_lbl_pos = search_string_range(buffer, ',', lbl_pos, next_pos); - if (next_lbl_pos == std::string::npos) // the last label in the whole file - { - next_lbl_pos = next_pos; - } - - if (next_lbl_pos > next_pos) // the last label in one line, just read to the end - { - next_lbl_pos = next_pos; - } - - label_str.assign(buffer.c_str() + lbl_pos, next_lbl_pos - lbl_pos); - if (label_str[label_str.length() - 1] == '\t') // '\t' won't exist in label file? - { - label_str.erase(label_str.length() - 1); - } - - LabelT token_as_num = (LabelT)std::stoul(label_str); - _pts_to_labels[labels_seen_so_far++] = (LabelT)token_as_num; - num_lbls_in_cur_pt++; - - // move to next label - lbl_pos = next_lbl_pos + 1; - } - - // move to next line - cur_pos = next_pos + 1; - - if (num_lbls_in_cur_pt == 0) - { - diskann::cout << "No label found for point " << line_cnt << std::endl; - exit(-1); - } - - line_cnt++; - } - - num_points_labels = line_cnt; - reset_stream_for_reading(infile); -} - -template void PQFlashIndex::set_universal_label(const LabelT &label) -{ - _use_universal_label = true; - _universal_filter_label = label; -} - #ifdef EXEC_ENV_OLS template int PQFlashIndex::load(MemoryMappedFiles &files, uint32_t num_threads, const char *index_prefix) @@ -813,12 +580,6 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons std::string medoids_file = std::string(_disk_index_file) + "_medoids.bin"; std::string centroids_file = std::string(_disk_index_file) + "_centroids.bin"; - std::string labels_file = (labels_filepath == nullptr ? "" : labels_filepath); - std::string labels_to_medoids = (labels_to_medoids_filepath == nullptr ? "" : labels_to_medoids_filepath); - std::string dummy_map_file = std ::string(_disk_index_file) + "_dummy_map.txt"; - std::string labels_map_file = (labels_map_filepath == nullptr ? "" : labels_map_filepath); - size_t num_pts_in_label_file = 0; - size_t pq_file_dim, pq_file_num_centroids; #ifdef EXEC_ENV_OLS get_bin_metadata(files, pq_table_bin, pq_file_num_centroids, pq_file_dim, METADATA_SIZE); @@ -846,147 +607,30 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons #else diskann::load_bin(pq_compressed_vectors, this->data, npts_u64, nchunks_u64); #endif - this->_num_points = npts_u64; this->_n_chunks = nchunks_u64; -#ifdef EXEC_ENV_OLS - if (files.fileExists(labels_file)) - { - FileContent &content_labels = files.getContent(labels_file); - std::stringstream infile(std::string((const char *)content_labels._content, content_labels._size)); -#else - if (file_exists(labels_file)) - { - std::ifstream infile(labels_file, std::ios::binary); - if (infile.fail()) - { - throw diskann::ANNException(std::string("Failed to open file ") + labels_file, -1); - } -#endif - parse_label_file(infile, num_pts_in_label_file); - assert(num_pts_in_label_file == this->_num_points); - -#ifndef EXEC_ENV_OLS - infile.close(); -#endif - -#ifdef EXEC_ENV_OLS - FileContent &content_labels_map = files.getContent(labels_map_file); - std::stringstream map_reader(std::string((const char *)content_labels_map._content, content_labels_map._size)); -#else - std::ifstream map_reader(labels_map_file); -#endif - _label_map = load_label_map(map_reader); - -#ifndef EXEC_ENV_OLS - map_reader.close(); -#endif -#ifdef EXEC_ENV_OLS - if (files.fileExists(labels_to_medoids)) - { - FileContent &content_labels_to_meoids = files.getContent(labels_to_medoids); - std::stringstream medoid_stream( - std::string((const char *)content_labels_to_meoids._content, content_labels_to_meoids._size)); -#else - if (file_exists(labels_to_medoids)) - { - std::ifstream medoid_stream(labels_to_medoids); - assert(medoid_stream.is_open()); -#endif - std::string line, token; - - _filter_to_medoid_ids.clear(); - try - { - while (std::getline(medoid_stream, line)) - { - std::istringstream iss(line); - uint32_t cnt = 0; - std::vector medoids; - LabelT label; - while (std::getline(iss, token, ',')) - { - if (cnt == 0) - label = (LabelT)std::stoul(token); - else - medoids.push_back((uint32_t)stoul(token)); - cnt++; - } - _filter_to_medoid_ids[label].swap(medoids); - } - } - catch (std::system_error &e) - { - throw FileException(labels_to_medoids, e, __FUNCSIG__, __FILE__, __LINE__); - } - } - std::string univ_label_file = (unv_label_filepath == nullptr ? "" : unv_label_filepath); - -#ifdef EXEC_ENV_OLS - if (files.fileExists(univ_label_file)) - { - FileContent& content_univ_label = files.getContent(univ_label_file); - std::stringstream universal_label_reader( - std::string((const char*)content_univ_label._content, content_univ_label._size)); -#else - if (file_exists(univ_label_file)) + _filter_store = std::make_unique>(); + try + { + _filter_index = _filter_store->load(_disk_index_file); + if (_filter_index) { - std::ifstream universal_label_reader(univ_label_file); - assert(universal_label_reader.is_open()); -#endif - std::string univ_label; - universal_label_reader >> univ_label; -#ifndef EXEC_ENV_OLS - universal_label_reader.close(); -#endif - LabelT label_as_num = (LabelT)std::stoul(univ_label); - set_universal_label(label_as_num); + diskann::cout << "Index has filter support. " << std::endl; } - -#ifdef EXEC_ENV_OLS - if (files.fileExists(dummy_map_file)) - { - FileContent &content_dummy_map = files.getContent(dummy_map_file); - std::stringstream dummy_map_stream( - std::string((const char *)content_dummy_map._content, content_dummy_map._size)); -#else - if (file_exists(dummy_map_file)) + else { - std::ifstream dummy_map_stream(dummy_map_file); - assert(dummy_map_stream.is_open()); -#endif - std::string line, token; - - while (std::getline(dummy_map_stream, line)) - { - std::istringstream iss(line); - uint32_t cnt = 0; - uint32_t dummy_id; - uint32_t real_id; - while (std::getline(iss, token, ',')) - { - if (cnt == 0) - dummy_id = (uint32_t)stoul(token); - else - real_id = (uint32_t)stoul(token); - cnt++; - } - _dummy_pts.insert(dummy_id); - _has_dummy_pts.insert(real_id); - _dummy_to_real_map[dummy_id] = real_id; - - if (_real_to_dummy_map.find(real_id) == _real_to_dummy_map.end()) - _real_to_dummy_map[real_id] = std::vector(); - - _real_to_dummy_map[real_id].emplace_back(dummy_id); - } -#ifndef EXEC_ENV_OLS - dummy_map_stream.close(); -#endif - diskann::cout << "Loaded dummy map" << std::endl; + diskann::cout << "Index does not have filter support." << std::endl; } } + catch (diskann::ANNException &ex) + { + // If filter_store=>load() returns false, it means filters are not + // enabled. If it throws, it means there was an error in processing a + // filter index. + diskann::cerr << "Filter index load failed because: " << ex.what() << std::endl; + return false; + } #ifdef EXEC_ENV_OLS _pq_table.load_pq_centroid_bin(files, pq_table_bin.c_str(), nchunks_u64); @@ -1048,8 +692,8 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons std::ifstream index_metadata(_disk_index_file, std::ios::binary); #endif - uint32_t nr, nc; // metadata itself is stored as bin format (nr is number of - // metadata, nc should be 1) + uint32_t nr, nc; // metadata itself is stored as bin format (nr is number + // of metadata, nc should be 1) READ_U32(index_metadata, nr); READ_U32(index_metadata, nc); @@ -1293,7 +937,6 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t const uint32_t io_limit, const bool use_reorder_data, QueryStats *stats) { - uint64_t num_sector_per_nodes = DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN); if (beam_width > num_sector_per_nodes * defaults::MAX_N_SECTOR_READS) throw ANNException("Beamwidth can not be higher than defaults::MAX_N_SECTOR_READS", -1, __FUNCSIG__, __FILE__, @@ -1316,8 +959,8 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t float *query_rotated = pq_query_scratch->rotated_query; // normalization step. for cosine, we simply normalize the query - // for mips, we normalize the first d-1 dims, and add a 0 for last dim, since an extra coordinate was used to - // convert MIPS to L2 search + // for mips, we normalize the first d-1 dims, and add a 0 for last dim, + // since an extra coordinate was used to convert MIPS to L2 search if (metric == diskann::Metric::INNER_PRODUCT || metric == diskann::Metric::COSINE) { uint64_t inherent_dim = (metric == diskann::Metric::COSINE) ? this->_data_dim : (uint64_t)(this->_data_dim - 1); @@ -1357,8 +1000,8 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t _nnodes_per_sector > 0 ? 1 : DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN); // query <-> PQ chunk centers distances - _pq_table.preprocess_query(query_rotated); // center the query and rotate if - // we have a rotation matrix + _pq_table.preprocess_query(query_rotated); // center the query and rotate + // if we have a rotation matrix float *pq_dists = pq_query_scratch->aligned_pqtable_dist_scratch; _pq_table.populate_chunk_distances(query_rotated, pq_dists); @@ -1396,13 +1039,18 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t } else { - if (_filter_to_medoid_ids.find(filter_label) != _filter_to_medoid_ids.end()) + const auto &medoid_ids = _filter_store->get_medoids_of_label(filter_label); + if (medoid_ids.size() > 0) + // if (_filter_to_medoid_ids.find(filter_label) != + // _filter_to_medoid_ids.end()) { - const auto &medoid_ids = _filter_to_medoid_ids[filter_label]; + // const auto &medoid_ids = _filter_to_medoid_ids[filter_label]; + for (uint64_t cur_m = 0; cur_m < medoid_ids.size(); cur_m++) { - // for filtered index, we dont store global centroid data as for unfiltered index, so we use PQ distance - // as approximation to decide closest medoid matching the query filter. + // for filtered index, we dont store global centroid data as for + // unfiltered index, so we use PQ distance as approximation to decide + // closest medoid matching the query filter. compute_dists(&medoid_ids[cur_m], 1, dist_scratch); float cur_expanded_dist = dist_scratch[0]; if (cur_expanded_dist < best_dist) @@ -1542,11 +1190,15 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t uint32_t id = node_nbrs[m]; if (visited.insert(id).second) { - if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) + // if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) + // unfiltered search, but filtered index! + if (!use_filter && _filter_store->is_dummy_point(id)) continue; - if (use_filter && !(point_has_label(id, filter_label)) && - (!_use_universal_label || !point_has_label(id, _universal_filter_label))) + // if (use_filter && !(point_has_label(id, filter_label)) && + // (!_use_universal_label || !point_has_label(id, + // _universal_filter_label))) + if (use_filter && !_filter_store->point_has_label_or_universal_label(id, filter_label)) continue; cmps++; float dist = dist_scratch[m]; @@ -1605,11 +1257,14 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t uint32_t id = node_nbrs[m]; if (visited.insert(id).second) { - if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) + // if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) + if (!use_filter && _filter_store->is_dummy_point(id)) continue; - if (use_filter && !(point_has_label(id, filter_label)) && - (!_use_universal_label || !point_has_label(id, _universal_filter_label))) + // if (use_filter && !(point_has_label(id, filter_label)) && + // (!_use_universal_label || !point_has_label(id, + // _universal_filter_label))) + if (use_filter && !_filter_store->point_has_label_or_universal_label(id, filter_label)) continue; cmps++; float dist = dist_scratch[m]; @@ -1692,9 +1347,9 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t { indices[i] = full_retset[i].id; auto key = (uint32_t)indices[i]; - if (_dummy_pts.find(key) != _dummy_pts.end()) + if (_filter_store->is_dummy_point(key)) { - indices[i] = _dummy_to_real_map[key]; + indices[i] = _filter_store->get_real_point_for_dummy(key); } if (distances != nullptr) @@ -1777,17 +1432,9 @@ template diskann::Metric PQFlashIndex:: } template -size_t PQFlashIndex::search_string_range(const std::string& str, char ch, size_t start, size_t end) +LabelT PQFlashIndex::get_converted_label(const std::string &filter_label) { - for (; start != end; start++) - { - if (str[start] == ch) - { - return start; - } - } - - return std::string::npos; + return _filter_store->get_converted_label(filter_label); } #ifdef EXEC_ENV_OLS diff --git a/src/pq_l2_distance.cpp b/src/pq_l2_distance.cpp index c08744c35..9168d26be 100644 --- a/src/pq_l2_distance.cpp +++ b/src/pq_l2_distance.cpp @@ -1,6 +1,6 @@ -#include "pq.h" #include "pq_l2_distance.h" +#include "pq.h" #include "pq_scratch.h" // block size for reading/processing large files and matrices in blocks diff --git a/src/scratch.cpp b/src/scratch.cpp index 8b8427453..5d34d7e1c 100644 --- a/src/scratch.cpp +++ b/src/scratch.cpp @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include #include +#include -#include "scratch.h" #include "pq_scratch.h" +#include "scratch.h" namespace diskann { diff --git a/src/windows_aligned_file_reader.cpp b/src/windows_aligned_file_reader.cpp index 9e6eef418..266e5b227 100644 --- a/src/windows_aligned_file_reader.cpp +++ b/src/windows_aligned_file_reader.cpp @@ -4,8 +4,8 @@ #ifdef _WINDOWS #ifndef USE_BING_INFRA #include "windows_aligned_file_reader.h" -#include #include "utils.h" +#include #include #define SECTOR_LEN 4096