diff --git a/apps/search_disk_index.cpp b/apps/search_disk_index.cpp index 7e2a7ac6d..16ae7fbee 100644 --- a/apps/search_disk_index.cpp +++ b/apps/search_disk_index.cpp @@ -232,7 +232,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre _pFlashIndex->cached_beam_search(query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at), query_result_dists[test_id].data() + (i * recall_at), - optimized_beamwidth, use_reorder_data, stats + i); + optimized_beamwidth, std::numeric_limits::max(), use_reorder_data, stats + i); } else { @@ -247,7 +247,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre } _pFlashIndex->cached_beam_search( query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at), - query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search, + query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search, std::numeric_limits::max(), use_reorder_data, stats + i); } } diff --git a/include/pq_flash_index.h b/include/pq_flash_index.h index ba5258e18..f8da87fdd 100644 --- a/include/pq_flash_index.h +++ b/include/pq_flash_index.h @@ -62,23 +62,23 @@ template class PQFlashIndex const bool shuffle = false); DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, - uint64_t *res_ids, float *res_dists, const uint64_t beam_width, + uint64_t *res_ids, float *res_dists, const uint64_t beam_width, const uint32_t max_l_per_seller = std::numeric_limits::max(), const bool use_reorder_data = false, QueryStats *stats = nullptr); DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, uint64_t *res_ids, float *res_dists, const uint64_t beam_width, - const bool use_filter, const LabelT &filter_label, + const bool use_filter, const LabelT &filter_label, const uint32_t max_l_per_seller = std::numeric_limits::max(), const bool use_reorder_data = false, QueryStats *stats = nullptr); DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, uint64_t *res_ids, float *res_dists, const uint64_t beam_width, - const uint32_t io_limit, const bool use_reorder_data = false, + const uint32_t io_limit, const uint32_t max_l_per_seller = std::numeric_limits::max(), const bool use_reorder_data = false, QueryStats *stats = nullptr); DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, uint64_t *res_ids, float *res_dists, const uint64_t beam_width, const bool use_filter, const LabelT &filter_label, - const uint32_t io_limit, const bool use_reorder_data = false, + const uint32_t io_limit, const uint32_t max_l_per_seller = std::numeric_limits::max(), const bool use_reorder_data = false, QueryStats *stats = nullptr); DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &filter_label); @@ -118,10 +118,14 @@ template class PQFlashIndex DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id); std::unordered_map load_label_map(std::basic_istream &infile); DISKANN_DLLEXPORT void parse_label_file(std::basic_istream &infile, size_t &num_pts_labels); + DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts, uint32_t &num_total_labels); DISKANN_DLLEXPORT void generate_random_labels(std::vector &labels, const uint32_t num_labels, const uint32_t nthreads); + + DISKANN_DLLEXPORT void parse_seller_file(const std::string &label_file, size_t &num_pts_labels); + void reset_stream_for_reading(std::basic_istream &infile); // sector # on disk where node_id is present with in the graph part @@ -234,6 +238,13 @@ template class PQFlashIndex tsl::robin_map> _real_to_dummy_map; std::unordered_map _label_map; + + bool _diverse_index = false; + uint32_t _max_L_per_seller = 0; + std::vector _location_to_seller; + std::string _seller_file; + + #ifdef EXEC_ENV_OLS // Set to a larger value than the actual header to accommodate // any additions we make to the header. This is an outer limit diff --git a/include/scratch.h b/include/scratch.h index 83240e77c..af1c6e421 100644 --- a/include/scratch.h +++ b/include/scratch.h @@ -154,8 +154,9 @@ template class SSDQueryScratch : public AbstractScratch tsl::robin_set visited; NeighborPriorityQueue retset; std::vector full_retset; + bestCandidates best_diverse_nodes; - SSDQueryScratch(size_t aligned_dim, size_t visited_reserve); + SSDQueryScratch(size_t aligned_dim, size_t visited_reserve, std::vector &location_to_sellers); ~SSDQueryScratch(); void reset(); @@ -167,7 +168,7 @@ template class SSDThreadData SSDQueryScratch scratch; IOContext ctx; - SSDThreadData(size_t aligned_dim, size_t visited_reserve); + SSDThreadData(size_t aligned_dim, size_t visited_reserve, std::vector &location_to_sellers); void clear(); }; diff --git a/python/src/static_disk_index.cpp b/python/src/static_disk_index.cpp index 9e86b0ad5..6bf307e28 100644 --- a/python/src/static_disk_index.cpp +++ b/python/src/static_disk_index.cpp @@ -65,7 +65,7 @@ NeighborsAndDistances StaticDiskIndex
::search( std::vector u64_ids(knn); diskann::QueryStats stats; - _index.cached_beam_search(query.data(), knn, complexity, u64_ids.data(), dists.mutable_data(), beam_width, false, + _index.cached_beam_search(query.data(), knn, complexity, u64_ids.data(), dists.mutable_data(), beam_width, std::numeric_limits::max(), false, &stats); auto r = ids.mutable_unchecked<1>(); diff --git a/src/disk_utils.cpp b/src/disk_utils.cpp index 22f1e98fd..05caa20f6 100644 --- a/src/disk_utils.cpp +++ b/src/disk_utils.cpp @@ -813,7 +813,7 @@ uint32_t optimize_beamwidth(std::unique_ptr> &p { pFlashIndex->cached_beam_search(tuning_sample + (i * tuning_sample_aligned_dim), 1, L, tuning_sample_result_ids_64.data() + (i * 1), - tuning_sample_result_dists.data() + (i * 1), cur_bw, false, stats + i); + tuning_sample_result_dists.data() + (i * 1), cur_bw, std::numeric_limits::max(), false, stats + i); } auto e = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = e - s; diff --git a/src/index.cpp b/src/index.cpp index 318014461..873c7ceca 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -598,7 +598,7 @@ void Index::load(const char *filename, uint32_t num_threads, ui throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } - std::string index_seller_file = std::string(filename) + "_sellers.txt"; + std::string index_seller_file = std::string(filename) + "_sellers.txt"; if(file_exists(index_seller_file)) { uint64_t nrows_seller_file; parse_seller_file(index_seller_file, nrows_seller_file); diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index d9ad50617..d17204f20 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -131,7 +131,7 @@ void PQFlashIndex::setup_thread_data(uint64_t nthreads, uint64_t visi { #pragma omp critical { - SSDThreadData *data = new SSDThreadData(this->_aligned_dim, visited_reserve); + SSDThreadData *data = new SSDThreadData(this->_aligned_dim, visited_reserve, this->_location_to_seller); this->reader->register_thread(); data->ctx = this->reader->get_ctx(); this->_thread_data.push(data); @@ -326,7 +326,7 @@ void PQFlashIndex::generate_cache_list_from_sample_queries(std::strin // concurrently update the node_visit_counter to track most visited nodes. The last false is to not use the // "use_reorder_data" option which enables a final reranking if the disk index itself contains only PQ data. cached_beam_search(samples + (i * sample_aligned_dim), 1, l_search, tmp_result_ids_64.data() + i, - tmp_result_dists.data() + i, beamwidth, filtered_search, label_for_search, false); + tmp_result_dists.data() + i, beamwidth, filtered_search, label_for_search, std::numeric_limits::max(), false); } std::sort(this->_node_visit_counter.begin(), _node_visit_counter.end(), @@ -747,6 +747,53 @@ void PQFlashIndex::parse_label_file(std::basic_istream &infile, reset_stream_for_reading(infile); } +template +void PQFlashIndex::parse_seller_file(const std::string &label_file, size_t &num_points) +{ + // Format of Label txt file: filters with comma separators + + std::ifstream infile(label_file); + if (infile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1); + } + + std::string line, token; + uint32_t line_cnt = 0; + std::set sellers; + while (std::getline(infile, line)) + { + line_cnt++; + } + _location_to_seller.resize(line_cnt); + + infile.clear(); + infile.seekg(0, std::ios::beg); + line_cnt = 0; + + while (std::getline(infile, line)) + { + std::istringstream iss(line); + getline(iss, token, '\t'); + std::istringstream new_iss(token); + uint32_t seller; + while (getline(new_iss, token, ',')) + { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + uint32_t token_as_num = (uint32_t)std::stoul(token); + seller = token_as_num; + sellers.insert(seller); + } + + _location_to_seller[line_cnt] = seller; + line_cnt++; + } + num_points = (size_t)line_cnt; + diskann::cout << "Identified " << sellers.size() << " distinct seller(s) across " << num_points <<" points." << std::endl; +} + + template void PQFlashIndex::set_universal_label(const LabelT &label) { _use_universal_label = true; @@ -1008,6 +1055,18 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons << std::endl; } + +#ifndef EXEC_ENV_OLS +// TODO: Make this friendly for DLVS + this->_seller_file = std ::string(_disk_index_file) + "_sellers.txt"; + if(file_exists(this->_seller_file)) { + uint64_t nrows_seller_file; + parse_seller_file(this->_seller_file, nrows_seller_file); + this->_diverse_index = true; + } +#endif + + // read index metadata #ifdef EXEC_ENV_OLS // This is a bit tricky. We have to read the header from the @@ -1236,31 +1295,31 @@ bool getNextCompletedRequest(std::shared_ptr &reader, IOConte template void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, - uint64_t *indices, float *distances, const uint64_t beam_width, + uint64_t *indices, float *distances, const uint64_t beam_width, const uint32_t max_l_per_seller, const bool use_reorder_data, QueryStats *stats) { - cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, std::numeric_limits::max(), + cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, std::numeric_limits::max(), max_l_per_seller, use_reorder_data, stats); } template void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, uint64_t *indices, float *distances, const uint64_t beam_width, - const bool use_filter, const LabelT &filter_label, + const bool use_filter, const LabelT &filter_label, const uint32_t max_l_per_seller, const bool use_reorder_data, QueryStats *stats) { cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, use_filter, filter_label, - std::numeric_limits::max(), use_reorder_data, stats); + std::numeric_limits::max(), max_l_per_seller, use_reorder_data, stats); } template void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, - uint64_t *indices, float *distances, const uint64_t beam_width, - const uint32_t io_limit, const bool use_reorder_data, + uint64_t *indices, float *distances, const uint64_t beam_width, + const uint32_t io_limit, const uint32_t max_l_per_seller, const bool use_reorder_data, QueryStats *stats) { LabelT dummy_filter = 0; - cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, false, dummy_filter, io_limit, + cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, false, dummy_filter, io_limit, max_l_per_seller, use_reorder_data, stats); } @@ -1268,10 +1327,15 @@ template void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, uint64_t *indices, float *distances, const uint64_t beam_width, const bool use_filter, const LabelT &filter_label, - const uint32_t io_limit, const bool use_reorder_data, + const uint32_t io_limit, const uint32_t max_l_per_seller, const bool use_reorder_data, QueryStats *stats) { + bool diverse_search = false; + if (max_l_per_seller != std::numeric_limits::max()) + diverse_search = true; + + uint64_t num_sector_per_nodes = DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN); if (beam_width > num_sector_per_nodes * defaults::MAX_N_SECTOR_READS) throw ANNException("Beamwidth can not be higher than defaults::MAX_N_SECTOR_READS", -1, __FUNCSIG__, __FILE__, @@ -1353,8 +1417,18 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t Timer query_timer, io_timer, cpu_timer; tsl::robin_set &visited = query_scratch->visited; - NeighborPriorityQueue &retset = query_scratch->retset; - retset.reserve(l_search); + //NeighborPriorityQueue &retset = query_scratch->retset; + bestCandidates &best_diverse_nodes_ref = query_scratch->best_diverse_nodes; + + NeighborPriorityQueue* retset; + if(diverse_search) { + best_diverse_nodes_ref.setup(l_search, max_l_per_seller); + retset = &(best_diverse_nodes_ref.best_L_nodes); + } else { + retset = &(query_scratch->retset); + retset->reserve(l_search); + } + std::vector &full_retset = query_scratch->full_retset; uint32_t best_medoid = 0; @@ -1397,7 +1471,13 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t } compute_dists(&best_medoid, 1, dist_scratch); - retset.insert(Neighbor(best_medoid, dist_scratch[0])); + if (diverse_search) { + best_diverse_nodes_ref.insert(best_medoid, dist_scratch[0]); + } else { + retset->insert(Neighbor(best_medoid, dist_scratch[0])); + } + + //retset->insert(Neighbor(best_medoid, dist_scratch[0])); visited.insert(best_medoid); uint32_t cmps = 0; @@ -1414,7 +1494,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t std::vector>> cached_nhoods; cached_nhoods.reserve(2 * beam_width); - while (retset.has_unexpanded_node() && num_ios < io_limit) + while (retset->has_unexpanded_node() && num_ios < io_limit) { // clear iteration state frontier.clear(); @@ -1424,9 +1504,9 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t sector_scratch_idx = 0; // find new beam uint32_t num_seen = 0; - while (retset.has_unexpanded_node() && frontier.size() < beam_width && num_seen < beam_width) + while (retset->has_unexpanded_node() && frontier.size() < beam_width && num_seen < beam_width) { - auto nbr = retset.closest_unexpanded(); + auto nbr = retset->closest_unexpanded(); num_seen++; auto iter = _nhood_cache.find(nbr.id); if (iter != _nhood_cache.end()) @@ -1528,8 +1608,13 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t continue; cmps++; float dist = dist_scratch[m]; - Neighbor nn(id, dist); - retset.insert(nn); + +// retset->insert(nn); + if (diverse_search) { + best_diverse_nodes_ref.insert(id, dist); + } else { + retset->insert(Neighbor(id, dist)); + } } } } @@ -1597,7 +1682,12 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t } Neighbor nn(id, dist); - retset.insert(nn); +// retset->insert(nn); + if (diverse_search) { + best_diverse_nodes_ref.insert(id, dist); + } else { + retset->insert(Neighbor(id, dist)); + } } } @@ -1611,6 +1701,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t } // re-sort by distance + std::sort(full_retset.begin(), full_retset.end()); if (use_reorder_data) @@ -1663,6 +1754,15 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t std::sort(full_retset.begin(), full_retset.end()); } + + if (diverse_search) { + best_diverse_nodes_ref.clear(); + for (auto &x : full_retset) { + best_diverse_nodes_ref.insert(x.id, x.distance); + } + full_retset = best_diverse_nodes_ref.best_L_nodes._data; + } + // copy k_search values for (uint64_t i = 0; i < k_search; i++) { @@ -1720,7 +1820,7 @@ uint32_t PQFlashIndex::range_search(const T *query1, const double ran cur_bw = (cur_bw > 100) ? 100 : cur_bw; for (auto &x : distances) x = std::numeric_limits::max(); - this->cached_beam_search(query1, l_search, l_search, indices.data(), distances.data(), cur_bw, false, stats); + this->cached_beam_search(query1, l_search, l_search, indices.data(), distances.data(), cur_bw, std::numeric_limits::max(), false, stats); for (uint32_t i = 0; i < l_search; i++) { if (distances[i] > (float)range) diff --git a/src/scratch.cpp b/src/scratch.cpp index 650c0a1ce..a7a7e9c98 100644 --- a/src/scratch.cpp +++ b/src/scratch.cpp @@ -93,9 +93,10 @@ template void SSDQueryScratch::reset() visited.clear(); retset.clear(); full_retset.clear(); + best_diverse_nodes.clear(); } -template SSDQueryScratch::SSDQueryScratch(size_t aligned_dim, size_t visited_reserve) +template SSDQueryScratch::SSDQueryScratch(size_t aligned_dim, size_t visited_reserve, std::vector &location_to_sellers) : best_diverse_nodes(location_to_sellers) { size_t coord_alloc_size = ROUND_UP(sizeof(T) * aligned_dim, 256); @@ -123,7 +124,7 @@ template SSDQueryScratch::~SSDQueryScratch() } template -SSDThreadData::SSDThreadData(size_t aligned_dim, size_t visited_reserve) : scratch(aligned_dim, visited_reserve) +SSDThreadData::SSDThreadData(size_t aligned_dim, size_t visited_reserve, std::vector &location_to_sellers) : scratch(aligned_dim, visited_reserve, location_to_sellers) { }