Skip to content

feat(core): Add Supplier and Consumer function callback support #1746

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

Expand All @@ -28,6 +29,7 @@
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
Expand Down Expand Up @@ -59,6 +61,25 @@ class OpenAiChatModelFunctionCallingIT {
@Autowired
ChatModel chatModel;

@Test
void functionCallSupplier() {

Map<String, Object> state = new ConcurrentHashMap<>();

// @formatter:off
String response = ChatClient.create(this.chatModel).prompt()
.user("Turn the light on in the living room")
.functions(FunctionCallback.builder()
.function("turnsLightOnInTheLivingRoom", () -> state.put("Light", "ON"))
.build())
.call()
.content();
// @formatter:on

logger.info("Response: {}", response);
assertThat(state).containsEntry("Light", "ON");
}

@Test
void functionCallTest() {
functionCallTest(OpenAiChatOptions.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,11 @@ public static String getJsonSchema(Type inputType, boolean toUpperCaseTypeValues
}

ObjectNode node = SCHEMA_GENERATOR_CACHE.get().generateSchema(inputType);

if ((inputType == Void.class) && !node.has("properties")) {
node.putObject("properties");
}

if (toUpperCaseTypeValues) { // Required for OpenAPI 3.0 (at least Vertex AI
// version of it).
toUpperCaseTypeValues(node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import java.lang.reflect.Type;
import java.util.Arrays;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
Expand All @@ -43,7 +45,7 @@

/**
* Default implementation of the {@link FunctionCallback.Builder}.
*
*
* @author Christian Tzolov
* @since 1.0.0
*/
Expand Down Expand Up @@ -137,6 +139,20 @@ public <I, O> FunctionInvokingSpec<I, O> function(String name, BiFunction<I, Too
return new DefaultFunctionInvokingSpec<>(name, biFunction);
}

@Override
public <O> FunctionInvokingSpec<Void, O> function(String name, Supplier<O> supplier) {
Function<Void, O> function = (input) -> supplier.get();
return new DefaultFunctionInvokingSpec<>(name, function).inputType(Void.class);
}

public <I> FunctionInvokingSpec<I, Void> function(String name, Consumer<I> consumer) {
Function<I, Void> function = (I input) -> {
consumer.accept(input);
return null;
};
return new DefaultFunctionInvokingSpec<>(name, function);
}

@Override
public MethodInvokingSpec method(String methodName, Class<?>... argumentTypes) {
return new DefaultMethodInvokingSpec(methodName, argumentTypes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
package org.springframework.ai.model.function;

import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

import com.fasterxml.jackson.databind.ObjectMapper;

Expand Down Expand Up @@ -141,6 +143,16 @@ interface Builder {
*/
<I, O> FunctionInvokingSpec<I, O> function(String name, BiFunction<I, ToolContext, O> biFunction);

/**
* Builds a {@link Supplier} invoking {@link FunctionCallback} instance.
*/
<O> FunctionInvokingSpec<Void, O> function(String name, Supplier<O> supplier);

/**
* Builds a {@link Consumer} invoking {@link FunctionCallback} instance.
*/
<I> FunctionInvokingSpec<I, Void> function(String name, Consumer<I> consumer);

/**
* Builds a Method invoking {@link FunctionCallback} instance.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
package org.springframework.ai.model.function;

import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

import com.fasterxml.jackson.annotation.JsonClassDescription;
import kotlin.jvm.functions.Function0;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.functions.Function2;

Expand All @@ -30,6 +33,7 @@
import org.springframework.context.annotation.Description;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.core.KotlinDetector;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.ResolvableType;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
Expand Down Expand Up @@ -71,7 +75,8 @@ public void setApplicationContext(@NonNull ApplicationContext applicationContext
public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable String defaultDescription) {

ResolvableType functionType = TypeResolverHelper.resolveBeanType(this.applicationContext, beanName);
ResolvableType functionInputType = TypeResolverHelper.getFunctionArgumentType(functionType, 0);
ResolvableType functionInputType = (ResolvableType.forType(Supplier.class).isAssignableFrom(functionType))
? ResolvableType.forType(Void.class) : TypeResolverHelper.getFunctionArgumentType(functionType, 0);

Class<?> functionInputClass = functionInputType.toClass();
String functionDescription = defaultDescription;
Expand Down Expand Up @@ -109,15 +114,23 @@ public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable
.schemaType(this.schemaType)
.description(functionDescription)
.function(beanName, KotlinDelegate.wrapKotlinFunction(bean))
.inputType(functionInputClass)
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
.build();
}
else if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) {
return FunctionCallback.builder()
.description(functionDescription)
.schemaType(this.schemaType)
.function(beanName, KotlinDelegate.wrapKotlinBiFunction(bean))
.inputType(functionInputClass)
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
.build();
}
else if (KotlinDelegate.isKotlinSupplier(functionType.toClass())) {
return FunctionCallback.builder()
.description(functionDescription)
.schemaType(this.schemaType)
.function(beanName, KotlinDelegate.wrapKotlinSupplier(bean))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we will need support for the Kotlin's equivalent of java.util.Consumer?
But this can be implemented in a follow up PR.

.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
.build();
}
}
Expand All @@ -126,15 +139,31 @@ else if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) {
.schemaType(this.schemaType)
.description(functionDescription)
.function(beanName, function)
.inputType(functionInputClass)
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
.build();
}
else if (bean instanceof BiFunction<?, ?, ?>) {
return FunctionCallback.builder()
.description(functionDescription)
.schemaType(this.schemaType)
.function(beanName, (BiFunction<?, ToolContext, ?>) bean)
.inputType(functionInputClass)
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
.build();
}
else if (bean instanceof Supplier<?> supplier) {
return FunctionCallback.builder()
.description(functionDescription)
.schemaType(this.schemaType)
.function(beanName, supplier)
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
.build();
}
else if (bean instanceof Consumer<?> consumer) {
return FunctionCallback.builder()
.description(functionDescription)
.schemaType(this.schemaType)
.function(beanName, consumer)
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
.build();
}
else {
Expand All @@ -150,6 +179,15 @@ public enum SchemaType {

private static class KotlinDelegate {

public static boolean isKotlinSupplier(Class<?> clazz) {
return Function0.class.isAssignableFrom(clazz);
}

@SuppressWarnings("unchecked")
public static Supplier<?> wrapKotlinSupplier(Object function) {
return () -> ((Function0<Object>) function).invoke();
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likely we may have to add support for the Kotlin's equivalent of java.util.Consumer() in a follow up PR.


public static boolean isKotlinFunction(Class<?> clazz) {
return Function1.class.isAssignableFrom(clazz);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

import kotlin.jvm.functions.Function0;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.functions.Function2;

Expand All @@ -44,6 +47,16 @@
*/
public abstract class TypeResolverHelper {

/**
* Returns the input class of a given Consumer class.
* @param consumerClass The consumer class.
* @return The input class of the consumer.
*/
public static Class<?> getConsumerInputClass(Class<? extends Consumer<?>> consumerClass) {
ResolvableType resolvableType = ResolvableType.forClass(consumerClass).as(Consumer.class);
return (resolvableType == ResolvableType.NONE ? Object.class : resolvableType.getGeneric(0).toClass());
}

/**
* Returns the input class of a given function class.
* @param biFunctionClass The function class.
Expand Down Expand Up @@ -199,13 +212,22 @@ public static ResolvableType getFunctionArgumentType(ResolvableType functionType
else if (BiFunction.class.isAssignableFrom(resolvableClass)) {
functionArgumentResolvableType = functionType.as(BiFunction.class);
}
else if (Supplier.class.isAssignableFrom(resolvableClass)) {
functionArgumentResolvableType = functionType.as(Supplier.class);
}
else if (Consumer.class.isAssignableFrom(resolvableClass)) {
functionArgumentResolvableType = functionType.as(Consumer.class);
}
else if (KotlinDetector.isKotlinPresent()) {
if (KotlinDelegate.isKotlinFunction(resolvableClass)) {
functionArgumentResolvableType = KotlinDelegate.adaptToKotlinFunctionType(functionType);
}
else if (KotlinDelegate.isKotlinBiFunction(resolvableClass)) {
functionArgumentResolvableType = KotlinDelegate.adaptToKotlinBiFunctionType(functionType);
}
else if (KotlinDelegate.isKotlinSupplier(resolvableClass)) {
functionArgumentResolvableType = KotlinDelegate.adaptToKotlinSupplierType(functionType);
}
}

if (functionArgumentResolvableType == ResolvableType.NONE) {
Expand All @@ -218,6 +240,14 @@ else if (KotlinDelegate.isKotlinBiFunction(resolvableClass)) {

private static class KotlinDelegate {

public static boolean isKotlinSupplier(Class<?> clazz) {
return Function0.class.isAssignableFrom(clazz);
}

public static ResolvableType adaptToKotlinSupplierType(ResolvableType resolvableType) {
return resolvableType.as(Function0.class);
}

public static boolean isKotlinFunction(Class<?> clazz) {
return Function1.class.isAssignableFrom(clazz);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.ai.model.function;

import java.util.function.Consumer;
import java.util.function.Function;

import org.junit.jupiter.params.ParameterizedTest;
Expand All @@ -39,7 +40,7 @@ public class TypeResolverHelperIT {

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "weatherClassDefinition", "weatherFunctionDefinition", "standaloneWeatherFunction",
"scannedStandaloneWeatherFunction", "componentWeatherFunction" })
"scannedStandaloneWeatherFunction", "componentWeatherFunction", "weatherConsumer" })
void beanInputTypeResolutionWithResolvableType(String beanName) {
assertThat(this.applicationContext).isNotNull();
ResolvableType functionType = TypeResolverHelper.resolveBeanType(this.applicationContext, beanName);
Expand Down Expand Up @@ -89,6 +90,13 @@ StandaloneWeatherFunction standaloneWeatherFunction() {
return new StandaloneWeatherFunction();
}

@Bean
Consumer<WeatherRequest> weatherConsumer() {
return (weatherRequest) -> {
System.out.println(weatherRequest);
};
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.ai.model.function;

import java.util.function.Consumer;
import java.util.function.Function;

import com.fasterxml.jackson.annotation.JsonClassDescription;
Expand All @@ -35,6 +36,12 @@
*/
public class TypeResolverHelperTests {

@Test
public void testGetConsumerInputType() {
Class<?> inputType = TypeResolverHelper.getConsumerInputClass(MyConsumer.class);
assertThat(inputType).isEqualTo(Request.class);
}

@Test
public void testGetFunctionInputType() {
Class<?> inputType = TypeResolverHelper.getFunctionInputClass(MockWeatherService.class);
Expand Down Expand Up @@ -63,6 +70,14 @@ public String apply(Response response) {

}

public static class MyConsumer implements Consumer<Request> {

@Override
public void accept(Request request) {
}

}

public static class MockWeatherService implements Function<Request, Response> {

@Override
Expand Down
1 change: 1 addition & 0 deletions spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
* xref:api/prompt.adoc[]
* xref:api/structured-output-converter.adoc[Structured Output]
* xref:api/functions.adoc[Function Calling]
** xref:api/function-callback.adoc[FunctionCallback API]
* xref:api/multimodality.adoc[Multimodality]
* xref:api/etl-pipeline.adoc[]
* xref:api/testing.adoc[AI Model Evaluation]
Expand Down
Loading