Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Default EF implementation #39

Merged
merged 8 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 2 additions & 35 deletions .github/workflows/integration-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
integration-test:
strategy:
matrix:
chroma-version: [ 0.4.3, 0.4.4 ]
chroma-version: [0.4.24, 0.5.0, 0.5.5 ]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
Expand All @@ -23,43 +23,10 @@ jobs:
java-version: '8'
distribution: 'adopt'
cache: maven
- name: Install Helm
uses: azure/setup-helm@v1
with:
version: v3.4.0

- name: start minikube
id: minikube
uses: medyagh/setup-minikube@latest
with:
kubernetes-version: 1.27.3
- name: Add helm repo
run: |
set -e
helm repo add chromadb https://amikos-tech.github.io/chromadb-chart/
helm repo update
- name: Install chromadb
run: |
set -e
helm install chromadb chromadb/chromadb --set chromadb.allowReset=true,chromadb.apiVersion=${{ matrix.chroma-version }}
- name: Wait for deployment to be ready
id: wait-and-set
run: |
set -e
kubectl wait \
--for=condition=ready pod \
--selector=app.kubernetes.io/name=chromadb \
--timeout=120s
echo "chroma-url=$(minikube service chromadb --url)" >> $GITHUB_OUTPUT
- name: Hearthbeat
run: |
set -e
kubectl get svc -A
curl $(minikube service chromadb --url)/api/v1
- name: Test with Maven
run: mvn --batch-mode clean test
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
HF_API_KEY: ${{ secrets.HF_API_KEY }}
CHROMA_URL: ${{steps.wait-and-set.outputs.chroma-url}}
CHROMA_VERSION: ${{ matrix.chroma-version }}
70 changes: 69 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
<maven.compiler.target>${java.version}</maven.compiler.target>
<gson-fire-version>1.8.5</gson-fire-version>
<swagger-core-version>1.6.9</swagger-core-version>
<okhttp-version>4.10.0</okhttp-version>
<okhttp-version>4.12.0</okhttp-version>
<gson-version>2.10.1</gson-version>
<threetenbp-version>1.6.5</threetenbp-version>
<maven-plugin-version>1.0.0</maven-plugin-version>
Expand Down Expand Up @@ -112,6 +112,31 @@
<artifactId>threetenbp</artifactId>
<version>${threetenbp-version}</version>
</dependency>
<dependency>
<groupId>ai.djl.huggingface</groupId>
<artifactId>tokenizers</artifactId>
<version>0.29.0</version>
</dependency>
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.18.0</version>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.16.1</version> <!-- Replace with the appropriate version -->
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>1.0.0-M2</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-compress</artifactId>
<version>1.27.0</version>
</dependency>
<!-- test dependencies -->
<dependency>
<groupId>junit</groupId>
Expand All @@ -125,6 +150,19 @@
<version>2.35.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>testcontainers-bom</artifactId>
<version>1.20.1</version>
<type>pom</type>
<scope>import</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>chromadb</artifactId>
<version>1.20.1</version>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down Expand Up @@ -395,6 +433,36 @@
<autoReleaseAfterClose>true</autoReleaseAfterClose>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-enforcer-plugin</artifactId>
<version>3.5.0</version>
<executions>
<execution>
<id>enforce-java</id>
<goals>
<goal>enforce</goal>
</goals>
<configuration>
<rules>
<requireJavaVersion>
<version>[1.8,)</version>
</requireJavaVersion>
</rules>
<fail>true</fail>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.1</version>
<configuration>
<source>${maven.compiler.source}</source>
<target>${maven.compiler.target}</target>
</configuration>
</plugin>
</plugins>
</build>

Expand Down
222 changes: 222 additions & 0 deletions src/main/java/tech/amikos/chromadb/DefaultEmbeddingFunction.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
package tech.amikos.chromadb;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.onnxruntime.*;

import java.util.zip.GZIPInputStream;

import org.apache.commons.compress.archivers.tar.*;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.guava.primitives.Floats;

import java.io.*;
import java.net.URL;
import java.nio.LongBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.*;

