diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java index e79a8f760f0..c40539f8c91 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -80,6 +81,10 @@ public class AnthropicChatOptions implements ChatOptions, FunctionCallingOptions @JsonIgnore private Boolean proxyToolCalls; + + @JsonIgnore + private Map toolContext; + // @formatter:on public static Builder builder() { @@ -152,6 +157,16 @@ public Builder withProxyToolCalls(Boolean proxyToolCalls) { return this; } + public Builder withToolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + public AnthropicChatOptions build() { return this.options; } @@ -263,6 +278,16 @@ public void setProxyToolCalls(Boolean proxyToolCalls) { this.proxyToolCalls = proxyToolCalls; } + @Override + public Map getToolContext() { + return this.toolContext; + } + + @Override + public void setToolContext(Map toolContext) { + this.toolContext = toolContext; + } + @Override public AnthropicChatOptions copy() { return fromOptions(this); @@ -279,6 +304,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) .withFunctions(fromOptions.getFunctions()) .withProxyToolCalls(fromOptions.getProxyToolCalls()) + .withToolContext(fromOptions.getToolContext()) .build(); } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java index 848e84a91e1..5685fa43ee7 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -22,18 +22,17 @@ import java.util.Map; import java.util.Set; -import com.azure.ai.openai.models.AzureChatEnhancementConfiguration; -import com.fasterxml.jackson.annotation.JsonIgnore; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; - import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.boot.context.properties.NestedConfigurationProperty; import org.springframework.util.Assert; -import org.stringtemplate.v4.compiler.CodeGenerator.primary_return; + +import com.azure.ai.openai.models.AzureChatEnhancementConfiguration; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; /** * The configuration information for a chat completions request. Completions support a @@ -199,6 +198,10 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio @JsonIgnore private AzureChatEnhancementConfiguration enhancements; + @NestedConfigurationProperty + @JsonIgnore + private Map toolContext; + public static Builder builder() { return new Builder(); } @@ -312,6 +315,16 @@ public Builder withEnhancements(AzureChatEnhancementConfiguration enhancements) return this; } + public Builder withToolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + public AzureOpenAiChatOptions build() { return this.options; } @@ -498,6 +511,16 @@ public void setProxyToolCalls(Boolean proxyToolCalls) { this.proxyToolCalls = proxyToolCalls; } + @Override + public Map getToolContext() { + return this.toolContext; + } + + @Override + public void setToolContext(Map toolContext) { + this.toolContext = toolContext; + } + @Override public AzureOpenAiChatOptions copy() { return fromOptions(this); @@ -521,6 +544,7 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti .withLogprobs(fromOptions.isLogprobs()) .withTopLogprobs(fromOptions.getTopLogProbs()) .withEnhancements(fromOptions.getEnhancements()) + .withToolContext(fromOptions.getToolContext()) .build(); } diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java index 30426eb6c7c..7762b4a4fe6 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java @@ -29,6 +29,7 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; /** @@ -145,6 +146,11 @@ public class MiniMaxChatOptions implements FunctionCallingOptions, ChatOptions { @JsonIgnore private Boolean proxyToolCalls; + + @NestedConfigurationProperty + @JsonIgnore + private Map toolContext; + // @formatter:on public static Builder builder() { @@ -250,6 +256,16 @@ public Builder withProxyToolCalls(Boolean proxyToolCalls) { return this; } + public Builder withToolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + public MiniMaxChatOptions build() { return this.options; } @@ -411,6 +427,16 @@ public void setProxyToolCalls(Boolean proxyToolCalls) { this.proxyToolCalls = proxyToolCalls; } + @Override + public Map getToolContext() { + return this.toolContext; + } + + @Override + public void setToolContext(Map toolContext) { + this.toolContext = toolContext; + } + @Override public int hashCode() { final int prime = 31; @@ -429,6 +455,7 @@ public int hashCode() { result = prime * result + ((tools == null) ? 0 : tools.hashCode()); result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode()); result = prime * result + ((proxyToolCalls == null) ? 0 : proxyToolCalls.hashCode()); + result = prime * result + ((toolContext == null) ? 0 : toolContext.hashCode()); return result; } @@ -525,6 +552,14 @@ else if (!toolChoice.equals(other.toolChoice)) } else if (!proxyToolCalls.equals(other.proxyToolCalls)) return false; + + if (this.toolContext == null) { + if (other.toolContext != null) + return false; + } + else if (!toolContext.equals(other.toolContext)) + return false; + return true; } @@ -550,6 +585,7 @@ public static MiniMaxChatOptions fromOptions(MiniMaxChatOptions fromOptions) { .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) .withFunctions(fromOptions.getFunctions()) .withProxyToolCalls(fromOptions.getProxyToolCalls()) + .withToolContext(fromOptions.getToolContext()) .build(); } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java index 7053f5ab0b0..f2d8185232f 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java @@ -18,6 +18,8 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Map; +import java.util.Objects; import java.util.Set; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -138,6 +140,10 @@ public class MistralAiChatOptions implements FunctionCallingOptions, ChatOptions @JsonIgnore private Boolean proxyToolCalls; + @NestedConfigurationProperty + @JsonIgnore + private Map toolContext; + public static Builder builder() { return new Builder(); } @@ -223,6 +229,16 @@ public Builder withProxyToolCalls(Boolean proxyToolCalls) { return this; } + public Builder withToolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + public MistralAiChatOptions build() { return this.options; } @@ -373,6 +389,16 @@ public void setProxyToolCalls(Boolean proxyToolCalls) { this.proxyToolCalls = proxyToolCalls; } + @Override + public Map getToolContext() { + return this.toolContext; + } + + @Override + public void setToolContext(Map toolContext) { + this.toolContext = toolContext; + } + @Override public MistralAiChatOptions copy() { return fromOptions(this); @@ -392,113 +418,37 @@ public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions) .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) .withFunctions(fromOptions.getFunctions()) .withProxyToolCalls(fromOptions.getProxyToolCalls()) + .withToolContext(fromOptions.getToolContext()) .build(); } @Override public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((model == null) ? 0 : model.hashCode()); - result = prime * result + ((temperature == null) ? 0 : temperature.hashCode()); - result = prime * result + ((topP == null) ? 0 : topP.hashCode()); - result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode()); - result = prime * result + ((safePrompt == null) ? 0 : safePrompt.hashCode()); - result = prime * result + ((randomSeed == null) ? 0 : randomSeed.hashCode()); - result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); - result = prime * result + ((stop == null) ? 0 : stop.hashCode()); - result = prime * result + ((tools == null) ? 0 : tools.hashCode()); - result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode()); - result = prime * result + ((functionCallbacks == null) ? 0 : functionCallbacks.hashCode()); - result = prime * result + ((functions == null) ? 0 : functions.hashCode()); - result = prime * result + ((proxyToolCalls == null) ? 0 : proxyToolCalls.hashCode()); - return result; + + return Objects.hash(model, temperature, topP, maxTokens, safePrompt, randomSeed, responseFormat, stop, tools, + toolChoice, functionCallbacks, functions, proxyToolCalls, toolContext); } @Override public boolean equals(Object obj) { if (this == obj) return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) + + if (obj == null || getClass() != obj.getClass()) return false; + MistralAiChatOptions other = (MistralAiChatOptions) obj; - if (model == null) { - if (other.model != null) - return false; - } - else if (!model.equals(other.model)) - return false; - if (temperature == null) { - if (other.temperature != null) - return false; - } - else if (!temperature.equals(other.temperature)) - return false; - if (topP == null) { - if (other.topP != null) - return false; - } - else if (!topP.equals(other.topP)) - return false; - if (maxTokens == null) { - if (other.maxTokens != null) - return false; - } - else if (!maxTokens.equals(other.maxTokens)) - return false; - if (safePrompt == null) { - if (other.safePrompt != null) - return false; - } - else if (!safePrompt.equals(other.safePrompt)) - return false; - if (randomSeed == null) { - if (other.randomSeed != null) - return false; - } - else if (!randomSeed.equals(other.randomSeed)) - return false; - if (responseFormat == null) { - if (other.responseFormat != null) - return false; - } - else if (!responseFormat.equals(other.responseFormat)) - return false; - if (stop == null) { - if (other.stop != null) - return false; - } - else if (!stop.equals(other.stop)) - return false; - if (tools == null) { - if (other.tools != null) - return false; - } - else if (!tools.equals(other.tools)) - return false; - if (toolChoice != other.toolChoice) - return false; - if (functionCallbacks == null) { - if (other.functionCallbacks != null) - return false; - } - else if (!functionCallbacks.equals(other.functionCallbacks)) - return false; - if (functions == null) { - if (other.functions != null) - return false; - } - else if (!functions.equals(other.functions)) - return false; - if (proxyToolCalls == null) { - if (other.proxyToolCalls != null) - return false; - } - else if (!proxyToolCalls.equals(other.proxyToolCalls)) - return false; - return true; + + return Objects.equals(this.model, other.model) && Objects.equals(this.temperature, other.temperature) + && Objects.equals(this.topP, other.topP) && Objects.equals(this.maxTokens, other.maxTokens) + && Objects.equals(this.safePrompt, other.safePrompt) + && Objects.equals(this.randomSeed, other.randomSeed) + && Objects.equals(this.responseFormat, other.responseFormat) && Objects.equals(this.stop, other.stop) + && Objects.equals(this.tools, other.tools) && Objects.equals(this.toolChoice, other.toolChoice) + && Objects.equals(this.functionCallbacks, other.functionCallbacks) + && Objects.equals(this.functions, other.functions) + && Objects.equals(this.proxyToolCalls, other.proxyToolCalls) + && Objects.equals(this.toolContext, other.toolContext); } } diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java index 6eedc5ef1fb..b5dd8109795 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java @@ -28,6 +28,7 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; /** @@ -140,6 +141,10 @@ public class MoonshotChatOptions implements FunctionCallingOptions, ChatOptions @JsonIgnore private Boolean proxyToolCalls; + @NestedConfigurationProperty + @JsonIgnore + private Map toolContext; + @Override public List getFunctionCallbacks() { return this.functionCallbacks; @@ -252,6 +257,16 @@ public Builder withProxyToolCalls(Boolean proxyToolCalls) { return this; } + public Builder withToolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + public MoonshotChatOptions build() { return this.options; } @@ -362,6 +377,16 @@ public void setProxyToolCalls(Boolean proxyToolCalls) { this.proxyToolCalls = proxyToolCalls; } + @Override + public Map getToolContext() { + return this.toolContext; + } + + @Override + public void setToolContext(Map toolContext) { + this.toolContext = toolContext; + } + @Override public MoonshotChatOptions copy() { return builder().withModel(this.model) @@ -378,6 +403,7 @@ public MoonshotChatOptions copy() { .withFunctionCallbacks(this.functionCallbacks) .withFunctions(this.functions) .withProxyToolCalls(this.proxyToolCalls) + .withToolContext(this.toolContext) .build(); } @@ -395,6 +421,7 @@ public int hashCode() { result = prime * result + ((topP == null) ? 0 : topP.hashCode()); result = prime * result + ((user == null) ? 0 : user.hashCode()); result = prime * result + ((proxyToolCalls == null) ? 0 : proxyToolCalls.hashCode()); + result = prime * result + ((toolContext == null) ? 0 : toolContext.hashCode()); return result; } @@ -465,6 +492,11 @@ else if (!this.user.equals(other.user)) } else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) return false; + if (this.toolContext == null) { + return other.toolContext == null; + } + else if (!this.toolContext.equals(other.toolContext)) + return false; return true; } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index 530e4361f18..8adee039fd7 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -300,6 +300,10 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed @JsonIgnore private Boolean proxyToolCalls; + @NestedConfigurationProperty + @JsonIgnore + private Map toolContext; + public static OllamaOptions builder() { return new OllamaOptions(); } @@ -502,6 +506,16 @@ public OllamaOptions withProxyToolCalls(Boolean proxyToolCalls) { return this; } + public OllamaOptions withToolContext(Map toolContext) { + if (this.toolContext == null) { + this.toolContext = toolContext; + } + else { + this.toolContext.putAll(toolContext); + } + return this; + } + // ------------------- // Getters and Setters // ------------------- @@ -832,6 +846,16 @@ public void setProxyToolCalls(Boolean proxyToolCalls) { this.proxyToolCalls = proxyToolCalls; } + @Override + public Map getToolContext() { + return this.toolContext; + } + + @Override + public void setToolContext(Map toolContext) { + this.toolContext = toolContext; + } + /** * Convert the {@link OllamaOptions} object to a {@link Map} of key/value pairs. * @return The {@link Map} of key/value pairs. @@ -901,7 +925,8 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) { .withStop(fromOptions.getStop()) .withFunctions(fromOptions.getFunctions()) .withProxyToolCalls(fromOptions.getProxyToolCalls()) - .withFunctionCallbacks(fromOptions.getFunctionCallbacks()); + .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) + .withToolContext(fromOptions.getToolContext()); } // @formatter:on @@ -930,7 +955,8 @@ public boolean equals(Object o) { && Objects.equals(mirostatTau, that.mirostatTau) && Objects.equals(mirostatEta, that.mirostatEta) && Objects.equals(penalizeNewline, that.penalizeNewline) && Objects.equals(stop, that.stop) && Objects.equals(functionCallbacks, that.functionCallbacks) - && Objects.equals(proxyToolCalls, that.proxyToolCalls) && Objects.equals(functions, that.functions); + && Objects.equals(proxyToolCalls, that.proxyToolCalls) && Objects.equals(functions, that.functions) + && Objects.equals(toolContext, that.toolContext); } @Override @@ -940,7 +966,8 @@ public int hashCode() { this.useMMap, this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict, this.topK, this.topP, tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty, this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta, - this.penalizeNewline, this.stop, this.functionCallbacks, this.functions, this.proxyToolCalls); + this.penalizeNewline, this.stop, this.functionCallbacks, this.functions, this.proxyToolCalls, + this.toolContext); } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index 6303ffe3bc7..3bc8c04ff9e 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -191,6 +191,11 @@ public class OpenAiChatOptions implements FunctionCallingOptions, ChatOptions { @NestedConfigurationProperty @JsonIgnore private Map httpHeaders = new HashMap<>(); + + @NestedConfigurationProperty + @JsonIgnore + private Map toolContext; + // @formatter:on public static Builder builder() { @@ -336,6 +341,16 @@ public Builder withHttpHeaders(Map httpHeaders) { return this; } + public Builder withToolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + public OpenAiChatOptions build() { return this.options; } @@ -561,6 +576,16 @@ public Integer getTopK() { return null; } + @Override + public Map getToolContext() { + return this.toolContext; + } + + @Override + public void setToolContext(Map toolContext) { + this.toolContext = toolContext; + } + @Override public OpenAiChatOptions copy() { return OpenAiChatOptions.fromOptions(this); @@ -591,6 +616,7 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) { .withFunctions(fromOptions.getFunctions()) .withHttpHeaders(fromOptions.getHttpHeaders()) .withProxyToolCalls(fromOptions.getProxyToolCalls()) + .withToolContext(fromOptions.getToolContext()) .build(); } @@ -600,7 +626,7 @@ public int hashCode() { this.maxTokens, this.maxCompletionTokens, this.n, this.presencePenalty, this.responseFormat, this.streamOptions, this.seed, this.stop, this.temperature, this.topP, this.tools, this.toolChoice, this.user, this.parallelToolCalls, this.functionCallbacks, this.functions, this.httpHeaders, - this.proxyToolCalls); + this.proxyToolCalls, this.toolContext); } @Override @@ -625,6 +651,7 @@ public boolean equals(Object o) { && Objects.equals(this.functionCallbacks, other.functionCallbacks) && Objects.equals(this.functions, other.functions) && Objects.equals(this.httpHeaders, other.httpHeaders) + && Objects.equals(this.toolContext, other.toolContext) && Objects.equals(this.proxyToolCalls, other.proxyToolCalls); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java index 22220566503..6834a74df06 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java @@ -31,6 +31,8 @@ import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.tool.MockWeatherService; +import org.springframework.ai.openai.api.tool.MockWeatherService.Request; +import org.springframework.ai.openai.api.tool.MockWeatherService.Response; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; @@ -39,6 +41,8 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -54,19 +58,58 @@ class OpenAiChatModelFunctionCallingIT { @Test void functionCallTest() { + functionCallTest(OpenAiChatOptions.builder() + .withModel(OpenAiApi.ChatModel.GPT_4_O.getValue()) + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the weather in location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build()); + } - UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + @Test + void functionCallWithToolContextTest() { - List messages = new ArrayList<>(List.of(userMessage)); + var biFunction = new BiFunction, MockWeatherService.Response>() { + + @Override + public Response apply(Request request, Map toolContext) { + + assertThat(toolContext).containsEntry("sessionId", "123"); - var promptOptions = OpenAiChatOptions.builder() + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new MockWeatherService.Response(temperature, 15, 20, 2, 53, 45, MockWeatherService.Unit.C); + } + + }; + + functionCallTest(OpenAiChatOptions.builder() .withModel(OpenAiApi.ChatModel.GPT_4_O.getValue()) - .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(biFunction) .withName("getCurrentWeather") .withDescription("Get the weather in location") .withResponseConverter((response) -> "" + response.temp() + response.unit()) .build())) - .build(); + .withToolContext(Map.of("sessionId", "123")) + .build()); + } + + void functionCallTest(OpenAiChatOptions promptOptions) { + + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + + List messages = new ArrayList<>(List.of(userMessage)); ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); @@ -78,19 +121,59 @@ void functionCallTest() { @Test void streamFunctionCallTest() { - UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + streamFunctionCallTest(OpenAiChatOptions.builder() + .withFunctionCallbacks(List.of((FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the weather in location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build()))) + .build()); + } - List messages = new ArrayList<>(List.of(userMessage)); + @Test + void streamFunctionCallWithToolContextTest() { - var promptOptions = OpenAiChatOptions.builder() - // .withModel(OpenAiApi.ChatModel.GPT_4_TURBO_PREVIEW.getValue()) - .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + var biFunction = new BiFunction, MockWeatherService.Response>() { + + @Override + public Response apply(Request request, Map toolContext) { + + assertThat(toolContext).containsEntry("sessionId", "123"); + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new MockWeatherService.Response(temperature, 15, 20, 2, 53, 45, MockWeatherService.Unit.C); + } + + }; + + OpenAiChatOptions promptOptions = OpenAiChatOptions.builder() + .withFunctionCallbacks(List.of((FunctionCallbackWrapper.builder(biFunction) .withName("getCurrentWeather") .withDescription("Get the weather in location") .withResponseConverter((response) -> "" + response.temp() + response.unit()) - .build())) + .build()))) + .withToolContext(Map.of("sessionId", "123")) .build(); + streamFunctionCallTest(promptOptions); + } + + void streamFunctionCallTest(OpenAiChatOptions promptOptions) { + + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + + List messages = new ArrayList<>(List.of(userMessage)); + Flux response = chatModel.stream(new Prompt(messages, promptOptions)); String content = response.collectList() diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java index 812ba3fc08c..0edbbcf326a 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java @@ -137,11 +137,11 @@ void functionCallWithExplicitInputType() throws NoSuchMethodException { MyFunction myFunction = new MyFunction(); Function function = createFunction(myFunction, currentTemp); - ChatClient.ChatClientRequestSpec chatClientRequestSpec = chatClient.prompt() + String content = chatClient.prompt() .user("What's the weather like in Shanghai?") - .function("currentTemp", "get current temp", MyFunction.Req.class, function); - - String content = chatClientRequestSpec.call().content(); + .function("currentTemp", "get current temp", MyFunction.Req.class, function) + .call() + .content(); assertThat(content).contains("23"); } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java index ce46d4f97c6..9088916457b 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Set; @@ -118,6 +119,10 @@ public enum TransportType { @JsonIgnore private Boolean proxyToolCalls; + @NestedConfigurationProperty + @JsonIgnore + private Map toolContext; + // @formatter:on public static Builder builder() { @@ -201,6 +206,16 @@ public Builder withProxyToolCalls(boolean proxyToolCalls) { return this; } + public Builder withToolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + public VertexAiGeminiChatOptions build() { return this.options; } @@ -337,6 +352,16 @@ public void setProxyToolCalls(Boolean proxyToolCalls) { this.proxyToolCalls = proxyToolCalls; } + @Override + public Map getToolContext() { + return this.toolContext; + } + + @Override + public void setToolContext(Map toolContext) { + this.toolContext = toolContext; + } + @Override public boolean equals(Object o) { if (this == o) @@ -349,13 +374,14 @@ public boolean equals(Object o) { && Objects.equals(maxOutputTokens, that.maxOutputTokens) && Objects.equals(model, that.model) && Objects.equals(responseMimeType, that.responseMimeType) && Objects.equals(functionCallbacks, that.functionCallbacks) - && Objects.equals(functions, that.functions) && Objects.equals(proxyToolCalls, that.proxyToolCalls); + && Objects.equals(functions, that.functions) && Objects.equals(proxyToolCalls, that.proxyToolCalls) + && Objects.equals(toolContext, that.toolContext); } @Override public int hashCode() { return Objects.hash(stopSequences, temperature, topP, topK, candidateCount, maxOutputTokens, model, - responseMimeType, functionCallbacks, functions, googleSearchRetrieval, proxyToolCalls); + responseMimeType, functionCallbacks, functions, googleSearchRetrieval, proxyToolCalls, toolContext); } @Override @@ -387,6 +413,7 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr options.setResponseMimeType(fromOptions.getResponseMimeType()); options.setGoogleSearchRetrieval(fromOptions.getGoogleSearchRetrieval()); options.setProxyToolCalls(fromOptions.getProxyToolCalls()); + options.setToolContext(fromOptions.getToolContext()); return options; } diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatModelFunctionCallingIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatModelFunctionCallingIT.java index 836155f3835..7c97cdca0b0 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatModelFunctionCallingIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatModelFunctionCallingIT.java @@ -32,8 +32,8 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.function.FunctionCallbackContext.SchemaType; import org.springframework.ai.model.function.FunctionCallbackWrapper; -import org.springframework.ai.model.function.FunctionCallbackWrapper.Builder.SchemaType; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions; import org.springframework.beans.factory.annotation.Autowired; diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java index 0572d4151b9..7f6ec1d3ad2 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java @@ -34,7 +34,7 @@ import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; import org.springframework.ai.model.function.FunctionCallbackContext; -import org.springframework.ai.model.function.FunctionCallbackWrapper.Builder.SchemaType; +import org.springframework.ai.model.function.FunctionCallbackContext.SchemaType; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions; import org.springframework.beans.factory.annotation.Autowired; diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java index d33f04d79f3..e30ab9666d5 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java @@ -30,6 +30,7 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; /** @@ -126,6 +127,10 @@ public class ZhiPuAiChatOptions implements FunctionCallingOptions, ChatOptions { @JsonIgnore private Boolean proxyToolCalls; + + @NestedConfigurationProperty + @JsonIgnore + private Map toolContext; // @formatter:on public static Builder builder() { @@ -216,6 +221,16 @@ public Builder withProxyToolCalls(Boolean proxyToolCalls) { return this; } + public Builder withToolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + public ZhiPuAiChatOptions build() { return this.options; } @@ -363,6 +378,16 @@ public void setProxyToolCalls(Boolean proxyToolCalls) { this.proxyToolCalls = proxyToolCalls; } + @Override + public Map getToolContext() { + return this.toolContext; + } + + @Override + public void setToolContext(Map toolContext) { + this.toolContext = toolContext; + } + @Override public int hashCode() { final int prime = 31; @@ -376,6 +401,7 @@ public int hashCode() { result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode()); result = prime * result + ((user == null) ? 0 : user.hashCode()); result = prime * result + ((proxyToolCalls == null) ? 0 : proxyToolCalls.hashCode()); + result = prime * result + ((toolContext == null) ? 0 : toolContext.hashCode()); return result; } @@ -454,6 +480,12 @@ else if (!this.doSample.equals(other.doSample)) } else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) return false; + if (this.toolContext == null) { + if (other.toolContext != null) + return false; + } + else if (!this.toolContext.equals(other.toolContext)) + return false; return true; } @@ -477,6 +509,7 @@ public static ZhiPuAiChatOptions fromOptions(ZhiPuAiChatOptions fromOptions) { .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) .withFunctions(fromOptions.getFunctions()) .withProxyToolCalls(fromOptions.getProxyToolCalls()) + .withToolContext(fromOptions.getToolContext()) .build(); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java index d74a41ddbf8..81d96ffd3c4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -30,6 +30,7 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.converter.StructuredOutputConverter; import org.springframework.ai.model.Media; +import org.springframework.ai.model.function.FunctionCallback; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.io.Resource; import org.springframework.util.MimeType; @@ -200,6 +201,11 @@ interface ChatClientRequestSpec { ChatClientRequestSpec function(String name, String description, java.util.function.Function function); + ChatClientRequestSpec function(String name, String description, + java.util.function.BiFunction, O> function); + + ChatClientRequestSpec functions(FunctionCallback... functionCallbacks); + ChatClientRequestSpec function(String name, String description, Class inputType, java.util.function.Function function); @@ -258,8 +264,13 @@ interface Builder { Builder defaultFunction(String name, String description, java.util.function.Function function); + Builder defaultFunction(String name, String description, + java.util.function.BiFunction, O> function); + Builder defaultFunctions(String... functionNames); + Builder defaultFunctions(FunctionCallback... functionCallbacks); + ChatClient build(); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index e470b4b02ac..42c21c5097f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -481,6 +481,8 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe private final DefaultAroundAdvisorChain.Builder aroundAdvisorChainBuilder; + private final Map toolContext = new HashMap<>(); + private ObservationRegistry getObservationRegistry() { return this.observationRegistry; } @@ -533,6 +535,10 @@ public List getFunctionCallbacks() { return this.functionCallbacks; } + public Map getToolContext() { + return this.toolContext; + } + /* copy constructor */ DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) { this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.functionCallbacks, @@ -678,6 +684,22 @@ public ChatClientRequestSpec function(String name, String description, return this.function(name, description, null, function); } + public ChatClientRequestSpec function(String name, String description, + java.util.function.BiFunction, O> biFunction) { + + Assert.hasText(name, "the name must be non-null and non-empty"); + Assert.hasText(description, "the description must be non-null and non-empty"); + Assert.notNull(biFunction, "the biFunction must be non-null"); + + FunctionCallbackWrapper fcw = FunctionCallbackWrapper.builder(biFunction) + .withDescription(description) + .withName(name) + .withResponseConverter(Object::toString) + .build(); + this.functionCallbacks.add(fcw); + return this; + } + public ChatClientRequestSpec function(String name, String description, Class inputType, java.util.function.Function function) { @@ -701,6 +723,12 @@ public ChatClientRequestSpec functions(String... functionBeanNames) { return this; } + public ChatClientRequestSpec functions(FunctionCallback... functionCallbacks) { + Assert.notNull(functionCallbacks, "the functionCallbacks must be non-null"); + this.functionCallbacks.addAll(Arrays.asList(functionCallbacks)); + return this; + } + public ChatClientRequestSpec system(String text) { Assert.notNull(text, "the text must be non-null"); this.systemText = text; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java index 411f19e8988..93af46815c3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java @@ -30,6 +30,7 @@ import org.springframework.ai.chat.client.observation.ChatClientObservationConvention; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.model.function.FunctionCallback; import org.springframework.core.io.Resource; import org.springframework.util.Assert; @@ -140,9 +141,20 @@ public Builder defaultFunction(String name, String description, java.util return this; } + public Builder defaultFunction(String name, String description, + java.util.function.BiFunction, O> biFunction) { + this.defaultRequest.function(name, description, biFunction); + return this; + } + public Builder defaultFunctions(String... functionNames) { this.defaultRequest.functions(functionNames); return this; } + public Builder defaultFunctions(FunctionCallback... functionCallbacks) { + this.defaultRequest.functions(functionCallbacks); + return this; + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java index 68b6b8abd7a..4cdac2c11ff 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java @@ -136,7 +136,13 @@ protected List handleToolCalls(Prompt prompt, ChatResponse response) { throw new IllegalStateException("No tool call generation found in the response!"); } AssistantMessage assistantMessage = toolCallGeneration.get().getOutput(); - ToolResponseMessage toolMessageResponse = this.executeFunctions(assistantMessage); + + Map toolContext = null; + if (prompt.getOptions() instanceof FunctionCallingOptions functionCallOptions) { + toolContext = functionCallOptions.getToolContext(); + } + ToolResponseMessage toolMessageResponse = this.executeFunctions(assistantMessage, toolContext); + return this.buildToolCallConversation(prompt.getInstructions(), assistantMessage, toolMessageResponse); } @@ -184,7 +190,7 @@ protected List resolveFunctionCallbacks(Set functionNa return retrievedFunctionCallbacks; } - protected ToolResponseMessage executeFunctions(AssistantMessage assistantMessage) { + protected ToolResponseMessage executeFunctions(AssistantMessage assistantMessage, Map toolContext) { List toolResponses = new ArrayList<>(); @@ -197,7 +203,8 @@ protected ToolResponseMessage executeFunctions(AssistantMessage assistantMessage throw new IllegalStateException("No function callback found for function name: " + functionName); } - String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments); + String functionResponse = this.functionCallbackRegister.get(functionName) + .call(functionArguments, toolContext); toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), functionName, functionResponse)); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java index 6bd639c883e..b118f85659e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java @@ -15,6 +15,9 @@ */ package org.springframework.ai.model.function; +import java.util.Map; +import java.util.Objects; +import java.util.function.BiFunction; import java.util.function.Function; import com.fasterxml.jackson.core.JsonProcessingException; @@ -37,7 +40,7 @@ * @param the 3rd party service output type. * @author Christian Tzolov */ -abstract class AbstractFunctionCallback implements Function, FunctionCallback { +abstract class AbstractFunctionCallback implements BiFunction, O>, FunctionCallback { private final String name; @@ -98,13 +101,18 @@ public String getInputTypeSchema() { } @Override - public String call(String functionArguments) { + public String call(String functionInput, Map toolContext) { + I request = fromJson(functionInput, inputType); + O response = this.apply(request, toolContext); + return this.responseConverter.apply(response); + } + @Override + public String call(String functionArguments) { // Convert the tool calls JSON arguments into a Java function request object. I request = fromJson(functionArguments, inputType); - // extend conversation with function response. - return this.andThen(this.responseConverter).apply(request); + return this.andThen(this.responseConverter).apply(request, null); } private T fromJson(String json, Class targetClass) { @@ -118,42 +126,21 @@ private T fromJson(String json, Class targetClass) { @Override public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((name == null) ? 0 : name.hashCode()); - result = prime * result + ((description == null) ? 0 : description.hashCode()); - result = prime * result + ((inputType == null) ? 0 : inputType.hashCode()); - return result; + return Objects.hash(name, description, inputType); } @Override public boolean equals(Object obj) { if (this == obj) return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) + if (obj == null || getClass() != obj.getClass()) return false; + AbstractFunctionCallback other = (AbstractFunctionCallback) obj; - if (name == null) { - if (other.name != null) - return false; - } - else if (!name.equals(other.name)) - return false; - if (description == null) { - if (other.description != null) - return false; - } - else if (!description.equals(other.description)) - return false; - if (inputType == null) { - if (other.inputType != null) - return false; - } - else if (!inputType.equals(other.inputType)) - return false; - return true; + + return Objects.equals(this.name, other.name) && Objects.equals(this.description, other.description) + && Objects.equals(this.inputType, other.inputType); + } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java index 6f4e8bac482..a679c3d241a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java @@ -15,6 +15,8 @@ */ package org.springframework.ai.model.function; +import java.util.Map; + /** * Represents a model function call handler. Implementations are registered with the * Models and called on prompts that trigger the function call. @@ -49,4 +51,25 @@ public interface FunctionCallback { */ public String call(String functionInput); + /** + * Called when a model detects and triggers a function call. The model is responsible + * to pass the function arguments in the pre-configured JSON schema format. + * Additionally the model can pass a context map to the function if available. The + * context is used to pass additional user provided state in addition to the arguments + * provided by the AI model. + * @param functionInput JSON string with the function arguments to be passed to the + * function. The arguments are defined as JSON schema usually registered with the the + * model. Arguments are provided by the AI model. + * @param functionContext Map with the function context. The context is used to pass + * additional user provided state in addition to the arguments provided by the AI + * model. + * @return String containing the function call response. + */ + default String call(String functionInput, Map functionContext) { + if (functionContext != null) { + throw new UnsupportedOperationException("Function context is not supported!"); + } + return call(functionInput); + } + } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java index 59f43098753..2814de8e0a1 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java @@ -16,11 +16,12 @@ package org.springframework.ai.model.function; import java.lang.reflect.Type; +import java.util.Map; +import java.util.function.BiFunction; import java.util.function.Function; import com.fasterxml.jackson.annotation.JsonClassDescription; -import org.springframework.ai.model.function.FunctionCallbackWrapper.Builder.SchemaType; import org.springframework.beans.BeansException; import org.springframework.cloud.function.context.catalog.FunctionTypeUtils; import org.springframework.cloud.function.context.config.FunctionContextUtils; @@ -52,6 +53,12 @@ public class FunctionCallbackContext implements ApplicationContextAware { private GenericApplicationContext applicationContext; + public enum SchemaType { + + JSON_SCHEMA, OPEN_API_SCHEMA + + } + private SchemaType schemaType = SchemaType.JSON_SCHEMA; public void setSchemaType(SchemaType schemaType) { @@ -73,9 +80,10 @@ public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable "Functional bean with name: " + beanName + " does not exist in the context."); } - if (!Function.class.isAssignableFrom(FunctionTypeUtils.getRawType(beanType))) { + if (!Function.class.isAssignableFrom(FunctionTypeUtils.getRawType(beanType)) + && !BiFunction.class.isAssignableFrom(FunctionTypeUtils.getRawType(beanType))) { throw new IllegalArgumentException( - "Function call Bean must be of type Function. Found: " + beanType.getTypeName()); + "Function call Bean must be of type Function or BiFunction. Found: " + beanType.getTypeName()); } Type functionInputType = TypeResolverHelper.getFunctionArgumentType(beanType, 0); @@ -118,6 +126,14 @@ public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable .withInputType(functionInputClass) .build(); } + else if (bean instanceof BiFunction biFunction) { + return FunctionCallbackWrapper.builder((BiFunction, ?>) biFunction) + .withName(functionName) + .withSchemaType(this.schemaType) + .withDescription(functionDescription) + .withInputType(functionInputClass) + .build(); + } else { throw new IllegalArgumentException("Bean must be of type Function"); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java index bfa9c9c3c28..1a9a0a96649 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java @@ -15,16 +15,19 @@ */ package org.springframework.ai.model.function; +import java.util.Map; +import java.util.function.BiFunction; import java.util.function.Function; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.function.FunctionCallbackContext.SchemaType; +import org.springframework.util.Assert; + import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.SerializationFeature; import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; -import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.util.Assert; - /** * Note that the underlying function is responsible for converting the output into format * that can be consumed by the Model. The default implementation converts the output into @@ -34,23 +37,23 @@ */ public class FunctionCallbackWrapper extends AbstractFunctionCallback { - private final Function function; + private final BiFunction, O> biFunction; private FunctionCallbackWrapper(String name, String description, String inputTypeSchema, Class inputType, - Function responseConverter, ObjectMapper objectMapper, Function function) { + Function responseConverter, ObjectMapper objectMapper, + BiFunction, O> function) { super(name, description, inputTypeSchema, inputType, responseConverter, objectMapper); Assert.notNull(function, "Function must not be null"); - this.function = function; + this.biFunction = function; } - @SuppressWarnings("unchecked") - private static Class resolveInputType(Function function) { - return (Class) TypeResolverHelper.getFunctionInputClass((Class>) function.getClass()); + @Override + public O apply(I input, Map context) { + return this.biFunction.apply(input, context); } - @Override - public O apply(I input) { - return this.function.apply(input); + public static Builder builder(BiFunction, O> biFunction) { + return new Builder<>(biFunction); } public static Builder builder(Function function) { @@ -59,24 +62,27 @@ public static Builder builder(Function function) { public static class Builder { - public enum SchemaType { - - JSON_SCHEMA, OPEN_API_SCHEMA - - } - private String name; private String description; private Class inputType; + private final BiFunction, O> biFunction; + private final Function function; private SchemaType schemaType = SchemaType.JSON_SCHEMA; + public Builder(BiFunction, O> biFunction) { + Assert.notNull(biFunction, "Function must not be null"); + this.biFunction = biFunction; + this.function = null; + } + public Builder(Function function) { Assert.notNull(function, "Function must not be null"); + this.biFunction = null; this.function = function; } @@ -136,12 +142,16 @@ public FunctionCallbackWrapper build() { Assert.hasText(this.name, "Name must not be empty"); Assert.hasText(this.description, "Description must not be empty"); - Assert.notNull(this.function, "Function must not be null"); Assert.notNull(this.responseConverter, "ResponseConverter must not be null"); Assert.notNull(this.objectMapper, "ObjectMapper must not be null"); if (this.inputType == null) { - this.inputType = resolveInputType(this.function); + if (this.function != null) { + this.inputType = resolveInputType(this.function); + } + else { + this.inputType = resolveInputType(this.biFunction); + } } if (this.inputTypeSchema == null) { @@ -149,8 +159,22 @@ public FunctionCallbackWrapper build() { this.inputTypeSchema = ModelOptionsUtils.getJsonSchema(this.inputType, upperCaseTypeValues); } + BiFunction, O> finalBiFunction = (this.biFunction != null) ? this.biFunction + : (request, context) -> this.function.apply(request); + return new FunctionCallbackWrapper<>(this.name, this.description, this.inputTypeSchema, this.inputType, - this.responseConverter, this.objectMapper, this.function); + this.responseConverter, this.objectMapper, finalBiFunction); + } + + @SuppressWarnings("unchecked") + private static Class resolveInputType(BiFunction, O> biFunction) { + return (Class) TypeResolverHelper + .getBiFunctionInputClass((Class, O>>) biFunction.getClass()); + } + + @SuppressWarnings("unchecked") + private static Class resolveInputType(Function function) { + return (Class) TypeResolverHelper.getFunctionInputClass((Class>) function.getClass()); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java index f953e907d33..0f9244ed35d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java @@ -16,6 +16,7 @@ package org.springframework.ai.model.function; import java.util.List; +import java.util.Map; import java.util.Set; /** @@ -72,4 +73,8 @@ public static FunctionCallingOptionsBuilder builder() { return new FunctionCallingOptionsBuilder(); } + Map getToolContext(); + + void setToolContext(Map functionContext); + } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java index 04c66c13d44..685618da518 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java @@ -16,8 +16,10 @@ package org.springframework.ai.model.function; import java.util.ArrayList; +import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; import org.springframework.ai.chat.prompt.ChatOptions; @@ -107,6 +109,19 @@ public FunctionCallingOptionsBuilder withProxyToolCalls(Boolean proxyToolCalls) return this; } + public FunctionCallingOptionsBuilder withToolContext(Map context) { + Assert.notNull(context, "Tool context must not be null"); + this.options.getToolContext().putAll(context); + return this; + } + + public FunctionCallingOptionsBuilder withToolContext(String key, Object value) { + Assert.notNull(key, "Key must not be null"); + Assert.notNull(value, "Value must not be null"); + this.options.getToolContext().put(key, value); + return this; + } + public PortableFunctionCallingOptions build() { return this.options; } @@ -135,6 +150,8 @@ public static class PortableFunctionCallingOptions implements FunctionCallingOpt private Boolean proxyToolCalls = false; + private Map context = new HashMap<>(); + public static FunctionCallingOptionsBuilder builder() { return new FunctionCallingOptionsBuilder(); } @@ -240,6 +257,14 @@ public void setProxyToolCalls(Boolean proxyToolCalls) { this.proxyToolCalls = proxyToolCalls; } + public Map getToolContext() { + return context; + } + + public void setToolContext(Map context) { + this.context = context; + } + @Override public ChatOptions copy() { return new FunctionCallingOptionsBuilder().withModel(this.model) @@ -253,6 +278,7 @@ public ChatOptions copy() { .withFunctions(this.functions) .withFunctionCallbacks(this.functionCallbacks) .withProxyToolCalls(this.proxyToolCalls) + .withToolContext(this.getToolContext()) .build(); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java index 1fa0736d3eb..656765a7ff1 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java @@ -18,6 +18,7 @@ import java.lang.reflect.GenericArrayType; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; +import java.util.function.BiFunction; import java.util.function.Function; import net.jodah.typetools.TypeResolver; @@ -32,6 +33,15 @@ */ public abstract class TypeResolverHelper { + /** + * Returns the input class of a given function class. + * @param biFunctionClass The function class. + * @return The input class of the function. + */ + public static Class getBiFunctionInputClass(Class> biFunctionClass) { + return getBiFunctionArgumentClass(biFunctionClass, 0); + } + /** * Returns the input class of a given function class. * @param functionClass The function class. @@ -65,6 +75,22 @@ public static Class getFunctionArgumentClass(Class> return toRawClass(argumentType); } + /** + * Retrieves the class of a specific argument in a given function class. + * @param biFunctionClass The function class. + * @param argumentIndex The index of the argument whose class should be retrieved. + * @return The class of the specified function argument. + */ + public static Class getBiFunctionArgumentClass(Class> biFunctionClass, + int argumentIndex) { + Type type = TypeResolver.reify(BiFunction.class, biFunctionClass); + + Type argumentType = type instanceof ParameterizedType + ? ((ParameterizedType) type).getActualTypeArguments()[argumentIndex] : Object.class; + + return toRawClass(argumentType); + } + /** * Returns the input type of a given function class. * @param functionClass The class of the function. diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java index b1b531b6629..6b7b8497883 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java @@ -21,7 +21,7 @@ import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; -import org.springframework.ai.model.function.FunctionCallbackWrapper.Builder.SchemaType; +import org.springframework.ai.model.function.FunctionCallbackContext.SchemaType; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.ImportAutoConfiguration; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java index 799a37623ee..d6b0c6ff082 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java @@ -16,6 +16,8 @@ package org.springframework.ai.autoconfigure.azure.tool; import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; import java.util.function.Function; import org.junit.jupiter.api.Test; @@ -89,6 +91,18 @@ public Function weather return new MockWeatherService(); } + @Bean + @Description("Get the weather in location") + public Function, MockWeatherService.Response>> weatherFunctionWithContext() { + return request -> context -> new MockWeatherService().apply(request); + } + + @Bean + @Description("Get the weather in location") + public BiFunction, MockWeatherService.Response> weatherFunctionWithContext2() { + return (request, context) -> new MockWeatherService().apply(request); + } + // Relies on the Request's JsonClassDescription annotation to provide the // function description. @Bean diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/MockWeatherService.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/MockWeatherService.java index 80f855a156c..9333522be2c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/MockWeatherService.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/MockWeatherService.java @@ -37,6 +37,7 @@ public class MockWeatherService implements Function, MockWeatherService.Response> weatherFunctionWithContext() { + return (request, context) -> { + return new MockWeatherService().apply(request); + }; + } + @Bean @Description("Get the weather in location") public Function weatherFunction() { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionWrapperIT.java index 2cf31c1d499..08a095119ae 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionWrapperIT.java @@ -28,8 +28,8 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallbackContext.SchemaType; import org.springframework.ai.model.function.FunctionCallbackWrapper; -import org.springframework.ai.model.function.FunctionCallbackWrapper.Builder.SchemaType; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithPromptFunctionIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithPromptFunctionIT.java index 69831e388ee..e72ac44806d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithPromptFunctionIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithPromptFunctionIT.java @@ -27,8 +27,8 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.function.FunctionCallbackContext.SchemaType; import org.springframework.ai.model.function.FunctionCallbackWrapper; -import org.springframework.ai.model.function.FunctionCallbackWrapper.Builder.SchemaType; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions; import org.springframework.boot.autoconfigure.AutoConfigurations;