diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index 5153bb58..d095e7ca 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -684,6 +684,62 @@ class Index { free_when_done_d)); } + py::object knnQuery_return_lists( + py::object input, + size_t k = 1, + int num_threads = -1, + const std::function& filter = nullptr) { + + py::array_t items(input); + auto buffer = items.request(); + size_t rows, features; + + if (num_threads <= 0) + num_threads = num_threads_default; + + get_input_array_shapes(buffer, &rows, &features); + + if (rows <= num_threads * 4) { + num_threads = 1; + } + + std::vector> data_lists_l(rows); + std::vector> data_lists_d(rows); + + CustomFilterFunctor idFilter(filter); + CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr; + + auto process_row = [&](size_t row) { + std::priority_queue> result = + appr_alg->searchKnn((void*)items.data(row), k, p_idFilter); + size_t result_size = result.size(); + data_lists_l[row].resize(result_size); + data_lists_d[row].resize(result_size); + for (int i = result_size - 1; i >= 0; i--) { + auto& result_tuple = result.top(); + data_lists_d[row][i] = result_tuple.first; + data_lists_l[row][i] = result_tuple.second; + result.pop(); + } + }; + + if (normalize == false) { + ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { + process_row(row); + }); + } else { + std::vector norm_array(num_threads * features); + ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { + float* data = (float*)items.data(row); + size_t start_idx = threadId * dim; + normalize_vector((float*)items.data(row), (norm_array.data() + start_idx)); + process_row(row); + }); + } + + return py::make_tuple(data_lists_l, data_lists_d); + } + void markDeleted(size_t label) { appr_alg->markDelete(label); @@ -894,6 +950,12 @@ PYBIND11_PLUGIN(hnswlib) { py::arg("k") = 1, py::arg("num_threads") = -1, py::arg("filter") = py::none()) + .def("knn_query_return_lists", + &Index::knnQuery_return_lists, + py::arg("data"), + py::arg("k") = 1, + py::arg("num_threads") = -1, + py::arg("filter") = py::none()) .def("add_items", &Index::addItems, py::arg("data"), diff --git a/tests/python/bindings_test_returnlists.py b/tests/python/bindings_test_returnlists.py new file mode 100644 index 00000000..d0c0f113 --- /dev/null +++ b/tests/python/bindings_test_returnlists.py @@ -0,0 +1,58 @@ +import os +import random +import unittest + +import numpy as np + +import hnswlib + + +class RandomSelfTestCase(unittest.TestCase): + def testRandomSelf(self): + + dim = 16 + num_elements = 100 + + # Generating sample data + data = np.float32(np.random.random((num_elements, dim))) + + # Declaring index + hnsw_index = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip + bf_index = hnswlib.BFIndex(space='l2', dim=dim) + + # Initiating index + # max_elements - the maximum number of elements, should be known beforehand + # (probably will be made optional in the future) + # + # ef_construction - controls index search speed/build speed tradeoff + # M - is tightly connected with internal dimensionality of the data + # strongly affects the memory consumption + + hnsw_index.init_index(max_elements=num_elements, ef_construction=100, M=16) + bf_index.init_index(max_elements=num_elements) + + # Controlling the recall by setting ef: + # higher ef leads to better accuracy, but slower search + hnsw_index.set_ef(10) + + hnsw_index.set_num_threads(4) # by default using all available cores + + print("Adding %d elements" % (len(data))) + hnsw_index.add_items(data) + bf_index.add_items(data) + + # filter_function designed to return different number of elements for each query + filter_function = lambda index: index > (num_elements-3) if random.random() < 0.5 else index >= (num_elements-1) + labels, distances = hnsw_index.knn_query_return_lists(data, k=2, num_threads=1, filter=filter_function) + + # Assert the return type for both labels and distances is a list + self.assertTrue(isinstance(labels, list)) + self.assertTrue(isinstance(distances, list)) + + #Check that the length of the returned labels is between 1 and 2 + for label in labels: + self.assertTrue(len(label) >= 1 and len(label) <= 2) + + + +