Skip to content

Commit

Permalink
Adding filtered search. API is going to change still.
Browse files Browse the repository at this point in the history
  • Loading branch information
daxpryce committed Nov 3, 2023
1 parent 58bcc3c commit 9709e66
Show file tree
Hide file tree
Showing 11 changed files with 192 additions and 27 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "diskannpy"
version = "0.6.1"
version = "0.7.0"

description = "DiskANN Python extension module"
readme = "python/README.md"
Expand Down
5 changes: 3 additions & 2 deletions python/include/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ template <typename DT, typename TagT = DynamicIdType, typename LabelT = filterT>
void build_memory_index(diskann::Metric metric, const std::string &vector_bin_path,
const std::string &index_output_path, uint32_t graph_degree, uint32_t complexity,
float alpha, uint32_t num_threads, bool use_pq_build,
size_t num_pq_bytes, bool use_opq, uint32_t filter_complexity,
bool use_tags = false);
size_t num_pq_bytes, bool use_opq, bool use_tags = false,
const std::string& filter_labels_file = "", const std::string& universal_label = "",
uint32_t filter_complexity = 0);

}
4 changes: 4 additions & 0 deletions python/include/static_memory_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ template <typename DT> class StaticMemoryIndex
NeighborsAndDistances<StaticIdType> search(py::array_t<DT, py::array::c_style | py::array::forcecast> &query,
uint64_t knn, uint64_t complexity);

NeighborsAndDistances<StaticIdType> search_with_filter(
py::array_t<DT, py::array::c_style | py::array::forcecast> &query, uint64_t knn, uint64_t complexity,
filterT filter);

