Skip to content

fix: Resolve URIs #262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.Exceptions;
Expand Down Expand Up @@ -82,6 +83,8 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv
*/
public static final String DEFAULT_SSE_ENDPOINT = "/sse";

public static final String DEFAULT_CONTEXT_PATH = "";

public static final String DEFAULT_BASE_URL = "";

private final ObjectMapper objectMapper;
Expand All @@ -92,6 +95,8 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv
*/
private final String baseUrl;

private final String contextPath;

private final String messageEndpoint;

private final String sseEndpoint;
Expand Down Expand Up @@ -134,33 +139,38 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa
* @throws IllegalArgumentException if either parameter is null
*/
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint);
this(objectMapper, DEFAULT_CONTEXT_PATH, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint);
}

/**
* Constructs a new WebFlux SSE server transport provider instance.
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
* of MCP messages. Must not be null.
* @param contextPath The context path of the server.
* @param baseUrl webflux message base path
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
* messages. This endpoint will be communicated to clients during SSE connection
* setup. Must not be null.
* @throws IllegalArgumentException if either parameter is null
*/
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint) {
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String contextPath, String baseUrl,
String messageEndpoint, String sseEndpoint) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
Assert.notNull(contextPath, "Context path must not be null");
Assert.notNull(baseUrl, "Message base path must not be null");
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
Assert.hasText(messageEndpoint, "Message endpoint must not be empty");
Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
Assert.hasText(sseEndpoint, "SSE endpoint must not be empty");

this.objectMapper = objectMapper;
this.baseUrl = baseUrl;
this.contextPath = Utils.removeTrailingSlash(contextPath);
this.baseUrl = Utils.removeTrailingSlash(baseUrl);
this.messageEndpoint = messageEndpoint;
this.sseEndpoint = sseEndpoint;
this.routerFunction = RouterFunctions.route()
.GET(this.sseEndpoint, this::handleSseConnection)
.POST(this.messageEndpoint, this::handleMessage)
.GET(this.baseUrl + this.sseEndpoint, this::handleSseConnection)
.POST(this.baseUrl + this.messageEndpoint, this::handleMessage)
.build();
}

