Skip to content

Commit fde14fe

Browse files
author
wmz7year
committed
Amazon Bedrock Chat adds tool support.
1 parent 49b3326 commit fde14fe

File tree

7 files changed

+671
-63
lines changed

7 files changed

+671
-63
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.bedrock;
17+
18+
import java.util.List;
19+
20+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
21+
22+
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
23+
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock.Type;
24+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
25+
26+
/**
27+
* Amazon Bedrock Chat model converse interface generation metadata, encapsulating
28+
* information on the completion.
29+
*
30+
* @author Wei Jiang
31+
* @since 1.0.0
32+
*/
33+
public class BedrockConverseChatGenerationMetadata implements ChatGenerationMetadata {
34+
35+
private final String stopReason;
36+
37+
private final List<ContentBlock> contentBlocks;
38+
39+
public BedrockConverseChatGenerationMetadata(String stopReason, List<ContentBlock> contentBlocks) {
40+
super();
41+
42+
this.stopReason = stopReason;
43+
this.contentBlocks = contentBlocks;
44+
}
45+
46+
public static BedrockConverseChatGenerationMetadata from(ConverseResponse response, List<ContentBlock> contents) {
47+
return new BedrockConverseChatGenerationMetadata(response.stopReasonAsString(), contents);
48+
}
49+
50+
@Override
51+
public <T> T getContentFilterMetadata() {
52+
return null;
53+
}
54+
55+
@Override
56+
public String getFinishReason() {
57+
return stopReason;
58+
}
59+
60+
public List<ContentBlock> getContentBlocks() {
61+
return contentBlocks;
62+
}
63+
64+
public boolean isToolUse() {
65+
return contentBlocks.stream().anyMatch(content -> content.type() == Type.TOOL_USE);
66+
67+
}
68+
69+
}

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,20 @@
1515
*/
1616
package org.springframework.ai.bedrock.anthropic3;
1717

18+
import com.fasterxml.jackson.annotation.JsonIgnore;
1819
import com.fasterxml.jackson.annotation.JsonInclude;
1920
import com.fasterxml.jackson.annotation.JsonInclude.Include;
2021
import com.fasterxml.jackson.annotation.JsonProperty;
2122
import org.springframework.ai.chat.prompt.ChatOptions;
23+
import org.springframework.ai.model.function.FunctionCallback;
24+
import org.springframework.ai.model.function.FunctionCallingOptions;
25+
import org.springframework.boot.context.properties.NestedConfigurationProperty;
26+
import org.springframework.util.Assert;
2227

28+
import java.util.ArrayList;
29+
import java.util.HashSet;
2330
import java.util.List;
31+
import java.util.Set;
2432

