Skip to content

Commit

Permalink
Fix insert interface issue
Browse files Browse the repository at this point in the history
  • Loading branch information
Sanhaoji2 committed Nov 28, 2024
1 parent 1ebc6c9 commit f95a526
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 80 deletions.
6 changes: 3 additions & 3 deletions include/abstract_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ class AbstractIndex
float *distances);

// insert points with labels, labels should be present for filtered index
template <typename data_type, typename tag_type, typename label_type>
int insert_point(const data_type *point, const tag_type tag, const std::vector<label_type> &labels);
template <typename data_type, typename tag_type>
int insert_point(const data_type *point, const tag_type tag, const std::vector<std::string> &labels);

// insert point for unfiltered index build. do not use with filtered index
template <typename data_type, typename tag_type> int insert_point(const data_type *point, const tag_type tag);
Expand Down Expand Up @@ -116,7 +116,7 @@ class AbstractIndex
virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query, const std::string &filter_label,
const size_t K, const uint32_t L, std::any &indices,
float *distances) = 0;
virtual int _insert_point(const DataType &data_point, const TagType tag, Labelvector &labels) = 0;
virtual int _insert_point(const DataType &data_point, const TagType tag, const std::vector<std::string> &labels) = 0;
virtual int _insert_point(const DataType &data_point, const TagType tag) = 0;
virtual int _lazy_delete(const TagType &tag) = 0;
virtual void _lazy_delete(TagVector &tags, TagVector &failed_tags) = 0;
Expand Down
2 changes: 1 addition & 1 deletion include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
float *distances) override;

virtual int _insert_point(const DataType &data_point, const TagType tag) override;
virtual int _insert_point(const DataType &data_point, const TagType tag, Labelvector &labels) override;
virtual int _insert_point(const DataType &data_point, const TagType tag, const std::vector<std::string> &labels) override;

virtual int _lazy_delete(const TagType &tag) override;

Expand Down
111 changes: 38 additions & 73 deletions src/abstract_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,12 @@ int AbstractIndex::insert_point(const data_type *point, const tag_type tag)
return this->_insert_point(any_point, any_tag);
}

template <typename data_type, typename tag_type, typename label_type>
int AbstractIndex::insert_point(const data_type *point, const tag_type tag, const std::vector<label_type> &labels)
template <typename data_type, typename tag_type>
int AbstractIndex::insert_point(const data_type *point, const tag_type tag, const std::vector<std::string>& labels)
{
auto any_point = std::any(point);
auto any_tag = std::any(tag);
auto any_labels = Labelvector(labels);
return this->_insert_point(any_point, any_tag, any_labels);
return this->_insert_point(any_point, any_tag, labels);
}

