From e8f90e787a18ebf9da4306494aed94c0d18fd196 Mon Sep 17 00:00:00 2001 From: rakri Date: Tue, 12 Nov 2024 20:32:53 -0800 Subject: [PATCH] added some code to do beta search --- apps/search_memory_index.cpp | 5 ++++- include/index.h | 2 +- src/index.cpp | 8 ++++---- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index 2f8d06cdd..19cfc902c 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -79,6 +79,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, const size_t num_frozen_pts = diskann::get_graph_num_frozen_points(index_path); std::cout << filter_penalty_threshold << " is value of filter_penalty_threshold at driver file" << std::endl; + std::cout << penalty_scale << " is value of beta at driver file" << std::endl; auto search_params = diskann::IndexSearchParams(*(std::max_element(Lvec.begin(), Lvec.end())), num_threads, filter_penalty_threshold, bruteforce_threshold, clustering_threshold); @@ -502,7 +503,9 @@ int main(int argc, char **argv) optional_configs.add_options()("num_local_start", po::value(&num_local)->default_value(0), "How many local start points to use"); - + optional_configs.add_options()("beta", + po::value(&penalty_scale)->default_value(1), + "Value of scaling parameter beta in float"); optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), program_options_utils::LABEL_TYPE_DESCRIPTION); diff --git a/include/index.h b/include/index.h index 105153deb..6b171808e 100644 --- a/include/index.h +++ b/include/index.h @@ -46,7 +46,7 @@ inline uint32_t num_graphs = 0; inline uint32_t min_inter_size = 2; inline bool print_qstats = false; inline int64_t curr_query = -1; -inline uint32_t penalty_scale = 10; +inline float penalty_scale = 0.8; inline uint32_t num_sp = 2; inline bool use_global_start = false; inline uint32_t num_start_points = 1; diff --git a/src/index.cpp b/src/index.cpp index a27f62bd5..f3a0d2b28 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1081,7 +1081,7 @@ std::pair Index::iterate_to_fixed_point( uint32_t res = detect_filter_penalty(id, search_invocation, filter_labels); if ((res) > _filter_penalty_threshold) continue; - penalty = res * penalty_scale; + penalty = (res == 0)? penalty_scale : 1; if (print_qstats) { std::ofstream out("query_stats.txt", std::ios_base::app); @@ -1116,7 +1116,7 @@ std::pair Index::iterate_to_fixed_point( uint32_t ids[] = {id}; float distances[] = {std::numeric_limits::max()}; _pq_data_store->get_distance(aligned_query, ids, 1, distances, scratch); - distance = distances[0] + penalty; + distance = distances[0]*penalty; Neighbor nn = Neighbor(id, distance); best_L_nodes.insert(nn); @@ -1259,7 +1259,7 @@ std::pair Index::iterate_to_fixed_point( id_iter++; continue; } - penalty = res * penalty_scale; + penalty = (res == 0) ? penalty_scale : 1; if (print_qstats) { @@ -1305,7 +1305,7 @@ std::pair Index::iterate_to_fixed_point( // Insert pairs into the pool of candidates for (size_t m = 0; m < id_scratch.size(); ++m) { - best_L_nodes.insert(Neighbor(id_scratch[m], dist_scratch[m] + dist_pens[m])); + best_L_nodes.insert(Neighbor(id_scratch[m], dist_scratch[m]*dist_pens[m])); } } return std::make_pair(hops, cmps);