Skip to content

Commit

Permalink
added some code to do beta search
Browse files Browse the repository at this point in the history
  • Loading branch information
rakri committed Nov 13, 2024
1 parent ca2fb3b commit e8f90e7
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
5 changes: 4 additions & 1 deletion apps/search_memory_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
const size_t num_frozen_pts = diskann::get_graph_num_frozen_points(index_path);

std::cout << filter_penalty_threshold << " is value of filter_penalty_threshold at driver file" << std::endl;
std::cout << penalty_scale << " is value of beta at driver file" << std::endl;
auto search_params =
diskann::IndexSearchParams(*(std::max_element(Lvec.begin(), Lvec.end())), num_threads, filter_penalty_threshold,
bruteforce_threshold, clustering_threshold);
Expand Down Expand Up @@ -502,7 +503,9 @@ int main(int argc, char **argv)
optional_configs.add_options()("num_local_start",
po::value<uint32_t>(&num_local)->default_value(0),
"How many local start points to use");

optional_configs.add_options()("beta",
po::value<float>(&penalty_scale)->default_value(1),
"Value of scaling parameter beta in float");

optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
program_options_utils::LABEL_TYPE_DESCRIPTION);
Expand Down
2 changes: 1 addition & 1 deletion include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ inline uint32_t num_graphs = 0;
inline uint32_t min_inter_size = 2;
inline bool print_qstats = false;
inline int64_t curr_query = -1;
inline uint32_t penalty_scale = 10;
inline float penalty_scale = 0.8;
inline uint32_t num_sp = 2;
inline bool use_global_start = false;
inline uint32_t num_start_points = 1;
Expand Down
8 changes: 4 additions & 4 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1081,7 +1081,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
uint32_t res = detect_filter_penalty(id, search_invocation, filter_labels);
if ((res) > _filter_penalty_threshold)
continue;
penalty = res * penalty_scale;
penalty = (res == 0)? penalty_scale : 1;
if (print_qstats)
{
std::ofstream out("query_stats.txt", std::ios_base::app);
Expand Down Expand Up @@ -1116,7 +1116,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
uint32_t ids[] = {id};
float distances[] = {std::numeric_limits<float>::max()};
_pq_data_store->get_distance(aligned_query, ids, 1, distances, scratch);
distance = distances[0] + penalty;
distance = distances[0]*penalty;

Neighbor nn = Neighbor(id, distance);
best_L_nodes.insert(nn);
Expand Down Expand Up @@ -1259,7 +1259,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
id_iter++;
continue;
}
penalty = res * penalty_scale;
penalty = (res == 0) ? penalty_scale : 1;

if (print_qstats)
{
Expand Down Expand Up @@ -1305,7 +1305,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
// Insert <id, dist> pairs into the pool of candidates
for (size_t m = 0; m < id_scratch.size(); ++m)
{
best_L_nodes.insert(Neighbor(id_scratch[m], dist_scratch[m] + dist_pens[m]));
best_L_nodes.insert(Neighbor(id_scratch[m], dist_scratch[m]*dist_pens[m]));
}
}
return std::make_pair(hops, cmps);
Expand Down

0 comments on commit e8f90e7

Please sign in to comment.