16
16
package org .springframework .ai .bedrock .anthropic3 ;
17
17
18
18
import reactor .core .publisher .Flux ;
19
- import software .amazon .awssdk .services .bedrockruntime .model .ConverseResponse ;
20
19
import software .amazon .awssdk .services .bedrockruntime .model .ConverseStreamOutput ;
20
+ import software .amazon .awssdk .services .bedrockruntime .model .Tool ;
21
+ import software .amazon .awssdk .services .bedrockruntime .model .ToolConfiguration ;
22
+ import software .amazon .awssdk .services .bedrockruntime .model .ToolInputSchema ;
23
+ import software .amazon .awssdk .services .bedrockruntime .model .ToolSpecification ;
21
24
25
+ import java .util .HashSet ;
26
+ import java .util .List ;
27
+ import java .util .Set ;
28
+
29
+ import org .springframework .ai .bedrock .BedrockConverseChatGenerationMetadata ;
22
30
import org .springframework .ai .bedrock .api .BedrockConverseApi ;
31
+ import org .springframework .ai .bedrock .api .BedrockConverseApi .BedrockConverseRequest ;
23
32
import org .springframework .ai .bedrock .api .BedrockConverseApiUtils ;
33
+ import org .springframework .ai .chat .messages .Message ;
24
34
import org .springframework .ai .chat .model .ChatModel ;
25
35
import org .springframework .ai .chat .model .ChatResponse ;
36
+ import org .springframework .ai .chat .model .Generation ;
26
37
import org .springframework .ai .chat .model .StreamingChatModel ;
27
38
import org .springframework .ai .chat .prompt .ChatOptions ;
28
39
import org .springframework .ai .chat .prompt .Prompt ;
29
40
import org .springframework .ai .model .ModelDescription ;
41
+ import org .springframework .ai .model .ModelOptionsUtils ;
42
+ import org .springframework .ai .model .function .AbstractFunctionCallSupport ;
43
+ import org .springframework .ai .model .function .FunctionCallbackContext ;
30
44
import org .springframework .util .Assert ;
45
+ import org .springframework .util .CollectionUtils ;
31
46
32
47
/**
33
48
* Java {@link ChatModel} and {@link StreamingChatModel} for the Bedrock Anthropic3 chat
38
53
* @author Wei Jiang
39
54
* @since 1.0.0
40
55
*/
41
- public class BedrockAnthropic3ChatModel implements ChatModel , StreamingChatModel {
56
+ public class BedrockAnthropic3ChatModel
57
+ extends AbstractFunctionCallSupport <Message , BedrockConverseRequest , ChatResponse >
58
+ implements ChatModel , StreamingChatModel {
42
59
43
60
private final String modelId ;
44
61
@@ -56,6 +73,13 @@ public BedrockAnthropic3ChatModel(BedrockConverseApi converseApi, Anthropic3Chat
56
73
}
57
74
58
75
public BedrockAnthropic3ChatModel (String modelId , BedrockConverseApi converseApi , Anthropic3ChatOptions options ) {
76
+ this (modelId , converseApi , options , null );
77
+ }
78
+
79
+ public BedrockAnthropic3ChatModel (String modelId , BedrockConverseApi converseApi , Anthropic3ChatOptions options ,
80
+ FunctionCallbackContext functionCallbackContext ) {
81
+ super (functionCallbackContext );
82
+
59
83
Assert .notNull (modelId , "modelId must not be null." );
60
84
Assert .notNull (converseApi , "BedrockConverseApi must not be null." );
61
85
Assert .notNull (options , "Anthropic3ChatOptions must not be null." );
@@ -69,29 +93,125 @@ public BedrockAnthropic3ChatModel(String modelId, BedrockConverseApi converseApi
69
93
public ChatResponse call (Prompt prompt ) {
70
94
Assert .notNull (prompt , "Prompt must not be null." );
71
95
72
- var request = BedrockConverseApiUtils .createConverseRequest (modelId , prompt , defaultOptions );
73
-
74
- ConverseResponse response = this .converseApi .converse (request );
96
+ var request = createBedrockConverseRequest (prompt );
75
97
76
- return BedrockConverseApiUtils . convertConverseResponse ( response );
98
+ return this . callWithFunctionSupport ( request );
77
99
}
78
100
79
101
@ Override
80
102
public Flux <ChatResponse > stream (Prompt prompt ) {
81
103
Assert .notNull (prompt , "Prompt must not be null." );
82
104
105
+ // TODO
83
106
var request = BedrockConverseApiUtils .createConverseStreamRequest (modelId , prompt , defaultOptions );
84
107
85
108
Flux <ConverseStreamOutput > fluxResponse = this .converseApi .converseStream (request );
86
109
87
110
return fluxResponse .map (output -> BedrockConverseApiUtils .convertConverseStreamOutput (output ));
88
111
}
89
112
113
+ private BedrockConverseRequest createBedrockConverseRequest (Prompt prompt ) {
114
+ var request = BedrockConverseApiUtils .createBedrockConverseRequest (modelId , prompt , defaultOptions );
115
+
116
+ ToolConfiguration toolConfiguration = createToolConfiguration (prompt );
117
+ request .setToolConfiguration (toolConfiguration );
118
+
119
+ return request ;
120
+ }
121
+
122
+ private ToolConfiguration createToolConfiguration (Prompt prompt ) {
123
+ Set <String > functionsForThisRequest = new HashSet <>();
124
+
125
+ if (this .defaultOptions != null ) {
126
+ Set <String > promptEnabledFunctions = this .handleFunctionCallbackConfigurations (this .defaultOptions ,
127
+ !IS_RUNTIME_CALL );
128
+ functionsForThisRequest .addAll (promptEnabledFunctions );
129
+ }
130
+
131
+ if (prompt .getOptions () != null ) {
132
+ if (prompt .getOptions () instanceof ChatOptions runtimeOptions ) {
133
+ Anthropic3ChatOptions updatedRuntimeOptions = ModelOptionsUtils .copyToTarget (runtimeOptions ,
134
+ ChatOptions .class , Anthropic3ChatOptions .class );
135
+
136
+ Set <String > defaultEnabledFunctions = this .handleFunctionCallbackConfigurations (updatedRuntimeOptions ,
137
+ IS_RUNTIME_CALL );
138
+ functionsForThisRequest .addAll (defaultEnabledFunctions );
139
+ }
140
+ else {
141
+ throw new IllegalArgumentException ("Prompt options are not of type ChatOptions: "
142
+ + prompt .getOptions ().getClass ().getSimpleName ());
143
+ }
144
+ }
145
+
146
+ if (CollectionUtils .isEmpty (functionsForThisRequest )) {
147
+ return null ;
148
+ }
149
+ else {
150
+ return ToolConfiguration .builder ().tools (getFunctionTools (functionsForThisRequest )).build ();
151
+ }
152
+ }
153
+
154
+ private List <Tool > getFunctionTools (Set <String > functionNames ) {
155
+ return this .resolveFunctionCallbacks (functionNames ).stream ().map (functionCallback -> {
156
+ var description = functionCallback .getDescription ();
157
+ var name = functionCallback .getName ();
158
+ String inputSchema = functionCallback .getInputTypeSchema ();
159
+
160
+ return Tool .builder ()
161
+ .toolSpec (ToolSpecification .builder ()
162
+ .name (name )
163
+ .description (description )
164
+ .inputSchema (ToolInputSchema .builder ()
165
+ .json (BedrockConverseApiUtils .convertObjectToDocument (ModelOptionsUtils .jsonToMap (inputSchema )))
166
+ .build ())
167
+ .build ())
168
+ .build ();
169
+ }).toList ();
170
+ }
171
+
90
172
@ Override
91
173
public ChatOptions getDefaultOptions () {
92
174
return Anthropic3ChatOptions .fromOptions (this .defaultOptions );
93
175
}
94
176
177
+ @ Override
178
+ protected BedrockConverseRequest doCreateToolResponseRequest (BedrockConverseRequest previousRequest ,
179
+ Message responseMessage , List <Message > conversationHistory ) {
180
+ // TODO
181
+ return null ;
182
+ }
183
+
184
+ @ Override
185
+ protected List <Message > doGetUserMessages (BedrockConverseRequest request ) {
186
+ return BedrockConverseApiUtils .getMessagesInstructions (request .getMessages ());
187
+ }
188
+
189
+ @ Override
190
+ protected Message doGetToolResponseMessage (ChatResponse response ) {
191
+ return response .getResult ().getOutput ();
192
+ }
193
+
194
+ @ Override
195
+ protected ChatResponse doChatCompletion (BedrockConverseRequest request ) {
196
+ return converseApi .converse (request );
197
+ }
198
+
199
+ @ Override
200
+ protected Flux <ChatResponse > doChatCompletionStream (BedrockConverseRequest request ) {
201
+ return converseApi .converseStream (request );
202
+ }
203
+
204
+ @ Override
205
+ protected boolean isToolFunctionCall (ChatResponse response ) {
206
+ Generation result = response .getResult ();
207
+
208
+ if (result .getMetadata () instanceof BedrockConverseChatGenerationMetadata metadata ) {
209
+ return metadata .isToolUse ();
210
+ }
211
+
212
+ return false ;
213
+ }
214
+
95
215
/**
96
216
* Anthropic3 models version.
97
217
*/
0 commit comments