template <typename tag_type> int AbstractIndex::lazy_delete(const tag_type &tag)
Expand Down Expand Up @@ -259,75 +258,41 @@ template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, tag_uint128>(c
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, tag_uint128>(const uint8_t* point, const tag_uint128 tag);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, tag_uint128>(const int8_t* point, const tag_uint128 tag);

template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, int32_t, uint16_t>(
const float *point, const int32_t tag, const std::vector<uint16_t> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, int32_t, uint16_t>(
const uint8_t *point, const int32_t tag, const std::vector<uint16_t> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, int32_t, uint16_t>(
const int8_t *point, const int32_t tag, const std::vector<uint16_t> &labels);

template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, uint32_t, uint16_t>(
const float *point, const uint32_t tag, const std::vector<uint16_t> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, uint32_t, uint16_t>(
const uint8_t *point, const uint32_t tag, const std::vector<uint16_t> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, uint32_t, uint16_t>(
const int8_t *point, const uint32_t tag, const std::vector<uint16_t> &labels);

template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, int64_t, uint16_t>(
const float *point, const int64_t tag, const std::vector<uint16_t> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, int64_t, uint16_t>(
const uint8_t *point, const int64_t tag, const std::vector<uint16_t> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, int64_t, uint16_t>(
const int8_t *point, const int64_t tag, const std::vector<uint16_t> &labels);

template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, uint64_t, uint16_t>(
const float *point, const uint64_t tag, const std::vector<uint16_t> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, uint64_t, uint16_t>(
const uint8_t *point, const uint64_t tag, const std::vector<uint16_t> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, uint64_t, uint16_t>(
const int8_t *point, const uint64_t tag, const std::vector<uint16_t> &labels);

template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, tag_uint128, uint16_t>(
const float* point, const tag_uint128 tag, const std::vector<uint16_t>& labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, tag_uint128, uint16_t>(
const uint8_t* point, const tag_uint128 tag, const std::vector<uint16_t>& labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, tag_uint128, uint16_t>(
const int8_t* point, const tag_uint128 tag, const std::vector<uint16_t>& labels);

template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, int32_t, uint32_t>(
const float *point, const int32_t tag, const std::vector<uint32_t> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, int32_t, uint32_t>(
const uint8_t *point, const int32_t tag, const std::vector<uint32_t> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, int32_t, uint32_t>(
const int8_t *point, const int32_t tag, const std::vector<uint32_t> &labels);

template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, uint32_t, uint32_t>(
const float *point, const uint32_t tag, const std::vector<uint32_t> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, uint32_t, uint32_t>(
const uint8_t *point, const uint32_t tag, const std::vector<uint32_t> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, uint32_t, uint32_t>(
const int8_t *point, const uint32_t tag, const std::vector<uint32_t> &labels);

template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, int64_t, uint32_t>(
const float *point, const int64_t tag, const std::vector<uint32_t> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, int64_t, uint32_t>(
const uint8_t *point, const int64_t tag, const std::vector<uint32_t> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, int64_t, uint32_t>(
const int8_t *point, const int64_t tag, const std::vector<uint32_t> &labels);

template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, uint64_t, uint32_t>(
const float *point, const uint64_t tag, const std::vector<uint32_t> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, uint64_t, uint32_t>(
const uint8_t *point, const uint64_t tag, const std::vector<uint32_t> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, uint64_t, uint32_t>(
const int8_t *point, const uint64_t tag, const std::vector<uint32_t> &labels);

template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, tag_uint128, uint32_t>(
const float* point, const tag_uint128 tag, const std::vector<uint32_t>& labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, tag_uint128, uint32_t>(
const uint8_t* point, const tag_uint128 tag, const std::vector<uint32_t>& labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, tag_uint128, uint32_t>(
const int8_t* point, const tag_uint128 tag, const std::vector<uint32_t>& labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, int32_t>(
const float *point, const int32_t tag, const std::vector<std::string> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, int32_t>(
const uint8_t *point, const int32_t tag, const std::vector<std::string> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, int32_t>(
const int8_t *point, const int32_t tag, const std::vector<std::string> &labels);

template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, uint32_t>(
const float *point, const uint32_t tag, const std::vector<std::string> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, uint32_t>(
const uint8_t *point, const uint32_t tag, const std::vector<std::string> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, uint32_t>(
const int8_t *point, const uint32_t tag, const std::vector<std::string> &labels);

template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, int64_t>(
const float *point, const int64_t tag, const std::vector<std::string> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, int64_t>(
const uint8_t *point, const int64_t tag, const std::vector<std::string> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, int64_t>(
const int8_t *point, const int64_t tag, const std::vector<std::string> &labels);

template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, uint64_t>(
const float *point, const uint64_t tag, const std::vector<std::string> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, uint64_t>(
const uint8_t *point, const uint64_t tag, const std::vector<std::string> &labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, uint64_t>(
const int8_t *point, const uint64_t tag, const std::vector<std::string> &labels);

template DISKANN_DLLEXPORT int AbstractIndex::insert_point<float, tag_uint128>(
const float* point, const tag_uint128 tag, const std::vector<std::string>& labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<uint8_t, tag_uint128>(
const uint8_t* point, const tag_uint128 tag, const std::vector<std::string>& labels);
template DISKANN_DLLEXPORT int AbstractIndex::insert_point<int8_t, tag_uint128>(
const int8_t* point, const tag_uint128 tag, const std::vector<std::string>& labels);


template DISKANN_DLLEXPORT int AbstractIndex::lazy_delete<int32_t>(const int32_t &tag);
template DISKANN_DLLEXPORT int AbstractIndex::lazy_delete<uint32_t>(const uint32_t &tag);
Expand Down
12 changes: 9 additions & 3 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3085,12 +3085,18 @@ int Index<T, TagT, LabelT>::_insert_point(const DataType &point, const TagType t
}

template <typename T, typename TagT, typename LabelT>
int Index<T, TagT, LabelT>::_insert_point(const DataType &point, const TagType tag, Labelvector &labels)
int Index<T, TagT, LabelT>::_insert_point(const DataType &point, const TagType tag, const std::vector<std::string>& labels)
{
try
{
return this->insert_point(std::any_cast<const T *>(point), std::any_cast<const TagT>(tag),
labels.get<const std::vector<LabelT>>());
std::vector<LabelT> converted_labels;
converted_labels.reserve(labels.size());
for (const auto& label : labels)
{
auto converted_label = this->get_converted_label(label);
converted_labels.push_back(converted_label);
}
return this->insert_point(std::any_cast<const T *>(point), std::any_cast<const TagT>(tag), converted_labels);
}
catch (const std::bad_any_cast &anycast_e)
{
Expand Down

0 comments on commit f95a526

Please sign in to comment.