Skip to content

Commit

Permalink
make subcentroids_random better (submode 3)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jakub Tarnawski committed Oct 23, 2023
1 parent 78761ff commit ccecf58
Showing 1 changed file with 59 additions and 3 deletions.
62 changes: 59 additions & 3 deletions tests/utils/partition_hmetis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ int aux_main(const std::string &input_file,
std::make_unique<float[]>(num_shards * num_subcentroids * dim);
std::unique_ptr<uint32_t[]> closest_centers_ivf =
std::make_unique<uint32_t[]>(num_shards * num_subcentroids);
// for each shard centers, find the `num_subcentroids` many closest shard centers
// for each shard center, find the `num_subcentroids` many closest shard centers
// (including itself, which will be the closest)
math_utils::compute_closest_centers(
centroids.get(), num_shards, dim, centroids.get(), num_shards,
Expand Down Expand Up @@ -493,7 +493,7 @@ int aux_main(const std::string &input_file,
} else if (mode == "multicentroids" || mode == "multicentroids-random" ||
mode == "multicentroids-neighbors") {

constexpr int submode = 1;
constexpr int submode = 3;

std::unique_ptr<float[]> queries_float =
std::make_unique<float[]>(num_queries * dim);
Expand Down Expand Up @@ -531,7 +531,7 @@ int aux_main(const std::string &input_file,
// compute score of shard_id for query_id
float score = 0.0;
for (int i = 0; i < num_subcentroids; ++i) {
const float dist = sqrt(math_utils::calc_distance(
const float dist = sqrt(math_utils::calc_distance(
queries_float.get() + query_id * dim,
subcentroids.get() + shard_id * num_subcentroids * dim +
i * dim,
Expand All @@ -552,6 +552,62 @@ int aux_main(const std::string &input_file,
// wasn't present
}
}
} else if (submode == 3) {
if (mode != "subcentroids_random") {
diskann::cout << "Error: submode 3 only works with "
"multicentroids-random"
<< std::endl;
return -1;
}
// 3: Depends on a specific K.
// For each query,
// we add the subsampled points one by one, from the closest to the farthest.
// As we do this, we maintain for each shard a `worth` value:
// the expected number of points at distance <= d = current distance
// (thought of as the GT points) that are in this shard.
// Worth is computed as: (number of subsampled points from this shard seen so far)
// * (shard size) / num_subcentroids.
// We stop once the sum of worths of all shards is >= K
// (at this point we estimate that we have seen all the subsampled GT points).
// Then we sort the shards by worth (and later take the top query_fanout shards).


const size_t num_subcenters = num_shards * num_subcentroids;
std::unique_ptr<uint32_t[]> closest_centroids_ivf =
std::make_unique<uint32_t[]>(num_queries * num_subcenters);
math_utils::compute_closest_centers(
queries_float.get(), num_queries, dim, subcentroids.get(),
num_subcenters, num_subcenters, closest_centroids_ivf.get());
for (size_t query_id = 0; query_id < num_queries; ++query_id) {
std::vector<float> worth_of_shard(num_shards, 0.0);
float sum_of_worths = 0.0;
for (int i = 0; i < num_subcenters; ++i) {
const size_t shard_id =
closest_centroids_ivf[query_id * num_subcenters + i] /
num_subcentroids;
const float worth_increase =
points_routed_to_shard[shard_id].size() * 1.0 / num_subcentroids;
worth_of_shard[shard_id] += worth_increase;
sum_of_worths += worth_increase;
if (sum_of_worths >= K) {
break;
}
}
std::vector<std::pair<float, size_t>> shards_with_worths;
for (size_t shard_id = 0; shard_id < num_shards; ++shard_id) {
shards_with_worths.emplace_back(-worth_of_shard[shard_id],
shard_id);
}
sort(shards_with_worths.begin(), shards_with_worths.end());
query_to_shards.emplace_back();
for (int i = 0; i < num_shards; ++i) {
const size_t shard_id = shards_with_worths[i].second;
query_to_shards[query_id].emplace_back(
shard_id, shard_to_count_of_GT_pts[query_id][shard_id]);
// shard_to_count_of_GT_pts[query_id][shard_id] will be(come) 0 if
// wasn't present
}
}
} else {
diskann::cout << "what submode?" << std::endl;
return -1;
Expand Down

0 comments on commit ccecf58

Please sign in to comment.