Skip to content

Commit 3da92b5

Browse files
authored
feat: Default EF implementation (#39)
Closes #27 feat: Default EF implementation chore: Adding missing impl
1 parent f6a3c2c commit 3da92b5

File tree

6 files changed

+1203
-123
lines changed

6 files changed

+1203
-123
lines changed

Diff for: .github/workflows/integration-test.yml

+2-35
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
integration-test:
1414
strategy:
1515
matrix:
16-
chroma-version: [ 0.4.3, 0.4.4 ]
16+
chroma-version: [0.4.24, 0.5.0, 0.5.5 ]
1717
runs-on: ubuntu-latest
1818
steps:
1919
- uses: actions/checkout@v3
@@ -23,43 +23,10 @@ jobs:
2323
java-version: '8'
2424
distribution: 'adopt'
2525
cache: maven
26-
- name: Install Helm
27-
uses: azure/setup-helm@v1
28-
with:
29-
version: v3.4.0
30-
31-
- name: start minikube
32-
id: minikube
33-
uses: medyagh/setup-minikube@latest
34-
with:
35-
kubernetes-version: 1.27.3
36-
- name: Add helm repo
37-
run: |
38-
set -e
39-
helm repo add chromadb https://amikos-tech.github.io/chromadb-chart/
40-
helm repo update
41-
- name: Install chromadb
42-
run: |
43-
set -e
44-
helm install chromadb chromadb/chromadb --set chromadb.allowReset=true,chromadb.apiVersion=${{ matrix.chroma-version }}
45-
- name: Wait for deployment to be ready
46-
id: wait-and-set
47-
run: |
48-
set -e
49-
kubectl wait \
50-
--for=condition=ready pod \
51-
--selector=app.kubernetes.io/name=chromadb \
52-
--timeout=120s
53-
echo "chroma-url=$(minikube service chromadb --url)" >> $GITHUB_OUTPUT
54-
- name: Hearthbeat
55-
run: |
56-
set -e
57-
kubectl get svc -A
58-
curl $(minikube service chromadb --url)/api/v1
5926
- name: Test with Maven
6027
run: mvn --batch-mode clean test
6128
env:
6229
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
6330
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
6431
HF_API_KEY: ${{ secrets.HF_API_KEY }}
65-
CHROMA_URL: ${{steps.wait-and-set.outputs.chroma-url}}
32+
CHROMA_VERSION: ${{ matrix.chroma-version }}

Diff for: pom.xml

+69-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
<maven.compiler.target>${java.version}</maven.compiler.target>
6464
<gson-fire-version>1.8.5</gson-fire-version>
6565
<swagger-core-version>1.6.9</swagger-core-version>
66-
<okhttp-version>4.10.0</okhttp-version>
66+
<okhttp-version>4.12.0</okhttp-version>
6767
<gson-version>2.10.1</gson-version>
6868
<threetenbp-version>1.6.5</threetenbp-version>
6969
<maven-plugin-version>1.0.0</maven-plugin-version>
@@ -112,6 +112,31 @@
112112
<artifactId>threetenbp</artifactId>
113113
<version>${threetenbp-version}</version>
114114
</dependency>
115+
<dependency>
116+
<groupId>ai.djl.huggingface</groupId>
117+
<artifactId>tokenizers</artifactId>
118+
<version>0.29.0</version>
119+
</dependency>
120+
<dependency>
121+
<groupId>com.microsoft.onnxruntime</groupId>
122+
<artifactId>onnxruntime</artifactId>
123+
<version>1.18.0</version>
124+
</dependency>
125+
<dependency>
126+
<groupId>commons-io</groupId>
127+
<artifactId>commons-io</artifactId>
128+
<version>2.16.1</version> <!-- Replace with the appropriate version -->
129+
</dependency>
130+
<dependency>
131+
<groupId>org.nd4j</groupId>
132+
<artifactId>nd4j-native-platform</artifactId>
133+
<version>1.0.0-M2</version>
134+
</dependency>
135+
<dependency>
136+
<groupId>org.apache.commons</groupId>
137+
<artifactId>commons-compress</artifactId>
138+
<version>1.27.0</version>
139+
</dependency>
115140
<!-- test dependencies -->
116141
<dependency>
117142
<groupId>junit</groupId>
@@ -125,6 +150,19 @@
125150
<version>2.35.1</version>
126151
<scope>test</scope>
127152
</dependency>
153+
<dependency>
154+
<groupId>org.testcontainers</groupId>
155+
<artifactId>testcontainers-bom</artifactId>
156+
<version>1.20.1</version>
157+
<type>pom</type>
158+
<scope>import</scope>
159+
</dependency>
160+
<dependency>
161+
<groupId>org.testcontainers</groupId>
162+
<artifactId>chromadb</artifactId>
163+
<version>1.20.1</version>
164+
<scope>test</scope>
165+
</dependency>
128166
</dependencies>
129167

130168
<build>
@@ -395,6 +433,36 @@
395433
<autoReleaseAfterClose>true</autoReleaseAfterClose>
396434
</configuration>
397435
</plugin>
436+
<plugin>
437+
<groupId>org.apache.maven.plugins</groupId>
438+
<artifactId>maven-enforcer-plugin</artifactId>
439+
<version>3.5.0</version>
440+
<executions>
441+
<execution>
442+
<id>enforce-java</id>
443+
<goals>
444+
<goal>enforce</goal>
445+
</goals>
446+
<configuration>
447+
<rules>
448+
<requireJavaVersion>
449+
<version>[1.8,)</version>
450+
</requireJavaVersion>
451+
</rules>
452+
<fail>true</fail>
453+
</configuration>
454+
</execution>
455+
</executions>
456+
</plugin>
457+
<plugin>
458+
<groupId>org.apache.maven.plugins</groupId>
459+
<artifactId>maven-compiler-plugin</artifactId>
460+
<version>3.8.1</version>
461+
<configuration>
462+
<source>${maven.compiler.source}</source>
463+
<target>${maven.compiler.target}</target>
464+
</configuration>
465+
</plugin>
398466
</plugins>
399467
</build>
400468

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
package tech.amikos.chromadb;
2+
3+
import ai.djl.huggingface.tokenizers.Encoding;
4+
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
5+
import ai.onnxruntime.*;
6+
7+
import java.util.zip.GZIPInputStream;
8+
9+
import org.apache.commons.compress.archivers.tar.*;
10+
import org.nd4j.linalg.api.ndarray.INDArray;
11+
import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue;
12+
import org.nd4j.linalg.factory.Nd4j;
13+
import org.nd4j.shade.guava.primitives.Floats;
14+
15+
import java.io.*;
16+
import java.net.URL;
17+
import java.nio.LongBuffer;
18+
import java.nio.file.Files;
19+
import java.nio.file.Path;
20+
import java.nio.file.Paths;
21+
import java.nio.file.StandardCopyOption;
22+
import java.security.MessageDigest;
23+
import java.security.NoSuchAlgorithmException;
24+
import java.util.*;
25+
26+
public class DefaultEmbeddingFunction implements EmbeddingFunction {
27+
public static final String MODEL_NAME = "all-MiniLM-L6-v2";
28+
private static final String ARCHIVE_FILENAME = "onnx.tar.gz";
29+
private static final String MODEL_DOWNLOAD_URL = "https://chroma-onnx-models.s3.amazonaws.com/all-MiniLM-L6-v2/onnx.tar.gz";
30+
private static final String MODEL_SHA256_CHECKSUM = "913d7300ceae3b2dbc2c50d1de4baacab4be7b9380491c27fab7418616a16ec3";
31+
public static final Path MODEL_CACHE_DIR = Paths.get(System.getProperty("user.home"), ".cache", "chroma", "onnx_models", MODEL_NAME);
32+
private static final Path modelPath = Paths.get(MODEL_CACHE_DIR.toString(), "onnx");
33+
private static final Path modelFile = Paths.get(modelPath.toString(), "model.onnx");
34+
private final HuggingFaceTokenizer tokenizer;
35+
private final OrtEnvironment env;
36+
final OrtSession session;
37+
38+
public static float[][] normalize(float[][] v) {
39+
int rows = v.length;
40+
int cols = v[0].length;
41+
float[] norm = new float[rows];
42+
43+
// Step 1: Compute the L2 norm of each row
44+
for (int i = 0; i < rows; i++) {
45+
float sum = 0;
46+
for (int j = 0; j < cols; j++) {
47+
sum += v[i][j] * v[i][j];
48+
}
49+
norm[i] = (float) Math.sqrt(sum);
50+
}
51+
52+
// Step 2: Handle zero norms
53+
for (int i = 0; i < rows; i++) {
54+
if (norm[i] == 0) {
55+
norm[i] = 1e-12f;
56+
}
57+
}
58+
59+
// Step 3: Normalize each row
60+
float[][] normalized = new float[rows][cols];
61+
for (int i = 0; i < rows; i++) {
62+
for (int j = 0; j < cols; j++) {
63+
normalized[i][j] = v[i][j] / norm[i];
64+
}
65+
}
66+
return normalized;
67+
}
68+
69+
public DefaultEmbeddingFunction() throws EFException {
70+
if (!validateModel()) {
71+
downloadAndSetupModel();
72+
}
73+
74+
Map<String, String> tokenizerConfig = Collections.unmodifiableMap(new HashMap<String, String>() {{
75+
put("padding", "MAX_LENGTH");
76+
put("maxLength", "256");
77+
}});
78+
79+
try {
80+
tokenizer = HuggingFaceTokenizer.newInstance(modelPath, tokenizerConfig);
81+
82+
this.env = OrtEnvironment.getEnvironment();
83+
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
84+
this.session = env.createSession(modelFile.toString(), options);
85+
} catch (OrtException | IOException e) {
86+
throw new EFException(e);
87+
}
88+
}
89+
90+
public List<List<Float>> forward(List<String> documents) throws OrtException {
91+
Encoding[] e = tokenizer.batchEncode(documents, true, false);
92+
ArrayList<Long> inputIds = new ArrayList<>();
93+
ArrayList<Long> attentionMask = new ArrayList<>();
94+
ArrayList<Long> tokenIdtypes = new ArrayList<>();
95+
int maxIds = 0;
96+
for (Encoding encoding : e) {
97+
maxIds = Math.max(maxIds, encoding.getIds().length);
98+
inputIds.addAll(Arrays.asList(Arrays.stream(encoding.getIds()).boxed().toArray(Long[]::new)));
99+
attentionMask.addAll(Arrays.asList(Arrays.stream(encoding.getAttentionMask()).boxed().toArray(Long[]::new)));
100+
tokenIdtypes.addAll(Arrays.asList(Arrays.stream(encoding.getTypeIds()).boxed().toArray(Long[]::new)));
101+
}
102+
long[] inputShape = new long[]{e.length, maxIds};
103+
OnnxTensor inputTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(inputIds.stream().mapToLong(i -> i).toArray()), inputShape);
104+
OnnxTensor attentionTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(attentionMask.stream().mapToLong(i -> i).toArray()), inputShape);
105+
OnnxTensor _tokenIdtypes = OnnxTensor.createTensor(env, LongBuffer.wrap(tokenIdtypes.stream().mapToLong(i -> i).toArray()), inputShape);
106+
// Inputs for all-MiniLM-L6-v2 model
107+
Map<String, ? extends OnnxTensorLike> inputs = Collections.unmodifiableMap(new HashMap<String, OnnxTensorLike>() {{
108+
put("input_ids", inputTensor);
109+
put("attention_mask", attentionTensor);
110+
put("token_type_ids", _tokenIdtypes);
111+
}});
112+
INDArray lastHiddenState = null;
113+
try (OrtSession.Result results = session.run(inputs)) {
114+
lastHiddenState = Nd4j.create((float[][][]) results.get(0).getValue());
115+
116+
}
117+
INDArray attMask = Nd4j.create(attentionMask.stream().mapToDouble(i -> i).toArray(), inputShape, 'c');
118+
INDArray expandedMask = Nd4j.expandDims(attMask, 2).broadcast(lastHiddenState.shape());
119+
INDArray summed = lastHiddenState.mul(expandedMask).sum(1);
120+
INDArray[] clippedSumMask = Nd4j.getExecutioner().exec(
121+
new ClipByValue(expandedMask.sum(1), 1e-9, Double.MAX_VALUE)
122+
);
123+
INDArray embeddings = summed.div(clippedSumMask[0]);
124+
float[][] embeddingsArray = normalize(embeddings.toFloatMatrix());
125+
List<List<Float>> embeddingsList = new ArrayList<>();
126+
for (float[] embedding : embeddingsArray) {
127+
embeddingsList.add(Floats.asList(embedding));
128+
}
129+
return embeddingsList;
130+
}
131+
132+
private static String getSHA256Checksum(String filePath) throws IOException, NoSuchAlgorithmException {
133+
MessageDigest digest = MessageDigest.getInstance("SHA-256");
134+
try (FileInputStream fis = new FileInputStream(filePath)) {
135+
byte[] byteArray = new byte[1024];
136+
int bytesCount;
137+
while ((bytesCount = fis.read(byteArray)) != -1) {
138+
digest.update(byteArray, 0, bytesCount);
139+
}
140+
}
141+
byte[] bytes = digest.digest();
142+
StringBuilder sb = new StringBuilder();
143+
for (byte b : bytes) {
144+
sb.append(String.format("%02x", b));
145+
}
146+
return sb.toString();
147+
}
148+
149+
private static void extractTarGz(Path tarGzPath, Path extractDir) throws IOException {
150+
try (InputStream fileIn = Files.newInputStream(tarGzPath);
151+
GZIPInputStream gzipIn = new GZIPInputStream(fileIn);
152+
TarArchiveInputStream tarIn = new TarArchiveInputStream(gzipIn)) {
153+
154+
TarArchiveEntry entry;
155+
while ((entry = tarIn.getNextTarEntry()) != null) {
156+
Path entryPath = extractDir.resolve(entry.getName());
157+
if (entry.isDirectory()) {
158+
Files.createDirectories(entryPath);
159+
} else {
160+
Files.createDirectories(entryPath.getParent());
161+
try (OutputStream out = Files.newOutputStream(entryPath)) {
162+
byte[] buffer = new byte[1024];
163+
int len;
164+
while ((len = tarIn.read(buffer)) != -1) {
165+
out.write(buffer, 0, len);
166+
}
167+
}
168+
}
169+
}
170+
}
171+
}
172+
173+
private void downloadAndSetupModel() throws EFException {
174+
try (InputStream in = new URL(MODEL_DOWNLOAD_URL).openStream()) {
175+
if (!Files.exists(MODEL_CACHE_DIR)) {
176+
Files.createDirectories(MODEL_CACHE_DIR);
177+
}
178+
Path archivePath = Paths.get(MODEL_CACHE_DIR.toString(), ARCHIVE_FILENAME);
179+
if (!archivePath.toFile().exists()) {
180+
System.err.println("Model not found under " + archivePath + ". Downloading...");
181+
Files.copy(in, archivePath, StandardCopyOption.REPLACE_EXISTING);
182+
}
183+
if (!MODEL_SHA256_CHECKSUM.equals(getSHA256Checksum(archivePath.toString()))) {
184+
throw new RuntimeException("Checksum does not match. Delete the whole directory " + MODEL_CACHE_DIR + " and try again.");
185+
}
186+
extractTarGz(archivePath, MODEL_CACHE_DIR);
187+
archivePath.toFile().delete();
188+
} catch (IOException | NoSuchAlgorithmException e) {
189+
throw new EFException(e);
190+
}
191+
}
192+
193+
194+
/**
195+
* Check if the model is present at the expected location
196+
*
197+
* @return
198+
*/
199+
private boolean validateModel() {
200+
return modelFile.toFile().exists() && modelFile.toFile().isFile();
201+
}
202+
203+
@Override
204+
public List<List<Float>> createEmbedding(List<String> documents) {
205+
try {
206+
return forward(documents);
207+
} catch (OrtException e) {
208+
//TODO not great to throw a runtime exception but we need to update the interface in upcoming release to rethrow
209+
throw new RuntimeException(e);
210+
}
211+
}
212+
213+
@Override
214+
public List<List<Float>> createEmbedding(List<String> documents, String model) {
215+
try {
216+
return forward(documents);
217+
} catch (OrtException e) {
218+
//TODO not great to throw a runtime exception but we need to update the interface in upcoming release to rethrow
219+
throw new RuntimeException(e);
220+
}
221+
}
222+
}

Diff for: src/main/java/tech/amikos/chromadb/EFException.java

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package tech.amikos.chromadb;
2+
3+
/**
4+
* This exception encapsulates all exceptions thrown by the EmbeddingFunction class.
5+
*/
6+
public class EFException extends Exception {
7+
public EFException(String message) {
8+
super(message);
9+
}
10+
11+
public EFException(String message, Throwable cause) {
12+
super(message, cause);
13+
}
14+
15+
public EFException(Throwable cause) {
16+
super(cause);
17+
}
18+
}

0 commit comments

Comments
 (0)