-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathTestHuggingFaceEmbeddings.java
47 lines (38 loc) · 1.68 KB
/
TestHuggingFaceEmbeddings.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
package tech.amikos.chromadb.embeddings.hf;
import org.junit.BeforeClass;
import org.junit.Test;
import tech.amikos.chromadb.EmbeddingFunction;
import tech.amikos.chromadb.Utils;
import tech.amikos.chromadb.handler.ApiException;
import java.util.*;
import static org.junit.Assert.assertEquals;
public class TestHuggingFaceEmbeddings {
@BeforeClass
public static void setup() {
Utils.loadEnvFile(".env");
}
@Test
public void testEmbeddings() {
HuggingFaceClient client = new HuggingFaceClient(Utils.getEnvOrProperty("HF_API_KEY"));
client.modelId("sentence-transformers/all-MiniLM-L6-v2");
String[] texts = {"Hello world", "How are you?"};
CreateEmbeddingResponse response = client.createEmbedding(new CreateEmbeddingRequest().inputs(texts));
assertEquals(2, response.getEmbeddings().size());
}
@Test
public void testEmbed() throws ApiException {
String apiKey = Utils.getEnvOrProperty("HF_API_KEY");
EmbeddingFunction ef = new HuggingFaceEmbeddingFunction(apiKey);
List<List<Float>> results = ef.createEmbedding(Arrays.asList("Hello world", "How are you?"));
assertEquals(2, results.size());
assertEquals(384, results.get(0).size());
}
@Test
public void testEmbedWithModel() throws ApiException {
String apiKey = Utils.getEnvOrProperty("HF_API_KEY");
EmbeddingFunction ef = new HuggingFaceEmbeddingFunction(apiKey);
List<List<Float>> results = ef.createEmbedding(Arrays.asList("Hello world", "How are you?"), "sentence-transformers/all-mpnet-base-v2");
assertEquals(2, results.size());
assertEquals(768, results.get(0).size());
}
}