diff --git a/README.md b/README.md index 9efc0b7..6440ebf 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ $ npm install faiss-node ## Usage ```javascript -const { IndexFlatL2, Index, IndexFlatIP, MetricType } = require('faiss-node'); +const { IndexFlatL2, Index, IndexFlatIP, IndexHNSW, MetricType } = require('faiss-node'); const dimension = 2; const index = new IndexFlatL2(dimension); @@ -80,7 +80,9 @@ const deserializedIndex = Index.fromBuffer(index_buf); console.log(deserializedIndex.ntotal()); // 3 // Factory index -const hnswIndex = Index.fromFactory(2, 'HNSW,Flat', MetricType.METRIC_INNER_PRODUCT); +const hnswIndex = Index.fromFactory(2, 'HNSW32,Flat', MetricType.METRIC_INNER_PRODUCT); +// same as: +// const hnswIndex = new IndexHNSW(2, 32, MetricType.METRIC_INNER_PRODUCT) const x = [1, 0, 0, 1]; hnswIndex.train(x); hnswIndex.add(x); diff --git a/lib/index.d.ts b/lib/index.d.ts index 0bca395..33d5d8a 100644 --- a/lib/index.d.ts +++ b/lib/index.d.ts @@ -158,4 +158,50 @@ export class IndexFlatIP extends Index { * @param {IndexFlatIP} otherIndex The other IndexFlatIP instance to merge from. */ mergeFrom(otherIndex: IndexFlatIP): void; +} + +/** + * IndexHNSW Index. + * The Hierarchical Navigable Small World indexing method is based on a graph built on the indexed vectors. + * @param {number} d The dimensionality of index. + * @param {number} m The number of neighbors used in the graph (defaults to 32). + * @param {number} metric Metric type (defaults to L2). + */ +export class IndexHNSW extends Index { + IndexHNSW(d?: number, m?: number, metric?: MetricType); + /** + * Read index from a file. + * @param {string} fname File path to read. + * @return {IndexHNSW} The index read. + */ + static read(fname: string): IndexHNSW; + /** + * Read index from buffer. + * @param {Buffer} src Buffer to create index from. + * @return {IndexHNSW} The index read. + */ + static fromBuffer(src: Buffer): IndexHNSW; + /** + * Merge the current index with another IndexHNSW instance. + * @param {IndexHNSW} otherIndex The other IndexHNSW instance to merge from. + */ + mergeFrom(otherIndex: IndexHNSW): void; + /** + * The depth of exploration at add time. + */ + get efConstruction(): number; + /** + * The depth of exploration at add time. + * @param {number} value The value to set. + */ + set efConstruction(value: number); + /** + * The depth of exploration of the search. + */ + get efSearch(): number; + /** + * The depth of exploration of the search. + * @param {number} value The value to set. + */ + set efSearch(value: number); } \ No newline at end of file diff --git a/lib/index.js b/lib/index.js index ced9f1f..02129ae 100644 --- a/lib/index.js +++ b/lib/index.js @@ -14,4 +14,25 @@ var MetricType; MetricType[MetricType["METRIC_Jaccard"] = 23] = "METRIC_Jaccard"; })(MetricType || (faiss.MetricType = MetricType = {})); +if (!('efConstruction' in faiss.IndexHNSW.prototype)) { // prevents redefinition in jest + Object.defineProperty(faiss.IndexHNSW.prototype, 'efConstruction', { + get: function () { + return this.getEfConstruction(); + }, + set: function (value) { + this.setEfConstruction(value); + } + }); +} +if (!('efSearch' in faiss.IndexHNSW.prototype)) { // prevents redefinition in jest + Object.defineProperty(faiss.IndexHNSW.prototype, 'efSearch', { + get: function () { + return this.getEfSearch(); + }, + set: function (value) { + this.setEfSearch(value); + } + }); +} + module.exports = faiss; \ No newline at end of file diff --git a/src/faiss.cc b/src/faiss.cc index f32b6b7..b766853 100644 --- a/src/faiss.cc +++ b/src/faiss.cc @@ -9,11 +9,20 @@ #include #include #include +#include using namespace Napi; using idx_t = faiss::idx_t; -template +enum class IndexType +{ + Index, + IndexFlatL2, + IndexFlatIP, + IndexHNSW, +}; + +template class IndexBase : public Napi::ObjectWrap { public: @@ -21,7 +30,25 @@ class IndexBase : public Napi::ObjectWrap { Napi::Env env = info.Env(); - if (info.Length() > 0 && info[0].IsNumber()) + if constexpr (IT == IndexType::IndexHNSW) + { // HNSW constructor + if (info.Length() > 0 && info[0].IsNumber()) + { + auto n = info[0].As().Uint32Value(); + auto m = 32; // faiss default + auto metric = faiss::MetricType::METRIC_L2; // faiss default + if (info.Length() > 1 && info[1].IsNumber()) + { + m = info[1].As().Uint32Value(); + } + if (info.Length() > 2 && info[2].IsNumber()) + { + metric = static_cast(info[2].As().Uint32Value()); + } + index_ = std::unique_ptr(new faiss::IndexHNSW(n, m, metric)); + } + } + else if (info.Length() > 0 && info[0].IsNumber()) { auto n = info[0].As().Uint32Value(); index_ = std::unique_ptr(new Y(n)); @@ -446,7 +473,7 @@ class IndexBase : public Napi::ObjectWrap }; // faiss::Index is abstract so IndexFlatL2 is used as fallback -class Index : public IndexBase +class Index : public IndexBase { public: using IndexBase::IndexBase; @@ -481,7 +508,7 @@ class Index : public IndexBase } }; -class IndexFlatL2 : public IndexBase +class IndexFlatL2 : public IndexBase { public: using IndexBase::IndexBase; @@ -515,7 +542,7 @@ class IndexFlatL2 : public IndexBase } }; -class IndexFlatIP : public IndexBase +class IndexFlatIP : public IndexBase { public: using IndexBase::IndexBase; @@ -549,11 +576,104 @@ class IndexFlatIP : public IndexBase } }; +class IndexHNSW : public IndexBase +{ +public: + using IndexBase::IndexBase; + + static constexpr const char *CLASS_NAME = "IndexHNSW"; + + Napi::Value getEfConstruction(const Napi::CallbackInfo &info) + { + auto index = dynamic_cast(index_.get()); + return Napi::Number::New(info.Env(), index->hnsw.efConstruction); + } + + Napi::Value setEfConstruction(const Napi::CallbackInfo &info) + { + Napi::Env env = info.Env(); + + if (info.Length() != 1) + { + Napi::Error::New(env, "Expected 1 argument, but got " + std::to_string(info.Length()) + ".") + .ThrowAsJavaScriptException(); + return env.Undefined(); + } + if (!info[0].IsNumber()) + { + Napi::TypeError::New(env, "Invalid the first argument type, must be a Number.").ThrowAsJavaScriptException(); + return env.Undefined(); + } + + auto index = dynamic_cast(index_.get()); + index->hnsw.efConstruction = info[0].As().Int32Value(); + return env.Undefined(); + } + + Napi::Value getEfSearch(const Napi::CallbackInfo &info) + { + auto index = dynamic_cast(index_.get()); + return Napi::Number::New(info.Env(), index->hnsw.efSearch); + } + + Napi::Value setEfSearch(const Napi::CallbackInfo &info) + { + Napi::Env env = info.Env(); + + if (info.Length() != 1) + { + Napi::Error::New(env, "Expected 1 argument, but got " + std::to_string(info.Length()) + ".") + .ThrowAsJavaScriptException(); + return env.Undefined(); + } + if (!info[0].IsNumber()) + { + Napi::TypeError::New(env, "Invalid the first argument type, must be a Number.").ThrowAsJavaScriptException(); + return env.Undefined(); + } + + auto index = dynamic_cast(index_.get()); + index->hnsw.efSearch = info[0].As().Int32Value(); + return env.Undefined(); + } + + static Napi::Object Init(Napi::Env env, Napi::Object exports) + { + // clang-format off + auto func = DefineClass(env, CLASS_NAME, { + InstanceMethod("getEfConstruction", &IndexHNSW::getEfConstruction), + InstanceMethod("setEfConstruction", &IndexHNSW::setEfConstruction), + InstanceMethod("getEfSearch", &IndexHNSW::getEfSearch), + InstanceMethod("setEfSearch", &IndexHNSW::setEfSearch), + InstanceMethod("ntotal", &IndexHNSW::ntotal), + InstanceMethod("getDimension", &IndexHNSW::getDimension), + InstanceMethod("isTrained", &IndexHNSW::isTrained), + InstanceMethod("add", &IndexHNSW::add), + InstanceMethod("train", &IndexHNSW::train), + InstanceMethod("search", &IndexHNSW::search), + InstanceMethod("write", &IndexHNSW::write), + InstanceMethod("mergeFrom", &IndexHNSW::mergeFrom), + InstanceMethod("removeIds", &IndexHNSW::removeIds), + InstanceMethod("toBuffer", &IndexHNSW::toBuffer), + StaticMethod("read", &IndexHNSW::read), + StaticMethod("fromBuffer", &IndexHNSW::fromBuffer), + }); + // clang-format on + + constructor = new Napi::FunctionReference(); + *constructor = Napi::Persistent(func); + + exports.Set(CLASS_NAME, func); + return exports; + } +}; + Napi::Object Init(Napi::Env env, Napi::Object exports) { Index::Init(env, exports); IndexFlatL2::Init(env, exports); IndexFlatIP::Init(env, exports); + IndexHNSW::Init(env, exports); return exports; } diff --git a/test/IndexHNSW.test.js b/test/IndexHNSW.test.js new file mode 100644 index 0000000..c9280c3 --- /dev/null +++ b/test/IndexHNSW.test.js @@ -0,0 +1,20 @@ +const { IndexHNSW, MetricType } = require('../lib'); + +describe('IndexHNSW', () => { + describe('#constructor', () => { + it('1 arg will result in index with default neighbors & metric', () => { + const index = new IndexHNSW(2); + expect(index.getDimension()).toBe(2); + }); + + it('2 args will result in index with default metric', () => { + const index = new IndexHNSW(2, 20); + expect(index.getDimension()).toBe(2); + }); + + it('3 args will result in index', () => { + const index = new IndexHNSW(2, 20, MetricType.METRIC_INNER_PRODUCT); + expect(index.getDimension()).toBe(2); + }); + }); +});