Skip to content

Commit

Permalink
added untested code for diversity in PQFlashIndex
Browse files Browse the repository at this point in the history
  • Loading branch information
rakri committed Nov 25, 2024
1 parent f07b458 commit 6410548
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 33 deletions.
4 changes: 2 additions & 2 deletions apps/search_disk_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre
_pFlashIndex->cached_beam_search(query + (i * query_aligned_dim), recall_at, L,
query_result_ids_64.data() + (i * recall_at),
query_result_dists[test_id].data() + (i * recall_at),
optimized_beamwidth, use_reorder_data, stats + i);
optimized_beamwidth, std::numeric_limits<uint32_t>::max(), use_reorder_data, stats + i);
}
else
{
Expand All @@ -247,7 +247,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre
}
_pFlashIndex->cached_beam_search(
query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at),
query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search,
query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search, std::numeric_limits<uint32_t>::max(),
use_reorder_data, stats + i);
}
}
Expand Down
19 changes: 15 additions & 4 deletions include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,23 +62,23 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
const bool shuffle = false);

DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search,
uint64_t *res_ids, float *res_dists, const uint64_t beam_width,
uint64_t *res_ids, float *res_dists, const uint64_t beam_width, const uint32_t max_l_per_seller = std::numeric_limits<uint32_t>::max(),
const bool use_reorder_data = false, QueryStats *stats = nullptr);

DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search,
uint64_t *res_ids, float *res_dists, const uint64_t beam_width,
const bool use_filter, const LabelT &filter_label,
const bool use_filter, const LabelT &filter_label, const uint32_t max_l_per_seller = std::numeric_limits<uint32_t>::max(),
const bool use_reorder_data = false, QueryStats *stats = nullptr);

DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search,
uint64_t *res_ids, float *res_dists, const uint64_t beam_width,
const uint32_t io_limit, const bool use_reorder_data = false,
const uint32_t io_limit, const uint32_t max_l_per_seller = std::numeric_limits<uint32_t>::max(), const bool use_reorder_data = false,
QueryStats *stats = nullptr);

DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search,
uint64_t *res_ids, float *res_dists, const uint64_t beam_width,
const bool use_filter, const LabelT &filter_label,
const uint32_t io_limit, const bool use_reorder_data = false,
const uint32_t io_limit, const uint32_t max_l_per_seller = std::numeric_limits<uint32_t>::max(), const bool use_reorder_data = false,
QueryStats *stats = nullptr);

DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &filter_label);
Expand Down Expand Up @@ -118,10 +118,14 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id);
std::unordered_map<std::string, LabelT> load_label_map(std::basic_istream<char> &infile);
DISKANN_DLLEXPORT void parse_label_file(std::basic_istream<char> &infile, size_t &num_pts_labels);

DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts,
uint32_t &num_total_labels);
DISKANN_DLLEXPORT void generate_random_labels(std::vector<LabelT> &labels, const uint32_t num_labels,
const uint32_t nthreads);

DISKANN_DLLEXPORT void parse_seller_file(const std::string &label_file, size_t &num_pts_labels);

void reset_stream_for_reading(std::basic_istream<char> &infile);

// sector # on disk where node_id is present with in the graph part
Expand Down Expand Up @@ -234,6 +238,13 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
tsl::robin_map<uint32_t, std::vector<uint32_t>> _real_to_dummy_map;
std::unordered_map<std::string, LabelT> _label_map;


bool _diverse_index = false;
uint32_t _max_L_per_seller = 0;
std::vector<uint32_t> _location_to_seller;
std::string _seller_file;


#ifdef EXEC_ENV_OLS
// Set to a larger value than the actual header to accommodate
// any additions we make to the header. This is an outer limit
Expand Down
5 changes: 3 additions & 2 deletions include/scratch.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,9 @@ template <typename T> class SSDQueryScratch : public AbstractScratch<T>
tsl::robin_set<size_t> visited;
NeighborPriorityQueue retset;
std::vector<Neighbor> full_retset;
bestCandidates best_diverse_nodes;

SSDQueryScratch(size_t aligned_dim, size_t visited_reserve);
SSDQueryScratch(size_t aligned_dim, size_t visited_reserve, std::vector<uint32_t> &location_to_sellers);
~SSDQueryScratch();

void reset();
Expand All @@ -167,7 +168,7 @@ template <typename T> class SSDThreadData
SSDQueryScratch<T> scratch;
IOContext ctx;

SSDThreadData(size_t aligned_dim, size_t visited_reserve);
SSDThreadData(size_t aligned_dim, size_t visited_reserve, std::vector<uint32_t> &location_to_sellers);
void clear();
};

Expand Down
2 changes: 1 addition & 1 deletion python/src/static_disk_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ NeighborsAndDistances<StaticIdType> StaticDiskIndex<DT>::search(
std::vector<uint64_t> u64_ids(knn);
diskann::QueryStats stats;

_index.cached_beam_search(query.data(), knn, complexity, u64_ids.data(), dists.mutable_data(), beam_width, false,
_index.cached_beam_search(query.data(), knn, complexity, u64_ids.data(), dists.mutable_data(), beam_width, std::numeric_limits<uint32_t>::max(), false,
&stats);

auto r = ids.mutable_unchecked<1>();
Expand Down
2 changes: 1 addition & 1 deletion src/disk_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ uint32_t optimize_beamwidth(std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> &p
{
pFlashIndex->cached_beam_search(tuning_sample + (i * tuning_sample_aligned_dim), 1, L,
tuning_sample_result_ids_64.data() + (i * 1),
tuning_sample_result_dists.data() + (i * 1), cur_bw, false, stats + i);
tuning_sample_result_dists.data() + (i * 1), cur_bw, std::numeric_limits<uint32_t>::max(), false, stats + i);
}
auto e = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = e - s;
Expand Down
2 changes: 1 addition & 1 deletion src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ void Index<T, TagT, LabelT>::load(const char *filename, uint32_t num_threads, ui
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}

std::string index_seller_file = std::string(filename) + "_sellers.txt";
std::string index_seller_file = std::string(filename) + "_sellers.txt";
if(file_exists(index_seller_file)) {
uint64_t nrows_seller_file;
parse_seller_file(index_seller_file, nrows_seller_file);
Expand Down
Loading

0 comments on commit 6410548

Please sign in to comment.