Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add16bytes tag type #506

Merged
merged 14 commits into from
Feb 6, 2024
3 changes: 0 additions & 3 deletions include/natural_number_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ template <typename Key, typename Value> class natural_number_map
{
public:
static_assert(std::is_trivial<Key>::value, "Key must be a trivial type");
// Some of the class member prototypes are done with this assumption to
// minimize verbosity since it's the only use case.
static_assert(std::is_trivial<Value>::value, "Value must be a trivial type");
rakri marked this conversation as resolved.
Show resolved Hide resolved

// Represents a reference to a element in the map. Used while iterating
// over map entries.
Expand Down
68 changes: 68 additions & 0 deletions include/tag_uint128.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#pragma once
#include <cstdint>
#include <type_traits>

namespace diskann
{
#pragma pack(push, 1)

struct tag_uint128
{
std::uint64_t _data1 = 0;
std::uint64_t _data2 = 0;

bool operator==(const tag_uint128 &other) const
{
return _data1 == other._data1 && _data2 == other._data2;
}

bool operator==(std::uint64_t other) const
{
return _data1 == other && _data2 == 0;
}

tag_uint128 &operator=(const tag_uint128 &other)
{
_data1 = other._data1;
_data2 = other._data2;

return *this;
}

tag_uint128 &operator=(std::uint64_t other)
{
_data1 = other;
_data2 = 0;

return *this;
}
};

#pragma pack(pop)
} // namespace diskann

namespace std
{
// Hash 128 input bits down to 64 bits of output.
// This is intended to be a reasonably good hash function.
inline std::uint64_t Hash128to64(const std::uint64_t &low, const std::uint64_t &high)
{
// Murmur-inspired hashing.
const std::uint64_t kMul = 0x9ddfea08eb382d69ULL;
std::uint64_t a = (low ^ high) * kMul;
a ^= (a >> 47);
std::uint64_t b = (high ^ a) * kMul;
b ^= (b >> 47);
b *= kMul;
return b;
}

template <> struct hash<diskann::tag_uint128>
{
size_t operator()(const diskann::tag_uint128 &key) const noexcept
{
return Hash128to64(key._data1, key._data2); // map -0 to 0
}
};

} // namespace std
12 changes: 12 additions & 0 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ typedef int FileHandle;
#include "windows_customizations.h"
#include "tsl/robin_set.h"
#include "types.h"
#include "tag_uint128.h"
#include <any>

#ifdef EXEC_ENV_OLS
Expand Down Expand Up @@ -1007,6 +1008,17 @@ void block_convert(std::ofstream &writr, std::ifstream &readr, float *read_buf,

DISKANN_DLLEXPORT void normalize_data_file(const std::string &inFileName, const std::string &outFileName);

inline std::string get_tag_string(std::uint64_t tag)
{
return std::to_string(tag);
}

inline std::string get_tag_string(tag_uint128 tag)
Sanhaoji2 marked this conversation as resolved.
Show resolved Hide resolved
{
rakri marked this conversation as resolved.
Show resolved Hide resolved
std::string str = std::to_string(tag._data2) + "_" + std::to_string(tag._data1);
return str;
}

}; // namespace diskann

struct PivotContainer
Expand Down
64 changes: 51 additions & 13 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "tsl/robin_map.h"
#include "tsl/robin_set.h"
#include "windows_customizations.h"
#include "tag_uint128.h"
#if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD)
#include "gperftools/malloc_extension.h"
#endif
Expand Down Expand Up @@ -717,7 +718,7 @@ template <typename T, typename TagT, typename LabelT> int Index<T, TagT, LabelT>
std::shared_lock<std::shared_timed_mutex> lock(_tag_lock);
if (_tag_to_location.find(tag) == _tag_to_location.end())
{
diskann::cout << "Tag " << tag << " does not exist" << std::endl;
diskann::cout << "Tag " << get_tag_string(tag) << " does not exist" << std::endl;
return -1;
}

Expand Down Expand Up @@ -1991,7 +1992,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search(const T *query, con
{
// safe because Index uses uint32_t ids internally
// and IDType will be uint32_t or uint64_t
indices[pos] = (IdType)best_L_nodes[i].id;
indices[pos] = best_L_nodes[i].id;
if (distances != nullptr)
{
#ifdef EXEC_ENV_OLS
Expand Down Expand Up @@ -2022,14 +2023,10 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::_search_with_filters(const
float *distances)
{
auto converted_label = this->get_converted_label(raw_label);
if (typeid(uint64_t *) == indices.type())
rakri marked this conversation as resolved.
Show resolved Hide resolved
{
auto ptr = std::any_cast<uint64_t *>(indices);
return this->search_with_filters(std::any_cast<T *>(query), converted_label, K, L, ptr, distances);
}
else if (typeid(uint32_t *) == indices.type())

if (typeid(TagT *) == indices.type())
{
auto ptr = std::any_cast<uint32_t *>(indices);
auto ptr = std::any_cast<TagT *>(indices);
return this->search_with_filters(std::any_cast<T *>(query), converted_label, K, L, ptr, distances);
}
else
Expand Down Expand Up @@ -2109,7 +2106,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search_with_filters(const
}
else
{
indices[pos] = (IdType)best_L_nodes[i].id;
indices[pos] = best_L_nodes[i].id;
}

if (distances != nullptr)
Expand Down Expand Up @@ -2861,7 +2858,7 @@ int Index<T, TagT, LabelT>::insert_point(const T *point, const TagT tag, const s
{

assert(_has_built);
if (tag == static_cast<TagT>(0))
if (tag == 0)
{
throw diskann::ANNException("Do not insert point with tag 0. That is "
"reserved for points hidden "
Expand All @@ -2879,7 +2876,7 @@ int Index<T, TagT, LabelT>::insert_point(const T *point, const TagT tag, const s
if (labels.empty())
{
release_location(location);
std::cerr << "Error: Can't insert point with tag " + std::to_string(tag) +
std::cerr << "Error: Can't insert point with tag " + get_tag_string(tag) +
" . there are no labels for the point."
<< std::endl;
return -1;
Expand Down Expand Up @@ -3047,7 +3044,7 @@ template <typename T, typename TagT, typename LabelT> int Index<T, TagT, LabelT>

if (_tag_to_location.find(tag) == _tag_to_location.end())
{
diskann::cerr << "Delete tag not found " << tag << std::endl;
diskann::cerr << "Delete tag not found " << get_tag_string(tag) << std::endl;
return -1;
}
assert(_tag_to_location[tag] < _max_points);
Expand Down Expand Up @@ -3336,6 +3333,9 @@ template DISKANN_DLLEXPORT class Index<uint8_t, int64_t, uint32_t>;
template DISKANN_DLLEXPORT class Index<float, uint64_t, uint32_t>;
template DISKANN_DLLEXPORT class Index<int8_t, uint64_t, uint32_t>;
template DISKANN_DLLEXPORT class Index<uint8_t, uint64_t, uint32_t>;
template DISKANN_DLLEXPORT class Index<float, tag_uint128, uint32_t>;
template DISKANN_DLLEXPORT class Index<int8_t, tag_uint128, uint32_t>;
template DISKANN_DLLEXPORT class Index<uint8_t, tag_uint128, uint32_t>;
// Label with short int 2 byte
template DISKANN_DLLEXPORT class Index<float, int32_t, uint16_t>;
template DISKANN_DLLEXPORT class Index<int8_t, int32_t, uint16_t>;
Expand All @@ -3349,6 +3349,9 @@ template DISKANN_DLLEXPORT class Index<uint8_t, int64_t, uint16_t>;
template DISKANN_DLLEXPORT class Index<float, uint64_t, uint16_t>;
template DISKANN_DLLEXPORT class Index<int8_t, uint64_t, uint16_t>;
template DISKANN_DLLEXPORT class Index<uint8_t, uint64_t, uint16_t>;
template DISKANN_DLLEXPORT class Index<float, tag_uint128, uint16_t>;
template DISKANN_DLLEXPORT class Index<int8_t, tag_uint128, uint16_t>;
template DISKANN_DLLEXPORT class Index<uint8_t, tag_uint128, uint16_t>;

template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint32_t>::search<uint64_t>(
const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances);
Expand All @@ -3375,6 +3378,13 @@ template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t,
const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint32_t>::search<uint32_t>(
const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances);
// TagT==uint128
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, tag_uint128, uint32_t>::search<tag_uint128>(
const float *query, const size_t K, const uint32_t L, tag_uint128 *indices, float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, tag_uint128, uint32_t>::search<tag_uint128>(
const uint8_t *query, const size_t K, const uint32_t L, tag_uint128 *indices, float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, tag_uint128, uint32_t>::search<tag_uint128>(
const int8_t *query, const size_t K, const uint32_t L, tag_uint128 *indices, float *distances);

template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint32_t>::search_with_filters<
uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
Expand Down Expand Up @@ -3413,6 +3423,16 @@ template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t,
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint32_t>::search_with_filters<
uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
float *distances);
// TagT==uint128
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, tag_uint128, uint32_t>::search_with_filters<
tag_uint128>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L,
tag_uint128 *indices, float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, tag_uint128, uint32_t>::search_with_filters<
tag_uint128>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L,
tag_uint128 *indices, float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, tag_uint128, uint32_t>::search_with_filters<
tag_uint128>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L,
tag_uint128 *indices, float *distances);

template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint16_t>::search<uint64_t>(
const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances);
Expand All @@ -3439,6 +3459,13 @@ template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t,
const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint16_t>::search<uint32_t>(
const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances);
// TagT==uint128
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, tag_uint128, uint16_t>::search<tag_uint128>(
const float *query, const size_t K, const uint32_t L, tag_uint128 *indices, float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, tag_uint128, uint16_t>::search<tag_uint128>(
const uint8_t *query, const size_t K, const uint32_t L, tag_uint128 *indices, float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, tag_uint128, uint16_t>::search<tag_uint128>(
const int8_t *query, const size_t K, const uint32_t L, tag_uint128 *indices, float *distances);

template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint16_t>::search_with_filters<
uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
Expand Down Expand Up @@ -3477,4 +3504,15 @@ template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t,
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint16_t>::search_with_filters<
uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
float *distances);
// TagT==uint128
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, tag_uint128, uint16_t>::search_with_filters<
tag_uint128>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L,
tag_uint128 *indices, float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, tag_uint128, uint16_t>::search_with_filters<
tag_uint128>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L,
tag_uint128 *indices, float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, tag_uint128, uint16_t>::search_with_filters<
tag_uint128>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L,
tag_uint128 *indices, float *distances);

} // namespace diskann
2 changes: 2 additions & 0 deletions src/natural_number_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <boost/dynamic_bitset.hpp>

#include "natural_number_map.h"
#include "tag_uint128.h"

namespace diskann
{
Expand Down Expand Up @@ -111,4 +112,5 @@ template class natural_number_map<uint32_t, int32_t>;
template class natural_number_map<uint32_t, uint32_t>;
template class natural_number_map<uint32_t, int64_t>;
template class natural_number_map<uint32_t, uint64_t>;
template class natural_number_map<uint32_t, tag_uint128>;
} // namespace diskann
Loading