Skip to content

Commit

Permalink
Merge branch 'haol/support-binary-jaccard-like-milvus' into 'mqdb-dev'
Browse files Browse the repository at this point in the history
Add support for binary IVF and HNSW indexes

Closes #16

See merge request mqdb/faiss!33
  • Loading branch information
Linpeng Tang committed Jan 9, 2024
2 parents c51d22d + 00ba12e commit bd549be
Show file tree
Hide file tree
Showing 14 changed files with 424 additions and 304 deletions.
59 changes: 34 additions & 25 deletions faiss/IndexBinaryFlat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
#include <faiss/impl/IDSelector.h>
#include <faiss/utils/Heap.h>
#include <faiss/utils/hamming.h>
#include <faiss/utils/jaccard.h>
#include <faiss/utils/utils.h>
#include <cstring>

namespace faiss {

IndexBinaryFlat::IndexBinaryFlat(idx_t d) : IndexBinary(d) {}
IndexBinaryFlat::IndexBinaryFlat(idx_t d, MetricType metric) : IndexBinary(d, metric) {}

void IndexBinaryFlat::add(idx_t n, const uint8_t* x) {
xb.insert(xb.end(), x, x + n * code_size);
Expand All @@ -38,30 +39,15 @@ void IndexBinaryFlat::search(
int32_t* distances,
idx_t* labels,
const SearchParameters* params) const {
FAISS_THROW_IF_NOT_MSG(
!params, "search params not supported for this index");
FAISS_THROW_IF_NOT(k > 0);

const idx_t block_size = query_batch_size;
for (idx_t s = 0; s < n; s += block_size) {
idx_t nn = block_size;
if (s + block_size > n) {
nn = n - s;
}

if (use_heap) {
// We see the distances and labels as heaps.
int_maxheap_array_t res = {
size_t(nn), size_t(k), labels + s * k, distances + s * k};

hammings_knn_hc(
&res,
x + s * code_size,
xb.data(),
ntotal,
code_size,
/* ordered = */ true);
} else {
if(this->metric_type == MetricType::METRIC_HAMMING) {
const idx_t block_size = query_batch_size;
for (idx_t s = 0; s < n; s += block_size) {
idx_t nn = block_size;
if (s + block_size > n) {
nn = n - s;
}
const faiss::IDSelector* sel = params ? params->sel : nullptr;
hammings_knn_mc(
x + s * code_size,
xb.data(),
Expand All @@ -70,8 +56,31 @@ void IndexBinaryFlat::search(
k,
code_size,
distances + s * k,
labels + s * k);
labels + s * k,
reinterpret_cast<void*>(const_cast<faiss::IDSelector*>(sel)));
}
} else if (this->metric_type == MetricType::METRIC_JACCARD) {
const idx_t block_size = query_batch_size;
for (idx_t s = 0; s < n; s += block_size) {
idx_t nn = block_size;
if (s + block_size > n) {
nn = n - s;
}

const faiss::IDSelector* sel = params ? params->sel : nullptr;
jaccard_knn(x + s * code_size,
xb.data(),
nn,
ntotal,
k,
code_size,
reinterpret_cast<float*>(distances + s * k),
labels + s * k,
sel);
}
} else {
SI_LOG_FATAL("metric_type {} not supported for IndexBinaryFlat::search",
Search::enumToString(this->metric_type));
}
}

Expand Down
2 changes: 1 addition & 1 deletion faiss/IndexBinaryFlat.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct IndexBinaryFlat : IndexBinary {

size_t query_batch_size = 32;

explicit IndexBinaryFlat(idx_t d);
explicit IndexBinaryFlat(idx_t d, MetricType metric = METRIC_HAMMING);

void add(idx_t n, const uint8_t* x) override;

Expand Down
98 changes: 65 additions & 33 deletions faiss/IndexBinaryHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <faiss/utils/Heap.h>
#include <faiss/utils/hamming.h>
#include <faiss/utils/random.h>
#include <faiss/utils/jaccard-inl.h>

namespace faiss {

Expand Down Expand Up @@ -162,11 +163,11 @@ IndexBinaryHNSW::IndexBinaryHNSW() {
is_trained = true;
}

IndexBinaryHNSW::IndexBinaryHNSW(int d, int M)
: IndexBinary(d),
IndexBinaryHNSW::IndexBinaryHNSW(int d, int M, MetricType metric)
: IndexBinary(d, metric),
hnsw(M),
own_fields(true),
storage(new IndexBinaryFlat(d)) {
storage(new IndexBinaryFlat(d, metric)) {
is_trained = true;
}

Expand Down Expand Up @@ -196,33 +197,25 @@ void IndexBinaryHNSW::search(
idx_t k,
int32_t* distances,
idx_t* labels,
const SearchParameters* params) const {
FAISS_THROW_IF_NOT_MSG(
!params, "search params not supported for this index");
const SearchParameters* params_in) const {
FAISS_THROW_IF_NOT(k > 0);

#pragma omp parallel
const auto* params = dynamic_cast<const SearchParametersHNSW*>(params_in);
{
VisitedTable vt(ntotal);
std::unique_ptr<DistanceComputer> dis(get_distance_computer());

#pragma omp for
for (idx_t i = 0; i < n; i++) {
idx_t* idxi = labels + i * k;
float* simi = (float*)(distances + i * k);

dis->set_query((float*)(x + i * code_size));

maxheap_heapify(k, simi, idxi);
hnsw.search(*dis, k, idxi, simi, vt);
hnsw.search(*dis, k, idxi, simi, vt, params);
maxheap_reorder(k, simi, idxi);
}
}

#pragma omp parallel for
for (int i = 0; i < n * k; ++i) {
distances[i] = std::round(((float*)distances)[i]);
}
}

void IndexBinaryHNSW::add(idx_t n, const uint8_t* x) {
Expand Down Expand Up @@ -281,31 +274,70 @@ struct FlatHammingDis : DistanceComputer {
}
};


template <class JaccardComputer>
struct FlatJaccardDis : DistanceComputer {
const int code_size;
const uint8_t* b;
size_t ndis;
JaccardComputer jc;

float operator()(idx_t i) override {
ndis++;
return jc.jaccard(b + i * code_size);
}

float symmetric_dis(idx_t i, idx_t j) override {
return JaccardComputerDefault(b + j * code_size, code_size)
.jaccard(b + i * code_size);
}

explicit FlatJaccardDis(const IndexBinaryFlat& storage)
: code_size(storage.code_size),
b(storage.xb.data()),
ndis(0),
jc() {}

// NOTE: Pointers are cast from float in order to reuse the floating-point
// DistanceComputer.
void set_query(const float* x) override {
jc.set((uint8_t*)x, code_size);
}

~FlatJaccardDis() override {
#pragma omp critical
{ hnsw_stats.ndis += ndis; }
}
};

} // namespace

DistanceComputer* IndexBinaryHNSW::get_distance_computer() const {
IndexBinaryFlat* flat_storage = dynamic_cast<IndexBinaryFlat*>(storage);

FAISS_ASSERT(flat_storage != nullptr);

switch (code_size) {
case 4:
return new FlatHammingDis<HammingComputer4>(*flat_storage);
case 8:
return new FlatHammingDis<HammingComputer8>(*flat_storage);
case 16:
return new FlatHammingDis<HammingComputer16>(*flat_storage);
case 20:
return new FlatHammingDis<HammingComputer20>(*flat_storage);
case 32:
return new FlatHammingDis<HammingComputer32>(*flat_storage);
case 64:
return new FlatHammingDis<HammingComputer64>(*flat_storage);
default:
break;
if(flat_storage->metric_type == MetricType::METRIC_HAMMING) {
switch (code_size) {
case 4:
return new FlatHammingDis<HammingComputer4>(*flat_storage);
case 8:
return new FlatHammingDis<HammingComputer8>(*flat_storage);
case 16:
return new FlatHammingDis<HammingComputer16>(*flat_storage);
case 20:
return new FlatHammingDis<HammingComputer20>(*flat_storage);
case 32:
return new FlatHammingDis<HammingComputer32>(*flat_storage);
case 64:
return new FlatHammingDis<HammingComputer64>(*flat_storage);
default:
break;
}
return new FlatHammingDis<HammingComputerDefault>(*flat_storage);
} else if (flat_storage->metric_type == MetricType::METRIC_JACCARD) {
auto* jaccard_dist = new FlatJaccardDis<JaccardComputerDefault>(*flat_storage);
return jaccard_dist;
}

return new FlatHammingDis<HammingComputerDefault>(*flat_storage);
return nullptr;
}

} // namespace faiss
2 changes: 1 addition & 1 deletion faiss/IndexBinaryHNSW.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ struct IndexBinaryHNSW : IndexBinary {
IndexBinary* storage;

explicit IndexBinaryHNSW();
explicit IndexBinaryHNSW(int d, int M = 32);
explicit IndexBinaryHNSW(int d, int M = 32, MetricType metric = METRIC_HAMMING);
explicit IndexBinaryHNSW(IndexBinary* storage, int M = 32);

~IndexBinaryHNSW() override;
Expand Down
Loading

0 comments on commit bd549be

Please sign in to comment.