Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get and delete embeddings space #20

Merged
merged 11 commits into from
Sep 1, 2023
15 changes: 13 additions & 2 deletions docs-site/docs/03_components/01_core_abstractions.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,23 @@ public abstract class EmbeddingsStore {

### `PineconeEmbeddingsStore`

eLLMental provides a concrete implementation for Pinecone, which requires defining an URL, an API Key and a space.
eLLMental provides a concrete implementation for Pinecone, which requires defining a URL, an API Key, and a namespace.

```java
EmbeddingsStore pineconeStore = new PineconeEmbeddingsStore("YOUR_PINECONE_URL", "YOUR_PINECONE_API_KEY", "YOUR_PINECONE_NAMESPACE");

// You can now insert or perform similarity searches using the pineconeStore instance:
// You can now insert, fetch, delete or perform similarity searches using the pineconeStore instance:
pineconeStore.store(someEmbedding);
List<Embedding> similarEmbeddings = pineconeStore.similaritySearch(referenceEmbedding, 5);

// Get back an embedding by id without a namespace
UUID someEmbeddingId = someEmbedding.id()
Embedding embedding = pineconeStore.get(someEmbeddingId);

// Get a namespaced embedding by id
UUID anotherEmbeddingId = UUID.fromString("01870603-f211-7b9a-a7ea-4a98f5320ff8")
Embedding anotherEmbedding = pineconeStore.get(anotherEmbeddingId, "my-namespace");

// Delete an embedding by id
pineconeStore.delete(anotherEmbeddingId);
```
3 changes: 3 additions & 0 deletions modules/embeddingsspace/build.gradle
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
plugins {
id 'java'
id "io.freefair.lombok" version "8.3"
}

group = "com.theagilemonkeys.ellmental"
Expand All @@ -17,6 +18,8 @@ dependencies {

testImplementation platform('org.junit:junit-bom:5.9.1')
testImplementation 'org.junit.jupiter:junit-jupiter'
testImplementation 'org.mockito:mockito-core:3.+'
testImplementation 'org.mockito:mockito-junit-jupiter:5.5.0'
}

test {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,17 @@
import com.theagilemonkeys.ellmental.core.schema.Embedding;
import com.theagilemonkeys.ellmental.embeddingsgeneration.EmbeddingsGenerationModel;
import com.theagilemonkeys.ellmental.embeddingsstore.EmbeddingsStore;
import lombok.RequiredArgsConstructor;

import java.util.List;
import java.util.UUID;

@RequiredArgsConstructor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😲

public class EmbeddingsSpaceComponent {

private final EmbeddingsGenerationModel embeddingsGenerationModel;
private final EmbeddingsStore embeddingsStore;

public EmbeddingsSpaceComponent(EmbeddingsGenerationModel embeddingsGenerationModel, EmbeddingsStore embeddingsStore) {
this.embeddingsGenerationModel = embeddingsGenerationModel;
this.embeddingsStore = embeddingsStore;
}

private Embedding generate(String text) {
return embeddingsGenerationModel.generateEmbedding(text);
}
Expand All @@ -36,4 +34,11 @@ public List<Embedding> mostSimilarEmbeddings(String referenceText, int limit) {
return mostSimilarEmbeddings(embedding, limit);
}

public Embedding get(UUID uuid) {
return embeddingsStore.get(uuid);
}

public void delete(UUID uuid) {
embeddingsStore.delete(uuid);
}
}
Original file line number Diff line number Diff line change
@@ -1,21 +1,100 @@
package com.theagilemonkeys.ellmental.embeddingsspace;

import com.theagilemonkeys.ellmental.core.schema.Embedding;
import com.theagilemonkeys.ellmental.embeddingsgeneration.EmbeddingsGenerationModel;
import com.theagilemonkeys.ellmental.embeddingsstore.EmbeddingsStore;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.*;
import org.mockito.junit.jupiter.MockitoExtension;

import static org.junit.jupiter.api.Assertions.assertTrue;
import java.util.List;
import java.util.UUID;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.*;

@ExtendWith(MockitoExtension.class)
public class EmbeddingsSpaceComponentTest {

// TODO: Use mockito to Mock these
// @Spy
// private final EmbeddingsGenerationModel model;
// @Spy
// private final EmbeddingsStore store;
// @Inject
// private final EmbeddingsSpaceComponent embeddingsSpaceComponent;
@Mock
private EmbeddingsGenerationModel embeddingsGenerationModel;
@Mock
private EmbeddingsStore embeddingsStore;
@InjectMocks
private EmbeddingsSpaceComponent embeddingsSpaceComponent;
@Captor
private ArgumentCaptor<String> sampleTextCaptor;
Comment on lines +20 to +27
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


private final String sampleText = "sample text";
private final UUID embeddingId = UUID.randomUUID();
private final Embedding embeddingMock = new Embedding(embeddingId, List.of(1.0, 2.0, 3.0), null);

@Test
public void testSaveEmbedding() {
when(embeddingsGenerationModel.generateEmbedding(sampleText)).thenReturn(embeddingMock);

embeddingsSpaceComponent.save(sampleText);

verify(embeddingsGenerationModel).generateEmbedding(sampleTextCaptor.capture());
String capturedSampleText = sampleTextCaptor.getValue();
assertEquals(sampleText, capturedSampleText);

verify(embeddingsStore).store(embeddingMock);

verifyNoMoreInteractions(embeddingsGenerationModel, embeddingsStore);
}

@Test
public void testMostSimilarEmbeddings() {
int limit = 1;
Embedding similarEmbeddingMock = new Embedding(UUID.randomUUID(), List.of(4.0, 5.0, 6.0), null);

when(embeddingsStore.similaritySearch(embeddingMock, limit)).thenReturn(List.of(similarEmbeddingMock));

List<Embedding> embeddings = embeddingsSpaceComponent.mostSimilarEmbeddings(embeddingMock, limit);

assertEquals(embeddings.get(0), similarEmbeddingMock);

verifyNoMoreInteractions(embeddingsGenerationModel, embeddingsStore);
}
juanjoman marked this conversation as resolved.
Show resolved Hide resolved

@Test
public void testMostSimilarEmbeddingsWhenUsingString() {
int limit = 1;
Embedding similarEmbeddingMock = new Embedding(UUID.randomUUID(), List.of(4.0, 5.0, 6.0), null);

when(embeddingsGenerationModel.generateEmbedding(sampleText)).thenReturn(embeddingMock);
when(embeddingsStore.similaritySearch(embeddingMock, limit)).thenReturn(List.of(similarEmbeddingMock));

List<Embedding> embeddings = embeddingsSpaceComponent.mostSimilarEmbeddings(sampleText, limit);

assertEquals(embeddings.get(0), similarEmbeddingMock);

verify(embeddingsGenerationModel).generateEmbedding(sampleTextCaptor.capture());
String capturedSampleText = sampleTextCaptor.getValue();
assertEquals(sampleText, capturedSampleText);

verifyNoMoreInteractions(embeddingsGenerationModel, embeddingsStore);
}

@Test
public void testGetEmbedding() {
when(embeddingsStore.get(embeddingId)).thenReturn(embeddingMock);

Embedding embedding = embeddingsSpaceComponent.get(embeddingId);

assertEquals(embeddingMock, embedding);

verifyNoMoreInteractions(embeddingsGenerationModel, embeddingsStore);
}

@Test
public void todo() {
assertTrue(true);
public void testDeleteEmbedding() {
embeddingsSpaceComponent.delete(embeddingId);

verify(embeddingsStore).delete(embeddingId);

verifyNoMoreInteractions(embeddingsGenerationModel, embeddingsStore);
}
}
}
3 changes: 3 additions & 0 deletions modules/embeddingsstore/build.gradle
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
plugins {
id 'java'
id "io.freefair.lombok" version "8.3"
}

group = 'com.theagilemonkeys.ellemental'
Expand All @@ -16,6 +17,8 @@ dependencies {

testImplementation platform('org.junit:junit-bom:5.9.1')
testImplementation 'org.junit.jupiter:junit-jupiter'
testImplementation 'org.mockito:mockito-core:3.+'
testImplementation 'org.mockito:mockito-junit-jupiter:5.5.0'
}

test {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.util.List;
import java.util.Map;
import java.util.UUID;

import com.theagilemonkeys.ellmental.core.schema.Embedding;

Expand All @@ -11,5 +12,7 @@
public abstract class EmbeddingsStore {
//TODO: check if there is an issue using a map instead of a metadata class
public abstract void store(Embedding embedding);
public abstract Embedding get(UUID uuid);
public abstract void delete(UUID uuid);
public abstract List<Embedding> similaritySearch(Embedding reference, int limit);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.theagilemonkeys.ellmental.embeddingsstore.pinecone;

import lombok.Data;

import java.util.List;
import java.util.Map;
import java.util.UUID;

@Data
public class DeleteVectorSchema {
private final List<UUID> ids;
private boolean deleteAll;
private String namespace;
private Map<String, String> filterMetadata;
}
Comment on lines +9 to +15
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that this @Dataannotation serves the same purpose as Java records. What would you consider more idiomatic? I think it could be worth deciding on one or the other and using a unique and consistent style in the whole project 😜

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooh I just checked and they are almost the same. In this case @DaTa includes @requiredargsconstructor, which creates a constructor only for final properties. For now we just need the list of ids, so we can keep it like it is or create a record that only contains this list for now, what do you think is the best approach? @javiertoledo

Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.theagilemonkeys.ellmental.embeddingsstore.pinecone;

import lombok.Data;

import java.util.List;

@Data
public class FetchVectorResponseSchema {
private final List<Vector> vectors;
private final String namespace;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.theagilemonkeys.ellmental.embeddingsstore.pinecone;

import lombok.AllArgsConstructor;
import lombok.Data;

import java.util.List;
import java.util.UUID;

@Data
@AllArgsConstructor
public class FetchVectorSchema {
private final List<UUID> ids;
private String namespace;
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import com.google.gson.Gson;
import com.theagilemonkeys.ellmental.embeddingsstore.EmbeddingsStore;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.*;

import com.theagilemonkeys.ellmental.core.schema.Embedding;
import com.theagilemonkeys.ellmental.core.errors.MissingRequiredCredentialException;
import lombok.Data;
import lombok.RequiredArgsConstructor;
import okhttp3.*;
import static java.net.HttpURLConnection.HTTP_BAD_REQUEST;

Expand Down Expand Up @@ -57,6 +59,51 @@ public void store(Embedding embedding) {

}

/**
* Retrieves and embedding in the embeddings store based on its UUID.
* @param uuid UUID to retrieve.
*/
public Embedding get(UUID uuid) {
Embedding embedding = null;
try {
embedding = this.fetch(List.of(uuid), null);
} catch (IOException e) {
System.out.println("VectorStore error on fetch: " + e.getMessage());
}
return embedding;
}

/**
* Retrieves and embedding in the embeddings store based on its UUID.
* @param uuid UUID to retrieve.
* @param namespace String to look into.
*/
public Embedding get(UUID uuid, String namespace) {
Embedding embedding = null;
try {
embedding = this.fetch(List.of(uuid), namespace);
} catch (IOException e) {
System.out.println("VectorStore error on fetch: " + e.getMessage());
}
return embedding;
}

/**
* Deletes an embedding in the embeddings store based on its UUID.
* @param uuid UUID to delete.
*/
public void delete(UUID uuid) {
DeleteVectorSchema request = new DeleteVectorSchema(List.of(uuid));
String requestBodyJson = new Gson().toJson(request);

try {
this.post("/vectors/delete", requestBodyJson);
} catch (IOException e) {
System.out.println("VectorStore error on delete: " + e.getMessage());
}

}

public List<Embedding> similaritySearch(Embedding reference, int limit) {
QueryVectorRequestSchema body = new QueryVectorRequestSchema(
limit,
Expand Down Expand Up @@ -92,6 +139,45 @@ private boolean validateEnvironment() {
return true;
}

private Embedding fetch(List<UUID> ids, String namespace) throws IOException {
if (!validateEnvironment()) {
return null;
}

HttpUrl.Builder urlBuilder = Objects.requireNonNull(HttpUrl.parse(this.url + "/vectors/fetch")).newBuilder();
urlBuilder.addQueryParameter("ids", ids.toString());
if (namespace != null) {
urlBuilder.addQueryParameter("namespace", namespace);
}

HttpUrl url = urlBuilder.build();

Request request = new Request.Builder()
.url(url)
.header("accept", "application/json")
.header("content-type", "application/json")
.header("Api-Key", apiKey)
.get()
.build();

try (Response response = new OkHttpClient().newCall(request).execute()) {
if (response.code() >= HTTP_BAD_REQUEST) {
throw new IOException(this.url);
}

ResponseBody responseBody = response.body();

if (responseBody != null) {
String json = responseBody.string();
FetchVectorResponseSchema responseSchema = new Gson().fromJson(json, FetchVectorResponseSchema.class);
Vector vector = responseSchema.getVectors().get(0);
return new Embedding(vector.getId(), vector.getValues(), vector.getMetadata());
} else {
return null;
}
}
}

private String post(String path, String bodyString) throws IOException {
if (!validateEnvironment()) {
return null;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.theagilemonkeys.ellmental.embeddingsstore.pinecone;

import java.util.List;

public class SparseValues {
public List<Integer> indices;
public List<Integer> values;
}
Loading
Loading