Skip to content

chore: Add ToolContextCreator. #2853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import io.micrometer.observation.ObservationRegistry;

import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.chat.model.ToolContextCreator;
import org.springframework.ai.model.tool.DefaultToolContextCreator;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
Expand Down Expand Up @@ -72,14 +75,22 @@ ToolExecutionExceptionProcessor toolExecutionExceptionProcessor() {
return new DefaultToolExecutionExceptionProcessor(false);
}

@Bean
@ConditionalOnMissingBean
ToolContextCreator<ToolContext> toolContextCreator() {
return new DefaultToolContextCreator();
}

@Bean
@ConditionalOnMissingBean
ToolCallingManager toolCallingManager(ToolCallbackResolver toolCallbackResolver,
ToolContextCreator<? extends ToolContext> toolContextCreator,
ToolExecutionExceptionProcessor toolExecutionExceptionProcessor,
ObjectProvider<ObservationRegistry> observationRegistry) {
return ToolCallingManager.builder()
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.toolCallbackResolver(toolCallbackResolver)
.toolContextCreator(toolContextCreator)
.toolExecutionExceptionProcessor(toolExecutionExceptionProcessor)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import io.modelcontextprotocol.server.McpSyncServerExchange;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.Role;
import org.springframework.ai.chat.model.ToolContextCreator;
import org.springframework.ai.model.tool.DefaultToolContextCreator;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

Expand Down Expand Up @@ -166,14 +168,40 @@ public static McpServerFeatures.SyncToolSpecification toSyncToolSpecification(To
*/
public static McpServerFeatures.SyncToolSpecification toSyncToolSpecification(ToolCallback toolCallback,
MimeType mimeType) {
return toSyncToolSpecification(toolCallback, new DefaultToolContextCreator(), mimeType);
}

/**
* Converts a Spring AI ToolCallback to an MCP SyncToolSpecification. This enables
* Spring AI functions to be exposed as MCP tools that can be discovered and invoked
* by language models.
*
* <p>
* The conversion process:
* <ul>
* <li>Creates an MCP Tool with the function's name and input schema</li>
* <li>Wraps the function's execution in a SyncToolSpecification that handles the MCP
* protocol</li>
* <li>Provides error handling and result formatting according to MCP
* specifications</li>
* </ul>
* @param toolCallback the Spring AI function callback to convert
* @param toolContextCreator the tool context creator to use for creating the tool
* context
* @param mimeType the MIME type of the output content
* @return an MCP SyncToolRegistration that wraps the function callback
* @throws RuntimeException if there's an error during the function execution
*/
public static McpServerFeatures.SyncToolSpecification toSyncToolSpecification(ToolCallback toolCallback,
ToolContextCreator<? extends ToolContext> toolContextCreator, MimeType mimeType) {

var tool = new McpSchema.Tool(toolCallback.getToolDefinition().name(),
toolCallback.getToolDefinition().description(), toolCallback.getToolDefinition().inputSchema());

return new McpServerFeatures.SyncToolSpecification(tool, (exchange, request) -> {
try {
String callResult = toolCallback.call(ModelOptionsUtils.toJsonString(request),
new ToolContext(Map.of(TOOL_CONTEXT_MCP_EXCHANGE_KEY, exchange)));
toolContextCreator.create(Map.of(TOOL_CONTEXT_MCP_EXCHANGE_KEY, exchange)));
if (mimeType != null && mimeType.toString().startsWith("image")) {
return new McpSchema.CallToolResult(List
.of(new McpSchema.ImageContent(List.of(Role.ASSISTANT), null, callResult, mimeType.toString())),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
* @author Christian Tzolov
* @since 1.0.0
*/
public final class ToolContext {
public class ToolContext {

/**
* The key for the running, tool call history stored in the context map.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package org.springframework.ai.chat.model;

import java.util.Map;

/**
* A functional interface for creating a {@link ToolContext} instance.
*/
public interface ToolContextCreator<Ctx extends ToolContext> {

/**
* Create a new instance of {@link ToolContext} with the provided context map.
*/
Ctx create(final Map<String, Object> context);

}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.chat.model.ToolContextCreator;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.tool.ToolCallback;
Expand Down Expand Up @@ -62,6 +63,9 @@ public class DefaultToolCallingManager implements ToolCallingManager {
private static final ToolCallbackResolver DEFAULT_TOOL_CALLBACK_RESOLVER
= new DelegatingToolCallbackResolver(List.of());

private static final DefaultToolContextCreator DEFAULT_TOOL_CONTEXT_CREATOR
= new DefaultToolContextCreator();

private static final ToolExecutionExceptionProcessor DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR
= DefaultToolExecutionExceptionProcessor.builder().build();

Expand All @@ -71,16 +75,20 @@ public class DefaultToolCallingManager implements ToolCallingManager {

private final ToolCallbackResolver toolCallbackResolver;

private final ToolContextCreator<? extends ToolContext> toolContextCreator;

private final ToolExecutionExceptionProcessor toolExecutionExceptionProcessor;

public DefaultToolCallingManager(ObservationRegistry observationRegistry, ToolCallbackResolver toolCallbackResolver,
ToolContextCreator<? extends ToolContext> toolContextCreator,
ToolExecutionExceptionProcessor toolExecutionExceptionProcessor) {
Assert.notNull(observationRegistry, "observationRegistry cannot be null");
Assert.notNull(toolCallbackResolver, "toolCallbackResolver cannot be null");
Assert.notNull(toolExecutionExceptionProcessor, "toolCallExceptionConverter cannot be null");

this.observationRegistry = observationRegistry;
this.toolCallbackResolver = toolCallbackResolver;
this.toolContextCreator = toolContextCreator;
this.toolExecutionExceptionProcessor = toolExecutionExceptionProcessor;
}

Expand Down Expand Up @@ -136,7 +144,7 @@ public ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResp
.build();
}

private static ToolContext buildToolContext(Prompt prompt, AssistantMessage assistantMessage) {
private ToolContext buildToolContext(Prompt prompt, AssistantMessage assistantMessage) {
Map<String, Object> toolContextMap = Map.of();

if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions
Expand All @@ -151,7 +159,7 @@ private static ToolContext buildToolContext(Prompt prompt, AssistantMessage assi
buildConversationHistoryBeforeToolExecution(prompt, assistantMessage));
}

return new ToolContext(toolContextMap);
return toolContextCreator.create(toolContextMap);
}

private static List<Message> buildConversationHistoryBeforeToolExecution(Prompt prompt,
Expand Down Expand Up @@ -236,6 +244,8 @@ public final static class Builder {

private ToolCallbackResolver toolCallbackResolver = DEFAULT_TOOL_CALLBACK_RESOLVER;

private ToolContextCreator<? extends ToolContext> toolContextCreator = DEFAULT_TOOL_CONTEXT_CREATOR;

private ToolExecutionExceptionProcessor toolExecutionExceptionProcessor = DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR;

private Builder() {
Expand All @@ -251,6 +261,11 @@ public Builder toolCallbackResolver(ToolCallbackResolver toolCallbackResolver) {
return this;
}

public Builder toolContextCreator(ToolContextCreator<? extends ToolContext> toolContextCreator) {
this.toolContextCreator = toolContextCreator;
return this;
}

public Builder toolExecutionExceptionProcessor(
ToolExecutionExceptionProcessor toolExecutionExceptionProcessor) {
this.toolExecutionExceptionProcessor = toolExecutionExceptionProcessor;
Expand All @@ -259,7 +274,7 @@ public Builder toolExecutionExceptionProcessor(

public DefaultToolCallingManager build() {
return new DefaultToolCallingManager(this.observationRegistry, this.toolCallbackResolver,
this.toolExecutionExceptionProcessor);
this.toolContextCreator, this.toolExecutionExceptionProcessor);
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package org.springframework.ai.model.tool;

import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.chat.model.ToolContextCreator;

import java.util.Map;

/**
* A default implementation of {@link ToolContextCreator} that creates a new instance of
* {@link ToolContext}.
*/
public class DefaultToolContextCreator implements ToolContextCreator<ToolContext> {

@Override
public ToolContext create(final Map<String, Object> context) {
return new ToolContext(context);
}

}