Skip to content

Commit

Permalink
fixed some gt bug and some recall optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
rakri committed Dec 4, 2024
1 parent 353258f commit 7133b21
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 35 deletions.
77 changes: 53 additions & 24 deletions apps/search_memory_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
const std::string &query_file, const std::string &truthset_file, const uint32_t num_threads,
const uint32_t recall_at, const bool print_all_recalls, const std::vector<uint32_t> &Lvec,
const bool dynamic, const bool tags, const bool show_qps_per_thread,
const std::vector<std::string> &query_filters, const float fail_if_recall_below, const uint32_t max_L_per_seller = 0, const bool post_process = false)
const std::vector<std::string> &query_filters, const float fail_if_recall_below, const uint32_t max_K_per_seller = std::numeric_limits<uint32_t>::max(), const bool diverse_search = false, const bool scale_seller_limits = false, const bool post_process = false)
{
std::cout<<max_K_per_seller <<" " << diverse_search <<" " << scale_seller_limits << " " << post_process << std::endl;
std::vector<uint32_t> location_to_sellers;
std::string seller_file = index_path +"_sellers.txt";
if (file_exists(seller_file)) {
Expand Down Expand Up @@ -209,12 +210,18 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
continue;
}

query_result_ids[test_id].resize(recall_at * query_num);
query_result_dists[test_id].resize(recall_at * query_num);
query_result_ids[test_id].resize(recall_at * query_num, std::numeric_limits<uint32_t>::max());
query_result_dists[test_id].resize(recall_at * query_num, std::numeric_limits<float>::max());
std::vector<T *> res = std::vector<T *>();

uint32_t maxLperSeller = (max_L_per_seller > 0) ? max_L_per_seller : L;
//uint32_t maxLperSeller = (max_L_per_seller > 0) ? max_L_per_seller : L;

//maxLperSeller = (maxLperSeller == 0)? 1 : maxLperSeller;
uint32_t maxLperSeller = max_K_per_seller;
if (diverse_search && scale_seller_limits) {
maxLperSeller = (1.0*L* max_K_per_seller)/(1.0*recall_at);
// std::cout<<maxLperSeller<<std::endl;
}

auto s = std::chrono::high_resolution_clock::now();
omp_set_num_threads(num_threads);
Expand Down Expand Up @@ -258,25 +265,37 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
}
else
{
if (maxLperSeller != L && !post_process)
std::vector<uint32_t> results(L,std::numeric_limits<uint32_t>::max());
std::vector<float> dists(L,std::numeric_limits<float>::max());
uint32_t K_to_use = (post_process == true) ? L : recall_at;

if (diverse_search) {

cmp_stats[i] = index
->diverse_search(query + i * query_aligned_dim, recall_at, L, maxLperSeller,
query_result_ids[test_id].data() + i * recall_at)
.second;
else {
->diverse_search(query + i * query_aligned_dim, K_to_use, L, maxLperSeller,
results.data(), dists.data())
.second;
} else {
// {
std::vector<uint32_t> results(L,0);
std::vector<float> dists(L,0);
cmp_stats[i] = index
->search(query + i * query_aligned_dim, L, L,
->search(query + i * query_aligned_dim, K_to_use, L,
results.data(), dists.data())
.second;
diskann::bestCandidates final_results(recall_at, maxLperSeller, location_to_sellers);
}
if (post_process) {
diskann::bestCandidates final_results(recall_at, max_K_per_seller, location_to_sellers);
for (uint32_t rr = 0; rr < L; rr++) {
final_results.insert(results[rr], dists[rr]);
}
for (uint32_t ctr = 0; ctr < final_results.best_L_nodes.size(); ctr++) {

for (uint32_t ctr = 0; ctr < std::min(final_results.best_L_nodes.size(), (uint64_t)recall_at); ctr++) {
query_result_ids[test_id][recall_at * i + ctr] = final_results.best_L_nodes._data[ctr].id;
query_result_dists[test_id][recall_at * i + ctr] = final_results.best_L_nodes._data[ctr].distance;
}
} else {
for (uint32_t ctr = 0; ctr < std::min(results.size(),(uint64_t)recall_at); ctr++) {
query_result_ids[test_id][recall_at * i + ctr] = results[ctr];
query_result_dists[test_id][recall_at * i + ctr] = dists[ctr];
}
}
}
Expand All @@ -298,7 +317,10 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; curr_recall++)
{
recalls.push_back(diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim,
query_result_ids[test_id].data(), recall_at, curr_recall));
query_result_ids[test_id].data(), recall_at, curr_recall, query_result_dists[test_id].data()));
// recalls.push_back(diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim,
// query_result_ids[test_id].data(), recall_at, curr_recall));

}
}

