Skip to content

[WIP] feat: Add support for Consumer and Supplier function callbacks #1719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ void functionCallTest() {
// @formatter:off
String response = ChatClient.create(this.chatModel).prompt()
.user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius."))
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
.function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
.call()
.content();
// @formatter:on
Expand All @@ -226,7 +226,7 @@ void defaultFunctionCallTest() {

// @formatter:off
String response = ChatClient.builder(this.chatModel)
.defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService())
.defaultFunction("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius."))
.build()
.prompt()
Expand All @@ -245,7 +245,7 @@ void streamFunctionCallTest() {
// @formatter:off
Flux<String> response = ChatClient.create(this.chatModel).prompt()
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
.function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
.stream()
.content();
// @formatter:on
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ void functionCallTest() {
// @formatter:off
String response = ChatClient.create(this.chatModel)
.prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.")
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
.function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
.call()
.content();
// @formatter:on
Expand All @@ -228,7 +228,7 @@ void functionCallWithAdvisorTest() {
// @formatter:off
String response = ChatClient.create(this.chatModel)
.prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.")
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
.function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
.advisors(new SimpleLoggerAdvisor())
.call()
.content();
Expand All @@ -244,7 +244,7 @@ void defaultFunctionCallTest() {

// @formatter:off
String response = ChatClient.builder(this.chatModel)
.defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService())
.defaultFunction("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius."))
.build()
.prompt()
Expand All @@ -263,7 +263,7 @@ void streamFunctionCallTest() {
// @formatter:off
Flux<String> response = ChatClient.create(this.chatModel).prompt()
.user("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.")
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
.function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
.stream()
.content();
// @formatter:on
Expand All @@ -280,7 +280,7 @@ void singularStreamFunctionCallTest() {
// @formatter:off
Flux<String> response = ChatClient.create(this.chatModel).prompt()
.user("What's the weather like in Paris? Return the temperature in Celsius.")
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
.function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
.stream()
.content();
// @formatter:on
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ void functionCallTest() {
String response = ChatClient.create(this.chatModel).prompt()
.options(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).withToolChoice(ToolChoice.AUTO).build())
.user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius."))
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
.function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
.call()
.content();
// @formatter:on
Expand All @@ -242,7 +242,7 @@ void defaultFunctionCallTest() {
// @formatter:off
String response = ChatClient.builder(this.chatModel)
.defaultOptions(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).build())
.defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService())
.defaultFunction("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius."))
.build()
.prompt().call().content();
Expand All @@ -262,7 +262,7 @@ void streamFunctionCallTest() {
Flux<String> response = ChatClient.create(this.chatModel).prompt()
.options(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).build())
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.")
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
.function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
.stream()
.content();
// @formatter:on
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ public OpenAiAudioSpeechModel openAiAudioSpeechModel(OpenAiAudioApi api) {
@Bean
public OpenAiImageModel openAiImageModel(OpenAiImageApi imageApi) {
OpenAiImageModel openAiImageModel = new OpenAiImageModel(imageApi);
// openAiImageModel.setModel("foobar");
return openAiImageModel;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import org.junit.jupiter.api.Disabled;
Expand Down Expand Up @@ -247,7 +249,7 @@ void functionCallTest() {
// @formatter:off
String response = ChatClient.create(this.chatModel).prompt()
.user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
.function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
.call()
.content();
// @formatter:on
Expand All @@ -257,12 +259,54 @@ void functionCallTest() {
assertThat(response).contains("30", "10", "15");
}

@Test
void functionCallSupplier() {

Map<String, Object> state = new ConcurrentHashMap<>();

// @formatter:off
String response = ChatClient.create(this.chatModel).prompt()
.user("Turn the light on in the living room")
.function("turnLivingRoomLightOnSupplier", "Turns light on in the living room", () -> state.put("foo", "bar"))
.call()
.content();
// @formatter:on

logger.info("Response: {}", response);
assertThat(state).containsEntry("foo", "bar");
}

@Test
void functionCallConsumer() {

Map<String, Object> state = new ConcurrentHashMap<>();

record LightInfo(String roomName, boolean isOn) {
}

// @formatter:off
String response = ChatClient.create(this.chatModel).prompt()
.user("Turn the light on in the kitchen and in the living room")
.function("turnLight", "Turn light on or off in a room", LightInfo.class, (LightInfo lightInfo) -> {
logger.info("Turning light to [" + lightInfo.isOn + "] in " + lightInfo.roomName());
state.put(lightInfo.roomName(), lightInfo.isOn());
})
.call()
.content();
// @formatter:on

logger.info("Response: {}", response);
assertThat(state).containsEntry("kitchen", Boolean.TRUE);
assertThat(state).containsEntry("living room", Boolean.TRUE);
}

@Test
void defaultFunctionCallTest() {

// @formatter:off
String response = ChatClient.builder(this.chatModel)
.defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService())
.defaultFunction("getCurrentWeather", "Get the weather in location",
MockWeatherService.Request.class, new MockWeatherService())
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
.build()
.prompt().call().content();
Expand All @@ -279,7 +323,8 @@ void streamFunctionCallTest() {
// @formatter:off
Flux<String> response = ChatClient.create(this.chatModel).prompt()
.user("What's the weather like in San Francisco, Tokyo, and Paris?")
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
.function("getCurrentWeather", "Get the weather in location",
MockWeatherService.Request.class, new MockWeatherService())
.stream()
.content();
// @formatter:on
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ void turnFunctionsOnAndOffTest() {
// @formatter:off
response = chatClientBuilder.build().prompt()
.user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
.function("getCurrentWeather", "Get the weather in location",
MockWeatherService.Request.class, new MockWeatherService())
.call()
.content();
// @formatter:on
Expand All @@ -110,7 +111,8 @@ void defaultFunctionCallTest() {

// @formatter:off
String response = ChatClient.builder(this.chatModel)
.defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService())
.defaultFunction("getCurrentWeather", "Get the weather in location",
MockWeatherService.Request.class, new MockWeatherService())
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
.build()
.prompt().call().content();
Expand Down Expand Up @@ -149,7 +151,7 @@ else if (request.location().contains("San Francisco")) {

// @formatter:off
String response = ChatClient.builder(this.chatModel)
.defaultFunction("getCurrentWeather", "Get the weather in location", biFunction)
.defaultFunction("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, biFunction)
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
.defaultToolContext(Map.of("sessionId", "123"))
.build()
Expand Down Expand Up @@ -189,7 +191,7 @@ else if (request.location().contains("San Francisco")) {

// @formatter:off
String response = ChatClient.builder(this.chatModel)
.defaultFunction("getCurrentWeather", "Get the weather in location", biFunction)
.defaultFunction("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, biFunction)
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
.build()
.prompt()
Expand All @@ -208,7 +210,7 @@ void streamFunctionCallTest() {
// @formatter:off
Flux<String> response = ChatClient.create(this.chatModel).prompt()
.user("What's the weather like in San Francisco, Tokyo, and Paris?")
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
.function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
.stream()
.content();
// @formatter:on
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,16 +212,38 @@ interface ChatClientRequestSpec {

<T extends ChatOptions> ChatClientRequestSpec options(T options);

/**
* @deprecated Use
* {@link #function(String, String, Class, java.util.function.Function)} instead.
* Because of JVM type erasure, for lambda to work, the inputType class is
* required to be provided explicitly.
*/
@Deprecated
<I, O> ChatClientRequestSpec function(String name, String description,
java.util.function.Function<I, O> function);

<I, O> ChatClientRequestSpec function(String name, String description, Class<I> inputType,
java.util.function.Function<I, O> function);

/**
* @deprecated Use
* {@link #functions(String, String, Class,java.util.function.BiFunction)} Because
* of JVM type erasure, for lambda to work, the inputType class is required to be
* provided explicitly.
*/
@Deprecated
<I, O> ChatClientRequestSpec function(String name, String description,
java.util.function.BiFunction<I, ToolContext, O> function);

<I, O> ChatClientRequestSpec function(String name, String description, Class<I> inputType,
java.util.function.BiFunction<I, ToolContext, O> function);

<I, O> ChatClientRequestSpec functions(FunctionCallback... functionCallbacks);

<I, O> ChatClientRequestSpec function(String name, String description, java.util.function.Supplier<O> supplier);

<I, O> ChatClientRequestSpec function(String name, String description, Class<I> inputType,
java.util.function.Function<I, O> function);
java.util.function.Consumer<I> consumer);

ChatClientRequestSpec functions(String... functionBeanNames);

Expand Down Expand Up @@ -278,11 +300,36 @@ interface Builder {

Builder defaultSystem(Consumer<PromptSystemSpec> systemSpecConsumer);

/**
* @deprecated Use
* {@link #defaultFunction(String, String, Class, java.util.function.Function)}
* instead. Because of JVM type erasure, for lambda to work, the inputType class
* is required to be provided explicitly.
*/
@Deprecated
<I, O> Builder defaultFunction(String name, String description, java.util.function.Function<I, O> function);

<I, O> Builder defaultFunction(String name, String description, Class<I> inputType,
java.util.function.Function<I, O> function);

/**
* @deprecated Use
* {@link #defaultFunction(String, String, Class, java.util.function.BiFunction)}
* instead. Because of JVM type erasure, for lambda to work, the inputType class
* is required to be provided explicitly.
*/
@Deprecated
<I, O> Builder defaultFunction(String name, String description,
java.util.function.BiFunction<I, ToolContext, O> function);

<I, O> Builder defaultFunction(String name, String description, Class<I> inputType,
java.util.function.BiFunction<I, ToolContext, O> function);

<I, O> Builder defaultFunction(String name, String description, java.util.function.Supplier<O> supplier);

<I, O> Builder defaultFunction(String name, String description, Class<I> inputType,
java.util.function.Consumer<I> consumer);

Builder defaultFunctions(String... functionNames);

Builder defaultFunctions(FunctionCallback... functionCallbacks);
Expand Down
Loading