Skip to content

feat: EF move to embeddings sub-package #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 87 additions & 85 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,72 +73,72 @@ package tech.amikos;

import tech.amikos.chromadb.*;
import tech.amikos.chromadb.Collection;
import tech.amikos.chromadb.embeddings.DefaultEmbeddingFunction;

import java.util.*;

public class Main {
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
client.reset();
EmbeddingFunction ef = new DefaultEmbeddingFunction();
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
System.out.println(e);
}
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
client.reset();
EmbeddingFunction ef = new DefaultEmbeddingFunction();
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
System.out.println(e);
}
}
}
```

### Example OpenAI Embedding Function

In this example we rely on `tech.amikos.chromadb.OpenAIEmbeddingFunction` to generate embeddings for our documents.
In this example we rely on `tech.amikos.chromadb.embeddings.openai.OpenAIEmbeddingFunction` to generate embeddings for our documents.

| **Important**: Ensure you have `OPENAI_API_KEY` environment variable set

```java
package tech.amikos;

import com.google.gson.internal.LinkedTreeMap;
import tech.amikos.chromadb.Client;
import tech.amikos.chromadb.Collection;
import tech.amikos.chromadb.EmbeddingFunction;
import tech.amikos.chromadb.OpenAIEmbeddingFunction;
import tech.amikos.chromadb.embeddings.openai.OpenAIEmbeddingFunction;

import java.util.*;

public class Main {
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
String apiKey = System.getenv("OPENAI_API_KEY");
EmbeddingFunction ef = new OpenAIEmbeddingFunction(apiKey, "text-embedding-3-small");
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
e.printStackTrace();
System.out.println(e);
}
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
String apiKey = System.getenv("OPENAI_API_KEY");
EmbeddingFunction ef = new OpenAIEmbeddingFunction(apiKey, "text-embedding-3-small");
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
e.printStackTrace();
System.out.println(e);
}
}
}
```

Expand Down Expand Up @@ -174,7 +174,7 @@ curl http://localhost:11434/api/embeddings -d '{\n "model": "llama2",\n "promp

### Example Cohere Embedding Function

In this example we rely on `tech.amikos.chromadb.CohereEmbeddingFunction` to generate embeddings for our documents.
In this example we rely on `tech.amikos.chromadb.embeddings.cohere.CohereEmbeddingFunction` to generate embeddings for our documents.

| **Important**: Ensure you have `COHERE_API_KEY` environment variable set

Expand All @@ -183,32 +183,33 @@ package tech.amikos;

import tech.amikos.chromadb.*;
import tech.amikos.chromadb.Collection;
import tech.amikos.chromadb.embeddings.cohere.CohereEmbeddingFunction;

import java.util.*;

public class Main {
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
client.reset();
String apiKey = System.getenv("COHERE_API_KEY");
EmbeddingFunction ef = new CohereEmbeddingFunction(apiKey);
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
e.printStackTrace();
System.out.println(e);
}
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
client.reset();
String apiKey = System.getenv("COHERE_API_KEY");
EmbeddingFunction ef = new CohereEmbeddingFunction(apiKey);
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
e.printStackTrace();
System.out.println(e);
}
}
}
```

Expand All @@ -220,7 +221,7 @@ The above should output:

### Example Hugging Face Sentence Transformers Embedding Function

In this example we rely on `tech.amikos.chromadb.HuggingFaceEmbeddingFunction` to generate embeddings for our documents.
In this example we rely on `tech.amikos.chromadb.embeddings.hf.HuggingFaceEmbeddingFunction` to generate embeddings for our documents.

| **Important**: Ensure you have `HF_API_KEY` environment variable set

Expand All @@ -229,31 +230,32 @@ package tech.amikos;

import tech.amikos.chromadb.*;
import tech.amikos.chromadb.Collection;
import tech.amikos.chromadb.embeddings.hf.HuggingFaceEmbeddingFunction;

import java.util.*;

