diff --git a/include/pq_flash_index.h b/include/pq_flash_index.h index 5872a0ebf..48e187aca 100644 --- a/include/pq_flash_index.h +++ b/include/pq_flash_index.h @@ -101,11 +101,10 @@ template class PQFlashIndex DISKANN_DLLEXPORT void set_universal_label(const LabelT &label); private: - DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, uint32_t label_id); - std::unordered_map load_label_map(const std::string &map_file); - DISKANN_DLLEXPORT void parse_label_file(const std::string &map_file, size_t &num_pts_labels); - DISKANN_DLLEXPORT void get_label_file_metadata(std::string map_file, uint32_t &num_pts, uint32_t &num_total_labels); - DISKANN_DLLEXPORT inline int32_t get_filter_number(const LabelT &filter_label); + DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id); + std::unordered_map load_label_map(std::basic_istream &infile); + DISKANN_DLLEXPORT void parse_label_file(std::basic_istream &infile, size_t &num_pts_labels); + DISKANN_DLLEXPORT void get_label_file_metadata(std::basic_istream &infile, uint32_t &num_pts, uint32_t &num_total_labels); DISKANN_DLLEXPORT void generate_random_labels(std::vector &labels, const uint32_t num_labels, const uint32_t nthreads); diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index 5bd23ecb0..b97d5acce 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -541,10 +541,9 @@ void PQFlashIndex::generate_random_labels(std::vector &labels } template -std::unordered_map PQFlashIndex::load_label_map(const std::string &labels_map_file) +std::unordered_map PQFlashIndex::load_label_map(std::basic_istream &map_reader) { std::unordered_map string_to_int_mp; - std::ifstream map_reader(labels_map_file); std::string line, token; LabelT token_as_num; std::string label_str; @@ -574,10 +573,9 @@ LabelT PQFlashIndex::get_converted_label(const std::string &filter_la } template -void PQFlashIndex::get_label_file_metadata(std::string map_file, uint32_t &num_pts, +void PQFlashIndex::get_label_file_metadata(std::basic_istream &infile, uint32_t &num_pts, uint32_t &num_total_labels) { - std::ifstream infile(map_file); std::string line, token; num_pts = 0; num_total_labels = 0; @@ -596,7 +594,7 @@ void PQFlashIndex::get_label_file_metadata(std::string map_file, uint diskann::cout << "Labels file metadata: num_points: " << num_pts << ", #total_labels: " << num_total_labels << std::endl; - infile.close(); + infile.seekg(0); } template @@ -617,20 +615,14 @@ inline bool PQFlashIndex::point_has_label(uint32_t point_id, uint32_t } template -void PQFlashIndex::parse_label_file(const std::string &label_file, size_t &num_points_labels) +void PQFlashIndex::parse_label_file(std::basic_istream &infile, size_t &num_points_labels) { - std::ifstream infile(label_file); - if (infile.fail()) - { - throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1); - } - std::string line, token; uint32_t line_cnt = 0; uint32_t num_pts_in_label_file; uint32_t num_total_labels; - get_label_file_metadata(label_file, num_pts_in_label_file, num_total_labels); + get_label_file_metadata(infile, num_pts_in_label_file, num_total_labels); _pts_to_label_offsets = new uint32_t[num_pts_in_label_file]; _pts_to_labels = new uint32_t[num_pts_in_label_file + num_total_labels]; @@ -766,16 +758,46 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons diskann::load_bin(pq_compressed_vectors, this->data, npts_u64, nchunks_u64); #endif - this->num_points = npts_u64; - this->n_chunks = nchunks_u64; + this->_num_points = npts_u64; + this->_n_chunks = nchunks_u64; +#ifdef EXEC_ENV_OLS + if (files.fileExists(labels_file)) + { + FileContent &content = files.getContent(labels_file); + std::stringstream infile( + std::string((const char *) content._content, content._size)); +#else if (file_exists(labels_file)) { - parse_label_file(labels_file, num_pts_in_label_file); - assert(num_pts_in_label_file == this->num_points); - _label_map = load_label_map(labels_map_file); + std::ifstream infile(labels_file); + if (infile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + labels_file, -1); + } +#endif + parse_label_file(infile, num_pts_in_label_file); + assert(num_pts_in_label_file == this->_num_points); + +#ifdef EXEC_ENV_OLS + FileContent &content = files.getContent(labels_map_file); + std::stringstream map_reader( + std::string((const char *) content._content, content._size)); +#else + std::ifstream map_reader(labels_map_file); +#endif + _label_map = load_label_map(map_reader); + +#ifdef EXEC_ENV_OLS + if (files.fileExists(labels_to_medoids)) + { + FileContent &content = files.getContent(labels_to_medoids); + std::stringstream medoid_stream( + std::string((const char *) content._content, content._size)); +#else if (file_exists(labels_to_medoids)) { std::ifstream medoid_stream(labels_to_medoids); +#endif assert(medoid_stream.is_open()); std::string line, token; @@ -804,10 +826,23 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons throw FileException(labels_to_medoids, e, __FUNCSIG__, __FILE__, __LINE__); } } +<<<<<<< HEAD std::string univ_label_file = std ::string(disk_index_file) + "_universal_label.txt"; +======= + std::string univ_label_file = std ::string(_disk_index_file) + "_universal_label.txt"; + +#ifdef EXEC_ENV_OLS + if (files.fileExists(univ_label_file)) + { + FileContent &content = files.getContent(univ_label_file); + std::stringstream universal_label_reader( + std::string((const char *) content._content, content._size)); +#else +>>>>>>> 0b6d5aa (read from MemoryMappedFile when EXEC_ENV_OLS is defined) if (file_exists(univ_label_file)) { std::ifstream universal_label_reader(univ_label_file); +#endif assert(universal_label_reader.is_open()); std::string univ_label; universal_label_reader >> univ_label; @@ -815,9 +850,18 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons LabelT label_as_num = (LabelT)std::stoul(univ_label); set_universal_label(label_as_num); } + +#ifdef EXEC_ENV_OLS + if (files.fileExists(dummy_map_file)) + { + FileContent &content = files.getContent(dummy_map_file); + std::stringstream dummy_map_stream( + std::string((const char *) content._content, content._size)); +#else if (file_exists(dummy_map_file)) { std::ifstream dummy_map_stream(dummy_map_file); +#endif assert(dummy_map_stream.is_open()); std::string line, token; @@ -867,16 +911,32 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } +<<<<<<< HEAD std::string disk_pq_pivots_path = this->disk_index_file + "_pq_pivots.bin"; if (file_exists(disk_pq_pivots_path)) { use_disk_index_pq = true; #ifdef EXEC_ENV_OLS // giving 0 chunks to make the pq_table infer from the +======= + std::string disk_pq_pivots_path = this->_disk_index_file + "_pq_pivots.bin"; +#ifdef EXEC_ENV_OLS + if (files.fileExists(disk_pq_pivots_path)) + { + _use_disk_index_pq = true; + // giving 0 chunks to make the _pq_table infer from the +>>>>>>> 0b6d5aa (read from MemoryMappedFile when EXEC_ENV_OLS is defined) // chunk_offsets file the correct value disk_pq_table.load_pq_centroid_bin(files, disk_pq_pivots_path.c_str(), 0); #else +<<<<<<< HEAD // giving 0 chunks to make the pq_table infer from the +======= + if (file_exists(disk_pq_pivots_path)) + { + _use_disk_index_pq = true; + // giving 0 chunks to make the _pq_table infer from the +>>>>>>> 0b6d5aa (read from MemoryMappedFile when EXEC_ENV_OLS is defined) // chunk_offsets file the correct value disk_pq_table.load_pq_centroid_bin(disk_pq_pivots_path.c_str(), 0); #endif