public class DefaultEmbeddingFunction implements EmbeddingFunction {
public static final String MODEL_NAME = "all-MiniLM-L6-v2";
private static final String ARCHIVE_FILENAME = "onnx.tar.gz";
private static final String MODEL_DOWNLOAD_URL = "https://chroma-onnx-models.s3.amazonaws.com/all-MiniLM-L6-v2/onnx.tar.gz";
private static final String MODEL_SHA256_CHECKSUM = "913d7300ceae3b2dbc2c50d1de4baacab4be7b9380491c27fab7418616a16ec3";
public static final Path MODEL_CACHE_DIR = Paths.get(System.getProperty("user.home"), ".cache", "chroma", "onnx_models", MODEL_NAME);
private static final Path modelPath = Paths.get(MODEL_CACHE_DIR.toString(), "onnx");
private static final Path modelFile = Paths.get(modelPath.toString(), "model.onnx");
private final HuggingFaceTokenizer tokenizer;
private final OrtEnvironment env;
final OrtSession session;

public static float[][] normalize(float[][] v) {
int rows = v.length;
int cols = v[0].length;
float[] norm = new float[rows];

// Step 1: Compute the L2 norm of each row
for (int i = 0; i < rows; i++) {
float sum = 0;
for (int j = 0; j < cols; j++) {
sum += v[i][j] * v[i][j];
}
norm[i] = (float) Math.sqrt(sum);
}

// Step 2: Handle zero norms
for (int i = 0; i < rows; i++) {
if (norm[i] == 0) {
norm[i] = 1e-12f;
}
}

// Step 3: Normalize each row
float[][] normalized = new float[rows][cols];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
normalized[i][j] = v[i][j] / norm[i];
}
}
return normalized;
}

public DefaultEmbeddingFunction() throws EFException {
if (!validateModel()) {
downloadAndSetupModel();
}

Map<String, String> tokenizerConfig = Collections.unmodifiableMap(new HashMap<String, String>() {{
put("padding", "MAX_LENGTH");
put("maxLength", "256");
}});

try {
tokenizer = HuggingFaceTokenizer.newInstance(modelPath, tokenizerConfig);

this.env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
this.session = env.createSession(modelFile.toString(), options);
} catch (OrtException | IOException e) {
throw new EFException(e);
}
}

public List<List<Float>> forward(List<String> documents) throws OrtException {
Encoding[] e = tokenizer.batchEncode(documents, true, false);
ArrayList<Long> inputIds = new ArrayList<>();
ArrayList<Long> attentionMask = new ArrayList<>();
ArrayList<Long> tokenIdtypes = new ArrayList<>();
int maxIds = 0;
for (Encoding encoding : e) {
maxIds = Math.max(maxIds, encoding.getIds().length);
inputIds.addAll(Arrays.asList(Arrays.stream(encoding.getIds()).boxed().toArray(Long[]::new)));
attentionMask.addAll(Arrays.asList(Arrays.stream(encoding.getAttentionMask()).boxed().toArray(Long[]::new)));
tokenIdtypes.addAll(Arrays.asList(Arrays.stream(encoding.getTypeIds()).boxed().toArray(Long[]::new)));
}
long[] inputShape = new long[]{e.length, maxIds};
OnnxTensor inputTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(inputIds.stream().mapToLong(i -> i).toArray()), inputShape);
OnnxTensor attentionTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(attentionMask.stream().mapToLong(i -> i).toArray()), inputShape);
OnnxTensor _tokenIdtypes = OnnxTensor.createTensor(env, LongBuffer.wrap(tokenIdtypes.stream().mapToLong(i -> i).toArray()), inputShape);
// Inputs for all-MiniLM-L6-v2 model
Map<String, ? extends OnnxTensorLike> inputs = Collections.unmodifiableMap(new HashMap<String, OnnxTensorLike>() {{
put("input_ids", inputTensor);
put("attention_mask", attentionTensor);
put("token_type_ids", _tokenIdtypes);
}});
INDArray lastHiddenState = null;
try (OrtSession.Result results = session.run(inputs)) {
lastHiddenState = Nd4j.create((float[][][]) results.get(0).getValue());

}
INDArray attMask = Nd4j.create(attentionMask.stream().mapToDouble(i -> i).toArray(), inputShape, 'c');
INDArray expandedMask = Nd4j.expandDims(attMask, 2).broadcast(lastHiddenState.shape());
INDArray summed = lastHiddenState.mul(expandedMask).sum(1);
INDArray[] clippedSumMask = Nd4j.getExecutioner().exec(
new ClipByValue(expandedMask.sum(1), 1e-9, Double.MAX_VALUE)
);
INDArray embeddings = summed.div(clippedSumMask[0]);
float[][] embeddingsArray = normalize(embeddings.toFloatMatrix());
List<List<Float>> embeddingsList = new ArrayList<>();
for (float[] embedding : embeddingsArray) {
embeddingsList.add(Floats.asList(embedding));
}
return embeddingsList;
}

