Skip to content

Support Progress Flow for McpClient #389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ public class McpAsyncClient {
public static final TypeReference<LoggingMessageNotification> LOGGING_MESSAGE_NOTIFICATION_TYPE_REF = new TypeReference<>() {
};

public static final TypeReference<McpSchema.ProgressNotification> PROGRESS_NOTIFICATION_TYPE_REF = new TypeReference<>() {
};

/**
* Client capabilities.
*/
Expand Down Expand Up @@ -253,6 +256,16 @@ public class McpAsyncClient {
notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE,
asyncLoggingNotificationHandler(loggingConsumersFinal));

// Utility Progress Notification
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumersFinal = new ArrayList<>();
progressConsumersFinal
.add((notification) -> Mono.fromRunnable(() -> logger.debug("Progress: {}", notification)));
if (!Utils.isEmpty(features.progressConsumers())) {
progressConsumersFinal.addAll(features.progressConsumers());
}
notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_PROGRESS,
asyncProgressNotificationHandler(progressConsumersFinal));

this.initializer = new LifecycleInitializer(clientCapabilities, clientInfo,
List.of(McpSchema.LATEST_PROTOCOL_VERSION), initializationTimeout,
ctx -> new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers,
Expand Down Expand Up @@ -828,6 +841,19 @@ private NotificationHandler asyncLoggingNotificationHandler(
};
}

private NotificationHandler asyncProgressNotificationHandler(
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers) {

return params -> {
McpSchema.ProgressNotification progressNotification = transport.unmarshalFrom(params,
PROGRESS_NOTIFICATION_TYPE_REF);

return Flux.fromIterable(progressConsumers)
.flatMap(consumer -> consumer.apply(progressNotification))
.then();
};
}

/**
* Sets the minimum logging level for messages received from the server. The client
* will only receive log messages at or above the specified severity level.
Expand Down
41 changes: 38 additions & 3 deletions mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ class SyncSpec {

private final List<Consumer<McpSchema.LoggingMessageNotification>> loggingConsumers = new ArrayList<>();

private final List<Consumer<McpSchema.ProgressNotification>> progressConsumers = new ArrayList<>();

private Function<CreateMessageRequest, CreateMessageResult> samplingHandler;

private Function<ElicitRequest, ElicitResult> elicitationHandler;
Expand Down Expand Up @@ -377,6 +379,36 @@ public SyncSpec loggingConsumers(List<Consumer<McpSchema.LoggingMessageNotificat
return this;
}

/**
* Adds a consumer to be notified of progress notifications from the server. This
* allows the client to track long-running operations and provide feedback to
* users.
* @param progressConsumer A consumer that receives progress notifications. Must
* not be null.
* @return This builder instance for method chaining
* @throws IllegalArgumentException if progressConsumer is null
*/
public SyncSpec progressConsumer(Consumer<McpSchema.ProgressNotification> progressConsumer) {
Assert.notNull(progressConsumer, "Progress consumer must not be null");
this.progressConsumers.add(progressConsumer);
return this;
}

/**
* Adds a multiple consumers to be notified of progress notifications from the
* server. This allows the client to track long-running operations and provide
* feedback to users.
* @param progressConsumers A list of consumers that receives progress
* notifications. Must not be null.
* @return This builder instance for method chaining
* @throws IllegalArgumentException if progressConsumer is null
*/
public SyncSpec progressConsumers(List<Consumer<McpSchema.ProgressNotification>> progressConsumers) {
Assert.notNull(progressConsumers, "Progress consumers must not be null");
this.progressConsumers.addAll(progressConsumers);
return this;
}

