Skip to content

support auth by HTTP basic auth #205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package io.modelcontextprotocol.server.transport;

import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.server.McpServerAuthParam;
import io.modelcontextprotocol.server.McpServerAuthProvider;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
Expand Down Expand Up @@ -60,7 +61,9 @@
* @author Christian Tzolov
* @author Alexandros Pappas
* @author Dariusz Jędrzejczyk
* @author lambochen
* @see McpServerTransport
* @see McpServerAuthProvider
* @see ServerSentEvent
*/
public class WebFluxSseServerTransportProvider implements McpServerTransportProvider {
Expand Down Expand Up @@ -100,6 +103,11 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv

private McpServerSession.Factory sessionFactory;

/**
* auth provider
*/
private final McpServerAuthProvider authProvider;

/**
* Map of active client sessions, keyed by session ID.
*/
Expand Down Expand Up @@ -149,6 +157,22 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa
*/
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint) {
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null);
}

/**
* Constructs a new WebFlux SSE server transport provider instance.
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
* of MCP messages. Must not be null.
* @param baseUrl webflux message base path
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
* messages. This endpoint will be communicated to clients during SSE connection
* setup. Must not be null.
* @param authProvider auth provider
* @throws IllegalArgumentException if either parameter is null
*/
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint, McpServerAuthProvider authProvider) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
Assert.notNull(baseUrl, "Message base path must not be null");
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
Expand All @@ -158,6 +182,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU
this.baseUrl = baseUrl;
this.messageEndpoint = messageEndpoint;
this.sseEndpoint = sseEndpoint;
this.authProvider = authProvider;
this.routerFunction = RouterFunctions.route()
.GET(this.sseEndpoint, this::handleSseConnection)
.POST(this.messageEndpoint, this::handleMessage)
Expand Down Expand Up @@ -256,6 +281,10 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
}

if (null != authProvider && !authProvider.authenticate(assemblyAuthParam(request))) {
return ServerResponse.status(HttpStatus.UNAUTHORIZED).bodyValue("Unauthorized");
}

return ServerResponse.ok()
.contentType(MediaType.TEXT_EVENT_STREAM)
.body(Flux.<ServerSentEvent<?>>create(sink -> {
Expand All @@ -280,6 +309,14 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
}), ServerSentEvent.class);
}

private McpServerAuthParam assemblyAuthParam(ServerRequest request) {
return McpServerAuthParam.builder()
.sseEndpoint(this.sseEndpoint)
.uri(request.uri().toString())
.params(request.queryParams().toSingleValueMap())
.build();
}

/**
* Handles incoming JSON-RPC messages from clients. Deserializes the message and
* processes it through the configured message handler.
Expand Down Expand Up @@ -397,6 +434,18 @@ public static class Builder {

private String sseEndpoint = DEFAULT_SSE_ENDPOINT;

private McpServerAuthProvider authProvider = null;

/**
* Sets the authentication provider.
* @param authProvider the authentication provider
* @return this builder instance
*/
public Builder authProvider(McpServerAuthProvider authProvider) {
this.authProvider = authProvider;
return this;
}

/**
* Sets the ObjectMapper to use for JSON serialization/deserialization of MCP
* messages.
Expand Down Expand Up @@ -457,7 +506,8 @@ public WebFluxSseServerTransportProvider build() {
Assert.notNull(objectMapper, "ObjectMapper must be set");
Assert.notNull(messageEndpoint, "Message endpoint must be set");

return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint);
return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint,
authProvider);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@

import java.io.IOException;
import java.time.Duration;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.server.McpServerAuthParam;
import io.modelcontextprotocol.server.McpServerAuthProvider;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.spec.McpServerSession;
Expand Down Expand Up @@ -63,7 +64,9 @@
*
* @author Christian Tzolov
* @author Alexandros Pappas
* @author lambochen
* @see McpServerTransportProvider
* @see McpServerAuthProvider
* @see RouterFunction
*/
public class WebMvcSseServerTransportProvider implements McpServerTransportProvider {
Expand Down Expand Up @@ -107,6 +110,8 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi
*/
private volatile boolean isClosing = false;

private final McpServerAuthProvider authProvider;

/**
* Constructs a new WebMvcSseServerTransportProvider instance with the default SSE
* endpoint.
Expand Down Expand Up @@ -149,6 +154,24 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag
*/
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint) {
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null);
}

