Skip to content

Commit

Permalink
Add a couple of convenience methods for converting embeddings (#13329)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #13329

The patch adds two convenience methods `ConvertFloatsToSlice` and `ConvertSliceToFloats` that can be used to convert embeddings from a contiguous range of floats to a RocksDB `Slice` or vice versa. The methods are added to the public API so they can be utilized by applications as well.

Reviewed By: jowlyzhang

Differential Revision: D68581494

fbshipit-source-id: 2207fa3e668a6546b7de6d8ab78be2ba9f2ffd8c
  • Loading branch information
ltamasi authored and facebook-github-bot committed Jan 25, 2025
1 parent a8bd6a3 commit ac6c671
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 25 deletions.
21 changes: 21 additions & 0 deletions include/rocksdb/utilities/secondary_index_faiss.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <string>

#include "rocksdb/rocksdb_namespace.h"
#include "rocksdb/slice.h"
#include "rocksdb/utilities/secondary_index.h"

namespace faiss {
Expand All @@ -29,4 +30,24 @@ namespace ROCKSDB_NAMESPACE {
std::unique_ptr<SecondaryIndex> NewFaissIVFIndex(
std::unique_ptr<faiss::IndexIVF>&& index, std::string primary_column_name);

// Helper methods to convert embeddings from a span of floats to Slice or vice
// versa

// Convert the given span of floats of size dim to a Slice.
// PRE: embedding points to a contiguous span of floats of size dim
inline Slice ConvertFloatsToSlice(const float* embedding, size_t dim) {
return Slice(reinterpret_cast<const char*>(embedding), dim * sizeof(float));
}

// Convert the given Slice to a span of floats of size dim.
// PRE: embedding.size() == dim * sizeof(float)
// Returns nullptr if the precondition is violated.
inline const float* ConvertSliceToFloats(const Slice& embedding, size_t dim) {
if (embedding.size() != dim * sizeof(float)) {
return nullptr;
}

return reinterpret_cast<const float*>(embedding.data());
}

} // namespace ROCKSDB_NAMESPACE
23 changes: 14 additions & 9 deletions utilities/secondary_index/faiss_ivf_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ class FaissIVFIndex::KNNIterator : public Iterator {
pos_ = 0;
keys_.clear();

if (target.size() != index_->d * sizeof(float)) {
const float* const embedding = ConvertSliceToFloats(target, index_->d);
if (!embedding) {
status_ = Status::InvalidArgument(
"Incorrectly sized vector passed to FaissIVFIndex");
return;
Expand All @@ -83,8 +84,8 @@ class FaissIVFIndex::KNNIterator : public Iterator {
constexpr faiss::idx_t n = 1;

try {
index_->search(n, reinterpret_cast<const float*>(target.data()), k_,
distances_.data(), labels_.data(), &params);
index_->search(n, embedding, k_, distances_.data(), labels_.data(),
&params);
} catch (const std::exception& e) {
status_ = Status::InvalidArgument(e.what());
}
Expand Down Expand Up @@ -364,7 +365,9 @@ Status FaissIVFIndex::UpdatePrimaryColumnValue(
const {
assert(updated_column_value);

if (primary_column_value.size() != index_->d * sizeof(float)) {
const float* const embedding =
ConvertSliceToFloats(primary_column_value, index_->d);
if (!embedding) {
return Status::InvalidArgument(
"Incorrectly sized vector passed to FaissIVFIndex");
}
Expand All @@ -373,8 +376,7 @@ Status FaissIVFIndex::UpdatePrimaryColumnValue(
faiss::idx_t label = -1;

try {
index_->quantizer->assign(
n, reinterpret_cast<const float*>(primary_column_value.data()), &label);
index_->quantizer->assign(n, embedding, &label);
} catch (const std::exception& e) {
return Status::InvalidArgument(e.what());
}
Expand Down Expand Up @@ -420,13 +422,16 @@ Status FaissIVFIndex::GetSecondaryValue(
assert(label < index_->nlist);

constexpr faiss::idx_t n = 1;

const float* const embedding =
ConvertSliceToFloats(original_column_value, index_->d);
assert(embedding);

constexpr faiss::idx_t* xids = nullptr;
std::string code_str;

try {
index_->add_core(
n, reinterpret_cast<const float*>(original_column_value.data()), xids,
&label, &code_str);
index_->add_core(n, embedding, xids, &label, &code_str);
} catch (const std::exception& e) {
return Status::InvalidArgument(e.what());
}
Expand Down
24 changes: 8 additions & 16 deletions utilities/secondary_index/faiss_ivf_index_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ TEST(FaissIVFIndexTest, Basic) {
cfh1, primary_key,
WideColumns{
{primary_column_name,
Slice(reinterpret_cast<const char*>(embeddings.data() + i * dim),
dim * sizeof(float))}}));
ConvertFloatsToSlice(embeddings.data() + i * dim, dim)}}));
}

ASSERT_OK(txn->Commit());
Expand All @@ -102,10 +101,8 @@ TEST(FaissIVFIndexTest, Basic) {

// Since we use IndexIVFFlat, there is no fine quantization, so the code
// is actually just the original embedding
ASSERT_EQ(
it->value(),
Slice(reinterpret_cast<const char*>(embeddings.data() + id * dim),
dim * sizeof(float)));
ASSERT_EQ(it->value(),
ConvertFloatsToSlice(embeddings.data() + id * dim, dim));

++num_found;
}
Expand Down Expand Up @@ -159,9 +156,7 @@ TEST(FaissIVFIndexTest, Basic) {
// Search for a vector from the original set; we expect to find the vector
// itself as the closest match, since we're performing an exhaustive search
{
it->Seek(
Slice(reinterpret_cast<const char*>(embeddings.data() + id * dim),
dim * sizeof(float)));
it->Seek(ConvertFloatsToSlice(embeddings.data() + id * dim, dim));
ASSERT_TRUE(it->Valid());
ASSERT_OK(it->status());
ASSERT_EQ(get_id(), id);
Expand Down Expand Up @@ -225,8 +220,7 @@ TEST(FaissIVFIndexTest, Basic) {
ASSERT_FALSE(it->Valid());
ASSERT_TRUE(it->status().IsNotSupported());

it->SeekForPrev(Slice(reinterpret_cast<const char*>(embeddings.data()),
dim * sizeof(float)));
it->SeekForPrev(ConvertFloatsToSlice(embeddings.data(), dim));
ASSERT_FALSE(it->Valid());
ASSERT_TRUE(it->status().IsNotSupported());

Expand Down Expand Up @@ -354,8 +348,7 @@ TEST(FaissIVFIndexTest, Compare) {

const std::string primary_key = std::to_string(i);
ASSERT_OK(db->Put(WriteOptions(), cfh1, primary_key,
Slice(reinterpret_cast<const char*>(embedding),
dim * sizeof(float))));
ConvertFloatsToSlice(embedding, dim)));
}
}

Expand Down Expand Up @@ -413,9 +406,8 @@ TEST(FaissIVFIndexTest, Compare) {
}

size_t num_found = 0;
for (it->Seek(Slice(reinterpret_cast<const char*>(embedding),
dim * sizeof(float)));
it->Valid(); it->Next()) {
for (it->Seek(ConvertFloatsToSlice(embedding, dim)); it->Valid();
it->Next()) {
const faiss::idx_t id = get_id();
ASSERT_GE(id, 0);
ASSERT_LT(id, num_db);
Expand Down

0 comments on commit ac6c671

Please sign in to comment.