Skip to content

Commit

Permalink
read from MemoryMappedFile when EXEC_ENV_OLS is defined
Browse files Browse the repository at this point in the history
  • Loading branch information
hliu18 committed Oct 5, 2023
1 parent aa6e33a commit 6ef7421
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 27 deletions.
6 changes: 3 additions & 3 deletions include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex

private:
DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, uint32_t label_id);
std::unordered_map<std::string, LabelT> load_label_map(const std::string &map_file);
DISKANN_DLLEXPORT void parse_label_file(const std::string &map_file, size_t &num_pts_labels);
DISKANN_DLLEXPORT void get_label_file_metadata(std::string map_file, uint32_t &num_pts, uint32_t &num_total_labels);
std::unordered_map<std::string, LabelT> load_label_map(std::basic_istream<char> &infile);
DISKANN_DLLEXPORT void parse_label_file(std::basic_istream<char> &infile, size_t &num_pts_labels);
DISKANN_DLLEXPORT void get_label_file_metadata(std::basic_istream<char> &infile, uint32_t &num_pts, uint32_t &num_total_labels);
DISKANN_DLLEXPORT inline int32_t get_filter_number(const LabelT &filter_label);
DISKANN_DLLEXPORT void generate_random_labels(std::vector<LabelT> &labels, const uint32_t num_labels,
const uint32_t nthreads);
Expand Down
91 changes: 67 additions & 24 deletions src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -541,10 +541,9 @@ void PQFlashIndex<T, LabelT>::generate_random_labels(std::vector<LabelT> &labels
}