/**
* Constructs a new WebMvcSseServerTransportProvider instance.
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
* of messages.
* @param baseUrl The base URL for the message endpoint, used to construct the full
* endpoint URL for clients.
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
* messages via HTTP POST. This endpoint will be communicated to clients through the
* SSE connection's initial endpoint event.
* @param sseEndpoint The endpoint URI where clients establish their SSE connections.
* @param authProvider The authentication provider to use for authentication.
* @throws IllegalArgumentException if any parameter is null
*/
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint, McpServerAuthProvider authProvider) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
Assert.notNull(baseUrl, "Message base URL must not be null");
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
Expand All @@ -158,6 +181,7 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUr
this.baseUrl = baseUrl;
this.messageEndpoint = messageEndpoint;
this.sseEndpoint = sseEndpoint;
this.authProvider = authProvider;
this.routerFunction = RouterFunctions.route()
.GET(this.sseEndpoint, this::handleSseConnection)
.POST(this.messageEndpoint, this::handleMessage)
Expand Down Expand Up @@ -247,6 +271,10 @@ private ServerResponse handleSseConnection(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
}

if (null != authProvider && !authProvider.authenticate(assemblyAuthParam(request))) {
return ServerResponse.status(HttpStatus.UNAUTHORIZED).body("Unauthorized");
}

String sessionId = UUID.randomUUID().toString();
logger.debug("Creating new SSE connection for session: {}", sessionId);

Expand Down Expand Up @@ -284,6 +312,14 @@ private ServerResponse handleSseConnection(ServerRequest request) {
}
}

private McpServerAuthParam assemblyAuthParam(ServerRequest request) {
return McpServerAuthParam.builder()
.sseEndpoint(this.sseEndpoint)
.uri(request.uri().toString())
.params(request.params().toSingleValueMap())
.build();
}

/**
* Handles incoming JSON-RPC messages from clients. This method:
* <ul>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -75,6 +76,11 @@ public class HttpClientSseClientTransport implements McpClientTransport {
/** SSE endpoint path */
private final String sseEndpoint;

/**
* Additional parameters for the connect( client to server).
*/
private final Map<String, String> params;

/** SSE client for handling server-sent events. Uses the /sse endpoint */
private final FlowSseClient sseClient;

Expand Down Expand Up @@ -174,13 +180,19 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques
*/
HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri,
String sseEndpoint, ObjectMapper objectMapper) {
this(httpClient, requestBuilder, baseUri, sseEndpoint, objectMapper, null);
}

HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri,
String sseEndpoint, ObjectMapper objectMapper, Map<String, String> params) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
Assert.hasText(baseUri, "baseUri must not be empty");
Assert.hasText(sseEndpoint, "sseEndpoint must not be empty");
Assert.notNull(httpClient, "httpClient must not be null");
Assert.notNull(requestBuilder, "requestBuilder must not be null");
this.baseUri = URI.create(baseUri);
this.sseEndpoint = sseEndpoint;
this.params = params;
this.objectMapper = objectMapper;
this.httpClient = httpClient;
this.requestBuilder = requestBuilder;
Expand All @@ -206,6 +218,8 @@ public static class Builder {

private String sseEndpoint = DEFAULT_SSE_ENDPOINT;

private Map<String, String> params;

private HttpClient.Builder clientBuilder = HttpClient.newBuilder()
.version(HttpClient.Version.HTTP_1_1)
.connectTimeout(Duration.ofSeconds(10));
Expand Down Expand Up @@ -257,6 +271,16 @@ public Builder sseEndpoint(String sseEndpoint) {
return this;
}

/**
* Sets the request params.
* @param params the request params
* @return this builder
*/
public Builder params(Map<String, String> params) {
this.params = params;
return this;
}

/**
* Sets the HTTP client builder.
* @param clientBuilder the HTTP client builder
Expand Down Expand Up @@ -318,7 +342,7 @@ public Builder objectMapper(ObjectMapper objectMapper) {
*/
public HttpClientSseClientTransport build() {
return new HttpClientSseClientTransport(clientBuilder.build(), requestBuilder, baseUri, sseEndpoint,
objectMapper);
objectMapper, params);
}

}
Expand All @@ -342,7 +366,7 @@ public Mono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> h
connectionFuture.set(future);

URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint);
sseClient.subscribe(clientUri.toString(), new FlowSseClient.SseEventHandler() {
sseClient.subscribe(assemblyUri(clientUri.toString(), this.params), new FlowSseClient.SseEventHandler() {
@Override
public void onEvent(SseEvent event) {
if (isClosing) {
Expand Down Expand Up @@ -382,6 +406,23 @@ public void onError(Throwable error) {
return Mono.fromFuture(future);
}

/**
* Assembles the full URI with the base URI and the params.
* @param baseUri baseUri + sseEndpoint
* @param params additional params
* @return full uri: baseUri + sseEndpoint + params
*/
private String assemblyUri(String baseUri, Map<String, String> params) {
if (null == params || params.isEmpty()) {
return baseUri;
}
StringBuilder uri = new StringBuilder(baseUri);
uri.append("?");
params.forEach((k, v) -> uri.append(k).append("=").append(v).append("&"));
uri.deleteCharAt(uri.length() - 1);
return uri.toString();
}

/**
* Sends a JSON-RPC message to the server.
*
Expand Down
Loading