/**
* Create an instance of {@link McpSyncClient} with the provided configurations or
* sensible defaults.
Expand All @@ -385,7 +417,8 @@ public SyncSpec loggingConsumers(List<Consumer<McpSchema.LoggingMessageNotificat
public McpSyncClient build() {
McpClientFeatures.Sync syncFeatures = new McpClientFeatures.Sync(this.clientInfo, this.capabilities,
this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers,
this.promptsChangeConsumers, this.loggingConsumers, this.samplingHandler, this.elicitationHandler);
this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers, this.samplingHandler,
this.elicitationHandler);

McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures);

Expand Down Expand Up @@ -435,6 +468,8 @@ class AsyncSpec {

private final List<Function<McpSchema.LoggingMessageNotification, Mono<Void>>> loggingConsumers = new ArrayList<>();

private final List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers = new ArrayList<>();

private Function<CreateMessageRequest, Mono<CreateMessageResult>> samplingHandler;

private Function<ElicitRequest, Mono<ElicitResult>> elicitationHandler;
Expand Down Expand Up @@ -663,8 +698,8 @@ public McpAsyncClient build() {
return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout,
new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots,
this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers,
this.promptsChangeConsumers, this.loggingConsumers, this.samplingHandler,
this.elicitationHandler));
this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers,
this.samplingHandler, this.elicitationHandler));
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class McpClientFeatures {
* @param resourcesChangeConsumers the resources change consumers.
* @param promptsChangeConsumers the prompts change consumers.
* @param loggingConsumers the logging consumers.
* @param progressConsumers the progress consumers.
* @param samplingHandler the sampling handler.
* @param elicitationHandler the elicitation handler.
*/
Expand All @@ -68,6 +69,7 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
List<Function<List<McpSchema.ResourceContents>, Mono<Void>>> resourcesUpdateConsumers,
List<Function<List<McpSchema.Prompt>, Mono<Void>>> promptsChangeConsumers,
List<Function<McpSchema.LoggingMessageNotification, Mono<Void>>> loggingConsumers,
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers,
Function<McpSchema.CreateMessageRequest, Mono<McpSchema.CreateMessageResult>> samplingHandler,
Function<McpSchema.ElicitRequest, Mono<McpSchema.ElicitResult>> elicitationHandler) {

Expand All @@ -79,6 +81,7 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
* @param resourcesChangeConsumers the resources change consumers.
* @param promptsChangeConsumers the prompts change consumers.
* @param loggingConsumers the logging consumers.
* @param progressConsumers the progressconsumers.
* @param samplingHandler the sampling handler.
* @param elicitationHandler the elicitation handler.
*/
Expand All @@ -89,6 +92,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
List<Function<List<McpSchema.ResourceContents>, Mono<Void>>> resourcesUpdateConsumers,
List<Function<List<McpSchema.Prompt>, Mono<Void>>> promptsChangeConsumers,
List<Function<McpSchema.LoggingMessageNotification, Mono<Void>>> loggingConsumers,
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers,
Function<McpSchema.CreateMessageRequest, Mono<McpSchema.CreateMessageResult>> samplingHandler,
Function<McpSchema.ElicitRequest, Mono<McpSchema.ElicitResult>> elicitationHandler) {

Expand All @@ -106,6 +110,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
this.resourcesUpdateConsumers = resourcesUpdateConsumers != null ? resourcesUpdateConsumers : List.of();
this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of();
this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of();
this.progressConsumers = progressConsumers != null ? progressConsumers : List.of();
this.samplingHandler = samplingHandler;
this.elicitationHandler = elicitationHandler;
}
Expand Down Expand Up @@ -149,6 +154,12 @@ public static Async fromSync(Sync syncSpec) {
.subscribeOn(Schedulers.boundedElastic()));
}

List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers = new ArrayList<>();
for (Consumer<McpSchema.ProgressNotification> consumer : syncSpec.progressConsumers()) {
progressConsumers.add(p -> Mono.<Void>fromRunnable(() -> consumer.accept(p))
.subscribeOn(Schedulers.boundedElastic()));
}

Function<McpSchema.CreateMessageRequest, Mono<McpSchema.CreateMessageResult>> samplingHandler = r -> Mono
.fromCallable(() -> syncSpec.samplingHandler().apply(r))
.subscribeOn(Schedulers.boundedElastic());
Expand All @@ -159,7 +170,7 @@ public static Async fromSync(Sync syncSpec) {

return new Async(syncSpec.clientInfo(), syncSpec.clientCapabilities(), syncSpec.roots(),
toolsChangeConsumers, resourcesChangeConsumers, resourcesUpdateConsumers, promptsChangeConsumers,
loggingConsumers, samplingHandler, elicitationHandler);
loggingConsumers, progressConsumers, samplingHandler, elicitationHandler);
}
}

Expand All @@ -174,6 +185,7 @@ public static Async fromSync(Sync syncSpec) {
* @param resourcesChangeConsumers the resources change consumers.
* @param promptsChangeConsumers the prompts change consumers.
* @param loggingConsumers the logging consumers.
* @param progressConsumers the progress consumers.
* @param samplingHandler the sampling handler.
* @param elicitationHandler the elicitation handler.
*/
Expand All @@ -183,6 +195,7 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili
List<Consumer<List<McpSchema.ResourceContents>>> resourcesUpdateConsumers,
List<Consumer<List<McpSchema.Prompt>>> promptsChangeConsumers,
List<Consumer<McpSchema.LoggingMessageNotification>> loggingConsumers,
List<Consumer<McpSchema.ProgressNotification>> progressConsumers,
Function<McpSchema.CreateMessageRequest, McpSchema.CreateMessageResult> samplingHandler,
Function<McpSchema.ElicitRequest, McpSchema.ElicitResult> elicitationHandler) {

Expand All @@ -196,6 +209,7 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili
* @param resourcesUpdateConsumers the resource update consumers.
* @param promptsChangeConsumers the prompts change consumers.
* @param loggingConsumers the logging consumers.
* @param progressConsumers the progress consumers.
* @param samplingHandler the sampling handler.
* @param elicitationHandler the elicitation handler.
*/
Expand All @@ -205,6 +219,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl
List<Consumer<List<McpSchema.ResourceContents>>> resourcesUpdateConsumers,
List<Consumer<List<McpSchema.Prompt>>> promptsChangeConsumers,
List<Consumer<McpSchema.LoggingMessageNotification>> loggingConsumers,
List<Consumer<McpSchema.ProgressNotification>> progressConsumers,
Function<McpSchema.CreateMessageRequest, McpSchema.CreateMessageResult> samplingHandler,
Function<McpSchema.ElicitRequest, McpSchema.ElicitResult> elicitationHandler) {

Expand All @@ -222,6 +237,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl
this.resourcesUpdateConsumers = resourcesUpdateConsumers != null ? resourcesUpdateConsumers : List.of();
this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of();
this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of();
this.progressConsumers = progressConsumers != null ? progressConsumers : List.of();
this.samplingHandler = samplingHandler;
this.elicitationHandler = elicitationHandler;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,20 @@ public Mono<Void> loggingNotification(LoggingMessageNotification loggingMessageN
});
}