Expand Down Expand Up @@ -271,7 +281,7 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
logger.debug("Sending initial endpoint event to session: {}", sessionId);
sink.next(ServerSentEvent.builder()
.event(ENDPOINT_EVENT_TYPE)
.data(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId)
.data(this.contextPath + this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId)
.build());
sink.onCancel(() -> {
logger.debug("Session {} cancelled", sessionId);
Expand Down Expand Up @@ -391,6 +401,8 @@ public static class Builder {

private ObjectMapper objectMapper;

private String contextPath = DEFAULT_CONTEXT_PATH;

private String baseUrl = DEFAULT_BASE_URL;

private String messageEndpoint;
Expand Down Expand Up @@ -423,6 +435,18 @@ public Builder basePath(String baseUrl) {
return this;
}

/**
* Sets the context path under which the server is running.
* @param contextPath the context path.
* @return this builder instance.
* @throws IllegalArgumentException if contextPath is null
*/
public Builder contextPath(String contextPath) {
Assert.notNull(contextPath, "contextPath must not be null");
this.contextPath = contextPath;
return this;
}

/**
* Sets the endpoint URI where clients should send their JSON-RPC messages.
* @param messageEndpoint The message endpoint URI. Must not be null.
Expand Down Expand Up @@ -457,7 +481,8 @@ public WebFluxSseServerTransportProvider build() {
Assert.notNull(objectMapper, "ObjectMapper must be set");
Assert.notNull(messageEndpoint, "Message endpoint must be set");

return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint);
return new WebFluxSseServerTransportProvider(objectMapper, contextPath, baseUrl, messageEndpoint,
sseEndpoint);
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package io.modelcontextprotocol.server;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
import io.modelcontextprotocol.spec.McpSchema;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.springframework.http.server.reactive.HttpHandler;
import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.reactive.function.server.RouterFunction;
import org.springframework.web.reactive.function.server.RouterFunctions;
import org.springframework.web.reactive.function.server.ServerResponse;
import reactor.core.publisher.Mono;
import reactor.netty.DisposableServer;
import reactor.netty.http.server.HttpServer;

import java.util.List;
import java.util.Map;
import java.util.stream.Stream;

import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.web.reactive.function.server.RequestPredicates.path;
import static org.springframework.web.reactive.function.server.RouterFunctions.nest;

/**
* Tests the {@link WebFluxSseServerTransportProvider} with different values for the
* endpoint.
*/
public class WebFluxSseCustomPathIntegrationTests {

private static final int PORT = TestUtil.findAvailablePort();

private DisposableServer httpServer;

private WebFluxSseServerTransportProvider mcpServerTransportProvider;

String emptyJsonSchema = """
{
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"properties": {}
}
""";

@ParameterizedTest(
name = "baseUrl = \"{0}\" messageEndpoint = \"{1}\" sseEndpoint = \"{2}\" contextPath = \"{3}\" : {displayName} ")
@MethodSource("provideCustomEndpoints")
public void testCustomizedEndpoints(String baseUrl, String messageEndpoint, String sseEndpoint,
String contextPath) {

this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider(new ObjectMapper(), contextPath,
baseUrl, messageEndpoint, sseEndpoint);

RouterFunction<?> router = this.mcpServerTransportProvider.getRouterFunction();
// wrap the context path around the router function
RouterFunction<ServerResponse> nestedRouter = (RouterFunction<ServerResponse>) nest(path(contextPath), router);
HttpHandler httpHandler = RouterFunctions.toHttpHandler(nestedRouter);
ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler);

this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow();

var endpoint = buildSseEndpoint(contextPath, baseUrl, sseEndpoint);

var clientBuilder = McpClient
.sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT))
.sseEndpoint(endpoint)
.build());

McpSchema.CallToolResult callResponse = new McpSchema.CallToolResult(
List.of(new McpSchema.TextContent("CALL RESPONSE")), null);

McpServerFeatures.AsyncToolSpecification tool1 = new McpServerFeatures.AsyncToolSpecification(
new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema),
(exchange, request) -> Mono.just(callResponse));

var server = McpServer.async(mcpServerTransportProvider)
.serverInfo("test-server", "1.0.0")
.capabilities(McpSchema.ServerCapabilities.builder().tools(true).build())
.tools(tool1)
.build();

try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) {
assertThat(client.initialize()).isNotNull();
assertThat(client.listTools().tools()).contains(tool1.tool());

McpSchema.CallToolResult response = client.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
assertThat(response).isNotNull().isEqualTo(callResponse);
}

server.close();

}

/**
* This is a helper function for the tests which builds the SSE endpoint to pass to
* the client transport.
* @param contextPath context path of the server.
* @param baseUrl base url of the sse endpoint.
* @param sseEndpoint the sse endpoint.
* @return the created sse endpoint.
*/
private String buildSseEndpoint(String contextPath, String baseUrl, String sseEndpoint) {
if (baseUrl.endsWith("/")) {
baseUrl = baseUrl.substring(0, baseUrl.length() - 1);
}
if (contextPath.endsWith("/")) {
contextPath = contextPath.substring(0, contextPath.length() - 1);
}

return contextPath + baseUrl + sseEndpoint;
}

@AfterEach
public void after() {
if (mcpServerTransportProvider != null) {
mcpServerTransportProvider.closeGracefully().block();
}
if (httpServer != null) {
httpServer.disposeNow();
}
}

