forked from opensearch-project/k-NN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Heemin Kim <[email protected]>
- Loading branch information
Showing
16 changed files
with
604 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
26 changes: 26 additions & 0 deletions
26
jni/include/knn_extension/faiss/MultiVectorResultCollector.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#pragma once | ||
|
||
#include <faiss/impl/ResultCollector.h> | ||
#include <faiss/MetricType.h> | ||
#include "knn_extension/faiss/utils/BitSet.h" | ||
|
||
namespace os_faiss { | ||
|
||
using idx_t = faiss::idx_t; | ||
struct MultiVectorResultCollector:faiss::ResultCollector { | ||
std::unordered_map<idx_t, idx_t> group_id_to_id; | ||
std::unordered_map<idx_t, size_t> group_id_to_index; | ||
BitSet* parent_bit_set; | ||
const std::vector<int64_t>* id_map; | ||
MultiVectorResultCollector(BitSet* parent_bit_set, const std::vector<int64_t>* id_map); | ||
void collect( | ||
int k, | ||
int& nres, | ||
float* bh_val, | ||
int64_t* bh_ids, | ||
float val, | ||
int64_t ids) override; | ||
void post_process(int64_t nres, int64_t* bh_ids) override; | ||
}; | ||
|
||
} |
14 changes: 14 additions & 0 deletions
14
jni/include/knn_extension/faiss/MultiVectorResultCollectorFactory.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#pragma once | ||
|
||
#include <faiss/impl/ResultCollectorFactory.h> | ||
#include "knn_extension/faiss/utils/BitSet.h" | ||
|
||
namespace os_faiss { | ||
struct MultiVectorResultCollectorFactory:faiss::ResultCollectorFactory { | ||
BitSet* parent_bit_set; | ||
|
||
MultiVectorResultCollectorFactory(BitSet* parent_bit_set); | ||
faiss::ResultCollector* new_collector() override; | ||
void delete_collector(faiss::ResultCollector* resultCollector) override; | ||
}; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#pragma once | ||
|
||
#include <faiss/MetricType.h> | ||
#include <faiss/impl/platform_macros.h> | ||
#include <limits> | ||
|
||
using idx_t = faiss::idx_t; | ||
class BitSet { | ||
protected: | ||
const int NO_MORE_DOCS = std::numeric_limits<int>::max(); | ||
public: | ||
virtual idx_t nextSetBit(idx_t index) = 0; | ||
virtual ~BitSet() = default; | ||
}; | ||
|
||
|
||
class FixedBitSet : public BitSet { | ||
public: | ||
size_t n; | ||
// using uint64_t to leverage function __builtin_ctzll which is defined in faiss/impl/platform_macros.h | ||
uint64_t* bitmap; | ||
|
||
public: | ||
FixedBitSet(const int* intArray, const int length); | ||
idx_t nextSetBit(idx_t index) override; | ||
~FixedBitSet(); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
#pragma once | ||
|
||
#include <climits> | ||
#include <cmath> | ||
#include <cstring> | ||
|
||
#include <stdint.h> | ||
#include <cassert> | ||
#include <cstdio> | ||
|
||
#include <limits> | ||
#include <unordered_map> | ||
#include <faiss/utils/ordered_key_value.h> | ||
|
||
// Collection of heap operations | ||
namespace os_faiss { | ||
|
||
template <class C> | ||
inline void up_heap( | ||
size_t k, | ||
typename C::T* bh_val, | ||
typename C::TI* bh_ids, | ||
typename C::T val, | ||
typename C::TI id, | ||
std::unordered_map<typename C::TI, typename C::TI>* group_id_to_id, | ||
std::unordered_map<typename C::TI, size_t>* group_id_to_index, | ||
typename C::TI group_id, | ||
size_t start_index) { | ||
bh_val--; /* Use 1-based indexing for easier node->child translation */ | ||
bh_ids--; | ||
size_t i = start_index + 1, i_father; | ||
|
||
while (i > 1) { | ||
i_father = i >> 1; | ||
if (!C::cmp2(val, bh_val[i_father], group_id, bh_ids[i_father])) { | ||
/* the heap structure is ok */ | ||
break; | ||
} | ||
bh_val[i] = bh_val[i_father]; | ||
bh_ids[i] = bh_ids[i_father]; | ||
(*group_id_to_index)[bh_ids[i]] = i - 1; | ||
i = i_father; | ||
} | ||
bh_val[i] = val; | ||
bh_ids[i] = group_id; | ||
(*group_id_to_id)[group_id] = id; | ||
(*group_id_to_index)[group_id] = i - 1; | ||
} | ||
|
||
template <class C> | ||
inline void down_heap( | ||
size_t k, | ||
typename C::T* bh_val, | ||
typename C::TI* bh_ids, | ||
typename C::T val, | ||
typename C::TI id, | ||
std::unordered_map<typename C::TI, typename C::TI>* group_id_to_id, | ||
std::unordered_map<typename C::TI, size_t>* group_id_to_index, | ||
typename C::TI group_id, | ||
size_t start_index) { | ||
bh_val--; /* Use 1-based indexing for easier node->child translation */ | ||
bh_ids--; | ||
size_t i = start_index + 1, i1, i2; | ||
|
||
while (1) { | ||
i1 = i << 1; | ||
i2 = i1 + 1; | ||
if (i1 > k) { | ||
break; | ||
} | ||
|
||
// Note that C::cmp2() is a bool function answering | ||
// `(a1 > b1) || ((a1 == b1) && (a2 > b2))` for max | ||
// heap and same with the `<` sign for min heap. | ||
if ((i2 == k + 1) || | ||
C::cmp2(bh_val[i1], bh_val[i2], bh_ids[i1], bh_ids[i2])) { | ||
if (C::cmp2(val, bh_val[i1], group_id, bh_ids[i1])) { | ||
break; | ||
} | ||
bh_val[i] = bh_val[i1]; | ||
bh_ids[i] = bh_ids[i1]; | ||
(*group_id_to_index)[bh_ids[i]] = i - 1; | ||
i = i1; | ||
} else { | ||
if (C::cmp2(val, bh_val[i2], group_id, bh_ids[i2])) { | ||
break; | ||
} | ||
bh_val[i] = bh_val[i2]; | ||
bh_ids[i] = bh_ids[i2]; | ||
(*group_id_to_index)[bh_ids[i]] = i - 1; | ||
i = i2; | ||
} | ||
} | ||
bh_val[i] = val; | ||
bh_ids[i] = group_id; | ||
(*group_id_to_id)[group_id] = id; | ||
(*group_id_to_index)[group_id] = i - 1; | ||
} | ||
|
||
template <typename T> | ||
inline void maxheap_push( | ||
size_t k, | ||
T* bh_val, | ||
int64_t* bh_ids, | ||
T val, | ||
int64_t id, | ||
std::unordered_map<int64_t, int64_t>* group_id_to_id, | ||
std::unordered_map<int64_t, size_t>* group_id_to_index, | ||
int64_t group_id) { | ||
up_heap<faiss::CMax<T, int64_t>>( | ||
k, | ||
bh_val, | ||
bh_ids, | ||
val, | ||
id, | ||
group_id_to_id, | ||
group_id_to_index, | ||
group_id, | ||
k - 1); | ||
} | ||
|
||
template <typename T> | ||
inline void maxheap_replace_top( | ||
size_t k, | ||
T* bh_val, | ||
int64_t* bh_ids, | ||
T val, | ||
int64_t id, | ||
std::unordered_map<int64_t, int64_t>* group_id_to_id, | ||
std::unordered_map<int64_t, size_t>* group_id_to_index, | ||
int64_t group_id) { | ||
group_id_to_id->erase(bh_ids[0]); | ||
group_id_to_index->erase(bh_ids[0]); | ||
down_heap<faiss::CMax<T, int64_t>>( | ||
k, | ||
bh_val, | ||
bh_ids, | ||
val, | ||
id, | ||
group_id_to_id, | ||
group_id_to_index, | ||
group_id, | ||
0); | ||
} | ||
|
||
template <typename T> | ||
inline void maxheap_update( | ||
size_t k, | ||
T* bh_val, | ||
int64_t* bh_ids, | ||
T val, | ||
int64_t id, | ||
std::unordered_map<int64_t, int64_t>* group_id_to_id, | ||
std::unordered_map<int64_t, size_t>* group_id_to_index, | ||
int64_t group_id) { | ||
size_t start_index = group_id_to_index->at(group_id); | ||
up_heap<faiss::CMax<T, int64_t>>( | ||
k, | ||
bh_val, | ||
bh_ids, | ||
val, | ||
id, | ||
group_id_to_id, | ||
group_id_to_index, | ||
group_id, | ||
start_index); | ||
down_heap<faiss::CMax<T, int64_t>>( | ||
k, | ||
bh_val, | ||
bh_ids, | ||
val, | ||
id, | ||
group_id_to_id, | ||
group_id_to_index, | ||
group_id, | ||
start_index); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.