Skip to content

Commit

Permalink
untested, added diversity code to disk index searcher
Browse files Browse the repository at this point in the history
  • Loading branch information
rakri committed Dec 4, 2024
1 parent 0ae9b41 commit 7970aa3
Showing 1 changed file with 124 additions and 12 deletions.
136 changes: 124 additions & 12 deletions apps/search_disk_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,52 @@

namespace po = boost::program_options;

void parse_seller_file(const std::string &label_file, size_t &num_points, std::vector<uint32_t> &location_to_seller)
{
// Format of Label txt file: filters with comma separators

std::ifstream infile(label_file);
if (infile.fail())
{
throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1);
}

std::string line, token;
uint32_t line_cnt = 0;
std::set<uint32_t> sellers;
while (std::getline(infile, line))
{
line_cnt++;
}
location_to_seller.resize(line_cnt);

infile.clear();
infile.seekg(0, std::ios::beg);
line_cnt = 0;

while (std::getline(infile, line))
{
std::istringstream iss(line);
getline(iss, token, '\t');
std::istringstream new_iss(token);
uint32_t seller;
while (getline(new_iss, token, ','))
{
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
uint32_t token_as_num = (uint32_t)std::stoul(token);
seller = token_as_num;
sellers.insert(seller);
}

location_to_seller[line_cnt] = seller;
line_cnt++;
}
num_points = (size_t)line_cnt;
diskann::cout << " Search code: Identified " << sellers.size() << " distinct seller(s) across " << num_points <<" points." << std::endl;
}


