Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IndexHNSW support #47

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ $ npm install faiss-node
## Usage

```javascript
const { IndexFlatL2, Index, IndexFlatIP, MetricType } = require('faiss-node');
const { IndexFlatL2, Index, IndexFlatIP, IndexHNSW, MetricType } = require('faiss-node');

const dimension = 2;
const index = new IndexFlatL2(dimension);
Expand Down Expand Up @@ -80,7 +80,9 @@ const deserializedIndex = Index.fromBuffer(index_buf);
console.log(deserializedIndex.ntotal()); // 3

// Factory index
const hnswIndex = Index.fromFactory(2, 'HNSW,Flat', MetricType.METRIC_INNER_PRODUCT);
const hnswIndex = Index.fromFactory(2, 'HNSW32,Flat', MetricType.METRIC_INNER_PRODUCT);
// same as:
// const hnswIndex = new IndexHNSW(2, 32, MetricType.METRIC_INNER_PRODUCT)
const x = [1, 0, 0, 1];
hnswIndex.train(x);
hnswIndex.add(x);
Expand Down
46 changes: 46 additions & 0 deletions lib/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,50 @@ export class IndexFlatIP extends Index {
* @param {IndexFlatIP} otherIndex The other IndexFlatIP instance to merge from.
*/
mergeFrom(otherIndex: IndexFlatIP): void;
}

/**
* IndexHNSW Index.
* The Hierarchical Navigable Small World indexing method is based on a graph built on the indexed vectors.
* @param {number} d The dimensionality of index.
* @param {number} m The number of neighbors used in the graph (defaults to 32).
* @param {number} metric Metric type (defaults to L2).
*/
export class IndexHNSW extends Index {
IndexHNSW(d?: number, m?: number, metric?: MetricType);
/**
* Read index from a file.
* @param {string} fname File path to read.
* @return {IndexHNSW} The index read.
*/
static read(fname: string): IndexHNSW;
/**
* Read index from buffer.
* @param {Buffer} src Buffer to create index from.
* @return {IndexHNSW} The index read.
*/
static fromBuffer(src: Buffer): IndexHNSW;
/**
* Merge the current index with another IndexHNSW instance.
* @param {IndexHNSW} otherIndex The other IndexHNSW instance to merge from.
*/
mergeFrom(otherIndex: IndexHNSW): void;
/**
* The depth of exploration at add time.
*/
get efConstruction(): number;
/**
* The depth of exploration at add time.
* @param {number} value The value to set.
*/
set efConstruction(value: number);
/**
* The depth of exploration of the search.
*/
get efSearch(): number;
/**
* The depth of exploration of the search.
* @param {number} value The value to set.
*/
set efSearch(value: number);
}
21 changes: 21 additions & 0 deletions lib/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,25 @@ var MetricType;
MetricType[MetricType["METRIC_Jaccard"] = 23] = "METRIC_Jaccard";
})(MetricType || (faiss.MetricType = MetricType = {}));

if (!('efConstruction' in faiss.IndexHNSW.prototype)) { // prevents redefinition in jest
Object.defineProperty(faiss.IndexHNSW.prototype, 'efConstruction', {
get: function () {
return this.getEfConstruction();
},
set: function (value) {
this.setEfConstruction(value);
}
});
}
if (!('efSearch' in faiss.IndexHNSW.prototype)) { // prevents redefinition in jest
Object.defineProperty(faiss.IndexHNSW.prototype, 'efSearch', {
get: function () {
return this.getEfSearch();
},
set: function (value) {
this.setEfSearch(value);
}
});
}

module.exports = faiss;
130 changes: 125 additions & 5 deletions src/faiss.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,46 @@
#include <faiss/index_factory.h>
#include <faiss/MetricType.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/IndexHNSW.h>

using namespace Napi;
using idx_t = faiss::idx_t;

template <class T, typename Y>
enum class IndexType
{
Index,
IndexFlatL2,
IndexFlatIP,
IndexHNSW,
};

template <class T, typename Y, IndexType IT>
class IndexBase : public Napi::ObjectWrap<T>
{
public:
IndexBase(const Napi::CallbackInfo &info) : Napi::ObjectWrap<T>(info)
{
Napi::Env env = info.Env();

if (info.Length() > 0 && info[0].IsNumber())
if constexpr (IT == IndexType::IndexHNSW)
{ // HNSW constructor
if (info.Length() > 0 && info[0].IsNumber())
{
auto n = info[0].As<Napi::Number>().Uint32Value();
auto m = 32; // faiss default
auto metric = faiss::MetricType::METRIC_L2; // faiss default
if (info.Length() > 1 && info[1].IsNumber())
{
m = info[1].As<Napi::Number>().Uint32Value();
}
if (info.Length() > 2 && info[2].IsNumber())
{
metric = static_cast<faiss::MetricType>(info[2].As<Napi::Number>().Uint32Value());
}
index_ = std::unique_ptr<faiss::IndexHNSW>(new faiss::IndexHNSW(n, m, metric));
}
}
else if (info.Length() > 0 && info[0].IsNumber())
{
auto n = info[0].As<Napi::Number>().Uint32Value();
index_ = std::unique_ptr<Y>(new Y(n));
Expand Down Expand Up @@ -446,7 +473,7 @@ class IndexBase : public Napi::ObjectWrap<T>
};

// faiss::Index is abstract so IndexFlatL2 is used as fallback
class Index : public IndexBase<Index, faiss::IndexFlatL2>
class Index : public IndexBase<Index, faiss::IndexFlatL2, IndexType::Index>
{
public:
using IndexBase::IndexBase;
Expand Down Expand Up @@ -481,7 +508,7 @@ class Index : public IndexBase<Index, faiss::IndexFlatL2>
}
};

class IndexFlatL2 : public IndexBase<IndexFlatL2, faiss::IndexFlatL2>
class IndexFlatL2 : public IndexBase<IndexFlatL2, faiss::IndexFlatL2, IndexType::IndexFlatL2>
{
public:
using IndexBase::IndexBase;
Expand Down Expand Up @@ -515,7 +542,7 @@ class IndexFlatL2 : public IndexBase<IndexFlatL2, faiss::IndexFlatL2>
}
};

class IndexFlatIP : public IndexBase<IndexFlatIP, faiss::IndexFlatIP>
class IndexFlatIP : public IndexBase<IndexFlatIP, faiss::IndexFlatIP, IndexType::IndexFlatIP>
{
public:
using IndexBase::IndexBase;
Expand Down Expand Up @@ -549,11 +576,104 @@ class IndexFlatIP : public IndexBase<IndexFlatIP, faiss::IndexFlatIP>
}
};

class IndexHNSW : public IndexBase<IndexHNSW, faiss::IndexHNSW, IndexType::IndexHNSW>
{
public:
using IndexBase::IndexBase;

static constexpr const char *CLASS_NAME = "IndexHNSW";

Napi::Value getEfConstruction(const Napi::CallbackInfo &info)
{
auto index = dynamic_cast<faiss::IndexHNSW *>(index_.get());
return Napi::Number::New(info.Env(), index->hnsw.efConstruction);
}

Napi::Value setEfConstruction(const Napi::CallbackInfo &info)
{
Napi::Env env = info.Env();

if (info.Length() != 1)
{
Napi::Error::New(env, "Expected 1 argument, but got " + std::to_string(info.Length()) + ".")
.ThrowAsJavaScriptException();
return env.Undefined();
}
if (!info[0].IsNumber())
{
Napi::TypeError::New(env, "Invalid the first argument type, must be a Number.").ThrowAsJavaScriptException();
return env.Undefined();
}

auto index = dynamic_cast<faiss::IndexHNSW *>(index_.get());
index->hnsw.efConstruction = info[0].As<Napi::Number>().Int32Value();
return env.Undefined();
}

Napi::Value getEfSearch(const Napi::CallbackInfo &info)
{
auto index = dynamic_cast<faiss::IndexHNSW *>(index_.get());
return Napi::Number::New(info.Env(), index->hnsw.efSearch);
}

Napi::Value setEfSearch(const Napi::CallbackInfo &info)
{
Napi::Env env = info.Env();

if (info.Length() != 1)
{
Napi::Error::New(env, "Expected 1 argument, but got " + std::to_string(info.Length()) + ".")
.ThrowAsJavaScriptException();
return env.Undefined();
}
if (!info[0].IsNumber())
{
Napi::TypeError::New(env, "Invalid the first argument type, must be a Number.").ThrowAsJavaScriptException();
return env.Undefined();
}

auto index = dynamic_cast<faiss::IndexHNSW *>(index_.get());
index->hnsw.efSearch = info[0].As<Napi::Number>().Int32Value();
return env.Undefined();
}

static Napi::Object Init(Napi::Env env, Napi::Object exports)
{
// clang-format off
auto func = DefineClass(env, CLASS_NAME, {
InstanceMethod("getEfConstruction", &IndexHNSW::getEfConstruction),
InstanceMethod("setEfConstruction", &IndexHNSW::setEfConstruction),
InstanceMethod("getEfSearch", &IndexHNSW::getEfSearch),
InstanceMethod("setEfSearch", &IndexHNSW::setEfSearch),
InstanceMethod("ntotal", &IndexHNSW::ntotal),
InstanceMethod("getDimension", &IndexHNSW::getDimension),
InstanceMethod("isTrained", &IndexHNSW::isTrained),
InstanceMethod("add", &IndexHNSW::add),
InstanceMethod("train", &IndexHNSW::train),
InstanceMethod("search", &IndexHNSW::search),
InstanceMethod("write", &IndexHNSW::write),
InstanceMethod("mergeFrom", &IndexHNSW::mergeFrom),
InstanceMethod("removeIds", &IndexHNSW::removeIds),
InstanceMethod("toBuffer", &IndexHNSW::toBuffer),
StaticMethod("read", &IndexHNSW::read),
StaticMethod("fromBuffer", &IndexHNSW::fromBuffer),
});
// clang-format on

constructor = new Napi::FunctionReference();
*constructor = Napi::Persistent(func);

exports.Set(CLASS_NAME, func);
return exports;
}
};

Napi::Object Init(Napi::Env env, Napi::Object exports)
{
Index::Init(env, exports);
IndexFlatL2::Init(env, exports);
IndexFlatIP::Init(env, exports);
IndexHNSW::Init(env, exports);

return exports;
}
Expand Down
20 changes: 20 additions & 0 deletions test/IndexHNSW.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
const { IndexHNSW, MetricType } = require('../lib');

describe('IndexHNSW', () => {
describe('#constructor', () => {
it('1 arg will result in index with default neighbors & metric', () => {
const index = new IndexHNSW(2);
expect(index.getDimension()).toBe(2);
Copy link
Contributor Author

@asilvas asilvas Oct 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we'd be able to check the value of neighbors & metrics, but indexes don't expose those props so I kept this way for consistency with faiss. This check is only useful to show it doesn't throw.

});

it('2 args will result in index with default metric', () => {
const index = new IndexHNSW(2, 20);
expect(index.getDimension()).toBe(2);
});

it('3 args will result in index', () => {
const index = new IndexHNSW(2, 20, MetricType.METRIC_INNER_PRODUCT);
expect(index.getDimension()).toBe(2);
});
});
});