Skip to content

Commit c30fda6

Browse files
committed
refactor: migrate from Class to ParameterizedTypeReference for type handling
Improves generic type support in function callbacks and JSON schema generation by: - Moving CustomizedTypeReference to ModelOptionsUtils - Updating BeanOutputConverter to use ParameterizedTypeReference - Modifying function callbacks to support generic type resolution - Adding train scheduler test case to validate generic type handling
1 parent 05c86d4 commit c30fda6

File tree

6 files changed

+108
-46
lines changed

6 files changed

+108
-46
lines changed

spring-ai-core/src/main/java/org/springframework/ai/converter/BeanOutputConverter.java

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
package org.springframework.ai.converter;
1818

19-
import java.lang.reflect.Type;
2019
import java.util.Objects;
2120

2221
import com.fasterxml.jackson.core.JsonProcessingException;
@@ -37,6 +36,7 @@
3736
import org.slf4j.Logger;
3837
import org.slf4j.LoggerFactory;
3938

39+
import org.springframework.ai.model.ModelOptionsUtils.CustomizedTypeReference;
4040
import org.springframework.ai.util.JacksonUtils;
4141
import org.springframework.core.ParameterizedTypeReference;
4242
import org.springframework.lang.NonNull;
@@ -94,7 +94,7 @@ public BeanOutputConverter(Class<T> clazz, ObjectMapper objectMapper) {
9494
* @param typeRef The target class type reference.
9595
*/
9696
public BeanOutputConverter(ParameterizedTypeReference<T> typeRef) {
97-
this(new CustomizedTypeReference<>(typeRef), null);
97+
this(CustomizedTypeReference.forType(typeRef), null);
9898
}
9999

100100
/**
@@ -105,7 +105,7 @@ public BeanOutputConverter(ParameterizedTypeReference<T> typeRef) {
105105
* @param objectMapper Custom object mapper for JSON operations. endings.
106106
*/
107107
public BeanOutputConverter(ParameterizedTypeReference<T> typeRef, ObjectMapper objectMapper) {
108-
this(new CustomizedTypeReference<>(typeRef), objectMapper);
108+
this(CustomizedTypeReference.forType(typeRef), objectMapper);
109109
}
110110

111111
/**
@@ -220,19 +220,4 @@ public String getJsonSchema() {
220220
return this.jsonSchema;
221221
}
222222

223-
private static class CustomizedTypeReference<T> extends TypeReference<T> {
224-
225-
private final Type type;
226-
227-
CustomizedTypeReference(ParameterizedTypeReference<T> typeRef) {
228-
this.type = typeRef.getType();
229-
}
230-
231-
@Override
232-
public Type getType() {
233-
return this.type;
234-
}
235-
236-
}
237-
238223
}

spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.beans.PropertyDescriptor;
2020
import java.lang.reflect.Field;
21+
import java.lang.reflect.Type;
2122
import java.util.ArrayList;
2223
import java.util.Arrays;
2324
import java.util.HashMap;
@@ -50,8 +51,8 @@
5051
import org.springframework.ai.util.JacksonUtils;
5152
import org.springframework.beans.BeanWrapper;
5253
import org.springframework.beans.BeanWrapperImpl;
54+
import org.springframework.core.ParameterizedTypeReference;
5355
import org.springframework.util.Assert;
54-
import org.springframework.util.ClassUtils;
5556
import org.springframework.util.CollectionUtils;
5657
import org.springframework.util.ObjectUtils;
5758

@@ -335,11 +336,11 @@ private static String toGetName(String name) {
335336

336337
/**
337338
* Generates JSON Schema (version 2020_12) for the given class.
338-
* @param clazz the class to generate JSON Schema for.
339+
* @param parameterizedType the class to generate JSON Schema for.
339340
* @param toUpperCaseTypeValues if true, the type values are converted to upper case.
340341
* @return the generated JSON Schema as a String.
341342
*/
342-
public static String getJsonSchema(Class<?> clazz, boolean toUpperCaseTypeValues) {
343+
public static String getJsonSchema(ParameterizedTypeReference<?> parameterizedType, boolean toUpperCaseTypeValues) {
343344

344345
if (SCHEMA_GENERATOR_CACHE.get() == null) {
345346

@@ -358,9 +359,10 @@ public static String getJsonSchema(Class<?> clazz, boolean toUpperCaseTypeValues
358359
SCHEMA_GENERATOR_CACHE.compareAndSet(null, generator);
359360
}
360361

361-
ObjectNode node = SCHEMA_GENERATOR_CACHE.get().generateSchema(clazz);
362+
ObjectNode node = SCHEMA_GENERATOR_CACHE.get()
363+
.generateSchema(CustomizedTypeReference.forType(parameterizedType).getType());
362364

363-
if (ClassUtils.isVoidType(clazz) && node.get("properties") == null) {
365+
if ((parameterizedType.getType() == Void.class) && node.get("properties") == null) {
364366
node.putObject("properties");
365367
}
366368

@@ -411,4 +413,23 @@ public static <T> T mergeOption(T runtimeValue, T defaultValue) {
411413
return ObjectUtils.isEmpty(runtimeValue) ? defaultValue : runtimeValue;
412414
}
413415

416+
public static class CustomizedTypeReference<T> extends TypeReference<T> {
417+
418+
private final Type type;
419+
420+
CustomizedTypeReference(ParameterizedTypeReference<T> typeRef) {
421+
this.type = typeRef.getType();
422+
}
423+
424+
@Override
425+
public Type getType() {
426+
return this.type;
427+
}
428+
429+
public static <T> CustomizedTypeReference<T> forType(ParameterizedTypeReference<T> typeRef) {
430+
return new CustomizedTypeReference<>(typeRef);
431+
}
432+
433+
}
434+
414435
}

spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import com.fasterxml.jackson.databind.ObjectMapper;
2525

2626
import org.springframework.ai.chat.model.ToolContext;
27+
import org.springframework.ai.model.ModelOptionsUtils.CustomizedTypeReference;
28+
import org.springframework.core.ParameterizedTypeReference;
2729
import org.springframework.util.Assert;
2830

2931
/**
@@ -47,7 +49,7 @@ abstract class AbstractFunctionCallback<I, O> implements BiFunction<I, ToolConte
4749

4850
private final String description;
4951

50-
private final Class<I> inputType;
52+
private final ParameterizedTypeReference<I> inputType;
5153

5254
private final String inputTypeSchema;
5355

@@ -70,8 +72,8 @@ abstract class AbstractFunctionCallback<I, O> implements BiFunction<I, ToolConte
7072
* @param objectMapper Used to convert the function's input and output types to and
7173
* from JSON.
7274
*/
73-
protected AbstractFunctionCallback(String name, String description, String inputTypeSchema, Class<I> inputType,
74-
Function<O, String> responseConverter, ObjectMapper objectMapper) {
75+
protected AbstractFunctionCallback(String name, String description, String inputTypeSchema,
76+
ParameterizedTypeReference<I> inputType, Function<O, String> responseConverter, ObjectMapper objectMapper) {
7577
Assert.notNull(name, "Name must not be null");
7678
Assert.notNull(description, "Description must not be null");
7779
Assert.notNull(inputType, "InputType must not be null");
@@ -116,9 +118,9 @@ public String call(String functionArguments) {
116118
return this.andThen(this.responseConverter).apply(request, null);
117119
}
118120

119-
private <T> T fromJson(String json, Class<T> targetClass) {
121+
private <T> T fromJson(String json, ParameterizedTypeReference<T> targetClass) {
120122
try {
121-
return this.objectMapper.readValue(json, targetClass);
123+
return this.objectMapper.readValue(json, CustomizedTypeReference.forType(targetClass));
122124
}
123125
catch (JsonProcessingException e) {
124126
throw new RuntimeException(e);

spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,35 +124,39 @@ else if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) {
124124
}
125125
}
126126
if (bean instanceof Function<?, ?> function) {
127+
// ResolvableType.forInstance(function);
127128
return FunctionCallbackWrapper.builder(function)
128129
.withName(beanName)
129130
.withSchemaType(this.schemaType)
130131
.withDescription(functionDescription)
131-
.withInputType(functionInputClass)
132+
.withInputType(functionInputType)
133+
// .withInputType(functionInputClass)
132134
.build();
133135
}
134136
if (bean instanceof Consumer<?> consumer) {
135137
return FunctionCallbackWrapper.builder(consumer)
136138
.withName(beanName)
137139
.withSchemaType(this.schemaType)
138140
.withDescription(functionDescription)
139-
.withInputType(functionInputClass)
141+
.withInputType(functionInputType)
142+
// .withInputType(functionInputClass)
140143
.build();
141144
}
142145
if (bean instanceof Supplier<?> supplier) {
143146
return FunctionCallbackWrapper.builder(supplier)
144147
.withName(beanName)
145148
.withSchemaType(this.schemaType)
146149
.withDescription(functionDescription)
147-
.withInputType(functionInputClass)
150+
.withInputType(functionInputType)
148151
.build();
149152
}
150153
else if (bean instanceof BiFunction<?, ?, ?>) {
151154
return FunctionCallbackWrapper.builder((BiFunction<?, ToolContext, ?>) bean)
152155
.withName(beanName)
153156
.withSchemaType(this.schemaType)
154157
.withDescription(functionDescription)
155-
.withInputType(functionInputClass)
158+
.withInputType(functionInputType)
159+
// .withInputType(functionInputClass)
156160
.build();
157161
}
158162
else {

spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
import org.springframework.ai.model.ModelOptionsUtils;
3131
import org.springframework.ai.model.function.FunctionCallbackContext.SchemaType;
3232
import org.springframework.ai.util.JacksonUtils;
33+
import org.springframework.core.ParameterizedTypeReference;
34+
import org.springframework.core.ResolvableType;
3335
import org.springframework.util.Assert;
3436

3537
/**
@@ -46,8 +48,9 @@ public final class FunctionCallbackWrapper<I, O> extends AbstractFunctionCallbac
4648

4749
private final BiFunction<I, ToolContext, O> biFunction;
4850

49-
private FunctionCallbackWrapper(String name, String description, String inputTypeSchema, Class<I> inputType,
50-
Function<O, String> responseConverter, ObjectMapper objectMapper, BiFunction<I, ToolContext, O> function) {
51+
private FunctionCallbackWrapper(String name, String description, String inputTypeSchema,
52+
ParameterizedTypeReference<I> inputType, Function<O, String> responseConverter, ObjectMapper objectMapper,
53+
BiFunction<I, ToolContext, O> function) {
5154
super(name, description, inputTypeSchema, inputType, responseConverter, objectMapper);
5255
Assert.notNull(function, "Function must not be null");
5356
this.biFunction = function;
@@ -87,7 +90,7 @@ public static class Builder<I, O> {
8790

8891
private String description;
8992

90-
private Class<I> inputType;
93+
private ParameterizedTypeReference<I> inputType;
9194

9295
private SchemaType schemaType = SchemaType.JSON_SCHEMA;
9396

@@ -119,20 +122,20 @@ public Builder(Consumer<I> consumer) {
119122
this.consumer = consumer;
120123
}
121124

122-
@SuppressWarnings("unchecked")
123-
private static <I, O> Class<I> resolveInputType(BiFunction<I, ToolContext, O> biFunction) {
124-
return (Class<I>) TypeResolverHelper
125-
.getBiFunctionInputClass((Class<BiFunction<I, ToolContext, O>>) biFunction.getClass());
125+
private static <I, O> ParameterizedTypeReference<I> resolveInputType(BiFunction<I, ToolContext, O> biFunction) {
126+
127+
ResolvableType rt = TypeResolverHelper.getFunctionArgumentType(ResolvableType.forInstance(biFunction), 0);
128+
return ParameterizedTypeReference.forType(rt.getType());
126129
}
127130

128-
@SuppressWarnings("unchecked")
129-
private static <I, O> Class<I> resolveInputType(Function<I, O> function) {
130-
return (Class<I>) TypeResolverHelper.getFunctionInputClass((Class<Function<I, O>>) function.getClass());
131+
private static <I, O> ParameterizedTypeReference<I> resolveInputType(Function<I, O> function) {
132+
ResolvableType rt = TypeResolverHelper.getFunctionArgumentType(ResolvableType.forInstance(function), 0);
133+
return ParameterizedTypeReference.forType(rt.getType());
131134
}
132135

133-
@SuppressWarnings("unchecked")
134-
private static <I> Class<I> resolveInputType(Consumer<I> consumer) {
135-
return (Class<I>) TypeResolverHelper.getConsumerInputClass((Class<Consumer<I>>) consumer.getClass());
136+
private static <I> ParameterizedTypeReference<I> resolveInputType(Consumer<I> consumer) {
137+
ResolvableType rt = TypeResolverHelper.getFunctionArgumentType(ResolvableType.forInstance(consumer), 0);
138+
return ParameterizedTypeReference.forType(rt.getType());
136139
}
137140

138141
public Builder<I, O> withName(String name) {
@@ -149,7 +152,12 @@ public Builder<I, O> withDescription(String description) {
149152

150153
@SuppressWarnings("unchecked")
151154
public Builder<I, O> withInputType(Class<?> inputType) {
152-
this.inputType = (Class<I>) inputType;
155+
this.inputType = ParameterizedTypeReference.forType((Class<I>) inputType);
156+
return this;
157+
}
158+
159+
public Builder<I, O> withInputType(ResolvableType inputType) {
160+
this.inputType = ParameterizedTypeReference.forType(inputType.getType());
153161
return this;
154162
}
155163

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,26 @@ void functionCallingSupplier() {
304304
});
305305
}
306306

307+
@Test
308+
void trainScheduler() {
309+
this.contextRunner.run(context -> {
310+
311+
OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class);
312+
313+
// Test weatherFunction
314+
UserMessage userMessage = new UserMessage(
315+
"Please schedule a train from San Francisco to Los Angeles on 2023-12-25");
316+
317+
PortableFunctionCallingOptions functionOptions = FunctionCallingOptions.builder()
318+
.withFunction("trainReservation")
319+
.build();
320+
321+
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions));
322+
323+
logger.info("Response: {}", response.getResult().getOutput().getContent());
324+
});
325+
}
326+
307327
@Configuration
308328
static class Config {
309329

@@ -375,6 +395,28 @@ public Supplier<String> turnLivingRoomLightOnSupplier() {
375395
};
376396
}
377397

398+
record TrainSearchSchedule(String from, String to, String date) {
399+
}
400+
401+
record TrainSearchScheduleResponse(String from, String to, String date, String trainNumber) {
402+
}
403+
404+
record TrainSearchRequest<T>(T data) {
405+
}
406+
407+
record TrainSearchResponse<T>(T data) {
408+
}
409+
410+
@Bean
411+
@Description("Schedule a train reservation")
412+
public Function<TrainSearchRequest<TrainSearchSchedule>, TrainSearchResponse<TrainSearchScheduleResponse>> trainReservation() {
413+
return (TrainSearchRequest<TrainSearchSchedule> request) -> {
414+
logger.info("Turning light to [" + request.data().from() + "] in " + request.data().to());
415+
return new TrainSearchResponse<>(
416+
new TrainSearchScheduleResponse(request.data().from(), request.data().to(), "", "123"));
417+
};
418+
}
419+
378420
}
379421

380422
public static class MyBiFunction

0 commit comments

Comments
 (0)