Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Mistral AI support #32

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions langchain4j-spring-boot-autoconfigure/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-spring</artifactId>
<version>0.32.0</version>
<relativePath>../pom.xml</relativePath>
</parent>

<artifactId>langchain4j-spring-boot-autoconfigure</artifactId>
<name>LangChain4j Spring Boot Autoconfigure</name>

<dependencies>

<!-- Annotation Processor -->

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-autoconfigure-processor</artifactId>
<optional>true</optional>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-configuration-processor</artifactId>
<optional>true</optional>
</dependency>

<!-- API -->

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter</artifactId>
<scope>compile</scope>
</dependency>

<!-- Optional -->

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-spring-core</artifactId>
<version>${project.parent.version}</version>
<optional>true</optional>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-spring-mistral-ai</artifactId>
<version>${project.parent.version}</version>
<optional>true</optional>
</dependency>

<!-- Test Implementation -->

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>

</dependencies>

<licenses>
<license>
<name>Apache-2.0</name>
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>
<distribution>repo</distribution>
<comments>A business-friendly OSS license</comments>
</license>
</licenses>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package dev.langchain4j.spring.boot.autoconfigure.models.mistralai;

import dev.langchain4j.model.mistralai.MistralAiChatModel;
import dev.langchain4j.model.mistralai.MistralAiEmbeddingModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;

