Skip to content

Commit a04d261

Browse files
committed
Implement for client
BREAKING : Client Specification record is braking
1 parent 74cfafa commit a04d261

File tree

3 files changed

+77
-4
lines changed

3 files changed

+77
-4
lines changed

mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,17 @@ public class McpAsyncClient {
234234
notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE,
235235
asyncLoggingNotificationHandler(loggingConsumersFinal));
236236

237-
this.mcpSession = new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers);
237+
// Utility Progress Notification
238+
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumersFinal = new ArrayList<>();
239+
progressConsumersFinal
240+
.add((notification) -> Mono.fromRunnable(() -> logger.debug("Progress: {}", notification)));
241+
if (!Utils.isEmpty(features.progressConsumers())) {
242+
progressConsumersFinal.addAll(features.progressConsumers());
243+
}
244+
notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_PROGRESS,
245+
asyncProgressNotificationHandler(progressConsumersFinal));
238246

247+
this.mcpSession = new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers);
239248
}
240249

241250
/**
@@ -789,6 +798,20 @@ private NotificationHandler asyncLoggingNotificationHandler(
789798
};
790799
}
791800

801+
private NotificationHandler asyncProgressNotificationHandler(
802+
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers) {
803+
804+
return params -> {
805+
McpSchema.ProgressNotification progressNotification = transport.unmarshalFrom(params,
806+
new TypeReference<McpSchema.ProgressNotification>() {
807+
});
808+
809+
return Flux.fromIterable(progressConsumers)
810+
.flatMap(consumer -> consumer.apply(progressNotification))
811+
.then();
812+
};
813+
}
814+
792815
/**
793816
* Sets the minimum logging level for messages received from the server. The client
794817
* will only receive log messages at or above the specified severity level.

mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ class SyncSpec {
173173

174174
private final List<Consumer<McpSchema.LoggingMessageNotification>> loggingConsumers = new ArrayList<>();
175175

176+
private final List<Consumer<McpSchema.ProgressNotification>> progressConsumers = new ArrayList<>();
177+
176178
private Function<CreateMessageRequest, CreateMessageResult> samplingHandler;
177179

178180
private SyncSpec(McpClientTransport transport) {
@@ -356,6 +358,36 @@ public SyncSpec loggingConsumers(List<Consumer<McpSchema.LoggingMessageNotificat
356358
return this;
357359
}
358360

361+
/**
362+
* Adds a consumer to be notified of progress notifications from the server. This
363+
* allows the client to track long-running operations and provide feedback to
364+
* users.
365+
* @param progressConsumer A consumer that receives progress notifications. Must
366+
* not be null.
367+
* @return This builder instance for method chaining
368+
* @throws IllegalArgumentException if progressConsumer is null
369+
*/
370+
public SyncSpec progressConsumer(Consumer<McpSchema.ProgressNotification> progressConsumer) {
371+
Assert.notNull(progressConsumer, "Progress consumer must not be null");
372+
this.progressConsumers.add(progressConsumer);
373+
return this;
374+
}
375+
376+
/**
377+
* Adds a multiple consumers to be notified of progress notifications from the
378+
* server. This allows the client to track long-running operations and provide
379+
* feedback to users.
380+
* @param progressConsumers A list of consumers that receives progress
381+
* notifications. Must not be null.
382+
* @return This builder instance for method chaining
383+
* @throws IllegalArgumentException if progressConsumer is null
384+
*/
385+
public SyncSpec progressConsumers(List<Consumer<McpSchema.ProgressNotification>> progressConsumers) {
386+
Assert.notNull(progressConsumers, "Progress consumers must not be null");
387+
this.progressConsumers.addAll(progressConsumers);
388+
return this;
389+
}
390+
359391
/**
360392
* Create an instance of {@link McpSyncClient} with the provided configurations or
361393
* sensible defaults.
@@ -364,7 +396,7 @@ public SyncSpec loggingConsumers(List<Consumer<McpSchema.LoggingMessageNotificat
364396
public McpSyncClient build() {
365397
McpClientFeatures.Sync syncFeatures = new McpClientFeatures.Sync(this.clientInfo, this.capabilities,
366398
this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.promptsChangeConsumers,
367-
this.loggingConsumers, this.samplingHandler);
399+
this.loggingConsumers, this.progressConsumers, this.samplingHandler);
368400

369401
McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures);
370402

@@ -412,6 +444,8 @@ class AsyncSpec {
412444

413445
private final List<Function<McpSchema.LoggingMessageNotification, Mono<Void>>> loggingConsumers = new ArrayList<>();
414446

447+
private final List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers = new ArrayList<>();
448+
415449
private Function<CreateMessageRequest, Mono<CreateMessageResult>> samplingHandler;
416450

417451
private AsyncSpec(McpClientTransport transport) {
@@ -606,7 +640,7 @@ public McpAsyncClient build() {
606640
return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout,
607641
new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots,
608642
this.toolsChangeConsumers, this.resourcesChangeConsumers, this.promptsChangeConsumers,
609-
this.loggingConsumers, this.samplingHandler));
643+
this.loggingConsumers, this.progressConsumers, this.samplingHandler));
610644
}
611645

612646
}

mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,15 @@ class McpClientFeatures {
5959
* @param resourcesChangeConsumers the resources change consumers.
6060
* @param promptsChangeConsumers the prompts change consumers.
6161
* @param loggingConsumers the logging consumers.
62+
* @param progressConsumers the progress consumers.
6263
* @param samplingHandler the sampling handler.
6364
*/
6465
record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities,
6566
Map<String, McpSchema.Root> roots, List<Function<List<McpSchema.Tool>, Mono<Void>>> toolsChangeConsumers,
6667
List<Function<List<McpSchema.Resource>, Mono<Void>>> resourcesChangeConsumers,
6768
List<Function<List<McpSchema.Prompt>, Mono<Void>>> promptsChangeConsumers,
6869
List<Function<McpSchema.LoggingMessageNotification, Mono<Void>>> loggingConsumers,
70+
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers,
6971
Function<McpSchema.CreateMessageRequest, Mono<McpSchema.CreateMessageResult>> samplingHandler) {
7072

7173
/**
@@ -76,6 +78,7 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
7678
* @param resourcesChangeConsumers the resources change consumers.
7779
* @param promptsChangeConsumers the prompts change consumers.
7880
* @param loggingConsumers the logging consumers.
81+
* @param progressConsumers the progressconsumers.
7982
* @param samplingHandler the sampling handler.
8083
*/
8184
public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities,
@@ -84,6 +87,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
8487
List<Function<List<McpSchema.Resource>, Mono<Void>>> resourcesChangeConsumers,
8588
List<Function<List<McpSchema.Prompt>, Mono<Void>>> promptsChangeConsumers,
8689
List<Function<McpSchema.LoggingMessageNotification, Mono<Void>>> loggingConsumers,
90+
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers,
8791
Function<McpSchema.CreateMessageRequest, Mono<McpSchema.CreateMessageResult>> samplingHandler) {
8892

8993
Assert.notNull(clientInfo, "Client info must not be null");
@@ -98,6 +102,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
98102
this.resourcesChangeConsumers = resourcesChangeConsumers != null ? resourcesChangeConsumers : List.of();
99103
this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of();
100104
this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of();
105+
this.progressConsumers = progressConsumers != null ? progressConsumers : List.of();
101106
this.samplingHandler = samplingHandler;
102107
}
103108

