Skip to content

Commit

Permalink
Merge pull request #18 from theam/llm-463/unit-tests
Browse files Browse the repository at this point in the history
Llm-463 - Fixed unit tests adding mocks to avoid making actual requests ✨
  • Loading branch information
juanjoman authored Sep 1, 2023
2 parents 80e99c3 + de099e5 commit fa1a8f7
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 57 deletions.
2 changes: 2 additions & 0 deletions modules/embeddingsgeneration/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ dependencies {

testImplementation platform('org.junit:junit-bom:5.9.1')
testImplementation 'org.junit.jupiter:junit-jupiter'
testImplementation 'org.mockito:mockito-core:3.+'
testImplementation 'org.mockito:mockito-junit-jupiter:5.5.0'
}

test {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
* OpenAI `EmbeddingsGenerationModel` implementation.
*/
public class OpenAIEmbeddingsModel extends EmbeddingsGenerationModel {
private static String openAIKey;
private static OpenAiService cachedService;
private final String openAIKey;
public static String embeddingOpenAiModel = "text-embedding-ada-002";

// This attribute needs no modifier to allow injection from tests,
// It is accessible for other classes in this package, but won't be accessible to end users.
OpenAiService openAiService;

/**
* Constructor that initializes the OpenAI embeddings model with an explicit API Key.
*
Expand Down Expand Up @@ -61,12 +64,12 @@ public Embedding generateEmbedding(String inputString) {
}

private OpenAiService getService() {
if (cachedService == null) {
if (openAiService == null) {
if (openAIKey == null) {
throw new MissingRequiredCredentialException("OpenAI API key is required.");
}
cachedService = new OpenAiService(openAIKey);
openAiService = new OpenAiService(openAIKey);
}
return cachedService;
return openAiService;
}
}

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

@ExtendWith(MockitoExtension.class)
public class EmbeddingsSpaceComponentTest {

@Mock
private EmbeddingsGenerationModel embeddingsGenerationModel;
@Mock
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,25 @@
import com.theagilemonkeys.ellmental.embeddingsstore.EmbeddingsStore;
import java.io.IOException;
import java.util.*;
import java.util.stream.Collectors;

import com.theagilemonkeys.ellmental.core.schema.Embedding;
import com.theagilemonkeys.ellmental.core.errors.MissingRequiredCredentialException;
import lombok.Data;
import lombok.RequiredArgsConstructor;
import okhttp3.*;
import static java.net.HttpURLConnection.HTTP_BAD_REQUEST;

/**
* Implementation for Pinecone EmbeddingsStore
*/
@RequiredArgsConstructor
public class PineconeEmbeddingsStore extends EmbeddingsStore {
private final String url;
private final String apiKey;
private final String namespace;

/**
* Constructor that initializes the Pinecone embeddings store with an explicit url and an apiKey without using a namespace.
*
* @param url Pinecone URL.
* @param apiKey Pinecone API key.
*/
public PineconeEmbeddingsStore(String url, String apiKey) {
this(url, apiKey, null);
}
private String namespace;
// This attribute has no modifier to be package-private, so it can be mocked in tests.
// This field will be accessible within this package, but not outside of it.
OkHttpClient httpClient = new OkHttpClient();

/**
* Constructor that initializes the Pinecone embeddings store with an explicit url, apiKey and namespace.
Expand All @@ -55,6 +49,7 @@ public void store(Embedding embedding) {
this.post("/vectors/upsert", requestBodyJson);
} catch (IOException e) {
System.out.println("VectorStore error on upsert: " + e.getMessage());
throw new RuntimeException(e);
}

}
Expand Down Expand Up @@ -116,16 +111,17 @@ public List<Embedding> similaritySearch(Embedding reference, int limit) {
try {
String responseString = this.post("/query", requestBodyJson);
QueryVectorResponseSchema response = new Gson().fromJson(responseString, QueryVectorResponseSchema.class);
ArrayList<Embedding> matchEmbeddings = new ArrayList<>();

for (Match match : response.matches) { // TODO: Make sure this array is sorted by similarity using the score field
matchEmbeddings.add(new Embedding(match.id, match.values, match.metadata));
}

return matchEmbeddings;
// Sort the matches by similarity using the score field and map them to Embedding objects
return response
.matches
.stream()
.sorted((a, b) -> Double.compare(b.score, a.score))
.map(match -> new Embedding(match.id, match.values, match.metadata))
.collect(Collectors.toList());
} catch (IOException e) {
System.out.println("VectorStore error on upsert: " + e.getMessage());
return null;
throw new RuntimeException(e);
}
}

Expand Down Expand Up @@ -194,7 +190,7 @@ private String post(String path, String bodyString) throws IOException {
.post(body)
.build();

try (Response response = new OkHttpClient().newCall(request).execute()) {
try (Response response = httpClient.newCall(request).execute()) {
if (response.code() >= HTTP_BAD_REQUEST) {
throw new IOException(url);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,11 @@
import static org.mockito.Mockito.*;

import com.theagilemonkeys.ellmental.embeddingsstore.EmbeddingsStore;
import okhttp3.Call;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import okhttp3.*;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Spy;
import org.mockito.*;
import org.mockito.junit.jupiter.MockitoExtension;

import java.io.IOException;
Expand All @@ -27,14 +22,11 @@

@ExtendWith(MockitoExtension.class)
public class PineconeEmbeddingsStoreTest {
//TODO: add mock reponse to improve store test

@InjectMocks
private PineconeEmbeddingsStore pineconeEmbeddingsStore;
@Spy
private OkHttpClient okHttpClient;
@Mock
private Call httpCall;
private OkHttpClient httpClient;
@Mock
private Call remoteCall;

private final String url = "https://pinecone.url";
private final String apiKey = "API_KEY";
Expand All @@ -47,26 +39,92 @@ public class PineconeEmbeddingsStoreTest {
@BeforeEach
void setup() {
pineconeEmbeddingsStore = new PineconeEmbeddingsStore(url, apiKey, namespace);
pineconeEmbeddingsStore.httpClient = httpClient;
}


@Test
public void testStore() throws IOException {
Response response = new Response.Builder()
.request(new Request.Builder().url(url + "/vectors/upsert").build())
.protocol(Protocol.HTTP_1_1)
.code(200).message("").body(ResponseBody.create(MediaType.parse("application/json"), ""))
.build();
when(remoteCall.execute()).thenReturn(response);
when(httpClient.newCall(any())).thenReturn(remoteCall);

Map<String, String> metadata = new HashMap<>();
metadata.put("key1", "value1");
metadata.put("key2", "value2");

//when(okHttpClient.newCall(any(Request.class)).execute()).thenReturn(null);

Embedding embedding = new Embedding(UUID.randomUUID(), vectorExpectedValue, metadata);

pineconeEmbeddingsStore.store(embedding);
assertDoesNotThrow(() -> pineconeEmbeddingsStore.store(embedding));
}

@Test
public void testStoreBadRequest() throws IOException {
Response response = new Response.Builder()
.request(new Request.Builder().url(url + "/vectors/upsert").build())
.protocol(Protocol.HTTP_1_1)
.code(400).message("").body(ResponseBody.create(MediaType.parse("application/json"), ""))
.build();
when(remoteCall.execute()).thenReturn(response);
when(httpClient.newCall(any())).thenReturn(remoteCall);

Map<String, String> metadata = new HashMap<>();
metadata.put("key1", "value1");
metadata.put("key2", "value2");

assertEquals(embedding.vector().size(), vectorExpectedValue.size());
assertArrayEquals(embedding.vector().toArray(), vectorExpectedValue.toArray());
Embedding embedding = new Embedding(UUID.randomUUID(), vectorExpectedValue, metadata);

verify(okHttpClient).newCall(any(Request.class));
assertThrows(RuntimeException.class, () -> pineconeEmbeddingsStore.store(embedding));
}

// TODO: get and delete methods
}
@Test
public void testSimilaritySearch() throws IOException {
UUID fakeUUID1 = UUID.randomUUID();
UUID fakeUUID2 = UUID.randomUUID();
Response response = new Response.Builder()
.request(new Request.Builder().url(url + "/query").build())
.protocol(Protocol.HTTP_1_1)
.code(200).message("").body(ResponseBody.create(
MediaType.parse("application/json"),
"{ \"matches\": [ { \"id\": \"" + fakeUUID1 + "\", \"score\": 0.5, \"values\": [ 0.1, 0.2, 0.3 ], \"metadata\": { \"key1\": \"value1\" } }, { \"id\": \"" + fakeUUID2 + "\", \"score\": 0.7, \"values\": [ 0.4, 0.5, 0.6 ], \"metadata\": { \"key2\": \"value2\" } } ] }"
))
.build();
when(remoteCall.execute()).thenReturn(response);
when(httpClient.newCall(any())).thenReturn(remoteCall);
Embedding embedding = new Embedding(UUID.randomUUID(), List.of(0.1, 0.2, 0.3), Map.of("key1", "value1", "key2", "value2"));

List<Embedding> embeddings = pineconeEmbeddingsStore.similaritySearch(embedding, 2);

// It includes all the embeddings in the response
assertEquals(2, embeddings.size());

// Embeddings come in the right order and the UUIDs were parsed correctly
assertEquals(fakeUUID2, embeddings.get(0).id());
assertEquals(fakeUUID1, embeddings.get(1).id());

// Embeddings values were parsed correctly
assertEquals(List.of(0.4, 0.5, 0.6), embeddings.get(0).vector());
assertEquals(List.of(0.1, 0.2, 0.3), embeddings.get(1).vector());

// Embeddings metadata was parsed correctly
assertEquals(Map.of("key2", "value2"), embeddings.get(0).metadata());
assertEquals(Map.of("key1", "value1"), embeddings.get(1).metadata());
}

@Test
public void testSimilaritySearchBadRequest() throws IOException {
Response response = new Response.Builder()
.request(new Request.Builder().url(url + "/query").build())
.protocol(Protocol.HTTP_1_1)
.code(400).message("").body(ResponseBody.create(MediaType.parse("application/json"), ""))
.build();
when(remoteCall.execute()).thenReturn(response);
when(httpClient.newCall(any())).thenReturn(remoteCall);
Embedding embedding = new Embedding(UUID.randomUUID(), List.of(0.1, 0.2, 0.3), Map.of("key1", "value1", "key2", "value2"));

assertThrows(RuntimeException.class, () -> pineconeEmbeddingsStore.similaritySearch(embedding, 2));
}
}

0 comments on commit fa1a8f7

Please sign in to comment.