Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: EF move to embeddings sub-package #49

Merged
merged 2 commits into from
Aug 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 87 additions & 85 deletions README.md
Original file line number Diff line number Diff line change
@@ -73,72 +73,72 @@ package tech.amikos;

import tech.amikos.chromadb.*;
import tech.amikos.chromadb.Collection;
import tech.amikos.chromadb.embeddings.DefaultEmbeddingFunction;

import java.util.*;

public class Main {
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
client.reset();
EmbeddingFunction ef = new DefaultEmbeddingFunction();
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
System.out.println(e);
}
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
client.reset();
EmbeddingFunction ef = new DefaultEmbeddingFunction();
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
System.out.println(e);
}
}
}
```

### Example OpenAI Embedding Function

In this example we rely on `tech.amikos.chromadb.OpenAIEmbeddingFunction` to generate embeddings for our documents.
In this example we rely on `tech.amikos.chromadb.embeddings.openai.OpenAIEmbeddingFunction` to generate embeddings for our documents.

| **Important**: Ensure you have `OPENAI_API_KEY` environment variable set

```java
package tech.amikos;

import com.google.gson.internal.LinkedTreeMap;
import tech.amikos.chromadb.Client;
import tech.amikos.chromadb.Collection;
import tech.amikos.chromadb.EmbeddingFunction;
import tech.amikos.chromadb.OpenAIEmbeddingFunction;
import tech.amikos.chromadb.embeddings.openai.OpenAIEmbeddingFunction;

import java.util.*;

public class Main {
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
String apiKey = System.getenv("OPENAI_API_KEY");
EmbeddingFunction ef = new OpenAIEmbeddingFunction(apiKey, "text-embedding-3-small");
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
e.printStackTrace();
System.out.println(e);
}
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
String apiKey = System.getenv("OPENAI_API_KEY");
EmbeddingFunction ef = new OpenAIEmbeddingFunction(apiKey, "text-embedding-3-small");
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
e.printStackTrace();
System.out.println(e);
}
}
}
```

@@ -174,7 +174,7 @@ curl http://localhost:11434/api/embeddings -d '{\n "model": "llama2",\n "promp

### Example Cohere Embedding Function

In this example we rely on `tech.amikos.chromadb.CohereEmbeddingFunction` to generate embeddings for our documents.
In this example we rely on `tech.amikos.chromadb.embeddings.cohere.CohereEmbeddingFunction` to generate embeddings for our documents.

| **Important**: Ensure you have `COHERE_API_KEY` environment variable set

@@ -183,32 +183,33 @@ package tech.amikos;

import tech.amikos.chromadb.*;
import tech.amikos.chromadb.Collection;
import tech.amikos.chromadb.embeddings.cohere.CohereEmbeddingFunction;

import java.util.*;

public class Main {
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
client.reset();
String apiKey = System.getenv("COHERE_API_KEY");
EmbeddingFunction ef = new CohereEmbeddingFunction(apiKey);
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
e.printStackTrace();
System.out.println(e);
}
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
client.reset();
String apiKey = System.getenv("COHERE_API_KEY");
EmbeddingFunction ef = new CohereEmbeddingFunction(apiKey);
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
e.printStackTrace();
System.out.println(e);
}
}
}
```

@@ -220,7 +221,7 @@ The above should output:

### Example Hugging Face Sentence Transformers Embedding Function

In this example we rely on `tech.amikos.chromadb.HuggingFaceEmbeddingFunction` to generate embeddings for our documents.
In this example we rely on `tech.amikos.chromadb.embeddings.hf.HuggingFaceEmbeddingFunction` to generate embeddings for our documents.

| **Important**: Ensure you have `HF_API_KEY` environment variable set

@@ -229,31 +230,32 @@ package tech.amikos;

import tech.amikos.chromadb.*;
import tech.amikos.chromadb.Collection;
import tech.amikos.chromadb.embeddings.hf.HuggingFaceEmbeddingFunction;

import java.util.*;

