From 9ae973656e858dd06d7dae6aed514bff126a4208 Mon Sep 17 00:00:00 2001 From: Daniel Garnier-Moiroux Date: Thu, 12 Feb 2026 10:21:36 +0100 Subject: [PATCH] Polish gh-771 Signed-off-by: Daniel Garnier-Moiroux --- ...faultServerTransportSecurityValidator.java | 9 ++--- .../transport/HttpServletRequestUtils.java | 40 +++++++++++++++++++ ...HttpServletSseServerTransportProvider.java | 22 +--------- .../HttpServletStatelessServerTransport.java | 20 +--------- ...vletStreamableServerTransportProvider.java | 24 ++--------- 5 files changed, 49 insertions(+), 66 deletions(-) create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletRequestUtils.java diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java index 5321aada7..cae05c01a 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java @@ -21,13 +21,10 @@ * @see ServerTransportSecurityValidator * @see ServerTransportSecurityException */ -public class DefaultServerTransportSecurityValidator implements ServerTransportSecurityValidator { +public final class DefaultServerTransportSecurityValidator implements ServerTransportSecurityValidator { private static final String ORIGIN_HEADER = "Origin"; - private static final ServerTransportSecurityException INVALID_ORIGIN = new ServerTransportSecurityException(403, - "Invalid Origin header"); - private final List allowedOrigins; /** @@ -35,7 +32,7 @@ public class DefaultServerTransportSecurityValidator implements ServerTransportS * @param allowedOrigins List of allowed origin patterns. Supports exact matches * (e.g., "http://example.com:8080") and wildcard ports (e.g., "http://example.com:*") */ - public DefaultServerTransportSecurityValidator(List allowedOrigins) { + private DefaultServerTransportSecurityValidator(List allowedOrigins) { Assert.notNull(allowedOrigins, "allowedOrigins must not be null"); this.allowedOrigins = allowedOrigins; } @@ -79,7 +76,7 @@ else if (allowed.endsWith(":*")) { } - throw INVALID_ORIGIN; + throw new ServerTransportSecurityException(403, "Invalid Origin header"); } /** diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletRequestUtils.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletRequestUtils.java new file mode 100644 index 000000000..32246948c --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletRequestUtils.java @@ -0,0 +1,40 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import jakarta.servlet.http.HttpServletRequest; + +/** + * Utility methods for working with {@link HttpServletRequest}. For internal use only. + * + * @author Daniel Garnier-Moiroux + */ +final class HttpServletRequestUtils { + + private HttpServletRequestUtils() { + } + + /** + * Extracts all headers from the HTTP request into a map. + * @param request The HTTP servlet request + * @return A map of header names to their values + */ + static Map> extractHeaders(HttpServletRequest request) { + Map> headers = new HashMap<>(); + Enumeration names = request.getHeaderNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + headers.put(name, Collections.list(request.getHeaders(name))); + } + return headers; + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index c07906b49..3b31eb949 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -8,9 +8,6 @@ import java.io.IOException; import java.io.PrintWriter; import java.time.Duration; -import java.util.Collections; -import java.util.Enumeration; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; @@ -258,7 +255,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) } try { - Map> headers = extractHeaders(request); + Map> headers = HttpServletRequestUtils.extractHeaders(request); this.securityValidator.validateHeaders(headers); } catch (ServerTransportSecurityException e) { @@ -332,7 +329,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) } try { - Map> headers = extractHeaders(request); + Map> headers = HttpServletRequestUtils.extractHeaders(request); this.securityValidator.validateHeaders(headers); } catch (ServerTransportSecurityException e) { @@ -440,21 +437,6 @@ private void sendEvent(PrintWriter writer, String eventType, String data) throws } } - /** - * Extracts all headers from the HTTP servlet request into a map. - * @param request The HTTP servlet request - * @return A map of header names to their values - */ - private Map> extractHeaders(HttpServletRequest request) { - Map> headers = new HashMap<>(); - Enumeration names = request.getHeaderNames(); - while (names.hasMoreElements()) { - String name = names.nextElement(); - headers.put(name, Collections.list(request.getHeaders(name))); - } - return headers; - } - /** * Cleans up resources when the servlet is being destroyed. *

diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java index 6431a2cd2..af01c709d 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java @@ -7,9 +7,6 @@ import java.io.BufferedReader; import java.io.IOException; import java.io.PrintWriter; -import java.util.Collections; -import java.util.Enumeration; -import java.util.HashMap; import java.util.List; import java.util.Map; @@ -137,7 +134,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) } try { - Map> headers = extractHeaders(request); + Map> headers = HttpServletRequestUtils.extractHeaders(request); this.securityValidator.validateHeaders(headers); } catch (ServerTransportSecurityException e) { @@ -232,21 +229,6 @@ private void responseError(HttpServletResponse response, int httpCode, McpError writer.flush(); } - /** - * Extracts all headers from the HTTP servlet request into a map. - * @param request The HTTP servlet request - * @return A map of header names to their values - */ - private Map> extractHeaders(HttpServletRequest request) { - Map> headers = new HashMap<>(); - Enumeration names = request.getHeaderNames(); - while (names.hasMoreElements()) { - String name = names.nextElement(); - headers.put(name, Collections.list(request.getHeaders(name))); - } - return headers; - } - /** * Cleans up resources when the servlet is being destroyed. *

diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java index 18cdcff96..07dc3467b 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -9,9 +9,6 @@ import java.io.PrintWriter; import java.time.Duration; import java.util.ArrayList; -import java.util.Collections; -import java.util.Enumeration; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -262,7 +259,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) } try { - Map> headers = extractHeaders(request); + Map> headers = HttpServletRequestUtils.extractHeaders(request); this.securityValidator.validateHeaders(headers); } catch (ServerTransportSecurityException e) { @@ -398,7 +395,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) } try { - Map> headers = extractHeaders(request); + Map> headers = HttpServletRequestUtils.extractHeaders(request); this.securityValidator.validateHeaders(headers); } catch (ServerTransportSecurityException e) { @@ -570,7 +567,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response } try { - Map> headers = extractHeaders(request); + Map> headers = HttpServletRequestUtils.extractHeaders(request); this.securityValidator.validateHeaders(headers); } catch (ServerTransportSecurityException e) { @@ -628,21 +625,6 @@ public void responseError(HttpServletResponse response, int httpCode, McpError m return; } - /** - * Extracts all headers from the HTTP servlet request into a map. - * @param request The HTTP servlet request - * @return A map of header names to their values - */ - private Map> extractHeaders(HttpServletRequest request) { - Map> headers = new HashMap<>(); - Enumeration names = request.getHeaderNames(); - while (names.hasMoreElements()) { - String name = names.nextElement(); - headers.put(name, Collections.list(request.getHeaders(name))); - } - return headers; - } - /** * Sends an SSE event to a client with a specific ID. * @param writer The writer to send the event through