Skip to content

Commit

Permalink
Refactoring compiles correctly now
Browse files Browse the repository at this point in the history
  • Loading branch information
gopal-msr committed Sep 30, 2024
1 parent 34f723f commit a5f3d2e
Show file tree
Hide file tree
Showing 5 changed files with 1,407 additions and 1,484 deletions.
189 changes: 111 additions & 78 deletions include/in_mem_filter_store.h
Original file line number Diff line number Diff line change
@@ -1,90 +1,123 @@
#pragma once

#include <vector>
#include "windows_customizations.h"
#include "tsl/robin_map.h"
#include "tsl/robin_set.h"
#include <abstract_filter_store.h>

namespace diskann
{
namespace diskann {
template<typename LabelT>
class InMemFilterStore : public AbstractFilterStore<LabelT>
{
public:
/// <summary>
/// Returns the filters for a data point. Only valid for base points
/// </summary>
/// <param name="point">base point id</param>
/// <returns>list of filters of the base point</returns>
virtual const std::vector<LabelT> &get_filters_for_point(location_t point) const override;

/// <summary>
/// Adds filters for a point.
/// </summary>
/// <param name="point"></param>
/// <param name="filters"></param>
virtual void add_filters_for_point(location_t point, const std::vector<LabelT> &filters) override;

/// <summary>
/// Returns a score between [0,1] indicating how many points in the dataset
/// matched the predicate
/// </summary>
/// <param name="pred">Predicate to match</param>
/// <returns>Score between [0,1] indicate %age of points matching pred</returns>
virtual float get_predicate_selectivity(const AbstractPredicate &pred) const override;


virtual const std::unordered_map<LabelT, std::vector<location_t>>& get_label_to_medoids() const;

virtual const std::vector<location_t> &get_medoids_of_label(const LabelT label) const;

virtual void set_universal_label(const LabelT univ_label);

inline bool point_has_label(location_t point_id, const LabelT label_id) const;

inline bool is_dummy_point(location_t id) const;

inline bool point_has_label_or_universal_label(location_t point_id, const LabelT label_id) const;

inline LabelT get_converted_label(const std::string &filter_label) const;

//Returns true if the index is filter-enabled and all files were loaded correctly.
//false otherwise. Note that "false" can mean that the index does not have filter support,
//or that some index files do not exist, or that they exist and could not be opened.
bool load(const std::string& disk_index_file);
class InMemFilterStore : public AbstractFilterStore<LabelT> {
public:
// Do nothing constructor because all the work is done in load()
DISKANN_DLLEXPORT InMemFilterStore() {
}

/// <summary>
/// Destructor
/// </summary>
DISKANN_DLLEXPORT virtual ~InMemFilterStore();

// No copy, no assignment.
DISKANN_DLLEXPORT InMemFilterStore<LabelT> &operator=(
const InMemFilterStore<LabelT> &v) = delete;
DISKANN_DLLEXPORT InMemFilterStore(const InMemFilterStore<LabelT> &v) =
delete;

DISKANN_DLLEXPORT virtual bool has_filter_support() const;

/// <summary>
/// Returns the filters for a data point. Only valid for base points
/// </summary>
/// <param name="point">base point id</param>
/// <returns>list of filters of the base point</returns>
DISKANN_DLLEXPORT virtual const std::vector<LabelT> &get_filters_for_point(
location_t point) const override;

/// <summary>
/// Adds filters for a point.
/// </summary>
/// <param name="point"></param>
/// <param name="filters"></param>
DISKANN_DLLEXPORT virtual void add_filters_for_point(
location_t point, const std::vector<LabelT> &filters) override;

/// <summary>
/// Returns a score between [0,1] indicating how many points in the dataset
/// matched the predicate
/// </summary>
/// <param name="pred">Predicate to match</param>
/// <returns>Score between [0,1] indicate %age of points matching
/// pred</returns>
DISKANN_DLLEXPORT virtual float get_predicate_selectivity(
const AbstractPredicate &pred) const override;

DISKANN_DLLEXPORT virtual const std::unordered_map<LabelT,
std::vector<location_t>>
&get_label_to_medoids() const;

DISKANN_DLLEXPORT virtual const std::vector<location_t>
&get_medoids_of_label(const LabelT label) ;

DISKANN_DLLEXPORT virtual void set_universal_label(const LabelT univ_label);

DISKANN_DLLEXPORT inline bool point_has_label(location_t point_id,
const LabelT label_id) const;

DISKANN_DLLEXPORT inline bool is_dummy_point(location_t id) const;

DISKANN_DLLEXPORT inline location_t get_real_point_for_dummy(
location_t dummy_id);

DISKANN_DLLEXPORT inline bool point_has_label_or_universal_label(
location_t point_id, const LabelT label_id) const;

DISKANN_DLLEXPORT inline LabelT get_converted_label(
const std::string &filter_label);

// Returns true if the index is filter-enabled and all files were loaded
// correctly. false otherwise. Note that "false" can mean that the index
// does not have filter support, or that some index files do not exist, or
// that they exist and could not be opened.
DISKANN_DLLEXPORT bool load(const std::string &disk_index_file);

DISKANN_DLLEXPORT void generate_random_labels(std::vector<LabelT> &labels,
const uint32_t num_labels,
const uint32_t nthreads);

private:

// Load functions for search START
void load_label_file(const std::string_view& file_content);
void load_label_map(std::basic_istream<char> &map_reader);
void load_labels_to_medoids(std::basic_istream<char> &reader);
void load_dummy_map(std::basic_istream<char> &dummy_map_stream);

bool load_file_and_parse(
const std::string &filename,
void (*parse_fn)(const std::string_view &content));

bool load_file_and_parse(
const std::string &filename,
void (*parse_fn)(std::basic_istream<char> &stream))


// Load functions for search END

// filter support
uint32_t *_pts_to_label_offsets = nullptr;
uint32_t *_pts_to_label_counts = nullptr;
LabelT *_pts_to_labels = nullptr;
std::unordered_map<LabelT, std::vector<location_t>> _filter_to_medoid_ids;
bool _use_universal_label = false;
LabelT _universal_filter_label;
tsl::robin_set<uint32_t> _dummy_pts;
tsl::robin_set<uint32_t> _has_dummy_pts;
tsl::robin_map<uint32_t, uint32_t> _dummy_to_real_map;
tsl::robin_map<uint32_t, std::vector<uint32_t>> _real_to_dummy_map;
std::unordered_map<std::string, LabelT> _label_map;

// Load functions for search START
void load_label_file(const std::string_view &file_content);
void load_label_map(std::basic_istream<char> &map_reader);
void load_labels_to_medoids(std::basic_istream<char> &reader);
void load_dummy_map(std::basic_istream<char> &dummy_map_stream);
void parse_universal_label(const std::string_view &content);
void get_label_file_metadata(const std::string_view &fileContent,
uint32_t &num_pts, uint32_t &num_total_labels);

bool load_file_and_parse(const std::string &filename,
void (InMemFilterStore::*parse_fn)(const std::string_view &content));
bool parse_stream(
const std::string &filename,
void (InMemFilterStore::*parse_fn)(std::basic_istream<char> &stream));

void reset_stream_for_reading(std::basic_istream<char> &infile);
// Load functions for search END

location_t _num_points = 0;
location_t *_pts_to_label_offsets = nullptr;
location_t *_pts_to_label_counts = nullptr;
LabelT *_pts_to_labels = nullptr;
bool _use_universal_label = false;
LabelT _universal_filter_label;
tsl::robin_set<location_t> _dummy_pts;
tsl::robin_set<location_t> _has_dummy_pts;
tsl::robin_map<location_t, location_t> _dummy_to_real_map;
tsl::robin_map<location_t, std::vector<location_t>> _real_to_dummy_map;
std::unordered_map<std::string, LabelT> _label_map;
std::unordered_map<LabelT, std::vector<location_t>> _filter_to_medoid_ids;
bool _is_valid = false;
};

}
} // namespace diskann
15 changes: 2 additions & 13 deletions include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
const uint32_t io_limit, const bool use_reorder_data = false,
QueryStats *stats = nullptr);

DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &filter_label);

DISKANN_DLLEXPORT uint32_t range_search(const T *query1, const double range, const uint64_t min_l_search,
const uint64_t max_l_search, std::vector<uint64_t> &indices,
Expand Down Expand Up @@ -114,18 +113,7 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
DISKANN_DLLEXPORT void use_medoids_data_as_centroids();
DISKANN_DLLEXPORT void setup_thread_data(uint64_t nthreads, uint64_t visited_reserve = 4096);

DISKANN_DLLEXPORT void set_universal_label(const LabelT &label);

private:
DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id);
std::unordered_map<std::string, LabelT> load_label_map(std::basic_istream<char> &infile);
DISKANN_DLLEXPORT void parse_label_file(std::basic_istream<char> &infile, size_t &num_pts_labels);
DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts,
uint32_t &num_total_labels);
DISKANN_DLLEXPORT void generate_random_labels(std::vector<LabelT> &labels, const uint32_t num_labels,
const uint32_t nthreads);
void reset_stream_for_reading(std::basic_istream<char> &infile);

// sector # on disk where node_id is present with in the graph part
DISKANN_DLLEXPORT uint64_t get_node_sector(uint64_t node_id);

Expand Down Expand Up @@ -225,7 +213,8 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex

//Moved filter-specific data structures to in_mem_filter_store.
//TODO: Make this a unique pointer
InMemFilterStore<LabelT>* _filter_store;
bool _filter_index = false;
std::unique_ptr<InMemFilterStore<LabelT>> _filter_store;


#ifdef EXEC_ENV_OLS
Expand Down
3 changes: 2 additions & 1 deletion src/dll/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
add_library(${PROJECT_NAME} SHARED dllmain.cpp ../abstract_data_store.cpp ../partition.cpp ../pq.cpp ../pq_flash_index.cpp ../logger.cpp ../utils.cpp
../windows_aligned_file_reader.cpp ../distance.cpp ../pq_l2_distance.cpp ../memory_mapper.cpp ../index.cpp
../in_mem_data_store.cpp ../pq_data_store.cpp ../in_mem_graph_store.cpp ../math_utils.cpp ../disk_utils.cpp ../filter_utils.cpp
../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp)
../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp
../in_mem_filter_store.cpp)

set(TARGET_DIR "$<$<CONFIG:Debug>:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}>$<$<CONFIG:Release>:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}>")

Expand Down
Loading

0 comments on commit a5f3d2e

Please sign in to comment.