Skip to content

Commit

Permalink
added multifilter gt
Browse files Browse the repository at this point in the history
  • Loading branch information
rakri committed Feb 12, 2024
1 parent 654ceb0 commit 1c46db0
Showing 1 changed file with 77 additions and 62 deletions.
139 changes: 77 additions & 62 deletions apps/utils/compute_filtered_groundtruth.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <omp.h>
#include <mkl.h>
#include <boost/program_options.hpp>
#include <boost/dynamic_bitset.hpp>
#include <unordered_map>
#include <tsl/robin_map.h>
#include <tsl/robin_set.h>
Expand Down Expand Up @@ -132,7 +133,7 @@ void exact_knn(const size_t dim, const size_t k,
size_t npoints,
float *points_in, // points in Col major
size_t nqueries, float *queries_in,
diskann::Metric metric, std::vector<tsl::robin_set<uint32_t>>::iterator matching_points) // queries in Col major
diskann::Metric metric, std::vector<boost::dynamic_bitset<>> &matching_points) // queries in Col major
{
float *points_l2sq = new float[npoints];
float *queries_l2sq = new float[nqueries];
Expand Down Expand Up @@ -213,15 +214,15 @@ void exact_knn(const size_t dim, const size_t k,
for (long long q = q_b; q < q_e; q++)
{
maxPQIFCS point_dist;
for (size_t p = 0; p < k; p++) {
if ((*(matching_points + p)).find(q) != (*(matching_points + p)).end())
point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]);
}
for (size_t p = k; p < npoints; p++)
// for (size_t p = 0; p < k; p++) {
// if (matching_points[q][p] == true)
// point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]);
// }
for (size_t p = 0; p < npoints; p++)
{
if ((*(matching_points + p)).find(q) == (*(matching_points + p)).end())
if (matching_points[q][p] == false)
continue;
if (point_dist.top().second > dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints])
if (point_dist.size() < k || point_dist.top().second > dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints])
point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]);
if (point_dist.size() > k)
point_dist.pop();
Expand Down Expand Up @@ -346,61 +347,21 @@ inline void save_groundtruth_as_one_file(const std::string filename, int32_t *da
std::cout << "Finished writing truthset" << std::endl;
}

template <typename T>
std::vector<std::vector<std::pair<uint32_t, float>>> processUnfilteredParts(const std::string &base_file,
size_t &nqueries, size_t &npoints,
size_t &dim, size_t &k, float *query_data,
const diskann::Metric &metric,
std::vector<uint32_t> &location_to_tag, std::vector<tsl::robin_set<uint32_t>> &matching_points)
{
float *base_data = nullptr;
int num_parts = get_num_parts<T>(base_file.c_str());
std::vector<std::vector<std::pair<uint32_t, float>>> res(nqueries);
for (int p = 0; p < num_parts; p++)
{
size_t start_id = p * PARTSIZE;
load_bin_as_float<T>(base_file.c_str(), base_data, npoints, dim, p);
size_t end_id = start_id + npoints;

size_t *closest_points_part = new size_t[nqueries * k];
float *dist_closest_points_part = new float[nqueries * k];

auto part_k = k < npoints ? k : npoints;
exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, query_data,
metric, matching_points.begin() + start_id);

for (size_t i = 0; i < nqueries; i++)
{
for (size_t j = 0; j < part_k; j++)
{
if (!location_to_tag.empty())
if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0)
continue;

res[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id),
dist_closest_points_part[i * part_k + j]));
}
}

delete[] closest_points_part;
delete[] dist_closest_points_part;

diskann::aligned_free(base_data);
}
return res;
};

