Skip to content

Commit 7e342f8

Browse files
committed
feat: Add support for Consumer and Supplier function callbacks
- Add support for Java Consumer and Supplier functional interfaces in function callbacks - Handle void type inputs and outputs in function callbacks - Add test cases for void responses, Consumer callbacks, and Supplier callbacks - Update ModelOptionsUtils to properly handle void type schemas Resolves #1718 and #1277
1 parent 6acd202 commit 7e342f8

File tree

5 files changed

+287
-121
lines changed

5 files changed

+287
-121
lines changed

spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
import org.springframework.beans.BeanWrapper;
5252
import org.springframework.beans.BeanWrapperImpl;
5353
import org.springframework.util.Assert;
54+
import org.springframework.util.ClassUtils;
5455
import org.springframework.util.CollectionUtils;
5556
import org.springframework.util.ObjectUtils;
5657

@@ -358,8 +359,13 @@ public static String getJsonSchema(Class<?> clazz, boolean toUpperCaseTypeValues
358359
}
359360

360361
ObjectNode node = SCHEMA_GENERATOR_CACHE.get().generateSchema(clazz);
361-
if (toUpperCaseTypeValues) { // Required for OpenAPI 3.0 (at least Vertex AI
362-
// version of it).
362+
363+
if (ClassUtils.isVoidType(clazz) && node.get("properties") == null) {
364+
node.putObject("properties");
365+
}
366+
367+
// Required for OpenAPI 3.0 (at least Vertex AI version of it).
368+
if (toUpperCaseTypeValues) {
363369
toUpperCaseTypeValues(node);
364370
}
365371

spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
package org.springframework.ai.model.function;
1818

1919
import java.util.function.BiFunction;
20+
import java.util.function.Consumer;
2021
import java.util.function.Function;
22+
import java.util.function.Supplier;
2123

2224
import com.fasterxml.jackson.annotation.JsonClassDescription;
2325
import kotlin.jvm.functions.Function1;
@@ -129,6 +131,22 @@ else if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) {
129131
.withInputType(functionInputClass)
130132
.build();
131133
}
134+
if (bean instanceof Consumer<?> consumer) {
135+
return FunctionCallbackWrapper.builder(consumer)
136+
.withName(beanName)
137+
.withSchemaType(this.schemaType)
138+
.withDescription(functionDescription)
139+
.withInputType(functionInputClass)
140+
.build();
141+
}
142+
if (bean instanceof Supplier<?> supplier) {
143+
return FunctionCallbackWrapper.builder(supplier)
144+
.withName(beanName)
145+
.withSchemaType(this.schemaType)
146+
.withDescription(functionDescription)
147+
.withInputType(functionInputClass)
148+
.build();
149+
}
132150
else if (bean instanceof BiFunction<?, ?, ?>) {
133151
return FunctionCallbackWrapper.builder((BiFunction<?, ToolContext, ?>) bean)
134152
.withName(beanName)

spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
package org.springframework.ai.model.function;
1818

1919
import java.util.function.BiFunction;
20+
import java.util.function.Consumer;
2021
import java.util.function.Function;
22+
import java.util.function.Supplier;
2123

2224
import com.fasterxml.jackson.databind.DeserializationFeature;
2325
import com.fasterxml.jackson.databind.ObjectMapper;
@@ -59,6 +61,19 @@ public static <I, O> Builder<I, O> builder(Function<I, O> function) {
5961
return new Builder<>(function);
6062
}
6163

64+
public static <Void, O> Builder<Void, O> builder(Supplier<O> supplier) {
65+
Function<Void, O> function = (input) -> supplier.get();
66+
return new Builder<>(function);
67+
}
68+
69+
public static <I, Void> Builder<I, Void> builder(Consumer<I> consumer) {
70+
Function<I, Void> function = (input) -> {
71+
consumer.accept(input);
72+
return null;
73+
};
74+
return new Builder<>(function);
75+
}
76+
6277
@Override
6378
public O apply(I input, ToolContext context) {
6479
return this.biFunction.apply(input, context);

spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
import java.lang.reflect.Modifier;
2121
import java.util.Arrays;
2222
import java.util.function.BiFunction;
23+
import java.util.function.Consumer;
2324
import java.util.function.Function;
25+
import java.util.function.Supplier;
2426

2527
import kotlin.jvm.functions.Function1;
2628
import kotlin.jvm.functions.Function2;
@@ -189,6 +191,12 @@ public static ResolvableType getFunctionArgumentType(ResolvableType functionType
189191
else if (BiFunction.class.isAssignableFrom(resolvableClass)) {
190192
functionArgumentResolvableType = functionType.as(BiFunction.class);
191193
}
194+
else if (Supplier.class.isAssignableFrom(resolvableClass)) {
195+
return ResolvableType.forClass(Void.class);
196+
}
197+
else if (Consumer.class.isAssignableFrom(resolvableClass)) {
198+
functionArgumentResolvableType = functionType.as(Consumer.class);
199+
}
192200
else if (KotlinDetector.isKotlinPresent()) {
193201
if (KotlinDelegate.isKotlinFunction(resolvableClass)) {
194202
functionArgumentResolvableType = KotlinDelegate.adaptToKotlinFunctionType(functionType);

0 commit comments

Comments
 (0)