From 3898330e01080f9f59741524d6ae5bf9e44febdf Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Mon, 11 Nov 2024 12:37:22 +0100 Subject: [PATCH 1/4] feat: Add support for Consumer and Supplier function callbacks - Add support for Java Consumer and Supplier functional interfaces in function callbacks - Handle void type inputs and outputs in function callbacks - Add test cases for void responses, Consumer callbacks, and Supplier callbacks - Update ModelOptionsUtils to properly handle void type schemas Resolves #1718 and #1277 --- .../ai/model/ModelOptionsUtils.java | 10 +- .../function/FunctionCallbackContext.java | 18 + .../function/FunctionCallbackWrapper.java | 15 + .../ai/model/function/TypeResolverHelper.java | 8 + ...nctionCallbackWithPlainFunctionBeanIT.java | 357 ++++++++++++------ 5 files changed, 287 insertions(+), 121 deletions(-) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java index e049fb17f00..b4edf135657 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java @@ -51,6 +51,7 @@ import org.springframework.beans.BeanWrapper; import org.springframework.beans.BeanWrapperImpl; import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; import org.springframework.util.CollectionUtils; import org.springframework.util.ObjectUtils; @@ -358,8 +359,13 @@ public static String getJsonSchema(Class clazz, boolean toUpperCaseTypeValues } ObjectNode node = SCHEMA_GENERATOR_CACHE.get().generateSchema(clazz); - if (toUpperCaseTypeValues) { // Required for OpenAPI 3.0 (at least Vertex AI - // version of it). + + if (ClassUtils.isVoidType(clazz) && node.get("properties") == null) { + node.putObject("properties"); + } + + // Required for OpenAPI 3.0 (at least Vertex AI version of it). + if (toUpperCaseTypeValues) { toUpperCaseTypeValues(node); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java index 762d33969bf..b96673c8cb7 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java @@ -17,7 +17,9 @@ package org.springframework.ai.model.function; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Supplier; import com.fasterxml.jackson.annotation.JsonClassDescription; import kotlin.jvm.functions.Function1; @@ -129,6 +131,22 @@ else if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) { .withInputType(functionInputClass) .build(); } + if (bean instanceof Consumer consumer) { + return FunctionCallbackWrapper.builder(consumer) + .withName(beanName) + .withSchemaType(this.schemaType) + .withDescription(functionDescription) + .withInputType(functionInputClass) + .build(); + } + if (bean instanceof Supplier supplier) { + return FunctionCallbackWrapper.builder(supplier) + .withName(beanName) + .withSchemaType(this.schemaType) + .withDescription(functionDescription) + .withInputType(functionInputClass) + .build(); + } else if (bean instanceof BiFunction) { return FunctionCallbackWrapper.builder((BiFunction) bean) .withName(beanName) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java index fe9fa0a1533..5342aa9f6d5 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java @@ -17,7 +17,9 @@ package org.springframework.ai.model.function; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Supplier; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; @@ -59,6 +61,19 @@ public static Builder builder(Function function) { return new Builder<>(function); } + public static Builder builder(Supplier supplier) { + Function function = (input) -> supplier.get(); + return new Builder<>(function); + } + + public static Builder builder(Consumer consumer) { + Function function = (input) -> { + consumer.accept(input); + return null; + }; + return new Builder<>(function); + } + @Override public O apply(I input, ToolContext context) { return this.biFunction.apply(input, context); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java index feaafd80130..39d1323479c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java @@ -20,7 +20,9 @@ import java.lang.reflect.Modifier; import java.util.Arrays; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Supplier; import kotlin.jvm.functions.Function1; import kotlin.jvm.functions.Function2; @@ -189,6 +191,12 @@ public static ResolvableType getFunctionArgumentType(ResolvableType functionType else if (BiFunction.class.isAssignableFrom(resolvableClass)) { functionArgumentResolvableType = functionType.as(BiFunction.class); } + else if (Supplier.class.isAssignableFrom(resolvableClass)) { + return ResolvableType.forClass(Void.class); + } + else if (Consumer.class.isAssignableFrom(resolvableClass)) { + functionArgumentResolvableType = functionType.as(Consumer.class); + } else if (KotlinDetector.isKotlinPresent()) { if (KotlinDelegate.isKotlinFunction(resolvableClass)) { functionArgumentResolvableType = KotlinDelegate.adaptToKotlinFunctionType(functionType); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java index 6408edc1b9a..906028807d5 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -18,10 +18,14 @@ import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Supplier; import java.util.stream.Collectors; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; @@ -52,179 +56,252 @@ @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") class FunctionCallbackWithPlainFunctionBeanIT { - private final Logger logger = LoggerFactory.getLogger(FunctionCallbackWithPlainFunctionBeanIT.class); + private final static Logger logger = LoggerFactory.getLogger(FunctionCallbackWithPlainFunctionBeanIT.class); private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY")) + .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY"), + "spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) .withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class)) .withUserConfiguration(Config.class); + private static Map feedback = new ConcurrentHashMap<>(); + + @BeforeEach + void setUp() { + feedback.clear(); + } + @Test void functionCallWithDirectBiFunction() { - this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) - .run(context -> { + this.contextRunner.run(context -> { - OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); - ChatClient chatClient = ChatClient.builder(chatModel).build(); + ChatClient chatClient = ChatClient.builder(chatModel).build(); - String content = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") - .functions("weatherFunctionWithContext") - .toolContext(Map.of("sessionId", "123")) - .call() - .content(); - logger.info(content); + String content = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") + .functions("weatherFunctionWithContext") + .toolContext(Map.of("sessionId", "123")) + .call() + .content(); + logger.info(content); - // Test weatherFunction - UserMessage userMessage = new UserMessage( - "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); + // Test weatherFunction + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder() - .withFunction("weatherFunctionWithContext") - .withToolContext(Map.of("sessionId", "123")) - .build())); + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder() + .withFunction("weatherFunctionWithContext") + .withToolContext(Map.of("sessionId", "123")) + .build())); - logger.info("Response: {}", response); + logger.info("Response: {}", response); - assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); - }); + }); } @Test void functionCallWithBiFunctionClass() { - this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) - .run(context -> { + this.contextRunner.run(context -> { - OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); - ChatClient chatClient = ChatClient.builder(chatModel).build(); + ChatClient chatClient = ChatClient.builder(chatModel).build(); - String content = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") - .functions("weatherFunctionWithClassBiFunction") - .toolContext(Map.of("sessionId", "123")) - .call() - .content(); - logger.info(content); + String content = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") + .functions("weatherFunctionWithClassBiFunction") + .toolContext(Map.of("sessionId", "123")) + .call() + .content(); + logger.info(content); - // Test weatherFunction - UserMessage userMessage = new UserMessage( - "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); + // Test weatherFunction + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder() - .withFunction("weatherFunctionWithClassBiFunction") - .withToolContext(Map.of("sessionId", "123")) - .build())); + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder() + .withFunction("weatherFunctionWithClassBiFunction") + .withToolContext(Map.of("sessionId", "123")) + .build())); - logger.info("Response: {}", response); + logger.info("Response: {}", response); - assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); - }); + }); } @Test void functionCallTest() { - this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) - .run(context -> { + this.contextRunner.run(context -> { - OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); - // Test weatherFunction - UserMessage userMessage = new UserMessage( - "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); + // Test weatherFunction + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().withFunction("weatherFunction").build())); + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withFunction("weatherFunction").build())); - logger.info("Response: {}", response); + logger.info("Response: {}", response); - assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); - // Test weatherFunctionTwo - response = chatModel.call(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().withFunction("weatherFunctionTwo").build())); + // Test weatherFunctionTwo + response = chatModel.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withFunction("weatherFunctionTwo").build())); - logger.info("Response: {}", response); + logger.info("Response: {}", response); - assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); - }); + }); } @Test void functionCallWithPortableFunctionCallingOptions() { - this.contextRunner - .withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName(), - "spring.ai.openai.chat.options.temperature=0.1") - .run(context -> { + this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.temperature=0.1").run(context -> { - OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); - // Test weatherFunction - UserMessage userMessage = new UserMessage( - "What's the weather like in San Francisco, Tokyo, and Paris?"); + // Test weatherFunction + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); - PortableFunctionCallingOptions functionOptions = FunctionCallingOptions.builder() - .withFunction("weatherFunction") - .build(); + PortableFunctionCallingOptions functionOptions = FunctionCallingOptions.builder() + .withFunction("weatherFunction") + .build(); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); - logger.info("Response: {}", response.getResult().getOutput().getContent()); + logger.info("Response: {}", response.getResult().getOutput().getContent()); - assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); - }); + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + }); } @Test void streamFunctionCallTest() { - this.contextRunner - .withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName(), - "spring.ai.openai.chat.options.temperature=0.1") - .run(context -> { - - OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); - - // Test weatherFunction - UserMessage userMessage = new UserMessage( - "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); - - Flux response = chatModel.stream(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().withFunction("weatherFunction").build())); - - String content = response.collectList() - .block() - .stream() - .map(ChatResponse::getResults) - .flatMap(List::stream) - .map(Generation::getOutput) - .map(AssistantMessage::getContent) - .collect(Collectors.joining()); - logger.info("Response: {}", content); - - assertThat(content).contains("30", "10", "15"); - - // Test weatherFunctionTwo - response = chatModel.stream(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().withFunction("weatherFunctionTwo").build())); - - content = response.collectList() - .block() - .stream() - .map(ChatResponse::getResults) - .flatMap(List::stream) - .map(Generation::getOutput) - .map(AssistantMessage::getContent) - .collect(Collectors.joining()); - logger.info("Response: {}", content); - - assertThat(content).isNotEmpty().withFailMessage("Content returned from OpenAI model is empty"); - assertThat(content).contains("30", "10", "15"); - - }); + this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.temperature=0.1").run(context -> { + + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); + + Flux response = chatModel.stream(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withFunction("weatherFunction").build())); + + String content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + + assertThat(content).contains("30", "10", "15"); + + // Test weatherFunctionTwo + response = chatModel.stream(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withFunction("weatherFunctionTwo").build())); + + content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + + assertThat(content).isNotEmpty().withFailMessage("Content returned from OpenAI model is empty"); + assertThat(content).contains("30", "10", "15"); + + }); + } + + @Test + void functionCallingVoidResponse() { + this.contextRunner.run(context -> { + + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage("Turn the light on in the kitchen and in the living room"); + + ChatResponse response = chatModel + .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withFunction("turnLight").build())); + + logger.info("Response: {}", response); + assertThat(feedback).hasSize(2); + assertThat(feedback.get("kitchen")).isEqualTo(Boolean.valueOf(true)); + assertThat(feedback.get("living room")).isEqualTo(Boolean.valueOf(true)); + }); + } + + @Test + void functionCallingConsumer() { + this.contextRunner.run(context -> { + + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage("Turn the light on in the kitchen and in the living room"); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withFunction("turnLightConsumer").build())); + + logger.info("Response: {}", response); + assertThat(feedback).hasSize(2); + assertThat(feedback.get("kitchen")).isEqualTo(Boolean.valueOf(true)); + assertThat(feedback.get("living room")).isEqualTo(Boolean.valueOf(true)); + + }); + } + + @Test + void functionCallingVoidArguments() { + this.contextRunner.run(context -> { + + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage("Turn the light on in the living room"); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withFunction("turnLivingRoomLightOn").build())); + + logger.info("Response: {}", response); + assertThat(feedback).hasSize(1); + assertThat(feedback.get("turnLivingRoomLightOn")).isEqualTo(Boolean.valueOf(true)); + }); + } + + @Test + void functionCallingSupplier() { + this.contextRunner.run(context -> { + + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage("Turn the light on in the living room"); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withFunction("turnLivingRoomLightOnSupplier").build())); + + logger.info("Response: {}", response); + assertThat(feedback).hasSize(1); + assertThat(feedback.get("turnLivingRoomLightOnSupplier")).isEqualTo(Boolean.valueOf(true)); + }); } @Configuration @@ -256,6 +333,48 @@ public Function weather return (weatherService::apply); } + record LightInfo(String roomName, boolean isOn) { + } + + @Bean + @Description("Turn light on or off in a room") + public Function turnLight() { + return (LightInfo lightInfo) -> { + logger.info("Turning light to [" + lightInfo.isOn + "] in " + lightInfo.roomName()); + feedback.put(lightInfo.roomName(), lightInfo.isOn()); + return null; + }; + } + + @Bean + @Description("Turn light on or off in a room") + public Consumer turnLightConsumer() { + return (LightInfo lightInfo) -> { + logger.info("Turning light to [" + lightInfo.isOn + "] in " + lightInfo.roomName()); + feedback.put(lightInfo.roomName(), lightInfo.isOn()); + }; + } + + @Bean + @Description("Turns light on in the living room") + public Function turnLivingRoomLightOn() { + return (Void v) -> { + logger.info("Turning light on in the living room"); + feedback.put("turnLivingRoomLightOn", Boolean.TRUE); + return "Done"; + }; + } + + @Bean + @Description("Turns light on in the living room") + public Supplier turnLivingRoomLightOnSupplier() { + return () -> { + logger.info("Turning light on in the living room"); + feedback.put("turnLivingRoomLightOnSupplier", Boolean.TRUE); + return "Done"; + }; + } + } public static class MyBiFunction From 021bc8ccd33d98de423dac15f1030f4134a6a4c6 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 12 Nov 2024 10:43:21 +0100 Subject: [PATCH 2/4] add ChatClient API support for consumer and supplier funcitons --- .../ai/openai/OpenAiTestConfiguration.java | 1 - .../chat/client/OpenAiChatClientIT.java | 46 ++++++++++++ .../ai/chat/client/ChatClient.java | 4 ++ .../ai/chat/client/DefaultChatClient.java | 32 +++++++++ .../function/FunctionCallbackWrapper.java | 48 ++++++++++--- .../ai/model/function/TypeResolverHelper.java | 10 +++ .../chat/client/DefaultChatClientTests.java | 19 +++++ .../tool/FunctionCallbackInPrompt2IT.java | 71 +++++++++++++------ 8 files changed, 201 insertions(+), 30 deletions(-) diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java index d9e6b6ca513..38988191179 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java @@ -79,7 +79,6 @@ public OpenAiAudioSpeechModel openAiAudioSpeechModel(OpenAiAudioApi api) { @Bean public OpenAiImageModel openAiImageModel(OpenAiImageApi imageApi) { OpenAiImageModel openAiImageModel = new OpenAiImageModel(imageApi); - // openAiImageModel.setModel("foobar"); return openAiImageModel; } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java index 3c8a724a063..c081052421f 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java @@ -21,6 +21,8 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; import java.util.stream.Collectors; import org.junit.jupiter.api.Disabled; @@ -257,6 +259,50 @@ void functionCallTest() { assertThat(response).contains("30", "10", "15"); } + @Test + void functionCallSupplier() { + + Map state = new ConcurrentHashMap<>(); + + // @formatter:off + String response = ChatClient.create(this.chatModel).prompt() + .user("Turn the light on in the living room") + .function("turnLivingRoomLightOnSupplier", "Turns light on in the living room", () -> state.put("foo", "bar")) + .call() + .content(); + // @formatter:on + + logger.info("Response: {}", response); + assertThat(state).containsEntry("foo", "bar"); + } + + @Test + void functionCallConsumer() { + + Map state = new ConcurrentHashMap<>(); + + record LightInfo(String roomName, boolean isOn) { + } + + // @formatter:off + String response = ChatClient.create(this.chatModel).prompt() + .user("Turn the light on in the kitchen and in the living room") + .function("turnLight", "Turn light on or off in a room", new Consumer() { + @Override + public void accept(LightInfo lightInfo) { + logger.info("Turning light to [" + lightInfo.isOn + "] in " + lightInfo.roomName()); + state.put(lightInfo.roomName(), lightInfo.isOn()); + } + }) + .call() + .content(); + // @formatter:on + + logger.info("Response: {}", response); + assertThat(state).containsEntry("kitchen", Boolean.TRUE); + assertThat(state).containsEntry("living room", Boolean.TRUE); + } + @Test void defaultFunctionCallTest() { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java index d53e998967a..424819a6067 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -223,6 +223,10 @@ ChatClientRequestSpec function(String name, String description, ChatClientRequestSpec function(String name, String description, Class inputType, java.util.function.Function function); + ChatClientRequestSpec function(String name, String description, java.util.function.Supplier supplier); + + ChatClientRequestSpec function(String name, String description, java.util.function.Consumer consumer); + ChatClientRequestSpec functions(String... functionBeanNames); ChatClientRequestSpec toolContext(Map toolContext); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 6f46f2749b0..852b2c58570 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -874,6 +874,38 @@ public ChatClientRequestSpec function(String name, String description, @N return this; } + public ChatClientRequestSpec function(String name, String description, + java.util.function.Supplier supplier) { + + Assert.hasText(name, "name cannot be null or empty"); + Assert.hasText(description, "description cannot be null or empty"); + Assert.notNull(supplier, "supplier cannot be null"); + + var fcw = FunctionCallbackWrapper.builder(supplier) + .withDescription(description) + .withName(name) + .withInputType(Void.class) + .build(); + this.functionCallbacks.add(fcw); + return this; + } + + public ChatClientRequestSpec function(String name, String description, + java.util.function.Consumer consumer) { + + Assert.hasText(name, "name cannot be null or empty"); + Assert.hasText(description, "description cannot be null or empty"); + Assert.notNull(consumer, "consumer cannot be null"); + + var fcw = FunctionCallbackWrapper.builder(consumer) + .withDescription(description) + .withName(name) + // .withResponseConverter(Object::toString) + .build(); + this.functionCallbacks.add(fcw); + return this; + } + public ChatClientRequestSpec functions(String... functionBeanNames) { Assert.notNull(functionBeanNames, "functionBeanNames cannot be null"); Assert.noNullElements(functionBeanNames, "functionBeanNames cannot contain null elements"); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java index 5342aa9f6d5..fd804b9220f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java @@ -66,12 +66,8 @@ public static Builder builder(Supplier supplier) { return new Builder<>(function); } - public static Builder builder(Consumer consumer) { - Function function = (input) -> { - consumer.accept(input); - return null; - }; - return new Builder<>(function); + public static Builder builder(Consumer consumer) { + return new Builder<>(consumer); } @Override @@ -85,6 +81,8 @@ public static class Builder { private final Function function; + private final Consumer consumer; + private String name; private String description; @@ -104,12 +102,21 @@ public Builder(BiFunction biFunction) { Assert.notNull(biFunction, "Function must not be null"); this.biFunction = biFunction; this.function = null; + this.consumer = null; } public Builder(Function function) { Assert.notNull(function, "Function must not be null"); this.biFunction = null; this.function = function; + this.consumer = null; + } + + public Builder(Consumer consumer) { + Assert.notNull(consumer, "Consumer must not be null"); + this.biFunction = null; + this.function = null; + this.consumer = consumer; } @SuppressWarnings("unchecked") @@ -123,6 +130,11 @@ private static Class resolveInputType(Function function) { return (Class) TypeResolverHelper.getFunctionInputClass((Class>) function.getClass()); } + @SuppressWarnings("unchecked") + private static Class resolveInputType(Consumer consumer) { + return (Class) TypeResolverHelper.getConsumerInputClass((Class>) consumer.getClass()); + } + public Builder withName(String name) { Assert.hasText(name, "Name must not be empty"); this.name = name; @@ -183,9 +195,12 @@ public FunctionCallbackWrapper build() { if (this.function != null) { this.inputType = resolveInputType(this.function); } - else { + else if (this.biFunction != null) { this.inputType = resolveInputType(this.biFunction); } + else { + this.inputType = resolveInputType(this.consumer); + } } if (this.inputTypeSchema == null) { @@ -193,8 +208,23 @@ public FunctionCallbackWrapper build() { this.inputTypeSchema = ModelOptionsUtils.getJsonSchema(this.inputType, upperCaseTypeValues); } - BiFunction finalBiFunction = (this.biFunction != null) ? this.biFunction - : (request, context) -> this.function.apply(request); + BiFunction finalBiFunction = null; + if (this.biFunction != null) { + finalBiFunction = this.biFunction; + } + else if (this.function != null) { + finalBiFunction = (request, context) -> this.function.apply(request); + } + else { + finalBiFunction = (request, context) -> { + this.consumer.accept(request); + return null; + }; + } + + // BiFunction finalBiFunction = (this.biFunction != null) ? + // this.biFunction + // : (request, context) -> this.function.apply(request); return new FunctionCallbackWrapper<>(this.name, this.description, this.inputTypeSchema, this.inputType, this.responseConverter, this.objectMapper, finalBiFunction); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java index 39d1323479c..14c14f2c9a1 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java @@ -64,6 +64,16 @@ public static Class getFunctionInputClass(Class> fun return getFunctionArgumentClass(functionClass, 0); } + /** + * Returns the input class of a given Consumer class. + * @param consumerClass The consumer class. + * @return The input class of the consumer. + */ + public static Class getConsumerInputClass(Class> consumerClass) { + ResolvableType resolvableType = ResolvableType.forClass(consumerClass).as(Consumer.class); + return (resolvableType == ResolvableType.NONE ? Object.class : resolvableType.getGeneric(0).toClass()); + } + /** * Returns the output class of a given function class. * @param functionClass The function class. diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java index 269fa7812dd..3d803ee4159 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java @@ -1467,6 +1467,25 @@ void whenBiFunctionThenReturn() { assertThat(defaultSpec.getFunctionCallbacks()).anyMatch(callback -> callback.getName().equals("name")); } + @Test + void whenSupplierFunctionThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + spec = spec.function("name", "description", () -> "hello"); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getFunctionCallbacks()).anyMatch(callback -> callback.getName().equals("name")); + } + + @Test + void whenConsumerFunctionThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + Consumer consumer = input -> System.out.println(input); + spec = spec.function("name", "description", consumer); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getFunctionCallbacks()).anyMatch(callback -> callback.getName().equals("name")); + } + @Test void whenFunctionBeanNamesElementIsNullThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java index b6e2c3024e0..fa687338cc4 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java @@ -16,6 +16,9 @@ package org.springframework.ai.autoconfigure.openai.tool; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Collectors; @@ -39,17 +42,17 @@ public class FunctionCallbackInPrompt2IT { private final Logger logger = LoggerFactory.getLogger(FunctionCallbackInPromptIT.class); private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY")) + .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY"), + "spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) .withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class)); @Test void functionCallTest() { - this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) - .run(context -> { + this.contextRunner.run(context -> { - OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); - ChatClient chatClient = ChatClient.builder(chatModel).build(); + ChatClient chatClient = ChatClient.builder(chatModel).build(); // @formatter:off chatClient.prompt() @@ -62,18 +65,17 @@ void functionCallTest() { .call().content(); // @formatter:on - logger.info("Response: {}", content); + logger.info("Response: {}", content); - assertThat(content).contains("30", "10", "15"); - }); + assertThat(content).contains("30", "10", "15"); + }); } @Test void functionCallTest2() { - this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) - .run(context -> { + this.contextRunner.run(context -> { - OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); // @formatter:off String content = ChatClient.builder(chatModel).build().prompt() @@ -87,19 +89,48 @@ public String apply(MockWeatherService.Request request) { }) .call().content(); // @formatter:on - logger.info("Response: {}", content); + logger.info("Response: {}", content); - assertThat(content).contains("18"); - }); + assertThat(content).contains("18"); + }); + } + + @Test + void functionCallTest21() { + Map state = new ConcurrentHashMap<>(); + + record LightInfo(String roomName, boolean isOn) { + } + + this.contextRunner.run(context -> { + + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + + // @formatter:off + String content = ChatClient.builder(chatModel).build().prompt() + .user("Turn the light on in the kitchen and in the living room!") + .function("turnLight", "Turn light on or off in a room", + new Consumer() { + @Override + public void accept(LightInfo lightInfo) { + logger.info("Turning light to [" + lightInfo.isOn + "] in " + lightInfo.roomName()); + state.put(lightInfo.roomName(), lightInfo.isOn()); + } + }) + .call().content(); + // @formatter:on + logger.info("Response: {}", content); + assertThat(state).containsEntry("kitchen", Boolean.TRUE); + assertThat(state).containsEntry("living room", Boolean.TRUE); + }); } @Test void streamingFunctionCallTest() { - this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) - .run(context -> { + this.contextRunner.run(context -> { - OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); // @formatter:off String content = ChatClient.builder(chatModel).build().prompt() @@ -109,10 +140,10 @@ void streamingFunctionCallTest() { .collectList().block().stream().collect(Collectors.joining()); // @formatter:on - logger.info("Response: {}", content); + logger.info("Response: {}", content); - assertThat(content).contains("30", "10", "15"); - }); + assertThat(content).contains("30", "10", "15"); + }); } } From 05c86d414e7267f2749b50befba2d6609a9f9a69 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 12 Nov 2024 14:56:00 +0100 Subject: [PATCH 3/4] feat: adds the ability to specify the input type class when defining a function This change adds the ability to specify the input type class when defining a function callback for the ChatClient. Previously, the input type had to be inferred, which caused issues with lambda expressions due to type erasure. The following new methods have been added: - ChatClient.ChatClientRequestSpec.function(String, String, Class, Function) - ChatClient.ChatClientRequestSpec.function(String, String, Class, BiFunction) - ChatClient.ChatClientRequestSpec.function(String, String, Class, Consumer) - ChatClient.Builder.defaultFunction(String, String, Class, Function) - ChatClient.Builder.defaultFunction(String, String, Class, BiFunction) - ChatClient.Builder.defaultFunction(String, String, Class, Consumer) The deprecated methods without the input type parameter have also been kept for backwards compatibility. This change should make it easier to use the ChatClient API, especially when dealing with lambda expressions. --- .../client/AnthropicChatClientIT.java | 6 +-- .../converse/BedrockConverseChatClientIT.java | 10 ++-- .../ai/mistralai/MistralAiChatClientIT.java | 6 +-- .../chat/client/OpenAiChatClientIT.java | 13 +++-- ...enAiChatClientMultipleFunctionCallsIT.java | 12 +++-- .../ai/chat/client/ChatClient.java | 51 +++++++++++++++++-- .../ai/chat/client/DefaultChatClient.java | 41 ++++++++++++--- .../chat/client/DefaultChatClientBuilder.java | 25 +++++++++ .../ai/chat/client/ChatClientTest.java | 8 +-- .../chat/client/DefaultChatClientTests.java | 26 +++++----- .../modules/ROOT/pages/api/functions.adoc | 2 + .../tool/FunctionCallbackInPrompt2IT.java | 21 ++++---- 12 files changed, 159 insertions(+), 62 deletions(-) diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java index b4038d6acb0..4ebb2c5bf73 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java @@ -211,7 +211,7 @@ void functionCallTest() { // @formatter:off String response = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")) - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .call() .content(); // @formatter:on @@ -226,7 +226,7 @@ void defaultFunctionCallTest() { // @formatter:off String response = ChatClient.builder(this.chatModel) - .defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .defaultFunction("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")) .build() .prompt() @@ -245,7 +245,7 @@ void streamFunctionCallTest() { // @formatter:off Flux response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.") - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .stream() .content(); // @formatter:on diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java index 0bbc99be341..aed619c2dd6 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java @@ -212,7 +212,7 @@ void functionCallTest() { // @formatter:off String response = ChatClient.create(this.chatModel) .prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.") - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .call() .content(); // @formatter:on @@ -228,7 +228,7 @@ void functionCallWithAdvisorTest() { // @formatter:off String response = ChatClient.create(this.chatModel) .prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.") - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .advisors(new SimpleLoggerAdvisor()) .call() .content(); @@ -244,7 +244,7 @@ void defaultFunctionCallTest() { // @formatter:off String response = ChatClient.builder(this.chatModel) - .defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .defaultFunction("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.")) .build() .prompt() @@ -263,7 +263,7 @@ void streamFunctionCallTest() { // @formatter:off Flux response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.") - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .stream() .content(); // @formatter:on @@ -280,7 +280,7 @@ void singularStreamFunctionCallTest() { // @formatter:off Flux response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in Paris? Return the temperature in Celsius.") - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .stream() .content(); // @formatter:on diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java index 6298c777393..fe1f109a82e 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java @@ -224,7 +224,7 @@ void functionCallTest() { String response = ChatClient.create(this.chatModel).prompt() .options(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).withToolChoice(ToolChoice.AUTO).build()) .user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.")) - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .call() .content(); // @formatter:on @@ -242,7 +242,7 @@ void defaultFunctionCallTest() { // @formatter:off String response = ChatClient.builder(this.chatModel) .defaultOptions(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).build()) - .defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .defaultFunction("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.")) .build() .prompt().call().content(); @@ -262,7 +262,7 @@ void streamFunctionCallTest() { Flux response = ChatClient.create(this.chatModel).prompt() .options(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).build()) .user("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.") - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .stream() .content(); // @formatter:on diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java index c081052421f..c65a2d5f2c2 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java @@ -249,7 +249,7 @@ void functionCallTest() { // @formatter:off String response = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .call() .content(); // @formatter:on @@ -287,12 +287,9 @@ record LightInfo(String roomName, boolean isOn) { // @formatter:off String response = ChatClient.create(this.chatModel).prompt() .user("Turn the light on in the kitchen and in the living room") - .function("turnLight", "Turn light on or off in a room", new Consumer() { - @Override - public void accept(LightInfo lightInfo) { + .function("turnLight", "Turn light on or off in a room", LightInfo.class, (LightInfo lightInfo) -> { logger.info("Turning light to [" + lightInfo.isOn + "] in " + lightInfo.roomName()); state.put(lightInfo.roomName(), lightInfo.isOn()); - } }) .call() .content(); @@ -308,7 +305,8 @@ void defaultFunctionCallTest() { // @formatter:off String response = ChatClient.builder(this.chatModel) - .defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .defaultFunction("getCurrentWeather", "Get the weather in location", + MockWeatherService.Request.class, new MockWeatherService()) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) .build() .prompt().call().content(); @@ -325,7 +323,8 @@ void streamFunctionCallTest() { // @formatter:off Flux response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris?") - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", + MockWeatherService.Request.class, new MockWeatherService()) .stream() .content(); // @formatter:on diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java index 6ba1682e1d5..490163ebe13 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java @@ -83,7 +83,8 @@ void turnFunctionsOnAndOffTest() { // @formatter:off response = chatClientBuilder.build().prompt() .user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", + MockWeatherService.Request.class, new MockWeatherService()) .call() .content(); // @formatter:on @@ -110,7 +111,8 @@ void defaultFunctionCallTest() { // @formatter:off String response = ChatClient.builder(this.chatModel) - .defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .defaultFunction("getCurrentWeather", "Get the weather in location", + MockWeatherService.Request.class, new MockWeatherService()) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) .build() .prompt().call().content(); @@ -149,7 +151,7 @@ else if (request.location().contains("San Francisco")) { // @formatter:off String response = ChatClient.builder(this.chatModel) - .defaultFunction("getCurrentWeather", "Get the weather in location", biFunction) + .defaultFunction("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, biFunction) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) .defaultToolContext(Map.of("sessionId", "123")) .build() @@ -189,7 +191,7 @@ else if (request.location().contains("San Francisco")) { // @formatter:off String response = ChatClient.builder(this.chatModel) - .defaultFunction("getCurrentWeather", "Get the weather in location", biFunction) + .defaultFunction("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, biFunction) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) .build() .prompt() @@ -208,7 +210,7 @@ void streamFunctionCallTest() { // @formatter:off Flux response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris?") - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .stream() .content(); // @formatter:on diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java index 424819a6067..bffcacd1e28 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -212,20 +212,38 @@ interface ChatClientRequestSpec { ChatClientRequestSpec options(T options); + /** + * @deprecated Use + * {@link #function(String, String, Class, java.util.function.Function)} instead. + * Because of JVM type erasure, for lambda to work, the inputType class is + * required to be provided explicitly. + */ + @Deprecated ChatClientRequestSpec function(String name, String description, java.util.function.Function function); + ChatClientRequestSpec function(String name, String description, Class inputType, + java.util.function.Function function); + + /** + * @deprecated Use + * {@link #functions(String, String, Class,java.util.function.BiFunction)} Because + * of JVM type erasure, for lambda to work, the inputType class is required to be + * provided explicitly. + */ + @Deprecated ChatClientRequestSpec function(String name, String description, java.util.function.BiFunction function); - ChatClientRequestSpec functions(FunctionCallback... functionCallbacks); - ChatClientRequestSpec function(String name, String description, Class inputType, - java.util.function.Function function); + java.util.function.BiFunction function); + + ChatClientRequestSpec functions(FunctionCallback... functionCallbacks); ChatClientRequestSpec function(String name, String description, java.util.function.Supplier supplier); - ChatClientRequestSpec function(String name, String description, java.util.function.Consumer consumer); + ChatClientRequestSpec function(String name, String description, Class inputType, + java.util.function.Consumer consumer); ChatClientRequestSpec functions(String... functionBeanNames); @@ -282,11 +300,36 @@ interface Builder { Builder defaultSystem(Consumer systemSpecConsumer); + /** + * @deprecated Use + * {@link #defaultFunction(String, String, Class, java.util.function.Function)} + * instead. Because of JVM type erasure, for lambda to work, the inputType class + * is required to be provided explicitly. + */ + @Deprecated Builder defaultFunction(String name, String description, java.util.function.Function function); + Builder defaultFunction(String name, String description, Class inputType, + java.util.function.Function function); + + /** + * @deprecated Use + * {@link #defaultFunction(String, String, Class, java.util.function.BiFunction)} + * instead. Because of JVM type erasure, for lambda to work, the inputType class + * is required to be provided explicitly. + */ + @Deprecated Builder defaultFunction(String name, String description, java.util.function.BiFunction function); + Builder defaultFunction(String name, String description, Class inputType, + java.util.function.BiFunction function); + + Builder defaultFunction(String name, String description, java.util.function.Supplier supplier); + + Builder defaultFunction(String name, String description, Class inputType, + java.util.function.Consumer consumer); + Builder defaultFunctions(String... functionNames); Builder defaultFunctions(FunctionCallback... functionCallbacks); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 852b2c58570..529b17f3e5c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -836,11 +836,40 @@ public ChatClientRequestSpec options(T options) { return this; } + /** + * @deprecated since 1.0.0 in favor of + * {@link #function(String, String, Class, java.util.function.Function)} Because + * of JVM type erasure, the inputType class is required to be provided explicitly. + */ + @Deprecated(since = "1.0.0", forRemoval = true) public ChatClientRequestSpec function(String name, String description, java.util.function.Function function) { return this.function(name, description, null, function); } + public ChatClientRequestSpec function(String name, String description, @Nullable Class inputType, + java.util.function.Function function) { + + Assert.hasText(name, "name cannot be null or empty"); + Assert.hasText(description, "description cannot be null or empty"); + Assert.notNull(function, "function cannot be null"); + + var fcw = FunctionCallbackWrapper.builder(function) + .withDescription(description) + .withName(name) + .withInputType(inputType) + .withResponseConverter(Object::toString) + .build(); + this.functionCallbacks.add(fcw); + return this; + } + + /** + * @deprecated since 1.0.0 in favor of + * {@link #function(String, String, Class, java.util.function.BiFunction)} Because + * of JVM type erasure, the inputType class is required to be provided explicitly. + */ + @Deprecated public ChatClientRequestSpec function(String name, String description, java.util.function.BiFunction biFunction) { @@ -857,14 +886,14 @@ public ChatClientRequestSpec function(String name, String description, return this; } - public ChatClientRequestSpec function(String name, String description, @Nullable Class inputType, - java.util.function.Function function) { + public ChatClientRequestSpec function(String name, String description, Class inputType, + java.util.function.BiFunction biFunction) { Assert.hasText(name, "name cannot be null or empty"); Assert.hasText(description, "description cannot be null or empty"); - Assert.notNull(function, "function cannot be null"); + Assert.notNull(biFunction, "biFunction cannot be null"); - var fcw = FunctionCallbackWrapper.builder(function) + FunctionCallbackWrapper fcw = FunctionCallbackWrapper.builder(biFunction) .withDescription(description) .withName(name) .withInputType(inputType) @@ -890,7 +919,7 @@ public ChatClientRequestSpec function(String name, String description, return this; } - public ChatClientRequestSpec function(String name, String description, + public ChatClientRequestSpec function(String name, String description, Class inputType, java.util.function.Consumer consumer) { Assert.hasText(name, "name cannot be null or empty"); @@ -900,7 +929,7 @@ public ChatClientRequestSpec function(String name, String description, var fcw = FunctionCallbackWrapper.builder(consumer) .withDescription(description) .withName(name) - // .withResponseConverter(Object::toString) + .withInputType(inputType) .build(); this.functionCallbacks.add(fcw); return this; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java index 4ae3833d868..e4e8b56c662 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java @@ -143,17 +143,42 @@ public Builder defaultSystem(Consumer systemSpecConsumer) { return this; } + @Deprecated public Builder defaultFunction(String name, String description, java.util.function.Function function) { this.defaultRequest.function(name, description, function); return this; } + public Builder defaultFunction(String name, String description, Class inputType, + java.util.function.Function function) { + this.defaultRequest.function(name, description, inputType, function); + return this; + } + + @Deprecated public Builder defaultFunction(String name, String description, java.util.function.BiFunction biFunction) { this.defaultRequest.function(name, description, biFunction); return this; } + public Builder defaultFunction(String name, String description, Class inputType, + java.util.function.BiFunction biFunction) { + this.defaultRequest.function(name, description, inputType, biFunction); + return this; + } + + public Builder defaultFunction(String name, String description, java.util.function.Supplier supplier) { + this.defaultRequest.function(name, description, supplier); + return this; + } + + public Builder defaultFunction(String name, String description, Class inputType, + java.util.function.Consumer consumer) { + this.defaultRequest.function(name, description, inputType, consumer); + return this; + } + public Builder defaultFunctions(String... functionNames) { this.defaultRequest.functions(functionNames); return this; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java index 1f9d407c38a..d085e810f16 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java @@ -217,7 +217,7 @@ void mutateDefaults() { .param("param1", "value1") .param("param2", "value2")) .defaultFunctions("fun1", "fun2") - .defaultFunction("fun3", "fun3description", mockFunction) + .defaultFunction("fun3", "fun3description", String.class, mockFunction) .defaultUser(u -> u.text("Default user text {uparam1}, {uparam2}") .param("uparam1", "value1") .param("uparam2", "value2") @@ -344,7 +344,7 @@ void mutatePrompt() { .param("param1", "value1") .param("param2", "value2")) .defaultFunctions("fun1", "fun2") - .defaultFunction("fun3", "fun3description", mockFunction) + .defaultFunction("fun3", "fun3description", String.class, mockFunction) .defaultUser(u -> u.text("Default user text {uparam1}, {uparam2}") .param("uparam1", "value1") .param("uparam2", "value2") @@ -541,9 +541,9 @@ void complexCall() throws MalformedURLException { assertThat(userMessage.getMedia().iterator().next().getData()) .isEqualTo("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"); - FunctionCallingOptions runtieOptions = (FunctionCallingOptions) this.promptCaptor.getValue().getOptions(); + FunctionCallingOptions runtimeOptions = (FunctionCallingOptions) this.promptCaptor.getValue().getOptions(); - assertThat(runtieOptions.getFunctions()).containsExactly("function1"); + assertThat(runtimeOptions.getFunctions()).containsExactly("function1"); assertThat(options.getFunctions()).isEmpty(); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java index 3d803ee4159..a2cbfed91ef 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java @@ -1354,7 +1354,7 @@ void whenOptionsThenReturn() { void whenFunctionNameIsNullThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.function(null, "description", input -> "hello")) + assertThatThrownBy(() -> spec.function(null, "description", String.class, input -> "hello")) .isInstanceOf(IllegalArgumentException.class) .hasMessage("name cannot be null or empty"); } @@ -1363,7 +1363,7 @@ void whenFunctionNameIsNullThenThrow() { void whenFunctionNameIsEmptyThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.function("", "description", input -> "hello")) + assertThatThrownBy(() -> spec.function("", "description", String.class, input -> "hello")) .isInstanceOf(IllegalArgumentException.class) .hasMessage("name cannot be null or empty"); } @@ -1372,7 +1372,7 @@ void whenFunctionNameIsEmptyThenThrow() { void whenFunctionDescriptionIsNullThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.function("name", null, input -> "hello")) + assertThatThrownBy(() -> spec.function("name", null, String.class, input -> "hello")) .isInstanceOf(IllegalArgumentException.class) .hasMessage("description cannot be null or empty"); } @@ -1381,7 +1381,7 @@ void whenFunctionDescriptionIsNullThenThrow() { void whenFunctionDescriptionIsEmptyThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.function("name", "", input -> "hello")) + assertThatThrownBy(() -> spec.function("name", "", String.class, input -> "hello")) .isInstanceOf(IllegalArgumentException.class) .hasMessage("description cannot be null or empty"); } @@ -1390,7 +1390,7 @@ void whenFunctionDescriptionIsEmptyThenThrow() { void whenFunctionLambdaIsNullThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.function("name", "description", (Function) null)) + assertThatThrownBy(() -> spec.function("name", "description", String.class, (Function) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("function cannot be null"); } @@ -1399,7 +1399,7 @@ void whenFunctionLambdaIsNullThenThrow() { void whenFunctionThenReturn() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - spec = spec.function("name", "description", input -> "hello"); + spec = spec.function("name", "description", String.class, input -> "hello"); DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; assertThat(defaultSpec.getFunctionCallbacks()).anyMatch(callback -> callback.getName().equals("name")); } @@ -1417,7 +1417,7 @@ void whenFunctionAndInputTypeThenReturn() { void whenBiFunctionNameIsNullThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.function(null, "description", (input, ctx) -> "hello")) + assertThatThrownBy(() -> spec.function(null, "description", String.class, (input, ctx) -> "hello")) .isInstanceOf(IllegalArgumentException.class) .hasMessage("name cannot be null or empty"); } @@ -1426,7 +1426,7 @@ void whenBiFunctionNameIsNullThenThrow() { void whenBiFunctionNameIsEmptyThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.function("", "description", (input, ctx) -> "hello")) + assertThatThrownBy(() -> spec.function("", "description", String.class, (input, ctx) -> "hello")) .isInstanceOf(IllegalArgumentException.class) .hasMessage("name cannot be null or empty"); } @@ -1435,7 +1435,7 @@ void whenBiFunctionNameIsEmptyThenThrow() { void whenBiFunctionDescriptionIsNullThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.function("name", null, (input, ctx) -> "hello")) + assertThatThrownBy(() -> spec.function("name", null, String.class, (input, ctx) -> "hello")) .isInstanceOf(IllegalArgumentException.class) .hasMessage("description cannot be null or empty"); } @@ -1444,7 +1444,7 @@ void whenBiFunctionDescriptionIsNullThenThrow() { void whenBiFunctionDescriptionIsEmptyThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.function("name", "", (input, ctx) -> "hello")) + assertThatThrownBy(() -> spec.function("name", "", String.class, (input, ctx) -> "hello")) .isInstanceOf(IllegalArgumentException.class) .hasMessage("description cannot be null or empty"); } @@ -1453,7 +1453,7 @@ void whenBiFunctionDescriptionIsEmptyThenThrow() { void whenBiFunctionLambdaIsNullThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.function("name", "description", (BiFunction) null)) + assertThatThrownBy(() -> spec.function("name", "description", String.class, (BiFunction) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("biFunction cannot be null"); } @@ -1462,7 +1462,7 @@ void whenBiFunctionLambdaIsNullThenThrow() { void whenBiFunctionThenReturn() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - spec = spec.function("name", "description", (input, ctx) -> "hello"); + spec = spec.function("name", "description", String.class, (input, ctx) -> "hello"); DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; assertThat(defaultSpec.getFunctionCallbacks()).anyMatch(callback -> callback.getName().equals("name")); } @@ -1481,7 +1481,7 @@ void whenConsumerFunctionThenReturn() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); Consumer consumer = input -> System.out.println(input); - spec = spec.function("name", "description", consumer); + spec = spec.function("name", "description", String.class, consumer); DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; assertThat(defaultSpec.getFunctionCallbacks()).anyMatch(callback -> callback.getName().equals("name")); } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc index 07062e9ede2..2b861dfcd37 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc @@ -277,6 +277,7 @@ ChatResponse response = this.chatClient.prompt("What's the weather like in San F .functions(new FunctionCallbackWrapper<>( "CurrentWeather", // name "Get the weather in location", // function description + MockWeatherService.Request.class, // input type new MockWeatherService())) .call() .chatResponse(); @@ -405,6 +406,7 @@ BiFunction ChatResponse response = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") .functions(FunctionCallbackWrapper.builder(this.weatherFunction) + .withInputType(MockWeatherService.Request.class) .withName("getCurrentWeather") .withDescription("Get the weather in location") .build()) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java index fa687338cc4..dacea1edc54 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java @@ -18,7 +18,6 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Collectors; @@ -61,7 +60,7 @@ void functionCallTest() { String content = ChatClient.builder(chatModel).build().prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris?") - .function("CurrentWeatherService", "Get the weather in location", new MockWeatherService()) + .function("CurrentWeatherService", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .call().content(); // @formatter:on @@ -80,7 +79,7 @@ void functionCallTest2() { // @formatter:off String content = ChatClient.builder(chatModel).build().prompt() .user("What's the weather like in Amsterdam?") - .function("CurrentWeatherService", "Get the weather in location", + .function("CurrentWeatherService", "Get the weather in location", MockWeatherService.Request.class, new Function() { @Override public String apply(MockWeatherService.Request request) { @@ -109,14 +108,11 @@ record LightInfo(String roomName, boolean isOn) { // @formatter:off String content = ChatClient.builder(chatModel).build().prompt() .user("Turn the light on in the kitchen and in the living room!") - .function("turnLight", "Turn light on or off in a room", - new Consumer() { - @Override - public void accept(LightInfo lightInfo) { - logger.info("Turning light to [" + lightInfo.isOn + "] in " + lightInfo.roomName()); - state.put(lightInfo.roomName(), lightInfo.isOn()); - } - }) + .function("turnLight", "Turn light on or off in a room",LightInfo.class, + lightInfo -> { + logger.info("Turning light to [" + lightInfo.isOn + "] in " + lightInfo.roomName()); + state.put(lightInfo.roomName(), lightInfo.isOn()); + }) .call().content(); // @formatter:on logger.info("Response: {}", content); @@ -135,7 +131,8 @@ void streamingFunctionCallTest() { // @formatter:off String content = ChatClient.builder(chatModel).build().prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris?") - .function("CurrentWeatherService", "Get the weather in location", new MockWeatherService()) + .function("CurrentWeatherService", "Get the weather in location", + MockWeatherService.Request.class, new MockWeatherService()) .stream().content() .collectList().block().stream().collect(Collectors.joining()); // @formatter:on From c30fda62fea8456c9a3d91c9bed1a4100eee7678 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 12 Nov 2024 22:56:10 +0100 Subject: [PATCH 4/4] refactor: migrate from Class to ParameterizedTypeReference for type handling Improves generic type support in function callbacks and JSON schema generation by: - Moving CustomizedTypeReference to ModelOptionsUtils - Updating BeanOutputConverter to use ParameterizedTypeReference - Modifying function callbacks to support generic type resolution - Adding train scheduler test case to validate generic type handling --- .../ai/converter/BeanOutputConverter.java | 21 ++-------- .../ai/model/ModelOptionsUtils.java | 31 +++++++++++--- .../function/AbstractFunctionCallback.java | 12 +++--- .../function/FunctionCallbackContext.java | 12 ++++-- .../function/FunctionCallbackWrapper.java | 36 +++++++++------- ...nctionCallbackWithPlainFunctionBeanIT.java | 42 +++++++++++++++++++ 6 files changed, 108 insertions(+), 46 deletions(-) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/converter/BeanOutputConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/converter/BeanOutputConverter.java index 0e718c1b989..4cb1dfa594a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/converter/BeanOutputConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/converter/BeanOutputConverter.java @@ -16,7 +16,6 @@ package org.springframework.ai.converter; -import java.lang.reflect.Type; import java.util.Objects; import com.fasterxml.jackson.core.JsonProcessingException; @@ -37,6 +36,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.model.ModelOptionsUtils.CustomizedTypeReference; import org.springframework.ai.util.JacksonUtils; import org.springframework.core.ParameterizedTypeReference; import org.springframework.lang.NonNull; @@ -94,7 +94,7 @@ public BeanOutputConverter(Class clazz, ObjectMapper objectMapper) { * @param typeRef The target class type reference. */ public BeanOutputConverter(ParameterizedTypeReference typeRef) { - this(new CustomizedTypeReference<>(typeRef), null); + this(CustomizedTypeReference.forType(typeRef), null); } /** @@ -105,7 +105,7 @@ public BeanOutputConverter(ParameterizedTypeReference typeRef) { * @param objectMapper Custom object mapper for JSON operations. endings. */ public BeanOutputConverter(ParameterizedTypeReference typeRef, ObjectMapper objectMapper) { - this(new CustomizedTypeReference<>(typeRef), objectMapper); + this(CustomizedTypeReference.forType(typeRef), objectMapper); } /** @@ -220,19 +220,4 @@ public String getJsonSchema() { return this.jsonSchema; } - private static class CustomizedTypeReference extends TypeReference { - - private final Type type; - - CustomizedTypeReference(ParameterizedTypeReference typeRef) { - this.type = typeRef.getType(); - } - - @Override - public Type getType() { - return this.type; - } - - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java index b4edf135657..f6411e31e92 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java @@ -18,6 +18,7 @@ import java.beans.PropertyDescriptor; import java.lang.reflect.Field; +import java.lang.reflect.Type; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -50,8 +51,8 @@ import org.springframework.ai.util.JacksonUtils; import org.springframework.beans.BeanWrapper; import org.springframework.beans.BeanWrapperImpl; +import org.springframework.core.ParameterizedTypeReference; import org.springframework.util.Assert; -import org.springframework.util.ClassUtils; import org.springframework.util.CollectionUtils; import org.springframework.util.ObjectUtils; @@ -335,11 +336,11 @@ private static String toGetName(String name) { /** * Generates JSON Schema (version 2020_12) for the given class. - * @param clazz the class to generate JSON Schema for. + * @param parameterizedType the class to generate JSON Schema for. * @param toUpperCaseTypeValues if true, the type values are converted to upper case. * @return the generated JSON Schema as a String. */ - public static String getJsonSchema(Class clazz, boolean toUpperCaseTypeValues) { + public static String getJsonSchema(ParameterizedTypeReference parameterizedType, boolean toUpperCaseTypeValues) { if (SCHEMA_GENERATOR_CACHE.get() == null) { @@ -358,9 +359,10 @@ public static String getJsonSchema(Class clazz, boolean toUpperCaseTypeValues SCHEMA_GENERATOR_CACHE.compareAndSet(null, generator); } - ObjectNode node = SCHEMA_GENERATOR_CACHE.get().generateSchema(clazz); + ObjectNode node = SCHEMA_GENERATOR_CACHE.get() + .generateSchema(CustomizedTypeReference.forType(parameterizedType).getType()); - if (ClassUtils.isVoidType(clazz) && node.get("properties") == null) { + if ((parameterizedType.getType() == Void.class) && node.get("properties") == null) { node.putObject("properties"); } @@ -411,4 +413,23 @@ public static T mergeOption(T runtimeValue, T defaultValue) { return ObjectUtils.isEmpty(runtimeValue) ? defaultValue : runtimeValue; } + public static class CustomizedTypeReference extends TypeReference { + + private final Type type; + + CustomizedTypeReference(ParameterizedTypeReference typeRef) { + this.type = typeRef.getType(); + } + + @Override + public Type getType() { + return this.type; + } + + public static CustomizedTypeReference forType(ParameterizedTypeReference typeRef) { + return new CustomizedTypeReference<>(typeRef); + } + + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java index 8a2c84aca52..e01226ddd60 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java @@ -24,6 +24,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.model.ModelOptionsUtils.CustomizedTypeReference; +import org.springframework.core.ParameterizedTypeReference; import org.springframework.util.Assert; /** @@ -47,7 +49,7 @@ abstract class AbstractFunctionCallback implements BiFunction inputType; + private final ParameterizedTypeReference inputType; private final String inputTypeSchema; @@ -70,8 +72,8 @@ abstract class AbstractFunctionCallback implements BiFunction inputType, - Function responseConverter, ObjectMapper objectMapper) { + protected AbstractFunctionCallback(String name, String description, String inputTypeSchema, + ParameterizedTypeReference inputType, Function responseConverter, ObjectMapper objectMapper) { Assert.notNull(name, "Name must not be null"); Assert.notNull(description, "Description must not be null"); Assert.notNull(inputType, "InputType must not be null"); @@ -116,9 +118,9 @@ public String call(String functionArguments) { return this.andThen(this.responseConverter).apply(request, null); } - private T fromJson(String json, Class targetClass) { + private T fromJson(String json, ParameterizedTypeReference targetClass) { try { - return this.objectMapper.readValue(json, targetClass); + return this.objectMapper.readValue(json, CustomizedTypeReference.forType(targetClass)); } catch (JsonProcessingException e) { throw new RuntimeException(e); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java index b96673c8cb7..29fe5262c0b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java @@ -124,11 +124,13 @@ else if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) { } } if (bean instanceof Function function) { + // ResolvableType.forInstance(function); return FunctionCallbackWrapper.builder(function) .withName(beanName) .withSchemaType(this.schemaType) .withDescription(functionDescription) - .withInputType(functionInputClass) + .withInputType(functionInputType) + // .withInputType(functionInputClass) .build(); } if (bean instanceof Consumer consumer) { @@ -136,7 +138,8 @@ else if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) { .withName(beanName) .withSchemaType(this.schemaType) .withDescription(functionDescription) - .withInputType(functionInputClass) + .withInputType(functionInputType) + // .withInputType(functionInputClass) .build(); } if (bean instanceof Supplier supplier) { @@ -144,7 +147,7 @@ else if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) { .withName(beanName) .withSchemaType(this.schemaType) .withDescription(functionDescription) - .withInputType(functionInputClass) + .withInputType(functionInputType) .build(); } else if (bean instanceof BiFunction) { @@ -152,7 +155,8 @@ else if (bean instanceof BiFunction) { .withName(beanName) .withSchemaType(this.schemaType) .withDescription(functionDescription) - .withInputType(functionInputClass) + .withInputType(functionInputType) + // .withInputType(functionInputClass) .build(); } else { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java index fd804b9220f..e3b077eb79d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java @@ -30,6 +30,8 @@ import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.FunctionCallbackContext.SchemaType; import org.springframework.ai.util.JacksonUtils; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.ResolvableType; import org.springframework.util.Assert; /** @@ -46,8 +48,9 @@ public final class FunctionCallbackWrapper extends AbstractFunctionCallbac private final BiFunction biFunction; - private FunctionCallbackWrapper(String name, String description, String inputTypeSchema, Class inputType, - Function responseConverter, ObjectMapper objectMapper, BiFunction function) { + private FunctionCallbackWrapper(String name, String description, String inputTypeSchema, + ParameterizedTypeReference inputType, Function responseConverter, ObjectMapper objectMapper, + BiFunction function) { super(name, description, inputTypeSchema, inputType, responseConverter, objectMapper); Assert.notNull(function, "Function must not be null"); this.biFunction = function; @@ -87,7 +90,7 @@ public static class Builder { private String description; - private Class inputType; + private ParameterizedTypeReference inputType; private SchemaType schemaType = SchemaType.JSON_SCHEMA; @@ -119,20 +122,20 @@ public Builder(Consumer consumer) { this.consumer = consumer; } - @SuppressWarnings("unchecked") - private static Class resolveInputType(BiFunction biFunction) { - return (Class) TypeResolverHelper - .getBiFunctionInputClass((Class>) biFunction.getClass()); + private static ParameterizedTypeReference resolveInputType(BiFunction biFunction) { + + ResolvableType rt = TypeResolverHelper.getFunctionArgumentType(ResolvableType.forInstance(biFunction), 0); + return ParameterizedTypeReference.forType(rt.getType()); } - @SuppressWarnings("unchecked") - private static Class resolveInputType(Function function) { - return (Class) TypeResolverHelper.getFunctionInputClass((Class>) function.getClass()); + private static ParameterizedTypeReference resolveInputType(Function function) { + ResolvableType rt = TypeResolverHelper.getFunctionArgumentType(ResolvableType.forInstance(function), 0); + return ParameterizedTypeReference.forType(rt.getType()); } - @SuppressWarnings("unchecked") - private static Class resolveInputType(Consumer consumer) { - return (Class) TypeResolverHelper.getConsumerInputClass((Class>) consumer.getClass()); + private static ParameterizedTypeReference resolveInputType(Consumer consumer) { + ResolvableType rt = TypeResolverHelper.getFunctionArgumentType(ResolvableType.forInstance(consumer), 0); + return ParameterizedTypeReference.forType(rt.getType()); } public Builder withName(String name) { @@ -149,7 +152,12 @@ public Builder withDescription(String description) { @SuppressWarnings("unchecked") public Builder withInputType(Class inputType) { - this.inputType = (Class) inputType; + this.inputType = ParameterizedTypeReference.forType((Class) inputType); + return this; + } + + public Builder withInputType(ResolvableType inputType) { + this.inputType = ParameterizedTypeReference.forType(inputType.getType()); return this; } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java index 906028807d5..66b74a197d6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -304,6 +304,26 @@ void functionCallingSupplier() { }); } + @Test + void trainScheduler() { + this.contextRunner.run(context -> { + + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage( + "Please schedule a train from San Francisco to Los Angeles on 2023-12-25"); + + PortableFunctionCallingOptions functionOptions = FunctionCallingOptions.builder() + .withFunction("trainReservation") + .build(); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); + + logger.info("Response: {}", response.getResult().getOutput().getContent()); + }); + } + @Configuration static class Config { @@ -375,6 +395,28 @@ public Supplier turnLivingRoomLightOnSupplier() { }; } + record TrainSearchSchedule(String from, String to, String date) { + } + + record TrainSearchScheduleResponse(String from, String to, String date, String trainNumber) { + } + + record TrainSearchRequest(T data) { + } + + record TrainSearchResponse(T data) { + } + + @Bean + @Description("Schedule a train reservation") + public Function, TrainSearchResponse> trainReservation() { + return (TrainSearchRequest request) -> { + logger.info("Turning light to [" + request.data().from() + "] in " + request.data().to()); + return new TrainSearchResponse<>( + new TrainSearchScheduleResponse(request.data().from(), request.data().to(), "", "123")); + }; + } + } public static class MyBiFunction