From 2f00b3e74b2359f41da1ceed7674a2e4fd127552 Mon Sep 17 00:00:00 2001 From: Danny Jelsma <73849717+Dannyj1@users.noreply.github.com> Date: Sun, 25 Feb 2024 18:02:39 +0100 Subject: [PATCH] Implemented embeddings --- .github/workflows/build.yml | 2 +- README.md | 57 +++++++-- .../java/nl/dannyj/mistral/MistralClient.java | 75 ++++++++++-- .../mistral/builders/MessageListBuilder.java | 10 +- .../nl/dannyj/mistral/models/Request.java | 20 ++++ .../nl/dannyj/mistral/models/Response.java | 20 ++++ .../ChatCompletionRequest.java | 6 +- .../ChatCompletionResponse.java | 11 +- .../models/{ => completion}/Choice.java | 2 +- .../models/{ => completion}/Message.java | 2 +- .../models/{ => completion}/MessageRole.java | 2 +- .../models/embedding/EmbeddingRequest.java | 64 ++++++++++ .../models/embedding/EmbeddingResponse.java | 63 ++++++++++ .../models/embedding/FloatEmbedding.java | 47 ++++++++ .../ListModelsResponse.java | 9 +- .../mistral/models/{ => model}/Model.java | 2 +- .../models/{ => model}/ModelPermission.java | 2 +- .../mistral/models/{ => usage}/Usage.java | 14 +-- .../dannyj/mistral/services/HttpService.java | 6 +- .../mistral/services/MistralService.java | 112 ++++++++++++++---- 20 files changed, 459 insertions(+), 67 deletions(-) create mode 100644 src/main/java/nl/dannyj/mistral/models/Request.java create mode 100644 src/main/java/nl/dannyj/mistral/models/Response.java rename src/main/java/nl/dannyj/mistral/models/{request => completion}/ChatCompletionRequest.java (96%) rename src/main/java/nl/dannyj/mistral/models/{response => completion}/ChatCompletionResponse.java (84%) rename src/main/java/nl/dannyj/mistral/models/{ => completion}/Choice.java (96%) rename src/main/java/nl/dannyj/mistral/models/{ => completion}/Message.java (95%) rename src/main/java/nl/dannyj/mistral/models/{ => completion}/MessageRole.java (95%) create mode 100644 src/main/java/nl/dannyj/mistral/models/embedding/EmbeddingRequest.java create mode 100644 src/main/java/nl/dannyj/mistral/models/embedding/EmbeddingResponse.java create mode 100644 src/main/java/nl/dannyj/mistral/models/embedding/FloatEmbedding.java rename src/main/java/nl/dannyj/mistral/models/{response => model}/ListModelsResponse.java (83%) rename src/main/java/nl/dannyj/mistral/models/{ => model}/Model.java (97%) rename src/main/java/nl/dannyj/mistral/models/{ => model}/ModelPermission.java (97%) rename src/main/java/nl/dannyj/mistral/models/{ => usage}/Usage.java (97%) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 94f3051..97ca4d0 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -4,7 +4,7 @@ on: branches: - master pull_request: - types: [opened, synchronize, reopened] + types: [ opened, synchronize, reopened ] jobs: build: name: Build and analyze diff --git a/README.md b/README.md index 574cc6f..604bda9 100644 --- a/README.md +++ b/README.md @@ -7,11 +7,12 @@ Currently supports all chat completion models. At the time of writing these are: - mistral-tiny - mistral-small - mistral-medium +- mistral-embed -The embedding endpoint will be supported at a later date. +New models or models not listed here may be already supported without any updates to the library. -**NOTE:** This library is currently in **alpha**. It is currently NOT possible to using streaming in message completions -or to use embedding models. These features will be added in the future. The currently supported APIs should be stable +**NOTE:** This library is currently in **alpha**. It is currently NOT possible to using streaming in message +completions. This will be added in the future. The currently supported APIs should be stable however. # Supported APIs @@ -20,11 +21,13 @@ Mistral-java-client is built against version 0.0.1 of the [Mistral AI API](https - [Create Chat Completions](https://docs.mistral.ai/api/#operation/createChatCompletion) - [List Available Models](https://docs.mistral.ai/api/#operation/listModels) -- "Create Embeddings" to be implemented later +- [Create Embeddings](https://docs.mistral.ai/guides/embeddings/) # Requirements + - Java 17 or higher -- A Mistral AI API Key (see the [Mistral documentation](https://docs.mistral.ai/#api-access) for more details on API access) +- A Mistral AI API Key (see the [Mistral documentation](https://docs.mistral.ai/#api-access) for more details on API + access) # Installation @@ -72,12 +75,15 @@ String apiKey = "API_KEY_HERE"; MistralClient client = new MistralClient(apiKey); // Get a list of available models -List models = client.listModels().getModels(); +List models = client.listModels().getModels(); // Loop through all available models and print their ID. The id can be used to specify the model when creating chat completions -for (Model model : models) { - System.out.println(model.getId()); -} +for( +Model model :models){ + System.out. + +println(model.getId()); + } ``` Example output: @@ -137,10 +143,41 @@ public class HelloWorld { ''' ``` +## Embeddings + +```java +// You can also put the API key in an environment variable called MISTRAL_API_KEY and remove the apiKey parameter given to the MistralClient constructor +String apiKey = "API_KEY_HERE"; + +// Initialize the client. This should ideally only be done once. The instance should be re-used for multiple requests +MistralClient client = new MistralClient(apiKey); +List exampleTexts = List.of( + "This is a test sentence.", + "This is another test sentence." +); + +EmbeddingRequest embeddingRequest = EmbeddingRequest.builder() + .model("mistral-embed") // mistral-embed is currently the only model available for embedding + .input(exampleTexts) + .build(); + +EmbeddingResponse embeddingsResponse = client.createEmbedding(embeddingRequest); +// Embeddings are returned as a list of FloatEmbedding objects. FloatEmbedding objects contain a list of floats per input string. +// See the Mistral documentation for more information: https://docs.mistral.ai/guides/embeddings/ +List embeddings = embeddingsResponse.getData(); +embeddings.forEach(embedding -> System.out.println(embedding.getEmbedding())); +``` + +Example output: + +``` +[-0.028015137, 0.02532959, 0.042785645, ... , -0.020980835, 0.011947632, -0.0035934448] +[-0.02015686, 0.04272461, 0.05529785, ... , -0.006855011, 0.009529114, -0.016448975] +``` + # Roadmap - [ ] Add support for streaming in message completions -- [ ] Add support for embedding models - [ ] Figure out how Mistral handles rate limiting and create a queue system to handle it - [ ] Unit tests diff --git a/src/main/java/nl/dannyj/mistral/MistralClient.java b/src/main/java/nl/dannyj/mistral/MistralClient.java index 355ea26..e4dbe6c 100644 --- a/src/main/java/nl/dannyj/mistral/MistralClient.java +++ b/src/main/java/nl/dannyj/mistral/MistralClient.java @@ -17,13 +17,17 @@ package nl.dannyj.mistral; import com.fasterxml.jackson.databind.ObjectMapper; +import jakarta.validation.ConstraintViolationException; import lombok.Getter; import lombok.NonNull; import lombok.Setter; +import nl.dannyj.mistral.exceptions.UnexpectedResponseException; import nl.dannyj.mistral.interceptors.MistralHeaderInterceptor; -import nl.dannyj.mistral.models.request.ChatCompletionRequest; -import nl.dannyj.mistral.models.response.ChatCompletionResponse; -import nl.dannyj.mistral.models.response.ListModelsResponse; +import nl.dannyj.mistral.models.completion.ChatCompletionRequest; +import nl.dannyj.mistral.models.completion.ChatCompletionResponse; +import nl.dannyj.mistral.models.embedding.EmbeddingRequest; +import nl.dannyj.mistral.models.embedding.EmbeddingResponse; +import nl.dannyj.mistral.models.model.ListModelsResponse; import nl.dannyj.mistral.services.HttpService; import nl.dannyj.mistral.services.MistralService; import okhttp3.OkHttpClient; @@ -50,6 +54,7 @@ public class MistralClient { /** * Constructor that initializes the MistralClient with a provided API key. + * * @param apiKey The API key to be used for the Mistral AI API */ public MistralClient(@NonNull String apiKey) { @@ -71,8 +76,9 @@ public MistralClient() { /** * Constructor that initializes the MistralClient with a provided API key, HTTP client, and object mapper. - * @param apiKey The API key to be used for the Mistral AI API - * @param httpClient The OkHttpClient to be used for making requests to the Mistral AI API + * + * @param apiKey The API key to be used for the Mistral AI API + * @param httpClient The OkHttpClient to be used for making requests to the Mistral AI API * @param objectMapper The Jackson ObjectMapper to be used for serializing and deserializing JSON */ public MistralClient(@NonNull String apiKey, @NonNull OkHttpClient httpClient, @NonNull ObjectMapper objectMapper) { @@ -84,7 +90,8 @@ public MistralClient(@NonNull String apiKey, @NonNull OkHttpClient httpClient, @ /** * Constructor that initializes the MistralClient with a provided API key and HTTP client. - * @param apiKey The API key to be used for the Mistral AI API + * + * @param apiKey The API key to be used for the Mistral AI API * @param httpClient The OkHttpClient to be used for making requests to the Mistral AI API */ public MistralClient(@NonNull String apiKey, @NonNull OkHttpClient httpClient) { @@ -96,7 +103,8 @@ public MistralClient(@NonNull String apiKey, @NonNull OkHttpClient httpClient) { /** * Constructor that initializes the MistralClient with a provided API key and object mapper. - * @param apiKey The API key to be used for the Mistral AI API + * + * @param apiKey The API key to be used for the Mistral AI API * @param objectMapper The Jackson ObjectMapper to be used for serializing and deserializing JSON */ public MistralClient(@NonNull String apiKey, @NonNull ObjectMapper objectMapper) { @@ -109,8 +117,12 @@ public MistralClient(@NonNull String apiKey, @NonNull ObjectMapper objectMapper) /** * Use the Mistral AI API to create a chat completion (an assistant reply to the conversation). * This is a blocking method. + * * @param request The request to create a chat completion. See {@link ChatCompletionRequest}. * @return The response from the Mistral AI API containing the generated message. See {@link ChatCompletionResponse}. + * @throws ConstraintViolationException if the request does not pass validation + * @throws UnexpectedResponseException if an unexpected response is received from the Mistral AI API + * @throws IllegalArgumentException if the first message role is not 'user' or 'system' */ public ChatCompletionResponse createChatCompletion(@NonNull ChatCompletionRequest request) { return mistralService.createChatCompletion(request); @@ -119,23 +131,70 @@ public ChatCompletionResponse createChatCompletion(@NonNull ChatCompletionReques /** * Use the Mistral AI API to create a chat completion (an assistant reply to the conversation). * This is a non-blocking/asynchronous method. + * * @param request The request to create a chat completion. See {@link ChatCompletionRequest}. * @return A CompletableFuture that will complete with generated message from the Mistral AI API. See {@link ChatCompletionResponse}. + * @throws ConstraintViolationException if the request does not pass validation + * @throws UnexpectedResponseException if an unexpected response is received from the Mistral AI API + * @throws IllegalArgumentException if the first message role is not 'user' or 'system' */ public CompletableFuture createChatCompletionAsync(@NonNull ChatCompletionRequest request) { return mistralService.createChatCompletionAsync(request); } + /** + * This method is used to create an embedding using the Mistral AI API. + * The embeddings for the input strings. See the mistral documentation for more details on embeddings. + * This is a blocking method. + * + * @param request The request to create an embedding. See {@link EmbeddingRequest}. + * @return The response from the Mistral AI API containing the generated embedding. See {@link EmbeddingResponse}. + * @throws ConstraintViolationException if the request does not pass validation + * @throws UnexpectedResponseException if an unexpected response is received from the Mistral AI API + */ + public EmbeddingResponse createEmbedding(@NonNull EmbeddingRequest request) { + return mistralService.createEmbedding(request); + } + + /** + * This method is used to create an embedding using the Mistral AI API. + * The embeddings for the input strings. See the mistral documentation for more details on embeddings. + * This is a non-blocking/asynchronous method. + * + * @param request The request to create an embedding. See {@link EmbeddingRequest}. + * @return A CompletableFuture that will complete with the generated embedding from the Mistral AI API. See {@link EmbeddingResponse}. + * @throws ConstraintViolationException if the request does not pass validation + * @throws UnexpectedResponseException if an unexpected response is received from the Mistral AI API + */ + public CompletableFuture createEmbeddingAsync(@NonNull EmbeddingRequest request) { + return mistralService.createEmbeddingAsync(request); + } + /** * Lists all models available according to the Mistral AI API. + * This is a blocking method. + * * @return The response from the Mistral AI API containing the list of models. See {@link ListModelsResponse}. + * @throws UnexpectedResponseException if an unexpected response is received from the Mistral AI API */ public ListModelsResponse listModels() { return mistralService.listModels(); } + /** + * Lists all models available according to the Mistral AI API. + * This is a non-blocking/asynchronous method. + * + * @return A CompletableFuture that will complete with the list of models from the Mistral AI API. See {@link ListModelsResponse}. + * @throws UnexpectedResponseException if an unexpected response is received from the Mistral AI API + */ + public CompletableFuture listModelsAsync() { + return mistralService.listModelsAsync(); + } + /** * Builds the MistralService. + * * @return A new instance of MistralService */ private MistralService buildMistralService() { @@ -144,6 +203,7 @@ private MistralService buildMistralService() { /** * Builds the HTTP client. + * * @return A new instance of OkHttpClient */ private OkHttpClient buildHttpClient() { @@ -157,6 +217,7 @@ private OkHttpClient buildHttpClient() { /** * Builds the object mapper. + * * @return A new instance of ObjectMapper */ private ObjectMapper buildObjectMapper() { diff --git a/src/main/java/nl/dannyj/mistral/builders/MessageListBuilder.java b/src/main/java/nl/dannyj/mistral/builders/MessageListBuilder.java index d830aa6..fbbc7e3 100644 --- a/src/main/java/nl/dannyj/mistral/builders/MessageListBuilder.java +++ b/src/main/java/nl/dannyj/mistral/builders/MessageListBuilder.java @@ -16,8 +16,8 @@ package nl.dannyj.mistral.builders; -import nl.dannyj.mistral.models.Message; -import nl.dannyj.mistral.models.MessageRole; +import nl.dannyj.mistral.models.completion.Message; +import nl.dannyj.mistral.models.completion.MessageRole; import java.util.ArrayList; import java.util.List; @@ -40,6 +40,7 @@ public MessageListBuilder() { /** * Constructor that initializes the list of Message objects with a provided list. + * * @param messages The initial list of Message objects */ public MessageListBuilder(List messages) { @@ -48,6 +49,7 @@ public MessageListBuilder(List messages) { /** * Adds a message with the system role to the list with the provided content. + * * @param content The content of the system message * @return The builder instance */ @@ -58,6 +60,7 @@ public MessageListBuilder system(String content) { /** * Adds a message with the assistant role to the list with the provided content. + * * @param content The content of the assistant message * @return The builder instance */ @@ -68,6 +71,7 @@ public MessageListBuilder assistant(String content) { /** * Adds a message with the user role to the list with the provided content. + * * @param content The content of the user message * @return The builder instance */ @@ -78,6 +82,7 @@ public MessageListBuilder user(String content) { /** * Adds a custom Message object to the list. + * * @param message The Message object to be added * @return The builder instance */ @@ -88,6 +93,7 @@ public MessageListBuilder message(Message message) { /** * Returns the list of Message objects that have been added. + * * @return The list of Message objects */ public List build() { diff --git a/src/main/java/nl/dannyj/mistral/models/Request.java b/src/main/java/nl/dannyj/mistral/models/Request.java new file mode 100644 index 0000000..18bfdcb --- /dev/null +++ b/src/main/java/nl/dannyj/mistral/models/Request.java @@ -0,0 +1,20 @@ +/* + * Copyright 2024 Danny Jelsma + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nl.dannyj.mistral.models; + +public interface Request { +} diff --git a/src/main/java/nl/dannyj/mistral/models/Response.java b/src/main/java/nl/dannyj/mistral/models/Response.java new file mode 100644 index 0000000..34015db --- /dev/null +++ b/src/main/java/nl/dannyj/mistral/models/Response.java @@ -0,0 +1,20 @@ +/* + * Copyright 2024 Danny Jelsma + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nl.dannyj.mistral.models; + +public interface Response { +} diff --git a/src/main/java/nl/dannyj/mistral/models/request/ChatCompletionRequest.java b/src/main/java/nl/dannyj/mistral/models/completion/ChatCompletionRequest.java similarity index 96% rename from src/main/java/nl/dannyj/mistral/models/request/ChatCompletionRequest.java rename to src/main/java/nl/dannyj/mistral/models/completion/ChatCompletionRequest.java index 1f8ec3d..32a2e92 100644 --- a/src/main/java/nl/dannyj/mistral/models/request/ChatCompletionRequest.java +++ b/src/main/java/nl/dannyj/mistral/models/completion/ChatCompletionRequest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package nl.dannyj.mistral.models.request; +package nl.dannyj.mistral.models.completion; import com.fasterxml.jackson.annotation.JsonProperty; import jakarta.validation.constraints.*; @@ -22,7 +22,7 @@ import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; -import nl.dannyj.mistral.models.Message; +import nl.dannyj.mistral.models.Request; import java.util.List; @@ -34,7 +34,7 @@ @AllArgsConstructor @NoArgsConstructor @Builder -public class ChatCompletionRequest { +public class ChatCompletionRequest implements Request { /** * ID of the model to use. You can use the List Available Models API to see all of your available models. diff --git a/src/main/java/nl/dannyj/mistral/models/response/ChatCompletionResponse.java b/src/main/java/nl/dannyj/mistral/models/completion/ChatCompletionResponse.java similarity index 84% rename from src/main/java/nl/dannyj/mistral/models/response/ChatCompletionResponse.java rename to src/main/java/nl/dannyj/mistral/models/completion/ChatCompletionResponse.java index 9d5c38e..afa8b43 100644 --- a/src/main/java/nl/dannyj/mistral/models/response/ChatCompletionResponse.java +++ b/src/main/java/nl/dannyj/mistral/models/completion/ChatCompletionResponse.java @@ -14,14 +14,14 @@ * limitations under the License. */ -package nl.dannyj.mistral.models.response; +package nl.dannyj.mistral.models.completion; import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NoArgsConstructor; import lombok.ToString; -import nl.dannyj.mistral.models.Choice; -import nl.dannyj.mistral.models.Usage; +import nl.dannyj.mistral.models.Response; +import nl.dannyj.mistral.models.usage.Usage; import java.util.List; @@ -33,13 +33,16 @@ @AllArgsConstructor @NoArgsConstructor @ToString -public class ChatCompletionResponse { +public class ChatCompletionResponse implements Response { /** * Unique identifier for this response. */ private String id; + /** + * Undocumented, seems to be the type of the response. + */ private String object; /** diff --git a/src/main/java/nl/dannyj/mistral/models/Choice.java b/src/main/java/nl/dannyj/mistral/models/completion/Choice.java similarity index 96% rename from src/main/java/nl/dannyj/mistral/models/Choice.java rename to src/main/java/nl/dannyj/mistral/models/completion/Choice.java index d979ac4..b3ca924 100644 --- a/src/main/java/nl/dannyj/mistral/models/Choice.java +++ b/src/main/java/nl/dannyj/mistral/models/completion/Choice.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package nl.dannyj.mistral.models; +package nl.dannyj.mistral.models.completion; import com.fasterxml.jackson.annotation.JsonProperty; import lombok.AllArgsConstructor; diff --git a/src/main/java/nl/dannyj/mistral/models/Message.java b/src/main/java/nl/dannyj/mistral/models/completion/Message.java similarity index 95% rename from src/main/java/nl/dannyj/mistral/models/Message.java rename to src/main/java/nl/dannyj/mistral/models/completion/Message.java index d4b164c..7c69baf 100644 --- a/src/main/java/nl/dannyj/mistral/models/Message.java +++ b/src/main/java/nl/dannyj/mistral/models/completion/Message.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package nl.dannyj.mistral.models; +package nl.dannyj.mistral.models.completion; import jakarta.validation.constraints.NotNull; import lombok.*; diff --git a/src/main/java/nl/dannyj/mistral/models/MessageRole.java b/src/main/java/nl/dannyj/mistral/models/completion/MessageRole.java similarity index 95% rename from src/main/java/nl/dannyj/mistral/models/MessageRole.java rename to src/main/java/nl/dannyj/mistral/models/completion/MessageRole.java index abae28b..b47b0d1 100644 --- a/src/main/java/nl/dannyj/mistral/models/MessageRole.java +++ b/src/main/java/nl/dannyj/mistral/models/completion/MessageRole.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package nl.dannyj.mistral.models; +package nl.dannyj.mistral.models.completion; import com.fasterxml.jackson.annotation.JsonValue; diff --git a/src/main/java/nl/dannyj/mistral/models/embedding/EmbeddingRequest.java b/src/main/java/nl/dannyj/mistral/models/embedding/EmbeddingRequest.java new file mode 100644 index 0000000..ded1d68 --- /dev/null +++ b/src/main/java/nl/dannyj/mistral/models/embedding/EmbeddingRequest.java @@ -0,0 +1,64 @@ +/* + * Copyright 2024 Danny Jelsma + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nl.dannyj.mistral.models.embedding; + +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.constraints.NotBlank; +import jakarta.validation.constraints.NotNull; +import jakarta.validation.constraints.Size; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; +import nl.dannyj.mistral.models.Request; + +import java.util.List; + +/** + * The EmbeddingRequest class represents a request to create embedding for a list of strings. + * Most of the field descriptions are taken from the Mistral API documentation. + */ +@Data +@AllArgsConstructor +@NoArgsConstructor +@Builder +public class EmbeddingRequest implements Request { + + /** + * The ID of the model to use for this request. + */ + @NotNull + @NotBlank + private String model; + + /** + * The list of strings to embed. + */ + @NotNull + @Size(min = 1) + private List input; + + /** + * The format of the output data. The valid values for this are not documented, so assume only "float" is valid for now. + */ + @JsonProperty("encoding_format") + @Builder.Default + @NotNull + @NotBlank + private String encodingFormat = "float"; + +} diff --git a/src/main/java/nl/dannyj/mistral/models/embedding/EmbeddingResponse.java b/src/main/java/nl/dannyj/mistral/models/embedding/EmbeddingResponse.java new file mode 100644 index 0000000..19fa5fb --- /dev/null +++ b/src/main/java/nl/dannyj/mistral/models/embedding/EmbeddingResponse.java @@ -0,0 +1,63 @@ +/* + * Copyright 2024 Danny Jelsma + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nl.dannyj.mistral.models.embedding; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.ToString; +import nl.dannyj.mistral.models.Response; +import nl.dannyj.mistral.models.usage.Usage; + +import java.util.List; + +/** + * The EmbeddingResponse class represents a response from the Mistral API when creating embeddings. + * Most of these fields are undocumented. + */ +@Getter +@AllArgsConstructor +@NoArgsConstructor +@ToString +public class EmbeddingResponse implements Response { + + /** + * Unique identifier for this response. + */ + private String id; + + /** + * Undocumented, seems to be the type of the response. + */ + private String object; + + /** + * The embeddings that were created for the list of input strings. + */ + private List data; + + /** + * The model used to create the embeddings. + */ + private String model; + + /** + * The token usage of the request. + */ + private Usage usage; + +} diff --git a/src/main/java/nl/dannyj/mistral/models/embedding/FloatEmbedding.java b/src/main/java/nl/dannyj/mistral/models/embedding/FloatEmbedding.java new file mode 100644 index 0000000..c4d9fcc --- /dev/null +++ b/src/main/java/nl/dannyj/mistral/models/embedding/FloatEmbedding.java @@ -0,0 +1,47 @@ +/* + * Copyright 2024 Danny Jelsma + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nl.dannyj.mistral.models.embedding; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.ToString; + +import java.util.List; + +@Getter +@AllArgsConstructor +@NoArgsConstructor +@ToString +public class FloatEmbedding { + + /** + * Undocumented, seems to be the type of the response. + */ + private String object; + + /** + * The embeddings for the input strings. See the mistral documentation for more details on embeddings. + */ + private List embedding; + + /** + * The index of the input string in the input list. + */ + private int index; + +} diff --git a/src/main/java/nl/dannyj/mistral/models/response/ListModelsResponse.java b/src/main/java/nl/dannyj/mistral/models/model/ListModelsResponse.java similarity index 83% rename from src/main/java/nl/dannyj/mistral/models/response/ListModelsResponse.java rename to src/main/java/nl/dannyj/mistral/models/model/ListModelsResponse.java index 1afaa9f..0647d12 100644 --- a/src/main/java/nl/dannyj/mistral/models/response/ListModelsResponse.java +++ b/src/main/java/nl/dannyj/mistral/models/model/ListModelsResponse.java @@ -14,14 +14,14 @@ * limitations under the License. */ -package nl.dannyj.mistral.models.response; +package nl.dannyj.mistral.models.model; import com.fasterxml.jackson.annotation.JsonProperty; import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NoArgsConstructor; import lombok.ToString; -import nl.dannyj.mistral.models.Model; +import nl.dannyj.mistral.models.Response; import java.util.List; @@ -32,8 +32,11 @@ @AllArgsConstructor @NoArgsConstructor @ToString -public class ListModelsResponse { +public class ListModelsResponse implements Response { + /** + * Undocumented, seems to be the type of the response. + */ private String object; /** diff --git a/src/main/java/nl/dannyj/mistral/models/Model.java b/src/main/java/nl/dannyj/mistral/models/model/Model.java similarity index 97% rename from src/main/java/nl/dannyj/mistral/models/Model.java rename to src/main/java/nl/dannyj/mistral/models/model/Model.java index 7ca01a8..19e185e 100644 --- a/src/main/java/nl/dannyj/mistral/models/Model.java +++ b/src/main/java/nl/dannyj/mistral/models/model/Model.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package nl.dannyj.mistral.models; +package nl.dannyj.mistral.models.model; import com.fasterxml.jackson.annotation.JsonProperty; import lombok.AllArgsConstructor; diff --git a/src/main/java/nl/dannyj/mistral/models/ModelPermission.java b/src/main/java/nl/dannyj/mistral/models/model/ModelPermission.java similarity index 97% rename from src/main/java/nl/dannyj/mistral/models/ModelPermission.java rename to src/main/java/nl/dannyj/mistral/models/model/ModelPermission.java index 6f933fa..0e2711b 100644 --- a/src/main/java/nl/dannyj/mistral/models/ModelPermission.java +++ b/src/main/java/nl/dannyj/mistral/models/model/ModelPermission.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package nl.dannyj.mistral.models; +package nl.dannyj.mistral.models.model; import com.fasterxml.jackson.annotation.JsonProperty; import lombok.AllArgsConstructor; diff --git a/src/main/java/nl/dannyj/mistral/models/Usage.java b/src/main/java/nl/dannyj/mistral/models/usage/Usage.java similarity index 97% rename from src/main/java/nl/dannyj/mistral/models/Usage.java rename to src/main/java/nl/dannyj/mistral/models/usage/Usage.java index 7355cc0..a594510 100644 --- a/src/main/java/nl/dannyj/mistral/models/Usage.java +++ b/src/main/java/nl/dannyj/mistral/models/usage/Usage.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package nl.dannyj.mistral.models; +package nl.dannyj.mistral.models.usage; import com.fasterxml.jackson.annotation.JsonProperty; import lombok.AllArgsConstructor; @@ -34,16 +34,16 @@ public class Usage { @JsonProperty("prompt_tokens") private int promptTokens; - /** - * The number of tokens used for the completion ("output tokens"). - */ - @JsonProperty("completion_tokens") - private int completionTokens; - /** * The total number of tokens used (prompt tokens + completion tokens). */ @JsonProperty("total_tokens") private int totalTokens; + /** + * The number of tokens used for the completion ("output tokens"). + */ + @JsonProperty("completion_tokens") + private int completionTokens; + } diff --git a/src/main/java/nl/dannyj/mistral/services/HttpService.java b/src/main/java/nl/dannyj/mistral/services/HttpService.java index e0b6dfc..0eedccf 100644 --- a/src/main/java/nl/dannyj/mistral/services/HttpService.java +++ b/src/main/java/nl/dannyj/mistral/services/HttpService.java @@ -35,6 +35,7 @@ public class HttpService { /** * Constructor that initializes the HttpService with a provided MistralClient. + * * @param client The MistralClient to be used for making requests to the Mistral AI API */ public HttpService(@NonNull MistralClient client) { @@ -43,6 +44,7 @@ public HttpService(@NonNull MistralClient client) { /** * Makes a GET request to the specified URL path. + * * @param urlPath The URL path to make the GET request to * @return The response body as a string */ @@ -57,8 +59,9 @@ public String get(@NonNull String urlPath) { /** * Makes a POST request to the specified URL path with the provided body. + * * @param urlPath The URL path to make the POST request to - * @param body The JSON body of the POST request + * @param body The JSON body of the POST request * @return The response body as a string */ public String post(@NonNull String urlPath, @NonNull String body) { @@ -72,6 +75,7 @@ public String post(@NonNull String urlPath, @NonNull String body) { /** * Executes the provided request using the OkHttpClient from the MistralClient. + * * @param request The request to be executed * @return The response body as a string * @throws MistralAPIException If the response is not successful, the response body is null or an IOException occurs in the objectmapper diff --git a/src/main/java/nl/dannyj/mistral/services/MistralService.java b/src/main/java/nl/dannyj/mistral/services/MistralService.java index 45ff6dc..d6b82d3 100644 --- a/src/main/java/nl/dannyj/mistral/services/MistralService.java +++ b/src/main/java/nl/dannyj/mistral/services/MistralService.java @@ -21,11 +21,15 @@ import lombok.NonNull; import nl.dannyj.mistral.MistralClient; import nl.dannyj.mistral.exceptions.UnexpectedResponseException; -import nl.dannyj.mistral.models.Message; -import nl.dannyj.mistral.models.MessageRole; -import nl.dannyj.mistral.models.request.ChatCompletionRequest; -import nl.dannyj.mistral.models.response.ChatCompletionResponse; -import nl.dannyj.mistral.models.response.ListModelsResponse; +import nl.dannyj.mistral.models.Request; +import nl.dannyj.mistral.models.Response; +import nl.dannyj.mistral.models.completion.ChatCompletionRequest; +import nl.dannyj.mistral.models.completion.ChatCompletionResponse; +import nl.dannyj.mistral.models.completion.Message; +import nl.dannyj.mistral.models.completion.MessageRole; +import nl.dannyj.mistral.models.embedding.EmbeddingRequest; +import nl.dannyj.mistral.models.embedding.EmbeddingResponse; +import nl.dannyj.mistral.models.model.ListModelsResponse; import java.util.Set; import java.util.concurrent.CompletableFuture; @@ -42,7 +46,8 @@ public class MistralService { /** * Constructor that initializes the MistralService with a provided MistralClient and HttpService. - * @param client The MistralClient to be used for interacting with the Mistral AI API + * + * @param client The MistralClient to be used for interacting with the Mistral AI API * @param httpService The HttpService to be used for making HTTP requests */ public MistralService(@NonNull MistralClient client, @NonNull HttpService httpService) { @@ -57,19 +62,14 @@ public MistralService(@NonNull MistralClient client, @NonNull HttpService httpSe /** * Use the Mistral AI API to create a chat completion (an assistant reply to the conversation). * This is a blocking method. + * * @param request The request to create a chat completion. See {@link ChatCompletionRequest}. * @return The response from the Mistral AI API containing the generated message. See {@link ChatCompletionResponse}. * @throws ConstraintViolationException if the request does not pass validation - * @throws UnexpectedResponseException if an unexpected response is received from the Mistral AI API - * @throws IllegalArgumentException if the first message role is not 'user' or 'system' + * @throws UnexpectedResponseException if an unexpected response is received from the Mistral AI API + * @throws IllegalArgumentException if the first message role is not 'user' or 'system' */ public ChatCompletionResponse createChatCompletion(@NonNull ChatCompletionRequest request) { - Set> violations = validator.validate(request); - - if (!violations.isEmpty()) { - throw new ConstraintViolationException(violations); - } - Message firstMessage = request.getMessages().get(0); MessageRole role = firstMessage.getRole(); @@ -77,21 +77,13 @@ public ChatCompletionResponse createChatCompletion(@NonNull ChatCompletionReques throw new IllegalArgumentException("The first message role should be either 'user' or 'system'"); } - String response = null; - - try { - String requestJson = client.getObjectMapper().writeValueAsString(request); - response = httpService.post("/chat/completions", requestJson); - - return client.getObjectMapper().readValue(response, ChatCompletionResponse.class); - } catch (JsonProcessingException e) { - throw new UnexpectedResponseException("Received unexpected response from the Mistral.ai API (mistral-java-client might need to be updated): " + response, e); - } + return validateRequestAndPost("/chat/completions", request, ChatCompletionResponse.class); } /** * Use the Mistral AI API to create a chat completion (an assistant reply to the conversation). * This is a non-blocking/asynchronous method. + * * @param request The request to create a chat completion. See {@link ChatCompletionRequest}. * @return A CompletableFuture that will complete with generated message from the Mistral AI API. See {@link ChatCompletionResponse}. */ @@ -101,6 +93,8 @@ public CompletableFuture createChatCompletionAsync(@NonN /** * Lists all models available according to the Mistral AI API. + * This is a blocking method. + * * @return The response from the Mistral AI API containing the list of models. See {@link ListModelsResponse}. * @throws UnexpectedResponseException if an unexpected response is received from the Mistral AI API */ @@ -113,5 +107,75 @@ public ListModelsResponse listModels() { throw new UnexpectedResponseException("Received unexpected response from the Mistral.ai API (mistral-java-client might need to be updated): " + response, e); } } + + /** + * Lists all models available according to the Mistral AI API. + * This is a non-blocking/asynchronous method. + * + * @return A CompletableFuture that will complete with the list of models from the Mistral AI API. See {@link ListModelsResponse}. + */ + public CompletableFuture listModelsAsync() { + return CompletableFuture.supplyAsync(this::listModels); + } + + /** + * This method is used to create an embedding using the Mistral AI API. + * The embeddings for the input strings. See the mistral documentation for more details on embeddings. + * This is a blocking method. + * + * @param request The request to create an embedding. See {@link EmbeddingRequest}. + * @return The response from the Mistral AI API containing the generated embedding. See {@link EmbeddingResponse}. + * @throws ConstraintViolationException if the request does not pass validation + * @throws UnexpectedResponseException if an unexpected response is received from the Mistral AI API + */ + public EmbeddingResponse createEmbedding(@NonNull EmbeddingRequest request) { + return validateRequestAndPost("/embeddings", request, EmbeddingResponse.class); + } + + /** + * This method is used to create an embedding using the Mistral AI API. + * The embeddings for the input strings. See the mistral documentation for more details on embeddings. + * This is a non-blocking/asynchronous method. + * + * @param request The request to create an embedding. See {@link EmbeddingRequest}. + * @return A CompletableFuture that will complete with the generated embedding from the Mistral AI API. See {@link EmbeddingResponse}. + */ + public CompletableFuture createEmbeddingAsync(@NonNull EmbeddingRequest request) { + return CompletableFuture.supplyAsync(() -> createEmbedding(request)); + } + + /** + * This method is used to validate the request and post it to the specified endpoint. + * It first validates the request using the validator. If there are any constraint violations, it throws a ConstraintViolationException. + * If the request is valid, it converts the request to JSON and sends a POST request to the specified endpoint. + * The response from the endpoint is then converted back to the specified response type and returned. + * + * @param The type of the request. It must extend Request. + * @param The type of the response. It must extend Response. + * @param endpoint The endpoint to which the request should be posted. + * @param request The request to be posted. + * @param responseType The class of the response type. + * @return The response from the endpoint, converted to the specified response type. + * @throws ConstraintViolationException if the request does not pass validation + * @throws UnexpectedResponseException if an unexpected response is received from the Mistral AI API + */ + private U validateRequestAndPost(String endpoint, T request, Class responseType) { + Set> violations = validator.validate(request); + + if (!violations.isEmpty()) { + throw new ConstraintViolationException(violations); + } + + String response = null; + + try { + String requestJson = client.getObjectMapper().writeValueAsString(request); + response = httpService.post(endpoint, requestJson); + + return client.getObjectMapper().readValue(response, responseType); + } catch (JsonProcessingException e) { + throw new UnexpectedResponseException("Received unexpected response from the Mistral.ai API (mistral-java-client might need to be updated): " + response, e); + } + } }