Skip to content

Commit

Permalink
improve load time
Browse files Browse the repository at this point in the history
  • Loading branch information
NeelamMahapatro committed Oct 16, 2023
1 parent 64c0e8c commit 5ec5861
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 45 deletions.
3 changes: 2 additions & 1 deletion include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id);
std::unordered_map<std::string, LabelT> load_label_map(const std::string &map_file);
DISKANN_DLLEXPORT void parse_label_file(const std::string &map_file, size_t &num_pts_labels);
DISKANN_DLLEXPORT void get_label_file_metadata(std::string map_file, uint32_t &num_pts, uint32_t &num_total_labels);
DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &file_content, uint32_t &num_pts,
uint32_t &num_total_labels);
DISKANN_DLLEXPORT void generate_random_labels(std::vector<LabelT> &labels, const uint32_t num_labels,
const uint32_t nthreads);

Expand Down
67 changes: 46 additions & 21 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2022,45 +2022,70 @@ void Index<T, TagT, LabelT>::parse_label_file(const std::string &label_file, siz
{
// Format of Label txt file: filters with comma separators

std::ifstream infile(label_file);
std::ifstream infile(label_file, std::ios::binary);
if (infile.fail())
{
throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1);
}

std::string line, token;
infile.seekg(0, std::ios::end);
size_t file_size = infile.tellg();

std::string buffer(file_size, ' ');

infile.seekg(0, std::ios::beg);
infile.read(&buffer[0], file_size);

std::string label_str;
size_t cur_pos = 0;
size_t next_pos = 0;
uint32_t line_cnt = 0;

while (std::getline(infile, line))
// Find total number of points in the labels file to reserve _pts_to_labels
while (cur_pos < file_size && cur_pos != std::string::npos)
{
next_pos = buffer.find('\n', cur_pos);
if (next_pos == std::string::npos)
break;
cur_pos = next_pos + 1;
line_cnt++;
}
cur_pos = 0;
next_pos = 0;
_pts_to_labels.resize(line_cnt, std::vector<LabelT>());

infile.clear();
infile.seekg(0, std::ios::beg);
line_cnt = 0;

while (std::getline(infile, line))
while (cur_pos < file_size && cur_pos != std::string::npos)
{
std::istringstream iss(line);
next_pos = buffer.find('\n', cur_pos);
if (next_pos == std::string::npos)
{
break;
}
size_t lbl_pos = cur_pos;
size_t next_lbl_pos = 0;
std::vector<LabelT> lbls(0);
getline(iss, token, '\t');
std::istringstream new_iss(token);
while (getline(new_iss, token, ','))
while (lbl_pos < next_pos && lbl_pos != std::string::npos)
{
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
LabelT token_as_num = (LabelT)std::stoul(token);
next_lbl_pos = buffer.find(',', lbl_pos);
if (next_lbl_pos == std::string::npos) // the last label in the whole file
{
next_lbl_pos = next_pos;
}
if (next_lbl_pos > next_pos) // the last label in one line, just read to the end
{
next_lbl_pos = next_pos;
}
label_str.assign(buffer.c_str() + lbl_pos, next_lbl_pos - lbl_pos);
if (label_str[label_str.length() - 1] == '\t') // '\t' won't exist in label file?
{
label_str.erase(label_str.length() - 1);
}
LabelT token_as_num = (LabelT)std::stoul(label_str);
lbls.push_back(token_as_num);
_labels.insert(token_as_num);
// move to next label
lbl_pos = next_lbl_pos + 1;
}
if (lbls.size() <= 0)
{
diskann::cout << "No label found";
exit(-1);
}
std::sort(lbls.begin(), lbls.end());
_pts_to_labels[line_cnt] = lbls;
line_cnt++;
}
Expand Down
99 changes: 76 additions & 23 deletions src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,29 +563,44 @@ LabelT PQFlashIndex<T, LabelT>::get_converted_label(const std::string &filter_la
}

template <typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::get_label_file_metadata(std::string map_file, uint32_t &num_pts,
void PQFlashIndex<T, LabelT>::get_label_file_metadata(const std::string &file_content, uint32_t &num_pts,
uint32_t &num_total_labels)
{
std::ifstream infile(map_file);
std::string line, token;
num_pts = 0;
num_total_labels = 0;

while (std::getline(infile, line))
size_t file_size = file_content.length();

std::string label_str;
size_t cur_pos = 0, next_pos = 0;

while (cur_pos < file_size && cur_pos != std::string::npos)
{
std::istringstream iss(line);
while (getline(iss, token, ','))
next_pos = file_content.find('\n', cur_pos);
if (next_pos == std::string::npos)
{
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
break;
}

size_t lbl_pos = cur_pos;
size_t next_lbl_pos = 0;
while (lbl_pos < next_pos && lbl_pos != std::string::npos)
{
next_lbl_pos = file_content.find(',', lbl_pos);
if (next_lbl_pos == std::string::npos) // the last label
{
next_lbl_pos = next_pos;
}

num_total_labels++;

lbl_pos = next_lbl_pos + 1;
}
cur_pos = next_pos + 1;
num_pts++;
}

diskann::cout << "Labels file metadata: num_points: " << num_pts << ", #total_labels: " << num_total_labels
<< std::endl;
infile.close();
}

template <typename T, typename LabelT>
Expand All @@ -608,51 +623,89 @@ inline bool PQFlashIndex<T, LabelT>::point_has_label(uint32_t point_id, LabelT l
template <typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::parse_label_file(const std::string &label_file, size_t &num_points_labels)
{
std::ifstream infile(label_file);
std::ifstream infile(label_file, std::ios::binary);
if (infile.fail())
{
throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1);
}

std::string line, token;
infile.seekg(0, std::ios::end);
size_t file_size = infile.tellg();

std::string buffer(file_size, ' ');

infile.seekg(0, std::ios::beg);
infile.read(&buffer[0], file_size);

infile.close();

std::string line;
uint32_t line_cnt = 0;

uint32_t num_pts_in_label_file;
uint32_t num_total_labels;
get_label_file_metadata(label_file, num_pts_in_label_file, num_total_labels);
get_label_file_metadata(buffer, num_pts_in_label_file, num_total_labels);

_pts_to_label_offsets = new uint32_t[num_pts_in_label_file];
_pts_to_label_counts = new uint32_t[num_pts_in_label_file];
_pts_to_labels = new LabelT[num_total_labels];
uint32_t labels_seen_so_far = 0;

while (std::getline(infile, line))
std::string label_str;
size_t cur_pos = 0;
size_t next_pos = 0;

while (cur_pos < file_size && cur_pos != std::string::npos)
{
std::istringstream iss(line);
std::vector<uint32_t> lbls(0);
next_pos = buffer.find('\n', cur_pos);
if (next_pos == std::string::npos)
{
break;
}

_pts_to_label_offsets[line_cnt] = labels_seen_so_far;
uint32_t &num_lbls_in_cur_pt = _pts_to_label_counts[line_cnt];
num_lbls_in_cur_pt = 0;
getline(iss, token, '\t');
std::istringstream new_iss(token);
while (getline(new_iss, token, ','))

size_t lbl_pos = cur_pos;
size_t next_lbl_pos = 0;

while (lbl_pos < next_pos && lbl_pos != std::string::npos)
{
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
LabelT token_as_num = (LabelT)std::stoul(token);
next_lbl_pos = buffer.find(',', lbl_pos);
if (next_lbl_pos == std::string::npos) // the last label in the whole file
{
next_lbl_pos = next_pos;
}

if (next_lbl_pos > next_pos) // the last label in one line, just read to the end
{
next_lbl_pos = next_pos;
}

label_str.assign(buffer.c_str() + lbl_pos, next_lbl_pos - lbl_pos);
if (label_str[label_str.length() - 1] == '\t') // '\t' won't exist in label file?
{
label_str.erase(label_str.length() - 1);
}

LabelT token_as_num = (LabelT)std::stoul(label_str);
_pts_to_labels[labels_seen_so_far++] = (LabelT)token_as_num;
num_lbls_in_cur_pt++;

// move to next label
lbl_pos = next_lbl_pos + 1;
}
cur_pos = next_pos + 1;

if (num_lbls_in_cur_pt == 0)
{
diskann::cout << "No label found for point " << line_cnt << std::endl;
exit(-1);
}

line_cnt++;
}
infile.close();
num_points_labels = line_cnt;
}

Expand Down

0 comments on commit 5ec5861

Please sign in to comment.