Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added similarity search by string to vector store #153

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import ai.knowly.langtorch.store.vectordb.integration.pgvector.schema.PGVectorStoreSpec;
import ai.knowly.langtorch.store.vectordb.integration.pgvector.schema.PGVectorValues;
import ai.knowly.langtorch.store.vectordb.integration.pgvector.schema.distance.DistanceStrategy;
import ai.knowly.langtorch.store.vectordb.integration.schema.SimilaritySearchQuery;
import ai.knowly.langtorch.store.vectordb.integration.schema.StringSimilaritySearchQuery;
import ai.knowly.langtorch.store.vectordb.integration.schema.VectorSimilaritySearchQuery;
import com.google.common.flogger.FluentLogger;
import com.google.common.primitives.Floats;
import com.google.inject.Inject;
Expand Down Expand Up @@ -106,11 +107,21 @@ public boolean addDocuments(List<DomainDocument> documents) {
}

/**
* Performs a similarity search using a vector query and returns a list of pairs containing the
* schema documents and their corresponding similarity scores.
* Performs a similarity search using a vector representation of passed string and returns a
* list of documents containing their corresponding similarity scores.
*/
@Override
public List<DomainDocument> similaritySearch(SimilaritySearchQuery similaritySearchQuery) {
public List<DomainDocument> similaritySearch(StringSimilaritySearchQuery similaritySearchQuery) {
List<Double> vector = createVector(similaritySearchQuery.getQuery());
return similaritySearch(similaritySearchQuery.toVectorSimilaritySearchQuery(vector));
}

/**
* Performs a similarity search using a vector query and returns a list of documents
* containing their corresponding similarity scores.
*/
@Override
public List<DomainDocument> similaritySearch(VectorSimilaritySearchQuery similaritySearchQuery) {
float[] queryVectorValuesAsFloats = getFloatVectorValues(similaritySearchQuery.getQuery());
double[] queryVectorValuesAsDoubles = getDoubleVectorValues(queryVectorValuesAsFloats);
List<DomainDocument> documentsWithScores;
Expand Down Expand Up @@ -280,12 +291,16 @@ private PGVectorQueryParameters buildPGVectorQueryParameters(
}

private List<Double> createVector(DomainDocument document) {
return createVector(document.getPageContent());
}

private List<Double> createVector(String input) {
EmbeddingOutput embeddingOutput =
embeddingsProcessor.run(
EmbeddingInput.builder()
.setModel(pgVectorStoreSpec.getModel())
.setInput(Collections.singletonList(document.getPageContent()))
.build());
embeddingsProcessor.run(
EmbeddingInput.builder()
.setModel(pgVectorStoreSpec.getModel())
.setInput(Collections.singletonList(input))
.build());
return embeddingOutput.getValue().get(0).getVector();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
package ai.knowly.langtorch.store.vectordb.integration;

import ai.knowly.langtorch.schema.io.DomainDocument;
import ai.knowly.langtorch.store.vectordb.integration.schema.SimilaritySearchQuery;
import ai.knowly.langtorch.store.vectordb.integration.schema.StringSimilaritySearchQuery;
import ai.knowly.langtorch.store.vectordb.integration.schema.VectorSimilaritySearchQuery;
import java.util.List;

/** A shared interface for all Vector Store Databases */
public interface VectorStore {

boolean addDocuments(List<DomainDocument> documents);

List<DomainDocument> similaritySearch(SimilaritySearchQuery similaritySearchQuery);
List<DomainDocument> similaritySearch(VectorSimilaritySearchQuery similaritySearchQuery);

List<DomainDocument> similaritySearch(StringSimilaritySearchQuery similaritySearchQuery);

boolean updateDocuments(List<DomainDocument> documents);

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.update.UpdateResponse;
import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.upsert.UpsertRequest;
import ai.knowly.langtorch.store.vectordb.integration.pinecone.schema.dto.upsert.UpsertResponse;
import ai.knowly.langtorch.store.vectordb.integration.schema.StringSimilaritySearchQuery;
import com.google.common.collect.ImmutableList;

import ai.knowly.langtorch.store.vectordb.integration.schema.SimilaritySearchQuery;
import ai.knowly.langtorch.store.vectordb.integration.schema.VectorSimilaritySearchQuery;
import com.google.common.flogger.FluentLogger;
import lombok.NonNull;

Expand Down Expand Up @@ -85,12 +86,7 @@ private boolean addVectors(List<Vector> vectors) {
* @return an instance of {@link Vector}
*/
private Vector createVector(DomainDocument document) {
EmbeddingOutput embeddingOutput =
embeddingProcessor.run(
EmbeddingInput.builder()
.setModel(pineconeVectorStoreSpec.getModel())
.setInput(Collections.singletonList(document.getPageContent()))
.build());
EmbeddingOutput embeddingOutput = createEmbeddingOutput(document.getPageContent());
return Vector.builder()
.setId(document.getId().orElse(UUID.randomUUID().toString()))
.setMetadata(document.getMetadata().orElse(Metadata.getDefaultInstance()).getValue())
Expand All @@ -99,11 +95,22 @@ private Vector createVector(DomainDocument document) {
}

/**
* Performs a similarity search using a vector query and returns a list of pairs containing the
* schema documents and their corresponding similarity scores.
* Performs a similarity search using a vector represenation of string query and returns a list of documents
* containing their corresponding similarity scores.
*/
@Override
public List<DomainDocument> similaritySearch(SimilaritySearchQuery similaritySearchQuery) {
public List<DomainDocument> similaritySearch(StringSimilaritySearchQuery similaritySearchQuery) {
List<Double> vector = createEmbeddingOutput(similaritySearchQuery.getQuery())
.getValue().get(0).getVector();
return similaritySearch(similaritySearchQuery.toVectorSimilaritySearchQuery(vector));
}

/**
* Performs a similarity search using a vector query and returns a list of documents
* containing their corresponding similarity scores.
*/
@Override
public List<DomainDocument> similaritySearch(VectorSimilaritySearchQuery similaritySearchQuery) {

QueryRequest.QueryRequestBuilder requestBuilder =
QueryRequest.builder()
Expand All @@ -120,7 +127,7 @@ public List<DomainDocument> similaritySearch(SimilaritySearchQuery similaritySea
// create mapping of PineCone metadata to schema meta data
if (response.getMatches() != null) {
for (Match match : response.getMatches()) {
if (!pineconeVectorStoreSpec.getTextKey().isPresent()) {
if (pineconeVectorStoreSpec.getTextKey().isEmpty()) {
continue;
}
Metadata metadata =
Expand All @@ -145,7 +152,7 @@ public List<DomainDocument> similaritySearch(SimilaritySearchQuery similaritySea

@Override
public boolean updateDocuments(List<DomainDocument> documents) {
if (!this.executorService.isPresent()) {
if (this.executorService.isEmpty()) {
this.executorService = Optional.of(Executors.newFixedThreadPool(16));
}
ExecutorService localExecutorService = this.executorService.get();
Expand Down Expand Up @@ -212,4 +219,12 @@ public boolean deleteDocumentsByIds(List<String> documentsIds) {
}
return true;
}

private EmbeddingOutput createEmbeddingOutput(String input) {
return embeddingProcessor.run(
EmbeddingInput.builder()
.setModel(pineconeVectorStoreSpec.getModel())
.setInput(Collections.singletonList(input))
.build());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package ai.knowly.langtorch.store.vectordb.integration.schema;

import lombok.*;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

@Data
@Builder(toBuilder = true, setterPrefix = "set")
@AllArgsConstructor(access = AccessLevel.PRIVATE)
public class StringSimilaritySearchQuery {
@Builder.Default Map<String, String> filter = new HashMap<>();
@NonNull private String query;
@NonNull private Long topK;

public VectorSimilaritySearchQuery toVectorSimilaritySearchQuery(List<Double> vector) {
return VectorSimilaritySearchQuery.builder()
.setFilter(getFilter())
.setTopK(getTopK())
.setQuery(vector)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@Data
@Builder(toBuilder = true, setterPrefix = "set")
@AllArgsConstructor(access = AccessLevel.PRIVATE)
public class SimilaritySearchQuery {
public class VectorSimilaritySearchQuery {
@Builder.Default Map<String, String> filter = new HashMap<>();
@NonNull private List<Double> query;
@NonNull private Long topK;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import ai.knowly.langtorch.store.vectordb.PGVectorStore;
import ai.knowly.langtorch.store.vectordb.integration.pgvector.schema.PGVectorStoreSpec;
import ai.knowly.langtorch.store.vectordb.integration.pgvector.schema.distance.DistanceStrategies;
import ai.knowly.langtorch.store.vectordb.integration.schema.SimilaritySearchQuery;
import ai.knowly.langtorch.store.vectordb.integration.schema.StringSimilaritySearchQuery;
import ai.knowly.langtorch.store.vectordb.integration.schema.VectorSimilaritySearchQuery;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.pgvector.PGvector;
import kotlin.Triple;
Expand Down Expand Up @@ -77,14 +79,93 @@ void testAddDocuments() throws SQLException {
assertThat(isSuccessful).isEqualTo(true);
}


@Test
void testSimilaritySearchStringWithScoreEuclidean() throws SQLException {
pgVectorStore =
new PGVectorStore(
embeddingProcessor, pgVectorStoreSpec, pgVectorService, DistanceStrategies.euclidean());

Triple<String, String, StringSimilaritySearchQuery> queryData = prepareSimilaritySearchStringQuery();
StringSimilaritySearchQuery query = queryData.getThird();
String firstPageContent = queryData.getFirst();
String secondPageContent = queryData.getSecond();

// Act.
List<DomainDocument> documentsWithScores = pgVectorStore.similaritySearch(query);
// Assert.
double firstDocumentScore = documentsWithScores.get(0).getSimilarityScore().orElse(-1.0);
double secondDocumentScore = documentsWithScores.get(1).getSimilarityScore().orElse(-1.0);
String firstDocumentPageContent = documentsWithScores.get(0).getPageContent();
String secondDocumentPageContent = documentsWithScores.get(1).getPageContent();
assertThat(documentsWithScores.size()).isEqualTo(3);
assertThat(firstDocumentScore).isEqualTo(0);
assertThat(firstDocumentScore).isLessThan(secondDocumentScore);
assertThat(firstDocumentPageContent).isEqualTo(firstPageContent);
assertThat(secondDocumentPageContent).isEqualTo(secondPageContent);
}

@Test
void testSimilaritySearchStringWithScoreInnerProduct() throws SQLException {
pgVectorStore =
new PGVectorStore(
embeddingProcessor,
pgVectorStoreSpec,
pgVectorService,
DistanceStrategies.innerProduct());

Triple<String, String, StringSimilaritySearchQuery> queryData = prepareSimilaritySearchStringQuery();
StringSimilaritySearchQuery query = queryData.getThird();
String firstPageContent = queryData.getFirst();
String secondPageContent = queryData.getSecond();

// Act.
List<DomainDocument> documentsWithScores = pgVectorStore.similaritySearch(query);
// Assert.
double firstDocumentScore = documentsWithScores.get(0).getSimilarityScore().orElse(-1.0);
double secondDocumentScore = documentsWithScores.get(1).getSimilarityScore().orElse(-1.0);
String firstDocumentPageContent = documentsWithScores.get(0).getPageContent();
String secondDocumentPageContent = documentsWithScores.get(1).getPageContent();
assertThat(documentsWithScores.size()).isEqualTo(3);
assertThat(firstDocumentScore).isEqualTo(3);
assertThat(firstDocumentScore).isLessThan(secondDocumentScore);
assertThat(firstDocumentPageContent).isEqualTo(firstPageContent);
assertThat(secondDocumentPageContent).isEqualTo(secondPageContent);
}

@Test
void testSimilaritySearchStringWithScoreCosine() throws SQLException {
pgVectorStore =
new PGVectorStore(
embeddingProcessor, pgVectorStoreSpec, pgVectorService, DistanceStrategies.cosine());

Triple<String, String, StringSimilaritySearchQuery> queryData = prepareSimilaritySearchStringQuery();
StringSimilaritySearchQuery query = queryData.getThird();
String firstPageContent = queryData.getFirst();
String secondPageContent = queryData.getSecond();

// Act.
List<DomainDocument> documentsWithScores = pgVectorStore.similaritySearch(query);
// Assert.
double firstDocumentScore = documentsWithScores.get(0).getSimilarityScore().orElse(-1.0);
double secondDocumentScore = documentsWithScores.get(1).getSimilarityScore().orElse(-1.0);
String firstDocumentPageContent = documentsWithScores.get(0).getPageContent();
String secondDocumentPageContent = documentsWithScores.get(1).getPageContent();
assertThat(documentsWithScores.size()).isEqualTo(3);
assertThat(Math.abs(firstDocumentScore - TOP_VECTOR_VALUE))
.isLessThan(Math.abs(secondDocumentScore - TOP_VECTOR_VALUE));
assertThat(firstDocumentPageContent).isEqualTo(firstPageContent);
assertThat(secondDocumentPageContent).isEqualTo(secondPageContent);
}

@Test
void testSimilaritySearchVectorWithScoreEuclidean() throws SQLException {
pgVectorStore =
new PGVectorStore(
embeddingProcessor, pgVectorStoreSpec, pgVectorService, DistanceStrategies.euclidean());

Triple<String, String, SimilaritySearchQuery> queryData = prepareSimilaritySearchQuery();
SimilaritySearchQuery query = queryData.getThird();
Triple<String, String, VectorSimilaritySearchQuery> queryData = prepareSimilaritySearchQuery();
VectorSimilaritySearchQuery query = queryData.getThird();
String firstPageContent = queryData.getFirst();
String secondPageContent = queryData.getSecond();

Expand All @@ -111,8 +192,8 @@ void testSimilaritySearchVectorWithScoreInnerProduct() throws SQLException {
pgVectorService,
DistanceStrategies.innerProduct());

Triple<String, String, SimilaritySearchQuery> queryData = prepareSimilaritySearchQuery();
SimilaritySearchQuery query = queryData.getThird();
Triple<String, String, VectorSimilaritySearchQuery> queryData = prepareSimilaritySearchQuery();
VectorSimilaritySearchQuery query = queryData.getThird();
String firstPageContent = queryData.getFirst();
String secondPageContent = queryData.getSecond();

Expand All @@ -136,8 +217,8 @@ void testSimilaritySearchVectorWithScoreCosine() throws SQLException {
new PGVectorStore(
embeddingProcessor, pgVectorStoreSpec, pgVectorService, DistanceStrategies.cosine());

Triple<String, String, SimilaritySearchQuery> queryData = prepareSimilaritySearchQuery();
SimilaritySearchQuery query = queryData.getThird();
Triple<String, String, VectorSimilaritySearchQuery> queryData = prepareSimilaritySearchQuery();
VectorSimilaritySearchQuery query = queryData.getThird();
String firstPageContent = queryData.getFirst();
String secondPageContent = queryData.getSecond();

Expand Down Expand Up @@ -200,7 +281,7 @@ private List<Embedding> getEmbeddings() {
return embeddings;
}

private Triple<String, String, SimilaritySearchQuery> prepareSimilaritySearchQuery()
private Triple<String, String, VectorSimilaritySearchQuery> prepareSimilaritySearchQuery()
throws SQLException {
String firstPageContent = "content 0";
String secondPageContent = "content 1";
Expand All @@ -221,8 +302,38 @@ private Triple<String, String, SimilaritySearchQuery> prepareSimilaritySearchQue
when(resultSet.getObject(3)).thenReturn(textKey);
when(resultSet.getObject(4)).thenReturn(firstPageContent, secondPageContent);
double v = TOP_VECTOR_VALUE;
SimilaritySearchQuery query =
SimilaritySearchQuery.builder().setTopK(0L).setQuery(Arrays.asList(v, v, v)).build();
VectorSimilaritySearchQuery query =
VectorSimilaritySearchQuery.builder().setTopK(0L).setQuery(Arrays.asList(v, v, v)).build();
return new Triple<>(firstPageContent, secondPageContent, query);
}

private Triple<String, String, StringSimilaritySearchQuery> prepareSimilaritySearchStringQuery()
throws SQLException {
String firstPageContent = "content 0";
String secondPageContent = "content 1";
ResultSet resultSet = Mockito.mock(ResultSet.class);
double v = TOP_VECTOR_VALUE;
List<Embedding> embeddings = ImmutableList.of(Embedding.of(Arrays.asList(v, v, v)));
EmbeddingOutput embeddingOutput = EmbeddingOutput.of(EmbeddingType.OPEN_AI, embeddings);
when(embeddingProcessor.run(ArgumentMatchers.any())).thenReturn(embeddingOutput);
when(pgVectorService.prepareStatement(ArgumentMatchers.any())).thenReturn(preparedStatement);
when(preparedStatement.executeQuery()).thenReturn(resultSet);
when(resultSet.next()).thenReturn(true, true, true, false);
when(resultSet.getObject(1))
.thenReturn(
UUID.randomUUID().toString(),
UUID.randomUUID().toString(),
UUID.randomUUID().toString());
when(resultSet.getObject(2))
.thenReturn(
new PGvector(new float[] {TOP_VECTOR_VALUE, TOP_VECTOR_VALUE, TOP_VECTOR_VALUE}),
new PGvector(new float[] {2.1f, 2.2f, 2.3f}),
new PGvector(new float[] {-3, -3, -3}));
when(resultSet.getObject(3)).thenReturn(textKey);
when(resultSet.getObject(4)).thenReturn(firstPageContent, secondPageContent);

StringSimilaritySearchQuery query =
StringSimilaritySearchQuery.builder().setTopK(0L).setQuery("").build();
return new Triple<>(firstPageContent, secondPageContent, query);
}
}
Loading