private static String getSHA256Checksum(String filePath) throws IOException, NoSuchAlgorithmException {
MessageDigest digest = MessageDigest.getInstance("SHA-256");
try (FileInputStream fis = new FileInputStream(filePath)) {
byte[] byteArray = new byte[1024];
int bytesCount;
while ((bytesCount = fis.read(byteArray)) != -1) {
digest.update(byteArray, 0, bytesCount);
}
}
byte[] bytes = digest.digest();
StringBuilder sb = new StringBuilder();
for (byte b : bytes) {
sb.append(String.format("%02x", b));
}
return sb.toString();
}

private static void extractTarGz(Path tarGzPath, Path extractDir) throws IOException {
try (InputStream fileIn = Files.newInputStream(tarGzPath);
GZIPInputStream gzipIn = new GZIPInputStream(fileIn);
TarArchiveInputStream tarIn = new TarArchiveInputStream(gzipIn)) {

TarArchiveEntry entry;
while ((entry = tarIn.getNextTarEntry()) != null) {
Path entryPath = extractDir.resolve(entry.getName());
if (entry.isDirectory()) {
Files.createDirectories(entryPath);
} else {
Files.createDirectories(entryPath.getParent());
try (OutputStream out = Files.newOutputStream(entryPath)) {
byte[] buffer = new byte[1024];
int len;
while ((len = tarIn.read(buffer)) != -1) {
out.write(buffer, 0, len);
}
}
}
}
}
}

private void downloadAndSetupModel() throws EFException {
try (InputStream in = new URL(MODEL_DOWNLOAD_URL).openStream()) {
if (!Files.exists(MODEL_CACHE_DIR)) {
Files.createDirectories(MODEL_CACHE_DIR);
}
Path archivePath = Paths.get(MODEL_CACHE_DIR.toString(), ARCHIVE_FILENAME);
if (!archivePath.toFile().exists()) {
System.err.println("Model not found under " + archivePath + ". Downloading...");
Files.copy(in, archivePath, StandardCopyOption.REPLACE_EXISTING);
}
if (!MODEL_SHA256_CHECKSUM.equals(getSHA256Checksum(archivePath.toString()))) {
throw new RuntimeException("Checksum does not match. Delete the whole directory " + MODEL_CACHE_DIR + " and try again.");
}
extractTarGz(archivePath, MODEL_CACHE_DIR);
archivePath.toFile().delete();
} catch (IOException | NoSuchAlgorithmException e) {
throw new EFException(e);
}
}


/**
* Check if the model is present at the expected location
*
* @return
*/
private boolean validateModel() {
return modelFile.toFile().exists() && modelFile.toFile().isFile();
}

@Override
public List<List<Float>> createEmbedding(List<String> documents) {
try {
return forward(documents);
} catch (OrtException e) {
//TODO not great to throw a runtime exception but we need to update the interface in upcoming release to rethrow
throw new RuntimeException(e);
}
}

@Override
public List<List<Float>> createEmbedding(List<String> documents, String model) {
try {
return forward(documents);
} catch (OrtException e) {
//TODO not great to throw a runtime exception but we need to update the interface in upcoming release to rethrow
throw new RuntimeException(e);
}
}
}
18 changes: 18 additions & 0 deletions src/main/java/tech/amikos/chromadb/EFException.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package tech.amikos.chromadb;

/**
* This exception encapsulates all exceptions thrown by the EmbeddingFunction class.
*/
public class EFException extends Exception {
public EFException(String message) {
super(message);
}

public EFException(String message, Throwable cause) {
super(message, cause);
}

public EFException(Throwable cause) {
super(cause);
}
}
Loading
Loading