inline void parse_base_label_file(const std::string &map_file,
std::vector<tsl::robin_set<std::string>> &pts_to_labels)
std::vector<tsl::robin_set<std::string>> &pts_to_labels, uint32_t start_id = 0)
{
pts_to_labels.clear();
std::ifstream infile(map_file);
std::string line, token;
std::set<std::string> labels;
infile.clear();
infile.seekg(0, std::ios::beg);
uint32_t line_no = 0;
while (std::getline(infile, line))
{
{
if (line_no < start_id)
continue;
std::istringstream iss(line);
tsl::robin_set<std::string> lbls;

Expand All @@ -415,6 +376,8 @@ inline void parse_base_label_file(const std::string &map_file,
}
// std::sort(lbls.begin(), lbls.end());
pts_to_labels.push_back(lbls);
if (pts_to_labels.size() >= PARTSIZE)
break;
}
std::cout << "Identified " << labels.size() << " distinct label(s), and populated labels for "
<< pts_to_labels.size() << " points" << std::endl;
Expand Down Expand Up @@ -451,22 +414,24 @@ inline void parse_query_label_file(const std::string &query_label_file,
<< query_labels.size() << " queries" << std::endl;
}


//template<typename A, typename B>
// add UNIVERSAL LABEL SUPPORT
int identify_matching_points(const std::string &base, const std::string &query, const std::string &unv_label, std::vector<tsl::robin_set<uint32_t>> &matching_points) {
int identify_matching_points(const std::string &base, const size_t start_id, const std::string &query, const std::string &unv_label, std::vector<boost::dynamic_bitset<>> &matching_points) {
std::vector<tsl::robin_set<std::string>> base_labels;
std::vector<std::vector<std::string>> query_labels;
parse_base_label_file(base, base_labels);
parse_base_label_file(base, base_labels, start_id);
parse_query_label_file(query, query_labels);
matching_points.clear();
uint32_t num_query = query_labels.size();
uint32_t num_base = base_labels.size();
matching_points.resize(num_query);
for (auto &x : matching_points)
x.reserve(0.4*num_base);
x.resize(num_base);
#pragma omp parallel for schedule(dynamic, 128)
for (uint32_t i = 0; i < num_query; i++) {
if (i % 100 == 0)
std::cout<<"."<< std::flush;
// if (i % 100 == 0)
// std::cout<<"."<< std::flush;
// tsl::robin_set<uint32_t> matches;
for (uint32_t j = 0; j < num_base; j++) {
bool pass = true;
Expand All @@ -479,13 +444,65 @@ int identify_matching_points(const std::string &base, const std::string &query,
}
}
if (pass) {
matching_points[i].insert(j);
matching_points[i][j] = 1;
}
}
}
return 0;
}



template <typename T>
std::vector<std::vector<std::pair<uint32_t, float>>> processUnfilteredParts(const std::string &base_file, const std::string &base_labels, const std::string &query_labels, const std::string &unv_label,
size_t &nqueries, size_t &npoints,
size_t &dim, size_t &k, float *query_data,
const diskann::Metric &metric,
std::vector<uint32_t> &location_to_tag)
{
float *base_data = nullptr;
int num_parts = get_num_parts<T>(base_file.c_str());
std::vector<std::vector<std::pair<uint32_t, float>>> res(nqueries);
for (int p = 0; p < num_parts; p++)
{
size_t start_id = p * PARTSIZE;
load_bin_as_float<T>(base_file.c_str(), base_data, npoints, dim, p);
size_t end_id = start_id + npoints;

std::vector<boost::dynamic_bitset<>> matching_points;
identify_matching_points(base_labels, start_id, query_labels, unv_label, matching_points);


size_t *closest_points_part = new size_t[nqueries * k];
float *dist_closest_points_part = new float[nqueries * k];

auto part_k = k < npoints ? k : npoints;
exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, query_data,
metric, matching_points);

for (size_t i = 0; i < nqueries; i++)
{
for (size_t j = 0; j < part_k; j++)
{
if (!location_to_tag.empty())
if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0)
continue;

res[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id),
dist_closest_points_part[i * part_k + j]));
}
}

delete[] closest_points_part;
delete[] dist_closest_points_part;

diskann::aligned_free(base_data);
}
return res;
};



// add UNIVERSAL LABEL SUPPORT
template <typename T>
int aux_main(const std::string &base_file, const std::string &query_file, const std::string &gt_file, size_t k,
Expand All @@ -494,8 +511,6 @@ int aux_main(const std::string &base_file, const std::string &query_file, const
size_t npoints, nqueries, dim;

float *query_data;
std::vector<tsl::robin_set<uint32_t>> matching_points;
identify_matching_points(base_labels, query_labels, unv_label, matching_points);
load_bin_as_float<T>(query_file.c_str(), query_data, nqueries, dim, 0);
if (nqueries > PARTSIZE)
std::cerr << "WARNING: #Queries provided (" << nqueries << ") is greater than " << PARTSIZE
Expand All @@ -509,7 +524,7 @@ int aux_main(const std::string &base_file, const std::string &query_file, const
float *dist_closest_points = new float[nqueries * k];

std::vector<std::vector<std::pair<uint32_t, float>>> results =
processUnfilteredParts<T>(base_file, nqueries, npoints, dim, k, query_data, metric, location_to_tag, matching_points);
processUnfilteredParts<T>(base_file, base_labels, query_labels, unv_label, nqueries, npoints, dim, k, query_data, metric, location_to_tag);

for (size_t i = 0; i < nqueries; i++)
{
Expand Down

0 comments on commit 1c46db0

Please sign in to comment.