From 5ec58615448893f34f26ed93e2d91ba0ca2dc948 Mon Sep 17 00:00:00 2001 From: Neelam Mahapatro Date: Mon, 16 Oct 2023 16:20:14 +0530 Subject: [PATCH] improve load time --- include/pq_flash_index.h | 3 +- src/index.cpp | 67 ++++++++++++++++++--------- src/pq_flash_index.cpp | 99 ++++++++++++++++++++++++++++++---------- 3 files changed, 124 insertions(+), 45 deletions(-) diff --git a/include/pq_flash_index.h b/include/pq_flash_index.h index 499a0ee62..e1f7d37be 100644 --- a/include/pq_flash_index.h +++ b/include/pq_flash_index.h @@ -111,7 +111,8 @@ template class PQFlashIndex DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id); std::unordered_map load_label_map(const std::string &map_file); DISKANN_DLLEXPORT void parse_label_file(const std::string &map_file, size_t &num_pts_labels); - DISKANN_DLLEXPORT void get_label_file_metadata(std::string map_file, uint32_t &num_pts, uint32_t &num_total_labels); + DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &file_content, uint32_t &num_pts, + uint32_t &num_total_labels); DISKANN_DLLEXPORT void generate_random_labels(std::vector &labels, const uint32_t num_labels, const uint32_t nthreads); diff --git a/src/index.cpp b/src/index.cpp index b8e585648..04203cfc3 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2022,45 +2022,70 @@ void Index::parse_label_file(const std::string &label_file, siz { // Format of Label txt file: filters with comma separators - std::ifstream infile(label_file); + std::ifstream infile(label_file, std::ios::binary); if (infile.fail()) { throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1); } - std::string line, token; + infile.seekg(0, std::ios::end); + size_t file_size = infile.tellg(); + + std::string buffer(file_size, ' '); + + infile.seekg(0, std::ios::beg); + infile.read(&buffer[0], file_size); + + std::string label_str; + size_t cur_pos = 0; + size_t next_pos = 0; uint32_t line_cnt = 0; - while (std::getline(infile, line)) + // Find total number of points in the labels file to reserve _pts_to_labels + while (cur_pos < file_size && cur_pos != std::string::npos) { + next_pos = buffer.find('\n', cur_pos); + if (next_pos == std::string::npos) + break; + cur_pos = next_pos + 1; line_cnt++; } + cur_pos = 0; + next_pos = 0; _pts_to_labels.resize(line_cnt, std::vector()); - - infile.clear(); - infile.seekg(0, std::ios::beg); line_cnt = 0; - - while (std::getline(infile, line)) + while (cur_pos < file_size && cur_pos != std::string::npos) { - std::istringstream iss(line); + next_pos = buffer.find('\n', cur_pos); + if (next_pos == std::string::npos) + { + break; + } + size_t lbl_pos = cur_pos; + size_t next_lbl_pos = 0; std::vector lbls(0); - getline(iss, token, '\t'); - std::istringstream new_iss(token); - while (getline(new_iss, token, ',')) + while (lbl_pos < next_pos && lbl_pos != std::string::npos) { - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - LabelT token_as_num = (LabelT)std::stoul(token); + next_lbl_pos = buffer.find(',', lbl_pos); + if (next_lbl_pos == std::string::npos) // the last label in the whole file + { + next_lbl_pos = next_pos; + } + if (next_lbl_pos > next_pos) // the last label in one line, just read to the end + { + next_lbl_pos = next_pos; + } + label_str.assign(buffer.c_str() + lbl_pos, next_lbl_pos - lbl_pos); + if (label_str[label_str.length() - 1] == '\t') // '\t' won't exist in label file? + { + label_str.erase(label_str.length() - 1); + } + LabelT token_as_num = (LabelT)std::stoul(label_str); lbls.push_back(token_as_num); _labels.insert(token_as_num); + // move to next label + lbl_pos = next_lbl_pos + 1; } - if (lbls.size() <= 0) - { - diskann::cout << "No label found"; - exit(-1); - } - std::sort(lbls.begin(), lbls.end()); _pts_to_labels[line_cnt] = lbls; line_cnt++; } diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index fbbf68377..f4c77aa37 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -563,29 +563,44 @@ LabelT PQFlashIndex::get_converted_label(const std::string &filter_la } template -void PQFlashIndex::get_label_file_metadata(std::string map_file, uint32_t &num_pts, +void PQFlashIndex::get_label_file_metadata(const std::string &file_content, uint32_t &num_pts, uint32_t &num_total_labels) { - std::ifstream infile(map_file); - std::string line, token; num_pts = 0; num_total_labels = 0; - while (std::getline(infile, line)) + size_t file_size = file_content.length(); + + std::string label_str; + size_t cur_pos = 0, next_pos = 0; + + while (cur_pos < file_size && cur_pos != std::string::npos) { - std::istringstream iss(line); - while (getline(iss, token, ',')) + next_pos = file_content.find('\n', cur_pos); + if (next_pos == std::string::npos) { - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + break; + } + + size_t lbl_pos = cur_pos; + size_t next_lbl_pos = 0; + while (lbl_pos < next_pos && lbl_pos != std::string::npos) + { + next_lbl_pos = file_content.find(',', lbl_pos); + if (next_lbl_pos == std::string::npos) // the last label + { + next_lbl_pos = next_pos; + } + num_total_labels++; + + lbl_pos = next_lbl_pos + 1; } + cur_pos = next_pos + 1; num_pts++; } - diskann::cout << "Labels file metadata: num_points: " << num_pts << ", #total_labels: " << num_total_labels << std::endl; - infile.close(); } template @@ -608,51 +623,89 @@ inline bool PQFlashIndex::point_has_label(uint32_t point_id, LabelT l template void PQFlashIndex::parse_label_file(const std::string &label_file, size_t &num_points_labels) { - std::ifstream infile(label_file); + std::ifstream infile(label_file, std::ios::binary); if (infile.fail()) { throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1); } - std::string line, token; + infile.seekg(0, std::ios::end); + size_t file_size = infile.tellg(); + + std::string buffer(file_size, ' '); + + infile.seekg(0, std::ios::beg); + infile.read(&buffer[0], file_size); + + infile.close(); + + std::string line; uint32_t line_cnt = 0; uint32_t num_pts_in_label_file; uint32_t num_total_labels; - get_label_file_metadata(label_file, num_pts_in_label_file, num_total_labels); + get_label_file_metadata(buffer, num_pts_in_label_file, num_total_labels); _pts_to_label_offsets = new uint32_t[num_pts_in_label_file]; _pts_to_label_counts = new uint32_t[num_pts_in_label_file]; _pts_to_labels = new LabelT[num_total_labels]; uint32_t labels_seen_so_far = 0; - while (std::getline(infile, line)) + std::string label_str; + size_t cur_pos = 0; + size_t next_pos = 0; + + while (cur_pos < file_size && cur_pos != std::string::npos) { - std::istringstream iss(line); - std::vector lbls(0); + next_pos = buffer.find('\n', cur_pos); + if (next_pos == std::string::npos) + { + break; + } _pts_to_label_offsets[line_cnt] = labels_seen_so_far; uint32_t &num_lbls_in_cur_pt = _pts_to_label_counts[line_cnt]; num_lbls_in_cur_pt = 0; - getline(iss, token, '\t'); - std::istringstream new_iss(token); - while (getline(new_iss, token, ',')) + + size_t lbl_pos = cur_pos; + size_t next_lbl_pos = 0; + + while (lbl_pos < next_pos && lbl_pos != std::string::npos) { - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - LabelT token_as_num = (LabelT)std::stoul(token); + next_lbl_pos = buffer.find(',', lbl_pos); + if (next_lbl_pos == std::string::npos) // the last label in the whole file + { + next_lbl_pos = next_pos; + } + + if (next_lbl_pos > next_pos) // the last label in one line, just read to the end + { + next_lbl_pos = next_pos; + } + + label_str.assign(buffer.c_str() + lbl_pos, next_lbl_pos - lbl_pos); + if (label_str[label_str.length() - 1] == '\t') // '\t' won't exist in label file? + { + label_str.erase(label_str.length() - 1); + } + + LabelT token_as_num = (LabelT)std::stoul(label_str); _pts_to_labels[labels_seen_so_far++] = (LabelT)token_as_num; num_lbls_in_cur_pt++; + + // move to next label + lbl_pos = next_lbl_pos + 1; } + cur_pos = next_pos + 1; if (num_lbls_in_cur_pt == 0) { diskann::cout << "No label found for point " << line_cnt << std::endl; exit(-1); } + line_cnt++; } - infile.close(); num_points_labels = line_cnt; }