Skip to content

Commit

Permalink
Merge pull request #2 from Dannyj1/embedding
Browse files Browse the repository at this point in the history
Implemented Support for the Embeddings Endpoint
  • Loading branch information
Dannyj1 authored Feb 25, 2024
2 parents b1264a8 + 2f00b3e commit 60d8363
Show file tree
Hide file tree
Showing 20 changed files with 459 additions and 67 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ on:
branches:
- master
pull_request:
types: [opened, synchronize, reopened]
types: [ opened, synchronize, reopened ]
jobs:
build:
name: Build and analyze
Expand Down
57 changes: 47 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ Currently supports all chat completion models. At the time of writing these are:
- mistral-tiny
- mistral-small
- mistral-medium
- mistral-embed

The embedding endpoint will be supported at a later date.
New models or models not listed here may be already supported without any updates to the library.

**NOTE:** This library is currently in **alpha**. It is currently NOT possible to using streaming in message completions
or to use embedding models. These features will be added in the future. The currently supported APIs should be stable
**NOTE:** This library is currently in **alpha**. It is currently NOT possible to using streaming in message
completions. This will be added in the future. The currently supported APIs should be stable
however.

# Supported APIs
Expand All @@ -20,11 +21,13 @@ Mistral-java-client is built against version 0.0.1 of the [Mistral AI API](https

- [Create Chat Completions](https://docs.mistral.ai/api/#operation/createChatCompletion)
- [List Available Models](https://docs.mistral.ai/api/#operation/listModels)
- "Create Embeddings" to be implemented later
- [Create Embeddings](https://docs.mistral.ai/guides/embeddings/)

# Requirements

- Java 17 or higher
- A Mistral AI API Key (see the [Mistral documentation](https://docs.mistral.ai/#api-access) for more details on API access)
- A Mistral AI API Key (see the [Mistral documentation](https://docs.mistral.ai/#api-access) for more details on API
access)

# Installation

Expand Down Expand Up @@ -72,12 +75,15 @@ String apiKey = "API_KEY_HERE";
MistralClient client = new MistralClient(apiKey);

// Get a list of available models
List<Model> models = client.listModels().getModels();
List<Model> models = client.listModels().getModels();

// Loop through all available models and print their ID. The id can be used to specify the model when creating chat completions
for (Model model : models) {
System.out.println(model.getId());
}
for(
Model model :models){
System.out.

println(model.getId());
}
```

Example output:
Expand Down Expand Up @@ -137,10 +143,41 @@ public class HelloWorld {
'''
```

## Embeddings

```java
// You can also put the API key in an environment variable called MISTRAL_API_KEY and remove the apiKey parameter given to the MistralClient constructor
String apiKey = "API_KEY_HERE";

// Initialize the client. This should ideally only be done once. The instance should be re-used for multiple requests
MistralClient client = new MistralClient(apiKey);
List<String> exampleTexts = List.of(
"This is a test sentence.",
"This is another test sentence."
);

EmbeddingRequest embeddingRequest = EmbeddingRequest.builder()
.model("mistral-embed") // mistral-embed is currently the only model available for embedding
.input(exampleTexts)
.build();

EmbeddingResponse embeddingsResponse = client.createEmbedding(embeddingRequest);
// Embeddings are returned as a list of FloatEmbedding objects. FloatEmbedding objects contain a list of floats per input string.
// See the Mistral documentation for more information: https://docs.mistral.ai/guides/embeddings/
List<FloatEmbedding> embeddings = embeddingsResponse.getData();
embeddings.forEach(embedding -> System.out.println(embedding.getEmbedding()));
```

Example output:

```
[-0.028015137, 0.02532959, 0.042785645, ... , -0.020980835, 0.011947632, -0.0035934448]
[-0.02015686, 0.04272461, 0.05529785, ... , -0.006855011, 0.009529114, -0.016448975]
```

# Roadmap

- [ ] Add support for streaming in message completions
- [ ] Add support for embedding models
- [ ] Figure out how Mistral handles rate limiting and create a queue system to handle it
- [ ] Unit tests

Expand Down
75 changes: 68 additions & 7 deletions src/main/java/nl/dannyj/mistral/MistralClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@
package nl.dannyj.mistral;

import com.fasterxml.jackson.databind.ObjectMapper;
import jakarta.validation.ConstraintViolationException;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import nl.dannyj.mistral.exceptions.UnexpectedResponseException;
import nl.dannyj.mistral.interceptors.MistralHeaderInterceptor;
import nl.dannyj.mistral.models.request.ChatCompletionRequest;
import nl.dannyj.mistral.models.response.ChatCompletionResponse;
import nl.dannyj.mistral.models.response.ListModelsResponse;
import nl.dannyj.mistral.models.completion.ChatCompletionRequest;
import nl.dannyj.mistral.models.completion.ChatCompletionResponse;
import nl.dannyj.mistral.models.embedding.EmbeddingRequest;
import nl.dannyj.mistral.models.embedding.EmbeddingResponse;
import nl.dannyj.mistral.models.model.ListModelsResponse;
import nl.dannyj.mistral.services.HttpService;
import nl.dannyj.mistral.services.MistralService;
import okhttp3.OkHttpClient;
Expand All @@ -50,6 +54,7 @@ public class MistralClient {

/**
* Constructor that initializes the MistralClient with a provided API key.
*
* @param apiKey The API key to be used for the Mistral AI API
*/
public MistralClient(@NonNull String apiKey) {
Expand All @@ -71,8 +76,9 @@ public MistralClient() {

/**
* Constructor that initializes the MistralClient with a provided API key, HTTP client, and object mapper.
* @param apiKey The API key to be used for the Mistral AI API
* @param httpClient The OkHttpClient to be used for making requests to the Mistral AI API
*
* @param apiKey The API key to be used for the Mistral AI API
* @param httpClient The OkHttpClient to be used for making requests to the Mistral AI API
* @param objectMapper The Jackson ObjectMapper to be used for serializing and deserializing JSON
*/
public MistralClient(@NonNull String apiKey, @NonNull OkHttpClient httpClient, @NonNull ObjectMapper objectMapper) {
Expand All @@ -84,7 +90,8 @@ public MistralClient(@NonNull String apiKey, @NonNull OkHttpClient httpClient, @

/**
* Constructor that initializes the MistralClient with a provided API key and HTTP client.
* @param apiKey The API key to be used for the Mistral AI API
*
* @param apiKey The API key to be used for the Mistral AI API
* @param httpClient The OkHttpClient to be used for making requests to the Mistral AI API
*/
public MistralClient(@NonNull String apiKey, @NonNull OkHttpClient httpClient) {
Expand All @@ -96,7 +103,8 @@ public MistralClient(@NonNull String apiKey, @NonNull OkHttpClient httpClient) {

/**
* Constructor that initializes the MistralClient with a provided API key and object mapper.
* @param apiKey The API key to be used for the Mistral AI API
*
* @param apiKey The API key to be used for the Mistral AI API
* @param objectMapper The Jackson ObjectMapper to be used for serializing and deserializing JSON
*/
public MistralClient(@NonNull String apiKey, @NonNull ObjectMapper objectMapper) {
Expand All @@ -109,8 +117,12 @@ public MistralClient(@NonNull String apiKey, @NonNull ObjectMapper objectMapper)
/**
* Use the Mistral AI API to create a chat completion (an assistant reply to the conversation).
* This is a blocking method.
*
* @param request The request to create a chat completion. See {@link ChatCompletionRequest}.
* @return The response from the Mistral AI API containing the generated message. See {@link ChatCompletionResponse}.
* @throws ConstraintViolationException if the request does not pass validation
* @throws UnexpectedResponseException if an unexpected response is received from the Mistral AI API
* @throws IllegalArgumentException if the first message role is not 'user' or 'system'
*/
public ChatCompletionResponse createChatCompletion(@NonNull ChatCompletionRequest request) {
return mistralService.createChatCompletion(request);
Expand All @@ -119,23 +131,70 @@ public ChatCompletionResponse createChatCompletion(@NonNull ChatCompletionReques
/**
* Use the Mistral AI API to create a chat completion (an assistant reply to the conversation).
* This is a non-blocking/asynchronous method.
*
* @param request The request to create a chat completion. See {@link ChatCompletionRequest}.
* @return A CompletableFuture that will complete with generated message from the Mistral AI API. See {@link ChatCompletionResponse}.
* @throws ConstraintViolationException if the request does not pass validation
* @throws UnexpectedResponseException if an unexpected response is received from the Mistral AI API
* @throws IllegalArgumentException if the first message role is not 'user' or 'system'
*/
public CompletableFuture<ChatCompletionResponse> createChatCompletionAsync(@NonNull ChatCompletionRequest request) {
return mistralService.createChatCompletionAsync(request);
}

/**
* This method is used to create an embedding using the Mistral AI API.
* The embeddings for the input strings. See the <a href="https://docs.mistral.ai/guides/embeddings/">mistral documentation</a> for more details on embeddings.
* This is a blocking method.
*
* @param request The request to create an embedding. See {@link EmbeddingRequest}.
* @return The response from the Mistral AI API containing the generated embedding. See {@link EmbeddingResponse}.
* @throws ConstraintViolationException if the request does not pass validation
* @throws UnexpectedResponseException if an unexpected response is received from the Mistral AI API
*/
public EmbeddingResponse createEmbedding(@NonNull EmbeddingRequest request) {
return mistralService.createEmbedding(request);
}

/**
* This method is used to create an embedding using the Mistral AI API.
* The embeddings for the input strings. See the <a href="https://docs.mistral.ai/guides/embeddings/">mistral documentation</a> for more details on embeddings.
* This is a non-blocking/asynchronous method.
*
* @param request The request to create an embedding. See {@link EmbeddingRequest}.
* @return A CompletableFuture that will complete with the generated embedding from the Mistral AI API. See {@link EmbeddingResponse}.
* @throws ConstraintViolationException if the request does not pass validation
* @throws UnexpectedResponseException if an unexpected response is received from the Mistral AI API
*/
public CompletableFuture<EmbeddingResponse> createEmbeddingAsync(@NonNull EmbeddingRequest request) {
return mistralService.createEmbeddingAsync(request);
}

/**
* Lists all models available according to the Mistral AI API.
* This is a blocking method.
*
* @return The response from the Mistral AI API containing the list of models. See {@link ListModelsResponse}.
* @throws UnexpectedResponseException if an unexpected response is received from the Mistral AI API
*/
public ListModelsResponse listModels() {
return mistralService.listModels();
}

/**
* Lists all models available according to the Mistral AI API.
* This is a non-blocking/asynchronous method.
*
* @return A CompletableFuture that will complete with the list of models from the Mistral AI API. See {@link ListModelsResponse}.
* @throws UnexpectedResponseException if an unexpected response is received from the Mistral AI API
*/
public CompletableFuture<ListModelsResponse> listModelsAsync() {
return mistralService.listModelsAsync();
}

/**
* Builds the MistralService.
*
* @return A new instance of MistralService
*/
private MistralService buildMistralService() {
Expand All @@ -144,6 +203,7 @@ private MistralService buildMistralService() {

/**
* Builds the HTTP client.
*
* @return A new instance of OkHttpClient
*/
private OkHttpClient buildHttpClient() {
Expand All @@ -157,6 +217,7 @@ private OkHttpClient buildHttpClient() {

/**
* Builds the object mapper.
*
* @return A new instance of ObjectMapper
*/
private ObjectMapper buildObjectMapper() {
Expand Down
10 changes: 8 additions & 2 deletions src/main/java/nl/dannyj/mistral/builders/MessageListBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

package nl.dannyj.mistral.builders;

import nl.dannyj.mistral.models.Message;
import nl.dannyj.mistral.models.MessageRole;
import nl.dannyj.mistral.models.completion.Message;
import nl.dannyj.mistral.models.completion.MessageRole;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -40,6 +40,7 @@ public MessageListBuilder() {

/**
* Constructor that initializes the list of Message objects with a provided list.
*
* @param messages The initial list of Message objects
*/
public MessageListBuilder(List<Message> messages) {
Expand All @@ -48,6 +49,7 @@ public MessageListBuilder(List<Message> messages) {

/**
* Adds a message with the system role to the list with the provided content.
*
* @param content The content of the system message
* @return The builder instance
*/
Expand All @@ -58,6 +60,7 @@ public MessageListBuilder system(String content) {

/**
* Adds a message with the assistant role to the list with the provided content.
*
* @param content The content of the assistant message
* @return The builder instance
*/
Expand All @@ -68,6 +71,7 @@ public MessageListBuilder assistant(String content) {

/**
* Adds a message with the user role to the list with the provided content.
*
* @param content The content of the user message
* @return The builder instance
*/
Expand All @@ -78,6 +82,7 @@ public MessageListBuilder user(String content) {

/**
* Adds a custom Message object to the list.
*
* @param message The Message object to be added
* @return The builder instance
*/
Expand All @@ -88,6 +93,7 @@ public MessageListBuilder message(Message message) {

/**
* Returns the list of Message objects that have been added.
*
* @return The list of Message objects
*/
public List<Message> build() {
Expand Down
20 changes: 20 additions & 0 deletions src/main/java/nl/dannyj/mistral/models/Request.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright 2024 Danny Jelsma
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package nl.dannyj.mistral.models;

public interface Request {
}
20 changes: 20 additions & 0 deletions src/main/java/nl/dannyj/mistral/models/Response.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright 2024 Danny Jelsma
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package nl.dannyj.mistral.models;

public interface Response {
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
* limitations under the License.
*/

package nl.dannyj.mistral.models.request;
package nl.dannyj.mistral.models.completion;

import com.fasterxml.jackson.annotation.JsonProperty;
import jakarta.validation.constraints.*;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import nl.dannyj.mistral.models.Message;
import nl.dannyj.mistral.models.Request;

import java.util.List;

Expand All @@ -34,7 +34,7 @@
@AllArgsConstructor
@NoArgsConstructor
@Builder
public class ChatCompletionRequest {
public class ChatCompletionRequest implements Request {

/**
* ID of the model to use. You can use the List Available Models API to see all of your available models.
Expand Down
Loading

0 comments on commit 60d8363

Please sign in to comment.