From 07e1dfb4dd947e413adbb34baa26c8ed112f042d Mon Sep 17 00:00:00 2001 From: Gopal Srinivasa Date: Thu, 11 Jul 2024 13:12:23 +0530 Subject: [PATCH] Fixed issue with code hanging when num_threads == 1 --- apps/search_disk_index.cpp | 28 ++++++++++++++++++++++++++++ include/percentile_stats.h | 4 ++++ src/pq_flash_index.cpp | 24 ++++++++++++++++++++++++ 3 files changed, 56 insertions(+) diff --git a/apps/search_disk_index.cpp b/apps/search_disk_index.cpp index 99a517379..4b8d91a45 100644 --- a/apps/search_disk_index.cpp +++ b/apps/search_disk_index.cpp @@ -33,6 +33,31 @@ namespace po = boost::program_options; +#ifdef DISKANN_DEBUG_PRINT_RETSET +void dump_retset(uint64_t test_id, uint64_t query_num, diskann::QueryStats *stats, const std::string &result_output_prefix) +{ + std::stringstream ss; + if (stats != nullptr) + { + for (int i = 0; i < query_num; i++) + { + ss << i << "\t"; + for (int j = 0; j < (stats + i)->query_retset.size(); j++) + { + ss << "(" << (stats + i)->query_retset[j].id << ", " << (stats + i)->query_retset[j].distance + << "), "; + } + ss << std::endl; + } + + } + std::string results_file = result_output_prefix + "_L" + std::to_string(test_id) + "_retset.tsv"; + std::ofstream writer(results_file); + writer << ss.str() << std::endl; + writer.close(); +} +#endif + #ifdef DISKANN_DEBUG_INDIVIDUAL_RESULTS void dump_individual_results(uint64_t test_id, uint64_t query_num, uint32_t *gt_ids, float *gt_dists, uint64_t gt_dim, const std::vector &query_result_ids, @@ -414,6 +439,9 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre dump_individual_results(test_id, query_num, gt_ids, gt_dists, gt_dim, query_result_ids[test_id], query_result_dists[test_id], recall_at, result_output_prefix); #endif +#ifdef DISKANN_DEBUG_PRINT_RETSET + dump_retset(test_id, query_num, stats, result_output_prefix); +#endif diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth << std::setw(16) << qps << std::setw(16) << mean_latency << std::setw(16) << latency_999 << std::setw(16) << mean_ios diff --git a/include/percentile_stats.h b/include/percentile_stats.h index 793257577..361b0109e 100644 --- a/include/percentile_stats.h +++ b/include/percentile_stats.h @@ -33,6 +33,10 @@ struct QueryStats unsigned n_cmps = 0; // # cmps unsigned n_cache_hits = 0; // # cache_hits unsigned n_hops = 0; // # search hops + +#ifdef DISKANN_DEBUG_PRINT_RETSET + std::vector query_retset; //copy of the retset to debug PQ distances. +#endif }; template diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index 1067a83cc..4dcfe6497 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -136,6 +136,7 @@ void PQFlashIndex::setup_thread_data(uint64_t nthreads, uint64_t visi this->_thread_data.push(data); } } + this->_thread_data.push_notify_all(); _load_flag = true; } @@ -1431,6 +1432,9 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t } compute_dists(&best_medoid, 1, dist_scratch); retset.insert(Neighbor(best_medoid, dist_scratch[0])); +#ifdef DISKANN_DEBUG_PRINT_RETSET + stats->query_retset.push_back(Neighbor(best_medoid, dist_scratch[0])); +#endif visited.insert(best_medoid); cur_list_size = 1; } else { @@ -1638,6 +1642,10 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t float dist = dist_scratch[m]; Neighbor nn(id, dist); retset.insert(nn); +#ifdef DISKANN_DEBUG_PRINT_RETSET + stats->query_retset.push_back(nn); +#endif + } } } @@ -1712,6 +1720,10 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t Neighbor nn(id, dist); retset.insert(nn); +#ifdef DISKANN_DEBUG_PRINT_RETSET + stats->query_retset.push_back(nn); +#endif + } } @@ -1725,6 +1737,18 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t // re-sort by distance std::sort(full_retset.begin(), full_retset.end()); +#ifdef DISKANN_DEBUG_PRINT_RETSET + { + for (int i = 0; i < retset.size(); i++) + { + if (stats != nullptr) + { + stats->query_retset.push_back(retset[i]); + } + } + } +#endif + if (use_reorder_data) { if (!(this->_reorder_data_exists))