Skip to content

Commit c848577

Browse files
committed
feat: adds the ability to specify the input type class when defining a function
This change adds the ability to specify the input type class when defining a function callback for the ChatClient. Previously, the input type had to be inferred, which caused issues with lambda expressions due to type erasure. The following new methods have been added: - ChatClient.ChatClientRequestSpec.function(String, String, Class<I>, Function<I, O>) - ChatClient.ChatClientRequestSpec.function(String, String, Class<I>, BiFunction<I, ToolContext, O>) - ChatClient.ChatClientRequestSpec.function(String, String, Class<I>, Consumer<I>) - ChatClient.Builder.defaultFunction(String, String, Class<I>, Function<I, O>) - ChatClient.Builder.defaultFunction(String, String, Class<I>, BiFunction<I, ToolContext, O>) - ChatClient.Builder.defaultFunction(String, String, Class<I>, Consumer<I>) The deprecated methods without the input type parameter have also been kept for backwards compatibility. This change should make it easier to use the ChatClient API, especially when dealing with lambda expressions.
1 parent 2753b01 commit c848577

File tree

12 files changed

+159
-62
lines changed

12 files changed

+159
-62
lines changed

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ void functionCallTest() {
211211
// @formatter:off
212212
String response = ChatClient.create(this.chatModel).prompt()
213213
.user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius."))
214-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
214+
.function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
215215
.call()
216216
.content();
217217
// @formatter:on
@@ -226,7 +226,7 @@ void defaultFunctionCallTest() {
226226

227227
// @formatter:off
228228
String response = ChatClient.builder(this.chatModel)
229-
.defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService())
229+
.defaultFunction("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
230230
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius."))
231231
.build()
232232
.prompt()
@@ -245,7 +245,7 @@ void streamFunctionCallTest() {
245245
// @formatter:off
246246
Flux<String> response = ChatClient.create(this.chatModel).prompt()
247247
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
248-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
248+
.function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
249249
.stream()
250250
.content();
251251
// @formatter:on

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ void functionCallTest() {
212212
// @formatter:off
213213
String response = ChatClient.create(this.chatModel)
214214
.prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.")
215-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
215+
.function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
216216
.call()
217217
.content();
218218
// @formatter:on
@@ -228,7 +228,7 @@ void functionCallWithAdvisorTest() {
228228
// @formatter:off
229229
String response = ChatClient.create(this.chatModel)
230230
.prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.")
231-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
231+
.function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
232232
.advisors(new SimpleLoggerAdvisor())
233233
.call()
234234
.content();
@@ -244,7 +244,7 @@ void defaultFunctionCallTest() {
244244

245245
// @formatter:off
246246
String response = ChatClient.builder(this.chatModel)
247-
.defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService())
247+
.defaultFunction("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
248248
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius."))
249249
.build()
250250
.prompt()
@@ -263,7 +263,7 @@ void streamFunctionCallTest() {
263263
// @formatter:off
264264
Flux<String> response = ChatClient.create(this.chatModel).prompt()
265265
.user("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.")
266-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
266+
.function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
267267
.stream()
268268
.content();
269269
// @formatter:on
@@ -280,7 +280,7 @@ void singularStreamFunctionCallTest() {
280280
// @formatter:off
281281
Flux<String> response = ChatClient.create(this.chatModel).prompt()
282282
.user("What's the weather like in Paris? Return the temperature in Celsius.")
283-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
283+
.function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
284284
.stream()
285285
.content();
286286
// @formatter:on

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ void functionCallTest() {
224224
String response = ChatClient.create(this.chatModel).prompt()
225225
.options(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).withToolChoice(ToolChoice.AUTO).build())
226226
.user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius."))
227-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
227+
.function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
228228
.call()
229229
.content();
230230
// @formatter:on
@@ -242,7 +242,7 @@ void defaultFunctionCallTest() {
242242
// @formatter:off
243243
String response = ChatClient.builder(this.chatModel)
244244
.defaultOptions(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).build())
245-
.defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService())
245+
.defaultFunction("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
246246
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius."))
247247
.build()
248248
.prompt().call().content();
@@ -262,7 +262,7 @@ void streamFunctionCallTest() {
262262
Flux<String> response = ChatClient.create(this.chatModel).prompt()
263263
.options(MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.SMALL).build())
264264
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.")
265-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
265+
.function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
266266
.stream()
267267
.content();
268268
// @formatter:on

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ void functionCallTest() {
249249
// @formatter:off
250250
String response = ChatClient.create(this.chatModel).prompt()
251251
.user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
252-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
252+
.function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
253253
.call()
254254
.content();
255255
// @formatter:on
@@ -287,12 +287,9 @@ record LightInfo(String roomName, boolean isOn) {
287287
// @formatter:off
288288
String response = ChatClient.create(this.chatModel).prompt()
289289
.user("Turn the light on in the kitchen and in the living room")
290-
.function("turnLight", "Turn light on or off in a room", new Consumer<LightInfo>() {
291-
@Override
292-
public void accept(LightInfo lightInfo) {
290+
.function("turnLight", "Turn light on or off in a room", LightInfo.class, (LightInfo lightInfo) -> {
293291
logger.info("Turning light to [" + lightInfo.isOn + "] in " + lightInfo.roomName());
294292
state.put(lightInfo.roomName(), lightInfo.isOn());
295-
}
296293
})
297294
.call()
298295
.content();
@@ -308,7 +305,8 @@ void defaultFunctionCallTest() {
308305

309306
// @formatter:off
310307
String response = ChatClient.builder(this.chatModel)
311-
.defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService())
308+
.defaultFunction("getCurrentWeather", "Get the weather in location",
309+
MockWeatherService.Request.class, new MockWeatherService())
312310
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
313311
.build()
314312
.prompt().call().content();
@@ -325,7 +323,8 @@ void streamFunctionCallTest() {
325323
// @formatter:off
326324
Flux<String> response = ChatClient.create(this.chatModel).prompt()
327325
.user("What's the weather like in San Francisco, Tokyo, and Paris?")
328-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
326+
.function("getCurrentWeather", "Get the weather in location",
327+
MockWeatherService.Request.class, new MockWeatherService())
329328
.stream()
330329
.content();
331330
// @formatter:on

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ void turnFunctionsOnAndOffTest() {
8383
// @formatter:off
8484
response = chatClientBuilder.build().prompt()
8585
.user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
86-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
86+
.function("getCurrentWeather", "Get the weather in location",
87+
MockWeatherService.Request.class, new MockWeatherService())
8788
.call()
8889
.content();
8990
// @formatter:on
@@ -110,7 +111,8 @@ void defaultFunctionCallTest() {
110111

111112
// @formatter:off
112113
String response = ChatClient.builder(this.chatModel)
113-
.defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService())
114+
.defaultFunction("getCurrentWeather", "Get the weather in location",
115+
MockWeatherService.Request.class, new MockWeatherService())
114116
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
115117
.build()
116118
.prompt().call().content();
@@ -149,7 +151,7 @@ else if (request.location().contains("San Francisco")) {
149151

150152
// @formatter:off
151153
String response = ChatClient.builder(this.chatModel)
152-
.defaultFunction("getCurrentWeather", "Get the weather in location", biFunction)
154+
.defaultFunction("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, biFunction)
153155
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
154156
.defaultToolContext(Map.of("sessionId", "123"))
155157
.build()
@@ -189,7 +191,7 @@ else if (request.location().contains("San Francisco")) {
189191

190192
// @formatter:off
191193
String response = ChatClient.builder(this.chatModel)
192-
.defaultFunction("getCurrentWeather", "Get the weather in location", biFunction)
194+
.defaultFunction("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, biFunction)
193195
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
194196
.build()
195197
.prompt()
@@ -208,7 +210,7 @@ void streamFunctionCallTest() {
208210
// @formatter:off
209211
Flux<String> response = ChatClient.create(this.chatModel).prompt()
210212
.user("What's the weather like in San Francisco, Tokyo, and Paris?")
211-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
213+
.function("getCurrentWeather", "Get the weather in location", MockWeatherService.Request.class, new MockWeatherService())
212214
.stream()
213215
.content();
214216
// @formatter:on

spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,20 +212,38 @@ interface ChatClientRequestSpec {
212212

213213
<T extends ChatOptions> ChatClientRequestSpec options(T options);
214214

215+
/**
216+
* @deprecated Use
217+
* {@link #function(String, String, Class, java.util.function.Function)} instead.
218+
* Because of JVM type erasure, for lambda to work, the inputType class is
219+
* required to be provided explicitly.
220+
*/
221+
@Deprecated
215222
<I, O> ChatClientRequestSpec function(String name, String description,
216223
java.util.function.Function<I, O> function);
217224

225+
<I, O> ChatClientRequestSpec function(String name, String description, Class<I> inputType,
226+
java.util.function.Function<I, O> function);
227+
228+
/**
229+
* @deprecated Use
230+
* {@link #functions(String, String, Class,java.util.function.BiFunction)} Because
231+
* of JVM type erasure, for lambda to work, the inputType class is required to be
232+
* provided explicitly.
233+
*/
234+
@Deprecated
218235
<I, O> ChatClientRequestSpec function(String name, String description,
219236
java.util.function.BiFunction<I, ToolContext, O> function);
220237

221-
<I, O> ChatClientRequestSpec functions(FunctionCallback... functionCallbacks);
222-
223238
<I, O> ChatClientRequestSpec function(String name, String description, Class<I> inputType,
224-
java.util.function.Function<I, O> function);
239+
java.util.function.BiFunction<I, ToolContext, O> function);
240+
241+
<I, O> ChatClientRequestSpec functions(FunctionCallback... functionCallbacks);
225242

226243
<I, O> ChatClientRequestSpec function(String name, String description, java.util.function.Supplier<O> supplier);
227244

228-
<I, O> ChatClientRequestSpec function(String name, String description, java.util.function.Consumer<I> consumer);
245+
<I, O> ChatClientRequestSpec function(String name, String description, Class<I> inputType,
246+
java.util.function.Consumer<I> consumer);
229247

230248
ChatClientRequestSpec functions(String... functionBeanNames);
231249

@@ -282,11 +300,36 @@ interface Builder {
282300

283301
Builder defaultSystem(Consumer<PromptSystemSpec> systemSpecConsumer);
284302

303+
/**
304+
* @deprecated Use
305+
* {@link #defaultFunction(String, String, Class, java.util.function.Function)}
306+
* instead. Because of JVM type erasure, for lambda to work, the inputType class
307+
* is required to be provided explicitly.
308+
*/
309+
@Deprecated
285310
<I, O> Builder defaultFunction(String name, String description, java.util.function.Function<I, O> function);
286311

312+
<I, O> Builder defaultFunction(String name, String description, Class<I> inputType,
313+
java.util.function.Function<I, O> function);
314+
315+
/**
316+
* @deprecated Use
317+
* {@link #defaultFunction(String, String, Class, java.util.function.BiFunction)}
318+
* instead. Because of JVM type erasure, for lambda to work, the inputType class
319+
* is required to be provided explicitly.
320+
*/
321+
@Deprecated
287322
<I, O> Builder defaultFunction(String name, String description,
288323
java.util.function.BiFunction<I, ToolContext, O> function);
289324

325+
<I, O> Builder defaultFunction(String name, String description, Class<I> inputType,
326+
java.util.function.BiFunction<I, ToolContext, O> function);
327+
328+
<I, O> Builder defaultFunction(String name, String description, java.util.function.Supplier<O> supplier);
329+
330+
<I, O> Builder defaultFunction(String name, String description, Class<I> inputType,
331+
java.util.function.Consumer<I> consumer);
332+
290333
Builder defaultFunctions(String... functionNames);
291334

292335
Builder defaultFunctions(FunctionCallback... functionCallbacks);

spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -836,11 +836,40 @@ public <T extends ChatOptions> ChatClientRequestSpec options(T options) {
836836
return this;
837837
}
838838

839+
/**
840+
* @deprecated since 1.0.0 in favor of
841+
* {@link #function(String, String, Class, java.util.function.Function)} Because
842+
* of JVM type erasure, the inputType class is required to be provided explicitly.
843+
*/
844+
@Deprecated(since = "1.0.0", forRemoval = true)
839845
public <I, O> ChatClientRequestSpec function(String name, String description,
840846
java.util.function.Function<I, O> function) {
841847
return this.function(name, description, null, function);
842848
}
843849

850+
public <I, O> ChatClientRequestSpec function(String name, String description, @Nullable Class<I> inputType,
851+
java.util.function.Function<I, O> function) {
852+
853+
Assert.hasText(name, "name cannot be null or empty");
854+
Assert.hasText(description, "description cannot be null or empty");
855+
Assert.notNull(function, "function cannot be null");
856+
857+
var fcw = FunctionCallbackWrapper.builder(function)
858+
.withDescription(description)
859+
.withName(name)
860+
.withInputType(inputType)
861+
.withResponseConverter(Object::toString)
862+
.build();
863+
this.functionCallbacks.add(fcw);
864+
return this;
865+
}
866+
867+
/**
868+
* @deprecated since 1.0.0 in favor of
869+
* {@link #function(String, String, Class, java.util.function.BiFunction)} Because
870+
* of JVM type erasure, the inputType class is required to be provided explicitly.
871+
*/
872+
@Deprecated
844873
public <I, O> ChatClientRequestSpec function(String name, String description,
845874
java.util.function.BiFunction<I, ToolContext, O> biFunction) {
846875

@@ -857,14 +886,14 @@ public <I, O> ChatClientRequestSpec function(String name, String description,
857886
return this;
858887
}
859888

860-
public <I, O> ChatClientRequestSpec function(String name, String description, @Nullable Class<I> inputType,
861-
java.util.function.Function<I, O> function) {
889+
public <I, O> ChatClientRequestSpec function(String name, String description, Class<I> inputType,
890+
java.util.function.BiFunction<I, ToolContext, O> biFunction) {
862891

863892
Assert.hasText(name, "name cannot be null or empty");
864893
Assert.hasText(description, "description cannot be null or empty");
865-
Assert.notNull(function, "function cannot be null");
894+
Assert.notNull(biFunction, "biFunction cannot be null");
866895

867-
var fcw = FunctionCallbackWrapper.builder(function)
896+
FunctionCallbackWrapper<I, O> fcw = FunctionCallbackWrapper.builder(biFunction)
868897
.withDescription(description)
869898
.withName(name)
870899
.withInputType(inputType)
@@ -890,7 +919,7 @@ public <I, O> ChatClientRequestSpec function(String name, String description,
890919
return this;
891920
}
892921

893-
public <I, O> ChatClientRequestSpec function(String name, String description,
922+
public <I, O> ChatClientRequestSpec function(String name, String description, Class<I> inputType,
894923
java.util.function.Consumer<I> consumer) {
895924

896925
Assert.hasText(name, "name cannot be null or empty");
@@ -900,7 +929,7 @@ public <I, O> ChatClientRequestSpec function(String name, String description,
900929
var fcw = FunctionCallbackWrapper.builder(consumer)
901930
.withDescription(description)
902931
.withName(name)
903-
// .withResponseConverter(Object::toString)
932+
.withInputType(inputType)
904933
.build();
905934
this.functionCallbacks.add(fcw);
906935
return this;

spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,17 +143,42 @@ public Builder defaultSystem(Consumer<PromptSystemSpec> systemSpecConsumer) {
143143
return this;
144144
}
145145

146+
@Deprecated
146147
public <I, O> Builder defaultFunction(String name, String description, java.util.function.Function<I, O> function) {
147148
this.defaultRequest.function(name, description, function);
148149
return this;
149150
}
150151

152+
public <I, O> Builder defaultFunction(String name, String description, Class<I> inputType,
153+
java.util.function.Function<I, O> function) {
154+
this.defaultRequest.function(name, description, inputType, function);
155+
return this;
156+
}
157+
158+
@Deprecated
151159
public <I, O> Builder defaultFunction(String name, String description,
152160
java.util.function.BiFunction<I, ToolContext, O> biFunction) {
153161
this.defaultRequest.function(name, description, biFunction);
154162
return this;
155163
}
156164

165+
public <I, O> Builder defaultFunction(String name, String description, Class<I> inputType,
166+
java.util.function.BiFunction<I, ToolContext, O> biFunction) {
167+
this.defaultRequest.function(name, description, inputType, biFunction);
168+
return this;
169+
}
170+
171+
public <I, O> Builder defaultFunction(String name, String description, java.util.function.Supplier<O> supplier) {
172+
this.defaultRequest.function(name, description, supplier);
173+
return this;
174+
}
175+
176+
public <I, O> Builder defaultFunction(String name, String description, Class<I> inputType,
177+
java.util.function.Consumer<I> consumer) {
178+
this.defaultRequest.function(name, description, inputType, consumer);
179+
return this;
180+
}
181+
157182
public Builder defaultFunctions(String... functionNames) {
158183
this.defaultRequest.functions(functionNames);
159184
return this;

0 commit comments

Comments
 (0)