Skip to content

Add MCP async tool callback implementation and tests #3235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.modelcontextprotocol.client.McpAsyncClient;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
import io.modelcontextprotocol.spec.McpSchema.Tool;
import reactor.core.publisher.Mono;

import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.model.ModelOptionsUtils;
Expand Down Expand Up @@ -52,6 +53,7 @@
* }</pre>
*
* @author Christian Tzolov
* @author Wenli Tian
* @see ToolCallback
* @see McpAsyncClient
* @see Tool
Expand All @@ -65,7 +67,7 @@ public class AsyncMcpToolCallback implements ToolCallback {
/**
* Creates a new {@code AsyncMcpToolCallback} instance.
* @param mcpClient the MCP client to use for tool execution
* @param tool the MCP tool definition to adapt
* @param tool the MCP tool definition to adapt
*/
public AsyncMcpToolCallback(McpAsyncClient mcpClient, Tool tool) {
this.asyncMcpClient = mcpClient;
Expand Down Expand Up @@ -106,21 +108,49 @@ public ToolDefinition getToolDefinition() {
*/
@Override
public String call(String functionInput) {
Map<String, Object> arguments = ModelOptionsUtils.jsonToMap(functionInput);
// Note that we use the original tool name here, not the adapted one from
// getToolDefinition
return this.asyncMcpClient.callTool(new CallToolRequest(this.tool.name(), arguments)).map(response -> {
if (response.isError() != null && response.isError()) {
throw new IllegalStateException("Error calling tool: " + response.content());
}
return ModelOptionsUtils.toJsonString(response.content());
}).block();
// For backward compatibility, use blocking call but internally use reactive
// method
return callAsync(functionInput).block();
}

@Override
public String call(String toolArguments, ToolContext toolContext) {
// ToolContext is not supported by the MCP tools
return this.call(toolArguments);
return callAsync(toolArguments, toolContext).block();
}

/**
* Asynchronously executes the tool call, returning a Mono containing the
* result.
* <p>
* This method provides a fully non-blocking way to call tools, suitable for use
* in reactive applications.
*
* @param functionInput the tool input as a JSON string
* @return a Mono containing the tool response
*/
public Mono<String> callAsync(String functionInput) {
return callAsync(functionInput, null);
}

/**
* Asynchronously executes the tool call with tool context support, returning a
* Mono containing the result.
*
* @param toolArguments the tool arguments as a JSON string
* @param toolContext the tool execution context
* @return a Mono containing the tool response
*/
public Mono<String> callAsync(String toolArguments, ToolContext toolContext) {
Map<String, Object> arguments = ModelOptionsUtils.jsonToMap(toolArguments);
// Note that we use the original tool name here, not the adapted one from
// getToolDefinition
return this.asyncMcpClient.callTool(new CallToolRequest(this.tool.name(), arguments)).handle((response, sink) -> {
if (response.isError() != null && response.isError()) {
sink.error(new IllegalStateException("Error calling tool: " + response.content()));
return;
}
sink.next(ModelOptionsUtils.toJsonString(response.content()));
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
import io.modelcontextprotocol.spec.McpSchema.Tool;
import io.modelcontextprotocol.util.Assert;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.ai.tool.support.ToolUtils;
import org.springframework.util.CollectionUtils;

/**
Expand Down Expand Up @@ -67,6 +67,7 @@
* }</pre>
*
* @author Christian Tzolov
* @author Wenli Tian
* @since 1.0.0
* @see ToolCallbackProvider
* @see AsyncMcpToolCallback
Expand Down Expand Up @@ -96,7 +97,7 @@ public AsyncMcpToolCallbackProvider(BiPredicate<McpAsyncClient, Tool> toolFilter
* clients.
* @param mcpClients the list of MCP clients to use for discovering tools. Each client
* typically connects to a different MCP server, allowing tool discovery from multiple
* sources.
* sources.
* @throws IllegalArgumentException if mcpClients is null
*/
public AsyncMcpToolCallbackProvider(List<McpAsyncClient> mcpClients) {
Expand Down Expand Up @@ -139,41 +140,44 @@ public AsyncMcpToolCallbackProvider(McpAsyncClient... mcpClients) {
*/
@Override
public ToolCallback[] getToolCallbacks() {

List<ToolCallback> toolCallbackList = new ArrayList<>();

for (McpAsyncClient mcpClient : this.mcpClients) {

ToolCallback[] toolCallbacks = mcpClient.listTools()
.map(response -> response.tools()
.stream()
.filter(tool -> this.toolFilter.test(mcpClient, tool))
.map(tool -> new AsyncMcpToolCallback(mcpClient, tool))
.toArray(ToolCallback[]::new))
.block();

validateToolCallbacks(toolCallbacks);

toolCallbackList.addAll(List.of(toolCallbacks));
}

return toolCallbackList.toArray(new ToolCallback[0]);
// Use the new non-blocking method, but block here to comply with interface
// requirements
return getToolCallbacksAsync().block();
}

/**
* Validates that there are no duplicate tool names in the provided callbacks.
* Asynchronously retrieves tool callbacks, returning a Mono containing all tool
* callbacks.
* <p>
* This method ensures that each tool has a unique name, which is required for proper
* tool resolution and execution.
* @param toolCallbacks the tool callbacks to validate
* @throws IllegalStateException if duplicate tool names are found
* This method provides a fully non-blocking way to retrieve tool callbacks,
* suitable for use in reactive applications.
*
* @return a Mono containing all tool callbacks
*/
private void validateToolCallbacks(ToolCallback[] toolCallbacks) {
List<String> duplicateToolNames = ToolUtils.getDuplicateToolNames(toolCallbacks);
if (!duplicateToolNames.isEmpty()) {
throw new IllegalStateException(
"Multiple tools with the same name (%s)".formatted(String.join(", ", duplicateToolNames)));
public Mono<ToolCallback[]> getToolCallbacksAsync() {
List<Mono<ToolCallback[]>> clientToolCallbacks = new ArrayList<>();

for (McpAsyncClient mcpClient : this.mcpClients) {
Mono<ToolCallback[]> toolCallbacksMono = mcpClient.listTools()
.map(response -> response.tools()
.stream()
.filter(tool -> this.toolFilter.test(mcpClient, tool))
.map(tool -> new AsyncMcpToolCallback(mcpClient, tool))
.toArray(ToolCallback[]::new))
.doOnNext(McpToolUtils::validateToolCallbacks);

clientToolCallbacks.add(toolCallbacksMono);
}

return Flux.concat(clientToolCallbacks)
.collectList()
.map(lists -> {
List<ToolCallback> allCallbacks = new ArrayList<>();
for (ToolCallback[] callbacks : lists) {
allCallbacks.addAll(List.of(callbacks));
}
return allCallbacks.toArray(new ToolCallback[0]);
});
}

/**
Expand All @@ -200,7 +204,7 @@ public static Flux<ToolCallback> asyncToolCallbacks(List<McpAsyncClient> mcpClie
return Flux.empty();
}

return Flux.fromArray(new AsyncMcpToolCallbackProvider(mcpClients).getToolCallbacks());
return Flux.from(new AsyncMcpToolCallbackProvider(mcpClients).getToolCallbacksAsync())
.flatMap(Flux::fromArray);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.support.ToolUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MimeType;
Expand All @@ -59,6 +60,7 @@
* </ul>
*
* @author Christian Tzolov
* @author Wenli Tian
*/
public final class McpToolUtils {

Expand All @@ -71,7 +73,6 @@ private McpToolUtils() {
}

public static String prefixedToolName(String prefix, String toolName) {

if (StringUtils.isEmpty(prefix) || StringUtils.isEmpty(toolName)) {
throw new IllegalArgumentException("Prefix or toolName cannot be null or empty");
}
Expand Down Expand Up @@ -160,7 +161,7 @@ public static McpServerFeatures.SyncToolSpecification toSyncToolSpecification(To
* specifications</li>
* </ul>
* @param toolCallback the Spring AI function callback to convert
* @param mimeType the MIME type of the output content
* @param mimeType the MIME type of the output content
* @return an MCP SyncToolSpecification that wraps the function callback
* @throws RuntimeException if there's an error during the function execution
*/
Expand All @@ -176,12 +177,12 @@ public static McpServerFeatures.SyncToolSpecification toSyncToolSpecification(To
new ToolContext(Map.of(TOOL_CONTEXT_MCP_EXCHANGE_KEY, exchange)));
if (mimeType != null && mimeType.toString().startsWith("image")) {
return new McpSchema.CallToolResult(List
.of(new McpSchema.ImageContent(List.of(Role.ASSISTANT), null, callResult, mimeType.toString())),
.of(new McpSchema.ImageContent(List.of(Role.ASSISTANT), null, callResult,
mimeType.toString())),
false);
}
return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(callResult)), false);
}
catch (Exception e) {
} catch (Exception e) {
return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(e.getMessage())), true);
}
});
Expand All @@ -195,7 +196,7 @@ public static McpServerFeatures.SyncToolSpecification toSyncToolSpecification(To
public static Optional<McpSyncServerExchange> getMcpExchange(ToolContext toolContext) {
if (toolContext != null && toolContext.getContext().containsKey(TOOL_CONTEXT_MCP_EXCHANGE_KEY)) {
return Optional
.ofNullable((McpSyncServerExchange) toolContext.getContext().get(TOOL_CONTEXT_MCP_EXCHANGE_KEY));
.ofNullable((McpSyncServerExchange) toolContext.getContext().get(TOOL_CONTEXT_MCP_EXCHANGE_KEY));
}
return Optional.empty();
}
Expand Down Expand Up @@ -277,7 +278,7 @@ public static McpServerFeatures.AsyncToolSpecification toAsyncToolSpecification(
* <li>Provide backpressure through Project Reactor</li>
* </ul>
* @param toolCallback the Spring AI tool callback to convert
* @param mimeType the MIME type of the output content
* @param mimeType the MIME type of the output content
* @return an MCP asynchronous tool specificaiotn that wraps the tool callback
* @see McpServerFeatures.AsyncToolSpecification
* @see Schedulers#boundedElastic()
Expand All @@ -289,8 +290,9 @@ public static McpServerFeatures.AsyncToolSpecification toAsyncToolSpecification(

return new AsyncToolSpecification(syncToolSpecification.tool(),
(exchange, map) -> Mono
.fromCallable(() -> syncToolSpecification.call().apply(new McpSyncServerExchange(exchange), map))
.subscribeOn(Schedulers.boundedElastic()));
.fromCallable(
() -> syncToolSpecification.call().apply(new McpSyncServerExchange(exchange), map))
.subscribeOn(Schedulers.boundedElastic()));
}

/**
Expand Down Expand Up @@ -365,4 +367,19 @@ private record Base64Wrapper(@JsonAlias("mimetype") @Nullable MimeType mimeType,
"base64", "b64", "imageData" }) @Nullable String data) {
}

/**
* Validates that there are no duplicate tool names in the provided callbacks.
* <p>
* This method ensures that each tool has a unique name, which is required for proper
* tool resolution and execution.
* @param toolCallbacks the tool callbacks to validate
* @throws IllegalStateException if duplicate tool names are found
*/
public static void validateToolCallbacks(ToolCallback[] toolCallbacks) {
List<String> duplicateToolNames = ToolUtils.getDuplicateToolNames(toolCallbacks);
if (!duplicateToolNames.isEmpty()) {
throw new IllegalStateException(
"Multiple tools with the same name (%s)".formatted(String.join(", ", duplicateToolNames)));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,11 @@ public class SyncMcpToolCallback implements ToolCallback {
/**
* Creates a new {@code SyncMcpToolCallback} instance.
* @param mcpClient the MCP client to use for tool execution
* @param tool the MCP tool definition to adapt
* @param tool the MCP tool definition to adapt
*/
public SyncMcpToolCallback(McpSyncClient mcpClient, Tool tool) {
this.mcpClient = mcpClient;
this.tool = tool;

}

/**
Expand All @@ -90,10 +89,10 @@ public SyncMcpToolCallback(McpSyncClient mcpClient, Tool tool) {
@Override
public ToolDefinition getToolDefinition() {
return DefaultToolDefinition.builder()
.name(McpToolUtils.prefixedToolName(this.mcpClient.getClientInfo().name(), this.tool.name()))
.description(this.tool.description())
.inputSchema(ModelOptionsUtils.toJsonString(this.tool.inputSchema()))
.build();
.name(McpToolUtils.prefixedToolName(this.mcpClient.getClientInfo().name(), this.tool.name()))
.description(this.tool.description())
.inputSchema(ModelOptionsUtils.toJsonString(this.tool.inputSchema()))
.build();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
* }</pre>
*
* @author Christian Tzolov
* @author Wenli Tian
* @see ToolCallbackProvider
* @see SyncMcpToolCallback
* @see McpSyncClient
Expand Down Expand Up @@ -130,31 +131,16 @@ public SyncMcpToolCallbackProvider(McpSyncClient... mcpClients) {
@Override
public ToolCallback[] getToolCallbacks() {
var array = this.mcpClients.stream()
.flatMap(mcpClient -> mcpClient.listTools()
.tools()
.stream()
.filter(tool -> this.toolFilter.test(mcpClient, tool))
.map(tool -> new SyncMcpToolCallback(mcpClient, tool)))
.toArray(ToolCallback[]::new);
validateToolCallbacks(array);
.flatMap(mcpClient -> mcpClient.listTools()
.tools()
.stream()
.filter(tool -> this.toolFilter.test(mcpClient, tool))
.map(tool -> new SyncMcpToolCallback(mcpClient, tool)))
.toArray(ToolCallback[]::new);
McpToolUtils.validateToolCallbacks(array);
return array;
}

/**
* Validates that there are no duplicate tool names in the provided callbacks.
* <p>
* This method ensures that each tool has a unique name, which is required for proper
* tool resolution and execution.
* @param toolCallbacks the tool callbacks to validate
* @throws IllegalStateException if duplicate tool names are found
*/
private void validateToolCallbacks(ToolCallback[] toolCallbacks) {
List<String> duplicateToolNames = ToolUtils.getDuplicateToolNames(toolCallbacks);
if (!duplicateToolNames.isEmpty()) {
throw new IllegalStateException(
"Multiple tools with the same name (%s)".formatted(String.join(", ", duplicateToolNames)));
}
}


/**
* Creates a consolidated list of tool callbacks from multiple MCP clients.
Expand Down
Loading