Skip to content

Commit

Permalink
revert memory index parse file
Browse files Browse the repository at this point in the history
  • Loading branch information
NeelamMahapatro committed Oct 26, 2023
1 parent 17009c2 commit d1cb7e1
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 73 deletions.
8 changes: 4 additions & 4 deletions include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,13 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex

private:
DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id);
std::unordered_map<std::string, LabelT> load_label_map(const std::string &map_file);
DISKANN_DLLEXPORT void parse_label_file(const std::string &map_file, size_t &num_pts_labels);
DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &file_content, uint32_t &num_pts,
std::unordered_map<std::string, LabelT> load_label_map(std::basic_istream<char> &infile);
DISKANN_DLLEXPORT void parse_label_file(std::basic_istream<char> &infile, size_t &num_pts_labels);
DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts,
uint32_t &num_total_labels);
DISKANN_DLLEXPORT void generate_random_labels(std::vector<LabelT> &labels, const uint32_t num_labels,
const uint32_t nthreads);

void reset_stream_for_reading(std::basic_istream<char> &infile);
// index info
// nhood of node `i` is in sector: [i / nnodes_per_sector]
// offset in sector: [(i % nnodes_per_sector) * max_node_len]
Expand Down
64 changes: 17 additions & 47 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2022,71 +2022,41 @@ void Index<T, TagT, LabelT>::parse_label_file(const std::string &label_file, siz
{
// Format of Label txt file: filters with comma separators

std::ifstream infile(label_file, std::ios::binary);
std::ifstream infile(label_file);
if (infile.fail())
{
throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1);
}

infile.seekg(0, std::ios::end);
size_t file_size = infile.tellg();

std::string buffer(file_size, ' ');

infile.seekg(0, std::ios::beg);
infile.read(&buffer[0], file_size);

std::string label_str;
size_t cur_pos = 0;
size_t next_pos = 0;
std::string line, token;
uint32_t line_cnt = 0;

// Find total number of points in the labels file to reserve _pts_to_labels
while (cur_pos < file_size && cur_pos != std::string::npos)
while (std::getline(infile, line))
{
next_pos = buffer.find('\n', cur_pos);
if (next_pos == std::string::npos)
break;
cur_pos = next_pos + 1;
line_cnt++;
}
cur_pos = 0;
next_pos = 0;
_pts_to_labels.resize(line_cnt, std::vector<LabelT>());

infile.clear();
infile.seekg(0, std::ios::beg);
line_cnt = 0;
while (cur_pos < file_size && cur_pos != std::string::npos)

while (std::getline(infile, line))
{
next_pos = buffer.find('\n', cur_pos);
if (next_pos == std::string::npos)
{
break;
}
size_t lbl_pos = cur_pos;
size_t next_lbl_pos = 0;
std::istringstream iss(line);
std::vector<LabelT> lbls(0);
while (lbl_pos < next_pos && lbl_pos != std::string::npos)
getline(iss, token, '\t');
std::istringstream new_iss(token);
while (getline(new_iss, token, ','))
{
next_lbl_pos = buffer.find(',', lbl_pos);
if (next_lbl_pos == std::string::npos) // the last label in the whole file
{
next_lbl_pos = next_pos;
}
if (next_lbl_pos > next_pos) // the last label in one line, just read to the end
{
next_lbl_pos = next_pos;
}
label_str.assign(buffer.c_str() + lbl_pos, next_lbl_pos - lbl_pos);
if (label_str[label_str.length() - 1] == '\t') // '\t' won't exist in label file?
{
label_str.erase(label_str.length() - 1);
}
LabelT token_as_num = (LabelT)std::stoul(label_str);
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
LabelT token_as_num = (LabelT)std::stoul(token);
lbls.push_back(token_as_num);
_labels.insert(token_as_num);
// move to next label
lbl_pos = next_lbl_pos + 1;
}
cur_pos = next_pos + 1;

std::sort(lbls.begin(), lbls.end());
_pts_to_labels[line_cnt] = lbls;
line_cnt++;
}
Expand Down
105 changes: 83 additions & 22 deletions src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,10 +526,9 @@ void PQFlashIndex<T, LabelT>::generate_random_labels(std::vector<LabelT> &labels
}

template <typename T, typename LabelT>
std::unordered_map<std::string, LabelT> PQFlashIndex<T, LabelT>::load_label_map(const std::string &labels_map_file)
std::unordered_map<std::string, LabelT> PQFlashIndex<T, LabelT>::load_label_map(std::basic_istream<char> &map_reader)
{
std::unordered_map<std::string, LabelT> string_to_int_mp;
std::ifstream map_reader(labels_map_file);
std::string line, token;
LabelT token_as_num;
std::string label_str;
Expand Down Expand Up @@ -563,20 +562,27 @@ LabelT PQFlashIndex<T, LabelT>::get_converted_label(const std::string &filter_la
}

template <typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::get_label_file_metadata(const std::string &file_content, uint32_t &num_pts,
void PQFlashIndex<T, LabelT>::reset_stream_for_reading(std::basic_istream<char> &infile)
{
infile.clear();
infile.seekg(0);
}

template <typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts,
uint32_t &num_total_labels)
{
num_pts = 0;
num_total_labels = 0;

size_t file_size = file_content.length();
size_t file_size = fileContent.length();

std::string label_str;
size_t cur_pos = 0, next_pos = 0;

size_t cur_pos = 0;
size_t next_pos = 0;
while (cur_pos < file_size && cur_pos != std::string::npos)
{
next_pos = file_content.find('\n', cur_pos);
next_pos = fileContent.find('\n', cur_pos);
if (next_pos == std::string::npos)
{
break;
Expand All @@ -586,7 +592,7 @@ void PQFlashIndex<T, LabelT>::get_label_file_metadata(const std::string &file_co
size_t next_lbl_pos = 0;
while (lbl_pos < next_pos && lbl_pos != std::string::npos)
{
next_lbl_pos = file_content.find(',', lbl_pos);
next_lbl_pos = fileContent.find(',', lbl_pos);
if (next_lbl_pos == std::string::npos) // the last label
{
next_lbl_pos = next_pos;
Expand All @@ -596,9 +602,12 @@ void PQFlashIndex<T, LabelT>::get_label_file_metadata(const std::string &file_co

lbl_pos = next_lbl_pos + 1;
}

cur_pos = next_pos + 1;

num_pts++;
}

diskann::cout << "Labels file metadata: num_points: " << num_pts << ", #total_labels: " << num_total_labels
<< std::endl;
}
Expand All @@ -621,14 +630,8 @@ inline bool PQFlashIndex<T, LabelT>::point_has_label(uint32_t point_id, LabelT l
}

template <typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::parse_label_file(const std::string &label_file, size_t &num_points_labels)
void PQFlashIndex<T, LabelT>::parse_label_file(std::basic_istream<char> &infile, size_t &num_points_labels)
{
std::ifstream infile(label_file, std::ios::binary);
if (infile.fail())
{
throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1);
}

infile.seekg(0, std::ios::end);
size_t file_size = infile.tellg();

Expand All @@ -637,8 +640,6 @@ void PQFlashIndex<T, LabelT>::parse_label_file(const std::string &label_file, si
infile.seekg(0, std::ios::beg);
infile.read(&buffer[0], file_size);

infile.close();

std::string line;
uint32_t line_cnt = 0;

Expand All @@ -654,7 +655,6 @@ void PQFlashIndex<T, LabelT>::parse_label_file(const std::string &label_file, si
std::string label_str;
size_t cur_pos = 0;
size_t next_pos = 0;

while (cur_pos < file_size && cur_pos != std::string::npos)
{
next_pos = buffer.find('\n', cur_pos);
Expand All @@ -669,7 +669,6 @@ void PQFlashIndex<T, LabelT>::parse_label_file(const std::string &label_file, si

size_t lbl_pos = cur_pos;
size_t next_lbl_pos = 0;

while (lbl_pos < next_pos && lbl_pos != std::string::npos)
{
next_lbl_pos = buffer.find(',', lbl_pos);
Expand All @@ -696,6 +695,8 @@ void PQFlashIndex<T, LabelT>::parse_label_file(const std::string &label_file, si
// move to next label
lbl_pos = next_lbl_pos + 1;
}

// move to next line
cur_pos = next_pos + 1;

if (num_lbls_in_cur_pt == 0)
Expand All @@ -706,7 +707,9 @@ void PQFlashIndex<T, LabelT>::parse_label_file(const std::string &label_file, si

line_cnt++;
}

num_points_labels = line_cnt;
reset_stream_for_reading(infile);
}

template <typename T, typename LabelT> void PQFlashIndex<T, LabelT>::set_universal_label(const LabelT &label)
Expand Down Expand Up @@ -792,15 +795,51 @@ int PQFlashIndex<T, LabelT>::load_from_separate_paths(uint32_t num_threads, cons
this->num_points = npts_u64;
this->n_chunks = nchunks_u64;
memory_in_bytes += npts_u64 * nchunks_u64;
#ifdef EXEC_ENV_OLS
if (files.fileExists(labels_file))
{
FileContent &content_labels = files.getContent(labels_file);
std::stringstream infile(std::string((const char *)content_labels._content, content_labels._size));
#else
if (file_exists(labels_file))
{
parse_label_file(labels_file, num_pts_in_label_file);
assert(num_pts_in_label_file == this->num_points);
_label_map = load_label_map(labels_map_file);
std::ifstream infile(labels_file, std::ios::binary);
if (infile.fail())
{
throw diskann::ANNException(std::string("Failed to open file ") + labels_file, -1);
}
#endif
parse_label_file(infile, num_pts_in_label_file);
assert(num_pts_in_label_file == this->_num_points);

#ifndef EXEC_ENV_OLS
infile.close();
#endif

#ifdef EXEC_ENV_OLS
FileContent &content_labels_map = files.getContent(labels_map_file);
std::stringstream map_reader(std::string((const char *)content_labels_map._content, content_labels_map._size));
#else
std::ifstream map_reader(labels_map_file);
#endif
_label_map = load_label_map(map_reader);

#ifndef EXEC_ENV_OLS
map_reader.close();
#endif

#ifdef EXEC_ENV_OLS
if (files.fileExists(labels_to_medoids))
{
FileContent &content_labels_to_meoids = files.getContent(labels_to_medoids);
std::stringstream medoid_stream(
std::string((const char *)content_labels_to_meoids._content, content_labels_to_meoids._size));
#else
if (file_exists(labels_to_medoids))
{
std::ifstream medoid_stream(labels_to_medoids);
assert(medoid_stream.is_open());
#endif
std::string line, token;

_filter_to_medoid_ids.clear();
Expand Down Expand Up @@ -829,20 +868,40 @@ int PQFlashIndex<T, LabelT>::load_from_separate_paths(uint32_t num_threads, cons
}
}
std::string univ_label_file = std ::string(disk_index_file) + "_universal_label.txt";

#ifdef EXEC_ENV_OLS
if (files.fileExists(univ_label_file))
{
FileContent &content_univ_label = files.getContent(univ_label_file);
std::stringstream universal_label_reader(
std::string((const char *)content_univ_label._content, content_univ_label._size));
#else
if (file_exists(univ_label_file))
{
std::ifstream universal_label_reader(univ_label_file);
assert(universal_label_reader.is_open());
#endif
std::string univ_label;
universal_label_reader >> univ_label;
#ifndef EXEC_ENV_OLS
universal_label_reader.close();
#endif
LabelT label_as_num = (LabelT)std::stoul(univ_label);
set_universal_label(label_as_num);
}

#ifdef EXEC_ENV_OLS
if (files.fileExists(dummy_map_file))
{
FileContent &content_dummy_map = files.getContent(dummy_map_file);
std::stringstream dummy_map_stream(
std::string((const char *)content_dummy_map._content, content_dummy_map._size));
#else
if (file_exists(dummy_map_file))
{
std::ifstream dummy_map_stream(dummy_map_file);
assert(dummy_map_stream.is_open());
#endif
std::string line, token;

while (std::getline(dummy_map_stream, line))
Expand All @@ -868,7 +927,9 @@ int PQFlashIndex<T, LabelT>::load_from_separate_paths(uint32_t num_threads, cons

_real_to_dummy_map[real_id].emplace_back(dummy_id);
}
#ifndef EXEC_ENV_OLS
dummy_map_stream.close();
#endif
diskann::cout << "Loaded dummy map" << std::endl;
}
}
Expand Down

0 comments on commit d1cb7e1

Please sign in to comment.