Skip to content

Commit

Permalink
Add patch to support multi vector in faiss
Browse files Browse the repository at this point in the history
Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 committed Dec 20, 2023
1 parent 2e3ab95 commit 41f7d58
Show file tree
Hide file tree
Showing 4 changed files with 511 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Upgrade urllib to 1.26.18 [#1319](https://github.com/opensearch-project/k-NN/pull/1319)
* Upgrade guava to 32.1.3 [#1319](https://github.com/opensearch-project/k-NN/pull/1319)
### Refactoring
* Add patch to support multi vector in faiss [#1358](https://github.com/opensearch-project/k-NN/pull/1358)
12 changes: 9 additions & 3 deletions jni/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ list(APPEND TARGET_LIBS ${TARGET_LIB_COMMON})
# ---------------------------------- NMSLIB ----------------------------------
if (${CONFIG_NMSLIB} STREQUAL ON OR ${CONFIG_ALL} STREQUAL ON OR ${CONFIG_TEST} STREQUAL ON)
# Check if nmslib exists
find_path(NMS_REPO_DIR NAMES similarity_search PATHS ${CMAKE_CURRENT_SOURCE_DIR}/external/nmslib)
find_path(NMS_REPO_DIR NAMES similarity_search PATHS ${CMAKE_CURRENT_SOURCE_DIR}/external/nmslib NO_DEFAULT_PATH)

# If not, pull the updated submodule
if (NOT EXISTS ${NMS_REPO_DIR})
Expand Down Expand Up @@ -134,14 +134,20 @@ if (${CONFIG_FAISS} STREQUAL ON OR ${CONFIG_ALL} STREQUAL ON OR ${CONFIG_TEST} S
find_package(LAPACK REQUIRED)

# Check if faiss exists
find_path(FAISS_REPO_DIR NAMES faiss PATHS ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss)
find_path(FAISS_REPO_DIR NAMES faiss PATHS ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss NO_DEFAULT_PATH)

# If not, pull the updated submodule
# If not, pull the updated submodule and apply patches
if (NOT EXISTS ${FAISS_REPO_DIR})
message(STATUS "Could not find faiss. Pulling updated submodule.")
execute_process(COMMAND git submodule update --init -- external/faiss WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif ()

message(STATUS "Applying custom patches.")
execute_process(COMMAND git apply --ignore-space-change --3way ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE)
if(RESULT_CODE)
message(FATAL_ERROR "Failed to apply patch:\n${ERROR_MSG}")
endif()

set(FAISS_ENABLE_GPU OFF)
set(FAISS_ENABLE_PYTHON OFF)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/external/faiss EXCLUDE_FROM_ALL)
Expand Down
221 changes: 221 additions & 0 deletions jni/patches/0001-Custom-patch-to-support-multi-vector.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
From 864c1abe7bdced5d306e871ea2bd73e1e35987fd Mon Sep 17 00:00:00 2001
From: Heemin Kim <[email protected]>
Date: Wed, 6 Dec 2023 16:33:52 -0800
Subject: [PATCH] Custom patch to support multi-vector

Signed-off-by: Heemin Kim <[email protected]>
---
faiss/CMakeLists.txt | 2 +
faiss/Index.h | 6 ++-
faiss/impl/HNSW.cpp | 25 ++++++++-----
faiss/impl/ResultCollector.h | 58 +++++++++++++++++++++++++++++
faiss/impl/ResultCollectorFactory.h | 28 ++++++++++++++
5 files changed, 107 insertions(+), 12 deletions(-)
create mode 100644 faiss/impl/ResultCollector.h
create mode 100644 faiss/impl/ResultCollectorFactory.h

diff --git a/faiss/CMakeLists.txt b/faiss/CMakeLists.txt
index 27701586..af682a05 100644
--- a/faiss/CMakeLists.txt
+++ b/faiss/CMakeLists.txt
@@ -162,6 +162,8 @@ set(FAISS_HEADERS
impl/ProductQuantizer.h
impl/Quantizer.h
impl/ResidualQuantizer.h
+ impl/ResultCollector.h
+ impl/ResultCollectorFactory.h
impl/ResultHandler.h
impl/ScalarQuantizer.h
impl/ThreadedIndex-inl.h
diff --git a/faiss/Index.h b/faiss/Index.h
index 4b4b302b..13eab0c0 100644
--- a/faiss/Index.h
+++ b/faiss/Index.h
@@ -38,11 +38,12 @@

namespace faiss {

-/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h and
-/// impl/DistanceComputer.h
+/// Forward declarations see impl/AuxIndexStructures.h, impl/IDSelector.h,
+/// impl/DistanceComputer.h, and impl/ResultCollectorFactory.h
struct IDSelector;
struct RangeSearchResult;
struct DistanceComputer;
+struct ResultCollectorFactory;

/** Parent class for the optional search paramenters.
*
@@ -52,6 +53,7 @@ struct DistanceComputer;
struct SearchParameters {
/// if non-null, only these IDs will be considered during search.
IDSelector* sel = nullptr;
+ ResultCollectorFactory* col = nullptr;
/// make sure we can dynamic_cast this
virtual ~SearchParameters() {}
};
diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp
index 9fc201ea..540210a6 100644
--- a/faiss/impl/HNSW.cpp
+++ b/faiss/impl/HNSW.cpp
@@ -14,6 +14,7 @@
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/DistanceComputer.h>
#include <faiss/impl/IDSelector.h>
+#include <faiss/impl/ResultCollectorFactory.h>
#include <faiss/utils/prefetch.h>

#include <faiss/impl/platform_macros.h>
@@ -530,6 +531,15 @@ int search_from_candidates(
int level,
int nres_in = 0,
const SearchParametersHNSW* params = nullptr) {
+ ResultCollectorFactory defaultFactory;
+ ResultCollectorFactory* collectorFactory;
+ if (params == nullptr || params->col == nullptr) {
+ collectorFactory = &defaultFactory;
+ } else {
+ collectorFactory = params->col;
+ }
+ ResultCollector* collector = collectorFactory->newCollector();
+
int nres = nres_in;
int ndis = 0;

@@ -544,11 +554,7 @@ int search_from_candidates(
float d = candidates.dis[i];
FAISS_ASSERT(v1 >= 0);
if (!sel || sel->is_member(v1)) {
- if (nres < k) {
- faiss::maxheap_push(++nres, D, I, d, v1);
- } else if (d < D[0]) {
- faiss::maxheap_replace_top(nres, D, I, d, v1);
- }
+ collector->collect(k, nres, D, I, d, v1);
}
vt.set(v1);
}
@@ -612,11 +618,7 @@ int search_from_candidates(

auto add_to_heap = [&](const size_t idx, const float dis) {
if (!sel || sel->is_member(idx)) {
- if (nres < k) {
- faiss::maxheap_push(++nres, D, I, dis, idx);
- } else if (dis < D[0]) {
- faiss::maxheap_replace_top(nres, D, I, dis, idx);
- }
+ collector->collect(k, nres, D, I, dis, idx);
}
candidates.push(idx, dis);
};
@@ -660,6 +662,9 @@ int search_from_candidates(
}
}

+ collector->finalize(nres, I);
+ collectorFactory->deleteCollector(collector);
+
if (level == 0) {
stats.n1++;
if (candidates.size() == 0) {
diff --git a/faiss/impl/ResultCollector.h b/faiss/impl/ResultCollector.h
new file mode 100644
index 00000000..3e4dac34
--- /dev/null
+++ b/faiss/impl/ResultCollector.h
@@ -0,0 +1,58 @@
+/**
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+
+#include <unordered_set>
+#include <vector>
+
+#include <faiss/MetricType.h>
+#include <faiss/utils/Heap.h>
+
+/** ResultCollector is intended to define how to collect search result */
+
+namespace faiss {
+
+/** Encapsulates a set of ids to handle. */
+struct ResultCollector {
+ // For each result, collect method is called to store result
+ virtual void collect(
+ int k,
+ int& nres,
+ float* bh_val,
+ idx_t* bh_ids,
+ float val,
+ idx_t ids) = 0;
+
+ // This method is called after all result is collected
+ virtual void finalize(idx_t nres, idx_t* bh_ids) = 0;
+ virtual ~ResultCollector() {}
+};
+
+struct DefaultCollector : ResultCollector {
+ void collect(
+ int k,
+ int& nres,
+ float* bh_val,
+ idx_t* bh_ids,
+ float val,
+ idx_t ids) override {
+ if (nres < k) {
+ faiss::maxheap_push(++nres, bh_val, bh_ids, val, ids);
+ } else if (val < bh_val[0]) {
+ faiss::maxheap_replace_top(nres, bh_val, bh_ids, val, ids);
+ }
+ }
+
+ void finalize(idx_t nres, idx_t* bh_ids) override {
+ // Do nothing
+ }
+
+ ~DefaultCollector() override {}
+};
+
+} // namespace faiss
diff --git a/faiss/impl/ResultCollectorFactory.h b/faiss/impl/ResultCollectorFactory.h
new file mode 100644
index 00000000..6a15208a
--- /dev/null
+++ b/faiss/impl/ResultCollectorFactory.h
@@ -0,0 +1,28 @@
+/**
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+#include <faiss/impl/ResultCollector.h>
+namespace faiss {
+
+/** ResultCollector is intended to define how to collect search result */
+struct ResultCollectorFactory {
+ DefaultCollector defaultCollector;
+
+ // For each result, collect method is called to store result
+ virtual ResultCollector* newCollector() {
+ return &defaultCollector;
+ }
+
+ virtual void deleteCollector(ResultCollector* collector) {
+ // Do nothing
+ }
+ // This method is called after all result is collected
+ virtual ~ResultCollectorFactory() {}
+};
+
+} // namespace faiss
--
2.39.3 (Apple Git-145)

Loading

0 comments on commit 41f7d58

Please sign in to comment.