Skip to content

Commit e7cc220

Browse files
committed
feat(core): Add Supplier and Consumer function callback support
Add support for no-argument Supplier and single-argument Consumer function callbacks in the Spring AI core module. This enhancement allows: - Registration of Supplier<O> callbacks with no input (Void) type - Registration of Consumer<I> callbacks with no output (Void) type - Support for Kotlin Function0 (equivalent to Java Supplier) - Handle empty properties for Void input types in schema generation - Enhance FunctionCallback builder to support Supplier/Consumer patterns Additional changes: - Add test coverage for both Supplier and Consumer callbacks in various scenarios - Enhance TypeResolverHelper to support Consumer input type resolution - Support lambda-style function declarations for improved ergonomics - Add test cases for void input/output handling in OpenAI chat model - Include examples of function calls without return values - Add support for parameterless functions through Supplier interface Resolves #1718 , #1277 , #1118, #860
1 parent 018257a commit e7cc220

File tree

11 files changed

+466
-127
lines changed

11 files changed

+466
-127
lines changed

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.util.ArrayList;
2020
import java.util.List;
2121
import java.util.Map;
22+
import java.util.concurrent.ConcurrentHashMap;
2223
import java.util.function.BiFunction;
2324
import java.util.stream.Collectors;
2425

@@ -28,6 +29,7 @@
2829
import org.slf4j.LoggerFactory;
2930
import reactor.core.publisher.Flux;
3031

