Skip to content

Commit

Permalink
Merge pull request #9 from Maxxen/v0.10.2-dev
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxxen authored May 3, 2024
2 parents f50ffe2 + d8a80db commit 0d5e1d4
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 4 deletions.
22 changes: 20 additions & 2 deletions src/hnsw/hnsw_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,25 @@ struct HNSWIndexScanState : public IndexScanState {
unique_array<row_t> row_ids = nullptr;
};

unique_ptr<IndexScanState> HNSWIndex::InitializeScan(float *query_vector, idx_t limit) {
unique_ptr<IndexScanState> HNSWIndex::InitializeScan(float *query_vector, idx_t limit, ClientContext &context) {
auto state = make_uniq<HNSWIndexScanState>();

// Try to get the ef_search parameter from the database or use the default value
auto ef_search = index.expansion_search();

Value hnsw_ef_search_opt;
if(context.TryGetCurrentSetting("hnsw_ef_search", hnsw_ef_search_opt)) {
if(!hnsw_ef_search_opt.IsNull() && hnsw_ef_search_opt.type() == LogicalType::BIGINT) {
auto val = hnsw_ef_search_opt.GetValue<int64_t>();
if(val > 0) {
ef_search = static_cast<idx_t>(val);
}
}
}

// Acquire a shared lock to search the index
auto lock = rwlock.GetSharedLock();
auto search_result = index.search(query_vector, limit);
auto search_result = index.ef_search(query_vector, limit, ef_search);

state->current_row = 0;
state->total_rows = search_result.size();
Expand Down Expand Up @@ -525,6 +538,11 @@ void HNSWModule::RegisterIndex(DatabaseInstance &db) {
return std::move(res);
};

// Register scan option
db.config.AddExtensionOption("hnsw_ef_search",
"experimental: override the ef_search parameter when scanning HNSW indexes",
LogicalType::BIGINT);

// Register the index type
db.config.GetIndexTypes().RegisterIndexType(index_type);
}
Expand Down
2 changes: 1 addition & 1 deletion src/hnsw/hnsw_index_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ static unique_ptr<GlobalTableFunctionState> HNSWIndexScanInitGlobal(ClientContex
local_storage.InitializeScan(bind_data.table.GetStorage(), result->local_storage_state.local_state, input.filters);

// Initialize the scan state for the index
result->index_state = bind_data.index.Cast<HNSWIndex>().InitializeScan(bind_data.query.get(), bind_data.limit);
result->index_state = bind_data.index.Cast<HNSWIndex>().InitializeScan(bind_data.query.get(), bind_data.limit, context);

return std::move(result);
}
Expand Down
2 changes: 1 addition & 1 deletion src/include/hnsw/hnsw_index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class HNSWIndex : public Index {
//! The allocator used to persist linked blocks
unique_ptr<FixedSizeAllocator> linked_block_allocator;

unique_ptr<IndexScanState> InitializeScan(float *query_vector, idx_t limit);
unique_ptr<IndexScanState> InitializeScan(float *query_vector, idx_t limit, ClientContext &context);
idx_t Scan(IndexScanState &state, Vector &result);

idx_t GetVectorSize() const;
Expand Down
30 changes: 30 additions & 0 deletions src/include/usearch/index_dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,12 @@ class index_dense_gt {
search_result_t search(f32_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, thread, exact, casts_.from_f32); }
search_result_t search(f64_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, thread, exact, casts_.from_f64); }

search_result_t ef_search(b1x8_t const* vector, std::size_t wanted, std::size_t ef_search, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, thread, exact, casts_.from_b1x8, ef_search); }
search_result_t ef_search(i8_t const* vector, std::size_t wanted, std::size_t ef_search, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, thread, exact, casts_.from_i8, ef_search); }
search_result_t ef_search(f16_t const* vector, std::size_t wanted, std::size_t ef_search, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, thread, exact, casts_.from_f16, ef_search); }
search_result_t ef_search(f32_t const* vector, std::size_t wanted, std::size_t ef_search, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, thread, exact, casts_.from_f32, ef_search); }
search_result_t ef_search(f64_t const* vector, std::size_t wanted, std::size_t ef_search, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, thread, exact, casts_.from_f64, ef_search); }

std::size_t get(vector_key_t key, b1x8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_b1x8); }
std::size_t get(vector_key_t key, i8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_i8); }
std::size_t get(vector_key_t key, f16_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_f16); }
Expand Down Expand Up @@ -1791,6 +1797,30 @@ class index_dense_gt {
return typed_->search(vector_data, wanted, metric_proxy_t{*this}, search_config, allow);
}

template <typename scalar_at>
search_result_t search_( //
scalar_at const* vector, std::size_t wanted, //
std::size_t thread, bool exact, cast_t const& cast, std::size_t ef_search) const {

// Cast the vector, if needed for compatibility with `metric_`
thread_lock_t lock = thread_lock_(thread);
byte_t const* vector_data = reinterpret_cast<byte_t const*>(vector);
{
byte_t* casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id;
bool casted = cast(vector_data, dimensions(), casted_data);
if (casted)
vector_data = casted_data;
}

index_search_config_t search_config;
search_config.thread = lock.thread_id;
search_config.expansion = ef_search;
search_config.exact = exact;

auto allow = [=](member_cref_t const& member) noexcept { return member.key != free_key_; };
return typed_->search(vector_data, wanted, metric_proxy_t{*this}, search_config, allow);
}

template <typename scalar_at>
cluster_result_t cluster_( //
scalar_at const* vector, std::size_t level, //
Expand Down

0 comments on commit 0d5e1d4

Please sign in to comment.