Skip to content

Commit

Permalink
introduce duplicate add for hgraph (#251)
Browse files Browse the repository at this point in the history
- add "duplicate build test" for hgraph and hnsw

Signed-off-by: LHT129 <[email protected]>
  • Loading branch information
LHT129 authored Dec 25, 2024
1 parent 4f7e21c commit 55dcb8a
Show file tree
Hide file tree
Showing 10 changed files with 170 additions and 17 deletions.
70 changes: 60 additions & 10 deletions src/algorithm/hgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,22 @@ HGraph::Add(const DatasetPtr& data) {
CHECK_ARGUMENT(base_dim == dim_,
fmt::format("base.dim({}) must be equal to index.dim({})", base_dim, dim_));
CHECK_ARGUMENT(data->GetFloat32Vectors() != nullptr, "base.float_vector is nullptr");

this->basic_flatten_codes_->Train(data->GetFloat32Vectors(), data->GetNumElements());
this->basic_flatten_codes_->BatchInsertVector(data->GetFloat32Vectors(),
data->GetNumElements());
if (use_reorder_) {
this->high_precise_codes_->Train(data->GetFloat32Vectors(), data->GetNumElements());
this->high_precise_codes_->BatchInsertVector(data->GetFloat32Vectors(),
data->GetNumElements());
auto split_datasets = this->split_dataset_by_duplicate_label(data, failed_ids);

for (auto& data_ptr : split_datasets) {
this->basic_flatten_codes_->Train(data_ptr->GetFloat32Vectors(),
data_ptr->GetNumElements());
this->basic_flatten_codes_->BatchInsertVector(data_ptr->GetFloat32Vectors(),
data_ptr->GetNumElements());
if (use_reorder_) {
this->high_precise_codes_->Train(data_ptr->GetFloat32Vectors(),
data_ptr->GetNumElements());
this->high_precise_codes_->BatchInsertVector(data_ptr->GetFloat32Vectors(),
data_ptr->GetNumElements());
}
this->hnsw_add(data_ptr);
}
this->hnsw_add(data);
return failed_ids; // TODO(LHT): maybe use copy for small batch filtered by label
return failed_ids;
} catch (const std::invalid_argument& e) {
LOG_ERROR_AND_RETURNS(
ErrorType::INVALID_ARGUMENT, "failed to add(invalid argument): ", e.what());
Expand Down Expand Up @@ -866,5 +871,50 @@ bool
HGraph::CheckFeature(IndexFeature feature) const {
return this->feature_list_.CheckFeature(feature);
}
Vector<DatasetPtr>
HGraph::split_dataset_by_duplicate_label(const DatasetPtr& dataset,
std::vector<LabelType>& failed_ids) const {
Vector<DatasetPtr> return_datasets(0, this->allocator_);
auto count = dataset->GetNumElements();
auto dim = dataset->GetDim();
auto* labels = dataset->GetIds();
auto* vec = dataset->GetFloat32Vectors();
UnorderedSet<LabelType> temp_labels(allocator_);

for (uint64_t i = 0; i < count; ++i) {
if (label_lookup_.find(labels[i]) != label_lookup_.end() or
temp_labels.find(labels[i]) != temp_labels.end()) {
failed_ids.emplace_back(i);
continue;
}
temp_labels.emplace(labels[i]);
}
failed_ids.emplace_back(count);

if (failed_ids.size() == 1) {
return_datasets.emplace_back(dataset);
return return_datasets;
}
int64_t start = -1;
for (auto end : failed_ids) {
if (end - start == 1) {
start = end;
continue;
}
auto new_dataset = Dataset::Make();
new_dataset->NumElements(end - start - 1)
->Dim(dim)
->Ids(labels + start + 1)
->Float32Vectors(vec + dim * (start + 1))
->Owner(false);
return_datasets.emplace_back(new_dataset);
start = end;
}
failed_ids.pop_back();
for (auto& failed_id : failed_ids) {
failed_id = labels[failed_id];
}
return return_datasets;
}

} // namespace vsag
4 changes: 4 additions & 0 deletions src/algorithm/hgraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ class HGraph {
void
init_features();

Vector<DatasetPtr>
split_dataset_by_duplicate_label(const DatasetPtr& dataset,
std::vector<LabelType>& failed_ids) const;

private:
FlattenInterfacePtr basic_flatten_codes_{nullptr};
FlattenInterfacePtr high_precise_codes_{nullptr};
Expand Down
4 changes: 2 additions & 2 deletions src/impl/conjugate_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ ConjugateGraph::get_neighbors(int64_t from_tag_id) const {
}

tl::expected<uint32_t, Error>
ConjugateGraph::EnhanceResult(std::priority_queue<std::pair<float, size_t>>& results,
ConjugateGraph::EnhanceResult(std::priority_queue<std::pair<float, LabelType>>& results,
const std::function<float(int64_t)>& distance_of_tag) const {
if (this->is_empty()) {
return 0;
}

int64_t k = results.size();
int64_t look_at_k = std::min(LOOK_AT_K, k);
std::priority_queue<std::pair<float, size_t>> old_results(results);
std::priority_queue<std::pair<float, LabelType>> old_results(results);
std::vector<int64_t> to_be_visited(look_at_k);
std::unordered_set<int64_t> visited_set;
uint32_t successfully_enhanced = 0;
Expand Down
2 changes: 1 addition & 1 deletion src/impl/conjugate_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ConjugateGraph {
AddNeighbor(int64_t from_tag_id, int64_t to_tag_id);

tl::expected<uint32_t, Error>
EnhanceResult(std::priority_queue<std::pair<float, size_t>>& results,
EnhanceResult(std::priority_queue<std::pair<float, LabelType>>& results,
const std::function<float(int64_t)>& distance_of_tag) const;

public:
Expand Down
4 changes: 2 additions & 2 deletions src/index/hnsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ HNSW::knn_search(const DatasetPtr& query,

// perform search
int64_t original_k = k;
std::priority_queue<std::pair<float, size_t>> results;
std::priority_queue<std::pair<float, LabelType>> results;
double time_cost;
try {
Timer t(time_cost);
Expand Down Expand Up @@ -354,7 +354,7 @@ HNSW::range_search(const DatasetPtr& query,
auto params = HnswSearchParameters::FromJson(parameters);

// perform search
std::priority_queue<std::pair<float, size_t>> results;
std::priority_queue<std::pair<float, LabelType>> results;
double time_cost;
try {
std::shared_lock lock(rw_mutex_);
Expand Down
2 changes: 1 addition & 1 deletion src/typing.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
namespace vsag {

using InnerIdType = uint32_t; // inner id's type; index's vector count may less than 2^31 - 1
using LabelType = uint64_t; // external id's type
using LabelType = int64_t; // external id's type

using JsonType = nlohmann::json; // alias for nlohmann::json type

Expand Down
36 changes: 36 additions & 0 deletions tests/test_hgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,39 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex,
}
vsag::Options::Instance().set_block_size_limit(origin_size);
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Duplicate Build", "[ft][hgraph]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2", "ip", "cosine");
std::vector<std::pair<std::string, float>> test_cases = {
{"sq8", 0.97}, {"fp32", 0.99}, {"sq8_uniform", 0.95}};
const std::string name = "hgraph";
auto search_param = fmt::format(search_param_tmp, 200);
for (auto& dim : dims) {
for (auto& [base_quantization_str, recall] : test_cases) {
vsag::Options::Instance().set_block_size_limit(size);
auto param =
GenerateHGraphBuildParametersString(metric_type, dim, base_quantization_str);
auto index = TestFactory(name, param, true);
if (index->CheckFeature(vsag::SUPPORT_BUILD)) {
auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type);
TestDuplicateAdd(index, dataset);
if (index->CheckFeature(vsag::SUPPORT_KNN_SEARCH)) {
TestKnnSearch(index, dataset, search_param, recall, true);
if (index->CheckFeature(vsag::SUPPORT_SEARCH_CONCURRENT)) {
TestConcurrentKnnSearch(index, dataset, search_param, recall, true);
}
}
if (index->CheckFeature(vsag::SUPPORT_RANGE_SEARCH)) {
TestRangeSearch(index, dataset, search_param, recall, 10, true);
TestRangeSearch(index, dataset, search_param, recall / 2.0, 5, true);
}
if (index->CheckFeature(vsag::SUPPORT_KNN_SEARCH_WITH_ID_FILTER)) {
TestFilterSearch(index, dataset, search_param, recall, true);
}
}
vsag::Options::Instance().set_block_size_limit(origin_size);
}
}
}
25 changes: 24 additions & 1 deletion tests/test_hnsw_new.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,4 +325,27 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex,
TestContinueAddIgnoreRequire(index.value(), dataset);
}
vsag::Options::Instance().set_block_size_limit(origin_size);
}
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Duplicate Add", "[ft][hnsw]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2", "ip", "cosine");
const std::string name = "hnsw";
auto search_param = fmt::format(search_param_tmp, 100);
for (auto& dim : dims) {
vsag::Options::Instance().set_block_size_limit(size);
auto param = GenerateHNSWBuildParametersString(metric_type, dim);
auto index = TestFactory(name, param, true);

auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type);
TestDuplicateAdd(index, dataset);
TestKnnSearch(index, dataset, search_param, 0.99, true);
TestConcurrentKnnSearch(index, dataset, search_param, 0.99, true);
TestRangeSearch(index, dataset, search_param, 0.99, 10, true);
TestRangeSearch(index, dataset, search_param, 0.49, 5, true);
TestFilterSearch(index, dataset, search_param, 0.99, true);

vsag::Options::Instance().set_block_size_limit(origin_size);
}
}
38 changes: 38 additions & 0 deletions tests/test_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,5 +368,43 @@ TestIndex::TestContinueAddIgnoreRequire(const TestIndex::IndexPtr& index,
auto add_index = index->Add(data_one);
}
}
void
TestIndex::TestDuplicateAdd(const TestIndex::IndexPtr& index, const TestDatasetPtr& dataset) {
auto double_dataset = vsag::Dataset::Make();
uint64_t base_count = dataset->base_->GetNumElements();
uint64_t double_count = base_count * 2;
auto dim = dataset->base_->GetDim();
auto new_data = std::shared_ptr<float[]>(new float[double_count * dim]);
auto new_ids = std::shared_ptr<int64_t[]>(new int64_t[double_count]);
memcpy(new_data.get(), dataset->base_->GetFloat32Vectors(), base_count * dim * sizeof(float));
memcpy(new_data.get() + base_count * dim,
dataset->base_->GetFloat32Vectors(),
base_count * dim * sizeof(float));
memcpy(new_ids.get(), dataset->base_->GetIds(), base_count * sizeof(int64_t));
memcpy(new_ids.get() + base_count, dataset->base_->GetIds(), base_count * sizeof(int64_t));
double_dataset->Dim(dim)
->NumElements(double_count)
->Ids(new_ids.get())
->Float32Vectors(new_data.get())
->Owner(false);

auto check_func = [&](std::vector<int64_t>& failed_ids) -> void {
REQUIRE(failed_ids.size() == base_count);
std::sort(failed_ids.begin(), failed_ids.end());
for (uint64_t i = 0; i < base_count; ++i) {
REQUIRE(failed_ids[i] == dataset->base_->GetIds()[i]);
}
};

// add once with duplicate;
auto add_index = index->Add(double_dataset);
REQUIRE(add_index.has_value());
check_func(add_index.value());

// add twice with duplicate;
auto add_index_2 = index->Add(dataset->base_);
REQUIRE(add_index_2.has_value());
check_func(add_index_2.value());
}

} // namespace fixtures
2 changes: 2 additions & 0 deletions tests/test_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ class TestIndex {
TestConcurrentAdd(const IndexPtr& index,
const TestDatasetPtr& dataset,
bool expected_success = true);
static void
TestDuplicateAdd(const IndexPtr& index, const TestDatasetPtr& dataset);
};

} // namespace fixtures

0 comments on commit 55dcb8a

Please sign in to comment.