diff --git a/tests/utils/partition_hmetis.cpp b/tests/utils/partition_hmetis.cpp index db1ad7272..4db1d0284 100644 --- a/tests/utils/partition_hmetis.cpp +++ b/tests/utils/partition_hmetis.cpp @@ -378,7 +378,7 @@ int aux_main(const std::string &input_file, std::make_unique(num_shards * num_subcentroids * dim); std::unique_ptr closest_centers_ivf = std::make_unique(num_shards * num_subcentroids); - // for each shard centers, find the `num_subcentroids` many closest shard centers + // for each shard center, find the `num_subcentroids` many closest shard centers // (including itself, which will be the closest) math_utils::compute_closest_centers( centroids.get(), num_shards, dim, centroids.get(), num_shards, @@ -493,7 +493,7 @@ int aux_main(const std::string &input_file, } else if (mode == "multicentroids" || mode == "multicentroids-random" || mode == "multicentroids-neighbors") { - constexpr int submode = 1; + constexpr int submode = 3; std::unique_ptr queries_float = std::make_unique(num_queries * dim); @@ -531,7 +531,7 @@ int aux_main(const std::string &input_file, // compute score of shard_id for query_id float score = 0.0; for (int i = 0; i < num_subcentroids; ++i) { - const float dist = sqrt(math_utils::calc_distance( + const float dist = sqrt(math_utils::calc_distance( queries_float.get() + query_id * dim, subcentroids.get() + shard_id * num_subcentroids * dim + i * dim, @@ -552,6 +552,62 @@ int aux_main(const std::string &input_file, // wasn't present } } + } else if (submode == 3) { + if (mode != "subcentroids_random") { + diskann::cout << "Error: submode 3 only works with " + "multicentroids-random" + << std::endl; + return -1; + } + // 3: Depends on a specific K. + // For each query, + // we add the subsampled points one by one, from the closest to the farthest. + // As we do this, we maintain for each shard a `worth` value: + // the expected number of points at distance <= d = current distance + // (thought of as the GT points) that are in this shard. + // Worth is computed as: (number of subsampled points from this shard seen so far) + // * (shard size) / num_subcentroids. + // We stop once the sum of worths of all shards is >= K + // (at this point we estimate that we have seen all the subsampled GT points). + // Then we sort the shards by worth (and later take the top query_fanout shards). + + + const size_t num_subcenters = num_shards * num_subcentroids; + std::unique_ptr closest_centroids_ivf = + std::make_unique(num_queries * num_subcenters); + math_utils::compute_closest_centers( + queries_float.get(), num_queries, dim, subcentroids.get(), + num_subcenters, num_subcenters, closest_centroids_ivf.get()); + for (size_t query_id = 0; query_id < num_queries; ++query_id) { + std::vector worth_of_shard(num_shards, 0.0); + float sum_of_worths = 0.0; + for (int i = 0; i < num_subcenters; ++i) { + const size_t shard_id = + closest_centroids_ivf[query_id * num_subcenters + i] / + num_subcentroids; + const float worth_increase = + points_routed_to_shard[shard_id].size() * 1.0 / num_subcentroids; + worth_of_shard[shard_id] += worth_increase; + sum_of_worths += worth_increase; + if (sum_of_worths >= K) { + break; + } + } + std::vector> shards_with_worths; + for (size_t shard_id = 0; shard_id < num_shards; ++shard_id) { + shards_with_worths.emplace_back(-worth_of_shard[shard_id], + shard_id); + } + sort(shards_with_worths.begin(), shards_with_worths.end()); + query_to_shards.emplace_back(); + for (int i = 0; i < num_shards; ++i) { + const size_t shard_id = shards_with_worths[i].second; + query_to_shards[query_id].emplace_back( + shard_id, shard_to_count_of_GT_pts[query_id][shard_id]); + // shard_to_count_of_GT_pts[query_id][shard_id] will be(come) 0 if + // wasn't present + } + } } else { diskann::cout << "what submode?" << std::endl; return -1;