@@ -135,12 +140,18 @@ public static Async fromSync(Sync syncSpec) {
135140
.subscribeOn(Schedulers.boundedElastic()));
136141
}
137142

143+
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers = new ArrayList<>();
144+
for (Consumer<McpSchema.ProgressNotification> consumer : syncSpec.progressConsumers()) {
145+
progressConsumers.add(p -> Mono.<Void>fromRunnable(() -> consumer.accept(p))
146+
.subscribeOn(Schedulers.boundedElastic()));
147+
}
148+
138149
Function<McpSchema.CreateMessageRequest, Mono<McpSchema.CreateMessageResult>> samplingHandler = r -> Mono
139150
.fromCallable(() -> syncSpec.samplingHandler().apply(r))
140151
.subscribeOn(Schedulers.boundedElastic());
141152
return new Async(syncSpec.clientInfo(), syncSpec.clientCapabilities(), syncSpec.roots(),
142153
toolsChangeConsumers, resourcesChangeConsumers, promptsChangeConsumers, loggingConsumers,
143-
samplingHandler);
154+
progressConsumers, samplingHandler);
144155
}
145156
}
146157

@@ -155,13 +166,15 @@ public static Async fromSync(Sync syncSpec) {
155166
* @param resourcesChangeConsumers the resources change consumers.
156167
* @param promptsChangeConsumers the prompts change consumers.
157168
* @param loggingConsumers the logging consumers.
169+
* @param progressConsumers the progress consumers.
158170
* @param samplingHandler the sampling handler.
159171
*/
160172
public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities,
161173
Map<String, McpSchema.Root> roots, List<Consumer<List<McpSchema.Tool>>> toolsChangeConsumers,
162174
List<Consumer<List<McpSchema.Resource>>> resourcesChangeConsumers,
163175
List<Consumer<List<McpSchema.Prompt>>> promptsChangeConsumers,
164176
List<Consumer<McpSchema.LoggingMessageNotification>> loggingConsumers,
177+
List<Consumer<McpSchema.ProgressNotification>> progressConsumers,
165178
Function<McpSchema.CreateMessageRequest, McpSchema.CreateMessageResult> samplingHandler) {
166179

167180
/**
@@ -173,13 +186,15 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili
173186
* @param resourcesChangeConsumers the resources change consumers.
174187
* @param promptsChangeConsumers the prompts change consumers.
175188
* @param loggingConsumers the logging consumers.
189+
* @param progressConsumers the progress consumers.
176190
* @param samplingHandler the sampling handler.
177191
*/
178192
public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities,
179193
Map<String, McpSchema.Root> roots, List<Consumer<List<McpSchema.Tool>>> toolsChangeConsumers,
180194
List<Consumer<List<McpSchema.Resource>>> resourcesChangeConsumers,
181195
List<Consumer<List<McpSchema.Prompt>>> promptsChangeConsumers,
182196
List<Consumer<McpSchema.LoggingMessageNotification>> loggingConsumers,
197+
List<Consumer<McpSchema.ProgressNotification>> progressConsumers,
183198
Function<McpSchema.CreateMessageRequest, McpSchema.CreateMessageResult> samplingHandler) {
184199

185200
Assert.notNull(clientInfo, "Client info must not be null");
@@ -194,6 +209,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl
194209
this.resourcesChangeConsumers = resourcesChangeConsumers != null ? resourcesChangeConsumers : List.of();
195210
this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of();
196211
this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of();
212+
this.progressConsumers = progressConsumers != null ? progressConsumers : List.of();
197213
this.samplingHandler = samplingHandler;
198214
}
199215
}

0 commit comments

Comments
 (0)