Skip to content

Commit

Permalink
Add: added embedding search fucntion to store
Browse files Browse the repository at this point in the history
also cleaned up some code
  • Loading branch information
mnlx committed Aug 20, 2023
1 parent d5d0a8f commit c6d7e9f
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 50 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.theagilemonkeys.ellmental.core.errors;


public class EnvironmentVariableNotDeclaredException extends RuntimeException {
public EnvironmentVariableNotDeclaredException(String message) {
super(message);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import java.util.List;

public class Embedding {

public Embedding(List<Double> vector) {
this.vector = vector;
}
public List<Double> vector;

}
Original file line number Diff line number Diff line change
@@ -1,31 +1,30 @@
package com.theagilemonkeys.ellmental.helloworld;

import com.theagilemonkeys.ellmental.core.schema.Embedding;
import com.theagilemonkeys.examplemodule.PrintHelloLibrary;
import com.theagilemonkeys.ellmental.embeddingsgeneration.openai.OpenAIEmbeddingsModel;
import com.theagilemonkeys.ellmental.embeddingsstore.EmbeddingsStore;
import com.theagilemonkeys.ellmental.embeddingsstore.pinecone.PineconeEmbeddingsStore;

import java.util.HashMap;
import java.util.Map;
import java.util.List;

public class Main {
public static void main(String[] args) {
// Step 1: generate embeddings from input string
OpenAIEmbeddingsModel openAI = new OpenAIEmbeddingsModel();
Embedding embedding = openAI.generateEmbedding("Test");


EmbeddingsStore embeddingStore = new PineconeEmbeddingsStore();

Map<String, String> metadata = new HashMap<>();
metadata.put("key1", "value1");
metadata.put("key2", "value2");


// Step 2: save the generated embeddings to a store (Pinecone in this case)
EmbeddingsStore embeddingStore = new PineconeEmbeddingsStore();
embeddingStore.store(embedding, metadata);

PrintHelloLibrary a = new PrintHelloLibrary();
a.printHello();
// Step 3: search for the embedding in the store
List<Embedding> searchEmbeddings = embeddingStore.similaritySearch(embedding, 5);

System.out.println("Embedding generation and storage finished.");
}
}

Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.theagilemonkeys.ellmental.embeddingsgeneration.openai;

import com.theagilemonkeys.ellmental.core.errors.EnvironmentVariableNotDeclaredException;
import com.theagilemonkeys.ellmental.core.schema.Embedding;
import com.theagilemonkeys.ellmental.embeddingsgeneration.EmbeddingsGenerationModel;

Expand All @@ -12,30 +13,30 @@


public class OpenAIEmbeddingsModel extends EmbeddingsGenerationModel {

private static OpenAiService service;
static String embeddingOpenAiModel = "text-embedding-ada-002";
public static String embeddingOpenAiModel = "text-embedding-ada-002";
public OpenAIEmbeddingsModel() {
service = new OpenAiService("");
String openAIKey = System.getenv("OPEN_AI_API_KEY");
if (openAIKey == null) {
throw new EnvironmentVariableNotDeclaredException("Environment variable OPEN_AI_API_KEY is not declared.");
}
service = new OpenAiService(openAIKey);

}
public Embedding generateEmbedding(String inputString) {

// TODO: the embedding function uses array as input. Check if we should implement an array option.
// TODO: the embeddings function from the library uses an array as input. We are only using a length 1 array.
// Check if we should implement an array option.
List<String> embeddingsInput = new ArrayList<>();
embeddingsInput.add(inputString);


EmbeddingRequest embeddingRequest = EmbeddingRequest.builder()
.model(embeddingOpenAiModel)
.input(embeddingsInput)
.build();


List<Double> embedding = service.createEmbeddings(embeddingRequest).getData().get(0).getEmbedding();

Embedding embeddingReturn = new Embedding();
embeddingReturn.vector = embedding;

return embeddingReturn;
return new Embedding(embedding);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import com.theagilemonkeys.ellmental.core.schema.Embedding;

public abstract class EmbeddingsStore {
// TODO: check if there is an issue using a map instead of a metadata class
//TODO: check if there is an issue using a map instead of a metadata class
public abstract void store(Embedding embedding, Map<String,String> metadata);
public abstract List<Embedding> similaritySearch(Embedding reference, int limit);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,37 @@

import com.google.gson.Gson;
import com.theagilemonkeys.ellmental.embeddingsstore.EmbeddingsStore;


import java.io.IOException;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;



import com.theagilemonkeys.ellmental.core.schema.Embedding;
import com.theagilemonkeys.ellmental.core.errors.EnvironmentVariableNotDeclaredException;
import okhttp3.*;


import static java.net.HttpURLConnection.HTTP_BAD_REQUEST;


public class PineconeEmbeddingsStore extends EmbeddingsStore {

private static final String url = "https://andre-index-ae44658.svc.gcp-starter.pinecone.io";
private final String url;
private final String apiKey;
private final String namespace;
private static OkHttpClient client;
public PineconeEmbeddingsStore(
public PineconeEmbeddingsStore() {
url = System.getenv("PINECONE_URL");
apiKey = System.getenv("PINECONE_API_KEY");
namespace = System.getenv("PINECONE_NAMESPACE");

if (url == null) {
throw new EnvironmentVariableNotDeclaredException("Environement variable PINECONE_URL is not declared.");
} else if (apiKey == null) {
throw new EnvironmentVariableNotDeclaredException("Environement variable PINECONE_API_KEY is not declared.");
} else if (namespace == null) {
// TODO: need to test code using pinecone with namespace
throw new EnvironmentVariableNotDeclaredException("Environement variable PINECONE_NAMESPACE is not declared.");
}

) {
client = new OkHttpClient();


}

private String post(String path, String bodyString) throws IOException {
Expand All @@ -40,11 +45,10 @@ private String post(String path, String bodyString) throws IOException {
.url(url + path)
.header("accept", "application/json")
.header("content-type", "application/json")
.header("Api-Key", "1c13987f-1d13-4373-8655-f739089de6af")
.header("Api-Key", apiKey)
.post(body)
.build();

// TODO: need to improve error handling message
try(Response response = client.newCall(request).execute()) {
if (response.code() >= HTTP_BAD_REQUEST) {
throw new IOException(url);
Expand All @@ -62,12 +66,9 @@ private String post(String path, String bodyString) throws IOException {

// TODO: using store for the upsert. Check if this is the correct path.
public void store(Embedding embedding, Map<String,String> metadata) {

UpsertVectorSchema schema = new UpsertVectorSchema(embedding, metadata);

String requestBodyJson = new Gson().toJson(schema);


try {
this.post("/vectors/upsert", requestBodyJson);
} catch (IOException e) {
Expand All @@ -84,23 +85,21 @@ public List<Embedding> similaritySearch(Embedding reference, int limit) {
reference.vector,
null
);

String requestBodyJson = new Gson().toJson(body);

try {
String responseString = this.post("/query", requestBodyJson);
QueryVectorResponseSchema response = new Gson().fromJson(responseString, QueryVectorResponseSchema.class);
ArrayList<Embedding> matchEmbeddings = new ArrayList<>();

QueryVectorRequestSchema response = new Gson().fromJson(responseString, QueryVectorRequestSchema.class);


for ( Match m : response.matches) {
matchEmbeddings.add(new Embedding(m.values));
}

return Arrays.asList(new Embedding(), new Embedding());
return matchEmbeddings;
} catch (IOException e) {
System.out.println("VectorStore error on upsert: " + e.getMessage());
return null;
}



}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,4 @@ public class QueryVectorRequestSchema {
public Boolean includeMetadata;
public List<Double> vector;
public String namespace;



}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public class QueryVectorResponseSchema {

class Match {
String id;
Integer score;
Double score;
List<Double> values;
Map<String,String> metadata;
}

0 comments on commit c6d7e9f

Please sign in to comment.