Skip to content

Commit

Permalink
Add filter streamging interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Sanhaoji2 committed Oct 19, 2024
1 parent ce63c03 commit 09dda84
Showing 1 changed file with 39 additions and 2 deletions.
41 changes: 39 additions & 2 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1806,9 +1806,46 @@ void Index<T, TagT, LabelT>::build(const std::string &data_file, const size_t nu
size_t points_to_load = num_points_to_load == 0 ? _max_points : num_points_to_load;

auto s = std::chrono::high_resolution_clock::now();

std::vector<TagT> tags;

if (_enable_tags)
{
if (filter_params.tags_file.empty())
{
throw ANNException("Tag filename isn't set, while _enable_tags is set", -1, __FUNCSIG__, __FILE__, __LINE__);
}
else
{
if (file_exists(filter_params.tags_file))
{
diskann::cout << "Loading tags from " << filter_params.tags_file << " for vamana index build" << std::endl;
TagT* tag_data = nullptr;
size_t npts, ndim;
diskann::load_bin(filter_params.tags_file, tag_data, npts, ndim);
if (npts < num_points_to_load)
{
std::stringstream sstream;
sstream << "Loaded " << npts << " tags, insufficient to populate tags for " << num_points_to_load
<< " points to load";
throw diskann::ANNException(sstream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}
tags.resize(num_points_to_load);
memcpy(tags.data(), tag_data, sizeof(TagT) * num_points_to_load);

delete[] tag_data;
}
else
{
throw diskann::ANNException(std::string("Tag file") + filter_params.tags_file + " does not exist", -1, __FUNCSIG__,
__FILE__, __LINE__);
}
}
}

if (filter_params.label_file == "")
{
this->build(data_file.c_str(), points_to_load);
this->build(data_file.c_str(), points_to_load, tags);
}
else
{
Expand All @@ -1823,7 +1860,7 @@ void Index<T, TagT, LabelT>::build(const std::string &data_file, const size_t nu
// LabelT unv_label_as_num = 0;
this->set_universal_label(unv_label_as_num);
}
this->build_filtered_index(data_file.c_str(), labels_file_to_use, points_to_load);
this->build_filtered_index(data_file.c_str(), labels_file_to_use, points_to_load, tags);
}
std::chrono::duration<double> diff = std::chrono::high_resolution_clock::now() - s;
std::cout << "Indexing time: " << diff.count() << "\n";
Expand Down

0 comments on commit 09dda84

Please sign in to comment.