diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/AsyncHttpRequestCustomizer.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/AsyncHttpRequestCustomizer.java new file mode 100644 index 00000000..40f286e6 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/AsyncHttpRequestCustomizer.java @@ -0,0 +1,44 @@ +package io.modelcontextprotocol.client.transport; + +import java.net.URI; +import java.net.http.HttpRequest; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; + +/** + * Customize {@link HttpRequest.Builder} before sending out SSE or Streamable HTTP + * transport. + *

+ * When used in a non-blocking context, implementations MUST be non-blocking. + */ +public interface AsyncHttpRequestCustomizer { + + Publisher customize(HttpRequest.Builder builder, String method, URI endpoint, + @Nullable String body); + + AsyncHttpRequestCustomizer NOOP = new Noop(); + + /** + * Wrap a sync implementation in an async wrapper. + *

+ * Do NOT use in a non-blocking context. + */ + static AsyncHttpRequestCustomizer fromSync(SyncHttpRequestCustomizer customizer) { + return (builder, method, uri, body) -> Mono.defer(() -> { + customizer.customize(builder, method, uri, body); + return Mono.just(builder); + }); + } + + class Noop implements AsyncHttpRequestCustomizer { + + @Override + public Publisher customize(HttpRequest.Builder builder, String method, URI endpoint, + String body) { + return Mono.just(builder); + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 271f3823..28f43f28 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -102,6 +102,9 @@ public class HttpClientSseClientTransport implements McpClientTransport { */ protected final Sinks.One messageEndpointSink = Sinks.one(); + // TODO + private final AsyncHttpRequestCustomizer httpRequestCustomizer; + /** * Creates a new transport instance with default HTTP client and object mapper. * @param baseUri the base URI of the MCP server @@ -172,18 +175,38 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques * @param objectMapper the object mapper for JSON serialization/deserialization * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null */ + @Deprecated(forRemoval = true) HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, String sseEndpoint, ObjectMapper objectMapper) { + this(httpClient, requestBuilder, baseUri, sseEndpoint, objectMapper, AsyncHttpRequestCustomizer.NOOP); + } + + /** + * Creates a new transport instance with custom HTTP client builder, object mapper, + * and headers. + * @param httpClient the HTTP client to use + * @param requestBuilder the HTTP request builder to use + * @param baseUri the base URI of the MCP server + * @param sseEndpoint the SSE endpoint path + * @param objectMapper the object mapper for JSON serialization/deserialization + * @param httpRequestCustomizer customizer for the requestBuilder before sending + * requests + * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null + */ + HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, + String sseEndpoint, ObjectMapper objectMapper, AsyncHttpRequestCustomizer httpRequestCustomizer) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.hasText(baseUri, "baseUri must not be empty"); Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); Assert.notNull(httpClient, "httpClient must not be null"); Assert.notNull(requestBuilder, "requestBuilder must not be null"); + Assert.notNull(httpRequestCustomizer, "httpRequestCustomizer must not be null"); this.baseUri = URI.create(baseUri); this.sseEndpoint = sseEndpoint; this.objectMapper = objectMapper; this.httpClient = httpClient; this.requestBuilder = requestBuilder; + this.httpRequestCustomizer = httpRequestCustomizer; } /** @@ -213,6 +236,8 @@ public static class Builder { private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder() .header("Content-Type", "application/json"); + private AsyncHttpRequestCustomizer httpRequestCustomizer = AsyncHttpRequestCustomizer.NOOP; + /** * Creates a new builder instance. */ @@ -310,96 +335,111 @@ public Builder objectMapper(ObjectMapper objectMapper) { return this; } + /** + * In reactive, DONT USE THIS. Use AsyncHttpRequestCustomizer. + */ + public Builder httpRequestCustomizer(SyncHttpRequestCustomizer syncHttpRequestCustomizer) { + this.httpRequestCustomizer = AsyncHttpRequestCustomizer.fromSync(syncHttpRequestCustomizer); + return this; + } + + public Builder httpRequestCustomizer(AsyncHttpRequestCustomizer asyncHttpRequestCustomizer) { + this.httpRequestCustomizer = asyncHttpRequestCustomizer; + return this; + } + /** * Builds a new {@link HttpClientSseClientTransport} instance. * @return a new transport instance */ public HttpClientSseClientTransport build() { return new HttpClientSseClientTransport(clientBuilder.build(), requestBuilder, baseUri, sseEndpoint, - objectMapper); + objectMapper, httpRequestCustomizer); } } @Override public Mono connect(Function, Mono> handler) { + var uri = Utils.resolveUri(this.baseUri, this.sseEndpoint); - return Mono.create(sink -> { - - HttpRequest request = requestBuilder.copy() - .uri(Utils.resolveUri(this.baseUri, this.sseEndpoint)) + return Mono + .just(requestBuilder.copy() + .uri(uri) .header("Accept", "text/event-stream") .header("Cache-Control", "no-cache") - .GET() - .build(); - - Disposable connection = Flux.create(sseSink -> this.httpClient - .sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) - .exceptionallyCompose(e -> { - sseSink.error(e); - return CompletableFuture.failedFuture(e); - })) - .map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent) - .flatMap(responseEvent -> { - if (isClosing) { - return Mono.empty(); - } - - int statusCode = responseEvent.responseInfo().statusCode(); + .GET()) + .flatMap(builder -> Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null))) + .map(HttpRequest.Builder::build) + .flatMap(request -> Mono.create(sink -> { + Disposable connection = Flux.create(sseSink -> this.httpClient + .sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) + .exceptionallyCompose(e -> { + sseSink.error(e); + return CompletableFuture.failedFuture(e); + })) + .map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent) + .flatMap(responseEvent -> { + if (isClosing) { + return Mono.empty(); + } - if (statusCode >= 200 && statusCode < 300) { - try { - if (ENDPOINT_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { - String messageEndpointUri = responseEvent.sseEvent().data(); - if (this.messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) { + int statusCode = responseEvent.responseInfo().statusCode(); + + if (statusCode >= 200 && statusCode < 300) { + try { + if (ENDPOINT_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { + String messageEndpointUri = responseEvent.sseEvent().data(); + if (this.messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) { + sink.success(); + return Flux.empty(); // No further processing + // needed + } + else { + sink.error(new McpError("Failed to handle SSE endpoint event")); + } + } + else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, + responseEvent.sseEvent().data()); sink.success(); - return Flux.empty(); // No further processing needed + return Flux.just(message); } else { - sink.error(new McpError("Failed to handle SSE endpoint event")); + logger.error("Received unrecognized SSE event type: {}", + responseEvent.sseEvent().event()); + sink.error(new McpError("Received unrecognized SSE event type: " + + responseEvent.sseEvent().event())); } } - else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, - responseEvent.sseEvent().data()); - sink.success(); - return Flux.just(message); - } - else { - logger.error("Received unrecognized SSE event type: {}", - responseEvent.sseEvent().event()); - sink.error(new McpError( - "Received unrecognized SSE event type: " + responseEvent.sseEvent().event())); + catch (IOException e) { + logger.error("Error processing SSE event", e); + sink.error(new McpError("Error processing SSE event")); } } - catch (IOException e) { - logger.error("Error processing SSE event", e); - sink.error(new McpError("Error processing SSE event")); + return Flux.error( + new RuntimeException("Failed to send message: " + responseEvent)); + + }) + .flatMap(jsonRpcMessage -> handler.apply(Mono.just(jsonRpcMessage))) + .onErrorComplete(t -> { + if (!isClosing) { + logger.warn("SSE stream observed an error", t); + sink.error(t); } - } - return Flux.error( - new RuntimeException("Failed to send message: " + responseEvent)); - - }) - .flatMap(jsonRpcMessage -> handler.apply(Mono.just(jsonRpcMessage))) - .onErrorComplete(t -> { - if (!isClosing) { - logger.warn("SSE stream observed an error", t); - sink.error(t); - } - return true; - }) - .doFinally(s -> { - Disposable ref = this.sseSubscription.getAndSet(null); - if (ref != null && !ref.isDisposed()) { - ref.dispose(); - } - }) - .contextWrite(sink.contextView()) - .subscribe(); + return true; + }) + .doFinally(s -> { + Disposable ref = this.sseSubscription.getAndSet(null); + if (ref != null && !ref.isDisposed()) { + ref.dispose(); + } + }) + .contextWrite(sink.contextView()) + .subscribe(); - this.sseSubscription.set(connection); - }); + this.sseSubscription.set(connection); + })); } /** @@ -455,13 +495,11 @@ private Mono serializeMessage(final JSONRPCMessage message) { private Mono> sendHttpPost(final String endpoint, final String body) { final URI requestUri = Utils.resolveUri(baseUri, endpoint); - final HttpRequest request = this.requestBuilder.copy() - .uri(requestUri) - .POST(HttpRequest.BodyPublishers.ofString(body)) - .build(); - - // TODO: why discard the body? - return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding())); + return Mono.just(this.requestBuilder.copy().uri(requestUri).POST(HttpRequest.BodyPublishers.ofString(body))) + .flatMap(builder -> Mono.from(this.httpRequestCustomizer.customize(builder, "POST", requestUri, body))) + .map(HttpRequest.Builder::build) + // TODO: why discard the body? + .flatMap(request -> Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding()))); } /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java index 4cf1690f..969204ce 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java @@ -109,6 +109,8 @@ public class HttpClientStreamableHttpTransport implements McpClientTransport { private final boolean resumableStreams; + private final AsyncHttpRequestCustomizer httpRequestCustomizer; + private final AtomicReference activeSession = new AtomicReference<>(); private final AtomicReference, Mono>> handler = new AtomicReference<>(); @@ -117,7 +119,7 @@ public class HttpClientStreamableHttpTransport implements McpClientTransport { private HttpClientStreamableHttpTransport(ObjectMapper objectMapper, HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, String endpoint, boolean resumableStreams, - boolean openConnectionOnStartup) { + boolean openConnectionOnStartup, AsyncHttpRequestCustomizer httpRequestCustomizer) { this.objectMapper = objectMapper; this.httpClient = httpClient; this.requestBuilder = requestBuilder; @@ -126,6 +128,7 @@ private HttpClientStreamableHttpTransport(ObjectMapper objectMapper, HttpClient this.resumableStreams = resumableStreams; this.openConnectionOnStartup = openConnectionOnStartup; this.activeSession.set(createTransportSession()); + this.httpRequestCustomizer = httpRequestCustomizer; } public static Builder builder(String baseUri) { @@ -154,14 +157,18 @@ private DefaultMcpTransportSession createTransportSession() { } private Publisher createDelete(String sessionId) { - HttpRequest request = this.requestBuilder.copy() - .uri(Utils.resolveUri(this.baseUri, this.endpoint)) - .header("Cache-Control", "no-cache") - .header("mcp-session-id", sessionId) - .DELETE() - .build(); - - return Mono.fromFuture(() -> this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())).then(); + var uri = Utils.resolveUri(this.baseUri, this.endpoint); + return Mono + .just(this.requestBuilder.copy() + .uri(uri) + .header("Cache-Control", "no-cache") + .header("mcp-session-id", sessionId) + .DELETE()) + .flatMap(builder -> Mono.from(this.httpRequestCustomizer.customize(builder, "DELETE", uri, null))) + .map(HttpRequest.Builder::build) + .flatMap(request -> Mono + .fromFuture(() -> this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()))) + .then(); } @Override @@ -208,96 +215,109 @@ private Mono reconnect(McpTransportStream stream) { final AtomicReference disposableRef = new AtomicReference<>(); final McpTransportSession transportSession = this.activeSession.get(); + var uri = Utils.resolveUri(this.baseUri, this.endpoint); - HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); - - if (transportSession != null && transportSession.sessionId().isPresent()) { - requestBuilder = requestBuilder.header("mcp-session-id", transportSession.sessionId().get()); - } - - if (stream != null && stream.lastId().isPresent()) { - requestBuilder = requestBuilder.header("last-event-id", stream.lastId().get()); - } - - HttpRequest request = requestBuilder.uri(Utils.resolveUri(this.baseUri, this.endpoint)) - .header("Accept", TEXT_EVENT_STREAM) - .header("Cache-Control", "no-cache") - .GET() - .build(); - - Disposable connection = Flux.create(sseSink -> this.httpClient - .sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) - .whenComplete((response, throwable) -> { - if (throwable != null) { - sseSink.error(throwable); - } - else { - logger.debug("SSE connection established successfully"); - } - })) - .map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent) - .flatMap(responseEvent -> { - int statusCode = responseEvent.responseInfo().statusCode(); - - if (statusCode >= 200 && statusCode < 300) { - - if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { - try { - // We don't support batching ATM and probably won't since - // the - // next version considers removing it. - McpSchema.JSONRPCMessage message = McpSchema - .deserializeJsonRpcMessage(this.objectMapper, responseEvent.sseEvent().data()); + Disposable connection = Mono.fromCallable(() -> { + HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); - Tuple2, Iterable> idWithMessages = Tuples - .of(Optional.ofNullable(responseEvent.sseEvent().id()), List.of(message)); + if (transportSession != null && transportSession.sessionId().isPresent()) { + requestBuilder = requestBuilder.header("mcp-session-id", transportSession.sessionId().get()); + } - McpTransportStream sessionStream = stream != null ? stream - : new DefaultMcpTransportStream<>(this.resumableStreams, this::reconnect); - logger.debug("Connected stream {}", sessionStream.streamId()); + if (stream != null && stream.lastId().isPresent()) { + requestBuilder = requestBuilder.header("last-event-id", stream.lastId().get()); + } - return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); + return requestBuilder.uri(uri) + .header("Accept", TEXT_EVENT_STREAM) + .header("Cache-Control", "no-cache") + .GET(); + }) + .flatMap(builder -> Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null))) + .map(HttpRequest.Builder::build) + .flatMapMany( + request -> Flux.create( + sseSink -> this.httpClient + .sendAsync(request, + responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, + sseSink)) + .whenComplete((response, throwable) -> { + if (throwable != null) { + sseSink.error(throwable); + } + else { + logger.debug("SSE connection established successfully"); + } + })) + .map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent) + .flatMap(responseEvent -> { + int statusCode = responseEvent.responseInfo().statusCode(); + + if (statusCode >= 200 && statusCode < 300) { + + if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { + try { + // We don't support batching ATM and probably + // won't since + // the + // next version considers removing it. + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage( + this.objectMapper, responseEvent.sseEvent().data()); + + Tuple2, Iterable> idWithMessages = Tuples + .of(Optional.ofNullable(responseEvent.sseEvent().id()), + List.of(message)); + + McpTransportStream sessionStream = stream != null ? stream + : new DefaultMcpTransportStream<>(this.resumableStreams, + this::reconnect); + logger.debug("Connected stream {}", sessionStream.streamId()); + + return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); + + } + catch (IOException ioException) { + return Flux.error( + new McpError("Error parsing JSON-RPC message: " + + responseEvent.sseEvent().data())); + } + } + } + else if (statusCode == METHOD_NOT_ALLOWED) { // NotAllowed + logger + .debug("The server does not support SSE streams, using request-response mode."); + return Flux.empty(); + } + else if (statusCode == NOT_FOUND) { + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } + else if (statusCode == BAD_REQUEST) { + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } - } - catch (IOException ioException) { return Flux.error(new McpError( - "Error parsing JSON-RPC message: " + responseEvent.sseEvent().data())); - } - } - } - else if (statusCode == METHOD_NOT_ALLOWED) { // NotAllowed - logger.debug("The server does not support SSE streams, using request-response mode."); - return Flux.empty(); - } - else if (statusCode == NOT_FOUND) { - String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); - McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( - "Session not found for session ID: " + sessionIdRepresentation); - return Flux.error(exception); - } - else if (statusCode == BAD_REQUEST) { - String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); - McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( - "Session not found for session ID: " + sessionIdRepresentation); - return Flux.error(exception); - } - - return Flux.error( - new McpError("Received unrecognized SSE event type: " + responseEvent.sseEvent().event())); - - }).flatMap(jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage))) - .onErrorMap(CompletionException.class, t -> t.getCause()) - .onErrorComplete(t -> { - this.handleException(t); - return true; - }) - .doFinally(s -> { - Disposable ref = disposableRef.getAndSet(null); - if (ref != null) { - transportSession.removeConnection(ref); - } - }) + "Received unrecognized SSE event type: " + responseEvent.sseEvent().event())); + + }).flatMap( + jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage))) + .onErrorMap(CompletionException.class, t -> t.getCause()) + .onErrorComplete(t -> { + this.handleException(t); + return true; + }) + .doFinally(s -> { + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + })) .contextWrite(ctx) .subscribe(); @@ -348,125 +368,136 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sendMessage) { final AtomicReference disposableRef = new AtomicReference<>(); final McpTransportSession transportSession = this.activeSession.get(); - - HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); - - if (transportSession != null && transportSession.sessionId().isPresent()) { - requestBuilder = requestBuilder.header("mcp-session-id", transportSession.sessionId().get()); - } - + var uri = Utils.resolveUri(this.baseUri, this.endpoint); String jsonBody = this.toString(sendMessage); - HttpRequest request = requestBuilder.uri(Utils.resolveUri(this.baseUri, this.endpoint)) - .header("Accept", TEXT_EVENT_STREAM + ", " + APPLICATION_JSON) - .header("Content-Type", APPLICATION_JSON) - .header("Cache-Control", "no-cache") - .POST(HttpRequest.BodyPublishers.ofString(jsonBody)) - .build(); - - Disposable connection = Flux.create(responseEventSink -> { + Disposable connection = Mono.fromCallable(() -> { + HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); - // Create the async request with proper body subscriber selection - Mono.fromFuture(this.httpClient.sendAsync(request, this.toSendMessageBodySubscriber(responseEventSink)) - .whenComplete((response, throwable) -> { - if (throwable != null) { - responseEventSink.error(throwable); - } - else { - logger.debug("SSE connection established successfully"); - } - })).onErrorMap(CompletionException.class, t -> t.getCause()).onErrorComplete().subscribe(); - - }).flatMap(responseEvent -> { - if (transportSession.markInitialized( - responseEvent.responseInfo().headers().firstValue("mcp-session-id").orElseGet(() -> null))) { - // Once we have a session, we try to open an async stream for - // the server to send notifications and requests out-of-band. - - reconnect(null).contextWrite(messageSink.contextView()).subscribe(); + if (transportSession != null && transportSession.sessionId().isPresent()) { + requestBuilder = requestBuilder.header("mcp-session-id", transportSession.sessionId().get()); } - String sessionRepresentation = sessionIdOrPlaceholder(transportSession); - - int statusCode = responseEvent.responseInfo().statusCode(); - - if (statusCode >= 200 && statusCode < 300) { + return requestBuilder.uri(uri) + .header("Accept", TEXT_EVENT_STREAM + ", " + APPLICATION_JSON) + .header("Content-Type", APPLICATION_JSON) + .header("Cache-Control", "no-cache") + .POST(HttpRequest.BodyPublishers.ofString(jsonBody)); + }) + .flatMap(builder -> Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, jsonBody))) + .map(HttpRequest.Builder::build) + .flatMapMany(request -> Flux.create(responseEventSink -> { + + // Create the async request with proper body subscriber selection + Mono.fromFuture( + this.httpClient.sendAsync(request, this.toSendMessageBodySubscriber(responseEventSink)) + .whenComplete((response, throwable) -> { + if (throwable != null) { + responseEventSink.error(throwable); + } + else { + logger.debug("SSE connection established successfully"); + } + })) + .onErrorMap(CompletionException.class, t -> t.getCause()) + .onErrorComplete() + .subscribe(); - String contentType = responseEvent.responseInfo() + })) + .flatMap(responseEvent -> { + if (transportSession.markInitialized(responseEvent.responseInfo() .headers() - .firstValue("Content-Type") - .orElse("") - .toLowerCase(); - - if (contentType.isBlank()) { - logger.debug("No content type returned for POST in session {}", sessionRepresentation); - // No content type means no response body, so we can just return - // an empty stream - messageSink.success(); - return Flux.empty(); - } - else if (contentType.contains(TEXT_EVENT_STREAM)) { - return Flux.just(((ResponseSubscribers.SseResponseEvent) responseEvent).sseEvent()) - .flatMap(sseEvent -> { - try { - // We don't support batching ATM and probably won't - // since the - // next version considers removing it. - McpSchema.JSONRPCMessage message = McpSchema - .deserializeJsonRpcMessage(this.objectMapper, sseEvent.data()); + .firstValue("mcp-session-id") + .orElseGet(() -> null))) { + // Once we have a session, we try to open an async stream for + // the server to send notifications and requests out-of-band. - Tuple2, Iterable> idWithMessages = Tuples - .of(Optional.ofNullable(sseEvent.id()), List.of(message)); + reconnect(null).contextWrite(messageSink.contextView()).subscribe(); + } - McpTransportStream sessionStream = new DefaultMcpTransportStream<>( - this.resumableStreams, this::reconnect); + String sessionRepresentation = sessionIdOrPlaceholder(transportSession); - logger.debug("Connected stream {}", sessionStream.streamId()); + int statusCode = responseEvent.responseInfo().statusCode(); - messageSink.success(); + if (statusCode >= 200 && statusCode < 300) { - return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); - } - catch (IOException ioException) { - return Flux.error( - new McpError("Error parsing JSON-RPC message: " + sseEvent.data())); - } - }); - } - else if (contentType.contains(APPLICATION_JSON)) { - messageSink.success(); - String data = ((ResponseSubscribers.AggregateResponseEvent) responseEvent).data(); - try { - return Mono.just(McpSchema.deserializeJsonRpcMessage(objectMapper, data)); + String contentType = responseEvent.responseInfo() + .headers() + .firstValue("Content-Type") + .orElse("") + .toLowerCase(); + + if (contentType.isBlank()) { + logger.debug("No content type returned for POST in session {}", sessionRepresentation); + // No content type means no response body, so we can just + // return + // an empty stream + messageSink.success(); + return Flux.empty(); } - catch (IOException e) { - return Mono.error(e); + else if (contentType.contains(TEXT_EVENT_STREAM)) { + return Flux.just(((ResponseSubscribers.SseResponseEvent) responseEvent).sseEvent()) + .flatMap(sseEvent -> { + try { + // We don't support batching ATM and probably + // won't + // since the + // next version considers removing it. + McpSchema.JSONRPCMessage message = McpSchema + .deserializeJsonRpcMessage(this.objectMapper, sseEvent.data()); + + Tuple2, Iterable> idWithMessages = Tuples + .of(Optional.ofNullable(sseEvent.id()), List.of(message)); + + McpTransportStream sessionStream = new DefaultMcpTransportStream<>( + this.resumableStreams, this::reconnect); + + logger.debug("Connected stream {}", sessionStream.streamId()); + + messageSink.success(); + + return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); + } + catch (IOException ioException) { + return Flux.error( + new McpError("Error parsing JSON-RPC message: " + sseEvent.data())); + } + }); } + else if (contentType.contains(APPLICATION_JSON)) { + messageSink.success(); + String data = ((ResponseSubscribers.AggregateResponseEvent) responseEvent).data(); + try { + return Mono.just(McpSchema.deserializeJsonRpcMessage(objectMapper, data)); + } + catch (IOException e) { + return Mono.error(e); + } + } + logger.warn("Unknown media type {} returned for POST in session {}", contentType, + sessionRepresentation); + + return Flux.error( + new RuntimeException("Unknown media type returned: " + contentType)); + } + else if (statusCode == NOT_FOUND) { + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionRepresentation); + return Flux.error(exception); + } + // Some implementations can return 400 when presented with a + // session id that it doesn't know about, so we will + // invalidate the session + // https://github.com/modelcontextprotocol/typescript-sdk/issues/389 + else if (statusCode == BAD_REQUEST) { + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionRepresentation); + return Flux.error(exception); } - logger.warn("Unknown media type {} returned for POST in session {}", contentType, - sessionRepresentation); return Flux.error( - new RuntimeException("Unknown media type returned: " + contentType)); - } - else if (statusCode == NOT_FOUND) { - McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( - "Session not found for session ID: " + sessionRepresentation); - return Flux.error(exception); - } - // Some implementations can return 400 when presented with a - // session id that it doesn't know about, so we will - // invalidate the session - // https://github.com/modelcontextprotocol/typescript-sdk/issues/389 - else if (statusCode == BAD_REQUEST) { - McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( - "Session not found for session ID: " + sessionRepresentation); - return Flux.error(exception); - } - - return Flux.error( - new RuntimeException("Failed to send message: " + responseEvent)); - }) + new RuntimeException("Failed to send message: " + responseEvent)); + }) .flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))) .onErrorMap(CompletionException.class, t -> t.getCause()) .onErrorComplete(t -> { @@ -521,6 +552,8 @@ public static class Builder { private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(); + private AsyncHttpRequestCustomizer httpRequestCustomizer = AsyncHttpRequestCustomizer.NOOP; + /** * Creates a new builder with the specified base URI. * @param baseUri the base URI of the MCP server @@ -623,6 +656,19 @@ public Builder openConnectionOnStartup(boolean openConnectionOnStartup) { return this; } + /** + * In reactive, DONT USE THIS. Use AsyncHttpRequestCustomizer. + */ + public Builder httpRequestCustomizer(SyncHttpRequestCustomizer syncHttpRequestCustomizer) { + this.httpRequestCustomizer = AsyncHttpRequestCustomizer.fromSync(syncHttpRequestCustomizer); + return this; + } + + public Builder httpRequestCustomizer(AsyncHttpRequestCustomizer asyncHttpRequestCustomizer) { + this.httpRequestCustomizer = asyncHttpRequestCustomizer; + return this; + } + /** * Construct a fresh instance of {@link HttpClientStreamableHttpTransport} using * the current builder configuration. @@ -632,7 +678,7 @@ public HttpClientStreamableHttpTransport build() { ObjectMapper objectMapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); return new HttpClientStreamableHttpTransport(objectMapper, clientBuilder.build(), requestBuilder, baseUri, - endpoint, resumableStreams, openConnectionOnStartup); + endpoint, resumableStreams, openConnectionOnStartup, httpRequestCustomizer); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/SyncHttpRequestCustomizer.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/SyncHttpRequestCustomizer.java new file mode 100644 index 00000000..97ac6850 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/SyncHttpRequestCustomizer.java @@ -0,0 +1,15 @@ +package io.modelcontextprotocol.client.transport; + +import java.net.URI; +import java.net.http.HttpRequest; +import reactor.util.annotation.Nullable; + +/** + * Customize {@link HttpRequest.Builder} before sending out SSE or Streamable HTTP + * transport. + */ +public interface SyncHttpRequestCustomizer { + + void customize(HttpRequest.Builder builder, String method, URI endpoint, @Nullable String body); + +}