Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AddWithIds & toIDMap support #46

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ const hnswIndex = Index.fromFactory(2, 'HNSW,Flat', MetricType.METRIC_INNER_PROD
const x = [1, 0, 0, 1];
hnswIndex.train(x);
hnswIndex.add(x);

// IDMap'd index
const idIndex = new IndexFlat(2).toIDMap();
idIndex.addWithIds([1, 0, 0, 1], [100n, 200n]);
```

## License
Expand Down
4 changes: 4 additions & 0 deletions lib/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ export class Index {
* Write index to buffer.
*/
toBuffer(): Buffer;
/**
* Create an IDMap'd index from source index.
*/
toIDMap(): Index;
/**
* Read index from a file.
* @param {string} fname File path to read.
Expand Down
114 changes: 114 additions & 0 deletions src/faiss.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <faiss/index_factory.h>
#include <faiss/MetricType.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/IndexIDMap.h>

using namespace Napi;
using idx_t = faiss::idx_t;
Expand Down Expand Up @@ -190,6 +191,86 @@ class IndexBase : public Napi::ObjectWrap<T>
return env.Undefined();
}

Napi::Value addWithIds(const Napi::CallbackInfo &info)
{
Napi::Env env = info.Env();

if (info.Length() != 2)
{
Napi::Error::New(env, "Expected 2 arguments, but got " + std::to_string(info.Length()) + ".")
.ThrowAsJavaScriptException();
return env.Undefined();
}
if (!info[0].IsArray())
{
Napi::TypeError::New(env, "Invalid the first argument type, must be an Array.").ThrowAsJavaScriptException();
return env.Undefined();
}
if (!info[1].IsArray())
{
Napi::TypeError::New(env, "Invalid the second argument type, must be an Array.").ThrowAsJavaScriptException();
return env.Undefined();
}

Napi::Array arr = info[0].As<Napi::Array>();
size_t length = arr.Length();
auto dv = std::div(length, index_->d);
if (dv.rem != 0)
{
Napi::Error::New(env, "Invalid the given array length.")
.ThrowAsJavaScriptException();
return env.Undefined();
}
Napi::Array labels = info[1].As<Napi::Array>();
size_t labelCount = labels.Length();
if (labelCount != dv.quot)
{
Napi::Error::New(env, "Labels array length must match the number of vectors.")
.ThrowAsJavaScriptException();
return env.Undefined();
}

float *xb = new float[length];
for (size_t i = 0; i < length; i++)
{
Napi::Value val = arr[i];
if (!val.IsNumber())
{
Napi::Error::New(env, "Expected a Number as array item. (at: " + std::to_string(i) + ")")
.ThrowAsJavaScriptException();
return env.Undefined();
}
xb[i] = val.As<Napi::Number>().FloatValue();
}

idx_t *xc = new idx_t[labelCount];
for (size_t i = 0; i < labelCount; i++)
{
Napi::Value val = labels[i];
if (val.IsNumber())
{
xc[i] = val.As<Napi::Number>().Int64Value();
}
else if (val.IsBigInt())
{
auto lossless = false;
xc[i] = val.As<Napi::BigInt>().Int64Value(&lossless);
}
else
{
Napi::Error::New(env, "Expected a Number or BigInt as array item. (at: " + std::to_string(i) + ")")
.ThrowAsJavaScriptException();
return env.Undefined();
}
}

index_->add_with_ids(dv.quot, xb, xc);

delete[] xb;
delete[] xc;
return env.Undefined();
}

Napi::Value train(const Napi::CallbackInfo &info)
{
Napi::Env env = info.Env();
Expand Down Expand Up @@ -440,6 +521,33 @@ class IndexBase : public Napi::ObjectWrap<T>
return Napi::Buffer<uint8_t>::Copy(env, writer->data.data(), writer->data.size());
}

Napi::Value toIDMap2(const Napi::CallbackInfo &info)
{
Napi::Env env = info.Env();

if (info.Length() != 0)
{
Napi::Error::New(env, "Expected 0 arguments, but got " + std::to_string(info.Length()) + ".")
.ThrowAsJavaScriptException();
return env.Undefined();
}

Napi::Object instance = T::constructor->New({});
T *index = Napi::ObjectWrap<T>::Unwrap(instance);

try
{
// wrap the new IDMap'd index around the old and leave it to faiss to throw if index not compatible
index->index_ = std::unique_ptr<faiss::IndexIDMap2>(new faiss::IndexIDMap2(index_.get()));
}
catch (const faiss::FaissException &ex)
{
Napi::Error::New(env, ex.what()).ThrowAsJavaScriptException();
}

return instance;
}

protected:
std::unique_ptr<faiss::Index> index_;
inline static Napi::FunctionReference *constructor;
Expand All @@ -461,12 +569,14 @@ class Index : public IndexBase<Index, faiss::IndexFlatL2>
InstanceMethod("getDimension", &Index::getDimension),
InstanceMethod("isTrained", &Index::isTrained),
InstanceMethod("add", &Index::add),
InstanceMethod("addWithIds", &Index::addWithIds),
InstanceMethod("train", &Index::train),
InstanceMethod("search", &Index::search),
InstanceMethod("write", &Index::write),
InstanceMethod("mergeFrom", &Index::mergeFrom),
InstanceMethod("removeIds", &Index::removeIds),
InstanceMethod("toBuffer", &Index::toBuffer),
InstanceMethod("toIDMap2", &Index::toIDMap2),
StaticMethod("read", &Index::read),
StaticMethod("fromBuffer", &Index::fromBuffer),
StaticMethod("fromFactory", &Index::fromFactory),
Expand Down Expand Up @@ -496,12 +606,14 @@ class IndexFlatL2 : public IndexBase<IndexFlatL2, faiss::IndexFlatL2>
InstanceMethod("getDimension", &IndexFlatL2::getDimension),
InstanceMethod("isTrained", &IndexFlatL2::isTrained),
InstanceMethod("add", &IndexFlatL2::add),
InstanceMethod("addWithIds", &IndexFlatL2::addWithIds),
InstanceMethod("train", &IndexFlatL2::train),
InstanceMethod("search", &IndexFlatL2::search),
InstanceMethod("write", &IndexFlatL2::write),
InstanceMethod("mergeFrom", &IndexFlatL2::mergeFrom),
InstanceMethod("removeIds", &IndexFlatL2::removeIds),
InstanceMethod("toBuffer", &IndexFlatL2::toBuffer),
InstanceMethod("toIDMap2", &IndexFlatL2::toIDMap2),
StaticMethod("read", &IndexFlatL2::read),
StaticMethod("fromBuffer", &IndexFlatL2::fromBuffer),
});
Expand Down Expand Up @@ -530,12 +642,14 @@ class IndexFlatIP : public IndexBase<IndexFlatIP, faiss::IndexFlatIP>
InstanceMethod("getDimension", &IndexFlatIP::getDimension),
InstanceMethod("isTrained", &IndexFlatIP::isTrained),
InstanceMethod("add", &IndexFlatIP::add),
InstanceMethod("addWithIds", &IndexFlatIP::addWithIds),
InstanceMethod("train", &IndexFlatIP::train),
InstanceMethod("search", &IndexFlatIP::search),
InstanceMethod("write", &IndexFlatIP::write),
InstanceMethod("mergeFrom", &IndexFlatIP::mergeFrom),
InstanceMethod("removeIds", &IndexFlatIP::removeIds),
InstanceMethod("toBuffer", &IndexFlatIP::toBuffer),
InstanceMethod("toIDMap2", &IndexFlatIP::toIDMap2),
StaticMethod("read", &IndexFlatIP::read),
StaticMethod("fromBuffer", &IndexFlatIP::fromBuffer),
});
Expand Down
22 changes: 22 additions & 0 deletions test/Index.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,26 @@ describe('Index', () => {
expect(index.ntotal()).toBe(newIndex.ntotal());
});
});

describe('#toIDMap2', () => {
it('new index preserves ID\'s', () => {
const index = Index.fromFactory(2, 'Flat').toIDMap2();
const x = [1, 0, 0, 1];
const labels = [100, 200];
index.addWithIds(x, labels);
const results = index.search([1, 0], 2);
expect(results.labels).toEqual(labels);
});

it('supports BigInt labels', () => {
const index = Index.fromFactory(2, 'Flat').toIDMap2();
const x = [1, 0, 0, 1];
const labels = [100n, 200n];
index.addWithIds(x, labels);
const results = index.search([1, 0], 2);
expect(results.labels).toEqual([100, 200]);
// TODO: Once search supports BigInt, use this test instead
// expect(results.labels).toEqual(labels);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can update once #43 is merged

});
});
});