Skip to content

Commit

Permalink
IndexHNSW support
Browse files Browse the repository at this point in the history
  • Loading branch information
asilvas committed Oct 1, 2023
1 parent 05488b7 commit f1c5774
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 7 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ $ npm install faiss-node
## Usage

```javascript
const { IndexFlatL2, Index, IndexFlatIP, MetricType } = require('faiss-node');
const { IndexFlatL2, Index, IndexFlatIP, IndexHNSW, MetricType } = require('faiss-node');

const dimension = 2;
const index = new IndexFlatL2(dimension);
Expand Down Expand Up @@ -80,7 +80,9 @@ const deserializedIndex = Index.fromBuffer(index_buf);
console.log(deserializedIndex.ntotal()); // 3

// Factory index
const hnswIndex = Index.fromFactory(2, 'HNSW,Flat', MetricType.METRIC_INNER_PRODUCT);
const hnswIndex = Index.fromFactory(2, 'HNSW32,Flat', MetricType.METRIC_INNER_PRODUCT);
// same as:
// const hnswIndex = new IndexHNSW(2, 32, MetricType.METRIC_INNER_PRODUCT)
const x = [1, 0, 0, 1];
hnswIndex.train(x);
hnswIndex.add(x);
Expand Down
46 changes: 46 additions & 0 deletions lib/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,50 @@ export class IndexFlatIP extends Index {
* @param {IndexFlatIP} otherIndex The other IndexFlatIP instance to merge from.
*/
mergeFrom(otherIndex: IndexFlatIP): void;
}

/**
* IndexHNSW Index.
* The Hierarchical Navigable Small World indexing method is based on a graph built on the indexed vectors.
* @param {number} d The dimensionality of index.
* @param {number} m The number of neighbors used in the graph (defaults to 32).
* @param {number} metric Metric type (defaults to L2).
*/
export class IndexHNSW extends Index {
IndexHNSW(d?: number, m?: number, metric?: MetricType);
/**
* Read index from a file.
* @param {string} fname File path to read.
* @return {IndexHNSW} The index read.
*/
static read(fname: string): IndexHNSW;
/**
* Read index from buffer.
* @param {Buffer} src Buffer to create index from.
* @return {IndexHNSW} The index read.
*/
static fromBuffer(src: Buffer): IndexHNSW;
/**
* Merge the current index with another IndexHNSW instance.
* @param {IndexHNSW} otherIndex The other IndexHNSW instance to merge from.
*/
mergeFrom(otherIndex: IndexHNSW): void;
/**
* The depth of exploration at add time.
*/
get efConstruction(): number;
/**
* The depth of exploration at add time.
* @param {number} value The value to set.
*/
set efConstruction(value: number);
/**
* The depth of exploration of the search.
*/
get efSearch(): number;
/**
* The depth of exploration of the search.
* @param {number} value The value to set.
*/
set efSearch(value: number);
}
17 changes: 17 additions & 0 deletions lib/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,21 @@ var MetricType;
MetricType[MetricType["METRIC_Jaccard"] = 23] = "METRIC_Jaccard";
})(MetricType || (faiss.MetricType = MetricType = {}));

Object.defineProperty(faiss.IndexHNSW.prototype, 'efConstruction', {
get: function () {
return this.getEfConstruction();
},
set: function (value) {
this.setEfConstruction(value);
}
});
Object.defineProperty(faiss.IndexHNSW.prototype, 'efSearch', {
get: function () {
return this.getEfSearch();
},
set: function (value) {
this.setEfSearch(value);
}
});

module.exports = faiss;
130 changes: 125 additions & 5 deletions src/faiss.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,46 @@
#include <faiss/index_factory.h>
#include <faiss/MetricType.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/IndexHNSW.h>

using namespace Napi;
using idx_t = faiss::idx_t;

template <class T, typename Y>
enum class IndexType
{
Index,
IndexFlatL2,
IndexFlatIP,
IndexHNSW,
};

template <class T, typename Y, IndexType IT>
class IndexBase : public Napi::ObjectWrap<T>
{
public:
IndexBase(const Napi::CallbackInfo &info) : Napi::ObjectWrap<T>(info)
{
Napi::Env env = info.Env();

if (info.Length() > 0 && info[0].IsNumber())
if constexpr (IT == IndexType::IndexHNSW)
{ // HNSW constructor
if (info.Length() > 0 && info[0].IsNumber())
{
auto n = info[0].As<Napi::Number>().Uint32Value();
auto m = 32; // faiss default
auto metric = faiss::MetricType::METRIC_L2; // faiss default
if (info.Length() > 1 && info[1].IsNumber())
{
m = info[1].As<Napi::Number>().Uint32Value();
}
if (info.Length() > 2 && info[2].IsNumber())
{
metric = static_cast<faiss::MetricType>(info[2].As<Napi::Number>().Uint32Value());
}
index_ = std::unique_ptr<faiss::IndexHNSW>(new faiss::IndexHNSW(n, m, metric));
}
}
else if (info.Length() > 0 && info[0].IsNumber())
{
auto n = info[0].As<Napi::Number>().Uint32Value();
index_ = std::unique_ptr<Y>(new Y(n));
Expand Down Expand Up @@ -446,7 +473,7 @@ class IndexBase : public Napi::ObjectWrap<T>
};

// faiss::Index is abstract so IndexFlatL2 is used as fallback
class Index : public IndexBase<Index, faiss::IndexFlatL2>
class Index : public IndexBase<Index, faiss::IndexFlatL2, IndexType::Index>
{
public:
using IndexBase::IndexBase;
Expand Down Expand Up @@ -481,7 +508,7 @@ class Index : public IndexBase<Index, faiss::IndexFlatL2>
}
};

class IndexFlatL2 : public IndexBase<IndexFlatL2, faiss::IndexFlatL2>
class IndexFlatL2 : public IndexBase<IndexFlatL2, faiss::IndexFlatL2, IndexType::IndexFlatL2>
{
public:
using IndexBase::IndexBase;
Expand Down Expand Up @@ -515,7 +542,7 @@ class IndexFlatL2 : public IndexBase<IndexFlatL2, faiss::IndexFlatL2>
}
};

class IndexFlatIP : public IndexBase<IndexFlatIP, faiss::IndexFlatIP>
class IndexFlatIP : public IndexBase<IndexFlatIP, faiss::IndexFlatIP, IndexType::IndexFlatIP>
{
public:
using IndexBase::IndexBase;
Expand Down Expand Up @@ -549,11 +576,104 @@ class IndexFlatIP : public IndexBase<IndexFlatIP, faiss::IndexFlatIP>
}
};

class IndexHNSW : public IndexBase<IndexHNSW, faiss::IndexHNSW, IndexType::IndexHNSW>
{
public:
using IndexBase::IndexBase;

static constexpr const char *CLASS_NAME = "IndexHNSW";

Napi::Value getEfConstruction(const Napi::CallbackInfo &info)
{
auto index = dynamic_cast<faiss::IndexHNSW *>(index_.get());
return Napi::Number::New(info.Env(), index->hnsw.efConstruction);
}

Napi::Value setEfConstruction(const Napi::CallbackInfo &info)
{
Napi::Env env = info.Env();

if (info.Length() != 1)
{
Napi::Error::New(env, "Expected 1 argument, but got " + std::to_string(info.Length()) + ".")
.ThrowAsJavaScriptException();
return env.Undefined();
}
if (!info[0].IsNumber())
{
Napi::TypeError::New(env, "Invalid the first argument type, must be a Number.").ThrowAsJavaScriptException();
return env.Undefined();
}

auto index = dynamic_cast<faiss::IndexHNSW *>(index_.get());
index->hnsw.efConstruction = info[0].As<Napi::Number>().Int32Value();
return env.Undefined();
}

Napi::Value getEfSearch(const Napi::CallbackInfo &info)
{
auto index = dynamic_cast<faiss::IndexHNSW *>(index_.get());
return Napi::Number::New(info.Env(), index->hnsw.efSearch);
}

Napi::Value setEfSearch(const Napi::CallbackInfo &info)
{
Napi::Env env = info.Env();

if (info.Length() != 1)
{
Napi::Error::New(env, "Expected 1 argument, but got " + std::to_string(info.Length()) + ".")
.ThrowAsJavaScriptException();
return env.Undefined();
}
if (!info[0].IsNumber())
{
Napi::TypeError::New(env, "Invalid the first argument type, must be a Number.").ThrowAsJavaScriptException();
return env.Undefined();
}

auto index = dynamic_cast<faiss::IndexHNSW *>(index_.get());
index->hnsw.efSearch = info[0].As<Napi::Number>().Int32Value();
return env.Undefined();
}

static Napi::Object Init(Napi::Env env, Napi::Object exports)
{
// clang-format off
auto func = DefineClass(env, CLASS_NAME, {
InstanceMethod("getEfConstruction", &IndexHNSW::getEfConstruction),
InstanceMethod("setEfConstruction", &IndexHNSW::setEfConstruction),
InstanceMethod("getEfSearch", &IndexHNSW::getEfSearch),
InstanceMethod("setEfSearch", &IndexHNSW::setEfSearch),
InstanceMethod("ntotal", &IndexHNSW::ntotal),
InstanceMethod("getDimension", &IndexHNSW::getDimension),
InstanceMethod("isTrained", &IndexHNSW::isTrained),
InstanceMethod("add", &IndexHNSW::add),
InstanceMethod("train", &IndexHNSW::train),
InstanceMethod("search", &IndexHNSW::search),
InstanceMethod("write", &IndexHNSW::write),
InstanceMethod("mergeFrom", &IndexHNSW::mergeFrom),
InstanceMethod("removeIds", &IndexHNSW::removeIds),
InstanceMethod("toBuffer", &IndexHNSW::toBuffer),
StaticMethod("read", &IndexHNSW::read),
StaticMethod("fromBuffer", &IndexHNSW::fromBuffer),
});
// clang-format on

constructor = new Napi::FunctionReference();
*constructor = Napi::Persistent(func);

exports.Set(CLASS_NAME, func);
return exports;
}
};

Napi::Object Init(Napi::Env env, Napi::Object exports)
{
Index::Init(env, exports);
IndexFlatL2::Init(env, exports);
IndexFlatIP::Init(env, exports);
IndexHNSW::Init(env, exports);

return exports;
}
Expand Down
20 changes: 20 additions & 0 deletions test/IndexHNSW.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
const { IndexHNSW, MetricType } = require('../lib');

describe('IndexHNSW', () => {
describe('#constructor', () => {
it('1 arg will result in index with default neighbors & metric', () => {
const index = new IndexHNSW(2);
expect(index.getDimension()).toBe(2);
});

it('2 args will result in index with default metric', () => {
const index = new IndexHNSW(2, 20);
expect(index.getDimension()).toBe(2);
});

it('3 args will result in index', () => {
const index = new IndexHNSW(2, 20, MetricType.METRIC_INNER_PRODUCT);
expect(index.getDimension()).toBe(2);
});
});
});

0 comments on commit f1c5774

Please sign in to comment.