diff --git a/bindings/python/pyproject.toml b/bindings/python/pyproject.toml index 8bc8c14c..aeea5a98 100644 --- a/bindings/python/pyproject.toml +++ b/bindings/python/pyproject.toml @@ -16,7 +16,7 @@ requires = [ "setuptools>=42", "scikit-build", - "cmake>=3.21", # Keep in-sync with `CMakeLists.txt` + "cmake>=3.21, <4", # Keep in-sync with `CMakeLists.txt` "numpy>=1.10.0, <2", # Keep in-sync with `setup.py` "archspec>=0.2.0", # Keep in-sync with `setup.py` "toml>=0.10.2", # Keep in-sync with `setup.py` required for the tests diff --git a/bindings/python/src/vamana.cpp b/bindings/python/src/vamana.cpp index 9801c306..603f5007 100644 --- a/bindings/python/src/vamana.cpp +++ b/bindings/python/src/vamana.cpp @@ -30,6 +30,7 @@ #include "svs/lib/dispatcher.h" #include "svs/lib/float16.h" #include "svs/lib/meta.h" +#include "svs/lib/preprocessor.h" #include "svs/orchestrators/vamana.h" // pybind @@ -420,40 +421,22 @@ void wrap(py::module& m) { size_t window_size, size_t max_candidate_pool_size, size_t prune_to, - size_t num_threads) { - if (num_threads != std::numeric_limits::max()) { - PyErr_WarnEx( - PyExc_DeprecationWarning, - "Constructing VamanaBuildParameters with the \"num_threads\" " - "keyword " - "argument is deprecated, no longer has any effect, and will be " - "removed " - "from future versions of the library. Use the \"num_threads\" " - "keyword " - "argument of \"svs.Vamana.build\" instead!", - 1 - ); - } - - // Default the `prune_to` argument appropriately. - if (prune_to == std::numeric_limits::max()) { - prune_to = graph_max_degree; - } - + bool use_full_search_history) { return svs::index::vamana::VamanaBuildParameters{ alpha, graph_max_degree, window_size, max_candidate_pool_size, prune_to, - true}; + use_full_search_history}; }), - py::arg("alpha") = 1.2, - py::arg("graph_max_degree") = 32, - py::arg("window_size") = 64, - py::arg("max_candidate_pool_size") = 80, - py::arg("prune_to") = std::numeric_limits::max(), - py::arg("num_threads") = std::numeric_limits::max(), + py::arg("alpha") = svs::FLOAT_PLACEHOLDER, + py::arg("graph_max_degree") = svs::VAMANA_GRAPH_MAX_DEGREE_DEFAULT, + py::arg("window_size") = svs::VAMANA_WINDOW_SIZE_DEFAULT, + py::arg("max_candidate_pool_size") = svs::UNSIGNED_INTEGER_PLACEHOLDER, + py::arg("prune_to") = svs::UNSIGNED_INTEGER_PLACEHOLDER, + py::arg("use_full_search_history") = + svs::VAMANA_USE_FULL_SEARCH_HISTORY_DEFAULT, R"( Construct a new instance from keyword arguments. @@ -462,6 +445,7 @@ void wrap(py::module& m) { For distance types favoring minimization, set this to a number greater than 1.0 (typically, 1.2 is sufficient). For distance types preferring maximization, set to a value less than 1.0 (such as 0.95). + The default value is 1.2 for L2 distance type and 0.95 for MIP/Cosine. graph_max_degree: The maximum out-degree in the final graph. Graphs with a higher degree tend to yield better accuracy and performance at the cost of a larger memory footprint. @@ -470,10 +454,15 @@ void wrap(py::module& m) { longer construction time. Should be larger than `graph_max_degree`. max_candidate_pool_size: Limit on the number of candidates to consider for neighbor updates. Should be larger than `window_size`. + The default value is ``graph_max_degree`` * 2. prune_to: Amount candidate lists will be pruned to when exceeding the target max degree. In general, setting this to slightly less than - `graph_max_degree` will yield faster index building times. Default: - `graph_max_degree`. + ``graph_max_degree`` will yield faster index building times. Default: + ` `graph_max_degree`` - 4 if + ``graph_max_degree`` is at least 16, otherwise ``graph_max_degree``. + use_full_search_history: When true, uses the full search history during + graph construction, which can improve graph quality at the expense of + additional memory and potentially longer build times. )" ) .def_readwrite("alpha", &svs::index::vamana::VamanaBuildParameters::alpha) @@ -557,4 +546,4 @@ overwritten when saving the index to this directory. )" ); } -} // namespace svs::python::vamana +} // namespace svs::python::vamana \ No newline at end of file diff --git a/bindings/python/tests/test_dynamic_vamana.py b/bindings/python/tests/test_dynamic_vamana.py index 7fa48640..84d78217 100644 --- a/bindings/python/tests/test_dynamic_vamana.py +++ b/bindings/python/tests/test_dynamic_vamana.py @@ -98,7 +98,7 @@ def test_loop(self): # here, we set an expected mid-point for the recall and allow it to wander up and # down by a little. expected_recall = 0.845 - expected_recall_delta = 0.03 + expected_recall_delta = 0.05 reference = ReferenceDataset(num_threads = num_threads) data, ids = reference.new_ids(5000) diff --git a/bindings/python/tests/test_vamana.py b/bindings/python/tests/test_vamana.py index 763afe88..8b288564 100644 --- a/bindings/python/tests/test_vamana.py +++ b/bindings/python/tests/test_vamana.py @@ -281,13 +281,6 @@ def test_basic(self): self._test_basic(loader, matcher, first_iter = first_iter) first_iter = False - def test_deprecation(self): - with warnings.catch_warnings(record = True) as w: - p = svs.VamanaBuildParameters(num_threads = 1) - self.assertTrue(len(w) == 1) - self.assertTrue(issubclass(w[0].category, DeprecationWarning)) - self.assertTrue("VamanaBuildParameters" in str(w[0].message)) - def _groundtruth_map(self): return { svs.DistanceType.L2: test_groundtruth_l2, diff --git a/include/svs/index/vamana/build_params.h b/include/svs/index/vamana/build_params.h index 11959134..65b5039c 100644 --- a/include/svs/index/vamana/build_params.h +++ b/include/svs/index/vamana/build_params.h @@ -17,6 +17,7 @@ #pragma once // svs +#include "svs/lib/preprocessor.h" #include "svs/lib/saveload.h" // stl @@ -44,33 +45,33 @@ struct VamanaBuildParameters { , use_full_search_history{use_full_search_history_} {} /// The pruning parameter. - float alpha; + float alpha = svs::FLOAT_PLACEHOLDER; /// The maximum degree in the graph. A higher max degree may yield a higher quality /// graph in terms of recall for performance, but the memory footprint of the graph is /// directly proportional to the maximum degree. - size_t graph_max_degree; + size_t graph_max_degree = svs::VAMANA_GRAPH_MAX_DEGREE_DEFAULT; /// The search window size to use during graph construction. A higher search window /// size will yield a higher quality graph since more overall vertices are considered, /// but will increase construction time. - size_t window_size; + size_t window_size = svs::VAMANA_WINDOW_SIZE_DEFAULT; /// Set a limit on the number of neighbors considered during pruning. In practice, set /// this to a high number (at least 5 times greater than the window_size) and forget /// about it. - size_t max_candidate_pool_size; + size_t max_candidate_pool_size = svs::UNSIGNED_INTEGER_PLACEHOLDER; /// This is the amount that candidates will be pruned to after certain pruning /// procedures. Setting this to less than ``graph_max_degree`` can result in significant /// speedups in index building. - size_t prune_to; + size_t prune_to = svs::UNSIGNED_INTEGER_PLACEHOLDER; /// When building, either the contents of the search buffer can be used or the entire /// search history can be used. /// /// The latter case may yield a slightly better graph as the cost of more search time. - bool use_full_search_history = true; + bool use_full_search_history = svs::VAMANA_USE_FULL_SEARCH_HISTORY_DEFAULT; ///// Comparison friend bool @@ -129,4 +130,4 @@ struct VamanaBuildParameters { ); } }; -} // namespace svs::index::vamana +} // namespace svs::index::vamana \ No newline at end of file diff --git a/include/svs/index/vamana/dynamic_index.h b/include/svs/index/vamana/dynamic_index.h index 6a37778b..39891edc 100644 --- a/include/svs/index/vamana/dynamic_index.h +++ b/include/svs/index/vamana/dynamic_index.h @@ -38,6 +38,7 @@ #include "svs/index/vamana/index.h" #include "svs/index/vamana/vamana_build.h" #include "svs/lib/boundscheck.h" +#include "svs/lib/preprocessor.h" #include "svs/lib/threads.h" namespace svs::index::vamana { @@ -157,6 +158,9 @@ class MutableVamanaIndex { float alpha_ = 1.2; bool use_full_search_history_ = true; + // Construction parameters + VamanaBuildParameters build_parameters_{}; + // SVS logger for per index logging svs::logging::logger_ptr logger_; @@ -210,12 +214,19 @@ class MutableVamanaIndex { , distance_(std::move(distance_function)) , threadpool_(threads::as_threadpool(std::move(threadpool_proto))) , search_parameters_(vamana::construct_default_search_parameters(data_)) - , construction_window_size_(parameters.window_size) - , max_candidates_(parameters.max_candidate_pool_size) - , prune_to_(parameters.prune_to) - , alpha_(parameters.alpha) - , use_full_search_history_{parameters.use_full_search_history} + , build_parameters_(parameters) , logger_{std::move(logger)} { + // Verify and set defaults directly on the input parameters + verify_and_set_default_index_parameters(build_parameters_, distance_function); + + // Set graph again as verify function might change graph_max_degree parameter + graph_ = Graph{data_.size(), build_parameters_.graph_max_degree}; + construction_window_size_ = build_parameters_.window_size; + max_candidates_ = build_parameters_.max_candidate_pool_size; + prune_to_ = build_parameters_.prune_to; + alpha_ = build_parameters_.alpha; + use_full_search_history_ = build_parameters_.use_full_search_history; + // Setup the initial translation of external to internal ids. translator_.insert(external_ids, threads::UnitRange(0, external_ids.size())); @@ -227,10 +238,12 @@ class MutableVamanaIndex { auto prefetch_parameters = GreedySearchPrefetchParameters{sp.prefetch_lookahead_, sp.prefetch_step_}; auto builder = VamanaBuilder( - graph_, data_, distance_, parameters, threadpool_, prefetch_parameters + graph_, data_, distance_, build_parameters_, threadpool_, prefetch_parameters ); builder.construct(1.0f, entry_point_[0], logging::Level::Info, logger_); - builder.construct(parameters.alpha, entry_point_[0], logging::Level::Info, logger_); + builder.construct( + build_parameters_.alpha, entry_point_[0], logging::Level::Info, logger_ + ); } /// @brief Post re-load constructor. @@ -1346,4 +1359,4 @@ auto auto_dynamic_assemble( std::move(logger)}; } -} // namespace svs::index::vamana +} // namespace svs::index::vamana \ No newline at end of file diff --git a/include/svs/index/vamana/index.h b/include/svs/index/vamana/index.h index a50ce11d..1ee96cd5 100644 --- a/include/svs/index/vamana/index.h +++ b/include/svs/index/vamana/index.h @@ -404,19 +404,22 @@ class VamanaIndex { if (graph_.n_nodes() != data_.size()) { throw ANNEXCEPTION("Wrong sizes!"); } - build_parameters_ = parameters; + // verify the parameters before set local var + verify_and_set_default_index_parameters(build_parameters_, distance_function); auto builder = VamanaBuilder( graph_, data_, distance_, - parameters, + build_parameters_, threadpool_, extensions::estimate_prefetch_parameters(data_) ); builder.construct(1.0F, entry_point_[0], logging::Level::Info, logger); - builder.construct(parameters.alpha, entry_point_[0], logging::Level::Info, logger); + builder.construct( + build_parameters_.alpha, entry_point_[0], logging::Level::Info, logger + ); } /// @brief Getter method for logger @@ -896,10 +899,13 @@ auto auto_build( auto entry_point = extensions::compute_entry_point(data, threadpool); // Default graph. - auto graph = default_graph(data.size(), parameters.graph_max_degree, graph_allocator); + auto verified_parameters = parameters; + verify_and_set_default_index_parameters(verified_parameters, distance); + auto graph = + default_graph(data.size(), verified_parameters.graph_max_degree, graph_allocator); using I = typename decltype(graph)::index_type; return VamanaIndex{ - parameters, + verified_parameters, std::move(graph), std::move(data), lib::narrow(entry_point), @@ -959,4 +965,57 @@ auto auto_assemble( index.apply(config); return index; } + +/// @brief Verify parameters and set defaults if needed +template +void verify_and_set_default_index_parameters( + VamanaBuildParameters& parameters, Dist distance_function +) { + // Set default values + if (parameters.max_candidate_pool_size == svs::UNSIGNED_INTEGER_PLACEHOLDER) { + parameters.max_candidate_pool_size = 2 * parameters.graph_max_degree; + } + + if (parameters.prune_to == svs::UNSIGNED_INTEGER_PLACEHOLDER) { + if (parameters.graph_max_degree >= 16) { + parameters.prune_to = parameters.graph_max_degree - 4; + } else { + parameters.prune_to = parameters.graph_max_degree; + } + } + + // Check supported distance type using std::is_same type trait + using dist_type = std::decay_t; + // Create type flags for each distance type + constexpr bool is_L2 = std::is_same_v; + constexpr bool is_IP = std::is_same_v; + constexpr bool is_Cosine = + std::is_same_v; + + // Handle alpha based on distance type + if constexpr (is_L2) { + if (parameters.alpha == svs::FLOAT_PLACEHOLDER) { + parameters.alpha = svs::VAMANA_ALPHA_MINIMIZE_DEFAULT; + } else if (parameters.alpha < 1.0f) { + // Check User set values + throw std::invalid_argument("For L2 distance, alpha must be >= 1.0"); + } + } else if constexpr (is_IP || is_Cosine) { + if (parameters.alpha == svs::FLOAT_PLACEHOLDER) { + parameters.alpha = svs::VAMANA_ALPHA_MAXIMIZE_DEFAULT; + } else if (parameters.alpha > 1.0f) { + // Check User set values + throw std::invalid_argument("For MIP/Cosine distance, alpha must be <= 1.0"); + } else if (parameters.alpha <= 0.0f) { + throw std::invalid_argument("alpha must be > 0"); + } + } else { + throw std::invalid_argument("Unsupported distance type"); + } + + // Check prune_to <= graph_max_degree + if (parameters.prune_to > parameters.graph_max_degree) { + throw std::invalid_argument("prune_to must be <= graph_max_degree"); + } +} } // namespace svs::index::vamana diff --git a/include/svs/lib/preprocessor.h b/include/svs/lib/preprocessor.h index f1765cde..e3a1900d 100644 --- a/include/svs/lib/preprocessor.h +++ b/include/svs/lib/preprocessor.h @@ -16,6 +16,9 @@ #pragma once +#include +#include + namespace svs::preprocessor::detail { // consteval functions for working with preprocessor defines. @@ -159,3 +162,14 @@ inline constexpr bool have_avx512_avx2 = true; #endif } // namespace svs::arch + +namespace svs { +// Maximum values used as default initializers +inline constexpr size_t UNSIGNED_INTEGER_PLACEHOLDER = std::numeric_limits::max(); +inline constexpr float FLOAT_PLACEHOLDER = std::numeric_limits::max(); +inline constexpr float VAMANA_GRAPH_MAX_DEGREE_DEFAULT = 32; +inline constexpr float VAMANA_WINDOW_SIZE_DEFAULT = 64; +inline constexpr bool VAMANA_USE_FULL_SEARCH_HISTORY_DEFAULT = true; +inline constexpr float VAMANA_ALPHA_MINIMIZE_DEFAULT = 1.2; +inline constexpr float VAMANA_ALPHA_MAXIMIZE_DEFAULT = 0.95; +} // namespace svs \ No newline at end of file diff --git a/tests/svs/index/vamana/dynamic_index_2.cpp b/tests/svs/index/vamana/dynamic_index_2.cpp index a3acb7f0..f09c2c1e 100644 --- a/tests/svs/index/vamana/dynamic_index_2.cpp +++ b/tests/svs/index/vamana/dynamic_index_2.cpp @@ -19,6 +19,7 @@ #include "svs/core/recall.h" #include "svs/index/flat/flat.h" #include "svs/index/vamana/dynamic_index.h" +#include "svs/lib/preprocessor.h" #include "svs/lib/timing.h" #include "svs/misc/dynamic_helper.h" @@ -476,4 +477,144 @@ CATCH_TEST_CASE("Dynamic MutableVamanaIndex Default Logger Test", "[logging]") { // Verify that the default logger is used auto default_logger = svs::logging::get(); CATCH_REQUIRE(index.get_logger() == default_logger); +} + +CATCH_TEST_CASE("Dynamic Vamana Index Default Parameters", "[parameter][vamana]") { + using Catch::Approx; + std::filesystem::path data_path = test_dataset::data_svs_file(); + + CATCH_SECTION("L2 Distance Defaults") { + auto expected_result = test_dataset::vamana::expected_build_results( + svs::L2, svsbenchmark::Uncompressed(svs::DataType::float32) + ); + auto build_params = expected_result.build_parameters_.value(); + auto data_loader = svs::data::SimpleData::load(data_path); + + // Get IDs for all points in the dataset + std::vector indices(data_loader.size()); + std::iota(indices.begin(), indices.end(), 0); + + // Build dynamic index with L2 distance + auto index = svs::index::vamana::MutableVamanaIndex( + build_params, std::move(data_loader), indices, svs::distance::DistanceL2(), 2 + ); + + CATCH_REQUIRE(index.get_alpha() == Approx(svs::VAMANA_ALPHA_MINIMIZE_DEFAULT)); + } + + CATCH_SECTION("MIP Distance Defaults") { + auto expected_result = test_dataset::vamana::expected_build_results( + svs::MIP, svsbenchmark::Uncompressed(svs::DataType::float32) + ); + auto build_params = expected_result.build_parameters_.value(); + auto data_loader = svs::data::SimpleData::load(data_path); + + // Get IDs for all points in the dataset + std::vector indices(data_loader.size()); + std::iota(indices.begin(), indices.end(), 0); + + // Build dynamic index with MIP distance + auto index = svs::index::vamana::MutableVamanaIndex( + build_params, std::move(data_loader), indices, svs::distance::DistanceIP(), 2 + ); + + CATCH_REQUIRE(index.get_alpha() == Approx(svs::VAMANA_ALPHA_MAXIMIZE_DEFAULT)); + } + + CATCH_SECTION("Invalid Alpha for L2") { + auto expected_result = test_dataset::vamana::expected_build_results( + svs::L2, svsbenchmark::Uncompressed(svs::DataType::float32) + ); + auto build_params = expected_result.build_parameters_.value(); + build_params.alpha = 0.8f; + auto data_loader = svs::data::SimpleData::load(data_path); + + // Get IDs for all points in the dataset + std::vector indices(data_loader.size()); + std::iota(indices.begin(), indices.end(), 0); + + CATCH_REQUIRE_THROWS_WITH( + svs::index::vamana::MutableVamanaIndex( + build_params, + std::move(data_loader), + indices, + svs::distance::DistanceL2(), + 2 + ), + "For L2 distance, alpha must be >= 1.0" + ); + } + + CATCH_SECTION("Invalid Alpha for MIP") { + auto expected_result = test_dataset::vamana::expected_build_results( + svs::MIP, svsbenchmark::Uncompressed(svs::DataType::float32) + ); + auto build_params = expected_result.build_parameters_.value(); + build_params.alpha = 1.2f; + auto data_loader = svs::data::SimpleData::load(data_path); + + // Get IDs for all points in the dataset + std::vector indices(data_loader.size()); + std::iota(indices.begin(), indices.end(), 0); + + CATCH_REQUIRE_THROWS_WITH( + svs::index::vamana::MutableVamanaIndex( + build_params, + std::move(data_loader), + indices, + svs::distance::DistanceIP(), + 2 + ), + "For MIP/Cosine distance, alpha must be <= 1.0" + ); + } + + CATCH_SECTION("Invalid prune_to > graph_max_degree") { + auto expected_result = test_dataset::vamana::expected_build_results( + svs::L2, svsbenchmark::Uncompressed(svs::DataType::float32) + ); + auto build_params = expected_result.build_parameters_.value(); + build_params.prune_to = build_params.graph_max_degree + 10; + auto data_loader = svs::data::SimpleData::load(data_path); + + // Get IDs for all points in the dataset + std::vector indices(data_loader.size()); + std::iota(indices.begin(), indices.end(), 0); + + CATCH_REQUIRE_THROWS_WITH( + svs::index::vamana::MutableVamanaIndex( + build_params, + std::move(data_loader), + indices, + svs::distance::DistanceL2(), + 2 + ), + "prune_to must be <= graph_max_degree" + ); + } + + CATCH_SECTION("L2 Distance Empty Params") { + svs::index::vamana::VamanaBuildParameters params; + std::vector data(32); + for (size_t i = 0; i < data.size(); i++) { + data[i] = static_cast(i + 1); + } + auto data_view = svs::data::SimpleDataView(data.data(), 8, 4); + std::vector indices = {0, 1, 2, 3, 4, 5, 6, 7}; + auto index = svs::index::vamana::MutableVamanaIndex( + params, std::move(data_view), indices, svs::distance::DistanceL2(), 1 + ); + CATCH_REQUIRE(index.get_alpha() == Approx(svs::VAMANA_ALPHA_MINIMIZE_DEFAULT)); + CATCH_REQUIRE(index.get_graph_max_degree() == svs::VAMANA_GRAPH_MAX_DEGREE_DEFAULT); + CATCH_REQUIRE(index.get_prune_to() == svs::VAMANA_GRAPH_MAX_DEGREE_DEFAULT - 4); + CATCH_REQUIRE( + index.get_construction_window_size() == svs::VAMANA_WINDOW_SIZE_DEFAULT + ); + CATCH_REQUIRE( + index.get_max_candidates() == 2 * svs::VAMANA_GRAPH_MAX_DEGREE_DEFAULT + ); + CATCH_REQUIRE( + index.get_full_search_history() == svs::VAMANA_USE_FULL_SEARCH_HISTORY_DEFAULT + ); + } } \ No newline at end of file diff --git a/tests/svs/index/vamana/index.cpp b/tests/svs/index/vamana/index.cpp index cd549299..6ceba9a1 100644 --- a/tests/svs/index/vamana/index.cpp +++ b/tests/svs/index/vamana/index.cpp @@ -16,12 +16,26 @@ // Header under test #include "svs/index/vamana/index.h" + +// Logging #include "spdlog/sinks/callback_sink.h" #include "svs/core/logging.h" +// svs +#include "svs/index/vamana/build_params.h" +#include "svs/lib/preprocessor.h" + // catch2 #include "catch2/catch_test_macros.hpp" +#include +// tests +#include "tests/utils/test_dataset.h" +#include "tests/utils/utils.h" +#include "tests/utils/vamana_reference.h" + +// svsbenchmark +#include "svs-benchmark/benchmark.h" // stl #include @@ -150,4 +164,86 @@ CATCH_TEST_CASE("Static VamanaIndex Per-Index Logging", "[logging]") { // Verify the internal log messages CATCH_REQUIRE(captured_logs[0].find("Number of syncs:") != std::string::npos); CATCH_REQUIRE(captured_logs[1].find("Batch Size:") != std::string::npos); +} + +CATCH_TEST_CASE("Vamana Index Default Parameters", "[parameter][vamana]") { + using Catch::Approx; + std::filesystem::path data_path = test_dataset::data_svs_file(); + + CATCH_SECTION("L2 Distance Defaults") { + auto expected_result = test_dataset::vamana::expected_build_results( + svs::L2, svsbenchmark::Uncompressed(svs::DataType::float32) + ); + auto build_params = expected_result.build_parameters_.value(); + auto data_loader = svs::data::SimpleData::load(data_path); + svs::Vamana index = svs::Vamana::build(build_params, data_loader, svs::L2); + CATCH_REQUIRE(index.get_alpha() == Approx(svs::VAMANA_ALPHA_MINIMIZE_DEFAULT)); + } + + CATCH_SECTION("MIP Distance Defaults") { + auto expected_result = test_dataset::vamana::expected_build_results( + svs::MIP, svsbenchmark::Uncompressed(svs::DataType::float32) + ); + auto build_params = expected_result.build_parameters_.value(); + auto data_loader = svs::data::SimpleData::load(data_path); + svs::Vamana index = svs::Vamana::build(build_params, data_loader, svs::MIP); + CATCH_REQUIRE(index.get_alpha() == Approx(svs::VAMANA_ALPHA_MAXIMIZE_DEFAULT)); + } + + CATCH_SECTION("Invalid Alpha for L2") { + auto expected_result = test_dataset::vamana::expected_build_results( + svs::L2, svsbenchmark::Uncompressed(svs::DataType::float32) + ); + auto build_params = expected_result.build_parameters_.value(); + build_params.alpha = 0.8f; + auto data_loader = svs::data::SimpleData::load(data_path); + CATCH_REQUIRE_THROWS_WITH( + svs::Vamana::build(build_params, data_loader, svs::L2), + "For L2 distance, alpha must be >= 1.0" + ); + } + + CATCH_SECTION("Invalid Alpha for MIP") { + auto expected_result = test_dataset::vamana::expected_build_results( + svs::MIP, svsbenchmark::Uncompressed(svs::DataType::float32) + ); + auto build_params = expected_result.build_parameters_.value(); + build_params.alpha = 1.2f; + auto data_loader = svs::data::SimpleData::load(data_path); + CATCH_REQUIRE_THROWS_WITH( + svs::Vamana::build(build_params, data_loader, svs::MIP), + "For MIP/Cosine distance, alpha must be <= 1.0" + ); + } + + CATCH_SECTION("Invalid prune_to > graph_max_degree") { + auto expected_result = test_dataset::vamana::expected_build_results( + svs::L2, svsbenchmark::Uncompressed(svs::DataType::float32) + ); + auto build_params = expected_result.build_parameters_.value(); + build_params.prune_to = build_params.graph_max_degree + 10; + auto data_loader = svs::data::SimpleData::load(data_path); + CATCH_REQUIRE_THROWS_WITH( + svs::Vamana::build(build_params, data_loader, svs::L2), + "prune_to must be <= graph_max_degree" + ); + } + + CATCH_SECTION("L2 Distance Empty Params") { + svs::index::vamana::VamanaBuildParameters empty_params; + auto data_loader = svs::data::SimpleData::load(data_path); + svs::Vamana index = svs::Vamana::build(empty_params, data_loader, svs::L2); + CATCH_REQUIRE(index.get_alpha() == Approx(svs::VAMANA_ALPHA_MINIMIZE_DEFAULT)); + CATCH_REQUIRE(index.get_graph_max_degree() == svs::VAMANA_GRAPH_MAX_DEGREE_DEFAULT); + CATCH_REQUIRE(index.get_prune_to() == svs::VAMANA_GRAPH_MAX_DEGREE_DEFAULT - 4); + CATCH_REQUIRE( + index.get_construction_window_size() == svs::VAMANA_WINDOW_SIZE_DEFAULT + ); + CATCH_REQUIRE( + index.get_max_candidates() == 2 * svs::VAMANA_GRAPH_MAX_DEGREE_DEFAULT + ); + CATCH_REQUIRE( + index.get_full_search_history() == svs::VAMANA_USE_FULL_SEARCH_HISTORY_DEFAULT + ); + } } \ No newline at end of file