From 218ea36ad42bd7bc2af2e5894241071a15df24d8 Mon Sep 17 00:00:00 2001 From: Aaron Silvas Date: Sat, 30 Sep 2023 15:30:53 -0700 Subject: [PATCH] AddWithIds & toIDMap support --- README.md | 4 ++ lib/index.d.ts | 4 ++ src/faiss.cc | 114 +++++++++++++++++++++++++++++++++++++++++++++ test/Index.test.js | 22 +++++++++ 4 files changed, 144 insertions(+) diff --git a/README.md b/README.md index 9efc0b7..16b30e9 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,10 @@ const hnswIndex = Index.fromFactory(2, 'HNSW,Flat', MetricType.METRIC_INNER_PROD const x = [1, 0, 0, 1]; hnswIndex.train(x); hnswIndex.add(x); + +// IDMap'd index +const idIndex = new IndexFlat(2).toIDMap(); +idIndex.addWithIds([1, 0, 0, 1], [100n, 200n]); ``` ## License diff --git a/lib/index.d.ts b/lib/index.d.ts index 0bca395..aec8e6d 100644 --- a/lib/index.d.ts +++ b/lib/index.d.ts @@ -76,6 +76,10 @@ export class Index { * Write index to buffer. */ toBuffer(): Buffer; + /** + * Create an IDMap'd index from source index. + */ + toIDMap(): Index; /** * Read index from a file. * @param {string} fname File path to read. diff --git a/src/faiss.cc b/src/faiss.cc index f32b6b7..890bceb 100644 --- a/src/faiss.cc +++ b/src/faiss.cc @@ -9,6 +9,7 @@ #include #include #include +#include using namespace Napi; using idx_t = faiss::idx_t; @@ -190,6 +191,86 @@ class IndexBase : public Napi::ObjectWrap return env.Undefined(); } + Napi::Value addWithIds(const Napi::CallbackInfo &info) + { + Napi::Env env = info.Env(); + + if (info.Length() != 2) + { + Napi::Error::New(env, "Expected 2 arguments, but got " + std::to_string(info.Length()) + ".") + .ThrowAsJavaScriptException(); + return env.Undefined(); + } + if (!info[0].IsArray()) + { + Napi::TypeError::New(env, "Invalid the first argument type, must be an Array.").ThrowAsJavaScriptException(); + return env.Undefined(); + } + if (!info[1].IsArray()) + { + Napi::TypeError::New(env, "Invalid the second argument type, must be an Array.").ThrowAsJavaScriptException(); + return env.Undefined(); + } + + Napi::Array arr = info[0].As(); + size_t length = arr.Length(); + auto dv = std::div(length, index_->d); + if (dv.rem != 0) + { + Napi::Error::New(env, "Invalid the given array length.") + .ThrowAsJavaScriptException(); + return env.Undefined(); + } + Napi::Array labels = info[1].As(); + size_t labelCount = labels.Length(); + if (labelCount != dv.quot) + { + Napi::Error::New(env, "Labels array length must match the number of vectors.") + .ThrowAsJavaScriptException(); + return env.Undefined(); + } + + float *xb = new float[length]; + for (size_t i = 0; i < length; i++) + { + Napi::Value val = arr[i]; + if (!val.IsNumber()) + { + Napi::Error::New(env, "Expected a Number as array item. (at: " + std::to_string(i) + ")") + .ThrowAsJavaScriptException(); + return env.Undefined(); + } + xb[i] = val.As().FloatValue(); + } + + idx_t *xc = new idx_t[labelCount]; + for (size_t i = 0; i < labelCount; i++) + { + Napi::Value val = labels[i]; + if (val.IsNumber()) + { + xc[i] = val.As().Int64Value(); + } + else if (val.IsBigInt()) + { + auto lossless = false; + xc[i] = val.As().Int64Value(&lossless); + } + else + { + Napi::Error::New(env, "Expected a Number or BigInt as array item. (at: " + std::to_string(i) + ")") + .ThrowAsJavaScriptException(); + return env.Undefined(); + } + } + + index_->add_with_ids(dv.quot, xb, xc); + + delete[] xb; + delete[] xc; + return env.Undefined(); + } + Napi::Value train(const Napi::CallbackInfo &info) { Napi::Env env = info.Env(); @@ -440,6 +521,33 @@ class IndexBase : public Napi::ObjectWrap return Napi::Buffer::Copy(env, writer->data.data(), writer->data.size()); } + Napi::Value toIDMap2(const Napi::CallbackInfo &info) + { + Napi::Env env = info.Env(); + + if (info.Length() != 0) + { + Napi::Error::New(env, "Expected 0 arguments, but got " + std::to_string(info.Length()) + ".") + .ThrowAsJavaScriptException(); + return env.Undefined(); + } + + Napi::Object instance = T::constructor->New({}); + T *index = Napi::ObjectWrap::Unwrap(instance); + + try + { + // wrap the new IDMap'd index around the old and leave it to faiss to throw if index not compatible + index->index_ = std::unique_ptr(new faiss::IndexIDMap2(index_.get())); + } + catch (const faiss::FaissException &ex) + { + Napi::Error::New(env, ex.what()).ThrowAsJavaScriptException(); + } + + return instance; + } + protected: std::unique_ptr index_; inline static Napi::FunctionReference *constructor; @@ -461,12 +569,14 @@ class Index : public IndexBase InstanceMethod("getDimension", &Index::getDimension), InstanceMethod("isTrained", &Index::isTrained), InstanceMethod("add", &Index::add), + InstanceMethod("addWithIds", &Index::addWithIds), InstanceMethod("train", &Index::train), InstanceMethod("search", &Index::search), InstanceMethod("write", &Index::write), InstanceMethod("mergeFrom", &Index::mergeFrom), InstanceMethod("removeIds", &Index::removeIds), InstanceMethod("toBuffer", &Index::toBuffer), + InstanceMethod("toIDMap2", &Index::toIDMap2), StaticMethod("read", &Index::read), StaticMethod("fromBuffer", &Index::fromBuffer), StaticMethod("fromFactory", &Index::fromFactory), @@ -496,12 +606,14 @@ class IndexFlatL2 : public IndexBase InstanceMethod("getDimension", &IndexFlatL2::getDimension), InstanceMethod("isTrained", &IndexFlatL2::isTrained), InstanceMethod("add", &IndexFlatL2::add), + InstanceMethod("addWithIds", &IndexFlatL2::addWithIds), InstanceMethod("train", &IndexFlatL2::train), InstanceMethod("search", &IndexFlatL2::search), InstanceMethod("write", &IndexFlatL2::write), InstanceMethod("mergeFrom", &IndexFlatL2::mergeFrom), InstanceMethod("removeIds", &IndexFlatL2::removeIds), InstanceMethod("toBuffer", &IndexFlatL2::toBuffer), + InstanceMethod("toIDMap2", &IndexFlatL2::toIDMap2), StaticMethod("read", &IndexFlatL2::read), StaticMethod("fromBuffer", &IndexFlatL2::fromBuffer), }); @@ -530,12 +642,14 @@ class IndexFlatIP : public IndexBase InstanceMethod("getDimension", &IndexFlatIP::getDimension), InstanceMethod("isTrained", &IndexFlatIP::isTrained), InstanceMethod("add", &IndexFlatIP::add), + InstanceMethod("addWithIds", &IndexFlatIP::addWithIds), InstanceMethod("train", &IndexFlatIP::train), InstanceMethod("search", &IndexFlatIP::search), InstanceMethod("write", &IndexFlatIP::write), InstanceMethod("mergeFrom", &IndexFlatIP::mergeFrom), InstanceMethod("removeIds", &IndexFlatIP::removeIds), InstanceMethod("toBuffer", &IndexFlatIP::toBuffer), + InstanceMethod("toIDMap2", &IndexFlatIP::toIDMap2), StaticMethod("read", &IndexFlatIP::read), StaticMethod("fromBuffer", &IndexFlatIP::fromBuffer), }); diff --git a/test/Index.test.js b/test/Index.test.js index 00b0cd4..6907ff7 100644 --- a/test/Index.test.js +++ b/test/Index.test.js @@ -43,4 +43,26 @@ describe('Index', () => { expect(index.ntotal()).toBe(newIndex.ntotal()); }); }); + + describe('#toIDMap2', () => { + it('new index preserves ID\'s', () => { + const index = Index.fromFactory(2, 'Flat').toIDMap2(); + const x = [1, 0, 0, 1]; + const labels = [100, 200]; + index.addWithIds(x, labels); + const results = index.search([1, 0], 2); + expect(results.labels).toEqual(labels); + }); + + it('supports BigInt labels', () => { + const index = Index.fromFactory(2, 'Flat').toIDMap2(); + const x = [1, 0, 0, 1]; + const labels = [100n, 200n]; + index.addWithIds(x, labels); + const results = index.search([1, 0], 2); + expect(results.labels).toEqual([100, 200]); + // TODO: Once search supports BigInt, use this test instead + // expect(results.labels).toEqual(labels); + }); + }); });