Skip to content

feat: implement support for elicitation #271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import reactor.core.publisher.Mono;
import reactor.netty.DisposableServer;
import reactor.netty.http.server.HttpServer;

Expand All @@ -41,6 +42,7 @@
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.reactive.function.server.RouterFunctions;
import reactor.test.StepVerifier;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
Expand Down Expand Up @@ -331,6 +333,229 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt
mcpServer.closeGracefully().block();
}

// ---------------------------------------
// Elicitation Tests
// ---------------------------------------
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient", "webflux" })
void testCreateElicitationWithoutElicitationCapabilities(String clientType) {

var clientBuilder = clientBuilders.get(clientType);

McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification(
new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {

exchange.createElicitation(mock(ElicitRequest.class)).block();

return Mono.just(mock(CallToolResult.class));
});

var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build();

try (
// Create client without elicitation capabilities
var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) {

assertThat(client.initialize()).isNotNull();

try {
client.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
}
catch (McpError e) {
assertThat(e).isInstanceOf(McpError.class)
.hasMessage("Client must be configured with elicitation capabilities");
}
}
server.closeGracefully().block();
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient", "webflux" })
void testCreateElicitationSuccess(String clientType) {

var clientBuilder = clientBuilders.get(clientType);

Function<ElicitRequest, ElicitResult> elicitationHandler = request -> {
assertThat(request.message()).isNotEmpty();
assertThat(request.requestedSchema()).isNotNull();

return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message()));
};

CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
null);

McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification(
new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {

var elicitationRequest = ElicitRequest.builder()
.message("Test message")
.requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder()
.properties(Map.of("message", McpSchema.StringSchema.builder().build()))
.build())
.build();

StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> {
assertThat(result).isNotNull();
assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT);
assertThat(result.content().get("message")).isEqualTo("Test message");
}).verifyComplete();

return Mono.just(callResponse);
});

var mcpServer = McpServer.async(mcpServerTransportProvider)
.serverInfo("test-server", "1.0.0")
.tools(tool)
.build();

try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
.capabilities(ClientCapabilities.builder().elicitation().build())
.elicitation(elicitationHandler)
.build()) {

InitializeResult initResult = mcpClient.initialize();
assertThat(initResult).isNotNull();

CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));

assertThat(response).isNotNull();
assertThat(response).isEqualTo(callResponse);
}
mcpServer.closeGracefully().block();
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient", "webflux" })
void testCreateElicitationWithRequestTimeoutSuccess(String clientType) {

// Client
var clientBuilder = clientBuilders.get(clientType);

Function<ElicitRequest, ElicitResult> elicitationHandler = request -> {
assertThat(request.message()).isNotEmpty();
assertThat(request.requestedSchema()).isNotNull();
try {
TimeUnit.SECONDS.sleep(2);
}
catch (InterruptedException e) {
throw new RuntimeException(e);
}
return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message()));
};

var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
.capabilities(ClientCapabilities.builder().elicitation().build())
.elicitation(elicitationHandler)
.build();

// Server

CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
null);

McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification(
new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {

var elicitationRequest = ElicitRequest.builder()
.message("Test message")
.requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder()
.properties(Map.of("message", McpSchema.StringSchema.builder().build()))
.build())
.build();

StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> {
assertThat(result).isNotNull();
assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT);
assertThat(result.content().get("message")).isEqualTo("Test message");
}).verifyComplete();

return Mono.just(callResponse);
});

var mcpServer = McpServer.async(mcpServerTransportProvider)
.serverInfo("test-server", "1.0.0")
.requestTimeout(Duration.ofSeconds(3))
.tools(tool)
.build();

InitializeResult initResult = mcpClient.initialize();
assertThat(initResult).isNotNull();

CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));

assertThat(response).isNotNull();
assertThat(response).isEqualTo(callResponse);

mcpClient.closeGracefully();
mcpServer.closeGracefully().block();
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient", "webflux" })
void testCreateElicitationWithRequestTimeoutFail(String clientType) {

// Client
var clientBuilder = clientBuilders.get(clientType);

Function<ElicitRequest, ElicitResult> elicitationHandler = request -> {
assertThat(request.message()).isNotEmpty();
assertThat(request.requestedSchema()).isNotNull();
try {
TimeUnit.SECONDS.sleep(2);
}
catch (InterruptedException e) {
throw new RuntimeException(e);
}
return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message()));
};

var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
.capabilities(ClientCapabilities.builder().elicitation().build())
.elicitation(elicitationHandler)
.build();

// Server

CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
null);

McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification(
new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {

var elicitationRequest = ElicitRequest.builder()
.message("Test message")
.requestedSchema(McpSchema.PrimitiveSchemaDefinition.builder()
.properties(Map.of("message", McpSchema.StringSchema.builder().build()))
.build())
.build();

StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> {
assertThat(result).isNotNull();
assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT);
assertThat(result.content().get("message")).isEqualTo("Test message");
}).verifyComplete();

return Mono.just(callResponse);
});

var mcpServer = McpServer.async(mcpServerTransportProvider)
.serverInfo("test-server", "1.0.0")
.requestTimeout(Duration.ofSeconds(1))
.tools(tool)
.build();

InitializeResult initResult = mcpClient.initialize();
assertThat(initResult).isNotNull();

assertThatExceptionOfType(McpError.class).isThrownBy(() -> {
mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
}).withMessageContaining("within 1000ms");

mcpClient.closeGracefully();
mcpServer.closeGracefully().block();
}

// ---------------------------------------
// Roots Tests
// ---------------------------------------
Expand Down
Loading