2533
/**
2634
* Java {@link ChatOptions} for the Bedrock Anthropic chat generative model chat options.
@@ -31,7 +39,7 @@
3139
* @since 1.0.0
3240
*/
3341
@JsonInclude(Include.NON_NULL)
34-
public class Anthropic3ChatOptions implements ChatOptions {
42+
public class Anthropic3ChatOptions implements ChatOptions, FunctionCallingOptions {
3543

3644
// @formatter:off
3745
/**
@@ -66,6 +74,31 @@ public class Anthropic3ChatOptions implements ChatOptions {
6674
*/
6775
private @JsonProperty("stop_sequences") List<String> stopSequences;
6876

77+
/**
78+
* Tool Function Callbacks to register with the ChatModel. For Prompt
79+
* Options the functionCallbacks are automatically enabled for the duration of the
80+
* prompt execution. For Default Options the functionCallbacks are registered but
81+
* disabled by default. Use the enableFunctions to set the functions from the registry
82+
* to be used by the ChatModel chat completion requests.
83+
*/
84+
@NestedConfigurationProperty
85+
@JsonIgnore
86+
private List<FunctionCallback> functionCallbacks = new ArrayList<>();
87+
88+
/**
89+
* List of functions, identified by their names, to configure for function calling in
90+
* the chat completion requests. Functions with those names must exist in the
91+
* functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions
92+
* are automatically enabled for the duration of the prompt execution.
93+
*
94+
* Note that function enabled with the default options are enabled for all chat
95+
* completion requests. This could impact the token count and the billing. If the
96+
* functions is set in a prompt options, then the enabled functions are only active
97+
* for the duration of this prompt execution.
98+
*/
99+
@NestedConfigurationProperty
100+
@JsonIgnore
101+
private Set<String> functions = new HashSet<>();
69102
// @formatter:on
70103

71104
public static Builder builder() {
@@ -101,6 +134,23 @@ public Builder withStopSequences(List<String> stopSequences) {
101134
return this;
102135
}
103136

137+
public Builder withFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
138+
this.options.functionCallbacks = functionCallbacks;
139+
return this;
140+
}
141+
142+
public Builder withFunctions(Set<String> functionNames) {
143+
Assert.notNull(functionNames, "Function names must not be null");
144+
this.options.functions = functionNames;
145+
return this;
146+
}
147+
148+
public Builder withFunction(String functionName) {
149+
Assert.hasText(functionName, "Function name must not be empty");
150+
this.options.functions.add(functionName);
151+
return this;
152+
}
153+
104154
public Anthropic3ChatOptions build() {
105155
return this.options;
106156
}
@@ -150,12 +200,36 @@ public void setStopSequences(List<String> stopSequences) {
150200
this.stopSequences = stopSequences;
151201
}
152202

203+
@Override
204+
public List<FunctionCallback> getFunctionCallbacks() {
205+
return this.functionCallbacks;
206+
}
207+
208+
@Override
209+
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
210+
Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null");
211+
this.functionCallbacks = functionCallbacks;
212+
}
213+
214+
@Override
215+
public Set<String> getFunctions() {
216+
return this.functions;
217+
}
218+
219+
@Override
220+
public void setFunctions(Set<String> functions) {
221+
Assert.notNull(functions, "Function must not be null");
222+
this.functions = functions;
223+
}
224+
153225
public static Anthropic3ChatOptions fromOptions(Anthropic3ChatOptions fromOptions) {
154226
return builder().withTemperature(fromOptions.getTemperature())
155227
.withMaxTokens(fromOptions.getMaxTokens())
156228
.withTopK(fromOptions.getTopK())
157229
.withTopP(fromOptions.getTopP())
158230
.withStopSequences(fromOptions.getStopSequences())
231+
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
232+
.withFunctions(fromOptions.getFunctions())
159233
.build();
160234
}
161235

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java

Lines changed: 126 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,33 @@
1616
package org.springframework.ai.bedrock.anthropic3;
1717

1818
import reactor.core.publisher.Flux;
19-
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
2019
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
20+
import software.amazon.awssdk.services.bedrockruntime.model.Tool;
21+
import software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration;
22+
import software.amazon.awssdk.services.bedrockruntime.model.ToolInputSchema;
23+
import software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification;
2124

25+
import java.util.HashSet;
26+
import java.util.List;
27+
import java.util.Set;
28+
29+
import org.springframework.ai.bedrock.BedrockConverseChatGenerationMetadata;
2230
import org.springframework.ai.bedrock.api.BedrockConverseApi;
31+
import org.springframework.ai.bedrock.api.BedrockConverseApi.BedrockConverseRequest;
2332
import org.springframework.ai.bedrock.api.BedrockConverseApiUtils;
33+
import org.springframework.ai.chat.messages.Message;
2434
import org.springframework.ai.chat.model.ChatModel;
2535
import org.springframework.ai.chat.model.ChatResponse;
36+
import org.springframework.ai.chat.model.Generation;
2637
import org.springframework.ai.chat.model.StreamingChatModel;
2738
import org.springframework.ai.chat.prompt.ChatOptions;
2839
import org.springframework.ai.chat.prompt.Prompt;
2940
import org.springframework.ai.model.ModelDescription;
41+
import org.springframework.ai.model.ModelOptionsUtils;
42+
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
43+
import org.springframework.ai.model.function.FunctionCallbackContext;
3044
import org.springframework.util.Assert;
45+
import org.springframework.util.CollectionUtils;
3146

3247
/**
3348
* Java {@link ChatModel} and {@link StreamingChatModel} for the Bedrock Anthropic3 chat
@@ -38,7 +53,9 @@
3853
* @author Wei Jiang
3954
* @since 1.0.0
4055
*/
41-
public class BedrockAnthropic3ChatModel implements ChatModel, StreamingChatModel {
56+
public class BedrockAnthropic3ChatModel
57+
extends AbstractFunctionCallSupport<Message, BedrockConverseRequest, ChatResponse>
58+
implements ChatModel, StreamingChatModel {
4259

4360
private final String modelId;
4461

@@ -56,6 +73,13 @@ public BedrockAnthropic3ChatModel(BedrockConverseApi converseApi, Anthropic3Chat
5673
}
5774

5875
public BedrockAnthropic3ChatModel(String modelId, BedrockConverseApi converseApi, Anthropic3ChatOptions options) {
76+
this(modelId, converseApi, options, null);
77+
}
78+
79+
public BedrockAnthropic3ChatModel(String modelId, BedrockConverseApi converseApi, Anthropic3ChatOptions options,
80+
FunctionCallbackContext functionCallbackContext) {
81+
super(functionCallbackContext);
82+
5983
Assert.notNull(modelId, "modelId must not be null.");
6084
Assert.notNull(converseApi, "BedrockConverseApi must not be null.");
6185
Assert.notNull(options, "Anthropic3ChatOptions must not be null.");
@@ -69,29 +93,125 @@ public BedrockAnthropic3ChatModel(String modelId, BedrockConverseApi converseApi
6993
public ChatResponse call(Prompt prompt) {
7094
Assert.notNull(prompt, "Prompt must not be null.");
7195

72-
var request = BedrockConverseApiUtils.createConverseRequest(modelId, prompt, defaultOptions);
73-
74-
ConverseResponse response = this.converseApi.converse(request);
96+
var request = createBedrockConverseRequest(prompt);
7597

76-
return BedrockConverseApiUtils.convertConverseResponse(response);
98+
return this.callWithFunctionSupport(request);
7799
}
78100

79101
@Override
80102
public Flux<ChatResponse> stream(Prompt prompt) {
81103
Assert.notNull(prompt, "Prompt must not be null.");
82104

105+
// TODO
83106
var request = BedrockConverseApiUtils.createConverseStreamRequest(modelId, prompt, defaultOptions);
84107

85108
Flux<ConverseStreamOutput> fluxResponse = this.converseApi.converseStream(request);
86109

87110
return fluxResponse.map(output -> BedrockConverseApiUtils.convertConverseStreamOutput(output));
88111
}
89112

113+
private BedrockConverseRequest createBedrockConverseRequest(Prompt prompt) {
114+
var request = BedrockConverseApiUtils.createBedrockConverseRequest(modelId, prompt, defaultOptions);
115+
116+
ToolConfiguration toolConfiguration = createToolConfiguration(prompt);
117+
request.setToolConfiguration(toolConfiguration);
118+
119+
return request;
120+
}
121+
122+
private ToolConfiguration createToolConfiguration(Prompt prompt) {
123+
Set<String> functionsForThisRequest = new HashSet<>();
124+
125+
if (this.defaultOptions != null) {
126+
Set<String> promptEnabledFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions,
127+
!IS_RUNTIME_CALL);
128+
functionsForThisRequest.addAll(promptEnabledFunctions);
129+
}
130+
131+
if (prompt.getOptions() != null) {
132+
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
133+
Anthropic3ChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
134+
ChatOptions.class, Anthropic3ChatOptions.class);
135+
136+
Set<String> defaultEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions,
137+
IS_RUNTIME_CALL);
138+
functionsForThisRequest.addAll(defaultEnabledFunctions);
139+
}
140+
else {
141+
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
142+
+ prompt.getOptions().getClass().getSimpleName());
143+
}
144+
}
145+
146+
if (CollectionUtils.isEmpty(functionsForThisRequest)) {
147+
return null;
148+
}
149+
else {
150+
return ToolConfiguration.builder().tools(getFunctionTools(functionsForThisRequest)).build();
151+
}
152+
}
153+
154+
private List<Tool> getFunctionTools(Set<String> functionNames) {
155+
return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
156+
var description = functionCallback.getDescription();
157+
var name = functionCallback.getName();
158+
String inputSchema = functionCallback.getInputTypeSchema();
159+
160+
return Tool.builder()
161+
.toolSpec(ToolSpecification.builder()
162+
.name(name)
163+
.description(description)
164+
.inputSchema(ToolInputSchema.builder()
165+
.json(BedrockConverseApiUtils.convertObjectToDocument(ModelOptionsUtils.jsonToMap(inputSchema)))
166+
.build())
167+
.build())
168+
.build();
169+
}).toList();
170+
}
171+
90172
@Override
91173
public ChatOptions getDefaultOptions() {
92174
return Anthropic3ChatOptions.fromOptions(this.defaultOptions);
93175
}
94176

