Skip to content

Commit eafd3a7

Browse files
committed
feat: HFEI support
1 parent c21b146 commit eafd3a7

File tree

2 files changed

+59
-6
lines changed

2 files changed

+59
-6
lines changed

src/main/java/tech/amikos/chromadb/embeddings/hf/HuggingFaceEmbeddingFunction.java

+33-5
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717
public class HuggingFaceEmbeddingFunction implements EmbeddingFunction {
1818
public static final String DEFAULT_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2";
1919
public static final String DEFAULT_BASE_API = "https://api-inference.huggingface.co/pipeline/feature-extraction/";
20+
public static final String HFEI_API_PATH = "/embed";
2021
public static final String HF_API_KEY_ENV = "HF_API_KEY";
22+
public static final String API_TYPE_CONFIG_KEY = "apiType";
2123
private final OkHttpClient client = new OkHttpClient();
2224
private final Map<String, Object> configParams = new HashMap<>();
2325
private static final Gson gson = new Gson();
2426

2527
private static final List<WithParam> defaults = Arrays.asList(
28+
new WithAPIType(APIType.HF_API),
2629
WithParam.baseAPI(DEFAULT_BASE_API),
2730
WithParam.defaultModel(DEFAULT_MODEL_NAME)
2831
);
@@ -46,14 +49,21 @@ public HuggingFaceEmbeddingFunction(WithParam... params) throws EFException {
4649
}
4750

4851
public CreateEmbeddingResponse createEmbedding(CreateEmbeddingRequest req) throws EFException {
49-
Request request = new Request.Builder()
50-
.url(this.configParams.get(Constants.EF_PARAMS_BASE_API).toString() + this.configParams.get(Constants.EF_PARAMS_MODEL).toString())
52+
Request.Builder rb = new Request.Builder()
53+
5154
.post(RequestBody.create(req.json(), JSON))
5255
.addHeader("Accept", "application/json")
5356
.addHeader("Content-Type", "application/json")
54-
.addHeader("User-Agent", Constants.HTTP_AGENT)
55-
.addHeader("Authorization", "Bearer " + configParams.get(Constants.EF_PARAMS_API_KEY).toString())
56-
.build();
57+
.addHeader("User-Agent", Constants.HTTP_AGENT);
58+
if (configParams.containsKey(API_TYPE_CONFIG_KEY) && configParams.get(API_TYPE_CONFIG_KEY).equals(APIType.HFEI_API)) {
59+
rb.url(this.configParams.get(Constants.EF_PARAMS_BASE_API).toString() + HFEI_API_PATH);
60+
} else {
61+
rb.url(this.configParams.get(Constants.EF_PARAMS_BASE_API).toString() + this.configParams.get(Constants.EF_PARAMS_MODEL).toString());
62+
}
63+
if (configParams.containsKey(Constants.EF_PARAMS_API_KEY)) {
64+
rb.addHeader("Authorization", "Bearer " + configParams.get(Constants.EF_PARAMS_API_KEY).toString());
65+
}
66+
Request request = rb.build();
5767
try (Response response = client.newCall(request).execute()) {
5868
if (!response.isSuccessful()) {
5969
throw new IOException("Unexpected code " + response);
@@ -86,4 +96,22 @@ public List<Embedding> embedDocuments(String[] documents) throws EFException {
8696
CreateEmbeddingResponse response = this.createEmbedding(new CreateEmbeddingRequest().inputs(documents));
8797
return response.getEmbeddings().stream().map(Embedding::fromList).collect(Collectors.toList());
8898
}
99+
100+
public static class WithAPIType extends WithParam {
101+
private final APIType apiType;
102+
103+
public WithAPIType(APIType apitype) {
104+
this.apiType = apitype;
105+
}
106+
107+
@Override
108+
public void apply(Map<String, Object> params) {
109+
params.put(API_TYPE_CONFIG_KEY, apiType);
110+
}
111+
}
112+
113+
public static enum APIType{
114+
HF_API,
115+
HFEI_API
116+
}
89117
}

src/test/java/tech/amikos/chromadb/embeddings/hf/TestHuggingFaceEmbeddings.java

+26-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import org.junit.BeforeClass;
44
import org.junit.Test;
5+
import org.testcontainers.containers.GenericContainer;
6+
import org.testcontainers.containers.wait.strategy.Wait;
57
import tech.amikos.chromadb.*;
68
import tech.amikos.chromadb.embeddings.EmbeddingFunction;
79
import tech.amikos.chromadb.embeddings.WithParam;
@@ -13,10 +15,23 @@
1315
import static org.junit.Assert.assertNotNull;
1416

1517
public class TestHuggingFaceEmbeddings {
18+
static GenericContainer hfeiContainer;
1619

1720
@BeforeClass
18-
public static void setup() {
21+
public static void setup() throws Exception {
1922
Utils.loadEnvFile(".env");
23+
24+
try {
25+
hfeiContainer = new GenericContainer("ghcr.io/huggingface/text-embeddings-inference:cpu-1.5.0")
26+
.withCommand("--model-id Snowflake/snowflake-arctic-embed-s --revision main")
27+
.withExposedPorts(80)
28+
.waitingFor(Wait.forHttp("/").forStatusCode(200));
29+
hfeiContainer.start();
30+
System.setProperty("HFEI_URL", "http://" + hfeiContainer.getHost() + ":" + hfeiContainer.getMappedPort(80));
31+
} catch (Exception e) {
32+
System.err.println("HFEI container failed to start");
33+
throw e;
34+
}
2035
}
2136

2237
@Test
@@ -45,5 +60,15 @@ public void testWithModel() throws ApiException, EFException {
4560
assertNotNull(results);
4661
assertEquals(768, results.getDimensions());
4762
}
63+
64+
@Test
65+
public void testWithURL() throws EFException {
66+
EmbeddingFunction ef = new HuggingFaceEmbeddingFunction(
67+
WithParam.baseAPI(System.getProperty("HFEI_URL")),
68+
new HuggingFaceEmbeddingFunction.WithAPIType(HuggingFaceEmbeddingFunction.APIType.HFEI_API));
69+
Embedding results = ef.embedQuery("How are you?");
70+
assertNotNull(results);
71+
assertEquals(384, results.getDimensions());
72+
}
4873
}
4974

0 commit comments

Comments
 (0)