Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: HFEI support #69

Merged
merged 2 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 144 additions & 83 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,32 +78,33 @@ import tech.amikos.chromadb.embeddings.DefaultEmbeddingFunction;
import java.util.*;

public class Main {
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
client.reset();
EmbeddingFunction ef = new DefaultEmbeddingFunction();
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
System.out.println(e);
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
client.reset();
EmbeddingFunction ef = new DefaultEmbeddingFunction();
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
System.out.println(e);
}
}
}
}
```

### Example OpenAI Embedding Function

In this example we rely on `tech.amikos.chromadb.embeddings.openai.OpenAIEmbeddingFunction` to generate embeddings for our documents.
In this example we rely on `tech.amikos.chromadb.embeddings.openai.OpenAIEmbeddingFunction` to generate embeddings for
our documents.

| **Important**: Ensure you have `OPENAI_API_KEY` environment variable set

Expand All @@ -118,27 +119,27 @@ import tech.amikos.chromadb.embeddings.openai.OpenAIEmbeddingFunction;
import java.util.*;

public class Main {
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
String apiKey = System.getenv("OPENAI_API_KEY");
EmbeddingFunction ef = new OpenAIEmbeddingFunction(apiKey, "text-embedding-3-small");
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
e.printStackTrace();
System.out.println(e);
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
String apiKey = System.getenv("OPENAI_API_KEY");
EmbeddingFunction ef = new OpenAIEmbeddingFunction(apiKey, "text-embedding-3-small");
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
e.printStackTrace();
System.out.println(e);
}
}
}
}
```

Expand Down Expand Up @@ -174,7 +175,8 @@ curl http://localhost:11434/api/embeddings -d '{\n "model": "llama2",\n "promp

### Example Cohere Embedding Function

In this example we rely on `tech.amikos.chromadb.embeddings.cohere.CohereEmbeddingFunction` to generate embeddings for our documents.
In this example we rely on `tech.amikos.chromadb.embeddings.cohere.CohereEmbeddingFunction` to generate embeddings for
our documents.

| **Important**: Ensure you have `COHERE_API_KEY` environment variable set

Expand All @@ -188,28 +190,28 @@ import tech.amikos.chromadb.embeddings.cohere.CohereEmbeddingFunction;
import java.util.*;

public class Main {
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
client.reset();
String apiKey = System.getenv("COHERE_API_KEY");
EmbeddingFunction ef = new CohereEmbeddingFunction(apiKey);
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
e.printStackTrace();
System.out.println(e);
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
client.reset();
String apiKey = System.getenv("COHERE_API_KEY");
EmbeddingFunction ef = new CohereEmbeddingFunction(apiKey);
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
e.printStackTrace();
System.out.println(e);
}
}
}
}
```

Expand All @@ -221,7 +223,10 @@ The above should output:

### Example Hugging Face Sentence Transformers Embedding Function

In this example we rely on `tech.amikos.chromadb.embeddings.hf.HuggingFaceEmbeddingFunction` to generate embeddings for our documents.
#### Hugging Face Inference API

In this example we rely on `tech.amikos.chromadb.embeddings.hf.HuggingFaceEmbeddingFunction` to generate embeddings for
our documents using HuggingFace cloud-based inference API.

| **Important**: Ensure you have `HF_API_KEY` environment variable set

Expand All @@ -235,27 +240,26 @@ import tech.amikos.chromadb.embeddings.hf.HuggingFaceEmbeddingFunction;
import java.util.*;

public class Main {
public static void main(String[] args) {
try {
Client client = new Client(System.getenv("CHROMA_URL"));
client.reset();
String apiKey = System.getenv("HF_API_KEY");
EmbeddingFunction ef = new HuggingFaceEmbeddingFunction(apiKey);
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
System.out.println(e);
public static void main(String[] args) {
try {
Client client = new Client("http://localhost:8000");
String apiKey = System.getenv("HF_API_KEY");
EmbeddingFunction ef = new HuggingFaceEmbeddingFunction(apiKey);
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
System.out.println(e);
}
}
}
}
```

Expand All @@ -265,6 +269,63 @@ The above should output:
{"documents":[["Hello, my name is Bond. I am a Spy.","Hello, my name is John. I am a Data Scientist."]],"ids":[["2","1"]],"metadatas":[[{"type":"spy"},{"type":"scientist"}]],"distances":[[0.9073759,1.6440368]]}
```

#### Hugging Face Text Embedding Inference (HFEI) API

In this example we'll use a local Docker based server to generate the embeddings with
`Snowflake/snowflake-arctic-embed-s` mode.

First let's start the HFEI server:

```bash
docker run -d -p 8008:80 --platform linux/amd64 --name hfei ghcr.io/huggingface/text-embeddings-inference:cpu-1.5.0 --model-id Snowflake/snowflake-arctic-embed-s --revision main
```

> Note: Check the official documentation for more details - https://github.com/huggingface/text-embeddings-inference

Then we can use the following code to generate embeddings. Note the use of
`new HuggingFaceEmbeddingFunction.WithAPIType(HuggingFaceEmbeddingFunction.APIType.HFEI_API));` to define the API type,
this will ensure the client uses the correct endpoint.

```java
package tech.amikos;