Expand Down Expand Up @@ -357,7 +379,7 @@ int main(int argc, char **argv)
query_filters_file;
uint32_t num_threads, K, max_L_per_seller;
std::vector<uint32_t> Lvec;
bool print_all_recalls, dynamic, tags, show_qps_per_thread, post_process;
bool print_all_recalls, dynamic, tags, show_qps_per_thread, post_process, diverse_search, scale_seller_limits;
float fail_if_recall_below = 0.0f;

po::options_description desc{
Expand Down Expand Up @@ -399,12 +421,19 @@ int main(int argc, char **argv)
optional_configs.add_options()("num_threads,T",
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
program_options_utils::NUMBER_THREADS_DESCRIPTION);
optional_configs.add_options()("max_L_per_seller",
optional_configs.add_options()("max_K_per_seller",
po::value<uint32_t>(&max_L_per_seller)->default_value(0),
"How many results per seller we want search results to contain");
optional_configs.add_options()("diverse_search",
po::value<bool>(&diverse_search)->default_value(false),
"Whether to run diverse search or baseline search");
optional_configs.add_options()("scale_seller_limits",
po::value<bool>(&scale_seller_limits)->default_value(false),
"Whether to run scale the max_L_per_seller based on the L value");
optional_configs.add_options()("post_process",
po::value<bool>(&post_process)->default_value(false),
"Whether to do vanilla search + post-processing for diversity");
"Whether to post-processing to ensure correct diversity");


optional_configs.add_options()(
"dynamic", po::value<bool>(&dynamic)->default_value(false),
Expand Down Expand Up @@ -504,19 +533,19 @@ int main(int argc, char **argv)
{
return search_memory_index<int8_t, uint16_t>(
metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls,
Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, post_process);
Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, diverse_search, scale_seller_limits, post_process);
}
else if (data_type == std::string("uint8"))
{
return search_memory_index<uint8_t, uint16_t>(
metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls,
Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, post_process);
Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, diverse_search, scale_seller_limits, post_process);
}
else if (data_type == std::string("float"))
{
return search_memory_index<float, uint16_t>(metric, index_path_prefix, result_path, query_file, gt_file,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, post_process);
show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, diverse_search, scale_seller_limits, post_process);
}
else
{
Expand All @@ -530,19 +559,19 @@ int main(int argc, char **argv)
{
return search_memory_index<int8_t>(metric, index_path_prefix, result_path, query_file, gt_file,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, post_process);
show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, diverse_search, scale_seller_limits, post_process);
}
else if (data_type == std::string("uint8"))
{
return search_memory_index<uint8_t>(metric, index_path_prefix, result_path, query_file, gt_file,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, post_process);
show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, diverse_search, scale_seller_limits, post_process);
}
else if (data_type == std::string("float"))
{
return search_memory_index<float>(metric, index_path_prefix, result_path, query_file, gt_file,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, post_process);
show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, diverse_search, scale_seller_limits, post_process);
}
else
{
Expand Down
20 changes: 14 additions & 6 deletions apps/utils/compute_diverse_groundtruth.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,13 @@ std::vector<std::vector<std::pair<uint32_t, float>>> processUnfilteredParts(cons
exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, start_id, base_data, nqueries, query_data, kperseller, running_results,
metric);


delete[] closest_points_part;
delete[] dist_closest_points_part;

diskann::aligned_free(base_data);
}