/**
* Provides a stream of custom endpoints. This generates all possible combinations for
* allowed endpoint values.
*
* <p>
* Each combination is returned as an {@link Arguments} object containing four
* parameters in the following order:
* </p>
* <ol>
* <li>Base URL (String)</li>
* <li>Message endpoint (String)</li>
* <li>SSE endpoint (String)</li>
* <li>Context path (String)</li>
* </ol>
* @return a {@link Stream} of {@link Arguments} objects, each containing four String
* parameters representing different endpoint combinations for parameterized testing
*/
private static Stream<Arguments> provideCustomEndpoints() {
String[] baseUrls = { "", "/", "/v1", "/v1/" };
String[] messageEndpoints = { "/", "/message", "/message/" };
String[] sseEndpoints = { "/", "/sse", "/sse/" };
String[] contextPaths = { "", "/", "/mcp", "/mcp/" };

return Stream.of(baseUrls)
.flatMap(baseUrl -> Stream.of(messageEndpoints)
.flatMap(messageEndpoint -> Stream.of(sseEndpoints)
.flatMap(sseEndpoint -> Stream.of(contextPaths)
.map(contextPath -> Arguments.of(baseUrl, messageEndpoint, sseEndpoint, contextPath)))));
}

}
8 changes: 7 additions & 1 deletion mcp-spring/mcp-spring-webmvc/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
Expand Down Expand Up @@ -128,7 +134,7 @@
<scope>test</scope>
</dependency>

</dependencies>
</dependencies>


</project>
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.io.IOException;
import java.time.Duration;
import java.util.Collections;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -18,6 +19,7 @@
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
Expand Down Expand Up @@ -93,6 +95,8 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi

private final String baseUrl;

private final String contextPath;

private final RouterFunction<ServerResponse> routerFunction;

private McpServerSession.Factory sessionFactory;
Expand Down Expand Up @@ -132,13 +136,14 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag
* @throws IllegalArgumentException if any parameter is null
*/
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
this(objectMapper, "", messageEndpoint, sseEndpoint);
this(objectMapper, "", "", messageEndpoint, sseEndpoint);
}

/**
* Constructs a new WebMvcSseServerTransportProvider instance.
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
* of messages.
* @param contextPath The context path under which the server runs.
* @param baseUrl The base URL for the message endpoint, used to construct the full
* endpoint URL for clients.
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
Expand All @@ -147,20 +152,24 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag
* @param sseEndpoint The endpoint URI where clients establish their SSE connections.
* @throws IllegalArgumentException if any parameter is null
*/
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint) {
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String contextPath, String baseUrl,
String messageEndpoint, String sseEndpoint) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
Assert.notNull(contextPath, "Context path must not be null");
Assert.notNull(baseUrl, "Message base URL must not be null");
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
Assert.hasText(messageEndpoint, "Message endpoint must not be empty");
Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
Assert.hasText(sseEndpoint, "SSE endpoint must not be empty");

this.objectMapper = objectMapper;
this.baseUrl = baseUrl;
this.contextPath = Utils.removeTrailingSlash(contextPath);
this.baseUrl = Utils.removeTrailingSlash(baseUrl);
this.messageEndpoint = messageEndpoint;
this.sseEndpoint = sseEndpoint;
this.routerFunction = RouterFunctions.route()
.GET(this.sseEndpoint, this::handleSseConnection)
.POST(this.messageEndpoint, this::handleMessage)
.GET(this.baseUrl + this.sseEndpoint, this::handleSseConnection)
.POST(this.baseUrl + this.messageEndpoint, this::handleMessage)
.build();
}

Expand Down Expand Up @@ -269,7 +278,7 @@ private ServerResponse handleSseConnection(ServerRequest request) {
try {
sseBuilder.id(sessionId)
.event(ENDPOINT_EVENT_TYPE)
.data(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId);
.data(this.contextPath + this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId);
}
catch (Exception e) {
logger.error("Failed to send initial endpoint event: {}", e.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ static class TestConfig {
@Bean
public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() {

return new WebMvcSseServerTransportProvider(new ObjectMapper(), CUSTOM_CONTEXT_PATH, MESSAGE_ENDPOINT,
return new WebMvcSseServerTransportProvider(new ObjectMapper(), CUSTOM_CONTEXT_PATH, "", MESSAGE_ENDPOINT,
WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT);
}

Expand Down
Loading