void print_stats(std::string category, std::vector<float> percentiles, std::vector<float> results)
{
diskann::cout << std::setw(20) << category << ": " << std::flush;
Expand All @@ -53,7 +99,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre
const uint32_t num_threads, const uint32_t recall_at, const uint32_t beamwidth,
const uint32_t num_nodes_to_cache, const uint32_t search_io_limit,
const std::vector<uint32_t> &Lvec, const float fail_if_recall_below,
const std::vector<std::string> &query_filters, const bool use_reorder_data = false)
const std::vector<std::string> &query_filters, const bool use_reorder_data = false, const uint32_t max_K_per_seller = std::numeric_limits<uint32_t>::max(), const bool diverse_search = false, const bool scale_seller_limits = false, const bool post_process = false)
{
diskann::cout << "Search parameters: #threads: " << num_threads << ", ";
if (beamwidth <= 0)
Expand All @@ -67,6 +113,16 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre

std::string warmup_query_file = index_path_prefix + "_sample_data.bin";

std::cout<<max_K_per_seller <<" " << diverse_search <<" " << scale_seller_limits << " " << post_process << std::endl;
std::vector<uint32_t> location_to_sellers;
std::string seller_file = index_path_prefix +"_sellers.txt";
if (file_exists(seller_file)) {
std::cout<<"Here" << std::endl;
uint64_t num_pts_seller_file;
parse_seller_file(seller_file, num_pts_seller_file, location_to_sellers);
}


// load query bin
T *query = nullptr;
uint32_t *gt_ids = nullptr;
Expand Down Expand Up @@ -216,23 +272,61 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre
else
optimized_beamwidth = beamwidth;

query_result_ids[test_id].resize(recall_at * query_num);
query_result_dists[test_id].resize(recall_at * query_num);
query_result_ids[test_id].resize(recall_at * query_num, std::numeric_limits<uint32_t>::max());
query_result_dists[test_id].resize(recall_at * query_num, std::numeric_limits<float>::max());

auto stats = new diskann::QueryStats[query_num];

std::vector<uint64_t> query_result_ids_64(recall_at * query_num);
std::vector<uint64_t> query_result_ids_64(recall_at * query_num, std::numeric_limits<uint32_t>::max());

//uint32_t maxLperSeller = (max_L_per_seller > 0) ? max_L_per_seller : L;

//maxLperSeller = (maxLperSeller == 0)? 1 : maxLperSeller;
uint32_t maxLperSeller = max_K_per_seller;
if (diverse_search && scale_seller_limits) {
maxLperSeller = (1.0*L* max_K_per_seller)/(1.0*recall_at);
// std::cout<<maxLperSeller<<std::endl;
}

auto s = std::chrono::high_resolution_clock::now();

#pragma omp parallel for schedule(dynamic, 1)
for (int64_t i = 0; i < (int64_t)query_num; i++)
{
if (!filtered_search)
{
std::vector<uint64_t> results(L,std::numeric_limits<uint32_t>::max());
std::vector<float> dists(L,std::numeric_limits<float>::max());
uint32_t K_to_use = (post_process == true) ? L : recall_at;

if (diverse_search) {
_pFlashIndex->cached_beam_search(query + (i * query_aligned_dim), K_to_use, L,
results.data(),
dists.data(),
optimized_beamwidth, maxLperSeller, use_reorder_data, stats + i);
} else {
_pFlashIndex->cached_beam_search(query + (i * query_aligned_dim), recall_at, L,
query_result_ids_64.data() + (i * recall_at),
query_result_dists[test_id].data() + (i * recall_at),
results.data(),
dists.data(),
optimized_beamwidth, std::numeric_limits<uint32_t>::max(), use_reorder_data, stats + i);

}
if (post_process) {
diskann::bestCandidates final_results(recall_at, max_K_per_seller, location_to_sellers);
for (uint32_t rr = 0; rr < L; rr++) {
final_results.insert((uint32_t)results[rr], dists[rr]);
}

for (uint32_t ctr = 0; ctr < std::min(final_results.best_L_nodes.size(), (uint64_t)recall_at); ctr++) {
query_result_ids[test_id][recall_at * i + ctr] = final_results.best_L_nodes._data[ctr].id;
query_result_dists[test_id][recall_at * i + ctr] = final_results.best_L_nodes._data[ctr].distance;
}
} else {
for (uint32_t ctr = 0; ctr < std::min(results.size(),(uint64_t)recall_at); ctr++) {
query_result_ids[test_id][recall_at * i + ctr] = results[ctr];
query_result_dists[test_id][recall_at * i + ctr] = dists[ctr];
}
}
}
else
{
Expand All @@ -255,8 +349,10 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre
std::chrono::duration<double> diff = e - s;
double qps = (1.0 * query_num) / (1.0 * diff.count());

if (filtered_search) {
diskann::convert_types<uint64_t, uint32_t>(query_result_ids_64.data(), query_result_ids[test_id].data(),
query_num, recall_at);
}

auto mean_latency = diskann::get_mean_stats<float>(
stats, query_num, [](const diskann::QueryStats &stats) { return stats.total_us; });
Expand Down Expand Up @@ -319,6 +415,9 @@ int main(int argc, char **argv)
bool use_reorder_data = false;
float fail_if_recall_below = 0.0f;

uint32_t max_L_per_seller;
bool post_process, diverse_search, scale_seller_limits;

po::options_description desc{
program_options_utils::make_program_description("search_disk_index", "Searches on-disk DiskANN indexes")};
try
Expand Down Expand Up @@ -373,6 +472,19 @@ int main(int argc, char **argv)
po::value<float>(&fail_if_recall_below)->default_value(0.0f),
program_options_utils::FAIL_IF_RECALL_BELOW);

optional_configs.add_options()("max_K_per_seller",
po::value<uint32_t>(&max_L_per_seller)->default_value(0),
"How many results per seller we want search results to contain");
optional_configs.add_options()("diverse_search",
po::value<bool>(&diverse_search)->default_value(false),
"Whether to run diverse search or baseline search");
optional_configs.add_options()("scale_seller_limits",
po::value<bool>(&scale_seller_limits)->default_value(false),
"Whether to run scale the max_L_per_seller based on the L value");
optional_configs.add_options()("post_process",
po::value<bool>(&post_process)->default_value(false),
"Whether to post-processing to ensure correct diversity");

// Merge required and optional parameters
desc.add(required_configs).add(optional_configs);

Expand Down Expand Up @@ -451,15 +563,15 @@ int main(int argc, char **argv)
if (data_type == std::string("float"))
return search_disk_index<float, uint16_t>(
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data, max_L_per_seller, diverse_search, scale_seller_limits, post_process);
else if (data_type == std::string("int8"))
return search_disk_index<int8_t, uint16_t>(
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data, max_L_per_seller, diverse_search, scale_seller_limits, post_process);
else if (data_type == std::string("uint8"))
return search_disk_index<uint8_t, uint16_t>(
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data, max_L_per_seller, diverse_search, scale_seller_limits, post_process);
else
{
std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl;
Expand All @@ -471,15 +583,15 @@ int main(int argc, char **argv)
if (data_type == std::string("float"))
return search_disk_index<float>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
fail_if_recall_below, query_filters, use_reorder_data);
fail_if_recall_below, query_filters, use_reorder_data, max_L_per_seller, diverse_search, scale_seller_limits, post_process);
else if (data_type == std::string("int8"))
return search_disk_index<int8_t>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
fail_if_recall_below, query_filters, use_reorder_data);
fail_if_recall_below, query_filters, use_reorder_data, max_L_per_seller, diverse_search, scale_seller_limits, post_process);
else if (data_type == std::string("uint8"))
return search_disk_index<uint8_t>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
fail_if_recall_below, query_filters, use_reorder_data);
fail_if_recall_below, query_filters, use_reorder_data, max_L_per_seller, diverse_search, scale_seller_limits, post_process);
else
{
std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl;
Expand Down

0 comments on commit 7970aa3

Please sign in to comment.