Skip to content

Commit

Permalink
Jegao/label hot fix test2 (#469)
Browse files Browse the repository at this point in the history
* read label in one file

* test commit

* fix last label issue

* remove get label number

* fix some issue
  • Loading branch information
Sanhaoji2 authored Oct 8, 2023
1 parent 07938b9 commit daa5a7b
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 56 deletions.
3 changes: 2 additions & 1 deletion include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, uint32_t label_id);
std::unordered_map<std::string, LabelT> load_label_map(const std::string &map_file);
DISKANN_DLLEXPORT void parse_label_file(const std::string &map_file, size_t &num_pts_labels);
DISKANN_DLLEXPORT void get_label_file_metadata(std::string map_file, uint32_t &num_pts, uint32_t &num_total_labels);
DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts,
uint32_t &num_total_labels);
DISKANN_DLLEXPORT inline int32_t get_filter_number(const LabelT &filter_label);
DISKANN_DLLEXPORT void generate_random_labels(std::vector<LabelT> &labels, const uint32_t num_labels,
const uint32_t nthreads);
Expand Down
136 changes: 81 additions & 55 deletions src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -578,30 +578,50 @@ LabelT PQFlashIndex<T, LabelT>::get_converted_label(const std::string &filter_la
}
}


// test commit
template <typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::get_label_file_metadata(std::string map_file, uint32_t &num_pts,
void PQFlashIndex<T, LabelT>::get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts,
uint32_t &num_total_labels)
{
std::ifstream infile(map_file);
std::string line, token;
num_pts = 0;
num_total_labels = 0;

while (std::getline(infile, line))
size_t file_size = fileContent.length();

std::string label_str;
size_t cur_pos = 0;
size_t next_pos = 0;
while (cur_pos < file_size && cur_pos != std::string::npos)
{
std::istringstream iss(line);
while (getline(iss, token, ','))
next_pos = fileContent.find('\n', cur_pos);
if (next_pos == std::string::npos)
{
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
break;
}

size_t lbl_pos = cur_pos;
size_t next_lbl_pos = 0;
while (lbl_pos < next_pos && lbl_pos != std::string::npos)
{
next_lbl_pos = fileContent.find(',', lbl_pos);
if (next_lbl_pos == std::string::npos) // the last label
{
next_lbl_pos = next_pos;
}

num_total_labels++;

lbl_pos = next_lbl_pos + 1;
}

cur_pos = next_pos + 1;

num_pts++;
}

diskann::cout << "Labels file metadata: num_points: " << num_pts << ", #total_labels: " << num_total_labels
<< std::endl;
infile.close();
}

template <typename T, typename LabelT>
Expand All @@ -624,77 +644,98 @@ inline bool PQFlashIndex<T, LabelT>::point_has_label(uint32_t point_id, uint32_t
template <typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::parse_label_file(const std::string &label_file, size_t &num_points_labels)
{
std::ifstream infile(label_file);
std::ifstream infile(label_file, std::ios::binary);
if (infile.fail())
{
throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1);
}
infile.seekg(0, std::ios::end);
size_t file_size = infile.tellg();

std::string buffer(file_size, ' ');

infile.seekg(0, std::ios::beg);
infile.read(&buffer[0], file_size);
infile.close();

std::string line, token;
uint32_t line_cnt = 0;

uint32_t num_pts_in_label_file;
uint32_t num_total_labels;
get_label_file_metadata(label_file, num_pts_in_label_file, num_total_labels);
get_label_file_metadata(buffer, num_pts_in_label_file, num_total_labels);

_pts_to_label_offsets = new uint32_t[num_pts_in_label_file];
_pts_to_labels = new uint32_t[num_pts_in_label_file + num_total_labels];
uint32_t counter = 0;

while (std::getline(infile, line))
std::string label_str;
size_t cur_pos = 0;
size_t next_pos = 0;
while (cur_pos < file_size && cur_pos != std::string::npos)
{
std::istringstream iss(line);
std::vector<uint32_t> lbls(0);
next_pos = buffer.find('\n', cur_pos);
if (next_pos == std::string::npos)
{
break;
}

_pts_to_label_offsets[line_cnt] = counter;
uint32_t &num_lbls_in_cur_pt = _pts_to_labels[counter];
num_lbls_in_cur_pt = 0;
counter++;
getline(iss, token, '\t');
std::istringstream new_iss(token);
while (getline(new_iss, token, ','))

size_t lbl_pos = cur_pos;
size_t next_lbl_pos = 0;
while (lbl_pos < next_pos && lbl_pos != std::string::npos)
{
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
LabelT token_as_num = (LabelT)std::stoul(token);
if (_labels.find(token_as_num) == _labels.end())
next_lbl_pos = buffer.find(',', lbl_pos);
if (next_lbl_pos == std::string::npos) // the last label in the whole file
{
_filter_list.emplace_back(token_as_num);
next_lbl_pos = next_pos;
}
int32_t filter_num = get_filter_number(token_as_num);
if (filter_num == -1)

if (next_lbl_pos > next_pos) // the last label in one line
{
next_lbl_pos = next_pos;
}

label_str.assign(buffer.c_str() + lbl_pos, next_lbl_pos - lbl_pos);
if (label_str[label_str.length() - 1] == '\t')
{
label_str.erase(label_str.length() - 1);
}

LabelT token_as_num = (LabelT)std::stoul(label_str);
if (_labels.find(token_as_num) == _labels.end())
{
diskann::cout << "Error!! " << std::endl;
exit(-1);
_filter_list.emplace_back(token_as_num);
}
_pts_to_labels[counter++] = filter_num;

_pts_to_labels[counter++] = token_as_num;
num_lbls_in_cur_pt++;
_labels.insert(token_as_num);

lbl_pos = next_lbl_pos + 1;
}

cur_pos = next_pos + 1;

if (num_lbls_in_cur_pt == 0)
{
diskann::cout << "No label found for point " << line_cnt << std::endl;
exit(-1);
}

line_cnt++;
}
infile.close();

num_points_labels = line_cnt;
}

template <typename T, typename LabelT> void PQFlashIndex<T, LabelT>::set_universal_label(const LabelT &label)
{
int32_t temp_filter_num = get_filter_number(label);
if (temp_filter_num == -1)
{
diskann::cout << "Error, could not find universal label." << std::endl;
}
else
{
_use_universal_label = true;
_universal_filter_num = (uint32_t)temp_filter_num;
}
_use_universal_label = true;
_universal_filter_num = (uint32_t)label;
}

#ifdef EXEC_ENV_OLS
Expand Down Expand Up @@ -1150,22 +1191,7 @@ void PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint64_t
const uint32_t io_limit, const bool use_reorder_data,
QueryStats *stats)
{
int32_t filter_num = 0;
if (use_filter)
{
filter_num = get_filter_number(filter_label);
if (filter_num < 0)
{
if (!_use_universal_label)
{
return;
}
else
{
filter_num = _universal_filter_num;
}
}
}
int32_t filter_num = filter_label;

if (beam_width > MAX_N_SECTOR_READS)
throw ANNException("Beamwidth can not be higher than MAX_N_SECTOR_READS", -1, __FUNCSIG__, __FILE__, __LINE__);
Expand Down

0 comments on commit daa5a7b

Please sign in to comment.