Skip to content

Commit a2a92bb

Browse files
apappascstzolov
authored andcommitted
feat: Update MistralChatOptions
- Updating `copy()` method, creating new instances of mutable collections (List, Set, Map, Metadata) to prevent shared state. - Adding `MistralChatOptionsTests` to verify `copy()`, builders, setters, and default values. Signed-off-by: Alexandros Pappas <[email protected]>
1 parent beb1d05 commit a2a92bb

File tree

3 files changed

+140
-7
lines changed

3 files changed

+140
-7
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,16 +171,17 @@ public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions)
171171
.temperature(fromOptions.getTemperature())
172172
.topP(fromOptions.getTopP())
173173
.responseFormat(fromOptions.getResponseFormat())
174-
.stop(fromOptions.getStop())
174+
.stop(fromOptions.getStop() != null ? new ArrayList<>(fromOptions.getStop()) : null)
175175
.frequencyPenalty(fromOptions.getFrequencyPenalty())
176176
.presencePenalty(fromOptions.getPresencePenalty())
177177
.n(fromOptions.getN())
178-
.tools(fromOptions.getTools())
178+
.tools(fromOptions.getTools() != null ? new ArrayList<>(fromOptions.getTools()) : null)
179179
.toolChoice(fromOptions.getToolChoice())
180-
.toolCallbacks(fromOptions.getToolCallbacks())
181-
.toolNames(fromOptions.getToolNames())
180+
.toolCallbacks(
181+
fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null)
182+
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
182183
.internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled())
183-
.toolContext(fromOptions.getToolContext())
184+
.toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null)
184185
.build();
185186
}
186187

@@ -366,6 +367,7 @@ public void setToolContext(Map<String, Object> toolContext) {
366367
}
367368

