Skip to content

Commit

Permalink
Add support of multi vector in jni
Browse files Browse the repository at this point in the history
Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 committed Dec 28, 2023
1 parent 7b47bae commit feca6da
Show file tree
Hide file tree
Showing 16 changed files with 604 additions and 18 deletions.
22 changes: 19 additions & 3 deletions jni/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,21 @@ if (${CONFIG_FAISS} STREQUAL ON OR ${CONFIG_ALL} STREQUAL ON OR ${CONFIG_TEST} S
set(FAISS_ENABLE_PYTHON OFF)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/external/faiss EXCLUDE_FROM_ALL)

add_library(${TARGET_LIB_FAISS} SHARED ${CMAKE_CURRENT_SOURCE_DIR}/src/org_opensearch_knn_jni_FaissService.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_wrapper.cpp)
add_library(
${TARGET_LIB_FAISS} SHARED
${CMAKE_CURRENT_SOURCE_DIR}/src/org_opensearch_knn_jni_FaissService.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_wrapper.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/knn_extension/faiss/utils/BitSet.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/knn_extension/faiss/MultiVectorResultCollector.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/knn_extension/faiss/MultiVectorResultCollectorFactory.cpp)
target_link_libraries(${TARGET_LIB_FAISS} faiss ${TARGET_LIB_COMMON} OpenMP::OpenMP_CXX)
target_include_directories(${TARGET_LIB_FAISS} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include $ENV{JAVA_HOME}/include $ENV{JAVA_HOME}/include/${JVM_OS_TYPE} ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss)
target_include_directories(${TARGET_LIB_FAISS} PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/include/knn_extension/faiss
${CMAKE_CURRENT_SOURCE_DIR}/include/knn_extension/faiss/utils
$ENV{JAVA_HOME}/include
$ENV{JAVA_HOME}/include/${JVM_OS_TYPE}
${CMAKE_CURRENT_SOURCE_DIR}/external/faiss)
set_target_properties(${TARGET_LIB_FAISS} PROPERTIES SUFFIX ${LIB_EXT})
set_target_properties(${TARGET_LIB_FAISS} PROPERTIES POSITION_INDEPENDENT_CODE ON)

Expand Down Expand Up @@ -198,7 +210,11 @@ if ("${WIN32}" STREQUAL "")
jni_test
tests/faiss_wrapper_test.cpp
tests/nmslib_wrapper_test.cpp
tests/test_util.cpp)
tests/test_util.cpp
tests/knn_extension/faiss/utils/BitSetTest.cpp
tests/knn_extension/faiss/MultiVectorResultCollectorTest.cpp
tests/knn_extension/faiss/MultiVectorResultCollectorFactoryTest.cpp
)

target_link_libraries(
jni_test
Expand Down
5 changes: 2 additions & 3 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#define OPENSEARCH_KNN_FAISS_WRAPPER_H

#include "jni_util.h"

#include <jni.h>

namespace knn_jni {
Expand All @@ -38,13 +37,13 @@ namespace knn_jni {
//
// Return an array of KNNQueryResults
jobjectArray QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ);
jfloatArray queryVectorJ, jint kJ, jintArray parentIdsJ);

// Execute a query against the index located in memory at indexPointerJ along with Filters
//
// Return an array of KNNQueryResults
jobjectArray QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ);
jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ, jintArray parentIdsJ);

// Free the index located in memory at indexPointerJ
void Free(jlong indexPointer);
Expand Down
26 changes: 26 additions & 0 deletions jni/include/knn_extension/faiss/MultiVectorResultCollector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#pragma once

#include <faiss/impl/ResultCollector.h>
#include <faiss/MetricType.h>
#include "knn_extension/faiss/utils/BitSet.h"