/**
* Auto-configuration for Mistral AI clients and models.
*/
@AutoConfiguration(after = {RestClientAutoConfiguration.class})
@ConditionalOnClass({ MistralAiChatModel.class })
@ConditionalOnProperty(prefix = MistralAiProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true)
@EnableConfigurationProperties({ MistralAiProperties.class, MistralAiChatProperties.class, MistralAiEmbeddingProperties.class })
public class MistralAiAutoConfiguration {

private static final Logger logger = LoggerFactory.getLogger(MistralAiAutoConfiguration.class);

@Bean
@ConditionalOnMissingBean
@ConditionalOnProperty(prefix = MistralAiChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true)
MistralAiChatModel mistralAiChatModel(MistralAiProperties properties, MistralAiChatProperties chatProperties) {
var chatModel = MistralAiChatModel.builder()
// Client Config
.apiKey(properties.getClient().getApiKey())
.baseUrl(properties.getClient().getBaseUrl().toString())
.timeout(properties.getClient().getReadTimeout())
.maxRetries(properties.getClient().getMaxRetries())
.logRequests(properties.getClient().isLogRequests())
.logResponses(properties.getClient().isLogResponses())
// Model Options
.modelName(chatProperties.getModel())
.temperature(chatProperties.getTemperature())
.topP(chatProperties.getTopP())
.maxTokens(chatProperties.getMaxTokens())
.safePrompt(chatProperties.isSafePrompt())
.randomSeed(chatProperties.getRandomSeed())
.build();

warnAboutSensitiveInformationExposure(properties.getClient(), MistralAiChatModel.class.getTypeName());

return chatModel;
}

@Bean
@ConditionalOnMissingBean
@ConditionalOnProperty(prefix = MistralAiEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true)
MistralAiEmbeddingModel mistralAiEmbeddingModel(MistralAiProperties properties, MistralAiEmbeddingProperties embeddingProperties) {
var embeddingModel = MistralAiEmbeddingModel.builder()
// Client Config
.apiKey(properties.getClient().getApiKey())
.baseUrl(properties.getClient().getBaseUrl().toString())
.timeout(properties.getClient().getReadTimeout())
.maxRetries(properties.getClient().getMaxRetries())
.logRequests(properties.getClient().isLogRequests())
.logResponses(properties.getClient().isLogResponses())
// Model Options
.modelName(embeddingProperties.getModel())
.build();

warnAboutSensitiveInformationExposure(properties.getClient(), MistralAiEmbeddingModel.class.getTypeName());

return embeddingModel;
}

private static void warnAboutSensitiveInformationExposure(MistralAiProperties.Client client, String modelClassName) {
if (client.isLogRequests()) {
logger.warn("You have enabled logging for the entire model request in {}, with the risk of exposing sensitive or private information. Please, be careful!", modelClassName);
}

if (client.isLogResponses()) {
logger.warn("You have enabled logging for the entire model response in {}, with the risk of exposing sensitive or private information. Please, be careful!", modelClassName);
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package dev.langchain4j.spring.boot.autoconfigure.models.mistralai;

import dev.langchain4j.model.mistralai.MistralAiChatModelName;
import org.springframework.boot.context.properties.ConfigurationProperties;

/**
* Configuration properties for Mistral AI chat models.
*/
@ConfigurationProperties(prefix = MistralAiChatProperties.CONFIG_PREFIX)
public class MistralAiChatProperties {

public static final String CONFIG_PREFIX = "langchain4j.mistralai.chat";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
public static final String CONFIG_PREFIX = "langchain4j.mistralai.chat";
public static final String CONFIG_PREFIX = "langchain4j.mistral-ai.chat-model";


/**
* Whether to enable the Mistral AI chat models.
*/
private boolean enabled = true;

/**
* ID of the model to use.
*/
private String model = MistralAiChatModelName.OPEN_MISTRAL_7B.toString();
/**
* What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or "top_p" but not both.
*/
private Double temperature = 0.7;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to set default values here? (except for the mandatory params (model)).
I would avoid setting defaults here

/**
* Nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or "temperature" but not both.
*/
private Double topP = 1.0;
/**
* The maximum number of tokens to generate in the completion. The token count of your prompt plus "max_tokens" cannot exceed the model's context length.
*/
private Integer maxTokens;
/**
* Whether to inject a safety prompt before all conversations.
*/
private boolean safePrompt = false;
/**
* The seed to use for random sampling. If set, different calls will generate deterministic results.
*/
private Integer randomSeed;

public boolean isEnabled() {
return enabled;
}

public void setEnabled(boolean enabled) {
this.enabled = enabled;
}

public String getModel() {
return model;
}

public void setModel(String model) {
this.model = model;
}

public Double getTemperature() {
return temperature;
}

public void setTemperature(Double temperature) {
this.temperature = temperature;
}

public Double getTopP() {
return topP;
}

public void setTopP(Double topP) {
this.topP = topP;
}

public Integer getMaxTokens() {
return maxTokens;
}

public void setMaxTokens(Integer maxTokens) {
this.maxTokens = maxTokens;
}

public boolean isSafePrompt() {
return safePrompt;
}

public void setSafePrompt(boolean safePrompt) {
this.safePrompt = safePrompt;
}

public Integer getRandomSeed() {
return randomSeed;
}

public void setRandomSeed(Integer randomSeed) {
this.randomSeed = randomSeed;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package dev.langchain4j.spring.boot.autoconfigure.models.mistralai;

import dev.langchain4j.model.mistralai.MistralAiEmbeddingModelName;
import org.springframework.boot.context.properties.ConfigurationProperties;

/**
* Configuration properties for Mistral AI chat models.
*/
@ConfigurationProperties(prefix = MistralAiEmbeddingProperties.CONFIG_PREFIX)
public class MistralAiEmbeddingProperties {

public static final String CONFIG_PREFIX = "langchain4j.mistralai.embedding";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
public static final String CONFIG_PREFIX = "langchain4j.mistralai.embedding";
public static final String CONFIG_PREFIX = "langchain4j.mistral-ai.embedding-model";


/**
* Whether to enable the Mistral AI embedding models.
*/
private boolean enabled = true;

/**
* ID of the model to use.
*/
private String model = MistralAiEmbeddingModelName.MISTRAL_EMBED.toString();

public boolean isEnabled() {
return enabled;
}

public void setEnabled(boolean enabled) {
this.enabled = enabled;
}

public String getModel() {
return model;
}

public void setModel(String model) {
this.model = model;
}

}
Loading
Loading