From b8be3b853ca5dcbf0985b33b412fe691c512c9a8 Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Fri, 19 Jan 2024 17:14:54 +0800 Subject: [PATCH 01/13] add 16 bytes tag type --- include/natural_number_map.h | 2 +- include/tag_uint128.h | 70 ++++++++++++++++++++++++++ include/utils.h | 13 +++++ src/index.cpp | 96 +++++++++++++++++++----------------- src/natural_number_map.cpp | 3 ++ 5 files changed, 139 insertions(+), 45 deletions(-) create mode 100644 include/tag_uint128.h diff --git a/include/natural_number_map.h b/include/natural_number_map.h index 820ac3fdf..7555e317d 100644 --- a/include/natural_number_map.h +++ b/include/natural_number_map.h @@ -28,7 +28,7 @@ template class natural_number_map static_assert(std::is_trivial::value, "Key must be a trivial type"); // Some of the class member prototypes are done with this assumption to // minimize verbosity since it's the only use case. - static_assert(std::is_trivial::value, "Value must be a trivial type"); +// static_assert(std::is_trivial::value, "Value must be a trivial type"); // Represents a reference to a element in the map. Used while iterating // over map entries. diff --git a/include/tag_uint128.h b/include/tag_uint128.h new file mode 100644 index 000000000..258b73c14 --- /dev/null +++ b/include/tag_uint128.h @@ -0,0 +1,70 @@ +#pragma once +#include +#include + +namespace diskann +{ +#pragma pack(push, 1) + +struct tag_uint128 +{ + std::uint64_t _data1 = 0; + std::uint64_t _data2 = 0; + + bool operator==(const tag_uint128& other) const + { + return _data1 == other._data1 + && _data2 == other._data2; + } + + bool operator==(std::uint64_t other) const + { + return _data1 == other + && _data2 == 0; + } + + tag_uint128& operator=(const tag_uint128& other) + { + _data1 = other._data1; + _data2 = other._data2; + + return *this; + } + + tag_uint128& operator=(std::uint64_t other) + { + _data1 = other; + _data2 = 0; + + return *this; + } +}; + +#pragma pack(pop) +} + +namespace std +{ +// Hash 128 input bits down to 64 bits of output. +// This is intended to be a reasonably good hash function. +inline std::uint64_t Hash128to64(const std::uint64_t& low, const std::uint64_t& high) { + // Murmur-inspired hashing. + const std::uint64_t kMul = 0x9ddfea08eb382d69ULL; + std::uint64_t a = (low ^ high) * kMul; + a ^= (a >> 47); + std::uint64_t b = (high ^ a) * kMul; + b ^= (b >> 47); + b *= kMul; + return b; +} + +template<> +struct hash +{ + _NODISCARD size_t operator()(const diskann::tag_uint128& key) const noexcept + { + return Hash128to64(key._data1, key._data2); // map -0 to 0 + } +}; + +} \ No newline at end of file diff --git a/include/utils.h b/include/utils.h index bb03d13f1..67cca9644 100644 --- a/include/utils.h +++ b/include/utils.h @@ -27,6 +27,7 @@ typedef int FileHandle; #include "windows_customizations.h" #include "tsl/robin_set.h" #include "types.h" +#include "tag_uint128.h" #include #ifdef EXEC_ENV_OLS @@ -1007,6 +1008,18 @@ void block_convert(std::ofstream &writr, std::ifstream &readr, float *read_buf, DISKANN_DLLEXPORT void normalize_data_file(const std::string &inFileName, const std::string &outFileName); + +inline std::string get_tag_string(std::uint64_t tag) +{ + return std::to_string(tag); +} + +inline std::string get_tag_string(tag_uint128 tag) +{ + std::string str = std::to_string(tag._data2) + "_" + std::to_string(tag._data1); + return str; +} + }; // namespace diskann struct PivotContainer diff --git a/src/index.cpp b/src/index.cpp index d906600d1..4a109f484 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -12,6 +12,7 @@ #include "tsl/robin_map.h" #include "tsl/robin_set.h" #include "windows_customizations.h" +#include "tag_uint128.h" #if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD) #include "gperftools/malloc_extension.h" #endif @@ -717,7 +718,7 @@ template int Index std::shared_lock lock(_tag_lock); if (_tag_to_location.find(tag) == _tag_to_location.end()) { - diskann::cout << "Tag " << tag << " does not exist" << std::endl; + diskann::cout << "Tag " << get_tag_string(tag) << " does not exist" << std::endl; return -1; } @@ -1991,7 +1992,7 @@ std::pair Index::search(const T *query, con { // safe because Index uses uint32_t ids internally // and IDType will be uint32_t or uint64_t - indices[pos] = (IdType)best_L_nodes[i].id; + indices[pos] = best_L_nodes[i].id; if (distances != nullptr) { #ifdef EXEC_ENV_OLS @@ -2022,14 +2023,10 @@ std::pair Index::_search_with_filters(const float *distances) { auto converted_label = this->get_converted_label(raw_label); - if (typeid(uint64_t *) == indices.type()) + + if (typeid(TagT*) == indices.type()) { - auto ptr = std::any_cast(indices); - return this->search_with_filters(std::any_cast(query), converted_label, K, L, ptr, distances); - } - else if (typeid(uint32_t *) == indices.type()) - { - auto ptr = std::any_cast(indices); + auto ptr = std::any_cast(indices); return this->search_with_filters(std::any_cast(query), converted_label, K, L, ptr, distances); } else @@ -2100,7 +2097,7 @@ std::pair Index::search_with_filters(const TagT tag; if (_location_to_tag.try_get(best_L_nodes[i].id, tag)) { - indices[pos] = (IdType)tag; + indices[pos] = tag; } else { @@ -2109,7 +2106,7 @@ std::pair Index::search_with_filters(const } else { - indices[pos] = (IdType)best_L_nodes[i].id; + indices[pos] = best_L_nodes[i].id; } if (distances != nullptr) @@ -2861,7 +2858,7 @@ int Index::insert_point(const T *point, const TagT tag, const s { assert(_has_built); - if (tag == static_cast(0)) + if (tag == 0) { throw diskann::ANNException("Do not insert point with tag 0. That is " "reserved for points hidden " @@ -2879,7 +2876,7 @@ int Index::insert_point(const T *point, const TagT tag, const s if (labels.empty()) { release_location(location); - std::cerr << "Error: Can't insert point with tag " + std::to_string(tag) + + std::cerr << "Error: Can't insert point with tag " + get_tag_string(tag) + " . there are no labels for the point." << std::endl; return -1; @@ -3047,7 +3044,7 @@ template int Index if (_tag_to_location.find(tag) == _tag_to_location.end()) { - diskann::cerr << "Delete tag not found " << tag << std::endl; + diskann::cerr << "Delete tag not found " << get_tag_string(tag) << std::endl; return -1; } assert(_tag_to_location[tag] < _max_points); @@ -3336,6 +3333,9 @@ template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; +template DISKANN_DLLEXPORT class Index; +template DISKANN_DLLEXPORT class Index; +template DISKANN_DLLEXPORT class Index; // Label with short int 2 byte template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; @@ -3349,19 +3349,16 @@ template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; +template DISKANN_DLLEXPORT class Index; +template DISKANN_DLLEXPORT class Index; +template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT std::pair Index::search( const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search( const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); @@ -3375,25 +3372,23 @@ template DISKANN_DLLEXPORT std::pair Index Index::search( const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); +// TagT==uint128 +template DISKANN_DLLEXPORT std::pair Index::search( + const float* query, const size_t K, const uint32_t L, tag_uint128* indices, float* distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const uint8_t* query, const size_t K, const uint32_t L, tag_uint128* indices, float* distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const int8_t* query, const size_t K, const uint32_t L, tag_uint128* indices, float* distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, - float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, - float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, - float *distances); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, @@ -3413,19 +3408,23 @@ template DISKANN_DLLEXPORT std::pair Index Index::search_with_filters< uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, float *distances); +// TagT==uint128 +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + tag_uint128>(const float* query, const uint32_t& filter_label, const size_t K, const uint32_t L, tag_uint128* indices, + float* distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + tag_uint128>(const uint8_t* query, const uint32_t& filter_label, const size_t K, const uint32_t L, tag_uint128* indices, + float* distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + tag_uint128>(const int8_t* query, const uint32_t& filter_label, const size_t K, const uint32_t L, tag_uint128* indices, + float* distances); template DISKANN_DLLEXPORT std::pair Index::search( const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search( const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); @@ -3439,25 +3438,23 @@ template DISKANN_DLLEXPORT std::pair Index Index::search( const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); +// TagT==uint128 +template DISKANN_DLLEXPORT std::pair Index::search( + const float* query, const size_t K, const uint32_t L, tag_uint128* indices, float* distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const uint8_t* query, const size_t K, const uint32_t L, tag_uint128* indices, float* distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const int8_t* query, const size_t K, const uint32_t L, tag_uint128* indices, float* distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, - float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, - float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, - float *distances); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, @@ -3477,4 +3474,15 @@ template DISKANN_DLLEXPORT std::pair Index Index::search_with_filters< uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, float *distances); +// TagT==uint128 +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + tag_uint128>(const float* query, const uint16_t& filter_label, const size_t K, const uint32_t L, tag_uint128* indices, + float* distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + tag_uint128>(const uint8_t* query, const uint16_t& filter_label, const size_t K, const uint32_t L, tag_uint128* indices, + float* distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + tag_uint128>(const int8_t* query, const uint16_t& filter_label, const size_t K, const uint32_t L, tag_uint128* indices, + float* distances); + } // namespace diskann diff --git a/src/natural_number_map.cpp b/src/natural_number_map.cpp index 9050831a2..347034481 100644 --- a/src/natural_number_map.cpp +++ b/src/natural_number_map.cpp @@ -5,6 +5,7 @@ #include #include "natural_number_map.h" +#include "tag_uint128.h" namespace diskann { @@ -111,4 +112,6 @@ template class natural_number_map; template class natural_number_map; template class natural_number_map; template class natural_number_map; +template class natural_number_map; +template class natural_number_map; } // namespace diskann From d09c9cb5d1d6ca2f22c694835abc50717a27f817 Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Fri, 19 Jan 2024 17:23:38 +0800 Subject: [PATCH 02/13] clean up code --- include/natural_number_map.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/include/natural_number_map.h b/include/natural_number_map.h index 7555e317d..e846882a8 100644 --- a/include/natural_number_map.h +++ b/include/natural_number_map.h @@ -26,9 +26,6 @@ template class natural_number_map { public: static_assert(std::is_trivial::value, "Key must be a trivial type"); - // Some of the class member prototypes are done with this assumption to - // minimize verbosity since it's the only use case. -// static_assert(std::is_trivial::value, "Value must be a trivial type"); // Represents a reference to a element in the map. Used while iterating // over map entries. From ec6a1a1d2d6d826b15313b2568340bc74048c265 Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Fri, 19 Jan 2024 17:27:29 +0800 Subject: [PATCH 03/13] format doc --- include/tag_uint128.h | 26 ++++++++++++-------------- include/utils.h | 1 - src/index.cpp | 42 +++++++++++++++++++++--------------------- 3 files changed, 33 insertions(+), 36 deletions(-) diff --git a/include/tag_uint128.h b/include/tag_uint128.h index 258b73c14..481f25f15 100644 --- a/include/tag_uint128.h +++ b/include/tag_uint128.h @@ -10,20 +10,18 @@ struct tag_uint128 { std::uint64_t _data1 = 0; std::uint64_t _data2 = 0; - - bool operator==(const tag_uint128& other) const + + bool operator==(const tag_uint128 &other) const { - return _data1 == other._data1 - && _data2 == other._data2; + return _data1 == other._data1 && _data2 == other._data2; } bool operator==(std::uint64_t other) const { - return _data1 == other - && _data2 == 0; + return _data1 == other && _data2 == 0; } - tag_uint128& operator=(const tag_uint128& other) + tag_uint128 &operator=(const tag_uint128 &other) { _data1 = other._data1; _data2 = other._data2; @@ -31,7 +29,7 @@ struct tag_uint128 return *this; } - tag_uint128& operator=(std::uint64_t other) + tag_uint128 &operator=(std::uint64_t other) { _data1 = other; _data2 = 0; @@ -41,13 +39,14 @@ struct tag_uint128 }; #pragma pack(pop) -} +} // namespace diskann namespace std { // Hash 128 input bits down to 64 bits of output. // This is intended to be a reasonably good hash function. -inline std::uint64_t Hash128to64(const std::uint64_t& low, const std::uint64_t& high) { +inline std::uint64_t Hash128to64(const std::uint64_t &low, const std::uint64_t &high) +{ // Murmur-inspired hashing. const std::uint64_t kMul = 0x9ddfea08eb382d69ULL; std::uint64_t a = (low ^ high) * kMul; @@ -58,13 +57,12 @@ inline std::uint64_t Hash128to64(const std::uint64_t& low, const std::uint64_t& return b; } -template<> -struct hash +template <> struct hash { - _NODISCARD size_t operator()(const diskann::tag_uint128& key) const noexcept + _NODISCARD size_t operator()(const diskann::tag_uint128 &key) const noexcept { return Hash128to64(key._data1, key._data2); // map -0 to 0 } }; -} \ No newline at end of file +} // namespace std \ No newline at end of file diff --git a/include/utils.h b/include/utils.h index 67cca9644..28ef47ddb 100644 --- a/include/utils.h +++ b/include/utils.h @@ -1008,7 +1008,6 @@ void block_convert(std::ofstream &writr, std::ifstream &readr, float *read_buf, DISKANN_DLLEXPORT void normalize_data_file(const std::string &inFileName, const std::string &outFileName); - inline std::string get_tag_string(std::uint64_t tag) { return std::to_string(tag); diff --git a/src/index.cpp b/src/index.cpp index 4a109f484..9dc1093b4 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2023,10 +2023,10 @@ std::pair Index::_search_with_filters(const float *distances) { auto converted_label = this->get_converted_label(raw_label); - - if (typeid(TagT*) == indices.type()) + + if (typeid(TagT *) == indices.type()) { - auto ptr = std::any_cast(indices); + auto ptr = std::any_cast(indices); return this->search_with_filters(std::any_cast(query), converted_label, K, L, ptr, distances); } else @@ -3374,11 +3374,11 @@ template DISKANN_DLLEXPORT std::pair Index Index::search( - const float* query, const size_t K, const uint32_t L, tag_uint128* indices, float* distances); + const float *query, const size_t K, const uint32_t L, tag_uint128 *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t* query, const size_t K, const uint32_t L, tag_uint128* indices, float* distances); + const uint8_t *query, const size_t K, const uint32_t L, tag_uint128 *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t* query, const size_t K, const uint32_t L, tag_uint128* indices, float* distances); + const int8_t *query, const size_t K, const uint32_t L, tag_uint128 *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, @@ -3410,14 +3410,14 @@ template DISKANN_DLLEXPORT std::pair Index Index::search_with_filters< - tag_uint128>(const float* query, const uint32_t& filter_label, const size_t K, const uint32_t L, tag_uint128* indices, - float* distances); + tag_uint128>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, + tag_uint128 *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - tag_uint128>(const uint8_t* query, const uint32_t& filter_label, const size_t K, const uint32_t L, tag_uint128* indices, - float* distances); + tag_uint128>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, + tag_uint128 *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - tag_uint128>(const int8_t* query, const uint32_t& filter_label, const size_t K, const uint32_t L, tag_uint128* indices, - float* distances); + tag_uint128>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, + tag_uint128 *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); @@ -3440,11 +3440,11 @@ template DISKANN_DLLEXPORT std::pair Index Index::search( - const float* query, const size_t K, const uint32_t L, tag_uint128* indices, float* distances); + const float *query, const size_t K, const uint32_t L, tag_uint128 *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t* query, const size_t K, const uint32_t L, tag_uint128* indices, float* distances); + const uint8_t *query, const size_t K, const uint32_t L, tag_uint128 *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t* query, const size_t K, const uint32_t L, tag_uint128* indices, float* distances); + const int8_t *query, const size_t K, const uint32_t L, tag_uint128 *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, @@ -3476,13 +3476,13 @@ template DISKANN_DLLEXPORT std::pair Index Index::search_with_filters< - tag_uint128>(const float* query, const uint16_t& filter_label, const size_t K, const uint32_t L, tag_uint128* indices, - float* distances); + tag_uint128>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, + tag_uint128 *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - tag_uint128>(const uint8_t* query, const uint16_t& filter_label, const size_t K, const uint32_t L, tag_uint128* indices, - float* distances); + tag_uint128>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, + tag_uint128 *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - tag_uint128>(const int8_t* query, const uint16_t& filter_label, const size_t K, const uint32_t L, tag_uint128* indices, - float* distances); + tag_uint128>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, + tag_uint128 *indices, float *distances); } // namespace diskann From 59d84d2e9d9d4fdeb64980aae8df734fdc9a9129 Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Fri, 19 Jan 2024 17:36:31 +0800 Subject: [PATCH 04/13] fix compile issue --- include/tag_uint128.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tag_uint128.h b/include/tag_uint128.h index 481f25f15..642de3159 100644 --- a/include/tag_uint128.h +++ b/include/tag_uint128.h @@ -59,7 +59,7 @@ inline std::uint64_t Hash128to64(const std::uint64_t &low, const std::uint64_t & template <> struct hash { - _NODISCARD size_t operator()(const diskann::tag_uint128 &key) const noexcept + size_t operator()(const diskann::tag_uint128 &key) const noexcept { return Hash128to64(key._data1, key._data2); // map -0 to 0 } From 65bd369695cbcc65b281e1ba2553647de3bd1f0b Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Fri, 19 Jan 2024 17:42:11 +0800 Subject: [PATCH 05/13] fix compile issue --- src/natural_number_map.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/natural_number_map.cpp b/src/natural_number_map.cpp index 347034481..a996dcf75 100644 --- a/src/natural_number_map.cpp +++ b/src/natural_number_map.cpp @@ -112,6 +112,5 @@ template class natural_number_map; template class natural_number_map; template class natural_number_map; template class natural_number_map; -template class natural_number_map; template class natural_number_map; } // namespace diskann From 3011fc64f85f1fc0a8296066bfa89657780db176 Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Tue, 23 Jan 2024 21:34:46 +0800 Subject: [PATCH 06/13] revert change --- src/index.cpp | 68 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 49 insertions(+), 19 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index 9dc1093b4..75efd90a2 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2097,7 +2097,7 @@ std::pair Index::search_with_filters(const TagT tag; if (_location_to_tag.try_get(best_L_nodes[i].id, tag)) { - indices[pos] = tag; + indices[pos] = (IdType)tag; } else { @@ -3354,11 +3354,17 @@ template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const float* query, const size_t K, const uint32_t L, uint64_t* indices, float* distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const float* query, const size_t K, const uint32_t L, uint32_t* indices, float* distances); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const uint8_t* query, const size_t K, const uint32_t L, uint64_t* indices, float* distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const uint8_t* query, const size_t K, const uint32_t L, uint32_t* indices, float* distances); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const int8_t* query, const size_t K, const uint32_t L, uint64_t* indices, float* distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const int8_t* query, const size_t K, const uint32_t L, uint32_t* indices, float* distances); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search( const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); @@ -3381,14 +3387,23 @@ template DISKANN_DLLEXPORT std::pair Index Index::search_with_filters< - uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, - float *distances); + uint64_t>(const float* query, const uint32_t& filter_label, const size_t K, const uint32_t L, uint64_t* indices, + float* distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint32_t>(const float* query, const uint32_t& filter_label, const size_t K, const uint32_t L, uint32_t* indices, + float* distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, - float *distances); + uint64_t>(const uint8_t* query, const uint32_t& filter_label, const size_t K, const uint32_t L, uint64_t* indices, + float* distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint32_t>(const uint8_t* query, const uint32_t& filter_label, const size_t K, const uint32_t L, uint32_t* indices, + float* distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, - float *distances); + uint64_t>(const int8_t* query, const uint32_t& filter_label, const size_t K, const uint32_t L, uint64_t* indices, + float* distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint32_t>(const int8_t* query, const uint32_t& filter_label, const size_t K, const uint32_t L, uint32_t* indices, + float* distances); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, @@ -3420,11 +3435,17 @@ template DISKANN_DLLEXPORT std::pair Index Index::search( - const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const float* query, const size_t K, const uint32_t L, uint64_t* indices, float* distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const float* query, const size_t K, const uint32_t L, uint32_t* indices, float* distances); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const uint8_t* query, const size_t K, const uint32_t L, uint64_t* indices, float* distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const uint8_t* query, const size_t K, const uint32_t L, uint32_t* indices, float* distances); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const int8_t* query, const size_t K, const uint32_t L, uint64_t* indices, float* distances); +template DISKANN_DLLEXPORT std::pair Index::search( + const int8_t* query, const size_t K, const uint32_t L, uint32_t* indices, float* distances); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search( const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); @@ -3447,14 +3468,23 @@ template DISKANN_DLLEXPORT std::pair Index Index::search_with_filters< - uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, - float *distances); + uint64_t>(const float* query, const uint16_t& filter_label, const size_t K, const uint32_t L, uint64_t* indices, + float* distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint32_t>(const float* query, const uint16_t& filter_label, const size_t K, const uint32_t L, uint32_t* indices, + float* distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, - float *distances); + uint64_t>(const uint8_t* query, const uint16_t& filter_label, const size_t K, const uint32_t L, uint64_t* indices, + float* distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint32_t>(const uint8_t* query, const uint16_t& filter_label, const size_t K, const uint32_t L, uint32_t* indices, + float* distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, - float *distances); + uint64_t>(const int8_t* query, const uint16_t& filter_label, const size_t K, const uint32_t L, uint64_t* indices, + float* distances); +template DISKANN_DLLEXPORT std::pair Index::search_with_filters< + uint32_t>(const int8_t* query, const uint16_t& filter_label, const size_t K, const uint32_t L, uint32_t* indices, + float* distances); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, From 4108870f6d1b25c79e6aac57d2d94b48bbbbfee9 Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Tue, 23 Jan 2024 21:36:06 +0800 Subject: [PATCH 07/13] format doc --- src/index.cpp | 72 +++++++++++++++++++++++++-------------------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index 75efd90a2..b96e67fd8 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -3354,17 +3354,17 @@ template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT std::pair Index::search( - const float* query, const size_t K, const uint32_t L, uint64_t* indices, float* distances); + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const float* query, const size_t K, const uint32_t L, uint32_t* indices, float* distances); + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t* query, const size_t K, const uint32_t L, uint64_t* indices, float* distances); + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t* query, const size_t K, const uint32_t L, uint32_t* indices, float* distances); + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t* query, const size_t K, const uint32_t L, uint64_t* indices, float* distances); + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t* query, const size_t K, const uint32_t L, uint32_t* indices, float* distances); + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search( const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); @@ -3387,23 +3387,23 @@ template DISKANN_DLLEXPORT std::pair Index Index::search_with_filters< - uint64_t>(const float* query, const uint32_t& filter_label, const size_t K, const uint32_t L, uint64_t* indices, - float* distances); + uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const float* query, const uint32_t& filter_label, const size_t K, const uint32_t L, uint32_t* indices, - float* distances); + uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const uint8_t* query, const uint32_t& filter_label, const size_t K, const uint32_t L, uint64_t* indices, - float* distances); + uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const uint8_t* query, const uint32_t& filter_label, const size_t K, const uint32_t L, uint32_t* indices, - float* distances); + uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const int8_t* query, const uint32_t& filter_label, const size_t K, const uint32_t L, uint64_t* indices, - float* distances); + uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const int8_t* query, const uint32_t& filter_label, const size_t K, const uint32_t L, uint32_t* indices, - float* distances); + uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + float *distances); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, @@ -3435,17 +3435,17 @@ template DISKANN_DLLEXPORT std::pair Index Index::search( - const float* query, const size_t K, const uint32_t L, uint64_t* indices, float* distances); + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const float* query, const size_t K, const uint32_t L, uint32_t* indices, float* distances); + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t* query, const size_t K, const uint32_t L, uint64_t* indices, float* distances); + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t* query, const size_t K, const uint32_t L, uint32_t* indices, float* distances); + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t* query, const size_t K, const uint32_t L, uint64_t* indices, float* distances); + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t* query, const size_t K, const uint32_t L, uint32_t* indices, float* distances); + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search( const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); @@ -3468,23 +3468,23 @@ template DISKANN_DLLEXPORT std::pair Index Index::search_with_filters< - uint64_t>(const float* query, const uint16_t& filter_label, const size_t K, const uint32_t L, uint64_t* indices, - float* distances); + uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const float* query, const uint16_t& filter_label, const size_t K, const uint32_t L, uint32_t* indices, - float* distances); + uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const uint8_t* query, const uint16_t& filter_label, const size_t K, const uint32_t L, uint64_t* indices, - float* distances); + uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const uint8_t* query, const uint16_t& filter_label, const size_t K, const uint32_t L, uint32_t* indices, - float* distances); + uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const int8_t* query, const uint16_t& filter_label, const size_t K, const uint32_t L, uint64_t* indices, - float* distances); + uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const int8_t* query, const uint16_t& filter_label, const size_t K, const uint32_t L, uint32_t* indices, - float* distances); + uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + float *distances); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, From 237f2e4ae8031e21a14ab4409555ce6844b091e4 Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Mon, 29 Jan 2024 18:24:14 +0800 Subject: [PATCH 08/13] reparate static search and streaming search --- include/abstract_index.h | 6 ++-- include/index.h | 6 ++-- src/abstract_index.cpp | 77 ++++++++++++++++++---------------------- src/index.cpp | 71 +++++++++++++----------------------- 4 files changed, 68 insertions(+), 92 deletions(-) diff --git a/include/abstract_index.h b/include/abstract_index.h index 12feec663..4bc15b9ec 100644 --- a/include/abstract_index.h +++ b/include/abstract_index.h @@ -62,7 +62,8 @@ class AbstractIndex // Initialize space for res_vectors before calling. template size_t search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, - float *distances, std::vector &res_vectors); + float *distances, std::vector &res_vectors, bool use_filters = false, + const std::string filter_label = ""); // Added search overload that takes L as parameter, so that we // can customize L on a per-query basis without tampering with "Parameters" @@ -120,7 +121,8 @@ class AbstractIndex virtual void _set_start_points_at_random(DataType radius, uint32_t random_seed = 0) = 0; virtual int _get_vector_by_tag(TagType &tag, DataType &vec) = 0; virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, - float *distances, DataVector &res_vectors) = 0; + float *distances, DataVector &res_vectors, bool use_filters, + const std::string filter_label) = 0; virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) = 0; virtual void _set_universal_label(const LabelType universal_label) = 0; }; diff --git a/include/index.h b/include/index.h index 199171020..b9bf4f384 100644 --- a/include/index.h +++ b/include/index.h @@ -136,7 +136,8 @@ template clas // Initialize space for res_vectors before calling. DISKANN_DLLEXPORT size_t search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags, - float *distances, std::vector &res_vectors); + float *distances, std::vector &res_vectors, bool use_filters = false, + const std::string filter_label = ""); // Filter support search template @@ -226,7 +227,8 @@ template clas virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) override; virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, - float *distances, DataVector &res_vectors) override; + float *distances, DataVector &res_vectors, bool use_filters = false, + const std::string filter_label = "") override; virtual void _set_universal_label(const LabelType universal_label) override; diff --git a/src/abstract_index.cpp b/src/abstract_index.cpp index a7a5986cc..92665825f 100644 --- a/src/abstract_index.cpp +++ b/src/abstract_index.cpp @@ -24,12 +24,13 @@ std::pair AbstractIndex::search(const data_type *query, cons template size_t AbstractIndex::search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, - float *distances, std::vector &res_vectors) + float *distances, std::vector &res_vectors, bool use_filters, + const std::string filter_label) { auto any_query = std::any(query); auto any_tags = std::any(tags); auto any_res_vectors = DataVector(res_vectors); - return this->_search_with_tags(any_query, K, L, any_tags, distances, any_res_vectors); + return this->_search_with_tags(any_query, K, L, any_tags, distances, any_res_vectors, use_filters, filter_label); } template @@ -162,61 +163,53 @@ template DISKANN_DLLEXPORT std::pair AbstractIndex::search_w const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint64_t *indices, float *distances); -template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const float *query, const uint64_t K, - const uint32_t L, int32_t *tags, - float *distances, - std::vector &res_vectors); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const float *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); -template DISKANN_DLLEXPORT size_t -AbstractIndex::search_with_tags(const uint8_t *query, const uint64_t K, const uint32_t L, - int32_t *tags, float *distances, std::vector &res_vectors); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const uint8_t *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); -template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const int8_t *query, - const uint64_t K, const uint32_t L, - int32_t *tags, float *distances, - std::vector &res_vectors); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const int8_t *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); -template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const float *query, const uint64_t K, - const uint32_t L, uint32_t *tags, - float *distances, - std::vector &res_vectors); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const float *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const uint8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, - std::vector &res_vectors); + std::vector &res_vectors, bool use_filters, const std::string filter_label); -template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const int8_t *query, - const uint64_t K, const uint32_t L, - uint32_t *tags, float *distances, - std::vector &res_vectors); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const int8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); -template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const float *query, const uint64_t K, - const uint32_t L, int64_t *tags, - float *distances, - std::vector &res_vectors); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const float *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); -template DISKANN_DLLEXPORT size_t -AbstractIndex::search_with_tags(const uint8_t *query, const uint64_t K, const uint32_t L, - int64_t *tags, float *distances, std::vector &res_vectors); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const uint8_t *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); -template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const int8_t *query, - const uint64_t K, const uint32_t L, - int64_t *tags, float *distances, - std::vector &res_vectors); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const int8_t *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); -template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const float *query, const uint64_t K, - const uint32_t L, uint64_t *tags, - float *distances, - std::vector &res_vectors); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const float *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const uint8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, - std::vector &res_vectors); + std::vector &res_vectors, bool use_filters, const std::string filter_label); -template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const int8_t *query, - const uint64_t K, const uint32_t L, - uint64_t *tags, float *distances, - std::vector &res_vectors); +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const int8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, + std::vector &res_vectors, bool use_filters, const std::string filter_label); template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout(const float *query, size_t K, size_t L, uint32_t *indices); diff --git a/src/index.cpp b/src/index.cpp index b96e67fd8..b39da7c28 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2023,10 +2023,14 @@ std::pair Index::_search_with_filters(const float *distances) { auto converted_label = this->get_converted_label(raw_label); - - if (typeid(TagT *) == indices.type()) + if (typeid(uint64_t *) == indices.type()) + { + auto ptr = std::any_cast(indices); + return this->search_with_filters(std::any_cast(query), converted_label, K, L, ptr, distances); + } + else if (typeid(uint32_t *) == indices.type()) { - auto ptr = std::any_cast(indices); + auto ptr = std::any_cast(indices); return this->search_with_filters(std::any_cast(query), converted_label, K, L, ptr, distances); } else @@ -2090,24 +2094,7 @@ std::pair Index::search_with_filters(const { if (best_L_nodes[i].id < _max_points) { - // safe because Index uses uint32_t ids internally - // and IDType will be uint32_t or uint64_t - if (_enable_tags) - { - TagT tag; - if (_location_to_tag.try_get(best_L_nodes[i].id, tag)) - { - indices[pos] = (IdType)tag; - } - else - { - continue; - } - } - else - { - indices[pos] = best_L_nodes[i].id; - } + indices[pos] = (IdType)best_L_nodes[i].id; if (distances != nullptr) { @@ -2134,12 +2121,13 @@ std::pair Index::search_with_filters(const template size_t Index::_search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, - const TagType &tags, float *distances, DataVector &res_vectors) + const TagType &tags, float *distances, DataVector &res_vectors, + bool use_filters, const std::string filter_label) { try { return this->search_with_tags(std::any_cast(query), K, L, std::any_cast(tags), distances, - res_vectors.get>()); + res_vectors.get>(), use_filters, filter_label); } catch (const std::bad_any_cast &e) { @@ -2153,7 +2141,8 @@ size_t Index::_search_with_tags(const DataType &query, const ui template size_t Index::search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags, - float *distances, std::vector &res_vectors) + float *distances, std::vector &res_vectors, bool use_filters, + const std::string filter_label) { if (K > (uint64_t)L) { @@ -2173,12 +2162,22 @@ size_t Index::search_with_tags(const T *query, const uint64_t K std::shared_lock ul(_update_lock); const std::vector init_ids = get_init_ids(); - const std::vector unused_filter_label; //_distance->preprocess_query(query, _data_store->get_dims(), // scratch->aligned_query()); _data_store->preprocess_query(query, scratch); - iterate_to_fixed_point(scratch, L, init_ids, false, unused_filter_label, true); + if (!use_filters) + { + const std::vector unused_filter_label; + iterate_to_fixed_point(scratch, L, init_ids, false, unused_filter_label, true); + } + else + { + std::vector filter_vec; + auto converted_label = this->get_converted_label(filter_label); + filter_vec.push_back(converted_label); + iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true); + } NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes(); assert(best_L_nodes.size() <= L); @@ -3423,16 +3422,6 @@ template DISKANN_DLLEXPORT std::pair Index Index::search_with_filters< uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, float *distances); -// TagT==uint128 -template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - tag_uint128>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, - tag_uint128 *indices, float *distances); -template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - tag_uint128>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, - tag_uint128 *indices, float *distances); -template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - tag_uint128>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, - tag_uint128 *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); @@ -3504,15 +3493,5 @@ template DISKANN_DLLEXPORT std::pair Index Index::search_with_filters< uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, float *distances); -// TagT==uint128 -template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - tag_uint128>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, - tag_uint128 *indices, float *distances); -template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - tag_uint128>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, - tag_uint128 *indices, float *distances); -template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - tag_uint128>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, - tag_uint128 *indices, float *distances); } // namespace diskann From 5c014cd06272e72e995f5b655384733a2849efa7 Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Mon, 29 Jan 2024 20:57:06 +0800 Subject: [PATCH 09/13] clean up code --- src/index.cpp | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index b39da7c28..486d41e76 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1992,7 +1992,7 @@ std::pair Index::search(const T *query, con { // safe because Index uses uint32_t ids internally // and IDType will be uint32_t or uint64_t - indices[pos] = best_L_nodes[i].id; + indices[pos] = (IdType)best_L_nodes[i].id; if (distances != nullptr) { #ifdef EXEC_ENV_OLS @@ -3377,13 +3377,6 @@ template DISKANN_DLLEXPORT std::pair Index Index::search( const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); -// TagT==uint128 -template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, tag_uint128 *indices, float *distances); -template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, tag_uint128 *indices, float *distances); -template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, tag_uint128 *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, @@ -3448,13 +3441,6 @@ template DISKANN_DLLEXPORT std::pair Index Index::search( const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); -// TagT==uint128 -template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, tag_uint128 *indices, float *distances); -template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, tag_uint128 *indices, float *distances); -template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, tag_uint128 *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, From 332bf60595975d8ea027902ca9e9a377ef0c08a9 Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Fri, 2 Feb 2024 16:14:38 +0800 Subject: [PATCH 10/13] resolve comment --- include/utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/utils.h b/include/utils.h index 28ef47ddb..90e0022ae 100644 --- a/include/utils.h +++ b/include/utils.h @@ -1013,7 +1013,7 @@ inline std::string get_tag_string(std::uint64_t tag) return std::to_string(tag); } -inline std::string get_tag_string(tag_uint128 tag) +inline std::string get_tag_string(const tag_uint128& tag) { std::string str = std::to_string(tag._data2) + "_" + std::to_string(tag._data1); return str; From 435aec14c0bd3d621d12756577bd8898a8ca1d15 Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Fri, 2 Feb 2024 16:17:53 +0800 Subject: [PATCH 11/13] format doc --- include/utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/utils.h b/include/utils.h index 90e0022ae..d3af5c3a9 100644 --- a/include/utils.h +++ b/include/utils.h @@ -1013,7 +1013,7 @@ inline std::string get_tag_string(std::uint64_t tag) return std::to_string(tag); } -inline std::string get_tag_string(const tag_uint128& tag) +inline std::string get_tag_string(const tag_uint128 &tag) { std::string str = std::to_string(tag._data2) + "_" + std::to_string(tag._data1); return str; From d8d132b8fd564ef125e5dd55866d7ca8ebf74d53 Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Mon, 5 Feb 2024 16:37:59 +0800 Subject: [PATCH 12/13] fix test --- apps/search_memory_index.cpp | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index 1bb02c9bc..1a9acc285 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -163,7 +163,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, for (int64_t i = 0; i < (int64_t)query_num; i++) { auto qs = std::chrono::high_resolution_clock::now(); - if (filtered_search) + if (filtered_search && !tags) { std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i]; @@ -179,8 +179,19 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, } else if (tags) { - index->search_with_tags(query + i * query_aligned_dim, recall_at, L, - query_result_tags.data() + i * recall_at, nullptr, res); + if (!filtered_search) + { + index->search_with_tags(query + i * query_aligned_dim, recall_at, L, + query_result_tags.data() + i * recall_at, nullptr, res); + } + else + { + std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i]; + + index->search_with_tags(query + i * query_aligned_dim, recall_at, L, + query_result_tags.data() + i * recall_at, nullptr, res, true, raw_filter); + } + for (int64_t r = 0; r < (int64_t)recall_at; r++) { query_result_ids[test_id][recall_at * i + r] = query_result_tags[recall_at * i + r]; From cf4a8830c4c156181268edec9cae46af330078d3 Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Mon, 5 Feb 2024 19:51:29 +0800 Subject: [PATCH 13/13] resolve comment --- include/abstract_index.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/abstract_index.h b/include/abstract_index.h index 4bc15b9ec..059866f7c 100644 --- a/include/abstract_index.h +++ b/include/abstract_index.h @@ -121,8 +121,8 @@ class AbstractIndex virtual void _set_start_points_at_random(DataType radius, uint32_t random_seed = 0) = 0; virtual int _get_vector_by_tag(TagType &tag, DataType &vec) = 0; virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, - float *distances, DataVector &res_vectors, bool use_filters, - const std::string filter_label) = 0; + float *distances, DataVector &res_vectors, bool use_filters = false, + const std::string filter_label = "") = 0; virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) = 0; virtual void _set_universal_label(const LabelType universal_label) = 0; };