namespace os_faiss {

using idx_t = faiss::idx_t;
struct MultiVectorResultCollector:faiss::ResultCollector {
std::unordered_map<idx_t, idx_t> group_id_to_id;
std::unordered_map<idx_t, size_t> group_id_to_index;
BitSet* parent_bit_set;
const std::vector<int64_t>* id_map;
MultiVectorResultCollector(BitSet* parent_bit_set, const std::vector<int64_t>* id_map);
void collect(
int k,
int& nres,
float* bh_val,
int64_t* bh_ids,
float val,
int64_t ids) override;
void post_process(int64_t nres, int64_t* bh_ids) override;
};

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#pragma once

#include <faiss/impl/ResultCollectorFactory.h>
#include "knn_extension/faiss/utils/BitSet.h"

namespace os_faiss {
struct MultiVectorResultCollectorFactory:faiss::ResultCollectorFactory {
BitSet* parent_bit_set;

MultiVectorResultCollectorFactory(BitSet* parent_bit_set);
faiss::ResultCollector* new_collector() override;
void delete_collector(faiss::ResultCollector* resultCollector) override;
};
}
27 changes: 27 additions & 0 deletions jni/include/knn_extension/faiss/utils/BitSet.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#include <faiss/MetricType.h>
#include <faiss/impl/platform_macros.h>
#include <limits>

using idx_t = faiss::idx_t;
class BitSet {
protected:
const int NO_MORE_DOCS = std::numeric_limits<int>::max();
public:
virtual idx_t nextSetBit(idx_t index) = 0;
virtual ~BitSet() = default;
};


class FixedBitSet : public BitSet {
public:
size_t n;
// using uint64_t to leverage function __builtin_ctzll which is defined in faiss/impl/platform_macros.h
uint64_t* bitmap;

public:
FixedBitSet(const int* intArray, const int length);
idx_t nextSetBit(idx_t index) override;
~FixedBitSet();
};
178 changes: 178 additions & 0 deletions jni/include/knn_extension/faiss/utils/Heap.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
#pragma once

#include <climits>
#include <cmath>
#include <cstring>

#include <stdint.h>
#include <cassert>
#include <cstdio>

#include <limits>
#include <unordered_map>
#include <faiss/utils/ordered_key_value.h>

// Collection of heap operations
namespace os_faiss {

template <class C>
inline void up_heap(
size_t k,
typename C::T* bh_val,
typename C::TI* bh_ids,
typename C::T val,
typename C::TI id,
std::unordered_map<typename C::TI, typename C::TI>* group_id_to_id,
std::unordered_map<typename C::TI, size_t>* group_id_to_index,
typename C::TI group_id,
size_t start_index) {
bh_val--; /* Use 1-based indexing for easier node->child translation */
bh_ids--;
size_t i = start_index + 1, i_father;

while (i > 1) {
i_father = i >> 1;
if (!C::cmp2(val, bh_val[i_father], group_id, bh_ids[i_father])) {
/* the heap structure is ok */
break;
}
bh_val[i] = bh_val[i_father];
bh_ids[i] = bh_ids[i_father];
(*group_id_to_index)[bh_ids[i]] = i - 1;
i = i_father;
}
bh_val[i] = val;
bh_ids[i] = group_id;
(*group_id_to_id)[group_id] = id;
(*group_id_to_index)[group_id] = i - 1;
}

template <class C>
inline void down_heap(
size_t k,
typename C::T* bh_val,
typename C::TI* bh_ids,
typename C::T val,
typename C::TI id,
std::unordered_map<typename C::TI, typename C::TI>* group_id_to_id,
std::unordered_map<typename C::TI, size_t>* group_id_to_index,
typename C::TI group_id,
size_t start_index) {
bh_val--; /* Use 1-based indexing for easier node->child translation */
bh_ids--;
size_t i = start_index + 1, i1, i2;

while (1) {
i1 = i << 1;
i2 = i1 + 1;
if (i1 > k) {
break;
}

// Note that C::cmp2() is a bool function answering
// `(a1 > b1) || ((a1 == b1) && (a2 > b2))` for max
// heap and same with the `<` sign for min heap.
if ((i2 == k + 1) ||
C::cmp2(bh_val[i1], bh_val[i2], bh_ids[i1], bh_ids[i2])) {
if (C::cmp2(val, bh_val[i1], group_id, bh_ids[i1])) {
break;
}
bh_val[i] = bh_val[i1];
bh_ids[i] = bh_ids[i1];
(*group_id_to_index)[bh_ids[i]] = i - 1;
i = i1;
} else {
if (C::cmp2(val, bh_val[i2], group_id, bh_ids[i2])) {
break;
}
bh_val[i] = bh_val[i2];
bh_ids[i] = bh_ids[i2];
(*group_id_to_index)[bh_ids[i]] = i - 1;
i = i2;
}
}
bh_val[i] = val;
bh_ids[i] = group_id;
(*group_id_to_id)[group_id] = id;
(*group_id_to_index)[group_id] = i - 1;
}

template <typename T>
inline void maxheap_push(
size_t k,
T* bh_val,
int64_t* bh_ids,
T val,
int64_t id,
std::unordered_map<int64_t, int64_t>* group_id_to_id,
std::unordered_map<int64_t, size_t>* group_id_to_index,
int64_t group_id) {
up_heap<faiss::CMax<T, int64_t>>(
k,
bh_val,
bh_ids,
val,
id,
group_id_to_id,
group_id_to_index,
group_id,
k - 1);
}

template <typename T>
inline void maxheap_replace_top(
size_t k,
T* bh_val,
int64_t* bh_ids,
T val,
int64_t id,
std::unordered_map<int64_t, int64_t>* group_id_to_id,
std::unordered_map<int64_t, size_t>* group_id_to_index,
int64_t group_id) {
group_id_to_id->erase(bh_ids[0]);
group_id_to_index->erase(bh_ids[0]);
down_heap<faiss::CMax<T, int64_t>>(
k,
bh_val,
bh_ids,
val,
id,
group_id_to_id,
group_id_to_index,
group_id,
0);
}

template <typename T>
inline void maxheap_update(
size_t k,
T* bh_val,
int64_t* bh_ids,
T val,
int64_t id,
std::unordered_map<int64_t, int64_t>* group_id_to_id,
std::unordered_map<int64_t, size_t>* group_id_to_index,
int64_t group_id) {
size_t start_index = group_id_to_index->at(group_id);
up_heap<faiss::CMax<T, int64_t>>(
k,
bh_val,
bh_ids,
val,
id,
group_id_to_id,
group_id_to_index,
group_id,
start_index);
down_heap<faiss::CMax<T, int64_t>>(
k,
bh_val,
bh_ids,
val,
id,
group_id_to_id,
group_id_to_index,
group_id,
start_index);
}
}
4 changes: 2 additions & 2 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex
* Signature: (J[FI)[Lorg/opensearch/knn/index/query/KNNQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndex
(JNIEnv *, jclass, jlong, jfloatArray, jint);
(JNIEnv *, jclass, jlong, jfloatArray, jint, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: queryIndex_WithFilter
* Signature: (J[FI[J)[Lorg/opensearch/knn/index/query/KNNQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter
(JNIEnv *, jclass, jlong, jfloatArray, jint, jintArray);
(JNIEnv *, jclass, jlong, jfloatArray, jint, jintArray, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
Expand Down
Loading

0 comments on commit feca6da

Please sign in to comment.