Skip to content

Commit a3ad6d5

Browse files
author
wmz7year
committed
Amazon Bedrock Chat adds tool support.
1 parent 49b3326 commit a3ad6d5

File tree

6 files changed

+392
-29
lines changed

6 files changed

+392
-29
lines changed

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,20 @@
1515
*/
1616
package org.springframework.ai.bedrock.anthropic3;
1717

18+
import com.fasterxml.jackson.annotation.JsonIgnore;
1819
import com.fasterxml.jackson.annotation.JsonInclude;
1920
import com.fasterxml.jackson.annotation.JsonInclude.Include;
2021
import com.fasterxml.jackson.annotation.JsonProperty;
2122
import org.springframework.ai.chat.prompt.ChatOptions;
23+
import org.springframework.ai.model.function.FunctionCallback;
24+
import org.springframework.ai.model.function.FunctionCallingOptions;
25+
import org.springframework.boot.context.properties.NestedConfigurationProperty;
26+
import org.springframework.util.Assert;
2227

28+
import java.util.ArrayList;
29+
import java.util.HashSet;
2330
import java.util.List;
31+
import java.util.Set;
2432

2533
/**
2634
* Java {@link ChatOptions} for the Bedrock Anthropic chat generative model chat options.
@@ -31,7 +39,7 @@
3139
* @since 1.0.0
3240
*/
3341
@JsonInclude(Include.NON_NULL)
34-
public class Anthropic3ChatOptions implements ChatOptions {
42+
public class Anthropic3ChatOptions implements ChatOptions, FunctionCallingOptions {
3543

3644
// @formatter:off
3745
/**
@@ -66,6 +74,31 @@ public class Anthropic3ChatOptions implements ChatOptions {
6674
*/
6775
private @JsonProperty("stop_sequences") List<String> stopSequences;
6876

77+
/**
78+
* Tool Function Callbacks to register with the ChatModel. For Prompt
79+
* Options the functionCallbacks are automatically enabled for the duration of the
80+
* prompt execution. For Default Options the functionCallbacks are registered but
81+
* disabled by default. Use the enableFunctions to set the functions from the registry
82+
* to be used by the ChatModel chat completion requests.
83+
*/
84+
@NestedConfigurationProperty
85+
@JsonIgnore
86+
private List<FunctionCallback> functionCallbacks = new ArrayList<>();
87+
88+
/**
89+
* List of functions, identified by their names, to configure for function calling in
90+
* the chat completion requests. Functions with those names must exist in the
91+
* functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions
92+
* are automatically enabled for the duration of the prompt execution.
93+
*
94+
* Note that function enabled with the default options are enabled for all chat
95+
* completion requests. This could impact the token count and the billing. If the
96+
* functions is set in a prompt options, then the enabled functions are only active
97+
* for the duration of this prompt execution.
98+
*/
99+
@NestedConfigurationProperty
100+
@JsonIgnore
101+
private Set<String> functions = new HashSet<>();
69102
// @formatter:on
70103

71104
public static Builder builder() {
@@ -101,6 +134,23 @@ public Builder withStopSequences(List<String> stopSequences) {
101134
return this;
102135
}
103136

137+
public Builder withFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
138+
this.options.functionCallbacks = functionCallbacks;
139+
return this;
140+
}
141+
142+
public Builder withFunctions(Set<String> functionNames) {
143+
Assert.notNull(functionNames, "Function names must not be null");
144+
this.options.functions = functionNames;
145+
return this;
146+
}
147+
148+
public Builder withFunction(String functionName) {
149+
Assert.hasText(functionName, "Function name must not be empty");
150+
this.options.functions.add(functionName);
151+
return this;
152+
}
153+
104154
public Anthropic3ChatOptions build() {
105155
return this.options;
106156
}
@@ -150,12 +200,36 @@ public void setStopSequences(List<String> stopSequences) {
150200
this.stopSequences = stopSequences;
151201
}
152202

203+
@Override
204+
public List<FunctionCallback> getFunctionCallbacks() {
205+
return this.functionCallbacks;
206+
}
207+
208+
@Override
209+
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
210+
Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null");
211+
this.functionCallbacks = functionCallbacks;
212+
}
213+
214+
@Override
215+
public Set<String> getFunctions() {
216+
return this.functions;
217+
}
218+
219+
@Override
220+
public void setFunctions(Set<String> functions) {
221+
Assert.notNull(functions, "Function must not be null");
222+
this.functions = functions;
223+
}
224+
153225
public static Anthropic3ChatOptions fromOptions(Anthropic3ChatOptions fromOptions) {
154226
return builder().withTemperature(fromOptions.getTemperature())
155227
.withMaxTokens(fromOptions.getMaxTokens())
156228
.withTopK(fromOptions.getTopK())
157229
.withTopP(fromOptions.getTopP())
158230
.withStopSequences(fromOptions.getStopSequences())
231+
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
232+
.withFunctions(fromOptions.getFunctions())
159233
.build();
160234
}
161235

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,22 @@
1616
package org.springframework.ai.bedrock.anthropic3;
1717

1818
import reactor.core.publisher.Flux;
19-
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
2019
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
2120

21+
import java.util.List;
22+
2223
import org.springframework.ai.bedrock.api.BedrockConverseApi;
24+
import org.springframework.ai.bedrock.api.BedrockConverseApi.BedrockConverseRequest;
2325
import org.springframework.ai.bedrock.api.BedrockConverseApiUtils;
2426
import org.springframework.ai.chat.model.ChatModel;
2527
import org.springframework.ai.chat.model.ChatResponse;
28+
import org.springframework.ai.chat.model.Generation;
2629
import org.springframework.ai.chat.model.StreamingChatModel;
2730
import org.springframework.ai.chat.prompt.ChatOptions;
2831
import org.springframework.ai.chat.prompt.Prompt;
2932
import org.springframework.ai.model.ModelDescription;
33+
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
34+
import org.springframework.ai.model.function.FunctionCallbackContext;
3035
import org.springframework.util.Assert;
3136

3237
/**
@@ -38,7 +43,9 @@
3843
* @author Wei Jiang
3944
* @since 1.0.0
4045
*/
41-
public class BedrockAnthropic3ChatModel implements ChatModel, StreamingChatModel {
46+
public class BedrockAnthropic3ChatModel
47+
extends AbstractFunctionCallSupport<Generation, BedrockConverseRequest, ChatResponse>
48+
implements ChatModel, StreamingChatModel {
4249

4350
private final String modelId;
4451

@@ -56,6 +63,13 @@ public BedrockAnthropic3ChatModel(BedrockConverseApi converseApi, Anthropic3Chat
5663
}
5764

5865
public BedrockAnthropic3ChatModel(String modelId, BedrockConverseApi converseApi, Anthropic3ChatOptions options) {
66+
this(modelId, converseApi, options, null);
67+
}
68+
69+
public BedrockAnthropic3ChatModel(String modelId, BedrockConverseApi converseApi, Anthropic3ChatOptions options,
70+
FunctionCallbackContext functionCallbackContext) {
71+
super(functionCallbackContext);
72+
5973
Assert.notNull(modelId, "modelId must not be null.");
6074
Assert.notNull(converseApi, "BedrockConverseApi must not be null.");
6175
Assert.notNull(options, "Anthropic3ChatOptions must not be null.");
@@ -69,17 +83,16 @@ public BedrockAnthropic3ChatModel(String modelId, BedrockConverseApi converseApi
6983
public ChatResponse call(Prompt prompt) {
7084
Assert.notNull(prompt, "Prompt must not be null.");
7185

72-
var request = BedrockConverseApiUtils.createConverseRequest(modelId, prompt, defaultOptions);
73-
74-
ConverseResponse response = this.converseApi.converse(request);
86+
var request = BedrockConverseApiUtils.createBedrockConverseRequest(modelId, prompt, defaultOptions);
7587

76-
return BedrockConverseApiUtils.convertConverseResponse(response);
88+
return this.callWithFunctionSupport(request);
7789
}
7890

7991
@Override
8092
public Flux<ChatResponse> stream(Prompt prompt) {
8193
Assert.notNull(prompt, "Prompt must not be null.");
8294

95+
// TODO
8396
var request = BedrockConverseApiUtils.createConverseStreamRequest(modelId, prompt, defaultOptions);
8497

8598
Flux<ConverseStreamOutput> fluxResponse = this.converseApi.converseStream(request);
@@ -92,6 +105,43 @@ public ChatOptions getDefaultOptions() {
92105
return Anthropic3ChatOptions.fromOptions(this.defaultOptions);
93106
}
94107

108+
@Override
109+
protected BedrockConverseRequest doCreateToolResponseRequest(BedrockConverseRequest previousRequest,
110+
Generation responseMessage, List<Generation> conversationHistory) {
111+
// TODO Auto-generated method stub
112+
return null;
113+
}
114+
115+
@Override
116+
protected List<Generation> doGetUserMessages(BedrockConverseRequest request) {
117+
// TODO Auto-generated method stub
118+
return null;
119+
}
120+
121+
@Override
122+
protected Generation doGetToolResponseMessage(ChatResponse response) {
123+
// TODO Auto-generated method stub
124+
return null;
125+
}
126+
127+
@Override
128+
protected ChatResponse doChatCompletion(BedrockConverseRequest request) {
129+
// TODO Auto-generated method stub
130+
return null;
131+
}
132+
133+
@Override
134+
protected Flux<ChatResponse> doChatCompletionStream(BedrockConverseRequest request) {
135+
// TODO Auto-generated method stub
136+
return null;
137+
}
138+
139+
@Override
140+
protected boolean isToolFunctionCall(ChatResponse response) {
141+
// TODO Auto-generated method stub
142+
return false;
143+
}
144+
95145
/**
96146
* Anthropic3 models version.
97147
*/

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/BedrockConverseApi.java

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
package org.springframework.ai.bedrock.api;
1818

1919
import java.time.Duration;
20+
import java.util.List;
2021

2122
import org.slf4j.Logger;
2223
import org.slf4j.LoggerFactory;
24+
import org.springframework.ai.chat.model.ChatResponse;
2325
import org.springframework.ai.retry.RetryUtils;
2426
import org.springframework.retry.support.RetryTemplate;
2527
import org.springframework.util.Assert;
@@ -30,6 +32,7 @@
3032
import reactor.core.publisher.Sinks.EmitResult;
3133
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
3234
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
35+
import software.amazon.awssdk.core.document.Document;
3336
import software.amazon.awssdk.regions.Region;
3437
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
3538
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
@@ -38,6 +41,8 @@
3841
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
3942
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
4043
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;
44+
import software.amazon.awssdk.services.bedrockruntime.model.Message;
45+
import software.amazon.awssdk.services.bedrockruntime.model.SystemContentBlock;
4146

4247
/**
4348
* Amazon Bedrock Converse API, It provides the basic functionality to invoke the Bedrock
@@ -177,6 +182,41 @@ public Region getRegion() {
177182
return this.region;
178183
}
179184

185+
/**
186+
* BedrockConverseRequest encapsulates the request parameters for the Amazon Bedrock
187+
* Converse Api.
188+
*
189+
* @param modelId The Amazon Bedrock Model Id.
190+
* @param messages The messages that you want to send to the model.
191+
* @param systemMessages A system prompt to pass to the model.
192+
* @param additionalModelRequestFields Additional inference parameters that the model
193+
* supports, beyond the base set of inference parameters that Converse supports in the
194+
* inferenceConfig field.
195+
*/
196+
public record BedrockConverseRequest(String modelId, List<Message> messages,
197+
List<SystemContentBlock> systemMessages, Document additionalModelRequestFields) {
198+
199+
}
200+
201+
/**
202+
* Invoke the model and return the response.
203+
*
204+
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
205+
* https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
206+
* https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient.html#converse
207+
* @param bedrockConverseRequest Model invocation request.
208+
* @return The model invocation response.
209+
*/
210+
public ChatResponse converse(BedrockConverseRequest bedrockConverseRequest) {
211+
Assert.notNull(bedrockConverseRequest, "'bedrockConverseRequest' must not be null");
212+
213+
ConverseRequest converseRequest = BedrockConverseApiUtils.createConverseRequest(bedrockConverseRequest);
214+
215+
ConverseResponse converseResponse = converse(converseRequest);
216+
217+
return BedrockConverseApiUtils.convertConverseResponse(converseResponse);
218+
}
219+
180220
/**
181221
* Invoke the model and return the response.
182222
*
@@ -194,6 +234,26 @@ public ConverseResponse converse(ConverseRequest converseRequest) {
194234
});
195235
}
196236

237+
/**
238+
* Invoke the model and return the response stream.
239+
*
240+
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
241+
* https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
242+
* https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient.html#converseStream
243+
* @param bedrockConverseRequest Model invocation request.
244+
* @return The model invocation response stream.
245+
*/
246+
public Flux<ChatResponse> converseStream(BedrockConverseRequest bedrockConverseRequest) {
247+
Assert.notNull(bedrockConverseRequest, "'bedrockConverseRequest' must not be null");
248+
249+
ConverseStreamRequest converseStreamRequest = BedrockConverseApiUtils
250+
.createConverseStreamRequest(bedrockConverseRequest);
251+
252+
return converseStream(converseStreamRequest)
253+
.map(output -> BedrockConverseApiUtils.convertConverseStreamOutput(output));
254+
255+
}
256+
197257
/**
198258
* Invoke the model and return the response stream.
199259
*

0 commit comments

Comments
 (0)