Skip to content

Commit bc273da

Browse files
author
wmz7year
committed
Amazon Bedrock Chat adds tool support.
1 parent 49b3326 commit bc273da

File tree

7 files changed

+680
-63
lines changed

7 files changed

+680
-63
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.bedrock;
17+
18+
import java.util.List;
19+
20+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
21+
22+
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock.Type;
23+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
24+
import software.amazon.awssdk.services.bedrockruntime.model.Message;
25+
import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock;
26+
27+
/**
28+
* Amazon Bedrock Chat model converse interface generation metadata, encapsulating
29+
* information on the completion.
30+
*
31+
* @author Wei Jiang
32+
* @since 1.0.0
33+
*/
34+
public class BedrockConverseChatGenerationMetadata implements ChatGenerationMetadata {
35+
36+
private final String stopReason;
37+
38+
private final Message message;
39+
40+
public BedrockConverseChatGenerationMetadata(String stopReason, Message message) {
41+
super();
42+
43+
this.stopReason = stopReason;
44+
this.message = message;
45+
}
46+
47+
public static BedrockConverseChatGenerationMetadata from(ConverseResponse response, Message message) {
48+
return new BedrockConverseChatGenerationMetadata(response.stopReasonAsString(), message);
49+
}
50+
51+
@Override
52+
public <T> T getContentFilterMetadata() {
53+
return null;
54+
}
55+
56+
@Override
57+
public String getFinishReason() {
58+
return stopReason;
59+
}
60+
61+
public Message getMessage() {
62+
return message;
63+
}
64+
65+
public boolean isToolUse() {
66+
return message.content().stream().anyMatch(content -> content.type() == Type.TOOL_USE);
67+
}
68+
69+
public List<ToolUseBlock> getToolToUseList() {
70+
return message.content()
71+
.stream()
72+
.filter(content -> content.type() == Type.TOOL_USE)
73+
.map(content -> content.toolUse())
74+
.toList();
75+
}
76+
77+
}

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,20 @@
1515
*/
1616
package org.springframework.ai.bedrock.anthropic3;
1717

18+
import com.fasterxml.jackson.annotation.JsonIgnore;
1819
import com.fasterxml.jackson.annotation.JsonInclude;
1920
import com.fasterxml.jackson.annotation.JsonInclude.Include;
2021
import com.fasterxml.jackson.annotation.JsonProperty;
2122
import org.springframework.ai.chat.prompt.ChatOptions;
23+
import org.springframework.ai.model.function.FunctionCallback;
24+
import org.springframework.ai.model.function.FunctionCallingOptions;
25+
import org.springframework.boot.context.properties.NestedConfigurationProperty;
26+
import org.springframework.util.Assert;
2227

28+
import java.util.ArrayList;
29+
import java.util.HashSet;
2330
import java.util.List;
31+
import java.util.Set;
2432

2533
/**
2634
* Java {@link ChatOptions} for the Bedrock Anthropic chat generative model chat options.
@@ -31,7 +39,7 @@
3139
* @since 1.0.0
3240
*/
3341
@JsonInclude(Include.NON_NULL)
34-
public class Anthropic3ChatOptions implements ChatOptions {
42+
public class Anthropic3ChatOptions implements ChatOptions, FunctionCallingOptions {
3543

3644
// @formatter:off
3745
/**
@@ -66,6 +74,31 @@ public class Anthropic3ChatOptions implements ChatOptions {
6674
*/
6775
private @JsonProperty("stop_sequences") List<String> stopSequences;
6876

77+
/**
78+
* Tool Function Callbacks to register with the ChatModel. For Prompt
79+
* Options the functionCallbacks are automatically enabled for the duration of the
80+
* prompt execution. For Default Options the functionCallbacks are registered but
81+
* disabled by default. Use the enableFunctions to set the functions from the registry
82+
* to be used by the ChatModel chat completion requests.
83+
*/
84+
@NestedConfigurationProperty
85+
@JsonIgnore
86+
private List<FunctionCallback> functionCallbacks = new ArrayList<>();
87+
88+
/**
89+
* List of functions, identified by their names, to configure for function calling in
90+
* the chat completion requests. Functions with those names must exist in the
91+
* functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions
92+
* are automatically enabled for the duration of the prompt execution.
93+
*
94+
* Note that function enabled with the default options are enabled for all chat
95+
* completion requests. This could impact the token count and the billing. If the
96+
* functions is set in a prompt options, then the enabled functions are only active
97+
* for the duration of this prompt execution.
98+
*/
99+
@NestedConfigurationProperty
100+
@JsonIgnore
101+
private Set<String> functions = new HashSet<>();
69102
// @formatter:on
70103

71104
public static Builder builder() {
@@ -101,6 +134,23 @@ public Builder withStopSequences(List<String> stopSequences) {
101134
return this;
102135
}
103136

137+
public Builder withFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
138+
this.options.functionCallbacks = functionCallbacks;
139+
return this;
140+
}
141+
142+
public Builder withFunctions(Set<String> functionNames) {
143+
Assert.notNull(functionNames, "Function names must not be null");
144+
this.options.functions = functionNames;
145+
return this;
146+
}
147+
148+
public Builder withFunction(String functionName) {
149+
Assert.hasText(functionName, "Function name must not be empty");
150+
this.options.functions.add(functionName);
151+
return this;
152+
}
153+
104154
public Anthropic3ChatOptions build() {
105155
return this.options;
106156
}
@@ -150,12 +200,36 @@ public void setStopSequences(List<String> stopSequences) {
150200
this.stopSequences = stopSequences;
151201
}
152202

203+
@Override
204+
public List<FunctionCallback> getFunctionCallbacks() {
205+
return this.functionCallbacks;
206+
}
207+
208+
@Override
209+
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
210+
Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null");
211+
this.functionCallbacks = functionCallbacks;
212+
}
213+
214+
@Override
215+
public Set<String> getFunctions() {
216+
return this.functions;
217+
}
218+
219+
@Override
220+
public void setFunctions(Set<String> functions) {
221+
Assert.notNull(functions, "Function must not be null");
222+
this.functions = functions;
223+
}
224+
153225
public static Anthropic3ChatOptions fromOptions(Anthropic3ChatOptions fromOptions) {
154226
return builder().withTemperature(fromOptions.getTemperature())
155227
.withMaxTokens(fromOptions.getMaxTokens())
156228
.withTopK(fromOptions.getTopK())
157229
.withTopP(fromOptions.getTopP())
158230
.withStopSequences(fromOptions.getStopSequences())
231+
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
232+
.withFunctions(fromOptions.getFunctions())
159233
.build();
160234
}
161235

0 commit comments

Comments
 (0)