32+
import org.springframework.ai.chat.client.ChatClient;
3133
import org.springframework.ai.chat.messages.AssistantMessage;
3234
import org.springframework.ai.chat.messages.Message;
3335
import org.springframework.ai.chat.messages.UserMessage;
@@ -59,6 +61,25 @@ class OpenAiChatModelFunctionCallingIT {
5961
@Autowired
6062
ChatModel chatModel;
6163

64+
@Test
65+
void functionCallSupplier() {
66+
67+
Map<String, Object> state = new ConcurrentHashMap<>();
68+
69+
// @formatter:off
70+
String response = ChatClient.create(this.chatModel).prompt()
71+
.user("Turn the light on in the living room")
72+
.functions(FunctionCallback.builder()
73+
.function("turnsLightOnInTheLivingRoom", () -> state.put("Light", "ON"))
74+
.build())
75+
.call()
76+
.content();
77+
// @formatter:on
78+
79+
logger.info("Response: {}", response);
80+
assertThat(state).containsEntry("Light", "ON");
81+
}
82+
6283
@Test
6384
void functionCallTest() {
6485
functionCallTest(OpenAiChatOptions.builder()

spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,11 @@ public static String getJsonSchema(Type inputType, boolean toUpperCaseTypeValues
395395
}
396396

397397
ObjectNode node = SCHEMA_GENERATOR_CACHE.get().generateSchema(inputType);
398+
399+
if ((inputType == Void.class) && !node.has("properties")) {
400+
node.putObject("properties");
401+
}
402+
398403
if (toUpperCaseTypeValues) { // Required for OpenAPI 3.0 (at least Vertex AI
399404
// version of it).
400405
toUpperCaseTypeValues(node);

spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallbackBuilder.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
import java.lang.reflect.Type;
1919
import java.util.Arrays;
2020
import java.util.function.BiFunction;
21+
import java.util.function.Consumer;
2122
import java.util.function.Function;
23+
import java.util.function.Supplier;
2224

2325
import com.fasterxml.jackson.core.JsonProcessingException;
2426
import com.fasterxml.jackson.databind.DeserializationFeature;
@@ -43,7 +45,7 @@
4345

4446
/**
4547
* Default implementation of the {@link FunctionCallback.Builder}.
46-
*
48+
*
4749
* @author Christian Tzolov
4850
* @since 1.0.0
4951
*/
@@ -137,6 +139,20 @@ public <I, O> FunctionInvokingSpec<I, O> function(String name, BiFunction<I, Too
137139
return new DefaultFunctionInvokingSpec<>(name, biFunction);
138140
}
139141

142+
@Override
143+
public <O> FunctionInvokingSpec<Void, O> function(String name, Supplier<O> supplier) {
144+
Function<Void, O> function = (input) -> supplier.get();
145+
return new DefaultFunctionInvokingSpec<>(name, function).inputType(Void.class);
146+
}
147+
148+
public <I> FunctionInvokingSpec<I, Void> function(String name, Consumer<I> consumer) {
149+
Function<I, Void> function = (I input) -> {
150+
consumer.accept(input);
151+
return null;
152+
};
153+
return new DefaultFunctionInvokingSpec<>(name, function);
154+
}
155+
140156
@Override
141157
public MethodInvokingSpec method(String methodName, Class<?>... argumentTypes) {
142158
return new DefaultMethodInvokingSpec(methodName, argumentTypes);

spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
package org.springframework.ai.model.function;
1818

1919
import java.util.function.BiFunction;
20+
import java.util.function.Consumer;
2021
import java.util.function.Function;
22+
import java.util.function.Supplier;
2123

2224
import com.fasterxml.jackson.databind.ObjectMapper;
2325

@@ -141,6 +143,16 @@ interface Builder {
141143
*/
142144
<I, O> FunctionInvokingSpec<I, O> function(String name, BiFunction<I, ToolContext, O> biFunction);
143145

146+
/**
147+
* Builds a {@link Supplier} invoking {@link FunctionCallback} instance.
148+
*/
149+
<O> FunctionInvokingSpec<Void, O> function(String name, Supplier<O> supplier);
150+
151+
/**
152+
* Builds a {@link Consumer} invoking {@link FunctionCallback} instance.
153+
*/
154+
<I> FunctionInvokingSpec<I, Void> function(String name, Consumer<I> consumer);
155+
144156
/**
145157
* Builds a Method invoking {@link FunctionCallback} instance.
146158
*/

spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
package org.springframework.ai.model.function;
1818

1919
import java.util.function.BiFunction;
20+
import java.util.function.Consumer;
2021
import java.util.function.Function;
22+
import java.util.function.Supplier;
2123

2224
import com.fasterxml.jackson.annotation.JsonClassDescription;
25+
import kotlin.jvm.functions.Function0;
2326
import kotlin.jvm.functions.Function1;
2427
import kotlin.jvm.functions.Function2;
2528

@@ -30,6 +33,7 @@
3033
import org.springframework.context.annotation.Description;
3134
import org.springframework.context.support.GenericApplicationContext;
3235
import org.springframework.core.KotlinDetector;
36+
import org.springframework.core.ParameterizedTypeReference;
3337
import org.springframework.core.ResolvableType;
3438
import org.springframework.lang.NonNull;
3539
import org.springframework.lang.Nullable;
@@ -71,7 +75,8 @@ public void setApplicationContext(@NonNull ApplicationContext applicationContext
7175
public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable String defaultDescription) {
7276

7377
ResolvableType functionType = TypeResolverHelper.resolveBeanType(this.applicationContext, beanName);
74-
ResolvableType functionInputType = TypeResolverHelper.getFunctionArgumentType(functionType, 0);
78+
ResolvableType functionInputType = (ResolvableType.forType(Supplier.class).isAssignableFrom(functionType))
79+
? ResolvableType.forType(Void.class) : TypeResolverHelper.getFunctionArgumentType(functionType, 0);
7580

7681
Class<?> functionInputClass = functionInputType.toClass();
7782
String functionDescription = defaultDescription;
@@ -109,15 +114,23 @@ public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable
109114
.schemaType(this.schemaType)
110115
.description(functionDescription)
111116
.function(beanName, KotlinDelegate.wrapKotlinFunction(bean))
112-
.inputType(functionInputClass)
117+
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
113118
.build();
114119
}
115120
else if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) {
116121
return FunctionCallback.builder()
117122
.description(functionDescription)
118123
.schemaType(this.schemaType)
119124
.function(beanName, KotlinDelegate.wrapKotlinBiFunction(bean))
120-
.inputType(functionInputClass)
125+
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
126+
.build();
127+
}
128+
else if (KotlinDelegate.isKotlinSupplier(functionType.toClass())) {
129+
return FunctionCallback.builder()
130+
.description(functionDescription)
131+
.schemaType(this.schemaType)
132+
.function(beanName, KotlinDelegate.wrapKotlinSupplier(bean))
133+
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
121134
.build();
122135
}
123136
}
@@ -126,15 +139,31 @@ else if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) {
126139
.schemaType(this.schemaType)
127140
.description(functionDescription)
128141
.function(beanName, function)
129-
.inputType(functionInputClass)
142+
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
130143
.build();
131144
}
132145
else if (bean instanceof BiFunction<?, ?, ?>) {
133146
return FunctionCallback.builder()
134147
.description(functionDescription)
135148
.schemaType(this.schemaType)
136149
.function(beanName, (BiFunction<?, ToolContext, ?>) bean)
137-
.inputType(functionInputClass)
150+
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
151+
.build();
152+
}
153+
else if (bean instanceof Supplier<?> supplier) {
154+
return FunctionCallback.builder()
155+
.description(functionDescription)
156+
.schemaType(this.schemaType)
157+
.function(beanName, supplier)
158+
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
159+
.build();
160+
}
161+
else if (bean instanceof Consumer<?> consumer) {
162+
return FunctionCallback.builder()
163+
.description(functionDescription)
164+
.schemaType(this.schemaType)
165+
.function(beanName, consumer)
166+
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
138167
.build();
139168
}
140169
else {
@@ -150,6 +179,15 @@ public enum SchemaType {
150179

151180
private static class KotlinDelegate {
152181

182+
public static boolean isKotlinSupplier(Class<?> clazz) {
183+
return Function0.class.isAssignableFrom(clazz);
184+
}
185+
186+
@SuppressWarnings("unchecked")
187+
public static Supplier<?> wrapKotlinSupplier(Object function) {
188+
return () -> ((Function0<Object>) function).invoke();
189+
}
190+
153191
public static boolean isKotlinFunction(Class<?> clazz) {
154192
return Function1.class.isAssignableFrom(clazz);
155193
}

spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,11 @@
2020
import java.lang.reflect.Modifier;
2121
import java.util.Arrays;
2222
import java.util.function.BiFunction;
23+
import java.util.function.Consumer;
2324
import java.util.function.Function;
25+
import java.util.function.Supplier;
2426

27+
import kotlin.jvm.functions.Function0;
2528
import kotlin.jvm.functions.Function1;
2629
import kotlin.jvm.functions.Function2;
2730

@@ -44,6 +47,16 @@
4447
*/
4548
public abstract class TypeResolverHelper {
4649

50+
/**
51+
* Returns the input class of a given Consumer class.
52+
* @param consumerClass The consumer class.
53+
* @return The input class of the consumer.
54+
*/
55+
public static Class<?> getConsumerInputClass(Class<? extends Consumer<?>> consumerClass) {
56+
ResolvableType resolvableType = ResolvableType.forClass(consumerClass).as(Consumer.class);
57+
return (resolvableType == ResolvableType.NONE ? Object.class : resolvableType.getGeneric(0).toClass());
58+
}
59+
4760
/**
4861
* Returns the input class of a given function class.
4962
* @param biFunctionClass The function class.
@@ -199,13 +212,22 @@ public static ResolvableType getFunctionArgumentType(ResolvableType functionType
199212
else if (BiFunction.class.isAssignableFrom(resolvableClass)) {
200213
functionArgumentResolvableType = functionType.as(BiFunction.class);
201214
}
215+
else if (Supplier.class.isAssignableFrom(resolvableClass)) {
216+
functionArgumentResolvableType = functionType.as(Supplier.class);
217+
}
218+
else if (Consumer.class.isAssignableFrom(resolvableClass)) {
219+
functionArgumentResolvableType = functionType.as(Consumer.class);
220+
}
202221
else if (KotlinDetector.isKotlinPresent()) {
203222
if (KotlinDelegate.isKotlinFunction(resolvableClass)) {
204223
functionArgumentResolvableType = KotlinDelegate.adaptToKotlinFunctionType(functionType);
205224
}
206225
else if (KotlinDelegate.isKotlinBiFunction(resolvableClass)) {
207226
functionArgumentResolvableType = KotlinDelegate.adaptToKotlinBiFunctionType(functionType);
208227
}
228+
else if (KotlinDelegate.isKotlinSupplier(resolvableClass)) {
229+
functionArgumentResolvableType = KotlinDelegate.adaptToKotlinSupplierType(functionType);
230+
}
209231
}
210232

211233
if (functionArgumentResolvableType == ResolvableType.NONE) {
@@ -218,6 +240,14 @@ else if (KotlinDelegate.isKotlinBiFunction(resolvableClass)) {
218240

219241
private static class KotlinDelegate {
220242

243+
public static boolean isKotlinSupplier(Class<?> clazz) {
244+
return Function0.class.isAssignableFrom(clazz);
245+
}
246+
247+
public static ResolvableType adaptToKotlinSupplierType(ResolvableType resolvableType) {
248+
return resolvableType.as(Function0.class);
249+
}
250+
221251
public static boolean isKotlinFunction(Class<?> clazz) {
222252
return Function1.class.isAssignableFrom(clazz);
223253
}

spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperIT.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.ai.model.function;
1818

19+
import java.util.function.Consumer;
1920
import java.util.function.Function;
2021

2122
import org.junit.jupiter.params.ParameterizedTest;
@@ -39,7 +40,7 @@ public class TypeResolverHelperIT {
3940

4041
@ParameterizedTest(name = "{0} : {displayName} ")
4142
@ValueSource(strings = { "weatherClassDefinition", "weatherFunctionDefinition", "standaloneWeatherFunction",
42-
"scannedStandaloneWeatherFunction", "componentWeatherFunction" })
43+
"scannedStandaloneWeatherFunction", "componentWeatherFunction", "weatherConsumer" })
4344
void beanInputTypeResolutionWithResolvableType(String beanName) {
4445
assertThat(this.applicationContext).isNotNull();
4546
ResolvableType functionType = TypeResolverHelper.resolveBeanType(this.applicationContext, beanName);
@@ -89,6 +90,13 @@ StandaloneWeatherFunction standaloneWeatherFunction() {
8990
return new StandaloneWeatherFunction();
9091
}
9192

93+
@Bean
94+
Consumer<WeatherRequest> weatherConsumer() {
95+
return (weatherRequest) -> {
96+
System.out.println(weatherRequest);
97+
};
98+
}
99+
92100
}
93101

94102
}

spring-ai-core/src/test/java/org/springframework/ai/model/function/TypeResolverHelperTests.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.ai.model.function;
1818

19+
import java.util.function.Consumer;
1920
import java.util.function.Function;
2021

2122
import com.fasterxml.jackson.annotation.JsonClassDescription;
@@ -35,6 +36,12 @@
3536
*/
3637
public class TypeResolverHelperTests {
3738

39+
@Test
40+
public void testGetConsumerInputType() {
41+
Class<?> inputType = TypeResolverHelper.getConsumerInputClass(MyConsumer.class);
42+
assertThat(inputType).isEqualTo(Request.class);
43+
}
44+
3845
@Test
3946
public void testGetFunctionInputType() {
4047
Class<?> inputType = TypeResolverHelper.getFunctionInputClass(MockWeatherService.class);
@@ -63,6 +70,14 @@ public String apply(Response response) {
6370

6471
}
6572

73+
public static class MyConsumer implements Consumer<Request> {
74+
75+
@Override
76+
public void accept(Request request) {
77+
}
78+
79+
}
80+
6681
public static class MockWeatherService implements Function<Request, Response> {
6782

6883
@Override

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ void functionCallTest() {
6666
var promptOptions = MistralAiChatOptions.builder()
6767
.withFunctionCallbacks(List.of(FunctionCallback.builder()
6868
.description("Get payment status of a transaction")
69-
.function("retrievePaymentStatus", transaction -> new Status(DATA.get(transaction).status()))
69+
.function("retrievePaymentStatus",
70+
(Transaction transaction) -> new Status(DATA.get(transaction).status()))
7071
.inputType(Transaction.class)
7172
.build()))
7273
.build();

0 commit comments

Comments
 (0)