From 7b60682c8cc0c94a5714b898befba9e70c94f05a Mon Sep 17 00:00:00 2001 From: Aaron Silvas Date: Wed, 20 Sep 2023 19:07:15 -0700 Subject: [PATCH] Support reset & dispose --- lib/index.d.ts | 23 +++++++---- src/faiss.cc | 97 ++++++++++++++++++++++++++++++++++++++++++++-- test/Index.test.js | 51 +++++++++++++++++++++++- 3 files changed, 159 insertions(+), 12 deletions(-) diff --git a/lib/index.d.ts b/lib/index.d.ts index 0bca395..e566d0b 100644 --- a/lib/index.d.ts +++ b/lib/index.d.ts @@ -10,16 +10,16 @@ export interface SearchResult { export enum MetricType { METRIC_INNER_PRODUCT = 0, ///< maximum inner product search METRIC_L2 = 1, ///< squared L2 search - METRIC_L1, ///< L1 (aka cityblock) - METRIC_Linf, ///< infinity distance - METRIC_Lp, ///< L_p distance, p is given by a faiss::Index + METRIC_L1 = 2, ///< L1 (aka cityblock) + METRIC_Linf = 3, ///< infinity distance + METRIC_Lp = 4, ///< L_p distance, p is given by a faiss::Index /// metric_arg /// some additional metrics defined in scipy.spatial.distance METRIC_Canberra = 20, - METRIC_BrayCurtis, - METRIC_JensenShannon, - METRIC_Jaccard, ///< defined as: sum_i(min(a_i, b_i)) / sum_i(max(a_i, b_i)) + METRIC_BrayCurtis = 21, + METRIC_JensenShannon = 22, + METRIC_Jaccard = 23, ///< defined as: sum_i(min(a_i, b_i)) / sum_i(max(a_i, b_i)) ///< where a_i, b_i > 0 } @@ -106,8 +106,15 @@ export class Index { * @param {number[]} ids IDs to read. * @return {number} number of IDs removed. */ - removeIds(ids: number[]): number - + removeIds(ids: number[]): number; + /** + * Reset the index, resulting in a ntotal of 0. + */ + reset(): void; + /** + * Free all resources associated with the index. Further calls to the index will throw. + */ + dispose(): void; } /** diff --git a/src/faiss.cc b/src/faiss.cc index f32b6b7..3a90996 100644 --- a/src/faiss.cc +++ b/src/faiss.cc @@ -142,13 +142,25 @@ class IndexBase : public Napi::ObjectWrap Napi::Value isTrained(const Napi::CallbackInfo &info) { - return Napi::Boolean::New(info.Env(), index_->is_trained); + Napi::Env env = info.Env(); + if (!index_) + { + Napi::Error::New(env, "Index has been disposed").ThrowAsJavaScriptException(); + return env.Undefined(); + } + + return Napi::Boolean::New(env, index_->is_trained); } Napi::Value add(const Napi::CallbackInfo &info) { Napi::Env env = info.Env(); + if (!index_) + { + Napi::Error::New(env, "Index has been disposed").ThrowAsJavaScriptException(); + return env.Undefined(); + } if (info.Length() != 1) { Napi::Error::New(env, "Expected 1 argument, but got " + std::to_string(info.Length()) + ".") @@ -190,10 +202,45 @@ class IndexBase : public Napi::ObjectWrap return env.Undefined(); } + Napi::Value reset(const Napi::CallbackInfo &info) + { + Napi::Env env = info.Env(); + if (!index_) + { + Napi::Error::New(env, "Index has been disposed").ThrowAsJavaScriptException(); + return env.Undefined(); + } + + index_->reset(); + + return env.Undefined(); + } + + Napi::Value dispose(const Napi::CallbackInfo &info) + { + Napi::Env env = info.Env(); + if (!index_) + { + Napi::Error::New(env, "Index has been disposed").ThrowAsJavaScriptException(); + return env.Undefined(); + } + + auto idx = index_.release(); + delete idx; + index_ = nullptr; + + return env.Undefined(); + } + Napi::Value train(const Napi::CallbackInfo &info) { Napi::Env env = info.Env(); + if (!index_) + { + Napi::Error::New(env, "Index has been disposed").ThrowAsJavaScriptException(); + return env.Undefined(); + } if (info.Length() != 1) { Napi::Error::New(env, "Expected 1 argument, but got " + std::to_string(info.Length()) + ".") @@ -239,6 +286,11 @@ class IndexBase : public Napi::ObjectWrap { Napi::Env env = info.Env(); + if (!index_) + { + Napi::Error::New(env, "Index has been disposed").ThrowAsJavaScriptException(); + return env.Undefined(); + } if (info.Length() != 2) { Napi::Error::New(env, "Expected 2 arguments, but got " + std::to_string(info.Length()) + ".") @@ -314,18 +366,36 @@ class IndexBase : public Napi::ObjectWrap Napi::Value ntotal(const Napi::CallbackInfo &info) { - return Napi::Number::New(info.Env(), index_->ntotal); + Napi::Env env = info.Env(); + if (!index_) + { + Napi::Error::New(env, "Index has been disposed").ThrowAsJavaScriptException(); + return env.Undefined(); + } + + return Napi::Number::New(env, index_->ntotal); } Napi::Value getDimension(const Napi::CallbackInfo &info) { - return Napi::Number::New(info.Env(), index_->d); + Napi::Env env = info.Env(); + if (!index_) + { + Napi::Error::New(env, "Index has been disposed").ThrowAsJavaScriptException(); + return env.Undefined(); + } + return Napi::Number::New(env, index_->d); } Napi::Value write(const Napi::CallbackInfo &info) { Napi::Env env = info.Env(); + if (!index_) + { + Napi::Error::New(env, "Index has been disposed").ThrowAsJavaScriptException(); + return env.Undefined(); + } if (info.Length() != 1) { Napi::Error::New(env, "Expected 1 argument, but got " + std::to_string(info.Length()) + ".") @@ -349,6 +419,11 @@ class IndexBase : public Napi::ObjectWrap { Napi::Env env = info.Env(); + if (!index_) + { + Napi::Error::New(env, "Index has been disposed").ThrowAsJavaScriptException(); + return env.Undefined(); + } if (info.Length() != 1) { Napi::Error::New(env, "Expected 1 argument, but got " + std::to_string(info.Length()) + ".") @@ -387,6 +462,11 @@ class IndexBase : public Napi::ObjectWrap { Napi::Env env = info.Env(); + if (!index_) + { + Napi::Error::New(env, "Index has been disposed").ThrowAsJavaScriptException(); + return env.Undefined(); + } if (info.Length() != 1) { Napi::Error::New(env, "Expected 1 argument, but got " + std::to_string(info.Length()) + ".") @@ -425,6 +505,11 @@ class IndexBase : public Napi::ObjectWrap { Napi::Env env = info.Env(); + if (!index_) + { + Napi::Error::New(env, "Index has been disposed").ThrowAsJavaScriptException(); + return env.Undefined(); + } if (info.Length() != 0) { Napi::Error::New(env, "Expected 0 arguments, but got " + std::to_string(info.Length()) + ".") @@ -461,6 +546,8 @@ class Index : public IndexBase InstanceMethod("getDimension", &Index::getDimension), InstanceMethod("isTrained", &Index::isTrained), InstanceMethod("add", &Index::add), + InstanceMethod("reset", &Index::reset), + InstanceMethod("dispose", &Index::dispose), InstanceMethod("train", &Index::train), InstanceMethod("search", &Index::search), InstanceMethod("write", &Index::write), @@ -496,6 +583,8 @@ class IndexFlatL2 : public IndexBase InstanceMethod("getDimension", &IndexFlatL2::getDimension), InstanceMethod("isTrained", &IndexFlatL2::isTrained), InstanceMethod("add", &IndexFlatL2::add), + InstanceMethod("reset", &IndexFlatL2::reset), + InstanceMethod("dispose", &IndexFlatL2::dispose), InstanceMethod("train", &IndexFlatL2::train), InstanceMethod("search", &IndexFlatL2::search), InstanceMethod("write", &IndexFlatL2::write), @@ -530,6 +619,8 @@ class IndexFlatIP : public IndexBase InstanceMethod("getDimension", &IndexFlatIP::getDimension), InstanceMethod("isTrained", &IndexFlatIP::isTrained), InstanceMethod("add", &IndexFlatIP::add), + InstanceMethod("reset", &IndexFlatIP::reset), + InstanceMethod("dispose", &IndexFlatIP::dispose), InstanceMethod("train", &IndexFlatIP::train), InstanceMethod("search", &IndexFlatIP::search), InstanceMethod("write", &IndexFlatIP::write), diff --git a/test/Index.test.js b/test/Index.test.js index 00b0cd4..f60cadb 100644 --- a/test/Index.test.js +++ b/test/Index.test.js @@ -30,7 +30,6 @@ describe('Index', () => { }); }); - describe('#toBuffer', () => { it('new index is same size as old', () => { const index = Index.fromFactory(2, 'Flat'); @@ -43,4 +42,54 @@ describe('Index', () => { expect(index.ntotal()).toBe(newIndex.ntotal()); }); }); + + describe('#reset', () => { + let index; + + beforeEach(() => { + index = Index.fromFactory(2, 'Flat'); + index.add([1, 0, 0, 1]); + }); + + it('reset the index', () => { + expect(index.ntotal()).toBe(2); + index.reset(); + expect(index.ntotal()).toBe(0); + }); + + it('reset the index and add new elements', () => { + expect(index.ntotal()).toBe(2); + index.reset(); + expect(index.ntotal()).toBe(0); + + index.add([1, 0]); + index.add([1, 2]); + expect(index.ntotal()).toBe(2); + }); + }); + + describe('#dispose', () => { + let index; + + beforeEach(() => { + index = Index.fromFactory(2, 'Flat'); + index.add([1, 0, 0, 1]); + }); + + it('disposing an index does not throw', () => { + index.dispose(); + }); + + it('disposing twice will throw', () => { + index.dispose(); + + expect(() => index.dispose()).toThrow('Index has been disposed'); + }); + + it('invoking a function after dispose will throw', () => { + index.dispose(); + + expect(() => index.ntotal()).toThrow('Index has been disposed'); + }); + }); });