diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java index b4038d6acb0..4ebb2c5bf73 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java @@ -211,7 +211,7 @@ void functionCallTest() { // @formatter:off String response = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")) - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .call() .content(); // @formatter:on @@ -226,7 +226,7 @@ void defaultFunctionCallTest() { // @formatter:off String response = ChatClient.builder(this.chatModel) - .defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .defaultFunction("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")) .build() .prompt() @@ -245,7 +245,7 @@ void streamFunctionCallTest() { // @formatter:off Flux response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.") - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .stream() .content(); // @formatter:on diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java index 0bbc99be341..aed619c2dd6 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java @@ -212,7 +212,7 @@ void functionCallTest() { // @formatter:off String response = ChatClient.create(this.chatModel) .prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.") - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .call() .content(); // @formatter:on @@ -228,7 +228,7 @@ void functionCallWithAdvisorTest() { // @formatter:off String response = ChatClient.create(this.chatModel) .prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.") - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .advisors(new SimpleLoggerAdvisor()) .call() .content(); @@ -244,7 +244,7 @@ void defaultFunctionCallTest() { // @formatter:off String response = ChatClient.builder(this.chatModel) - .defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .defaultFunction("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.")) .build() .prompt() @@ -263,7 +263,7 @@ void streamFunctionCallTest() { // @formatter:off Flux response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.") - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .stream() .content(); // @formatter:on @@ -280,7 +280,7 @@ void singularStreamFunctionCallTest() { // @formatter:off Flux response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in Paris? Return the temperature in Celsius.") - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .stream() .content(); // @formatter:on diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java index 6298c777393..fe1f109a82e 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java @@ -224,7 +224,7 @@ void functionCallTest() { String response = ChatClient.create(this.chatModel).prompt() .options(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).withToolChoice(ToolChoice.AUTO).build()) .user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.")) - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .call() .content(); // @formatter:on @@ -242,7 +242,7 @@ void defaultFunctionCallTest() { // @formatter:off String response = ChatClient.builder(this.chatModel) .defaultOptions(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).build()) - .defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .defaultFunction("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.")) .build() .prompt().call().content(); @@ -262,7 +262,7 @@ void streamFunctionCallTest() { Flux response = ChatClient.create(this.chatModel).prompt() .options(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).build()) .user("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.") - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .stream() .content(); // @formatter:on diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java index d9e6b6ca513..38988191179 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java @@ -79,7 +79,6 @@ public OpenAiAudioSpeechModel openAiAudioSpeechModel(OpenAiAudioApi api) { @Bean public OpenAiImageModel openAiImageModel(OpenAiImageApi imageApi) { OpenAiImageModel openAiImageModel = new OpenAiImageModel(imageApi); - // openAiImageModel.setModel("foobar"); return openAiImageModel; } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java index 3c8a724a063..c65a2d5f2c2 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java @@ -21,6 +21,8 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; import java.util.stream.Collectors; import org.junit.jupiter.api.Disabled; @@ -247,7 +249,7 @@ void functionCallTest() { // @formatter:off String response = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .call() .content(); // @formatter:on @@ -257,12 +259,54 @@ void functionCallTest() { assertThat(response).contains("30", "10", "15"); } + @Test + void functionCallSupplier() { + + Map state = new ConcurrentHashMap<>(); + + // @formatter:off + String response = ChatClient.create(this.chatModel).prompt() + .user("Turn the light on in the living room") + .function("turnLivingRoomLightOnSupplier", "Turns light on in the living room", () -> state.put("foo", "bar")) + .call() + .content(); + // @formatter:on + + logger.info("Response: {}", response); + assertThat(state).containsEntry("foo", "bar"); + } + + @Test + void functionCallConsumer() { + + Map state = new ConcurrentHashMap<>(); + + record LightInfo(String roomName, boolean isOn) { + } + + // @formatter:off + String response = ChatClient.create(this.chatModel).prompt() + .user("Turn the light on in the kitchen and in the living room") + .function("turnLight", "Turn light on or off in a room", LightInfo.class, (LightInfo lightInfo) -> { + logger.info("Turning light to [" + lightInfo.isOn + "] in " + lightInfo.roomName()); + state.put(lightInfo.roomName(), lightInfo.isOn()); + }) + .call() + .content(); + // @formatter:on + + logger.info("Response: {}", response); + assertThat(state).containsEntry("kitchen", Boolean.TRUE); + assertThat(state).containsEntry("living room", Boolean.TRUE); + } + @Test void defaultFunctionCallTest() { // @formatter:off String response = ChatClient.builder(this.chatModel) - .defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .defaultFunction("getCurrentWeather", "Get the weather in location", + MockWeatherService.Request.class, new MockWeatherService()) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) .build() .prompt().call().content(); @@ -279,7 +323,8 @@ void streamFunctionCallTest() { // @formatter:off Flux response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris?") - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", + MockWeatherService.Request.class, new MockWeatherService()) .stream() .content(); // @formatter:on diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java index 6ba1682e1d5..490163ebe13 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java @@ -83,7 +83,8 @@ void turnFunctionsOnAndOffTest() { // @formatter:off response = chatClientBuilder.build().prompt() .user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", + MockWeatherService.Request.class, new MockWeatherService()) .call() .content(); // @formatter:on @@ -110,7 +111,8 @@ void defaultFunctionCallTest() { // @formatter:off String response = ChatClient.builder(this.chatModel) - .defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .defaultFunction("getCurrentWeather", "Get the weather in location", + MockWeatherService.Request.class, new MockWeatherService()) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) .build() .prompt().call().content(); @@ -149,7 +151,7 @@ else if (request.location().contains("San Francisco")) { // @formatter:off String response = ChatClient.builder(this.chatModel) - .defaultFunction("getCurrentWeather", "Get the weather in location", biFunction) + .defaultFunction("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, biFunction) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) .defaultToolContext(Map.of("sessionId", "123")) .build() @@ -189,7 +191,7 @@ else if (request.location().contains("San Francisco")) { // @formatter:off String response = ChatClient.builder(this.chatModel) - .defaultFunction("getCurrentWeather", "Get the weather in location", biFunction) + .defaultFunction("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, biFunction) .defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) .build() .prompt() @@ -208,7 +210,7 @@ void streamFunctionCallTest() { // @formatter:off Flux response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris?") - .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .stream() .content(); // @formatter:on diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java index d53e998967a..bffcacd1e28 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -212,16 +212,38 @@ interface ChatClientRequestSpec { ChatClientRequestSpec options(T options); + /** + * @deprecated Use + * {@link #function(String, String, Class, java.util.function.Function)} instead. + * Because of JVM type erasure, for lambda to work, the inputType class is + * required to be provided explicitly. + */ + @Deprecated ChatClientRequestSpec function(String name, String description, java.util.function.Function function); + ChatClientRequestSpec function(String name, String description, Class inputType, + java.util.function.Function function); + + /** + * @deprecated Use + * {@link #functions(String, String, Class,java.util.function.BiFunction)} Because + * of JVM type erasure, for lambda to work, the inputType class is required to be + * provided explicitly. + */ + @Deprecated ChatClientRequestSpec function(String name, String description, java.util.function.BiFunction function); + ChatClientRequestSpec function(String name, String description, Class inputType, + java.util.function.BiFunction function); + ChatClientRequestSpec functions(FunctionCallback... functionCallbacks); + ChatClientRequestSpec function(String name, String description, java.util.function.Supplier supplier); + ChatClientRequestSpec function(String name, String description, Class inputType, - java.util.function.Function function); + java.util.function.Consumer consumer); ChatClientRequestSpec functions(String... functionBeanNames); @@ -278,11 +300,36 @@ interface Builder { Builder defaultSystem(Consumer systemSpecConsumer); + /** + * @deprecated Use + * {@link #defaultFunction(String, String, Class, java.util.function.Function)} + * instead. Because of JVM type erasure, for lambda to work, the inputType class + * is required to be provided explicitly. + */ + @Deprecated Builder defaultFunction(String name, String description, java.util.function.Function function); + Builder defaultFunction(String name, String description, Class inputType, + java.util.function.Function function); + + /** + * @deprecated Use + * {@link #defaultFunction(String, String, Class, java.util.function.BiFunction)} + * instead. Because of JVM type erasure, for lambda to work, the inputType class + * is required to be provided explicitly. + */ + @Deprecated Builder defaultFunction(String name, String description, java.util.function.BiFunction function); + Builder defaultFunction(String name, String description, Class inputType, + java.util.function.BiFunction function); + + Builder defaultFunction(String name, String description, java.util.function.Supplier supplier); + + Builder defaultFunction(String name, String description, Class inputType, + java.util.function.Consumer consumer); + Builder defaultFunctions(String... functionNames); Builder defaultFunctions(FunctionCallback... functionCallbacks); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 6f46f2749b0..529b17f3e5c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -836,11 +836,40 @@ public ChatClientRequestSpec options(T options) { return this; } + /** + * @deprecated since 1.0.0 in favor of + * {@link #function(String, String, Class, java.util.function.Function)} Because + * of JVM type erasure, the inputType class is required to be provided explicitly. + */ + @Deprecated(since = "1.0.0", forRemoval = true) public ChatClientRequestSpec function(String name, String description, java.util.function.Function function) { return this.function(name, description, null, function); } + public ChatClientRequestSpec function(String name, String description, @Nullable Class inputType, + java.util.function.Function function) { + + Assert.hasText(name, "name cannot be null or empty"); + Assert.hasText(description, "description cannot be null or empty"); + Assert.notNull(function, "function cannot be null"); + + var fcw = FunctionCallbackWrapper.builder(function) + .withDescription(description) + .withName(name) + .withInputType(inputType) + .withResponseConverter(Object::toString) + .build(); + this.functionCallbacks.add(fcw); + return this; + } + + /** + * @deprecated since 1.0.0 in favor of + * {@link #function(String, String, Class, java.util.function.BiFunction)} Because + * of JVM type erasure, the inputType class is required to be provided explicitly. + */ + @Deprecated public ChatClientRequestSpec function(String name, String description, java.util.function.BiFunction biFunction) { @@ -857,14 +886,14 @@ public ChatClientRequestSpec function(String name, String description, return this; } - public ChatClientRequestSpec function(String name, String description, @Nullable Class inputType, - java.util.function.Function function) { + public ChatClientRequestSpec function(String name, String description, Class inputType, + java.util.function.BiFunction biFunction) { Assert.hasText(name, "name cannot be null or empty"); Assert.hasText(description, "description cannot be null or empty"); - Assert.notNull(function, "function cannot be null"); + Assert.notNull(biFunction, "biFunction cannot be null"); - var fcw = FunctionCallbackWrapper.builder(function) + FunctionCallbackWrapper fcw = FunctionCallbackWrapper.builder(biFunction) .withDescription(description) .withName(name) .withInputType(inputType) @@ -874,6 +903,38 @@ public ChatClientRequestSpec function(String name, String description, @N return this; } + public ChatClientRequestSpec function(String name, String description, + java.util.function.Supplier supplier) { + + Assert.hasText(name, "name cannot be null or empty"); + Assert.hasText(description, "description cannot be null or empty"); + Assert.notNull(supplier, "supplier cannot be null"); + + var fcw = FunctionCallbackWrapper.builder(supplier) + .withDescription(description) + .withName(name) + .withInputType(Void.class) + .build(); + this.functionCallbacks.add(fcw); + return this; + } + + public ChatClientRequestSpec function(String name, String description, Class inputType, + java.util.function.Consumer consumer) { + + Assert.hasText(name, "name cannot be null or empty"); + Assert.hasText(description, "description cannot be null or empty"); + Assert.notNull(consumer, "consumer cannot be null"); + + var fcw = FunctionCallbackWrapper.builder(consumer) + .withDescription(description) + .withName(name) + .withInputType(inputType) + .build(); + this.functionCallbacks.add(fcw); + return this; + } + public ChatClientRequestSpec functions(String... functionBeanNames) { Assert.notNull(functionBeanNames, "functionBeanNames cannot be null"); Assert.noNullElements(functionBeanNames, "functionBeanNames cannot contain null elements"); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java index 4ae3833d868..e4e8b56c662 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java @@ -143,17 +143,42 @@ public Builder defaultSystem(Consumer systemSpecConsumer) { return this; } + @Deprecated public Builder defaultFunction(String name, String description, java.util.function.Function function) { this.defaultRequest.function(name, description, function); return this; } + public Builder defaultFunction(String name, String description, Class inputType, + java.util.function.Function function) { + this.defaultRequest.function(name, description, inputType, function); + return this; + } + + @Deprecated public Builder defaultFunction(String name, String description, java.util.function.BiFunction biFunction) { this.defaultRequest.function(name, description, biFunction); return this; } + public Builder defaultFunction(String name, String description, Class inputType, + java.util.function.BiFunction biFunction) { + this.defaultRequest.function(name, description, inputType, biFunction); + return this; + } + + public Builder defaultFunction(String name, String description, java.util.function.Supplier supplier) { + this.defaultRequest.function(name, description, supplier); + return this; + } + + public Builder defaultFunction(String name, String description, Class inputType, + java.util.function.Consumer consumer) { + this.defaultRequest.function(name, description, inputType, consumer); + return this; + } + public Builder defaultFunctions(String... functionNames) { this.defaultRequest.functions(functionNames); return this; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/converter/BeanOutputConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/converter/BeanOutputConverter.java index 0e718c1b989..4cb1dfa594a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/converter/BeanOutputConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/converter/BeanOutputConverter.java @@ -16,7 +16,6 @@ package org.springframework.ai.converter; -import java.lang.reflect.Type; import java.util.Objects; import com.fasterxml.jackson.core.JsonProcessingException; @@ -37,6 +36,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.model.ModelOptionsUtils.CustomizedTypeReference; import org.springframework.ai.util.JacksonUtils; import org.springframework.core.ParameterizedTypeReference; import org.springframework.lang.NonNull; @@ -94,7 +94,7 @@ public BeanOutputConverter(Class clazz, ObjectMapper objectMapper) { * @param typeRef The target class type reference. */ public BeanOutputConverter(ParameterizedTypeReference typeRef) { - this(new CustomizedTypeReference<>(typeRef), null); + this(CustomizedTypeReference.forType(typeRef), null); } /** @@ -105,7 +105,7 @@ public BeanOutputConverter(ParameterizedTypeReference typeRef) { * @param objectMapper Custom object mapper for JSON operations. endings. */ public BeanOutputConverter(ParameterizedTypeReference typeRef, ObjectMapper objectMapper) { - this(new CustomizedTypeReference<>(typeRef), objectMapper); + this(CustomizedTypeReference.forType(typeRef), objectMapper); } /** @@ -220,19 +220,4 @@ public String getJsonSchema() { return this.jsonSchema; } - private static class CustomizedTypeReference extends TypeReference { - - private final Type type; - - CustomizedTypeReference(ParameterizedTypeReference typeRef) { - this.type = typeRef.getType(); - } - - @Override - public Type getType() { - return this.type; - } - - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java index e049fb17f00..f6411e31e92 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java @@ -18,6 +18,7 @@ import java.beans.PropertyDescriptor; import java.lang.reflect.Field; +import java.lang.reflect.Type; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -50,6 +51,7 @@ import org.springframework.ai.util.JacksonUtils; import org.springframework.beans.BeanWrapper; import org.springframework.beans.BeanWrapperImpl; +import org.springframework.core.ParameterizedTypeReference; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.ObjectUtils; @@ -334,11 +336,11 @@ private static String toGetName(String name) { /** * Generates JSON Schema (version 2020_12) for the given class. - * @param clazz the class to generate JSON Schema for. + * @param parameterizedType the class to generate JSON Schema for. * @param toUpperCaseTypeValues if true, the type values are converted to upper case. * @return the generated JSON Schema as a String. */ - public static String getJsonSchema(Class clazz, boolean toUpperCaseTypeValues) { + public static String getJsonSchema(ParameterizedTypeReference parameterizedType, boolean toUpperCaseTypeValues) { if (SCHEMA_GENERATOR_CACHE.get() == null) { @@ -357,9 +359,15 @@ public static String getJsonSchema(Class clazz, boolean toUpperCaseTypeValues SCHEMA_GENERATOR_CACHE.compareAndSet(null, generator); } - ObjectNode node = SCHEMA_GENERATOR_CACHE.get().generateSchema(clazz); - if (toUpperCaseTypeValues) { // Required for OpenAPI 3.0 (at least Vertex AI - // version of it). + ObjectNode node = SCHEMA_GENERATOR_CACHE.get() + .generateSchema(CustomizedTypeReference.forType(parameterizedType).getType()); + + if ((parameterizedType.getType() == Void.class) && node.get("properties") == null) { + node.putObject("properties"); + } + + // Required for OpenAPI 3.0 (at least Vertex AI version of it). + if (toUpperCaseTypeValues) { toUpperCaseTypeValues(node); } @@ -405,4 +413,23 @@ public static T mergeOption(T runtimeValue, T defaultValue) { return ObjectUtils.isEmpty(runtimeValue) ? defaultValue : runtimeValue; } + public static class CustomizedTypeReference extends TypeReference { + + private final Type type; + + CustomizedTypeReference(ParameterizedTypeReference typeRef) { + this.type = typeRef.getType(); + } + + @Override + public Type getType() { + return this.type; + } + + public static CustomizedTypeReference forType(ParameterizedTypeReference typeRef) { + return new CustomizedTypeReference<>(typeRef); + } + + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java index 8a2c84aca52..e01226ddd60 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java @@ -24,6 +24,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.model.ModelOptionsUtils.CustomizedTypeReference; +import org.springframework.core.ParameterizedTypeReference; import org.springframework.util.Assert; /** @@ -47,7 +49,7 @@ abstract class AbstractFunctionCallback implements BiFunction inputType; + private final ParameterizedTypeReference inputType; private final String inputTypeSchema; @@ -70,8 +72,8 @@ abstract class AbstractFunctionCallback implements BiFunction inputType, - Function responseConverter, ObjectMapper objectMapper) { + protected AbstractFunctionCallback(String name, String description, String inputTypeSchema, + ParameterizedTypeReference inputType, Function responseConverter, ObjectMapper objectMapper) { Assert.notNull(name, "Name must not be null"); Assert.notNull(description, "Description must not be null"); Assert.notNull(inputType, "InputType must not be null"); @@ -116,9 +118,9 @@ public String call(String functionArguments) { return this.andThen(this.responseConverter).apply(request, null); } - private T fromJson(String json, Class targetClass) { + private T fromJson(String json, ParameterizedTypeReference targetClass) { try { - return this.objectMapper.readValue(json, targetClass); + return this.objectMapper.readValue(json, CustomizedTypeReference.forType(targetClass)); } catch (JsonProcessingException e) { throw new RuntimeException(e); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java index 762d33969bf..29fe5262c0b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java @@ -17,7 +17,9 @@ package org.springframework.ai.model.function; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Supplier; import com.fasterxml.jackson.annotation.JsonClassDescription; import kotlin.jvm.functions.Function1; @@ -122,11 +124,30 @@ else if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) { } } if (bean instanceof Function function) { + // ResolvableType.forInstance(function); return FunctionCallbackWrapper.builder(function) .withName(beanName) .withSchemaType(this.schemaType) .withDescription(functionDescription) - .withInputType(functionInputClass) + .withInputType(functionInputType) + // .withInputType(functionInputClass) + .build(); + } + if (bean instanceof Consumer consumer) { + return FunctionCallbackWrapper.builder(consumer) + .withName(beanName) + .withSchemaType(this.schemaType) + .withDescription(functionDescription) + .withInputType(functionInputType) + // .withInputType(functionInputClass) + .build(); + } + if (bean instanceof Supplier supplier) { + return FunctionCallbackWrapper.builder(supplier) + .withName(beanName) + .withSchemaType(this.schemaType) + .withDescription(functionDescription) + .withInputType(functionInputType) .build(); } else if (bean instanceof BiFunction) { @@ -134,7 +155,8 @@ else if (bean instanceof BiFunction) { .withName(beanName) .withSchemaType(this.schemaType) .withDescription(functionDescription) - .withInputType(functionInputClass) + .withInputType(functionInputType) + // .withInputType(functionInputClass) .build(); } else { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java index fe9fa0a1533..e3b077eb79d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java @@ -17,7 +17,9 @@ package org.springframework.ai.model.function; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Supplier; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; @@ -28,6 +30,8 @@ import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.FunctionCallbackContext.SchemaType; import org.springframework.ai.util.JacksonUtils; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.ResolvableType; import org.springframework.util.Assert; /** @@ -44,8 +48,9 @@ public final class FunctionCallbackWrapper extends AbstractFunctionCallbac private final BiFunction biFunction; - private FunctionCallbackWrapper(String name, String description, String inputTypeSchema, Class inputType, - Function responseConverter, ObjectMapper objectMapper, BiFunction function) { + private FunctionCallbackWrapper(String name, String description, String inputTypeSchema, + ParameterizedTypeReference inputType, Function responseConverter, ObjectMapper objectMapper, + BiFunction function) { super(name, description, inputTypeSchema, inputType, responseConverter, objectMapper); Assert.notNull(function, "Function must not be null"); this.biFunction = function; @@ -59,6 +64,15 @@ public static Builder builder(Function function) { return new Builder<>(function); } + public static Builder builder(Supplier supplier) { + Function function = (input) -> supplier.get(); + return new Builder<>(function); + } + + public static Builder builder(Consumer consumer) { + return new Builder<>(consumer); + } + @Override public O apply(I input, ToolContext context) { return this.biFunction.apply(input, context); @@ -70,11 +84,13 @@ public static class Builder { private final Function function; + private final Consumer consumer; + private String name; private String description; - private Class inputType; + private ParameterizedTypeReference inputType; private SchemaType schemaType = SchemaType.JSON_SCHEMA; @@ -89,23 +105,37 @@ public Builder(BiFunction biFunction) { Assert.notNull(biFunction, "Function must not be null"); this.biFunction = biFunction; this.function = null; + this.consumer = null; } public Builder(Function function) { Assert.notNull(function, "Function must not be null"); this.biFunction = null; this.function = function; + this.consumer = null; } - @SuppressWarnings("unchecked") - private static Class resolveInputType(BiFunction biFunction) { - return (Class) TypeResolverHelper - .getBiFunctionInputClass((Class>) biFunction.getClass()); + public Builder(Consumer consumer) { + Assert.notNull(consumer, "Consumer must not be null"); + this.biFunction = null; + this.function = null; + this.consumer = consumer; } - @SuppressWarnings("unchecked") - private static Class resolveInputType(Function function) { - return (Class) TypeResolverHelper.getFunctionInputClass((Class>) function.getClass()); + private static ParameterizedTypeReference resolveInputType(BiFunction biFunction) { + + ResolvableType rt = TypeResolverHelper.getFunctionArgumentType(ResolvableType.forInstance(biFunction), 0); + return ParameterizedTypeReference.forType(rt.getType()); + } + + private static ParameterizedTypeReference resolveInputType(Function function) { + ResolvableType rt = TypeResolverHelper.getFunctionArgumentType(ResolvableType.forInstance(function), 0); + return ParameterizedTypeReference.forType(rt.getType()); + } + + private static ParameterizedTypeReference resolveInputType(Consumer consumer) { + ResolvableType rt = TypeResolverHelper.getFunctionArgumentType(ResolvableType.forInstance(consumer), 0); + return ParameterizedTypeReference.forType(rt.getType()); } public Builder withName(String name) { @@ -122,7 +152,12 @@ public Builder withDescription(String description) { @SuppressWarnings("unchecked") public Builder withInputType(Class inputType) { - this.inputType = (Class) inputType; + this.inputType = ParameterizedTypeReference.forType((Class) inputType); + return this; + } + + public Builder withInputType(ResolvableType inputType) { + this.inputType = ParameterizedTypeReference.forType(inputType.getType()); return this; } @@ -168,9 +203,12 @@ public FunctionCallbackWrapper build() { if (this.function != null) { this.inputType = resolveInputType(this.function); } - else { + else if (this.biFunction != null) { this.inputType = resolveInputType(this.biFunction); } + else { + this.inputType = resolveInputType(this.consumer); + } } if (this.inputTypeSchema == null) { @@ -178,8 +216,23 @@ public FunctionCallbackWrapper build() { this.inputTypeSchema = ModelOptionsUtils.getJsonSchema(this.inputType, upperCaseTypeValues); } - BiFunction finalBiFunction = (this.biFunction != null) ? this.biFunction - : (request, context) -> this.function.apply(request); + BiFunction finalBiFunction = null; + if (this.biFunction != null) { + finalBiFunction = this.biFunction; + } + else if (this.function != null) { + finalBiFunction = (request, context) -> this.function.apply(request); + } + else { + finalBiFunction = (request, context) -> { + this.consumer.accept(request); + return null; + }; + } + + // BiFunction finalBiFunction = (this.biFunction != null) ? + // this.biFunction + // : (request, context) -> this.function.apply(request); return new FunctionCallbackWrapper<>(this.name, this.description, this.inputTypeSchema, this.inputType, this.responseConverter, this.objectMapper, finalBiFunction); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java index feaafd80130..14c14f2c9a1 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java @@ -20,7 +20,9 @@ import java.lang.reflect.Modifier; import java.util.Arrays; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Supplier; import kotlin.jvm.functions.Function1; import kotlin.jvm.functions.Function2; @@ -62,6 +64,16 @@ public static Class getFunctionInputClass(Class> fun return getFunctionArgumentClass(functionClass, 0); } + /** + * Returns the input class of a given Consumer class. + * @param consumerClass The consumer class. + * @return The input class of the consumer. + */ + public static Class getConsumerInputClass(Class> consumerClass) { + ResolvableType resolvableType = ResolvableType.forClass(consumerClass).as(Consumer.class); + return (resolvableType == ResolvableType.NONE ? Object.class : resolvableType.getGeneric(0).toClass()); + } + /** * Returns the output class of a given function class. * @param functionClass The function class. @@ -189,6 +201,12 @@ public static ResolvableType getFunctionArgumentType(ResolvableType functionType else if (BiFunction.class.isAssignableFrom(resolvableClass)) { functionArgumentResolvableType = functionType.as(BiFunction.class); } + else if (Supplier.class.isAssignableFrom(resolvableClass)) { + return ResolvableType.forClass(Void.class); + } + else if (Consumer.class.isAssignableFrom(resolvableClass)) { + functionArgumentResolvableType = functionType.as(Consumer.class); + } else if (KotlinDetector.isKotlinPresent()) { if (KotlinDelegate.isKotlinFunction(resolvableClass)) { functionArgumentResolvableType = KotlinDelegate.adaptToKotlinFunctionType(functionType); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java index 1f9d407c38a..d085e810f16 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java @@ -217,7 +217,7 @@ void mutateDefaults() { .param("param1", "value1") .param("param2", "value2")) .defaultFunctions("fun1", "fun2") - .defaultFunction("fun3", "fun3description", mockFunction) + .defaultFunction("fun3", "fun3description", String.class, mockFunction) .defaultUser(u -> u.text("Default user text {uparam1}, {uparam2}") .param("uparam1", "value1") .param("uparam2", "value2") @@ -344,7 +344,7 @@ void mutatePrompt() { .param("param1", "value1") .param("param2", "value2")) .defaultFunctions("fun1", "fun2") - .defaultFunction("fun3", "fun3description", mockFunction) + .defaultFunction("fun3", "fun3description", String.class, mockFunction) .defaultUser(u -> u.text("Default user text {uparam1}, {uparam2}") .param("uparam1", "value1") .param("uparam2", "value2") @@ -541,9 +541,9 @@ void complexCall() throws MalformedURLException { assertThat(userMessage.getMedia().iterator().next().getData()) .isEqualTo("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"); - FunctionCallingOptions runtieOptions = (FunctionCallingOptions) this.promptCaptor.getValue().getOptions(); + FunctionCallingOptions runtimeOptions = (FunctionCallingOptions) this.promptCaptor.getValue().getOptions(); - assertThat(runtieOptions.getFunctions()).containsExactly("function1"); + assertThat(runtimeOptions.getFunctions()).containsExactly("function1"); assertThat(options.getFunctions()).isEmpty(); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java index 269fa7812dd..a2cbfed91ef 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java @@ -1354,7 +1354,7 @@ void whenOptionsThenReturn() { void whenFunctionNameIsNullThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.function(null, "description", input -> "hello")) + assertThatThrownBy(() -> spec.function(null, "description", String.class, input -> "hello")) .isInstanceOf(IllegalArgumentException.class) .hasMessage("name cannot be null or empty"); } @@ -1363,7 +1363,7 @@ void whenFunctionNameIsNullThenThrow() { void whenFunctionNameIsEmptyThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.function("", "description", input -> "hello")) + assertThatThrownBy(() -> spec.function("", "description", String.class, input -> "hello")) .isInstanceOf(IllegalArgumentException.class) .hasMessage("name cannot be null or empty"); } @@ -1372,7 +1372,7 @@ void whenFunctionNameIsEmptyThenThrow() { void whenFunctionDescriptionIsNullThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.function("name", null, input -> "hello")) + assertThatThrownBy(() -> spec.function("name", null, String.class, input -> "hello")) .isInstanceOf(IllegalArgumentException.class) .hasMessage("description cannot be null or empty"); } @@ -1381,7 +1381,7 @@ void whenFunctionDescriptionIsNullThenThrow() { void whenFunctionDescriptionIsEmptyThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.function("name", "", input -> "hello")) + assertThatThrownBy(() -> spec.function("name", "", String.class, input -> "hello")) .isInstanceOf(IllegalArgumentException.class) .hasMessage("description cannot be null or empty"); } @@ -1390,7 +1390,7 @@ void whenFunctionDescriptionIsEmptyThenThrow() { void whenFunctionLambdaIsNullThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.function("name", "description", (Function) null)) + assertThatThrownBy(() -> spec.function("name", "description", String.class, (Function) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("function cannot be null"); } @@ -1399,7 +1399,7 @@ void whenFunctionLambdaIsNullThenThrow() { void whenFunctionThenReturn() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - spec = spec.function("name", "description", input -> "hello"); + spec = spec.function("name", "description", String.class, input -> "hello"); DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; assertThat(defaultSpec.getFunctionCallbacks()).anyMatch(callback -> callback.getName().equals("name")); } @@ -1417,7 +1417,7 @@ void whenFunctionAndInputTypeThenReturn() { void whenBiFunctionNameIsNullThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.function(null, "description", (input, ctx) -> "hello")) + assertThatThrownBy(() -> spec.function(null, "description", String.class, (input, ctx) -> "hello")) .isInstanceOf(IllegalArgumentException.class) .hasMessage("name cannot be null or empty"); } @@ -1426,7 +1426,7 @@ void whenBiFunctionNameIsNullThenThrow() { void whenBiFunctionNameIsEmptyThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.function("", "description", (input, ctx) -> "hello")) + assertThatThrownBy(() -> spec.function("", "description", String.class, (input, ctx) -> "hello")) .isInstanceOf(IllegalArgumentException.class) .hasMessage("name cannot be null or empty"); } @@ -1435,7 +1435,7 @@ void whenBiFunctionNameIsEmptyThenThrow() { void whenBiFunctionDescriptionIsNullThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.function("name", null, (input, ctx) -> "hello")) + assertThatThrownBy(() -> spec.function("name", null, String.class, (input, ctx) -> "hello")) .isInstanceOf(IllegalArgumentException.class) .hasMessage("description cannot be null or empty"); } @@ -1444,7 +1444,7 @@ void whenBiFunctionDescriptionIsNullThenThrow() { void whenBiFunctionDescriptionIsEmptyThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.function("name", "", (input, ctx) -> "hello")) + assertThatThrownBy(() -> spec.function("name", "", String.class, (input, ctx) -> "hello")) .isInstanceOf(IllegalArgumentException.class) .hasMessage("description cannot be null or empty"); } @@ -1453,7 +1453,7 @@ void whenBiFunctionDescriptionIsEmptyThenThrow() { void whenBiFunctionLambdaIsNullThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.function("name", "description", (BiFunction) null)) + assertThatThrownBy(() -> spec.function("name", "description", String.class, (BiFunction) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("biFunction cannot be null"); } @@ -1462,7 +1462,26 @@ void whenBiFunctionLambdaIsNullThenThrow() { void whenBiFunctionThenReturn() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - spec = spec.function("name", "description", (input, ctx) -> "hello"); + spec = spec.function("name", "description", String.class, (input, ctx) -> "hello"); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getFunctionCallbacks()).anyMatch(callback -> callback.getName().equals("name")); + } + + @Test + void whenSupplierFunctionThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + spec = spec.function("name", "description", () -> "hello"); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getFunctionCallbacks()).anyMatch(callback -> callback.getName().equals("name")); + } + + @Test + void whenConsumerFunctionThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + Consumer consumer = input -> System.out.println(input); + spec = spec.function("name", "description", String.class, consumer); DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; assertThat(defaultSpec.getFunctionCallbacks()).anyMatch(callback -> callback.getName().equals("name")); } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc index 07062e9ede2..2b861dfcd37 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc @@ -277,6 +277,7 @@ ChatResponse response = this.chatClient.prompt("What's the weather like in San F .functions(new FunctionCallbackWrapper<>( "CurrentWeather", // name "Get the weather in location", // function description + MockWeatherService.Request.class, // input type new MockWeatherService())) .call() .chatResponse(); @@ -405,6 +406,7 @@ BiFunction ChatResponse response = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") .functions(FunctionCallbackWrapper.builder(this.weatherFunction) + .withInputType(MockWeatherService.Request.class) .withName("getCurrentWeather") .withDescription("Get the weather in location") .build()) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java index b6e2c3024e0..dacea1edc54 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java @@ -16,6 +16,8 @@ package org.springframework.ai.autoconfigure.openai.tool; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; import java.util.stream.Collectors; @@ -39,17 +41,17 @@ public class FunctionCallbackInPrompt2IT { private final Logger logger = LoggerFactory.getLogger(FunctionCallbackInPromptIT.class); private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY")) + .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY"), + "spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) .withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class)); @Test void functionCallTest() { - this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) - .run(context -> { + this.contextRunner.run(context -> { - OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); - ChatClient chatClient = ChatClient.builder(chatModel).build(); + ChatClient chatClient = ChatClient.builder(chatModel).build(); // @formatter:off chatClient.prompt() @@ -58,27 +60,26 @@ void functionCallTest() { String content = ChatClient.builder(chatModel).build().prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris?") - .function("CurrentWeatherService", "Get the weather in location", new MockWeatherService()) + .function("CurrentWeatherService", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService()) .call().content(); // @formatter:on - logger.info("Response: {}", content); + logger.info("Response: {}", content); - assertThat(content).contains("30", "10", "15"); - }); + assertThat(content).contains("30", "10", "15"); + }); } @Test void functionCallTest2() { - this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) - .run(context -> { + this.contextRunner.run(context -> { - OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); // @formatter:off String content = ChatClient.builder(chatModel).build().prompt() .user("What's the weather like in Amsterdam?") - .function("CurrentWeatherService", "Get the weather in location", + .function("CurrentWeatherService", "Get the weather in location", MockWeatherService.Request.class, new Function() { @Override public String apply(MockWeatherService.Request request) { @@ -87,32 +88,59 @@ public String apply(MockWeatherService.Request request) { }) .call().content(); // @formatter:on - logger.info("Response: {}", content); + logger.info("Response: {}", content); - assertThat(content).contains("18"); - }); + assertThat(content).contains("18"); + }); + } + + @Test + void functionCallTest21() { + Map state = new ConcurrentHashMap<>(); + + record LightInfo(String roomName, boolean isOn) { + } + + this.contextRunner.run(context -> { + + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + + // @formatter:off + String content = ChatClient.builder(chatModel).build().prompt() + .user("Turn the light on in the kitchen and in the living room!") + .function("turnLight", "Turn light on or off in a room",LightInfo.class, + lightInfo -> { + logger.info("Turning light to [" + lightInfo.isOn + "] in " + lightInfo.roomName()); + state.put(lightInfo.roomName(), lightInfo.isOn()); + }) + .call().content(); + // @formatter:on + logger.info("Response: {}", content); + assertThat(state).containsEntry("kitchen", Boolean.TRUE); + assertThat(state).containsEntry("living room", Boolean.TRUE); + }); } @Test void streamingFunctionCallTest() { - this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) - .run(context -> { + this.contextRunner.run(context -> { - OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); // @formatter:off String content = ChatClient.builder(chatModel).build().prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris?") - .function("CurrentWeatherService", "Get the weather in location", new MockWeatherService()) + .function("CurrentWeatherService", "Get the weather in location", + MockWeatherService.Request.class, new MockWeatherService()) .stream().content() .collectList().block().stream().collect(Collectors.joining()); // @formatter:on - logger.info("Response: {}", content); + logger.info("Response: {}", content); - assertThat(content).contains("30", "10", "15"); - }); + assertThat(content).contains("30", "10", "15"); + }); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java index 6408edc1b9a..66b74a197d6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -18,10 +18,14 @@ import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Supplier; import java.util.stream.Collectors; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; @@ -52,179 +56,272 @@ @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") class FunctionCallbackWithPlainFunctionBeanIT { - private final Logger logger = LoggerFactory.getLogger(FunctionCallbackWithPlainFunctionBeanIT.class); + private final static Logger logger = LoggerFactory.getLogger(FunctionCallbackWithPlainFunctionBeanIT.class); private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY")) + .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY"), + "spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) .withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class)) .withUserConfiguration(Config.class); + private static Map feedback = new ConcurrentHashMap<>(); + + @BeforeEach + void setUp() { + feedback.clear(); + } + @Test void functionCallWithDirectBiFunction() { - this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) - .run(context -> { + this.contextRunner.run(context -> { - OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); - ChatClient chatClient = ChatClient.builder(chatModel).build(); + ChatClient chatClient = ChatClient.builder(chatModel).build(); - String content = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") - .functions("weatherFunctionWithContext") - .toolContext(Map.of("sessionId", "123")) - .call() - .content(); - logger.info(content); + String content = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") + .functions("weatherFunctionWithContext") + .toolContext(Map.of("sessionId", "123")) + .call() + .content(); + logger.info(content); - // Test weatherFunction - UserMessage userMessage = new UserMessage( - "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); + // Test weatherFunction + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder() - .withFunction("weatherFunctionWithContext") - .withToolContext(Map.of("sessionId", "123")) - .build())); + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder() + .withFunction("weatherFunctionWithContext") + .withToolContext(Map.of("sessionId", "123")) + .build())); - logger.info("Response: {}", response); + logger.info("Response: {}", response); - assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); - }); + }); } @Test void functionCallWithBiFunctionClass() { - this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) - .run(context -> { + this.contextRunner.run(context -> { - OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); - ChatClient chatClient = ChatClient.builder(chatModel).build(); + ChatClient chatClient = ChatClient.builder(chatModel).build(); - String content = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") - .functions("weatherFunctionWithClassBiFunction") - .toolContext(Map.of("sessionId", "123")) - .call() - .content(); - logger.info(content); + String content = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") + .functions("weatherFunctionWithClassBiFunction") + .toolContext(Map.of("sessionId", "123")) + .call() + .content(); + logger.info(content); - // Test weatherFunction - UserMessage userMessage = new UserMessage( - "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); + // Test weatherFunction + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder() - .withFunction("weatherFunctionWithClassBiFunction") - .withToolContext(Map.of("sessionId", "123")) - .build())); + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder() + .withFunction("weatherFunctionWithClassBiFunction") + .withToolContext(Map.of("sessionId", "123")) + .build())); - logger.info("Response: {}", response); + logger.info("Response: {}", response); - assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); - }); + }); } @Test void functionCallTest() { - this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName()) - .run(context -> { + this.contextRunner.run(context -> { - OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); - // Test weatherFunction - UserMessage userMessage = new UserMessage( - "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); + // Test weatherFunction + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().withFunction("weatherFunction").build())); + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withFunction("weatherFunction").build())); - logger.info("Response: {}", response); + logger.info("Response: {}", response); - assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); - // Test weatherFunctionTwo - response = chatModel.call(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().withFunction("weatherFunctionTwo").build())); + // Test weatherFunctionTwo + response = chatModel.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withFunction("weatherFunctionTwo").build())); - logger.info("Response: {}", response); + logger.info("Response: {}", response); - assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); - }); + }); } @Test void functionCallWithPortableFunctionCallingOptions() { - this.contextRunner - .withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName(), - "spring.ai.openai.chat.options.temperature=0.1") - .run(context -> { + this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.temperature=0.1").run(context -> { - OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); - // Test weatherFunction - UserMessage userMessage = new UserMessage( - "What's the weather like in San Francisco, Tokyo, and Paris?"); + // Test weatherFunction + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); - PortableFunctionCallingOptions functionOptions = FunctionCallingOptions.builder() - .withFunction("weatherFunction") - .build(); + PortableFunctionCallingOptions functionOptions = FunctionCallingOptions.builder() + .withFunction("weatherFunction") + .build(); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); - logger.info("Response: {}", response.getResult().getOutput().getContent()); + logger.info("Response: {}", response.getResult().getOutput().getContent()); - assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); - }); + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + }); } @Test void streamFunctionCallTest() { - this.contextRunner - .withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName(), - "spring.ai.openai.chat.options.temperature=0.1") - .run(context -> { - - OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); - - // Test weatherFunction - UserMessage userMessage = new UserMessage( - "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); - - Flux response = chatModel.stream(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().withFunction("weatherFunction").build())); - - String content = response.collectList() - .block() - .stream() - .map(ChatResponse::getResults) - .flatMap(List::stream) - .map(Generation::getOutput) - .map(AssistantMessage::getContent) - .collect(Collectors.joining()); - logger.info("Response: {}", content); - - assertThat(content).contains("30", "10", "15"); - - // Test weatherFunctionTwo - response = chatModel.stream(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().withFunction("weatherFunctionTwo").build())); - - content = response.collectList() - .block() - .stream() - .map(ChatResponse::getResults) - .flatMap(List::stream) - .map(Generation::getOutput) - .map(AssistantMessage::getContent) - .collect(Collectors.joining()); - logger.info("Response: {}", content); - - assertThat(content).isNotEmpty().withFailMessage("Content returned from OpenAI model is empty"); - assertThat(content).contains("30", "10", "15"); - - }); + this.contextRunner.withPropertyValues("spring.ai.openai.chat.options.temperature=0.1").run(context -> { + + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); + + Flux response = chatModel.stream(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withFunction("weatherFunction").build())); + + String content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + + assertThat(content).contains("30", "10", "15"); + + // Test weatherFunctionTwo + response = chatModel.stream(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withFunction("weatherFunctionTwo").build())); + + content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + + assertThat(content).isNotEmpty().withFailMessage("Content returned from OpenAI model is empty"); + assertThat(content).contains("30", "10", "15"); + + }); + } + + @Test + void functionCallingVoidResponse() { + this.contextRunner.run(context -> { + + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage("Turn the light on in the kitchen and in the living room"); + + ChatResponse response = chatModel + .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withFunction("turnLight").build())); + + logger.info("Response: {}", response); + assertThat(feedback).hasSize(2); + assertThat(feedback.get("kitchen")).isEqualTo(Boolean.valueOf(true)); + assertThat(feedback.get("living room")).isEqualTo(Boolean.valueOf(true)); + }); + } + + @Test + void functionCallingConsumer() { + this.contextRunner.run(context -> { + + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage("Turn the light on in the kitchen and in the living room"); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withFunction("turnLightConsumer").build())); + + logger.info("Response: {}", response); + assertThat(feedback).hasSize(2); + assertThat(feedback.get("kitchen")).isEqualTo(Boolean.valueOf(true)); + assertThat(feedback.get("living room")).isEqualTo(Boolean.valueOf(true)); + + }); + } + + @Test + void functionCallingVoidArguments() { + this.contextRunner.run(context -> { + + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage("Turn the light on in the living room"); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withFunction("turnLivingRoomLightOn").build())); + + logger.info("Response: {}", response); + assertThat(feedback).hasSize(1); + assertThat(feedback.get("turnLivingRoomLightOn")).isEqualTo(Boolean.valueOf(true)); + }); + } + + @Test + void functionCallingSupplier() { + this.contextRunner.run(context -> { + + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage("Turn the light on in the living room"); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withFunction("turnLivingRoomLightOnSupplier").build())); + + logger.info("Response: {}", response); + assertThat(feedback).hasSize(1); + assertThat(feedback.get("turnLivingRoomLightOnSupplier")).isEqualTo(Boolean.valueOf(true)); + }); + } + + @Test + void trainScheduler() { + this.contextRunner.run(context -> { + + OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage( + "Please schedule a train from San Francisco to Los Angeles on 2023-12-25"); + + PortableFunctionCallingOptions functionOptions = FunctionCallingOptions.builder() + .withFunction("trainReservation") + .build(); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); + + logger.info("Response: {}", response.getResult().getOutput().getContent()); + }); } @Configuration @@ -256,6 +353,70 @@ public Function weather return (weatherService::apply); } + record LightInfo(String roomName, boolean isOn) { + } + + @Bean + @Description("Turn light on or off in a room") + public Function turnLight() { + return (LightInfo lightInfo) -> { + logger.info("Turning light to [" + lightInfo.isOn + "] in " + lightInfo.roomName()); + feedback.put(lightInfo.roomName(), lightInfo.isOn()); + return null; + }; + } + + @Bean + @Description("Turn light on or off in a room") + public Consumer turnLightConsumer() { + return (LightInfo lightInfo) -> { + logger.info("Turning light to [" + lightInfo.isOn + "] in " + lightInfo.roomName()); + feedback.put(lightInfo.roomName(), lightInfo.isOn()); + }; + } + + @Bean + @Description("Turns light on in the living room") + public Function turnLivingRoomLightOn() { + return (Void v) -> { + logger.info("Turning light on in the living room"); + feedback.put("turnLivingRoomLightOn", Boolean.TRUE); + return "Done"; + }; + } + + @Bean + @Description("Turns light on in the living room") + public Supplier turnLivingRoomLightOnSupplier() { + return () -> { + logger.info("Turning light on in the living room"); + feedback.put("turnLivingRoomLightOnSupplier", Boolean.TRUE); + return "Done"; + }; + } + + record TrainSearchSchedule(String from, String to, String date) { + } + + record TrainSearchScheduleResponse(String from, String to, String date, String trainNumber) { + } + + record TrainSearchRequest(T data) { + } + + record TrainSearchResponse(T data) { + } + + @Bean + @Description("Schedule a train reservation") + public Function, TrainSearchResponse> trainReservation() { + return (TrainSearchRequest request) -> { + logger.info("Turning light to [" + request.data().from() + "] in " + request.data().to()); + return new TrainSearchResponse<>( + new TrainSearchScheduleResponse(request.data().from(), request.data().to(), "", "123")); + }; + } + } public static class MyBiFunction