From c6d7e9fd142954592b42cc6e2d363ea0c19cb2cd Mon Sep 17 00:00:00 2001 From: Andre Duarte Date: Sun, 20 Aug 2023 05:43:57 -0300 Subject: [PATCH] Add: added embedding search fucntion to store also cleaned up some code --- ...vironmentVariableNotDeclaredException.java | 8 +++ .../ellmental/core/schema/Embedding.java | 4 ++ .../ellmental/helloworld/Main.java | 17 +++---- .../openai/OpenAIEmbeddingsModel.java | 21 ++++---- .../embeddingsstore/EmbeddingsStore.java | 2 +- .../pinecone/PineconeEmbeddingsStore.java | 51 +++++++++---------- .../pinecone/QueryVectorRequestSchema.java | 3 -- .../pinecone/QueryVectorResponseSchema.java | 2 +- 8 files changed, 58 insertions(+), 50 deletions(-) create mode 100644 core/src/main/java/com/theagilemonkeys/ellmental/core/errors/EnvironmentVariableNotDeclaredException.java diff --git a/core/src/main/java/com/theagilemonkeys/ellmental/core/errors/EnvironmentVariableNotDeclaredException.java b/core/src/main/java/com/theagilemonkeys/ellmental/core/errors/EnvironmentVariableNotDeclaredException.java new file mode 100644 index 0000000..3429cb9 --- /dev/null +++ b/core/src/main/java/com/theagilemonkeys/ellmental/core/errors/EnvironmentVariableNotDeclaredException.java @@ -0,0 +1,8 @@ +package com.theagilemonkeys.ellmental.core.errors; + + +public class EnvironmentVariableNotDeclaredException extends RuntimeException { + public EnvironmentVariableNotDeclaredException(String message) { + super(message); + } +} \ No newline at end of file diff --git a/core/src/main/java/com/theagilemonkeys/ellmental/core/schema/Embedding.java b/core/src/main/java/com/theagilemonkeys/ellmental/core/schema/Embedding.java index 1781787..e845f13 100644 --- a/core/src/main/java/com/theagilemonkeys/ellmental/core/schema/Embedding.java +++ b/core/src/main/java/com/theagilemonkeys/ellmental/core/schema/Embedding.java @@ -3,6 +3,10 @@ import java.util.List; public class Embedding { + + public Embedding(List vector) { + this.vector = vector; + } public List vector; } diff --git a/examples/simplejava/src/main/java/com/theagilemonkeys/ellmental/helloworld/Main.java b/examples/simplejava/src/main/java/com/theagilemonkeys/ellmental/helloworld/Main.java index 7b9866d..2a1f27e 100644 --- a/examples/simplejava/src/main/java/com/theagilemonkeys/ellmental/helloworld/Main.java +++ b/examples/simplejava/src/main/java/com/theagilemonkeys/ellmental/helloworld/Main.java @@ -1,31 +1,30 @@ package com.theagilemonkeys.ellmental.helloworld; import com.theagilemonkeys.ellmental.core.schema.Embedding; -import com.theagilemonkeys.examplemodule.PrintHelloLibrary; import com.theagilemonkeys.ellmental.embeddingsgeneration.openai.OpenAIEmbeddingsModel; import com.theagilemonkeys.ellmental.embeddingsstore.EmbeddingsStore; import com.theagilemonkeys.ellmental.embeddingsstore.pinecone.PineconeEmbeddingsStore; - import java.util.HashMap; import java.util.Map; +import java.util.List; public class Main { public static void main(String[] args) { + // Step 1: generate embeddings from input string OpenAIEmbeddingsModel openAI = new OpenAIEmbeddingsModel(); Embedding embedding = openAI.generateEmbedding("Test"); - - - EmbeddingsStore embeddingStore = new PineconeEmbeddingsStore(); - Map metadata = new HashMap<>(); metadata.put("key1", "value1"); metadata.put("key2", "value2"); - + // Step 2: save the generated embeddings to a store (Pinecone in this case) + EmbeddingsStore embeddingStore = new PineconeEmbeddingsStore(); embeddingStore.store(embedding, metadata); - PrintHelloLibrary a = new PrintHelloLibrary(); - a.printHello(); + // Step 3: search for the embedding in the store + List searchEmbeddings = embeddingStore.similaritySearch(embedding, 5); + + System.out.println("Embedding generation and storage finished."); } } diff --git a/modules/embeddingsgeneration/src/main/java/com/theagilemonkeys/ellmental/embeddingsgeneration/openai/OpenAIEmbeddingsModel.java b/modules/embeddingsgeneration/src/main/java/com/theagilemonkeys/ellmental/embeddingsgeneration/openai/OpenAIEmbeddingsModel.java index 88dd076..97f90d3 100644 --- a/modules/embeddingsgeneration/src/main/java/com/theagilemonkeys/ellmental/embeddingsgeneration/openai/OpenAIEmbeddingsModel.java +++ b/modules/embeddingsgeneration/src/main/java/com/theagilemonkeys/ellmental/embeddingsgeneration/openai/OpenAIEmbeddingsModel.java @@ -1,5 +1,6 @@ package com.theagilemonkeys.ellmental.embeddingsgeneration.openai; +import com.theagilemonkeys.ellmental.core.errors.EnvironmentVariableNotDeclaredException; import com.theagilemonkeys.ellmental.core.schema.Embedding; import com.theagilemonkeys.ellmental.embeddingsgeneration.EmbeddingsGenerationModel; @@ -12,30 +13,30 @@ public class OpenAIEmbeddingsModel extends EmbeddingsGenerationModel { - private static OpenAiService service; - static String embeddingOpenAiModel = "text-embedding-ada-002"; + public static String embeddingOpenAiModel = "text-embedding-ada-002"; public OpenAIEmbeddingsModel() { - service = new OpenAiService(""); + String openAIKey = System.getenv("OPEN_AI_API_KEY"); + if (openAIKey == null) { + throw new EnvironmentVariableNotDeclaredException("Environment variable OPEN_AI_API_KEY is not declared."); + } + service = new OpenAiService(openAIKey); + } public Embedding generateEmbedding(String inputString) { - // TODO: the embedding function uses array as input. Check if we should implement an array option. + // TODO: the embeddings function from the library uses an array as input. We are only using a length 1 array. + // Check if we should implement an array option. List embeddingsInput = new ArrayList<>(); embeddingsInput.add(inputString); - EmbeddingRequest embeddingRequest = EmbeddingRequest.builder() .model(embeddingOpenAiModel) .input(embeddingsInput) .build(); - List embedding = service.createEmbeddings(embeddingRequest).getData().get(0).getEmbedding(); - Embedding embeddingReturn = new Embedding(); - embeddingReturn.vector = embedding; - - return embeddingReturn; + return new Embedding(embedding); } } diff --git a/modules/embeddingsstore/src/main/java/com/theagilemonkeys/ellmental/embeddingsstore/EmbeddingsStore.java b/modules/embeddingsstore/src/main/java/com/theagilemonkeys/ellmental/embeddingsstore/EmbeddingsStore.java index b59b092..7bb7977 100644 --- a/modules/embeddingsstore/src/main/java/com/theagilemonkeys/ellmental/embeddingsstore/EmbeddingsStore.java +++ b/modules/embeddingsstore/src/main/java/com/theagilemonkeys/ellmental/embeddingsstore/EmbeddingsStore.java @@ -6,7 +6,7 @@ import com.theagilemonkeys.ellmental.core.schema.Embedding; public abstract class EmbeddingsStore { - // TODO: check if there is an issue using a map instead of a metadata class + //TODO: check if there is an issue using a map instead of a metadata class public abstract void store(Embedding embedding, Map metadata); public abstract List similaritySearch(Embedding reference, int limit); } \ No newline at end of file diff --git a/modules/embeddingsstore/src/main/java/com/theagilemonkeys/ellmental/embeddingsstore/pinecone/PineconeEmbeddingsStore.java b/modules/embeddingsstore/src/main/java/com/theagilemonkeys/ellmental/embeddingsstore/pinecone/PineconeEmbeddingsStore.java index edfbf22..66c2554 100644 --- a/modules/embeddingsstore/src/main/java/com/theagilemonkeys/ellmental/embeddingsstore/pinecone/PineconeEmbeddingsStore.java +++ b/modules/embeddingsstore/src/main/java/com/theagilemonkeys/ellmental/embeddingsstore/pinecone/PineconeEmbeddingsStore.java @@ -2,32 +2,37 @@ import com.google.gson.Gson; import com.theagilemonkeys.ellmental.embeddingsstore.EmbeddingsStore; - - import java.io.IOException; -import java.util.Arrays; +import java.util.ArrayList; import java.util.List; import java.util.Map; - - - import com.theagilemonkeys.ellmental.core.schema.Embedding; +import com.theagilemonkeys.ellmental.core.errors.EnvironmentVariableNotDeclaredException; import okhttp3.*; - - import static java.net.HttpURLConnection.HTTP_BAD_REQUEST; public class PineconeEmbeddingsStore extends EmbeddingsStore { - private static final String url = "https://andre-index-ae44658.svc.gcp-starter.pinecone.io"; + private final String url; + private final String apiKey; + private final String namespace; private static OkHttpClient client; - public PineconeEmbeddingsStore( + public PineconeEmbeddingsStore() { + url = System.getenv("PINECONE_URL"); + apiKey = System.getenv("PINECONE_API_KEY"); + namespace = System.getenv("PINECONE_NAMESPACE"); + + if (url == null) { + throw new EnvironmentVariableNotDeclaredException("Environement variable PINECONE_URL is not declared."); + } else if (apiKey == null) { + throw new EnvironmentVariableNotDeclaredException("Environement variable PINECONE_API_KEY is not declared."); + } else if (namespace == null) { + // TODO: need to test code using pinecone with namespace + throw new EnvironmentVariableNotDeclaredException("Environement variable PINECONE_NAMESPACE is not declared."); + } - ) { client = new OkHttpClient(); - - } private String post(String path, String bodyString) throws IOException { @@ -40,11 +45,10 @@ private String post(String path, String bodyString) throws IOException { .url(url + path) .header("accept", "application/json") .header("content-type", "application/json") - .header("Api-Key", "1c13987f-1d13-4373-8655-f739089de6af") + .header("Api-Key", apiKey) .post(body) .build(); - // TODO: need to improve error handling message try(Response response = client.newCall(request).execute()) { if (response.code() >= HTTP_BAD_REQUEST) { throw new IOException(url); @@ -62,12 +66,9 @@ private String post(String path, String bodyString) throws IOException { // TODO: using store for the upsert. Check if this is the correct path. public void store(Embedding embedding, Map metadata) { - UpsertVectorSchema schema = new UpsertVectorSchema(embedding, metadata); - String requestBodyJson = new Gson().toJson(schema); - try { this.post("/vectors/upsert", requestBodyJson); } catch (IOException e) { @@ -84,23 +85,21 @@ public List similaritySearch(Embedding reference, int limit) { reference.vector, null ); - String requestBodyJson = new Gson().toJson(body); try { String responseString = this.post("/query", requestBodyJson); + QueryVectorResponseSchema response = new Gson().fromJson(responseString, QueryVectorResponseSchema.class); + ArrayList matchEmbeddings = new ArrayList<>(); - QueryVectorRequestSchema response = new Gson().fromJson(responseString, QueryVectorRequestSchema.class); - - + for ( Match m : response.matches) { + matchEmbeddings.add(new Embedding(m.values)); + } - return Arrays.asList(new Embedding(), new Embedding()); + return matchEmbeddings; } catch (IOException e) { System.out.println("VectorStore error on upsert: " + e.getMessage()); return null; } - - - } } diff --git a/modules/embeddingsstore/src/main/java/com/theagilemonkeys/ellmental/embeddingsstore/pinecone/QueryVectorRequestSchema.java b/modules/embeddingsstore/src/main/java/com/theagilemonkeys/ellmental/embeddingsstore/pinecone/QueryVectorRequestSchema.java index 0e2f04b..b8e2984 100644 --- a/modules/embeddingsstore/src/main/java/com/theagilemonkeys/ellmental/embeddingsstore/pinecone/QueryVectorRequestSchema.java +++ b/modules/embeddingsstore/src/main/java/com/theagilemonkeys/ellmental/embeddingsstore/pinecone/QueryVectorRequestSchema.java @@ -23,7 +23,4 @@ public class QueryVectorRequestSchema { public Boolean includeMetadata; public List vector; public String namespace; - - - } diff --git a/modules/embeddingsstore/src/main/java/com/theagilemonkeys/ellmental/embeddingsstore/pinecone/QueryVectorResponseSchema.java b/modules/embeddingsstore/src/main/java/com/theagilemonkeys/ellmental/embeddingsstore/pinecone/QueryVectorResponseSchema.java index aaff80e..684a45f 100644 --- a/modules/embeddingsstore/src/main/java/com/theagilemonkeys/ellmental/embeddingsstore/pinecone/QueryVectorResponseSchema.java +++ b/modules/embeddingsstore/src/main/java/com/theagilemonkeys/ellmental/embeddingsstore/pinecone/QueryVectorResponseSchema.java @@ -10,7 +10,7 @@ public class QueryVectorResponseSchema { class Match { String id; - Integer score; + Double score; List values; Map metadata; } \ No newline at end of file