368369
@Override
370+
@SuppressWarnings("unchecked")
369371
public MistralAiChatOptions copy() {
370372
return fromOptions(this);
371373
}

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -754,12 +754,16 @@ public enum ToolChoice {
754754
/**
755755
* An object specifying the format that the model must output.
756756
*
757-
* @param type Must be one of 'text' or 'json_object'.
758-
* @param jsonSchema A specific JSON schema to match, if 'type' is 'json_object'.
757+
* @param type Must be one of 'text', 'json_object' or 'json_schema'.
758+
* @param jsonSchema A specific JSON schema to match, if 'type' is 'json_schema'.
759759
*/
760760
@JsonInclude(Include.NON_NULL)
761761
public record ResponseFormat(@JsonProperty("type") String type,
762762
@JsonProperty("json_schema") Map<String, Object> jsonSchema) {
763+
764+
public ResponseFormat(String type) {
765+
this(type, null);
766+
}
763767
}
764768

765769
}
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.mistralai;
18+
19+
import java.util.List;
20+
import java.util.Map;
21+
22+
import static org.assertj.core.api.Assertions.assertThat;
23+
24+
import org.junit.jupiter.api.Test;
25+
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ResponseFormat;
26+
27+
import org.springframework.ai.mistralai.api.MistralAiApi;
28+
29+
/**
30+
* Tests for {@link MistralAiChatOptions}.
31+
*
32+
* @author Alexandros Pappas
33+
*/
34+
class MistralAiChatOptionsTests {
35+
36+
@Test
37+
void testBuilderWithAllFields() {
38+
MistralAiChatOptions options = MistralAiChatOptions.builder()
39+
.model("test-model")
40+
.temperature(0.7)
41+
.topP(0.9)
42+
.maxTokens(100)
43+
.safePrompt(true)
44+
.randomSeed(123)
45+
.stop(List.of("stop1", "stop2"))
46+
.responseFormat(new ResponseFormat("json_object"))
47+
.toolChoice(MistralAiApi.ChatCompletionRequest.ToolChoice.AUTO)
48+
.internalToolExecutionEnabled(true)
49+
.toolContext(Map.of("key1", "value1"))
50+
.build();
51+
52+
assertThat(options)
53+
.extracting("model", "temperature", "topP", "maxTokens", "safePrompt", "randomSeed", "stop",
54+
"responseFormat", "toolChoice", "internalToolExecutionEnabled", "toolContext")
55+
.containsExactly("test-model", 0.7, 0.9, 100, true, 123, List.of("stop1", "stop2"),
56+
new ResponseFormat("json_object"), MistralAiApi.ChatCompletionRequest.ToolChoice.AUTO, true,
57+
Map.of("key1", "value1"));
58+
}
59+
60+
@Test
61+
void testBuilderWithEnum() {
62+
MistralAiChatOptions optionsWithEnum = MistralAiChatOptions.builder()
63+
.model(MistralAiApi.ChatModel.MINISTRAL_8B_LATEST)
64+
.build();
65+
assertThat(optionsWithEnum.getModel()).isEqualTo(MistralAiApi.ChatModel.MINISTRAL_8B_LATEST.getValue());
66+
}
67+
68+
@Test
69+
void testCopy() {
70+
MistralAiChatOptions options = MistralAiChatOptions.builder()
71+
.model("test-model")
72+
.temperature(0.7)
73+
.topP(0.9)
74+
.maxTokens(100)
75+
.safePrompt(true)
76+
.randomSeed(123)
77+
.stop(List.of("stop1", "stop2"))
78+
.responseFormat(new ResponseFormat("json_object"))
79+
.toolChoice(MistralAiApi.ChatCompletionRequest.ToolChoice.AUTO)
80+
.internalToolExecutionEnabled(true)
81+
.toolContext(Map.of("key1", "value1"))
82+
.build();
83+
84+
MistralAiChatOptions copiedOptions = options.copy();
85+
assertThat(copiedOptions).isNotSameAs(options).isEqualTo(options);
86+
// Ensure deep copy
87+
assertThat(copiedOptions.getStop()).isNotSameAs(options.getStop());
88+
assertThat(copiedOptions.getToolContext()).isNotSameAs(options.getToolContext());
89+
}
90+
91+
@Test
92+
void testSetters() {
93+
ResponseFormat responseFormat = new ResponseFormat("json_object");
94+
MistralAiChatOptions options = new MistralAiChatOptions();
95+
options.setModel("test-model");
96+
options.setTemperature(0.7);
97+
options.setTopP(0.9);
98+
options.setMaxTokens(100);
99+
options.setSafePrompt(true);
100+
options.setRandomSeed(123);
101+
options.setResponseFormat(responseFormat);
102+
options.setStopSequences(List.of("stop1", "stop2"));
103+
104+
assertThat(options.getModel()).isEqualTo("test-model");
105+
assertThat(options.getTemperature()).isEqualTo(0.7);
106+
assertThat(options.getTopP()).isEqualTo(0.9);
107+
assertThat(options.getMaxTokens()).isEqualTo(100);
108+
assertThat(options.getSafePrompt()).isEqualTo(true);
109+
assertThat(options.getRandomSeed()).isEqualTo(123);
110+
assertThat(options.getStopSequences()).isEqualTo(List.of("stop1", "stop2"));
111+
assertThat(options.getResponseFormat()).isEqualTo(responseFormat);
112+
}
113+
114+
@Test
115+
void testDefaultValues() {
116+
MistralAiChatOptions options = new MistralAiChatOptions();
117+
assertThat(options.getModel()).isNull();
118+
assertThat(options.getTemperature()).isNull();
119+
assertThat(options.getTopP()).isNull();
120+
assertThat(options.getMaxTokens()).isNull();
121+
assertThat(options.getSafePrompt()).isNull();
122+
assertThat(options.getRandomSeed()).isNull();
123+
assertThat(options.getStopSequences()).isNull();
124+
assertThat(options.getResponseFormat()).isNull();
125+
}
126+
127+
}

0 commit comments

Comments
 (0)