diff --git a/CMakeLists.txt b/CMakeLists.txt index a60bd92..eedf0f5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,7 +32,7 @@ set(BUILD_TESTING OFF) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") -set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD 17) set(FAISS_ENABLE_GPU OFF) set(FAISS_ENABLE_PYTHON OFF) diff --git a/README.md b/README.md index 9eb4cd6..9efc0b7 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ $ npm install faiss-node ## Usage ```javascript -const { IndexFlatL2 } = require('faiss-node'); +const { IndexFlatL2, Index, IndexFlatIP, MetricType } = require('faiss-node'); const dimension = 2; const index = new IndexFlatL2(dimension); @@ -69,6 +69,21 @@ const removedCount = newIndex.removeIds([0]); console.log(removedCount); // 1 console.log(newIndex.ntotal()); // 3 console.log(newIndex.search([1, 2], 1)); // { distances: [ 0 ], labels: [ 0 ] } + +// IndexFlatIP +const ipIndex = new IndexFlatIP(2); +ipIndex.add([1, 0]); + +// Serialize an index +const index_buf = newIndex.toBuffer(); +const deserializedIndex = Index.fromBuffer(index_buf); +console.log(deserializedIndex.ntotal()); // 3 + +// Factory index +const hnswIndex = Index.fromFactory(2, 'HNSW,Flat', MetricType.METRIC_INNER_PRODUCT); +const x = [1, 0, 0, 1]; +hnswIndex.train(x); +hnswIndex.add(x); ``` ## License diff --git a/examples/index.js b/examples/index.js index 0c2dac6..dfafc4f 100644 --- a/examples/index.js +++ b/examples/index.js @@ -1,38 +1,60 @@ -const { IndexFlatL2 } = require('../'); +const { IndexFlatL2, Index, IndexFlatIP, MetricType } = require('../'); const dimension = 2; const index = new IndexFlatL2(dimension); -console.log(index.getDimension()); -console.log(index.isTrained()); -console.log(index.ntotal()); +console.log(index.getDimension()); // 2 +console.log(index.isTrained()); // true +console.log(index.ntotal()); // 0 + +// inserting data into index. index.add([1, 0]); index.add([1, 2]); index.add([1, 3]); index.add([1, 1]); -console.log(index.ntotal()); + +console.log(index.ntotal()); // 4 const k = 4; const results = index.search([1, 0], k); -console.log(results.labels); -console.log(results.distances); +console.log(results.labels); // [ 0, 3, 1, 2 ] +console.log(results.distances); // [ 0, 1, 4, 9 ] +// Save index const fname = 'faiss.index'; index.write(fname); +// Load saved index const index_loaded = IndexFlatL2.read(fname); -console.log(index_loaded.getDimension()); -console.log(index_loaded.ntotal()); +console.log(index_loaded.getDimension()); //2 +console.log(index_loaded.ntotal()); //4 const results1 = index_loaded.search([1, 1], 4); -console.log(results1.labels); -console.log(results1.distances); +console.log(results1.labels); // [ 3, 0, 1, 2 ] +console.log(results1.distances); // [ 0, 1, 1, 4 ] +// Merge index const newIndex = new IndexFlatL2(dimension); newIndex.mergeFrom(index); -console.log(newIndex.ntotal()); +console.log(newIndex.ntotal()); // 4 -console.log(newIndex.search([1, 2], 1)); +// Remove items +console.log(newIndex.search([1, 2], 1)); // { distances: [ 0 ], labels: [ 1 ] } const removedCount = newIndex.removeIds([0]); -console.log(removedCount); -console.log(newIndex.ntotal()); -console.log(newIndex.search([1, 2], 1)); \ No newline at end of file +console.log(removedCount); // 1 +console.log(newIndex.ntotal()); // 3 +console.log(newIndex.search([1, 2], 1)); // { distances: [ 0 ], labels: [ 0 ] } + +// IndexFlatIP +const ipIndex = new IndexFlatIP(2); +ipIndex.add([1, 0]); + +// Serialize an index +const index_buf = newIndex.toBuffer(); +const deserializedIndex = Index.fromBuffer(index_buf); +console.log(deserializedIndex.ntotal()); // 3 + +// Factory index +const hnswIndex = Index.fromFactory(2, 'HNSW,Flat', MetricType.METRIC_INNER_PRODUCT); +const x = [1, 0, 0, 1]; +hnswIndex.train(x); +hnswIndex.add(x); \ No newline at end of file diff --git a/lib/index.d.ts b/lib/index.d.ts index 30a4ce0..0bca395 100644 --- a/lib/index.d.ts +++ b/lib/index.d.ts @@ -6,12 +6,29 @@ export interface SearchResult { labels: number[] } +// See faiss/MetricType.h +export enum MetricType { + METRIC_INNER_PRODUCT = 0, ///< maximum inner product search + METRIC_L2 = 1, ///< squared L2 search + METRIC_L1, ///< L1 (aka cityblock) + METRIC_Linf, ///< infinity distance + METRIC_Lp, ///< L_p distance, p is given by a faiss::Index + /// metric_arg + + /// some additional metrics defined in scipy.spatial.distance + METRIC_Canberra = 20, + METRIC_BrayCurtis, + METRIC_JensenShannon, + METRIC_Jaccard, ///< defined as: sum_i(min(a_i, b_i)) / sum_i(max(a_i, b_i)) + ///< where a_i, b_i > 0 +} + /** - * IndexFlatL2 Index. + * Index. * Index that stores the full vectors and performs exhaustive search. * @param {number} d The dimensionality of index. */ -export class IndexFlatL2 { +export class Index { constructor(d: number); /** * returns the number of verctors currently indexed. @@ -34,6 +51,12 @@ export class IndexFlatL2 { * @param {number[]} x Input matrix, size n * d */ add(x: number[]): void; + /** + * Train n vectors of dimension d to the index. + * Vectors are implicitly assigned labels ntotal .. ntotal + n - 1 + * @param {number[]} x Input matrix, size n * d + */ + train(x: number[]): void; /** * Query n vectors of dimension d to the index. * return at most k vectors. If there are not enough results for a @@ -48,22 +71,91 @@ export class IndexFlatL2 { * Write index to a file. * @param {string} fname File path to write. */ - write(fname: string): void + write(fname: string): void; + /** + * Write index to buffer. + */ + toBuffer(): Buffer; + /** + * Read index from a file. + * @param {string} fname File path to read. + * @return {Index} The index read. + */ + static read(fname: string): Index; + /** + * Read index from buffer. + * @param {Buffer} src Buffer to create index from. + * @return {Index} The index read. + */ + static fromBuffer(src: Buffer): Index; + /** + * Construct an index from factory descriptor. + * @param {number} dims Buffer to create index from. + * @param {string} descriptor Factory descriptor. + * @param {MetricType} metric Metric type (defaults to L2). + * @return {Index} The index read. + */ + static fromFactory(dims: number, descriptor: string, metric?: MetricType): Index; + /** + * Merge the current index with another Index instance. + * @param {Index} otherIndex The other Index instance to merge from. + */ + mergeFrom(otherIndex: Index): void; + /** + * Remove IDs from the index. + * @param {number[]} ids IDs to read. + * @return {number} number of IDs removed. + */ + removeIds(ids: number[]): number + +} + +/** + * IndexFlatL2 Index. + * IndexFlatL2 that stores the full vectors and performs `squared L2` search. + * @param {number} d The dimensionality of index. + */ +export class IndexFlatL2 extends Index { /** * Read index from a file. * @param {string} fname File path to read. * @return {IndexFlatL2} The index read. */ static read(fname: string): IndexFlatL2; + /** + * Read index from buffer. + * @param {Buffer} src Buffer to create index from. + * @return {IndexFlatL2} The index read. + */ + static fromBuffer(src: Buffer): IndexFlatL2; /** * Merge the current index with another IndexFlatL2 instance. * @param {IndexFlatL2} otherIndex The other IndexFlatL2 instance to merge from. */ mergeFrom(otherIndex: IndexFlatL2): void; +} + +/** + * IndexFlatIP Index. + * Index that stores the full vectors and performs `maximum inner product` search. + * @param {number} d The dimensionality of index. + */ +export class IndexFlatIP extends Index { + /** + * Read index from a file. + * @param {string} fname File path to read. + * @return {IndexFlatIP} The index read. + */ + static read(fname: string): IndexFlatIP; + /** + * Read index from buffer. + * @param {Buffer} src Buffer to create index from. + * @return {IndexFlatIP} The index read. + */ + static fromBuffer(src: Buffer): IndexFlatIP; /** - * Remove IDs from the index. - * @param {string} ids IDs to read. - * @return {IndexFlatL2} number of IDs removed. + * Merge the current index with another IndexFlatIP instance. + * @param {IndexFlatIP} otherIndex The other IndexFlatIP instance to merge from. */ - removeIds(ids: number[]): number + mergeFrom(otherIndex: IndexFlatIP): void; } \ No newline at end of file diff --git a/lib/index.js b/lib/index.js index 21f2efe..ced9f1f 100644 --- a/lib/index.js +++ b/lib/index.js @@ -1,2 +1,17 @@ const faiss = require('bindings')('faiss-node'); + +faiss.MetricType = void 0; +var MetricType; +(function (MetricType) { + MetricType[MetricType["METRIC_INNER_PRODUCT"] = 0] = "METRIC_INNER_PRODUCT"; + MetricType[MetricType["METRIC_L2"] = 1] = "METRIC_L2"; + MetricType[MetricType["METRIC_L1"] = 2] = "METRIC_L1"; + MetricType[MetricType["METRIC_Linf"] = 3] = "METRIC_Linf"; + MetricType[MetricType["METRIC_Lp"] = 4] = "METRIC_Lp"; + MetricType[MetricType["METRIC_Canberra"] = 20] = "METRIC_Canberra"; + MetricType[MetricType["METRIC_BrayCurtis"] = 21] = "METRIC_BrayCurtis"; + MetricType[MetricType["METRIC_JensenShannon"] = 22] = "METRIC_JensenShannon"; + MetricType[MetricType["METRIC_Jaccard"] = 23] = "METRIC_Jaccard"; +})(MetricType || (faiss.MetricType = MetricType = {})); + module.exports = faiss; \ No newline at end of file diff --git a/src/faiss.cc b/src/faiss.cc index 2aaa5be..f32b6b7 100644 --- a/src/faiss.cc +++ b/src/faiss.cc @@ -5,101 +5,141 @@ #include #include #include +#include +#include +#include #include using namespace Napi; using idx_t = faiss::idx_t; -class IndexFlatL2 : public Napi::ObjectWrap +template +class IndexBase : public Napi::ObjectWrap { public: - IndexFlatL2(const Napi::CallbackInfo &info) : Napi::ObjectWrap(info) + IndexBase(const Napi::CallbackInfo &info) : Napi::ObjectWrap(info) { Napi::Env env = info.Env(); - if (info[0].IsExternal()) + + if (info.Length() > 0 && info[0].IsNumber()) { - const std::string fname = *info[0].As>().Data(); - try - { - index_ = std::unique_ptr(dynamic_cast(faiss::read_index(fname.c_str()))); - } - catch (const faiss::FaissException& ex) - { - Napi::Error::New(env, ex.what()).ThrowAsJavaScriptException(); - } + auto n = info[0].As().Uint32Value(); + index_ = std::unique_ptr(new Y(n)); + } + } + + static Napi::Value read(const Napi::CallbackInfo &info) + { + Napi::Env env = info.Env(); + + if (info.Length() != 1) + { + Napi::Error::New(env, "Expected 1 argument, but got " + std::to_string(info.Length()) + ".") + .ThrowAsJavaScriptException(); + return env.Undefined(); } - else + if (!info[0].IsString()) { - if (!info.IsConstructCall()) - { - Napi::Error::New(env, "Class constructors cannot be invoked without 'new'") - .ThrowAsJavaScriptException(); - return; - } + Napi::TypeError::New(env, "Invalid the first argument type, must be a string.").ThrowAsJavaScriptException(); + return env.Undefined(); + } - if (info.Length() != 1) - { - Napi::Error::New(env, "Expected 1 argument, but got " + std::to_string(info.Length()) + ".") - .ThrowAsJavaScriptException(); - return; - } - if (!info[0].IsNumber()) - { - Napi::TypeError::New(env, "Invalid the first argument type, must be a number.").ThrowAsJavaScriptException(); - return; - } + Napi::Object instance = T::constructor->New({}); + T *index = Napi::ObjectWrap::Unwrap(instance); + std::string fname = info[0].As().Utf8Value(); - auto n = info[0].As().Uint32Value(); - index_ = std::unique_ptr(new faiss::IndexFlatL2(n)); + try + { + index->index_ = std::unique_ptr(dynamic_cast(faiss::read_index(fname.c_str()))); } + catch (const faiss::FaissException &ex) + { + Napi::Error::New(env, ex.what()).ThrowAsJavaScriptException(); + } + + return instance; } - static Napi::Object Init(Napi::Env env, Napi::Object exports) + static Napi::Value fromBuffer(const Napi::CallbackInfo &info) { - // clang-format off - Napi::Function func = DefineClass(env, "IndexFlatL2", { - InstanceMethod("ntotal", &IndexFlatL2::ntotal), - InstanceMethod("getDimension", &IndexFlatL2::getDimension), - InstanceMethod("isTrained", &IndexFlatL2::isTrained), - InstanceMethod("add", &IndexFlatL2::add), - InstanceMethod("search", &IndexFlatL2::search), - InstanceMethod("write", &IndexFlatL2::write), - InstanceMethod("mergeFrom", &IndexFlatL2::mergeFrom), - InstanceMethod("removeIds", &IndexFlatL2::removeIds), - StaticMethod("read", &IndexFlatL2::read), - }); - // clang-format on + Napi::Env env = info.Env(); - Napi::FunctionReference *constructor = new Napi::FunctionReference(); - *constructor = Napi::Persistent(func); - env.SetInstanceData(constructor); + if (info.Length() != 1) + { + Napi::Error::New(env, "Expected 1 argument, but got " + std::to_string(info.Length()) + ".") + .ThrowAsJavaScriptException(); + return env.Undefined(); + } + if (!info[0].IsBuffer()) + { + Napi::TypeError::New(env, "Invalid the first argument type, must be a buffer.").ThrowAsJavaScriptException(); + return env.Undefined(); + } - exports.Set("IndexFlatL2", func); - return exports; + Napi::Object instance = T::constructor->New({}); + T *index = Napi::ObjectWrap::Unwrap(instance); + + auto buffer = Napi::Buffer::Copy(env, info[0].As>().Data(), info[0].As>().Length()); + faiss::VectorIOReader *reader = new faiss::VectorIOReader(); + reader->data = std::vector(buffer.Data(), buffer.Data() + buffer.Length()); + + try + { + index->index_ = std::unique_ptr(dynamic_cast(faiss::read_index(reader))); + } + catch (const faiss::FaissException &ex) + { + Napi::Error::New(env, ex.what()).ThrowAsJavaScriptException(); + } + + return instance; } - static Napi::Value read(const Napi::CallbackInfo &info) + static Napi::Value fromFactory(const Napi::CallbackInfo &info) { Napi::Env env = info.Env(); - if (info.Length() != 1) + if (info.Length() < 2) { - Napi::Error::New(env, "Expected 1 argument, but got " + std::to_string(info.Length()) + ".") + Napi::Error::New(env, "Expected 2 or 3 arguments, but got " + std::to_string(info.Length()) + ".") .ThrowAsJavaScriptException(); return env.Undefined(); } - if (!info[0].IsString()) + if (!info[0].IsNumber()) { - Napi::TypeError::New(env, "Invalid the first argument type, must be a string.").ThrowAsJavaScriptException(); + Napi::TypeError::New(env, "Invalid the first argument type, must be a number.").ThrowAsJavaScriptException(); return env.Undefined(); } + if (!info[1].IsString()) + { + Napi::TypeError::New(env, "Invalid the second argument type, must be a string.").ThrowAsJavaScriptException(); + return env.Undefined(); + } + + auto metric = faiss::MetricType::METRIC_L2; + if (info[2].IsNumber()) + { + metric = static_cast(info[2].As().Uint32Value()); + } + + Napi::Object instance = T::constructor->New({}); + T *index = Napi::ObjectWrap::Unwrap(instance); + + const uint32_t d = info[0].As().Uint32Value(); + std::string description = info[1].As().Utf8Value(); + + try + { + index->index_ = std::unique_ptr(dynamic_cast(faiss::index_factory(d, description.c_str(), metric))); + } + catch (const faiss::FaissException &ex) + { + Napi::Error::New(env, ex.what()).ThrowAsJavaScriptException(); + } - Napi::FunctionReference *constructor = env.GetInstanceData(); - return constructor->New({Napi::External::New(env, new std::string(info[0].As()))}); + return instance; } -private: - std::unique_ptr index_; Napi::Value isTrained(const Napi::CallbackInfo &info) { return Napi::Boolean::New(info.Env(), index_->is_trained); @@ -150,6 +190,51 @@ class IndexFlatL2 : public Napi::ObjectWrap return env.Undefined(); } + Napi::Value train(const Napi::CallbackInfo &info) + { + Napi::Env env = info.Env(); + + if (info.Length() != 1) + { + Napi::Error::New(env, "Expected 1 argument, but got " + std::to_string(info.Length()) + ".") + .ThrowAsJavaScriptException(); + return env.Undefined(); + } + if (!info[0].IsArray()) + { + Napi::TypeError::New(env, "Invalid the first argument type, must be an Array.").ThrowAsJavaScriptException(); + return env.Undefined(); + } + + Napi::Array arr = info[0].As(); + size_t length = arr.Length(); + auto dv = std::div(length, index_->d); + if (dv.rem != 0) + { + Napi::Error::New(env, "Invalid the given array length.") + .ThrowAsJavaScriptException(); + return env.Undefined(); + } + + float *xb = new float[length]; + for (size_t i = 0; i < length; i++) + { + Napi::Value val = arr[i]; + if (!val.IsNumber()) + { + Napi::Error::New(env, "Expected a Number as array item. (at: " + std::to_string(i) + ")") + .ThrowAsJavaScriptException(); + return env.Undefined(); + } + xb[i] = val.As().FloatValue(); + } + + index_->train(dv.quot, xb); + + delete[] xb; + return env.Undefined(); + } + Napi::Value search(const Napi::CallbackInfo &info) { Napi::Env env = info.Env(); @@ -277,7 +362,7 @@ class IndexFlatL2 : public Napi::ObjectWrap } Napi::Object otherIndex = info[0].As(); - IndexFlatL2 *otherIndexInstance = Napi::ObjectWrap::Unwrap(otherIndex); + T *otherIndexInstance = Napi::ObjectWrap::Unwrap(otherIndex); if (otherIndexInstance->index_->d != index_->d) { @@ -289,7 +374,7 @@ class IndexFlatL2 : public Napi::ObjectWrap { index_->merge_from(*(otherIndexInstance->index_)); } - catch (const faiss::FaissException& ex) + catch (const faiss::FaissException &ex) { Napi::Error::New(env, ex.what()).ThrowAsJavaScriptException(); return env.Undefined(); @@ -335,11 +420,141 @@ class IndexFlatL2 : public Napi::ObjectWrap delete[] xb; return Napi::Number::New(info.Env(), num); } + + Napi::Value toBuffer(const Napi::CallbackInfo &info) + { + Napi::Env env = info.Env(); + + if (info.Length() != 0) + { + Napi::Error::New(env, "Expected 0 arguments, but got " + std::to_string(info.Length()) + ".") + .ThrowAsJavaScriptException(); + return env.Undefined(); + } + + faiss::VectorIOWriter *writer = new faiss::VectorIOWriter(); + + faiss::write_index(index_.get(), writer); + + // return buffer from IOWriter + return Napi::Buffer::Copy(env, writer->data.data(), writer->data.size()); + } + +protected: + std::unique_ptr index_; + inline static Napi::FunctionReference *constructor; +}; + +// faiss::Index is abstract so IndexFlatL2 is used as fallback +class Index : public IndexBase +{ +public: + using IndexBase::IndexBase; + + static constexpr const char *CLASS_NAME = "Index"; + + static Napi::Object Init(Napi::Env env, Napi::Object exports) + { + // clang-format off + auto func = DefineClass(env, CLASS_NAME, { + InstanceMethod("ntotal", &Index::ntotal), + InstanceMethod("getDimension", &Index::getDimension), + InstanceMethod("isTrained", &Index::isTrained), + InstanceMethod("add", &Index::add), + InstanceMethod("train", &Index::train), + InstanceMethod("search", &Index::search), + InstanceMethod("write", &Index::write), + InstanceMethod("mergeFrom", &Index::mergeFrom), + InstanceMethod("removeIds", &Index::removeIds), + InstanceMethod("toBuffer", &Index::toBuffer), + StaticMethod("read", &Index::read), + StaticMethod("fromBuffer", &Index::fromBuffer), + StaticMethod("fromFactory", &Index::fromFactory), + }); + // clang-format on + + constructor = new Napi::FunctionReference(); + *constructor = Napi::Persistent(func); + + exports.Set(CLASS_NAME, func); + return exports; + } +}; + +class IndexFlatL2 : public IndexBase +{ +public: + using IndexBase::IndexBase; + + static constexpr const char *CLASS_NAME = "IndexFlatL2"; + + static Napi::Object Init(Napi::Env env, Napi::Object exports) + { + // clang-format off + auto func = DefineClass(env, CLASS_NAME, { + InstanceMethod("ntotal", &IndexFlatL2::ntotal), + InstanceMethod("getDimension", &IndexFlatL2::getDimension), + InstanceMethod("isTrained", &IndexFlatL2::isTrained), + InstanceMethod("add", &IndexFlatL2::add), + InstanceMethod("train", &IndexFlatL2::train), + InstanceMethod("search", &IndexFlatL2::search), + InstanceMethod("write", &IndexFlatL2::write), + InstanceMethod("mergeFrom", &IndexFlatL2::mergeFrom), + InstanceMethod("removeIds", &IndexFlatL2::removeIds), + InstanceMethod("toBuffer", &IndexFlatL2::toBuffer), + StaticMethod("read", &IndexFlatL2::read), + StaticMethod("fromBuffer", &IndexFlatL2::fromBuffer), + }); + // clang-format on + + constructor = new Napi::FunctionReference(); + *constructor = Napi::Persistent(func); + + exports.Set(CLASS_NAME, func); + return exports; + } +}; + +class IndexFlatIP : public IndexBase +{ +public: + using IndexBase::IndexBase; + + static constexpr const char *CLASS_NAME = "IndexFlatIP"; + + static Napi::Object Init(Napi::Env env, Napi::Object exports) + { + // clang-format off + auto func = DefineClass(env, CLASS_NAME, { + InstanceMethod("ntotal", &IndexFlatIP::ntotal), + InstanceMethod("getDimension", &IndexFlatIP::getDimension), + InstanceMethod("isTrained", &IndexFlatIP::isTrained), + InstanceMethod("add", &IndexFlatIP::add), + InstanceMethod("train", &IndexFlatIP::train), + InstanceMethod("search", &IndexFlatIP::search), + InstanceMethod("write", &IndexFlatIP::write), + InstanceMethod("mergeFrom", &IndexFlatIP::mergeFrom), + InstanceMethod("removeIds", &IndexFlatIP::removeIds), + InstanceMethod("toBuffer", &IndexFlatIP::toBuffer), + StaticMethod("read", &IndexFlatIP::read), + StaticMethod("fromBuffer", &IndexFlatIP::fromBuffer), + }); + // clang-format on + + constructor = new Napi::FunctionReference(); + *constructor = Napi::Persistent(func); + + exports.Set(CLASS_NAME, func); + return exports; + } }; Napi::Object Init(Napi::Env env, Napi::Object exports) { + Index::Init(env, exports); IndexFlatL2::Init(env, exports); + IndexFlatIP::Init(env, exports); + return exports; } diff --git a/test/Index.test.js b/test/Index.test.js new file mode 100644 index 0000000..00b0cd4 --- /dev/null +++ b/test/Index.test.js @@ -0,0 +1,46 @@ +const { Index, MetricType } = require('../lib'); + +describe('Index', () => { + describe('#fromFactory', () => { + it('Flat', () => { + const index = Index.fromFactory(2, 'Flat'); + const x = [1, 0, 0, 1]; + index.add(x); + + expect(index.ntotal()).toBe(2); + }); + + it('Flat /w IP', () => { + const index = Index.fromFactory(2, 'Flat', MetricType.METRIC_INNER_PRODUCT); + const x = [1, 0, 0, 1]; + index.add(x); + + expect(index.ntotal()).toBe(2); + }); + }); + + describe('#train', () => { + it('HNSW training', () => { + const index = Index.fromFactory(2, 'HNSW,Flat'); + const x = [1, 0, 0, 1]; + index.train(x); + index.add(x); + + expect(index.ntotal()).toBe(2); + }); + }); + + + describe('#toBuffer', () => { + it('new index is same size as old', () => { + const index = Index.fromFactory(2, 'Flat'); + const x = [1, 0, 0, 1]; + index.add(x); + + const buf = index.toBuffer(); + const newIndex = Index.fromBuffer(buf); + + expect(index.ntotal()).toBe(newIndex.ntotal()); + }); + }); +}); diff --git a/test/IndexFlatL2.test.js b/test/IndexFlatL2.test.js index 654c059..fc70c53 100644 --- a/test/IndexFlatL2.test.js +++ b/test/IndexFlatL2.test.js @@ -1,21 +1,6 @@ const { IndexFlatL2 } = require('../lib'); describe('IndexFlatL2', () => { - describe('#constructor', () => { - it('throws an error if the count of given param is not 1', () => { - expect(() => { new IndexFlatL2() }).toThrow('Expected 1 argument, but got 0.'); - expect(() => { new IndexFlatL2(1, 2) }).toThrow('Expected 1 argument, but got 2.'); - }); - - it('throws an error if given a non-Number object to the argument', () => { - expect(() => { new IndexFlatL2('1') }).toThrow('Invalid the first argument type, must be a number.'); - }); - - it('throws an error if functional call constructor', () => { - expect(() => { IndexFlatL2(1) }).toThrow("Class constructors cannot be invoked without 'new'"); - }); - }); - describe('#read', () => { it('throws an error if file does not existed', () => { const fname = 'not_existed_file'