for (size_t i = 0; i < nqueries; i++)
{
auto & cur_results = running_results[i];
Expand All @@ -413,11 +420,6 @@ std::vector<std::vector<std::pair<uint32_t, float>>> processUnfilteredParts(cons
}
}

delete[] closest_points_part;
delete[] dist_closest_points_part;

diskann::aligned_free(base_data);
}
return res;
};

Expand Down Expand Up @@ -521,8 +523,14 @@ int aux_main(const std::string &base_file, const std::string &query_file, const

++j;
}
if (j < k)
if (j < k) {
std::cout << "WARNING: found less than k GT entries for query " << i << std::endl;
while (j<k) {
dist_closest_points[i * k + j] = std::numeric_limits<float>::max();
closest_points[i * k + j] = std::numeric_limits<int32_t>::max();
j++;
}
}
}

save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, nqueries, k);
Expand Down
4 changes: 2 additions & 2 deletions include/neighbor.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ class NeighborPriorityQueue
{
}

explicit NeighborPriorityQueue(size_t capacity) : _size(0), _capacity(capacity), _cur(0), _data(capacity + 1)
explicit NeighborPriorityQueue(size_t capacity) : _size(0), _capacity(capacity), _cur(0), _data(capacity + 1, Neighbor(std::numeric_limits<uint32_t>::max(), std::numeric_limits<float>::max()))
{
}

void setup(uint32_t capacity) {
_data.resize(capacity+1);
_data.resize(capacity+1,Neighbor(std::numeric_limits<uint32_t>::max(), std::numeric_limits<float>::max()));
_capacity = capacity;
}
// Inserts the item ordered into the set up to the sets capacity.
Expand Down
2 changes: 1 addition & 1 deletion include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ inline void copy_file(std::string in_file, std::string out_file)
}

DISKANN_DLLEXPORT double calculate_recall(unsigned num_queries, unsigned *gold_std, float *gs_dist, unsigned dim_gs,
unsigned *our_results, unsigned dim_or, unsigned recall_at);
unsigned *our_results, unsigned dim_or, unsigned recall_at, float* algo_distances = nullptr);

DISKANN_DLLEXPORT double calculate_recall(unsigned num_queries, unsigned *gold_std, float *gs_dist, unsigned dim_gs,
unsigned *our_results, unsigned dim_or, unsigned recall_at,
Expand Down
17 changes: 15 additions & 2 deletions src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,19 @@ void normalize_data_file(const std::string &inFileName, const std::string &outFi
diskann::cout << "Wrote normalized points to file: " << outFileName << std::endl;
}


double calculate_recall(uint32_t num_queries, uint32_t *gold_std, float *gs_dist, uint32_t dim_gs,
uint32_t *our_results, uint32_t dim_or, uint32_t recall_at)
uint32_t *our_results, uint32_t dim_or, uint32_t recall_at, float* algo_distances)
{
bool use_distances_to_break_ties = false;
if (algo_distances != nullptr) {
use_distances_to_break_ties = true;
}
double total_recall = 0;
std::set<uint32_t> gt, res;

for (size_t i = 0; i < num_queries; i++)
{
if (!use_distances_to_break_ties) {
gt.clear();
res.clear();
uint32_t *gt_vec = gold_std + dim_gs * i;
Expand All @@ -160,6 +165,14 @@ double calculate_recall(uint32_t num_queries, uint32_t *gold_std, float *gs_dist
}
}
total_recall += cur_recall;
} else { // only works if dim_or == dim_gs. Not for the k-recall@k' regime.
uint32_t cur_recall =0;
for (uint32_t rr = 0; rr < std::min(dim_or, dim_gs); rr++) {
if (algo_distances[i*dim_or + rr] <= gs_dist[i*dim_gs + (recall_at-1)])
cur_recall++;
}
total_recall += cur_recall;
}
}
return total_recall / (num_queries) * (100.0 / recall_at);
}
Expand Down

0 comments on commit 7133b21

Please sign in to comment.