Skip to content

GH-3540: Allow user-provided embeddings in VectorStore #3541

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,16 @@ public static SimpleVectorStoreBuilder builder(EmbeddingModel embeddingModel) {
}

@Override
public void doAdd(List<Document> documents) {
public void doAdd(List<Document> documents, List<float[]> embeddings) {
Objects.requireNonNull(documents, "Documents list cannot be null");
if (documents.isEmpty()) {
throw new IllegalArgumentException("Documents list cannot be empty");
}

for (Document document : documents) {
logger.info("Calling EmbeddingModel for document id = {}", document.getId());
float[] embedding = this.embeddingModel.embed(document);
SimpleVectorStoreContent storeContent = new SimpleVectorStoreContent(document.getId(), document.getText(),
document.getMetadata(), embedding);
document.getMetadata(), embeddings.get(documents.indexOf(document)));
this.store.put(document.getId(), storeContent);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ default String getName() {
*/
void add(List<Document> documents);

/**
* Adds list of {@link Document}s with their corresponding embeddings to the vector store.
* @param documents the list of documents to store. Throws an exception if the
* underlying provider checks for duplicate IDs.
* @param embeddings the list of float[] embeddings corresponding to each document.
* @throws IllegalArgumentException if there is:
* <ul>
* <li> A mismatch between documents and embeddings
* <li> Dimensional inconsistency between embeddings
* <li> Embeddings contain {@code NaN}, {@code Infinity}, or null/empty vectors.
*/
void add(List<Document> documents, List<float[]> embeddings);

@Override
default void accept(List<Document> documents) {
add(documents);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
* Abstract base class for {@link VectorStore} implementations that provides observation
Expand Down Expand Up @@ -82,7 +84,29 @@ public void add(List<Document> documents) {
VectorStoreObservationDocumentation.AI_VECTOR_STORE
.observation(this.customObservationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> this.doAdd(documents));
.observe(() -> this.doAdd(documents, this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
this.batchingStrategy)));
}

/**
* Create a new {@link AbstractObservationVectorStore} instance.
* @param documents the documents to add
* @param embeddings the embeddings corresponding to each document
*/
@Override
public void add(List<Document> documents, List<float[]> embeddings) {

VectorStoreObservationContext observationContext = this
.createObservationContextBuilder(VectorStoreObservationContext.Operation.ADD.value())
.build();

VectorStoreObservationDocumentation.AI_VECTOR_STORE
.observation(this.customObservationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
this.validateEmbeddings(documents, embeddings);
this.doAdd(documents, embeddings);
});
}

@Override
Expand Down Expand Up @@ -132,8 +156,9 @@ public List<Document> similaritySearch(SearchRequest request) {
/**
* Perform the actual add operation.
* @param documents the documents to add
* @param embeddings the embeddings corresponding to each document
*/
public abstract void doAdd(List<Document> documents);
public abstract void doAdd(List<Document> documents, List<float[]> embeddings);

/**
* Perform the actual delete operation.
Expand Down Expand Up @@ -167,4 +192,55 @@ protected void doDelete(Filter.Expression filterExpression) {
*/
public abstract VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName);

/**
* Validates a list of documents and their corresponding embeddings.
*
* @param documents The list of documents. Must not be null.
* @param embeddings The list of float[] embeddings corresponding to each document.
* @throws IllegalArgumentException if validation fails for:
* <ul>
* <li> A mismatch between documents and embeddings
* <li> Dimensional inconsistency between embeddings
* <li> Embeddings contain {@code NaN}, {@code Infinity}, or null/empty vectors.
*/
protected void validateEmbeddings(List<Document> documents, List<float[]> embeddings) {
Assert.notNull(documents, "Documents list cannot be null.");
Assert.notNull(embeddings, "Embeddings list cannot be null.");

int docSize = documents.size();
int embSize = embeddings.size();

if (docSize != embSize) {
throw new IllegalArgumentException(
String.format("Mismatch between documents (%d) and embeddings (%d).", docSize, embSize));
}
if (embSize == 0) return;

float[] first = embeddings.get(0);
if (first == null || first.length == 0) {
throw new IllegalArgumentException("First embedding is null or empty.");
}

final int expectedDim = first.length;

for (int i = 0; i < embSize; i++) {
float[] emb = embeddings.get(i);

if (emb == null) {
throw new IllegalArgumentException("Embedding at index " + i + " is null.");
}
if (emb.length != expectedDim) {
throw new IllegalArgumentException(String.format(
"Embedding at index %d has dimension %d, expected %d.", i, emb.length, expectedDim));
}

for (float val : emb) {
if (Float.isNaN(val) || Float.isInfinite(val)) {
throw new IllegalArgumentException(String.format(
"Embedding at index %d contains NaN or Infinite value.", i));
}
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,7 @@
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.*;
import java.util.stream.Collectors;

import org.junit.jupiter.api.BeforeEach;
Expand Down Expand Up @@ -57,7 +52,8 @@ void setUp() {
this.mockEmbeddingModel = mock(EmbeddingModel.class);
when(this.mockEmbeddingModel.dimensions()).thenReturn(3);
when(this.mockEmbeddingModel.embed(any(String.class))).thenReturn(new float[] { 0.1f, 0.2f, 0.3f });
when(this.mockEmbeddingModel.embed(any(Document.class))).thenReturn(new float[] { 0.1f, 0.2f, 0.3f });
when(this.mockEmbeddingModel.embed(any(), any(), any()))
.thenReturn(List.of(new float[] { 0.1f, 0.2f, 0.3f }, new float[] { 0.1f, 0.2f, 0.3f }));
this.vectorStore = new SimpleVectorStore(SimpleVectorStore.builder(this.mockEmbeddingModel));
}

Expand Down Expand Up @@ -86,6 +82,66 @@ void shouldAddMultipleDocuments() {
assertThat(results).hasSize(2).extracting(Document::getId).containsExactlyInAnyOrder("1", "2");
}

@Test
void shouldAddMultipleDocsWithProvidedEmbeddings() {
List<Document> docs = Arrays.asList(Document.builder().id("1").text("first").build(),
Document.builder().id("2").text("second").build());
List<float[]> embeddings = List.of(new float[] {0.1f, 0.2f, 0.3f}, new float[] {0.4f, 0.5f, 0.6f});

this.vectorStore.add(docs, embeddings);

List<Document> results = this.vectorStore.similaritySearch("first");
assertThat(results).hasSize(2).extracting(Document::getId).containsExactlyInAnyOrder("1", "2");
}

@Test
void shouldHandleNullEmbeddingsList() {
assertThatThrownBy(() -> this.vectorStore.add(Collections.emptyList(), null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Embeddings list cannot be null.");
}

@Test
void shouldHandleMismatchDocsAndEmbeddingsList() {
List<float[]> embeddings = List.of(new float[] {0.1f, 0.2f, 0.3f});

assertThatThrownBy(() -> this.vectorStore.add(Collections.emptyList(), embeddings))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Mismatch between documents (0) and embeddings (1).");
}

@Test
void shouldHandleInvalidEmbeddings() {
List<Document> docs = List.of(Document.builder().id("1").text("first").build());

assertThatThrownBy(() -> this.vectorStore.add(docs, List.of(new float[] {Float.NaN})))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Embedding at index 0 contains NaN or Infinite value.");

List<float[]> nullEmbeddings = new ArrayList<>();
nullEmbeddings.add(null);

assertThatThrownBy(() -> this.vectorStore.add(docs, nullEmbeddings))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("First embedding is null or empty.");

List<Document> newDocs = Arrays.asList(Document.builder().id("1").text("first").build(),
Document.builder().id("2").text("second").build());
List<float[]> invalidEmbeddings = new ArrayList<>();
invalidEmbeddings.add(new float[] {0.1f, 0.2f, 0.3f});
invalidEmbeddings.add(null);

assertThatThrownBy(() -> this.vectorStore.add(newDocs, invalidEmbeddings))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Embedding at index 1 is null.");

List<float[]> invalidEmbeddingsDimensions = List.of(new float[] {0.1f, 0.2f, 0.3f}, new float[] {0.1f, 0.2f});

assertThatThrownBy(() -> this.vectorStore.add(newDocs, invalidEmbeddingsDimensions))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Embedding at index 1 has dimension 2, expected 3.");
}

@Test
void shouldHandleEmptyDocumentList() {
assertThatThrownBy(() -> this.vectorStore.add(Collections.emptyList()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ void setUp() {
this.mockEmbeddingModel = mock(EmbeddingModel.class);
when(this.mockEmbeddingModel.dimensions()).thenReturn(3);
when(this.mockEmbeddingModel.embed(any(String.class))).thenReturn(new float[] { 0.1f, 0.2f, 0.3f });
when(this.mockEmbeddingModel.embed(any(Document.class))).thenReturn(new float[] { 0.1f, 0.2f, 0.3f });
when(this.mockEmbeddingModel.embed(any(), any(), any()))
.thenReturn(List.of(new float[] { 0.1f, 0.2f, 0.3f }, new float[] { 0.1f, 0.2f, 0.3f }));
this.vectorStore = SimpleVectorStore.builder(this.mockEmbeddingModel).build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@

import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
Expand Down Expand Up @@ -226,11 +225,7 @@ private JsonNode mapCosmosDocument(Document document, float[] queryEmbedding) {
}

@Override
public void doAdd(List<Document> documents) {

// Batch the documents based on the batching strategy
List<float[]> embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
this.batchingStrategy);
public void doAdd(List<Document> documents, List<float[]> embeddings) {

// Create a list to hold both the CosmosItemOperation and the corresponding
// document ID
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.model.EmbeddingUtils;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
Expand Down Expand Up @@ -151,16 +150,13 @@ public static Builder builder(SearchIndexClient searchIndexClient, EmbeddingMode
}

@Override
public void doAdd(List<Document> documents) {
public void doAdd(List<Document> documents, List<float[]> embeddings) {

Assert.notNull(documents, "The document list should not be null.");
if (CollectionUtils.isEmpty(documents)) {
return; // nothing to do;
}

List<float[]> embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
this.batchingStrategy);

final var searchDocuments = documents.stream().map(document -> {
SearchDocument searchDocument = new SearchDocument();
searchDocument.put(ID_FIELD_NAME, document.getId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.model.EmbeddingUtils;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
Expand Down Expand Up @@ -267,12 +266,9 @@ private static Float[] toFloatArray(float[] embedding) {
}

@Override
public void doAdd(List<Document> documents) {
public void doAdd(List<Document> documents, List<float[]> embeddings) {
var futures = new CompletableFuture[documents.size()];

List<float[]> embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
this.batchingStrategy);

int i = 0;
for (Document d : documents) {
futures[i++] = CompletableFuture.runAsync(() -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.util.JacksonUtils;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
Expand Down Expand Up @@ -147,7 +146,7 @@ public void afterPropertiesSet() throws Exception {
}

@Override
public void doAdd(@NonNull List<Document> documents) {
public void doAdd(@NonNull List<Document> documents, List<float[]> documentEmbeddings) {
Assert.notNull(documents, "Documents must not be null");
if (CollectionUtils.isEmpty(documents)) {
return;
Expand All @@ -158,9 +157,6 @@ public void doAdd(@NonNull List<Document> documents) {
List<String> contents = new ArrayList<>();
List<float[]> embeddings = new ArrayList<>();

List<float[]> documentEmbeddings = this.embeddingModel.embed(documents,
EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);

for (Document document : documents) {
ids.add(document.getId());
metadatas.add(document.getMetadata());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,12 @@ public static Builder builder(Session session, EmbeddingModel embeddingModel) {
}

@Override
public void doAdd(final List<Document> documents) {
public void doAdd(final List<Document> documents, List<float[]> embeddings) {
Map<DocumentChunk.Id, DocumentChunk> chunks = new HashMap<>((int) Math.ceil(documents.size() / 0.75f));
for (Document doc : documents) {
var id = toChunkId(doc.getId());
var chunk = new DocumentChunk(doc.getText(), doc.getMetadata(),
toFloat32Vector(this.embeddingModel.embed(doc)));
toFloat32Vector(embeddings.get(documents.indexOf(doc))));
chunks.put(id, chunk);
}
this.documentChunks.putAll(chunks);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@

import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
Expand Down Expand Up @@ -133,12 +132,10 @@ public void afterPropertiesSet() {
}

@Override
public void doAdd(List<Document> documents) {
public void doAdd(List<Document> documents, List<float[]> embeddings) {
logger.info("Trying Add");
logger.info(this.bucketName);
logger.info(this.scopeName);
List<float[]> embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
this.batchingStrategy);
for (Document document : documents) {
CouchbaseDocument cbDoc = new CouchbaseDocument(document.getId(), document.getText(),
document.getMetadata(), embeddings.get(documents.indexOf(document)));
Expand Down
Loading