Skip to content

Commit

Permalink
Fix compile issue
Browse files Browse the repository at this point in the history
  • Loading branch information
Sanhaoji2 committed Nov 24, 2024
1 parent bc2765e commit bd9daa3
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 17 deletions.
11 changes: 11 additions & 0 deletions include/in_mem_filter_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ template <typename LabelT> class InMemFilterStore : public AbstractFilterStore<L

DISKANN_DLLEXPORT virtual bool has_filter_support() const;

bool is_label_valid(const std::string& filter_label) const;

DISKANN_DLLEXPORT virtual const std::unordered_map<LabelT, std::vector<location_t>> &get_label_to_medoids() const;

DISKANN_DLLEXPORT virtual const std::vector<location_t> &get_medoids_of_label(const LabelT label);
Expand Down Expand Up @@ -102,6 +104,13 @@ template <typename LabelT> class InMemFilterStore : public AbstractFilterStore<L
// that they exist and could not be opened.
DISKANN_DLLEXPORT bool load(const std::string &disk_index_file);

bool load(
const std::string& labels_filepath,
const std::string& labels_to_medoids_filepath,
const std::string& labels_map_filepath,
const std::string& unv_label_filepath,
const std::string& dummy_map_filepath);

DISKANN_DLLEXPORT void generate_random_labels(std::vector<LabelT> &labels, const uint32_t num_labels,
const uint32_t nthreads);

Expand All @@ -122,6 +131,8 @@ template <typename LabelT> class InMemFilterStore : public AbstractFilterStore<L
void reset_stream_for_reading(std::basic_istream<char> &infile);
// Load functions for search END

size_t search_string_range(const std::string_view& str, char ch, size_t start, size_t end);

location_t _num_points = 0;
location_t *_pts_to_label_offsets = nullptr;
location_t *_pts_to_label_counts = nullptr;
Expand Down
4 changes: 0 additions & 4 deletions include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex

DISKANN_DLLEXPORT uint64_t get_data_dim();

DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &filter_label);

std::shared_ptr<AlignedFileReader> &reader;

DISKANN_DLLEXPORT diskann::Metric get_metric();
Expand Down Expand Up @@ -136,8 +134,6 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
// returns region of `node_buf` containing [COORD(T)]
DISKANN_DLLEXPORT T *offset_to_node_coords(char *node_buf);

size_t search_string_range(const std::string& str, char ch, size_t start, size_t end);

// index info for multi-node sectors
// nhood of node `i` is in sector: [i / nnodes_per_sector]
// offset in sector: [(i % nnodes_per_sector) * max_node_len]
Expand Down
2 changes: 1 addition & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ else()
in_mem_data_store.cpp in_mem_graph_store.cpp in_mem_filter_store.cpp
natural_number_set.cpp memory_mapper.cpp partition.cpp pq.cpp
pq_flash_index.cpp scratch.cpp logger.cpp utils.cpp filter_utils.cpp
index_factory.cpp abstract_index.cpp pq_l2_distance.cpp pq_data_store.cpp)
index_factory.cpp abstract_index.cpp pq_l2_distance.cpp pq_data_store.cpp neighbor_list.cpp in_mem_static_graph_store.cpp)
if (RESTAPI)
list(APPEND CPP_SOURCES restapi/search_wrapper.cpp restapi/server.cpp)
endif()
Expand Down
2 changes: 1 addition & 1 deletion src/dll/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ add_library(${PROJECT_NAME} SHARED dllmain.cpp ../abstract_data_store.cpp ../par
../windows_aligned_file_reader.cpp ../distance.cpp ../pq_l2_distance.cpp ../memory_mapper.cpp ../index.cpp
../in_mem_data_store.cpp ../pq_data_store.cpp ../in_mem_graph_store.cpp ../math_utils.cpp ../disk_utils.cpp ../filter_utils.cpp
../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp
../in_mem_filter_store.cpp)
../in_mem_filter_store.cpp ../neighbor_list.cpp ../in_mem_static_graph_store.cpp)

set(TARGET_DIR "$<$<CONFIG:Debug>:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}>$<$<CONFIG:Release>:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}>")

Expand Down
54 changes: 44 additions & 10 deletions src/in_mem_filter_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,33 +73,41 @@ template <typename LabelT> bool InMemFilterStore<LabelT>::load(const std::string
std::string labels_map_file = disk_index_file + "_labels_map.txt";
std::string univ_label_file = disk_index_file + "_universal_label.txt";

size_t num_pts_in_label_file = 0;
return load(labels_file, labels_to_medoids, labels_map_file, univ_label_file, dummy_map_file);
}