177+
@Override
178+
protected BedrockConverseRequest doCreateToolResponseRequest(BedrockConverseRequest previousRequest,
179+
Message responseMessage, List<Message> conversationHistory) {
180+
// TODO
181+
return null;
182+
}
183+
184+
@Override
185+
protected List<Message> doGetUserMessages(BedrockConverseRequest request) {
186+
return BedrockConverseApiUtils.getMessagesInstructions(request.getMessages());
187+
}
188+
189+
@Override
190+
protected Message doGetToolResponseMessage(ChatResponse response) {
191+
return response.getResult().getOutput();
192+
}
193+
194+
@Override
195+
protected ChatResponse doChatCompletion(BedrockConverseRequest request) {
196+
return converseApi.converse(request);
197+
}
198+
199+
@Override
200+
protected Flux<ChatResponse> doChatCompletionStream(BedrockConverseRequest request) {
201+
return converseApi.converseStream(request);
202+
}
203+
204+
@Override
205+
protected boolean isToolFunctionCall(ChatResponse response) {
206+
Generation result = response.getResult();
207+
208+
if (result.getMetadata() instanceof BedrockConverseChatGenerationMetadata metadata) {
209+
return metadata.isToolUse();
210+
}
211+
212+
return false;
213+
}
214+
95215
/**
96216
* Anthropic3 models version.
97217
*/

0 commit comments

Comments
 (0)