Skip to content

Add support for DNS rebinding protections #284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.modelcontextprotocol.server.transport;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

Expand Down Expand Up @@ -110,6 +111,11 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv
*/
private volatile boolean isClosing = false;

/**
* DNS rebinding protection configuration.
*/
private final DnsRebindingProtectionConfig dnsRebindingProtectionConfig;

/**
* Constructs a new WebFlux SSE server transport provider instance with the default
* SSE endpoint.
Expand All @@ -134,7 +140,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa
* @throws IllegalArgumentException if either parameter is null
*/
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint);
this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint, null);
}

/**
Expand All @@ -149,6 +155,24 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa
*/
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint) {
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null);
}

/**
* Constructs a new WebFlux SSE server transport provider instance with optional DNS
* rebinding protection.
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
* of MCP messages. Must not be null.
* @param baseUrl webflux message base path
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
* messages. This endpoint will be communicated to clients during SSE connection
* setup. Must not be null.
* @param sseEndpoint The endpoint URI where clients establish their SSE connections.
* @param dnsRebindingProtectionConfig The DNS rebinding protection configuration (may be null).
* @throws IllegalArgumentException if required parameters are null
*/
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint, DnsRebindingProtectionConfig dnsRebindingProtectionConfig) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
Assert.notNull(baseUrl, "Message base path must not be null");
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
Expand All @@ -158,6 +182,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU
this.baseUrl = baseUrl;
this.messageEndpoint = messageEndpoint;
this.sseEndpoint = sseEndpoint;
this.dnsRebindingProtectionConfig = dnsRebindingProtectionConfig;
this.routerFunction = RouterFunctions.route()
.GET(this.sseEndpoint, this::handleSseConnection)
.POST(this.messageEndpoint, this::handleMessage)
Expand Down Expand Up @@ -256,6 +281,16 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
}

// Validate headers
if (dnsRebindingProtectionConfig != null) {
String hostHeader = request.headers().asHttpHeaders().getFirst("Host");
String originHeader = request.headers().asHttpHeaders().getFirst("Origin");
if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) {
logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, originHeader);
return ServerResponse.status(HttpStatus.FORBIDDEN).bodyValue("DNS rebinding protection validation failed");
}
}

return ServerResponse.ok()
.contentType(MediaType.TEXT_EVENT_STREAM)
.body(Flux.<ServerSentEvent<?>>create(sink -> {
Expand Down Expand Up @@ -300,6 +335,25 @@ private Mono<ServerResponse> handleMessage(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
}

// Always validate Content-Type for POST requests
String contentType = request.headers().contentType()
.map(MediaType::toString)
.orElse(null);
if (contentType == null || !contentType.toLowerCase().startsWith("application/json")) {
logger.warn("Invalid Content-Type header: '{}'", contentType);
return ServerResponse.badRequest().bodyValue(new McpError("Content-Type must be application/json"));
}

// Validate headers for POST requests if DNS rebinding protection is configured
if (dnsRebindingProtectionConfig != null) {
String hostHeader = request.headers().asHttpHeaders().getFirst("Host");
String originHeader = request.headers().asHttpHeaders().getFirst("Origin");
if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) {
logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, originHeader);
return ServerResponse.status(HttpStatus.FORBIDDEN).bodyValue("DNS rebinding protection validation failed");
}
}

if (request.queryParam("sessionId").isEmpty()) {
return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing in message endpoint"));
}
Expand Down Expand Up @@ -397,6 +451,8 @@ public static class Builder {

private String sseEndpoint = DEFAULT_SSE_ENDPOINT;

private DnsRebindingProtectionConfig dnsRebindingProtectionConfig;

/**
* Sets the ObjectMapper to use for JSON serialization/deserialization of MCP
* messages.
Expand Down Expand Up @@ -447,6 +503,23 @@ public Builder sseEndpoint(String sseEndpoint) {
return this;
}


/**
* Sets the DNS rebinding protection configuration.
* <p>
* When set, this configuration will be used to create a header validator that
* enforces DNS rebinding protection rules. This will override any previously set
* header validator.
* @param config The DNS rebinding protection configuration
* @return this builder instance
* @throws IllegalArgumentException if config is null
*/
public Builder dnsRebindingProtectionConfig(DnsRebindingProtectionConfig config) {
Assert.notNull(config, "DNS rebinding protection config must not be null");
this.dnsRebindingProtectionConfig = config;
return this;
}

/**
* Builds a new instance of {@link WebFluxSseServerTransportProvider} with the
* configured settings.
Expand All @@ -457,7 +530,8 @@ public WebFluxSseServerTransportProvider build() {
Assert.notNull(objectMapper, "ObjectMapper must be set");
Assert.notNull(messageEndpoint, "Message endpoint must be set");

return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint);
return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint,
dnsRebindingProtectionConfig);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.io.IOException;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
Expand Down Expand Up @@ -107,6 +108,11 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi
*/
private volatile boolean isClosing = false;

/**
* DNS rebinding protection configuration.
*/
private final DnsRebindingProtectionConfig dnsRebindingProtectionConfig;

/**
* Constructs a new WebMvcSseServerTransportProvider instance with the default SSE
* endpoint.
Expand All @@ -132,7 +138,7 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag
* @throws IllegalArgumentException if any parameter is null
*/
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
this(objectMapper, "", messageEndpoint, sseEndpoint);
this(objectMapper, "", messageEndpoint, sseEndpoint, null);
}

/**
Expand All @@ -149,6 +155,24 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag
*/
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint) {
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null);
}

