From 1c46db0241d39c5d158608bc274ea0b2f0c5b2e4 Mon Sep 17 00:00:00 2001 From: rakri Date: Mon, 12 Feb 2024 03:30:48 +0000 Subject: [PATCH] added multifilter gt --- apps/utils/compute_filtered_groundtruth.cpp | 139 +++++++++++--------- 1 file changed, 77 insertions(+), 62 deletions(-) diff --git a/apps/utils/compute_filtered_groundtruth.cpp b/apps/utils/compute_filtered_groundtruth.cpp index f35b4dde7..a352df0bc 100644 --- a/apps/utils/compute_filtered_groundtruth.cpp +++ b/apps/utils/compute_filtered_groundtruth.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -132,7 +133,7 @@ void exact_knn(const size_t dim, const size_t k, size_t npoints, float *points_in, // points in Col major size_t nqueries, float *queries_in, - diskann::Metric metric, std::vector>::iterator matching_points) // queries in Col major + diskann::Metric metric, std::vector> &matching_points) // queries in Col major { float *points_l2sq = new float[npoints]; float *queries_l2sq = new float[nqueries]; @@ -213,15 +214,15 @@ void exact_knn(const size_t dim, const size_t k, for (long long q = q_b; q < q_e; q++) { maxPQIFCS point_dist; - for (size_t p = 0; p < k; p++) { - if ((*(matching_points + p)).find(q) != (*(matching_points + p)).end()) - point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); - } - for (size_t p = k; p < npoints; p++) +// for (size_t p = 0; p < k; p++) { +// if (matching_points[q][p] == true) +// point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); +// } + for (size_t p = 0; p < npoints; p++) { - if ((*(matching_points + p)).find(q) == (*(matching_points + p)).end()) + if (matching_points[q][p] == false) continue; - if (point_dist.top().second > dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]) + if (point_dist.size() < k || point_dist.top().second > dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]) point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); if (point_dist.size() > k) point_dist.pop(); @@ -346,52 +347,9 @@ inline void save_groundtruth_as_one_file(const std::string filename, int32_t *da std::cout << "Finished writing truthset" << std::endl; } -template -std::vector>> processUnfilteredParts(const std::string &base_file, - size_t &nqueries, size_t &npoints, - size_t &dim, size_t &k, float *query_data, - const diskann::Metric &metric, - std::vector &location_to_tag, std::vector> &matching_points) -{ - float *base_data = nullptr; - int num_parts = get_num_parts(base_file.c_str()); - std::vector>> res(nqueries); - for (int p = 0; p < num_parts; p++) - { - size_t start_id = p * PARTSIZE; - load_bin_as_float(base_file.c_str(), base_data, npoints, dim, p); - size_t end_id = start_id + npoints; - - size_t *closest_points_part = new size_t[nqueries * k]; - float *dist_closest_points_part = new float[nqueries * k]; - - auto part_k = k < npoints ? k : npoints; - exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, query_data, - metric, matching_points.begin() + start_id); - - for (size_t i = 0; i < nqueries; i++) - { - for (size_t j = 0; j < part_k; j++) - { - if (!location_to_tag.empty()) - if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0) - continue; - - res[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id), - dist_closest_points_part[i * part_k + j])); - } - } - - delete[] closest_points_part; - delete[] dist_closest_points_part; - - diskann::aligned_free(base_data); - } - return res; -}; inline void parse_base_label_file(const std::string &map_file, - std::vector> &pts_to_labels) + std::vector> &pts_to_labels, uint32_t start_id = 0) { pts_to_labels.clear(); std::ifstream infile(map_file); @@ -399,8 +357,11 @@ inline void parse_base_label_file(const std::string &map_file, std::set labels; infile.clear(); infile.seekg(0, std::ios::beg); + uint32_t line_no = 0; while (std::getline(infile, line)) - { + { + if (line_no < start_id) + continue; std::istringstream iss(line); tsl::robin_set lbls; @@ -415,6 +376,8 @@ inline void parse_base_label_file(const std::string &map_file, } // std::sort(lbls.begin(), lbls.end()); pts_to_labels.push_back(lbls); + if (pts_to_labels.size() >= PARTSIZE) + break; } std::cout << "Identified " << labels.size() << " distinct label(s), and populated labels for " << pts_to_labels.size() << " points" << std::endl; @@ -451,22 +414,24 @@ inline void parse_query_label_file(const std::string &query_label_file, << query_labels.size() << " queries" << std::endl; } + +//template // add UNIVERSAL LABEL SUPPORT -int identify_matching_points(const std::string &base, const std::string &query, const std::string &unv_label, std::vector> &matching_points) { +int identify_matching_points(const std::string &base, const size_t start_id, const std::string &query, const std::string &unv_label, std::vector> &matching_points) { std::vector> base_labels; std::vector> query_labels; - parse_base_label_file(base, base_labels); + parse_base_label_file(base, base_labels, start_id); parse_query_label_file(query, query_labels); matching_points.clear(); uint32_t num_query = query_labels.size(); uint32_t num_base = base_labels.size(); matching_points.resize(num_query); for (auto &x : matching_points) - x.reserve(0.4*num_base); + x.resize(num_base); #pragma omp parallel for schedule(dynamic, 128) for (uint32_t i = 0; i < num_query; i++) { - if (i % 100 == 0) - std::cout<<"."<< std::flush; +// if (i % 100 == 0) +// std::cout<<"."<< std::flush; // tsl::robin_set matches; for (uint32_t j = 0; j < num_base; j++) { bool pass = true; @@ -479,13 +444,65 @@ int identify_matching_points(const std::string &base, const std::string &query, } } if (pass) { - matching_points[i].insert(j); + matching_points[i][j] = 1; } } } return 0; } + + +template +std::vector>> processUnfilteredParts(const std::string &base_file, const std::string &base_labels, const std::string &query_labels, const std::string &unv_label, + size_t &nqueries, size_t &npoints, + size_t &dim, size_t &k, float *query_data, + const diskann::Metric &metric, + std::vector &location_to_tag) +{ + float *base_data = nullptr; + int num_parts = get_num_parts(base_file.c_str()); + std::vector>> res(nqueries); + for (int p = 0; p < num_parts; p++) + { + size_t start_id = p * PARTSIZE; + load_bin_as_float(base_file.c_str(), base_data, npoints, dim, p); + size_t end_id = start_id + npoints; + + std::vector> matching_points; + identify_matching_points(base_labels, start_id, query_labels, unv_label, matching_points); + + + size_t *closest_points_part = new size_t[nqueries * k]; + float *dist_closest_points_part = new float[nqueries * k]; + + auto part_k = k < npoints ? k : npoints; + exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, query_data, + metric, matching_points); + + for (size_t i = 0; i < nqueries; i++) + { + for (size_t j = 0; j < part_k; j++) + { + if (!location_to_tag.empty()) + if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0) + continue; + + res[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id), + dist_closest_points_part[i * part_k + j])); + } + } + + delete[] closest_points_part; + delete[] dist_closest_points_part; + + diskann::aligned_free(base_data); + } + return res; +}; + + + // add UNIVERSAL LABEL SUPPORT template int aux_main(const std::string &base_file, const std::string &query_file, const std::string >_file, size_t k, @@ -494,8 +511,6 @@ int aux_main(const std::string &base_file, const std::string &query_file, const size_t npoints, nqueries, dim; float *query_data; - std::vector> matching_points; - identify_matching_points(base_labels, query_labels, unv_label, matching_points); load_bin_as_float(query_file.c_str(), query_data, nqueries, dim, 0); if (nqueries > PARTSIZE) std::cerr << "WARNING: #Queries provided (" << nqueries << ") is greater than " << PARTSIZE @@ -509,7 +524,7 @@ int aux_main(const std::string &base_file, const std::string &query_file, const float *dist_closest_points = new float[nqueries * k]; std::vector>> results = - processUnfilteredParts(base_file, nqueries, npoints, dim, k, query_data, metric, location_to_tag, matching_points); + processUnfilteredParts(base_file, base_labels, query_labels, unv_label, nqueries, npoints, dim, k, query_data, metric, location_to_tag); for (size_t i = 0; i < nqueries; i++) {