diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 91ddcfc..e5c511f 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -57,7 +57,7 @@ jobs: kubectl get svc -A curl $(minikube service chromadb --url)/api/v1 - name: Test with Maven - run: mvn clean test + run: mvn --batch-mode clean test env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} diff --git a/pom.xml b/pom.xml index f085cbf..1cddf65 100644 --- a/pom.xml +++ b/pom.xml @@ -115,6 +115,12 @@ ${junit-version} test + + com.github.tomakehurst + wiremock-jre8 + 2.35.1 + test + diff --git a/src/main/java/tech/amikos/Main.java b/src/main/java/tech/amikos/Main.java deleted file mode 100644 index ab79246..0000000 --- a/src/main/java/tech/amikos/Main.java +++ /dev/null @@ -1,19 +0,0 @@ -package tech.amikos; - -// Press Shift twice to open the Search Everywhere dialog and type `show whitespaces`, -// then press Enter. You can now see whitespace characters in your code. -public class Main { - public static void main(String[] args) { - // Press Opt+Enter with your caret at the highlighted text to see how - // IntelliJ IDEA suggests fixing it. - System.out.printf("Hello and welcome!"); - - // Press Ctrl+R or click the green arrow button in the gutter to run the code. - for (int i = 1; i <= 5; i++) { - - // Press Ctrl+D to start debugging your code. We have set one breakpoint - // for you, but you can always add more by pressing Cmd+F8. - System.out.println("i = " + i); - } - } -} \ No newline at end of file diff --git a/src/main/java/tech/amikos/chromadb/Client.java b/src/main/java/tech/amikos/chromadb/Client.java index 024add4..b9f64ac 100644 --- a/src/main/java/tech/amikos/chromadb/Client.java +++ b/src/main/java/tech/amikos/chromadb/Client.java @@ -16,12 +16,29 @@ */ public class Client { final ApiClient apiClient = new ApiClient(); - + private int timeout = 60; DefaultApi api; public Client(String basePath) { apiClient.setBasePath(basePath); api = new DefaultApi(apiClient); + apiClient.setHttpClient(apiClient.getHttpClient().newBuilder() + .readTimeout(this.timeout, java.util.concurrent.TimeUnit.SECONDS) + .writeTimeout(this.timeout, java.util.concurrent.TimeUnit.SECONDS) + .build()); + api.getApiClient().setUserAgent("Chroma-JavaClient/0.1.x"); + } + + /** + * Set the timeout for the client + * @param timeout timeout in seconds + */ + public void setTimeout(int timeout) { + this.timeout = timeout; + apiClient.setHttpClient(apiClient.getHttpClient().newBuilder() + .readTimeout(this.timeout, java.util.concurrent.TimeUnit.SECONDS) + .writeTimeout(this.timeout, java.util.concurrent.TimeUnit.SECONDS) + .build()); } public Collection getCollection(String collectionName, EmbeddingFunction embeddingFunction) throws ApiException { diff --git a/src/test/java/TestAPI.java b/src/test/java/TestAPI.java index ffdaaa2..39c5201 100644 --- a/src/test/java/TestAPI.java +++ b/src/test/java/TestAPI.java @@ -1,4 +1,6 @@ +import com.github.tomakehurst.wiremock.junit.WireMockRule; import com.google.gson.internal.LinkedTreeMap; +import org.junit.Rule; import org.junit.Test; import tech.amikos.chromadb.*; import tech.amikos.chromadb.Collection; @@ -8,11 +10,16 @@ import java.math.BigDecimal; import java.util.*; +import static com.github.tomakehurst.wiremock.client.WireMock.*; +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; import static org.junit.Assert.*; import static org.junit.Assume.*; public class TestAPI { + @Rule + public WireMockRule wireMockRule = new WireMockRule(8001); + @Test public void testHeartbeat() throws ApiException, IOException { @@ -396,4 +403,36 @@ public void testQueryExampleHF() throws ApiException { } + @Test + public void testTimeoutOk() throws ApiException, IOException { + stubFor(get(urlEqualTo("/api/v1/heartbeat")) + .willReturn(aResponse() + .withHeader("Content-Type", "application/json") + .withBody("{\"nanosecond heartbeat\": 123456789}").withFixedDelay(2000))); + + Utils.loadEnvFile(".env"); + Client client = new Client("http://127.0.0.1:8001"); + client.setTimeout(3); + Map hb = client.heartbeat(); + assertTrue(hb.containsKey("nanosecond heartbeat")); + } + + @Test(expected = ApiException.class) + public void testTimeoutExpires() throws ApiException, IOException{ + stubFor(get(urlEqualTo("/api/v1/heartbeat")) + .willReturn(aResponse() + .withHeader("Content-Type", "application/json") + .withBody("{\"nanosecond heartbeat\": 123456789}").withFixedDelay(2000))); + + Utils.loadEnvFile(".env"); + Client client = new Client("http://127.0.0.1:8001"); + client.setTimeout(1); + try { + client.heartbeat(); + } catch (ApiException e) { + assertTrue(e.getMessage().contains("Read timed out") || e.getMessage().contains("timeout")); + throw e; + } + + } }