diff --git a/src/hnsw/hnsw_const.rs b/src/hnsw/hnsw_const.rs index 2a76e21..630b3b8 100644 --- a/src/hnsw/hnsw_const.rs +++ b/src/hnsw/hnsw_const.rs @@ -64,6 +64,17 @@ where params, } } + + pub fn new_with_capacity(metric: Met, params: Params, capacity: usize) -> Self { + Self { + metric, + zero: Vec::with_capacity(capacity), + features: Vec::with_capacity(capacity), + layers: vec![], + prng: R::from_seed(R::Seed::default()), + params, + } + } } impl Knn for Hnsw @@ -374,12 +385,14 @@ where // See Algorithm 5 line 5 of the paper. The paper makes no further comment on why `1` was chosen. let &Neighbor { index, distance } = searcher.nearest.first().unwrap(); searcher.nearest.clear(); + searcher.seen.clear(); // Update the node to the next layer. let new_index = layer[index].next_node as usize; let candidate = Neighbor { index: new_index, distance, }; + searcher.seen.insert(layer[index].zero_node); // Insert the index of the nearest neighbor into the nearest pool for the next layer. searcher.nearest.push(candidate); // Insert the index into the candidate pool as well. diff --git a/tests/simple.rs b/tests/simple.rs index d51af36..edcaa00 100644 --- a/tests/simple.rs +++ b/tests/simple.rs @@ -1,6 +1,7 @@ //! Useful tests for debugging since they are hand-written and easy to see the debugging output. use hnsw::{Hnsw, Searcher}; +use itertools::Itertools; use rand_pcg::Pcg64; use space::{Metric, Neighbor}; @@ -18,13 +19,43 @@ impl Metric<&[f64]> for Euclidean { } } +struct TestBruteForceHelper { + vectors: Vec<(usize, Vec)>, +} + +impl TestBruteForceHelper { + fn new() -> Self { + Self { + vectors: Vec::new(), + } + } + + fn push(&mut self, v: (usize, Vec)) { + self.vectors.push(v); + } + + fn search(&self, query: &[f64], top_k: usize) -> Vec { + let metric = Euclidean; + let mut candidates: Vec<(usize, u64)> = self + .vectors + .iter() + .map(|v| (v.0.clone(), metric.distance(&query, &v.1.as_slice()))) + .collect_vec(); + + candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + + candidates.into_iter().take(top_k).map(|v| v.0).collect() + } +} + fn test_hnsw() -> ( Hnsw, Searcher, + TestBruteForceHelper, ) { let mut searcher = Searcher::default(); let mut hnsw = Hnsw::new(Euclidean); - + let mut helper = TestBruteForceHelper::new(); let features = [ &[0.0, 0.0, 0.0, 1.0], &[0.0, 0.0, 1.0, 0.0], @@ -36,11 +67,12 @@ fn test_hnsw() -> ( &[1.0, 0.0, 0.0, 1.0], ]; - for &feature in &features { - hnsw.insert(feature, &mut searcher); + for (index, feature) in features.iter().enumerate() { + helper.push((index, feature.to_vec())); + hnsw.insert(*feature, &mut searcher); } - (hnsw, searcher) + (hnsw, searcher, helper) } #[test] @@ -50,7 +82,7 @@ fn insertion() { #[test] fn nearest_neighbor() { - let (hnsw, mut searcher) = test_hnsw(); + let (hnsw, mut searcher, helper) = test_hnsw(); let searcher = &mut searcher; let mut neighbors = [Neighbor { index: !0, @@ -101,4 +133,17 @@ fn nearest_neighbor() { } ] ); + // test for not panicking + for topk in 0..8 { + let mut neighbors = vec![ + Neighbor { + index: !0, + distance: !0, + }; + topk + ]; + hnsw.nearest(&&[0.0, 0.0, 0.0, 1.0][..], 24, searcher, &mut neighbors); + let result = neighbors.iter().map(|item| item.index).collect_vec(); + assert_eq!(result, helper.search(&[0.0, 0.0, 0.0, 1.0], topk)); + } }