diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml
index 91ddcfc..e5c511f 100644
--- a/.github/workflows/integration-test.yml
+++ b/.github/workflows/integration-test.yml
@@ -57,7 +57,7 @@ jobs:
kubectl get svc -A
curl $(minikube service chromadb --url)/api/v1
- name: Test with Maven
- run: mvn clean test
+ run: mvn --batch-mode clean test
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
diff --git a/pom.xml b/pom.xml
index f085cbf..1cddf65 100644
--- a/pom.xml
+++ b/pom.xml
@@ -115,6 +115,12 @@
${junit-version}
test
+
+ com.github.tomakehurst
+ wiremock-jre8
+ 2.35.1
+ test
+
diff --git a/src/main/java/tech/amikos/Main.java b/src/main/java/tech/amikos/Main.java
deleted file mode 100644
index ab79246..0000000
--- a/src/main/java/tech/amikos/Main.java
+++ /dev/null
@@ -1,19 +0,0 @@
-package tech.amikos;
-
-// Press Shift twice to open the Search Everywhere dialog and type `show whitespaces`,
-// then press Enter. You can now see whitespace characters in your code.
-public class Main {
- public static void main(String[] args) {
- // Press Opt+Enter with your caret at the highlighted text to see how
- // IntelliJ IDEA suggests fixing it.
- System.out.printf("Hello and welcome!");
-
- // Press Ctrl+R or click the green arrow button in the gutter to run the code.
- for (int i = 1; i <= 5; i++) {
-
- // Press Ctrl+D to start debugging your code. We have set one breakpoint
- // for you, but you can always add more by pressing Cmd+F8.
- System.out.println("i = " + i);
- }
- }
-}
\ No newline at end of file
diff --git a/src/main/java/tech/amikos/chromadb/Client.java b/src/main/java/tech/amikos/chromadb/Client.java
index 024add4..b9f64ac 100644
--- a/src/main/java/tech/amikos/chromadb/Client.java
+++ b/src/main/java/tech/amikos/chromadb/Client.java
@@ -16,12 +16,29 @@
*/
public class Client {
final ApiClient apiClient = new ApiClient();
-
+ private int timeout = 60;
DefaultApi api;
public Client(String basePath) {
apiClient.setBasePath(basePath);
api = new DefaultApi(apiClient);
+ apiClient.setHttpClient(apiClient.getHttpClient().newBuilder()
+ .readTimeout(this.timeout, java.util.concurrent.TimeUnit.SECONDS)
+ .writeTimeout(this.timeout, java.util.concurrent.TimeUnit.SECONDS)
+ .build());
+ api.getApiClient().setUserAgent("Chroma-JavaClient/0.1.x");
+ }
+
+ /**
+ * Set the timeout for the client
+ * @param timeout timeout in seconds
+ */
+ public void setTimeout(int timeout) {
+ this.timeout = timeout;
+ apiClient.setHttpClient(apiClient.getHttpClient().newBuilder()
+ .readTimeout(this.timeout, java.util.concurrent.TimeUnit.SECONDS)
+ .writeTimeout(this.timeout, java.util.concurrent.TimeUnit.SECONDS)
+ .build());
}
public Collection getCollection(String collectionName, EmbeddingFunction embeddingFunction) throws ApiException {
diff --git a/src/test/java/TestAPI.java b/src/test/java/TestAPI.java
index ffdaaa2..39c5201 100644
--- a/src/test/java/TestAPI.java
+++ b/src/test/java/TestAPI.java
@@ -1,4 +1,6 @@
+import com.github.tomakehurst.wiremock.junit.WireMockRule;
import com.google.gson.internal.LinkedTreeMap;
+import org.junit.Rule;
import org.junit.Test;
import tech.amikos.chromadb.*;
import tech.amikos.chromadb.Collection;
@@ -8,11 +10,16 @@
import java.math.BigDecimal;
import java.util.*;
+import static com.github.tomakehurst.wiremock.client.WireMock.*;
+import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static org.junit.Assert.*;
import static org.junit.Assume.*;
public class TestAPI {
+ @Rule
+ public WireMockRule wireMockRule = new WireMockRule(8001);
+
@Test
public void testHeartbeat() throws ApiException, IOException {
@@ -396,4 +403,36 @@ public void testQueryExampleHF() throws ApiException {
}
+ @Test
+ public void testTimeoutOk() throws ApiException, IOException {
+ stubFor(get(urlEqualTo("/api/v1/heartbeat"))
+ .willReturn(aResponse()
+ .withHeader("Content-Type", "application/json")
+ .withBody("{\"nanosecond heartbeat\": 123456789}").withFixedDelay(2000)));
+
+ Utils.loadEnvFile(".env");
+ Client client = new Client("http://127.0.0.1:8001");
+ client.setTimeout(3);
+ Map hb = client.heartbeat();
+ assertTrue(hb.containsKey("nanosecond heartbeat"));
+ }
+
+ @Test(expected = ApiException.class)
+ public void testTimeoutExpires() throws ApiException, IOException{
+ stubFor(get(urlEqualTo("/api/v1/heartbeat"))
+ .willReturn(aResponse()
+ .withHeader("Content-Type", "application/json")
+ .withBody("{\"nanosecond heartbeat\": 123456789}").withFixedDelay(2000)));
+
+ Utils.loadEnvFile(".env");
+ Client client = new Client("http://127.0.0.1:8001");
+ client.setTimeout(1);
+ try {
+ client.heartbeat();
+ } catch (ApiException e) {
+ assertTrue(e.getMessage().contains("Read timed out") || e.getMessage().contains("timeout"));
+ throw e;
+ }
+
+ }
}