template <typename T, typename LabelT>
std::unordered_map<std::string, LabelT> PQFlashIndex<T, LabelT>::load_label_map(const std::string &labels_map_file)
std::unordered_map<std::string, LabelT> PQFlashIndex<T, LabelT>::load_label_map(std::basic_istream<char> &map_reader)
{
std::unordered_map<std::string, LabelT> string_to_int_mp;
std::ifstream map_reader(labels_map_file);
std::string line, token;
LabelT token_as_num;
std::string label_str;
Expand Down Expand Up @@ -574,10 +573,9 @@ LabelT PQFlashIndex<T, LabelT>::get_converted_label(const std::string &filter_la
}

template <typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::get_label_file_metadata(std::string map_file, uint32_t &num_pts,
void PQFlashIndex<T, LabelT>::get_label_file_metadata(std::basic_istream<char> &infile, uint32_t &num_pts,
uint32_t &num_total_labels)
{
std::ifstream infile(map_file);
std::string line, token;
num_pts = 0;
num_total_labels = 0;
Expand All @@ -596,7 +594,7 @@ void PQFlashIndex<T, LabelT>::get_label_file_metadata(std::string map_file, uint

diskann::cout << "Labels file metadata: num_points: " << num_pts << ", #total_labels: " << num_total_labels
<< std::endl;
infile.close();
infile.seekg(0);
}

template <typename T, typename LabelT>
Expand All @@ -617,20 +615,14 @@ inline bool PQFlashIndex<T, LabelT>::point_has_label(uint32_t point_id, uint32_t
}

template <typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::parse_label_file(const std::string &label_file, size_t &num_points_labels)
void PQFlashIndex<T, LabelT>::parse_label_file(std::basic_istream<char> &infile, size_t &num_points_labels)
{
std::ifstream infile(label_file);
if (infile.fail())
{
throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1);
}

std::string line, token;
uint32_t line_cnt = 0;

uint32_t num_pts_in_label_file;
uint32_t num_total_labels;
get_label_file_metadata(label_file, num_pts_in_label_file, num_total_labels);
get_label_file_metadata(infile, num_pts_in_label_file, num_total_labels);

_pts_to_label_offsets = new uint32_t[num_pts_in_label_file];
_pts_to_labels = new uint32_t[num_pts_in_label_file + num_total_labels];
Expand Down Expand Up @@ -766,16 +758,46 @@ int PQFlashIndex<T, LabelT>::load_from_separate_paths(uint32_t num_threads, cons
diskann::load_bin<uint8_t>(pq_compressed_vectors, this->data, npts_u64, nchunks_u64);
#endif

this->num_points = npts_u64;
this->n_chunks = nchunks_u64;
this->_num_points = npts_u64;
this->_n_chunks = nchunks_u64;
#ifdef EXEC_ENV_OLS
if (files.fileExists(labels_file))
{
FileContent &content = files.getContent(labels_file);
std::stringstream infile(
std::string((const char *) content._content, content._size));
#else
if (file_exists(labels_file))
{
parse_label_file(labels_file, num_pts_in_label_file);
assert(num_pts_in_label_file == this->num_points);
_label_map = load_label_map(labels_map_file);
std::ifstream infile(labels_file);
if (infile.fail())
{
throw diskann::ANNException(std::string("Failed to open file ") + labels_file, -1);
}
#endif
parse_label_file(infile, num_pts_in_label_file);
assert(num_pts_in_label_file == this->_num_points);

#ifdef EXEC_ENV_OLS
FileContent &content = files.getContent(labels_map_file);
std::stringstream map_reader(
std::string((const char *) content._content, content._size));
#else
std::ifstream map_reader(labels_map_file);
#endif
_label_map = load_label_map(map_reader);

#ifdef EXEC_ENV_OLS
if (files.fileExists(labels_to_medoids))
{
FileContent &content = files.getContent(labels_to_medoids);
std::stringstream medoid_stream(
std::string((const char *) content._content, content._size));
#else
if (file_exists(labels_to_medoids))
{
std::ifstream medoid_stream(labels_to_medoids);
#endif
assert(medoid_stream.is_open());
std::string line, token;

Expand Down Expand Up @@ -804,20 +826,38 @@ int PQFlashIndex<T, LabelT>::load_from_separate_paths(uint32_t num_threads, cons
throw FileException(labels_to_medoids, e, __FUNCSIG__, __FILE__, __LINE__);
}
}

std::string univ_label_file = std ::string(disk_index_file) + "_universal_label.txt";
#ifdef EXEC_ENV_OLS
if (files.fileExists(univ_label_file))
{
FileContent &content = files.getContent(univ_label_file);
std::stringstream universal_label_reader(
std::string((const char *) content._content, content._size));
#else
if (file_exists(univ_label_file))
{
std::ifstream universal_label_reader(univ_label_file);
#endif
assert(universal_label_reader.is_open());
std::string univ_label;
universal_label_reader >> univ_label;
universal_label_reader.close();
LabelT label_as_num = (LabelT)std::stoul(univ_label);
set_universal_label(label_as_num);
}

#ifdef EXEC_ENV_OLS
if (files.fileExists(dummy_map_file))
{
FileContent &content = files.getContent(dummy_map_file);
std::stringstream dummy_map_stream(
std::string((const char *) content._content, content._size));
#else
if (file_exists(dummy_map_file))
{
std::ifstream dummy_map_stream(dummy_map_file);
#endif
assert(dummy_map_stream.is_open());
std::string line, token;

Expand Down Expand Up @@ -867,16 +907,19 @@ int PQFlashIndex<T, LabelT>::load_from_separate_paths(uint32_t num_threads, cons
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}

std::string disk_pq_pivots_path = this->disk_index_file + "_pq_pivots.bin";
if (file_exists(disk_pq_pivots_path))
{
use_disk_index_pq = true;
std::string disk_pq_pivots_path = this->_disk_index_file + "_pq_pivots.bin";
#ifdef EXEC_ENV_OLS
// giving 0 chunks to make the pq_table infer from the
if (files.fileExists(disk_pq_pivots_path))
{
_use_disk_index_pq = true;
// giving 0 chunks to make the _pq_table infer from the
// chunk_offsets file the correct value
disk_pq_table.load_pq_centroid_bin(files, disk_pq_pivots_path.c_str(), 0);
#else
// giving 0 chunks to make the pq_table infer from the
if (file_exists(disk_pq_pivots_path))
{
_use_disk_index_pq = true;
// giving 0 chunks to make the _pq_table infer from the
// chunk_offsets file the correct value
disk_pq_table.load_pq_centroid_bin(disk_pq_pivots_path.c_str(), 0);
#endif
Expand Down

0 comments on commit 6ef7421

Please sign in to comment.