template <typename LabelT> bool InMemFilterStore<LabelT>::load(
const std::string& labels_filepath,
const std::string& labels_to_medoids_filepath,
const std::string& labels_map_filepath,
const std::string& unv_label_filepath,
const std::string& dummy_map_filepath)
{
// TODO: Check for encoding issues here. We are opening files as binary and
// reading them as bytes, not sure if that can cause an issue with UTF
// encodings.
bool has_filters = true;
if (false == load_file_and_parse(labels_file, &InMemFilterStore<LabelT>::load_label_file))
if (false == load_file_and_parse(labels_filepath, &InMemFilterStore<LabelT>::load_label_file))
{
diskann::cout << "Index does not have filter data. " << std::endl;
return false;
}
if (false == parse_stream(labels_map_file, &InMemFilterStore<LabelT>::load_label_map))
if (false == parse_stream(labels_map_filepath, &InMemFilterStore<LabelT>::load_label_map))
{
diskann::cerr << "Failed to find file: " << labels_map_file << " while labels_file exists." << std::endl;
diskann::cerr << "Failed to find file: " << labels_map_filepath << " while labels_file exists." << std::endl;
return false;
}

if (false == parse_stream(labels_to_medoids, &InMemFilterStore<LabelT>::load_labels_to_medoids))
if (false == parse_stream(labels_to_medoids_filepath, &InMemFilterStore<LabelT>::load_labels_to_medoids))
{
diskann::cerr << "Failed to find file: " << labels_to_medoids << " while labels file exists." << std::endl;
diskann::cerr << "Failed to find file: " << labels_to_medoids_filepath << " while labels file exists." << std::endl;
return false;
}
// missing universal label file is NOT an error.
load_file_and_parse(univ_label_file, &InMemFilterStore::parse_universal_label);
load_file_and_parse(unv_label_filepath, &InMemFilterStore::parse_universal_label);

// missing dummy map file is also NOT an error.
parse_stream(dummy_map_file, &InMemFilterStore<LabelT>::load_dummy_map);
parse_stream(dummy_map_filepath, &InMemFilterStore<LabelT>::load_dummy_map);
_is_valid = true;
return _is_valid;
}
Expand All @@ -109,6 +117,16 @@ template <typename LabelT> bool InMemFilterStore<LabelT>::has_filter_support() c
return _is_valid;
}

template <typename LabelT> bool InMemFilterStore<LabelT>::is_label_valid(const std::string& filter_label) const
{
if (_label_map.find(filter_label) != _label_map.end())
{
return true;
}

return false;
}

// TODO: Improve this to not load the entire file in memory
template <typename LabelT> void InMemFilterStore<LabelT>::load_label_file(const std::string_view &label_file_content)
{
Expand Down Expand Up @@ -147,7 +165,7 @@ template <typename LabelT> void InMemFilterStore<LabelT>::load_label_file(const
size_t next_lbl_pos = 0;
while (lbl_pos < next_pos && lbl_pos != std::string_view::npos)
{
next_lbl_pos = label_file_content.find(',', lbl_pos);
next_lbl_pos = search_string_range(label_file_content, ',', lbl_pos, next_pos);
if (next_lbl_pos == std::string_view::npos) // the last label in the whole file
{
next_lbl_pos = next_pos;
Expand Down Expand Up @@ -323,7 +341,7 @@ void InMemFilterStore<LabelT>::get_label_file_metadata(const std::string_view &f
size_t next_lbl_pos = 0;
while (lbl_pos < next_pos && lbl_pos != std::string::npos)
{
next_lbl_pos = fileContent.find(',', lbl_pos);
next_lbl_pos = search_string_range(fileContent, ',', lbl_pos, next_pos);
if (next_lbl_pos == std::string::npos) // the last label
{
next_lbl_pos = next_pos;
Expand Down Expand Up @@ -386,6 +404,21 @@ bool InMemFilterStore<LabelT>::load_file_and_parse(const std::string &filename,
}
}

template <typename LabelT>
size_t InMemFilterStore<LabelT>::search_string_range(const std::string_view& str, char ch, size_t start, size_t end)
{
for (; start != end; start++)
{
if (str[start] == ch)
{
return start;
}
}

return std::string::npos;

}

std::unique_ptr<char[]> get_file_content(const std::string &filename, uint64_t &file_size)
{
std::ifstream infile(filename, std::ios::binary);
Expand All @@ -402,6 +435,7 @@ std::unique_ptr<char[]> get_file_content(const std::string &filename, uint64_t &

return std::unique_ptr<char[]>(buffer);
}

// Load functions for SEARCH END
template class InMemFilterStore<uint16_t>;
template class InMemFilterStore<uint32_t>;
Expand Down
13 changes: 12 additions & 1 deletion src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,11 @@ template <typename T, typename LabelT> void PQFlashIndex<T, LabelT>::use_medoids
}
}

template <typename T, typename LabelT> bool PQFlashIndex<T, LabelT>::is_label_valid(const std::string& filter_label)
{
return _filter_store->is_label_valid(filter_label);
}

#ifdef EXEC_ENV_OLS
template <typename T, typename LabelT>
int PQFlashIndex<T, LabelT>::load(MemoryMappedFiles &files, uint32_t num_threads, const char *index_prefix)
Expand Down Expand Up @@ -613,7 +618,13 @@ int PQFlashIndex<T, LabelT>::load_from_separate_paths(uint32_t num_threads, cons
_filter_store = std::make_unique<InMemFilterStore<LabelT>>();
try
{
_filter_index = _filter_store->load(_disk_index_file);
_filter_index = _filter_store->load(
labels_filepath,
labels_to_medoids_filepath,
labels_map_filepath,
unv_label_filepath,
"");

if (_filter_index)
{
diskann::cout << "Index has filter support. " << std::endl;
Expand Down

0 comments on commit bd9daa3

Please sign in to comment.