public class Main {
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
client.reset();
String apiKey = System.getenv("HF_API_KEY");
EmbeddingFunction ef = new HuggingFaceEmbeddingFunction(apiKey);
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
System.out.println(e);
}
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
client.reset();
String apiKey = System.getenv("HF_API_KEY");
EmbeddingFunction ef = new HuggingFaceEmbeddingFunction(apiKey);
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
System.out.println(e);
}
}
}
```

Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@
<!-- <additionalProperty>ignoreImportMappings=false</additionalProperty>-->
<!-- </additionalProperties>-->
<importMappings>
<importMapping>CreateEmbeddingRequest=tech.amikos.openai.CreateEmbeddingRequest
<importMapping>CreateEmbeddingRequest=tech.amikos.chromadb.embeddings.openai.CreateEmbeddingRequest
</importMapping>
</importMappings>
<generateApiTests>false</generateApiTests>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tech.amikos.chromadb;
package tech.amikos.chromadb.embeddings;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
Expand All @@ -11,6 +11,8 @@
import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.guava.primitives.Floats;
import tech.amikos.chromadb.EFException;
import tech.amikos.chromadb.EmbeddingFunction;

import java.io.*;
import java.net.URL;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tech.amikos.cohere;
package tech.amikos.chromadb.embeddings.cohere;

import com.google.gson.Gson;
import okhttp3.*;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
package tech.amikos.chromadb;
package tech.amikos.chromadb.embeddings.cohere;

import tech.amikos.cohere.CohereClient;
import tech.amikos.cohere.CreateEmbeddingRequest;
import tech.amikos.cohere.CreateEmbeddingResponse;
import tech.amikos.chromadb.EmbeddingFunction;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

public class CohereEmbeddingFunction implements EmbeddingFunction {

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tech.amikos.cohere;
package tech.amikos.chromadb.embeddings.cohere;

import com.google.gson.*;
import com.google.gson.annotations.SerializedName;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tech.amikos.cohere;
package tech.amikos.chromadb.embeddings.cohere;

import com.google.gson.Gson;
import com.google.gson.annotations.SerializedName;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tech.amikos.hf;
package tech.amikos.chromadb.embeddings.hf;

import com.google.gson.Gson;
import com.google.gson.annotations.SerializedName;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tech.amikos.hf;
package tech.amikos.chromadb.embeddings.hf;

import com.google.gson.Gson;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tech.amikos.hf;
package tech.amikos.chromadb.embeddings.hf;

import com.google.gson.Gson;
import okhttp3.*;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
package tech.amikos.chromadb;
package tech.amikos.chromadb.embeddings.hf;


import tech.amikos.hf.CreateEmbeddingRequest;
import tech.amikos.hf.CreateEmbeddingResponse;
import tech.amikos.hf.HuggingFaceClient;
import tech.amikos.chromadb.EmbeddingFunction;

import java.util.List;

public class HuggingFaceEmbeddingFunction implements EmbeddingFunction {

private final String hfAPIKey;

private final HuggingFaceClient client;
public HuggingFaceEmbeddingFunction(String hfAPIKey) {
this.hfAPIKey = hfAPIKey;

this.client = new HuggingFaceClient(this.hfAPIKey);
}

@Override
public List<List<Float>> createEmbedding(List<String> documents) {
HuggingFaceClient client = new HuggingFaceClient(this.hfAPIKey);
CreateEmbeddingResponse response = client.createEmbedding(new CreateEmbeddingRequest().inputs(documents.toArray(new String[0])));

CreateEmbeddingResponse response = this.client.createEmbedding(new CreateEmbeddingRequest().inputs(documents.toArray(new String[0])));
return response.getEmbeddings();
}

@Override
public List<List<Float>> createEmbedding(List<String> documents, String model) {
HuggingFaceClient client = new HuggingFaceClient(this.hfAPIKey);
client.modelId(model);
CreateEmbeddingResponse response = client.createEmbedding(new CreateEmbeddingRequest().inputs(documents.toArray(new String[0])));
return response.getEmbeddings();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tech.amikos.openai;
package tech.amikos.chromadb.embeddings.openai;

import com.google.gson.*;
import com.google.gson.annotations.SerializedName;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tech.amikos.openai;
package tech.amikos.chromadb.embeddings.openai;

import com.google.gson.Gson;
import com.google.gson.annotations.SerializedName;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tech.amikos.openai;
package tech.amikos.chromadb.embeddings.openai;

import com.google.gson.Gson;
import okhttp3.*;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
package tech.amikos.chromadb;
package tech.amikos.chromadb.embeddings.openai;

import tech.amikos.openai.CreateEmbeddingRequest;
import tech.amikos.openai.CreateEmbeddingResponse;
import tech.amikos.openai.OpenAIClient;
import tech.amikos.chromadb.EmbeddingFunction;

import java.util.List;
import java.util.function.Consumer;
import java.util.stream.Collectors;

public class OpenAIEmbeddingFunction implements EmbeddingFunction {
Expand Down
Loading
Loading