Skip to content

Commit 2753b01

Browse files
committed
add ChatClient API support for consumer and supplier funcitons
1 parent 7e342f8 commit 2753b01

File tree

8 files changed

+201
-30
lines changed

8 files changed

+201
-30
lines changed

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ public OpenAiAudioSpeechModel openAiAudioSpeechModel(OpenAiAudioApi api) {
7979
@Bean
8080
public OpenAiImageModel openAiImageModel(OpenAiImageApi imageApi) {
8181
OpenAiImageModel openAiImageModel = new OpenAiImageModel(imageApi);
82-
// openAiImageModel.setModel("foobar");
8382
return openAiImageModel;
8483
}
8584

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import java.util.Arrays;
2222
import java.util.List;
2323
import java.util.Map;
24+
import java.util.concurrent.ConcurrentHashMap;
25+
import java.util.function.Consumer;
2426
import java.util.stream.Collectors;
2527

2628
import org.junit.jupiter.api.Disabled;
@@ -257,6 +259,50 @@ void functionCallTest() {
257259
assertThat(response).contains("30", "10", "15");
258260
}
259261

262+
@Test
263+
void functionCallSupplier() {
264+
265+
Map<String, Object> state = new ConcurrentHashMap<>();
266+
267+
// @formatter:off
268+
String response = ChatClient.create(this.chatModel).prompt()
269+
.user("Turn the light on in the living room")
270+
.function("turnLivingRoomLightOnSupplier", "Turns light on in the living room", () -> state.put("foo", "bar"))
271+
.call()
272+
.content();
273+
// @formatter:on
274+
275+
logger.info("Response: {}", response);
276+
assertThat(state).containsEntry("foo", "bar");
277+
}
278+
279+
@Test
280+
void functionCallConsumer() {
281+
282+
Map<String, Object> state = new ConcurrentHashMap<>();
283+
284+
record LightInfo(String roomName, boolean isOn) {
285+
}
286+
287+
// @formatter:off
288+
String response = ChatClient.create(this.chatModel).prompt()
289+
.user("Turn the light on in the kitchen and in the living room")
290+
.function("turnLight", "Turn light on or off in a room", new Consumer<LightInfo>() {
291+
@Override
292+
public void accept(LightInfo lightInfo) {
293+
logger.info("Turning light to [" + lightInfo.isOn + "] in " + lightInfo.roomName());
294+
state.put(lightInfo.roomName(), lightInfo.isOn());
295+
}
296+
})
297+
.call()
298+
.content();
299+
// @formatter:on
300+
301+
logger.info("Response: {}", response);
302+
assertThat(state).containsEntry("kitchen", Boolean.TRUE);
303+
assertThat(state).containsEntry("living room", Boolean.TRUE);
304+
}
305+
260306
@Test
261307
void defaultFunctionCallTest() {
262308

spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,10 @@ <I, O> ChatClientRequestSpec function(String name, String description,
223223
<I, O> ChatClientRequestSpec function(String name, String description, Class<I> inputType,
224224
java.util.function.Function<I, O> function);
225225

226+
<I, O> ChatClientRequestSpec function(String name, String description, java.util.function.Supplier<O> supplier);
227+
228+
<I, O> ChatClientRequestSpec function(String name, String description, java.util.function.Consumer<I> consumer);
229+
226230
ChatClientRequestSpec functions(String... functionBeanNames);
227231

228232
ChatClientRequestSpec toolContext(Map<String, Object> toolContext);

spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,38 @@ public <I, O> ChatClientRequestSpec function(String name, String description, @N
874874
return this;
875875
}
876876

877+
public <I, O> ChatClientRequestSpec function(String name, String description,
878+
java.util.function.Supplier<O> supplier) {
879+
880+
Assert.hasText(name, "name cannot be null or empty");
881+
Assert.hasText(description, "description cannot be null or empty");
882+
Assert.notNull(supplier, "supplier cannot be null");
883+
884+
var fcw = FunctionCallbackWrapper.builder(supplier)
885+
.withDescription(description)
886+
.withName(name)
887+
.withInputType(Void.class)
888+
.build();
889+
this.functionCallbacks.add(fcw);
890+
return this;
891+
}
892+
893+
public <I, O> ChatClientRequestSpec function(String name, String description,
894+
java.util.function.Consumer<I> consumer) {
895+
896+
Assert.hasText(name, "name cannot be null or empty");
897+
Assert.hasText(description, "description cannot be null or empty");
898+
Assert.notNull(consumer, "consumer cannot be null");
899+
900+
var fcw = FunctionCallbackWrapper.builder(consumer)
901+
.withDescription(description)
902+
.withName(name)
903+
// .withResponseConverter(Object::toString)
904+
.build();
905+
this.functionCallbacks.add(fcw);
906+
return this;
907+
}
908+
877909
public ChatClientRequestSpec functions(String... functionBeanNames) {
878910
Assert.notNull(functionBeanNames, "functionBeanNames cannot be null");
879911
Assert.noNullElements(functionBeanNames, "functionBeanNames cannot contain null elements");

spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,8 @@ public static <Void, O> Builder<Void, O> builder(Supplier<O> supplier) {
6666
return new Builder<>(function);
6767
}
6868

69-
public static <I, Void> Builder<I, Void> builder(Consumer<I> consumer) {
70-
Function<I, Void> function = (input) -> {
71-
consumer.accept(input);
72-
return null;
73-
};
74-
return new Builder<>(function);
69+
public static <I, O> Builder<I, O> builder(Consumer<I> consumer) {
70+
return new Builder<>(consumer);
7571
}
7672

7773
@Override
@@ -85,6 +81,8 @@ public static class Builder<I, O> {
8581

8682
private final Function<I, O> function;
8783

84+
private final Consumer<I> consumer;
85+
8886
private String name;
8987

9088
private String description;
@@ -104,12 +102,21 @@ public Builder(BiFunction<I, ToolContext, O> biFunction) {
104102
Assert.notNull(biFunction, "Function must not be null");
105103
this.biFunction = biFunction;
106104
this.function = null;
105+
this.consumer = null;
107106
}
108107

109108
public Builder(Function<I, O> function) {
110109
Assert.notNull(function, "Function must not be null");
111110
this.biFunction = null;
112111
this.function = function;
112+
this.consumer = null;
113+
}
114+
115+
public Builder(Consumer<I> consumer) {
116+
Assert.notNull(consumer, "Consumer must not be null");
117+
this.biFunction = null;
118+
this.function = null;
119+
this.consumer = consumer;
113120
}
114121

115122
@SuppressWarnings("unchecked")
@@ -123,6 +130,11 @@ private static <I, O> Class<I> resolveInputType(Function<I, O> function) {
123130
return (Class<I>) TypeResolverHelper.getFunctionInputClass((Class<Function<I, O>>) function.getClass());
124131
}
125132

133+
@SuppressWarnings("unchecked")
134+
private static <I> Class<I> resolveInputType(Consumer<I> consumer) {
135+
return (Class<I>) TypeResolverHelper.getConsumerInputClass((Class<Consumer<I>>) consumer.getClass());
136+
}
137+
126138
public Builder<I, O> withName(String name) {
127139
Assert.hasText(name, "Name must not be empty");
128140
this.name = name;
@@ -183,18 +195,36 @@ public FunctionCallbackWrapper<I, O> build() {
183195
if (this.function != null) {
184196
this.inputType = resolveInputType(this.function);
185197
}
186-
else {
198+
else if (this.biFunction != null) {
187199
this.inputType = resolveInputType(this.biFunction);
188200
}
201+
else {
202+
this.inputType = resolveInputType(this.consumer);
203+
}
189204
}
190205

191206
if (this.inputTypeSchema == null) {
192207
boolean upperCaseTypeValues = this.schemaType == SchemaType.OPEN_API_SCHEMA;
193208
this.inputTypeSchema = ModelOptionsUtils.getJsonSchema(this.inputType, upperCaseTypeValues);
194209
}
195210

196-
BiFunction<I, ToolContext, O> finalBiFunction = (this.biFunction != null) ? this.biFunction
197-
: (request, context) -> this.function.apply(request);
211+
BiFunction<I, ToolContext, O> finalBiFunction = null;
212+
if (this.biFunction != null) {
213+
finalBiFunction = this.biFunction;
214+
}
215+
else if (this.function != null) {
216+
finalBiFunction = (request, context) -> this.function.apply(request);
217+
}
218+
else {
219+
finalBiFunction = (request, context) -> {
220+
this.consumer.accept(request);
221+
return null;
222+
};
223+
}
224+
225+
// BiFunction<I, ToolContext, O> finalBiFunction = (this.biFunction != null) ?
226+
// this.biFunction
227+
// : (request, context) -> this.function.apply(request);
198228

199229
return new FunctionCallbackWrapper<>(this.name, this.description, this.inputTypeSchema, this.inputType,
200230
this.responseConverter, this.objectMapper, finalBiFunction);

spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,16 @@ public static Class<?> getFunctionInputClass(Class<? extends Function<?, ?>> fun
6464
return getFunctionArgumentClass(functionClass, 0);
6565
}
6666

67+
/**
68+
* Returns the input class of a given Consumer class.
69+
* @param consumerClass The consumer class.
70+
* @return The input class of the consumer.
71+
*/
72+
public static Class<?> getConsumerInputClass(Class<? extends Consumer<?>> consumerClass) {
73+
ResolvableType resolvableType = ResolvableType.forClass(consumerClass).as(Consumer.class);
74+
return (resolvableType == ResolvableType.NONE ? Object.class : resolvableType.getGeneric(0).toClass());
75+
}
76+
6777
/**
6878
* Returns the output class of a given function class.
6979
* @param functionClass The function class.

spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1467,6 +1467,25 @@ void whenBiFunctionThenReturn() {
14671467
assertThat(defaultSpec.getFunctionCallbacks()).anyMatch(callback -> callback.getName().equals("name"));
14681468
}
14691469

1470+
@Test
1471+
void whenSupplierFunctionThenReturn() {
1472+
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
1473+
ChatClient.ChatClientRequestSpec spec = chatClient.prompt();
1474+
spec = spec.function("name", "description", () -> "hello");
1475+
DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec;
1476+
assertThat(defaultSpec.getFunctionCallbacks()).anyMatch(callback -> callback.getName().equals("name"));
1477+
}
1478+
1479+
@Test
1480+
void whenConsumerFunctionThenReturn() {
1481+
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
1482+
ChatClient.ChatClientRequestSpec spec = chatClient.prompt();
1483+
Consumer<String> consumer = input -> System.out.println(input);
1484+
spec = spec.function("name", "description", consumer);
1485+
DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec;
1486+
assertThat(defaultSpec.getFunctionCallbacks()).anyMatch(callback -> callback.getName().equals("name"));
1487+
}
1488+
14701489
@Test
14711490
void whenFunctionBeanNamesElementIsNullThenThrow() {
14721491
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
package org.springframework.ai.autoconfigure.openai.tool;
1818

19+
import java.util.Map;
20+
import java.util.concurrent.ConcurrentHashMap;
21+
import java.util.function.Consumer;
1922
import java.util.function.Function;
2023
import java.util.stream.Collectors;
2124

@@ -39,17 +42,17 @@ public class FunctionCallbackInPrompt2IT {
3942
private final Logger logger = LoggerFactory.getLogger(FunctionCallbackInPromptIT.class);
4043

4144
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
42-
.withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY"))
45+
.withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY"),
46+
"spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName())
4347
.withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class));
4448

4549
@Test
4650
void functionCallTest() {
47-
this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName())
48-
.run(context -> {
51+
this.contextRunner.run(context -> {
4952

50-
OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class);
53+
OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class);
5154

52-
ChatClient chatClient = ChatClient.builder(chatModel).build();
55+
ChatClient chatClient = ChatClient.builder(chatModel).build();
5356

5457
// @formatter:off
5558
chatClient.prompt()
@@ -62,18 +65,17 @@ void functionCallTest() {
6265
.call().content();
6366
// @formatter:on
6467

65-
logger.info("Response: {}", content);
68+
logger.info("Response: {}", content);
6669

67-
assertThat(content).contains("30", "10", "15");
68-
});
70+
assertThat(content).contains("30", "10", "15");
71+
});
6972
}
7073

7174
@Test
7275
void functionCallTest2() {
73-
this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName())
74-
.run(context -> {
76+
this.contextRunner.run(context -> {
7577

76-
OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class);
78+
OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class);
7779

7880
// @formatter:off
7981
String content = ChatClient.builder(chatModel).build().prompt()
@@ -87,19 +89,48 @@ public String apply(MockWeatherService.Request request) {
8789
})
8890
.call().content();
8991
// @formatter:on
90-
logger.info("Response: {}", content);
92+
logger.info("Response: {}", content);
9193

92-
assertThat(content).contains("18");
93-
});
94+
assertThat(content).contains("18");
95+
});
96+
}
97+
98+
@Test
99+
void functionCallTest21() {
100+
Map<String, Object> state = new ConcurrentHashMap<>();
101+
102+
record LightInfo(String roomName, boolean isOn) {
103+
}
104+
105+
this.contextRunner.run(context -> {
106+
107+
OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class);
108+
109+
// @formatter:off
110+
String content = ChatClient.builder(chatModel).build().prompt()
111+
.user("Turn the light on in the kitchen and in the living room!")
112+
.function("turnLight", "Turn light on or off in a room",
113+
new Consumer<LightInfo>() {
114+
@Override
115+
public void accept(LightInfo lightInfo) {
116+
logger.info("Turning light to [" + lightInfo.isOn + "] in " + lightInfo.roomName());
117+
state.put(lightInfo.roomName(), lightInfo.isOn());
118+
}
119+
})
120+
.call().content();
121+
// @formatter:on
122+
logger.info("Response: {}", content);
123+
assertThat(state).containsEntry("kitchen", Boolean.TRUE);
124+
assertThat(state).containsEntry("living room", Boolean.TRUE);
125+
});
94126
}
95127

96128
@Test
97129
void streamingFunctionCallTest() {
98130

99-
this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName())
100-
.run(context -> {
131+
this.contextRunner.run(context -> {
101132

102-
OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class);
133+
OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class);
103134

104135
// @formatter:off
105136
String content = ChatClient.builder(chatModel).build().prompt()
@@ -109,10 +140,10 @@ void streamingFunctionCallTest() {
109140
.collectList().block().stream().collect(Collectors.joining());
110141
// @formatter:on
111142

112-
logger.info("Response: {}", content);
143+
logger.info("Response: {}", content);
113144

114-
assertThat(content).contains("30", "10", "15");
115-
});
145+
assertThat(content).contains("30", "10", "15");
146+
});
116147
}
117148

118149
}

0 commit comments

Comments
 (0)