import tech.amikos.chromadb.*;
import tech.amikos.chromadb.Collection;
import tech.amikos.chromadb.embeddings.hf.HuggingFaceEmbeddingFunction;

import java.util.*;

public class Main {
public static void main(String[] args) {
try {
Client client = new Client("http://localhost:8000");
EmbeddingFunction ef = new HuggingFaceEmbeddingFunction(
WithParam.baseAPI("http://localhost:8008"),
new HuggingFaceEmbeddingFunction.WithAPIType(HuggingFaceEmbeddingFunction.APIType.HFEI_API));
Collection collection = client.createCollection("test-collection", null, true, ef);
List<Map<String, String>> metadata = new ArrayList<>();
metadata.add(new HashMap<String, String>() {{
put("type", "scientist");
}});
metadata.add(new HashMap<String, String>() {{
put("type", "spy");
}});
collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2"));
Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null);
System.out.println(qr);
} catch (Exception e) {
System.out.println(e);
}
}
}
```

The above should similar to the following output:

```bash
{"documents":[["Hello, my name is Bond. I am a Spy.","Hello, my name is John. I am a Data Scientist."]],"ids":[["2","1"]],"metadatas":[[{"type":"spy"},{"type":"scientist"}]],"distances":[[0.19665092,0.42433012]]}
```

### Ollama Embedding Function

In this example we rely on `tech.amikos.chromadb.embeddings.ollama.OllamaEmbeddingFunction` to generate embeddings for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
public class HuggingFaceEmbeddingFunction implements EmbeddingFunction {
public static final String DEFAULT_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2";
public static final String DEFAULT_BASE_API = "https://api-inference.huggingface.co/pipeline/feature-extraction/";
public static final String HFEI_API_PATH = "/embed";
public static final String HF_API_KEY_ENV = "HF_API_KEY";
public static final String API_TYPE_CONFIG_KEY = "apiType";
private final OkHttpClient client = new OkHttpClient();
private final Map<String, Object> configParams = new HashMap<>();
private static final Gson gson = new Gson();

private static final List<WithParam> defaults = Arrays.asList(
new WithAPIType(APIType.HF_API),
WithParam.baseAPI(DEFAULT_BASE_API),
WithParam.defaultModel(DEFAULT_MODEL_NAME)
);
Expand All @@ -46,14 +49,21 @@ public HuggingFaceEmbeddingFunction(WithParam... params) throws EFException {
}

public CreateEmbeddingResponse createEmbedding(CreateEmbeddingRequest req) throws EFException {
Request request = new Request.Builder()
.url(this.configParams.get(Constants.EF_PARAMS_BASE_API).toString() + this.configParams.get(Constants.EF_PARAMS_MODEL).toString())
Request.Builder rb = new Request.Builder()

.post(RequestBody.create(req.json(), JSON))
.addHeader("Accept", "application/json")
.addHeader("Content-Type", "application/json")
.addHeader("User-Agent", Constants.HTTP_AGENT)
.addHeader("Authorization", "Bearer " + configParams.get(Constants.EF_PARAMS_API_KEY).toString())
.build();
.addHeader("User-Agent", Constants.HTTP_AGENT);
if (configParams.containsKey(API_TYPE_CONFIG_KEY) && configParams.get(API_TYPE_CONFIG_KEY).equals(APIType.HFEI_API)) {
rb.url(this.configParams.get(Constants.EF_PARAMS_BASE_API).toString() + HFEI_API_PATH);
} else {
rb.url(this.configParams.get(Constants.EF_PARAMS_BASE_API).toString() + this.configParams.get(Constants.EF_PARAMS_MODEL).toString());
}
if (configParams.containsKey(Constants.EF_PARAMS_API_KEY)) {
rb.addHeader("Authorization", "Bearer " + configParams.get(Constants.EF_PARAMS_API_KEY).toString());
}
Request request = rb.build();
try (Response response = client.newCall(request).execute()) {
if (!response.isSuccessful()) {
throw new IOException("Unexpected code " + response);
Expand Down Expand Up @@ -86,4 +96,22 @@ public List<Embedding> embedDocuments(String[] documents) throws EFException {
CreateEmbeddingResponse response = this.createEmbedding(new CreateEmbeddingRequest().inputs(documents));
return response.getEmbeddings().stream().map(Embedding::fromList).collect(Collectors.toList());
}

public static class WithAPIType extends WithParam {
private final APIType apiType;

public WithAPIType(APIType apitype) {
this.apiType = apitype;
}

@Override
public void apply(Map<String, Object> params) {
params.put(API_TYPE_CONFIG_KEY, apiType);
}
}

public enum APIType{
HF_API,
HFEI_API
}
}
Loading
Loading