NeighborsAndDistances<StaticIdType> batch_search(
py::array_t<DT, py::array::c_style | py::array::forcecast> &queries, uint64_t num_queries, uint64_t knn,
uint64_t complexity, uint32_t num_threads);
Expand Down
8 changes: 6 additions & 2 deletions python/src/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,10 @@ def build_memory_index(
num_pq_bytes: int = defaults.NUM_PQ_BYTES,
use_opq: bool = defaults.USE_OPQ,
vector_dtype: Optional[VectorDType] = None,
filter_complexity: int = defaults.FILTER_COMPLEXITY,
tags: Union[str, VectorIdentifierBatch] = "",
filter_labels_file: str = "",
universal_label: int = 0,
filter_complexity: int = defaults.FILTER_COMPLEXITY,
index_prefix: str = "ann",
) -> None:
"""
Expand Down Expand Up @@ -299,8 +301,10 @@ def build_memory_index(
use_pq_build=use_pq_build,
num_pq_bytes=num_pq_bytes,
use_opq=use_opq,
filter_complexity=filter_complexity,
use_tags=use_tags,
filter_labels_file=filter_labels_file,
universal_label="0", # a string that must be a number for some inexplicable reason.
filter_complexity=filter_complexity,
)

_write_index_metadata(
Expand Down
16 changes: 8 additions & 8 deletions python/src/_builder.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def build_memory_index(
use_pq_build: bool,
num_pq_bytes: int,
use_opq: bool,
label_file: str,
universal_label: str,
tags: Union[str, VectorIdentifierBatch],
filter_labels_file: str,
universal_label: int,
filter_complexity: int,
tags: Optional[VectorIdentifierBatch],
index_prefix: str,
index_prefix: str
) -> None: ...
@overload
def build_memory_index(
Expand All @@ -66,9 +66,9 @@ def build_memory_index(
num_pq_bytes: int,
use_opq: bool,
vector_dtype: VectorDType,
label_file: str,
universal_label: str,
tags: Union[str, VectorIdentifierBatch],
filter_labels_file: str,
universal_label: int,
filter_complexity: int,
tags: Optional[str],
index_prefix: str,
index_prefix: str
) -> None: ...
44 changes: 40 additions & 4 deletions python/src/_static_memory_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import os
import warnings
from typing import Optional
from typing import Hashable, Optional

import numpy as np

Expand Down Expand Up @@ -43,6 +43,8 @@ def __init__(
distance_metric: Optional[DistanceMetric] = None,
vector_dtype: Optional[VectorDType] = None,
dimensions: Optional[int] = None,
enable_filters: bool = False,
universal_label: str = ""
):
"""
### Parameters
Expand Down Expand Up @@ -73,8 +75,20 @@ def __init__(
- **dimensions**: The vector dimensionality of this index. All new vectors inserted must be the same
dimensionality. **This value is only used if a `{index_prefix}_metadata.bin` file does not exist.** If it
does not exist, you are required to provide it.
- **enable_filters**: Indexes built with filters can also be used for filtered search.
"""
index_prefix = _valid_index_prefix(index_directory, index_prefix)
self._labels_map = {}
self._universal_label = universal_label
if enable_filters:
try:
with open(index_prefix + "_labels_map.txt", "r") as labels_map_if:
for line in labels_map_if:
(key, val) = line.split("\t")
self._labels_map[key] = int(val)
except: # noqa: E722
# exceptions are basically presumed to be either file not found or file not formatted correctly
raise RuntimeException("Filter labels file was unable to be processed.")
vector_dtype, metric, num_points, dims = _ensure_index_metadata(
index_prefix,
vector_dtype,
Expand Down Expand Up @@ -109,7 +123,7 @@ def __init__(
)

def search(
self, query: VectorLike, k_neighbors: int, complexity: int
self, query: VectorLike, k_neighbors: int, complexity: int, filter_label: str = ""
) -> QueryResponse:
"""
Searches the index by a single query vector.
Expand All @@ -121,13 +135,24 @@ def search(
- **complexity**: Size of distance ordered list of candidate neighbors to use while searching. List size
increases accuracy at the cost of latency. Must be at least k_neighbors in size.
"""
if filter_label != "":
if len(self._labels_map) == 0:
raise ValueError(
f"A filter label of {filter_label} was provided, but this class was not initialized with filters "
"enabled, e.g. StaticDiskMemory(..., enable_filters=True)"
)
if filter_label != self._universal_label and filter_label not in self._labels_map:
raise ValueError(
f"A filter label of {filter_label} was provided, but the external(str)->internal(np.uint32) labels map "
f"does not include that label."
)
_query = _castable_dtype_or_raise(query, expected=self._vector_dtype)
_assert(len(_query.shape) == 1, "query vector must be 1-d")
_assert(
_query.shape[0] == self._dimensions,
f"query vector must have the same dimensionality as the index; index dimensionality: {self._dimensions}, "
f"query dimensionality: {_query.shape[0]}",
)
)
_assert_is_positive_uint32(k_neighbors, "k_neighbors")
_assert_is_nonnegative_uint32(complexity, "complexity")

Expand All @@ -136,9 +161,20 @@ def search(
f"k_neighbors={k_neighbors} asked for, but list_size={complexity} was smaller. Increasing {complexity} to {k_neighbors}"
)
complexity = k_neighbors
neighbors, distances = self._index.search(query=_query, knn=k_neighbors, complexity=complexity)

if filter_label == "":
neighbors, distances = self._index.search(query=_query, knn=k_neighbors, complexity=complexity)
else:
filter = 0 if filter_label == self._universal_label else self._labels_map[filter_label]
neighbors, distances = self._index.search_with_filter(
query=query,
knn=k_neighbors,
complexity=complexity,
filter=filter
)
return QueryResponse(identifiers=neighbors, distances=distances)


def batch_search(
self,
queries: VectorLikeBatch,
Expand Down
62 changes: 55 additions & 7 deletions python/src/builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,47 @@ template void build_disk_index<uint8_t>(diskann::Metric, const std::string &, co
template void build_disk_index<int8_t>(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t,
double, double, uint32_t, uint32_t);

template <typename T, typename TagT, typename LabelT>
std::string prepare_filtered_label_map(
diskann::Index<T, TagT, LabelT>& index,
const std::string& index_output_path,
const std::string& filter_labels_file,
const std::string& universal_label
) {
std::string labels_file_to_use = index_output_path + "_label_formatted.txt";
std::string mem_labels_int_map_file = index_output_path + "_labels_map.txt";
convert_labels_string_to_int(filter_labels_file, labels_file_to_use,
mem_labels_int_map_file, universal_label);
if (!universal_label.empty())
{
uint32_t unv_label_as_num = 0;
index.set_universal_label(unv_label_as_num);
}
return labels_file_to_use;
}

template std::string prepare_filtered_label_map<float>(diskann::Index<float, uint32_t, uint32_t>&,
const std::string&,
const std::string&,
const std::string&);

template std::string prepare_filtered_label_map<int8_t>(diskann::Index<int8_t, uint32_t, uint32_t>&,
const std::string&,
const std::string&,
const std::string&);

template std::string prepare_filtered_label_map<uint8_t>(diskann::Index<uint8_t, uint32_t, uint32_t>&,
const std::string&,
const std::string&,
const std::string&);

template <typename T, typename TagT, typename LabelT>
void build_memory_index(const diskann::Metric metric, const std::string &vector_bin_path,
const std::string &index_output_path, const uint32_t graph_degree, const uint32_t complexity,
const float alpha, const uint32_t num_threads, const bool use_pq_build,
const size_t num_pq_bytes, const bool use_opq, const uint32_t filter_complexity,
const bool use_tags)
const size_t num_pq_bytes, const bool use_opq, const bool use_tags,
const std::string& filter_labels_file, const std::string& universal_label,
const uint32_t filter_complexity)
{
diskann::IndexWriteParameters index_build_params = diskann::IndexWriteParametersBuilder(complexity, graph_degree)
.with_filter_list_size(filter_complexity)
Expand Down Expand Up @@ -65,23 +100,36 @@ void build_memory_index(const diskann::Metric metric, const std::string &vector_
size_t tag_dims = 1;
diskann::load_bin(tags_file, tags_data, data_num, tag_dims);
std::vector<TagT> tags(tags_data, tags_data + data_num);
index.build(vector_bin_path.c_str(), data_num, tags);
if (filter_labels_file.empty()) {
index.build(vector_bin_path.c_str(), data_num, tags);
} else {
auto labels_file = prepare_filtered_label_map<T, TagT, LabelT>(index, index_output_path, filter_labels_file, universal_label);
index.build_filtered_index(vector_bin_path.c_str(), labels_file, data_num, tags);
}
}
else
{
index.build(vector_bin_path.c_str(), data_num);
if (filter_labels_file.empty()) {
index.build(vector_bin_path.c_str(), data_num);
} else {
auto labels_file = prepare_filtered_label_map<T, TagT, LabelT>(index, index_output_path, filter_labels_file, universal_label);
index.build_filtered_index(vector_bin_path.c_str(), labels_file, data_num);
}
}

index.save(index_output_path.c_str());
}

template void build_memory_index<float>(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t,
float, uint32_t, bool, size_t, bool, uint32_t, bool);
float, uint32_t, bool, size_t, bool,bool, const std::string&,
const std::string&, uint32_t);

template void build_memory_index<int8_t>(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t,
float, uint32_t, bool, size_t, bool, uint32_t, bool);
float, uint32_t, bool, size_t, bool, bool, const std::string&,
const std::string&, uint32_t);

template void build_memory_index<uint8_t>(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t,
float, uint32_t, bool, size_t, bool, uint32_t, bool);
float, uint32_t, bool, size_t, bool, bool, const std::string&,
const std::string&, uint32_t);

} // namespace diskannpy
1 change: 0 additions & 1 deletion python/src/diskann_bindings.cpp

This file was deleted.

5 changes: 4 additions & 1 deletion python/src/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,17 @@ template <typename T> inline void add_variant(py::module_ &m, const Variant &var

m.def(variant.memory_builder_name.c_str(), &diskannpy::build_memory_index<T>, "distance_metric"_a,
"data_file_path"_a, "index_output_path"_a, "graph_degree"_a, "complexity"_a, "alpha"_a, "num_threads"_a,
"use_pq_build"_a, "num_pq_bytes"_a, "use_opq"_a, "filter_complexity"_a = 0, "use_tags"_a = false);
"use_pq_build"_a, "num_pq_bytes"_a, "use_opq"_a, "use_tags"_a = false,
"filter_labels_file"_a = "", "universal_label"_a = "", "filter_complexity"_a = 0);

py::class_<diskannpy::StaticMemoryIndex<T>>(m, variant.static_memory_index_name.c_str())
.def(py::init<const diskann::Metric, const std::string &, const size_t, const size_t, const uint32_t,
const uint32_t>(),
"distance_metric"_a, "index_path"_a, "num_points"_a, "dimensions"_a, "num_threads"_a,
"initial_search_complexity"_a)
.def("search", &diskannpy::StaticMemoryIndex<T>::search, "query"_a, "knn"_a, "complexity"_a)
.def("search_with_filter", &diskannpy::StaticMemoryIndex<T>::search_with_filter, "query"_a, "knn"_a,
"complexity"_a, "filter"_a)
.def("batch_search", &diskannpy::StaticMemoryIndex<T>::batch_search, "queries"_a, "num_queries"_a, "knn"_a,
"complexity"_a, "num_threads"_a);

Expand Down
13 changes: 13 additions & 0 deletions python/src/static_memory_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ NeighborsAndDistances<StaticIdType> StaticMemoryIndex<DT>::search(
return std::make_pair(ids, dists);
}

template <typename DT>
NeighborsAndDistances<StaticIdType> StaticMemoryIndex<DT>::search_with_filter(
py::array_t<DT, py::array::c_style | py::array::forcecast> &query,
const uint64_t knn, const uint64_t complexity, const filterT filter
)
{
py::array_t<StaticIdType> ids(knn);
py::array_t<float> dists(knn);
std::vector<DT *> empty_vector;
_index.search_with_filters(query.data(), filter, knn, complexity, ids.mutable_data(), dists.mutable_data());
return std::make_pair(ids, dists);
}

template <typename DT>
NeighborsAndDistances<StaticIdType> StaticMemoryIndex<DT>::batch_search(
py::array_t<DT, py::array::c_style | py::array::forcecast> &queries, const uint64_t num_queries, const uint64_t knn,
Expand Down
59 changes: 58 additions & 1 deletion python/tests/test_static_memory_index.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

import os
import shutil
import unittest

from tempfile import mkdtemp

import diskannpy as dap
import numpy as np
from fixtures import build_random_vectors_and_memory_index, calculate_recall
from fixtures import random_vectors
from sklearn.neighbors import NearestNeighbors


Expand Down Expand Up @@ -185,4 +189,57 @@ def test_zero_threads(self):
)

k = 5
ids, dists = index.batch_search(query_vectors, k_neighbors=k, complexity=5, num_threads=0)
ids, dists = index.batch_search(query_vectors, k_neighbors=k, complexity=5, num_threads=0)


class TestFilteredStaticMemoryIndex(unittest.TestCase):
def test_common(self):
vectors: np.ndarray = random_vectors(10000, 10, dtype=np.float32, seed=54321)
query_vectors: np.ndarray = random_vectors(10, 10, dtype=np.float32)
temp = mkdtemp()
try:
with open(os.path.join(temp, "labels.txt"), "w") as labels:
for idx in range(0, vectors.shape[0]):
if idx % 2 == 0:
label = "evensies"
else:
label = "oddsies"
if idx % 3 == 0:
label += ",evenly_by_three"
print(label, file=labels)

dap.build_memory_index(
data=vectors,
distance_metric="l2",
index_directory=temp,
complexity=64,
graph_degree=32,
num_threads=16,
filter_labels_file=os.path.join(temp, "labels.txt"),
universal_label="all",
filter_complexity=128,
)
index = dap.StaticMemoryIndex(
index_directory=temp,
num_threads=16,
initial_search_complexity=64,
enable_filters=True,
universal_label="all"
)

k = 50
probable_superset, _ = index.search(query_vectors[0], k_neighbors=k*2, complexity=128)
response = index.search(query_vectors[0], k_neighbors=k, complexity=64, filter_label="evensies")
self.assertIsInstance(response, dap.QueryResponse)
ids_1, dists_1 = response
self.assertTrue(all(id % 2 == 0 for id in ids_1))
ids_2, dists = index.search(query_vectors[0], k_neighbors=k, complexity=64, filter_label="oddsies")
self.assertTrue(all(id % 2 != 0 for id in ids_2))
self.assertTrue(np.intersect1d(ids_1, ids_2).shape[0] == 0)
ids_3, dists = index.search(query_vectors[0], k_neighbors=k, complexity=64, filter_label="evenly_by_three")
self.assertTrue(all(id % 3 == 0 for id in ids_3))

in_superset = np.intersect1d(probable_superset, np.append(ids_1, ids_2)).shape[0]
self.assertTrue(in_superset/k*2 > 0.98)
finally:
shutil.rmtree(temp, ignore_errors=True)

0 comments on commit 9709e66

Please sign in to comment.