From daa5a7bb62aee53beb97e48542537195d38dafd6 Mon Sep 17 00:00:00 2001 From: Jerry Gao <109158931+Sanhaoji2@users.noreply.github.com> Date: Sun, 8 Oct 2023 15:43:25 +0800 Subject: [PATCH] Jegao/label hot fix test2 (#469) * read label in one file * test commit * fix last label issue * remove get label number * fix some issue --- include/pq_flash_index.h | 3 +- src/pq_flash_index.cpp | 136 +++++++++++++++++++++++---------------- 2 files changed, 83 insertions(+), 56 deletions(-) diff --git a/include/pq_flash_index.h b/include/pq_flash_index.h index ba76cd47e..d333d0e7c 100644 --- a/include/pq_flash_index.h +++ b/include/pq_flash_index.h @@ -107,7 +107,8 @@ template class PQFlashIndex DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, uint32_t label_id); std::unordered_map load_label_map(const std::string &map_file); DISKANN_DLLEXPORT void parse_label_file(const std::string &map_file, size_t &num_pts_labels); - DISKANN_DLLEXPORT void get_label_file_metadata(std::string map_file, uint32_t &num_pts, uint32_t &num_total_labels); + DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts, + uint32_t &num_total_labels); DISKANN_DLLEXPORT inline int32_t get_filter_number(const LabelT &filter_label); DISKANN_DLLEXPORT void generate_random_labels(std::vector &labels, const uint32_t num_labels, const uint32_t nthreads); diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index 943fed44c..b74c96257 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -578,30 +578,50 @@ LabelT PQFlashIndex::get_converted_label(const std::string &filter_la } } + +// test commit template -void PQFlashIndex::get_label_file_metadata(std::string map_file, uint32_t &num_pts, +void PQFlashIndex::get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts, uint32_t &num_total_labels) { - std::ifstream infile(map_file); - std::string line, token; num_pts = 0; num_total_labels = 0; - while (std::getline(infile, line)) + size_t file_size = fileContent.length(); + + std::string label_str; + size_t cur_pos = 0; + size_t next_pos = 0; + while (cur_pos < file_size && cur_pos != std::string::npos) { - std::istringstream iss(line); - while (getline(iss, token, ',')) + next_pos = fileContent.find('\n', cur_pos); + if (next_pos == std::string::npos) { - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + break; + } + + size_t lbl_pos = cur_pos; + size_t next_lbl_pos = 0; + while (lbl_pos < next_pos && lbl_pos != std::string::npos) + { + next_lbl_pos = fileContent.find(',', lbl_pos); + if (next_lbl_pos == std::string::npos) // the last label + { + next_lbl_pos = next_pos; + } + num_total_labels++; + + lbl_pos = next_lbl_pos + 1; } + + cur_pos = next_pos + 1; + num_pts++; } diskann::cout << "Labels file metadata: num_points: " << num_pts << ", #total_labels: " << num_total_labels << std::endl; - infile.close(); } template @@ -624,77 +644,98 @@ inline bool PQFlashIndex::point_has_label(uint32_t point_id, uint32_t template void PQFlashIndex::parse_label_file(const std::string &label_file, size_t &num_points_labels) { - std::ifstream infile(label_file); + std::ifstream infile(label_file, std::ios::binary); if (infile.fail()) { throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1); } + infile.seekg(0, std::ios::end); + size_t file_size = infile.tellg(); + + std::string buffer(file_size, ' '); + + infile.seekg(0, std::ios::beg); + infile.read(&buffer[0], file_size); + infile.close(); - std::string line, token; uint32_t line_cnt = 0; uint32_t num_pts_in_label_file; uint32_t num_total_labels; - get_label_file_metadata(label_file, num_pts_in_label_file, num_total_labels); + get_label_file_metadata(buffer, num_pts_in_label_file, num_total_labels); _pts_to_label_offsets = new uint32_t[num_pts_in_label_file]; _pts_to_labels = new uint32_t[num_pts_in_label_file + num_total_labels]; uint32_t counter = 0; - while (std::getline(infile, line)) + std::string label_str; + size_t cur_pos = 0; + size_t next_pos = 0; + while (cur_pos < file_size && cur_pos != std::string::npos) { - std::istringstream iss(line); - std::vector lbls(0); + next_pos = buffer.find('\n', cur_pos); + if (next_pos == std::string::npos) + { + break; + } _pts_to_label_offsets[line_cnt] = counter; uint32_t &num_lbls_in_cur_pt = _pts_to_labels[counter]; num_lbls_in_cur_pt = 0; counter++; - getline(iss, token, '\t'); - std::istringstream new_iss(token); - while (getline(new_iss, token, ',')) + + size_t lbl_pos = cur_pos; + size_t next_lbl_pos = 0; + while (lbl_pos < next_pos && lbl_pos != std::string::npos) { - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - LabelT token_as_num = (LabelT)std::stoul(token); - if (_labels.find(token_as_num) == _labels.end()) + next_lbl_pos = buffer.find(',', lbl_pos); + if (next_lbl_pos == std::string::npos) // the last label in the whole file { - _filter_list.emplace_back(token_as_num); + next_lbl_pos = next_pos; } - int32_t filter_num = get_filter_number(token_as_num); - if (filter_num == -1) + + if (next_lbl_pos > next_pos) // the last label in one line + { + next_lbl_pos = next_pos; + } + + label_str.assign(buffer.c_str() + lbl_pos, next_lbl_pos - lbl_pos); + if (label_str[label_str.length() - 1] == '\t') + { + label_str.erase(label_str.length() - 1); + } + + LabelT token_as_num = (LabelT)std::stoul(label_str); + if (_labels.find(token_as_num) == _labels.end()) { - diskann::cout << "Error!! " << std::endl; - exit(-1); + _filter_list.emplace_back(token_as_num); } - _pts_to_labels[counter++] = filter_num; + + _pts_to_labels[counter++] = token_as_num; num_lbls_in_cur_pt++; _labels.insert(token_as_num); + + lbl_pos = next_lbl_pos + 1; } + cur_pos = next_pos + 1; + if (num_lbls_in_cur_pt == 0) { diskann::cout << "No label found for point " << line_cnt << std::endl; exit(-1); } + line_cnt++; } - infile.close(); + num_points_labels = line_cnt; } template void PQFlashIndex::set_universal_label(const LabelT &label) { - int32_t temp_filter_num = get_filter_number(label); - if (temp_filter_num == -1) - { - diskann::cout << "Error, could not find universal label." << std::endl; - } - else - { - _use_universal_label = true; - _universal_filter_num = (uint32_t)temp_filter_num; - } + _use_universal_label = true; + _universal_filter_num = (uint32_t)label; } #ifdef EXEC_ENV_OLS @@ -1150,22 +1191,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t const uint32_t io_limit, const bool use_reorder_data, QueryStats *stats) { - int32_t filter_num = 0; - if (use_filter) - { - filter_num = get_filter_number(filter_label); - if (filter_num < 0) - { - if (!_use_universal_label) - { - return; - } - else - { - filter_num = _universal_filter_num; - } - } - } + int32_t filter_num = filter_label; if (beam_width > MAX_N_SECTOR_READS) throw ANNException("Beamwidth can not be higher than MAX_N_SECTOR_READS", -1, __FUNCSIG__, __FILE__, __LINE__);