public class Main {
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
client.reset();
String apiKey = System.getenv("HF_API_KEY");
EmbeddingFunction ef = new HuggingFaceEmbeddingFunction(apiKey);
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
System.out.println(e);
}
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
client.reset();
String apiKey = System.getenv("HF_API_KEY");
EmbeddingFunction ef = new HuggingFaceEmbeddingFunction(apiKey);
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
System.out.println(e);
}
}
}
```

2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
@@ -228,7 +228,7 @@
<!-- <additionalProperty>ignoreImportMappings=false</additionalProperty>-->
<!-- </additionalProperties>-->
<importMappings>
<importMapping>CreateEmbeddingRequest=tech.amikos.openai.CreateEmbeddingRequest
<importMapping>CreateEmbeddingRequest=tech.amikos.chromadb.embeddings.openai.CreateEmbeddingRequest
</importMapping>
</importMappings>
<generateApiTests>false</generateApiTests>
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tech.amikos.chromadb;
package tech.amikos.chromadb.embeddings;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
@@ -11,6 +11,8 @@
import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.guava.primitives.Floats;
import tech.amikos.chromadb.EFException;
import tech.amikos.chromadb.EmbeddingFunction;

import java.io.*;
import java.net.URL;
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tech.amikos.cohere;
package tech.amikos.chromadb.embeddings.cohere;

import com.google.gson.Gson;
import okhttp3.*;
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
package tech.amikos.chromadb;
package tech.amikos.chromadb.embeddings.cohere;

import tech.amikos.cohere.CohereClient;
import tech.amikos.cohere.CreateEmbeddingRequest;
import tech.amikos.cohere.CreateEmbeddingResponse;
import tech.amikos.chromadb.EmbeddingFunction;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

public class CohereEmbeddingFunction implements EmbeddingFunction {

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tech.amikos.cohere;
package tech.amikos.chromadb.embeddings.cohere;

import com.google.gson.*;
import com.google.gson.annotations.SerializedName;
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tech.amikos.cohere;
package tech.amikos.chromadb.embeddings.cohere;

import com.google.gson.Gson;
import com.google.gson.annotations.SerializedName;
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tech.amikos.hf;
package tech.amikos.chromadb.embeddings.hf;

import com.google.gson.Gson;
import com.google.gson.annotations.SerializedName;
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tech.amikos.hf;
package tech.amikos.chromadb.embeddings.hf;

import com.google.gson.Gson;

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tech.amikos.hf;
package tech.amikos.chromadb.embeddings.hf;

import com.google.gson.Gson;
import okhttp3.*;
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
package tech.amikos.chromadb;
package tech.amikos.chromadb.embeddings.hf;


import tech.amikos.hf.CreateEmbeddingRequest;
import tech.amikos.hf.CreateEmbeddingResponse;
import tech.amikos.hf.HuggingFaceClient;
import tech.amikos.chromadb.EmbeddingFunction;

import java.util.List;

public class HuggingFaceEmbeddingFunction implements EmbeddingFunction {

private final String hfAPIKey;

private final HuggingFaceClient client;
public HuggingFaceEmbeddingFunction(String hfAPIKey) {
this.hfAPIKey = hfAPIKey;

this.client = new HuggingFaceClient(this.hfAPIKey);
}

@Override
public List<List<Float>> createEmbedding(List<String> documents) {
HuggingFaceClient client = new HuggingFaceClient(this.hfAPIKey);
CreateEmbeddingResponse response = client.createEmbedding(new CreateEmbeddingRequest().inputs(documents.toArray(new String[0])));

CreateEmbeddingResponse response = this.client.createEmbedding(new CreateEmbeddingRequest().inputs(documents.toArray(new String[0])));
return response.getEmbeddings();
}

@Override
public List<List<Float>> createEmbedding(List<String> documents, String model) {
HuggingFaceClient client = new HuggingFaceClient(this.hfAPIKey);
client.modelId(model);
CreateEmbeddingResponse response = client.createEmbedding(new CreateEmbeddingRequest().inputs(documents.toArray(new String[0])));
return response.getEmbeddings();
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tech.amikos.openai;
package tech.amikos.chromadb.embeddings.openai;

import com.google.gson.*;
import com.google.gson.annotations.SerializedName;
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tech.amikos.openai;
package tech.amikos.chromadb.embeddings.openai;

import com.google.gson.Gson;
import com.google.gson.annotations.SerializedName;
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tech.amikos.openai;
package tech.amikos.chromadb.embeddings.openai;

import com.google.gson.Gson;
import okhttp3.*;
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
package tech.amikos.chromadb;
package tech.amikos.chromadb.embeddings.openai;

import tech.amikos.openai.CreateEmbeddingRequest;
import tech.amikos.openai.CreateEmbeddingResponse;
import tech.amikos.openai.OpenAIClient;
import tech.amikos.chromadb.EmbeddingFunction;

import java.util.List;
import java.util.function.Consumer;
import java.util.stream.Collectors;

public class OpenAIEmbeddingFunction implements EmbeddingFunction {
16 changes: 0 additions & 16 deletions src/test/java/TestEmbeddingFunction.java

This file was deleted.

21 changes: 0 additions & 21 deletions src/test/java/TestHuggingFaceClient.java

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package tech.amikos.chromadb;

import com.github.tomakehurst.wiremock.junit.WireMockRule;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.testcontainers.chromadb.ChromaDBContainer;
import org.testcontainers.containers.wait.strategy.Wait;
import tech.amikos.chromadb.*;
import tech.amikos.chromadb.Collection;
import tech.amikos.chromadb.embeddings.DefaultEmbeddingFunction;
import tech.amikos.chromadb.embeddings.hf.HuggingFaceEmbeddingFunction;
import tech.amikos.chromadb.handler.ApiException;

import java.io.IOException;
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
package tech.amikos.chromadb;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package tech.amikos.chromadb;
package tech.amikos.chromadb.embeddings;

import org.apache.commons.io.FileUtils;
import org.junit.Test;
import tech.amikos.chromadb.EmbeddingFunction;

import java.util.Arrays;
import java.util.List;

import static org.junit.Assert.*;

public class TestDefaultEmbeddingFunction {
public class TestDefaultEmbeddings {
// this represents the output of sentence-transformers/all-MiniLM-L6-v2 for "Hello, my name is John. I am a Data Scientist.", "Hello, I am Jane and I am an ML researcher."
private static float[][] groundThruth = {{-0.09585458785295486f,
0.00948028638958931f,
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package tech.amikos.chromadb.embeddings.cohere;

import org.junit.Test;
import tech.amikos.cohere.CohereClient;
import tech.amikos.cohere.CreateEmbeddingRequest;
import tech.amikos.cohere.CreateEmbeddingResponse;
import tech.amikos.chromadb.Utils;

import static org.junit.Assert.*;

public class TestCohereClient {
public class TestCohereEmbeddings {

@Test
public void testEmbeddings() {
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package tech.amikos.chromadb.embeddings.hf;

import org.junit.BeforeClass;
import org.junit.Test;
import tech.amikos.chromadb.EmbeddingFunction;
import tech.amikos.chromadb.Utils;
import tech.amikos.chromadb.handler.ApiException;

import java.util.*;

import static org.junit.Assert.assertEquals;

public class TestHuggingFaceEmbeddings {

@BeforeClass
public static void setup() {
Utils.loadEnvFile(".env");
}

@Test
public void testEmbeddings() {
HuggingFaceClient client = new HuggingFaceClient(Utils.getEnvOrProperty("HF_API_KEY"));
client.modelId("sentence-transformers/all-MiniLM-L6-v2");
String[] texts = {"Hello world", "How are you?"};
CreateEmbeddingResponse response = client.createEmbedding(new CreateEmbeddingRequest().inputs(texts));
assertEquals(2, response.getEmbeddings().size());
}

@Test
public void testEmbed() throws ApiException {
String apiKey = Utils.getEnvOrProperty("HF_API_KEY");
EmbeddingFunction ef = new HuggingFaceEmbeddingFunction(apiKey);
List<List<Float>> results = ef.createEmbedding(Arrays.asList("Hello world", "How are you?"));
assertEquals(2, results.size());
assertEquals(384, results.get(0).size());
}

@Test
public void testEmbedWithModel() throws ApiException {
String apiKey = Utils.getEnvOrProperty("HF_API_KEY");
EmbeddingFunction ef = new HuggingFaceEmbeddingFunction(apiKey);
List<List<Float>> results = ef.createEmbedding(Arrays.asList("Hello world", "How are you?"), "sentence-transformers/all-mpnet-base-v2");
assertEquals(2, results.size());
assertEquals(768, results.get(0).size());
}
}

Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@

import static org.junit.Assert.*;

public class TestOllamaEmbeddingFunction {
public class TestOllamaEmbeddings {
static GenericContainer ollamaContainer;

@BeforeClass
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package tech.amikos.chromadb.embeddings.openai;

import org.junit.Test;
import tech.amikos.chromadb.EmbeddingFunction;
import tech.amikos.chromadb.OpenAIEmbeddingFunction;
import tech.amikos.openai.CreateEmbeddingRequest;
import tech.amikos.openai.CreateEmbeddingResponse;
import tech.amikos.openai.OpenAIClient;
import tech.amikos.chromadb.Utils;

import java.util.Arrays;
import java.util.List;