/**
* Constructs a new WebMvcSseServerTransportProvider instance with DNS rebinding protection.
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
* of messages.
* @param baseUrl The base URL for the message endpoint, used to construct the full
* endpoint URL for clients.
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
* messages via HTTP POST. This endpoint will be communicated to clients through the
* SSE connection's initial endpoint event.
* @param sseEndpoint The endpoint URI where clients establish their SSE connections.
* @param dnsRebindingProtectionConfig The DNS rebinding protection configuration (may be null).
* @throws IllegalArgumentException if any required parameter is null
*/
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint, DnsRebindingProtectionConfig dnsRebindingProtectionConfig) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
Assert.notNull(baseUrl, "Message base URL must not be null");
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
Expand All @@ -158,6 +182,7 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUr
this.baseUrl = baseUrl;
this.messageEndpoint = messageEndpoint;
this.sseEndpoint = sseEndpoint;
this.dnsRebindingProtectionConfig = dnsRebindingProtectionConfig;
this.routerFunction = RouterFunctions.route()
.GET(this.sseEndpoint, this::handleSseConnection)
.POST(this.messageEndpoint, this::handleMessage)
Expand Down Expand Up @@ -247,6 +272,16 @@ private ServerResponse handleSseConnection(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
}

// Validate headers
if (dnsRebindingProtectionConfig != null) {
String hostHeader = request.headers().asHttpHeaders().getFirst("Host");
String originHeader = request.headers().asHttpHeaders().getFirst("Origin");
if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) {
logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, originHeader);
return ServerResponse.status(HttpStatus.FORBIDDEN).body("DNS rebinding protection validation failed");
}
}

String sessionId = UUID.randomUUID().toString();
logger.debug("Creating new SSE connection for session: {}", sessionId);

Expand Down Expand Up @@ -300,6 +335,23 @@ private ServerResponse handleMessage(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
}

// Always validate Content-Type for POST requests
String contentType = request.headers().asHttpHeaders().getFirst("Content-Type");
if (contentType == null || !contentType.toLowerCase().startsWith("application/json")) {
logger.warn("Invalid Content-Type header: '{}'", contentType);
return ServerResponse.badRequest().body(new McpError("Content-Type must be application/json"));
}

// Validate headers for POST requests if DNS rebinding protection is configured
if (dnsRebindingProtectionConfig != null) {
String hostHeader = request.headers().asHttpHeaders().getFirst("Host");
String originHeader = request.headers().asHttpHeaders().getFirst("Origin");
if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) {
logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, originHeader);
return ServerResponse.status(HttpStatus.FORBIDDEN).body("DNS rebinding protection validation failed");
}
}

if (request.param("sessionId").isEmpty()) {
return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint"));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package io.modelcontextprotocol.server.transport;

import java.util.Collections;
import java.util.HashSet;
import java.util.Set;

/**
* Configuration for DNS rebinding protection in SSE server transports. Provides
* validation for Host and Origin headers to prevent DNS rebinding attacks.
*/
public class DnsRebindingProtectionConfig {

private final Set<String> allowedHosts;

private final Set<String> allowedOrigins;

private final boolean enableDnsRebindingProtection;

private DnsRebindingProtectionConfig(Builder builder) {
this.allowedHosts = Collections.unmodifiableSet(new HashSet<>(builder.allowedHosts));
this.allowedOrigins = Collections.unmodifiableSet(new HashSet<>(builder.allowedOrigins));
this.enableDnsRebindingProtection = builder.enableDnsRebindingProtection;
}

/**
* Validates Host and Origin headers for DNS rebinding protection. Returns true if the
* headers are valid, false otherwise.
* @param hostHeader The value of the Host header (may be null)
* @param originHeader The value of the Origin header (may be null)
* @return true if the headers are valid, false otherwise
*/
public boolean validate(String hostHeader, String originHeader) {
// Skip validation if protection is not enabled
if (!enableDnsRebindingProtection) {
return true;
}

// Validate Host header
if (hostHeader != null) {
String lowerHost = hostHeader.toLowerCase();
if (!allowedHosts.contains(lowerHost)) {
return false;
}
}

// Validate Origin header
if (originHeader != null) {
String lowerOrigin = originHeader.toLowerCase();
if (!allowedOrigins.contains(lowerOrigin)) {
return false;
}
}

return true;
}

public static Builder builder() {
return new Builder();
}

public static class Builder {

private final Set<String> allowedHosts = new HashSet<>();

private final Set<String> allowedOrigins = new HashSet<>();

private boolean enableDnsRebindingProtection = true;

public Builder allowedHost(String host) {
if (host != null) {
this.allowedHosts.add(host.toLowerCase());
}
return this;
}

public Builder allowedHosts(Set<String> hosts) {
if (hosts != null) {
hosts.forEach(this::allowedHost);
}
return this;
}

public Builder allowedOrigin(String origin) {
if (origin != null) {
this.allowedOrigins.add(origin.toLowerCase());
}
return this;
}

public Builder allowedOrigins(Set<String> origins) {
if (origins != null) {
origins.forEach(this::allowedOrigin);
}
return this;
}

public Builder enableDnsRebindingProtection(boolean enable) {
this.enableDnsRebindingProtection = enable;
return this;
}

public DnsRebindingProtectionConfig build() {
return new DnsRebindingProtectionConfig(this);
}

}

}
Loading