/**
* Sends a notification to the client that the current progress status has changed for
* long-running operations.
* @param progressNotification The progress notification to send
* @return A Mono that completes when the notification has been sent
*/
public Mono<Void> progressNotification(McpSchema.ProgressNotification progressNotification) {
if (progressNotification == null) {
return Mono.error(new McpError("Progress notification must not be null"));
}

return this.session.sendNotification(McpSchema.METHOD_NOTIFICATION_PROGRESS, progressNotification);
}

/**
* Sends a ping request to the client.
* @return A Mono that completes with clients's ping response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se
* represents a specific capability.
*
* @param tool The tool definition including name, description, and parameter schema
* @param call Deprecated. Uset he {@link AsyncToolSpecification#callHandler} instead.
* @param call Deprecated. Use the {@link AsyncToolSpecification#callHandler} instead.
* @param callHandler The function that implements the tool's logic, receiving a
* {@link McpAsyncServerExchange} and a
* {@link io.modelcontextprotocol.spec.McpSchema.CallToolRequest} and returning
Expand Down
20 changes: 13 additions & 7 deletions mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.modelcontextprotocol.util.Assert;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.util.annotation.Nullable;

/**
* Based on the <a href="http://www.jsonrpc.org/specification">JSON-RPC 2.0
Expand Down Expand Up @@ -58,6 +59,8 @@ private McpSchema() {

public static final String METHOD_PING = "ping";

public static final String METHOD_NOTIFICATION_PROGRESS = "notifications/progress";

// Tool Methods
public static final String METHOD_TOOLS_LIST = "tools/list";

Expand Down Expand Up @@ -249,7 +252,7 @@ public record InitializeRequest( // @formatter:off
@JsonProperty("capabilities") ClientCapabilities capabilities,
@JsonProperty("clientInfo") Implementation clientInfo,
@JsonProperty("_meta") Map<String, Object> meta) implements Request {

public InitializeRequest(String protocolVersion, ClientCapabilities capabilities, Implementation clientInfo) {
this(protocolVersion, capabilities, clientInfo, null);
}
Expand Down Expand Up @@ -462,7 +465,7 @@ public ServerCapabilities build() {
public record Implementation(// @formatter:off
@JsonProperty("name") String name,
@JsonProperty("title") String title,
@JsonProperty("version") String version) implements BaseMetadata {// @formatter:on
@JsonProperty("version") String version) implements BaseMetadata {// @formatter:on

public Implementation(String name, String version) {
this(name, null, version);
Expand Down Expand Up @@ -1053,6 +1056,8 @@ private static JsonSchema parseSchema(String schema) {
* tools/list.
* @param arguments Arguments to pass to the tool. These must conform to the tool's
* input schema.
* @param meta Optional metadata about the request. This can include additional
* information like `progressToken`
*/
@JsonInclude(JsonInclude.Include.NON_ABSENT)
@JsonIgnoreProperties(ignoreUnknown = true)
Expand All @@ -1063,9 +1068,10 @@ public record CallToolRequest(// @formatter:off

public CallToolRequest(String name, String jsonArguments) {
this(name, parseJsonArguments(jsonArguments), null);
}
public CallToolRequest(String name, Map<String, Object> arguments) {
this(name, arguments, null);
}

public CallToolRequest(String name, Map<String, Object> arguments) {
this(name, arguments, null);
}

private static Map<String, Object> parseJsonArguments(String jsonArguments) {
Expand Down Expand Up @@ -1317,7 +1323,7 @@ public record CreateMessageRequest(// @formatter:off
@JsonProperty("metadata") Map<String, Object> metadata,
@JsonProperty("_meta") Map<String, Object> meta) implements Request {


// backwards compatibility constructor
public CreateMessageRequest(List<SamplingMessage> messages, ModelPreferences modelPreferences,
String systemPrompt, ContextInclusionStrategy includeContext,
Expand Down Expand Up @@ -1771,7 +1777,7 @@ public CompleteRequest(McpSchema.CompleteReference ref, CompleteArgument argumen
public CompleteRequest(McpSchema.CompleteReference ref, CompleteArgument argument) {
this(ref, argument, null, null);
}

public record CompleteArgument(
@JsonProperty("name") String name,
@JsonProperty("value") String value) {
Expand Down
Loading