Skip to content

Commit 8f3cb34

Browse files
authored
Merge pull request #2 from kalvinnchau/mcp-prompt-spec
2 parents fe1108c + ef9309c commit 8f3cb34

File tree

7 files changed

+159
-51
lines changed

7 files changed

+159
-51
lines changed

crates/mcp-client/examples/stdio_integration.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,5 +82,16 @@ async fn main() -> Result<(), ClientError> {
8282
let resource = client.read_resource("memo://insights").await?;
8383
println!("Resource: {resource:?}\n");
8484

85+
let prompts = client.list_prompts(None).await?;
86+
println!("Prompts: {prompts:?}\n");
87+
88+
let prompt = client
89+
.get_prompt(
90+
"example_prompt",
91+
serde_json::json!({"message": "hello there!"}),
92+
)
93+
.await?;
94+
println!("Prompt: {prompt:?}\n");
95+
8596
Ok(())
8697
}

crates/mcp-client/src/client.rs

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use mcp_core::protocol::{
2-
CallToolResult, Implementation, InitializeResult, JsonRpcError, JsonRpcMessage,
3-
JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListResourcesResult, ListToolsResult,
4-
ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND,
2+
CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcError,
3+
JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListPromptsResult,
4+
ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND,
55
};
66
use serde::{Deserialize, Serialize};
77
use serde_json::Value;
@@ -93,6 +93,10 @@ pub trait McpClientTrait: Send + Sync {
9393
async fn list_tools(&self, next_cursor: Option<String>) -> Result<ListToolsResult, Error>;
9494

9595
async fn call_tool(&self, name: &str, arguments: Value) -> Result<CallToolResult, Error>;
96+
97+
async fn list_prompts(&self, next_cursor: Option<String>) -> Result<ListPromptsResult, Error>;
98+
99+
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error>;
96100
}
97101

98102
/// The MCP client is the interface for MCP operations.
@@ -346,4 +350,42 @@ where
346350
// https://modelcontextprotocol.io/docs/concepts/tools#error-handling-2
347351
self.send_request("tools/call", params).await
348352
}
353+
354+
async fn list_prompts(&self, next_cursor: Option<String>) -> Result<ListPromptsResult, Error> {
355+
if !self.completed_initialization() {
356+
return Err(Error::NotInitialized);
357+
}
358+
359+
// If prompts is not supported, return an error
360+
if self.server_capabilities.as_ref().unwrap().prompts.is_none() {
361+
return Err(Error::RpcError {
362+
code: METHOD_NOT_FOUND,
363+
message: "Server does not support 'prompts' capability".to_string(),
364+
});
365+
}
366+
367+
let payload = next_cursor
368+
.map(|cursor| serde_json::json!({"cursor": cursor}))
369+
.unwrap_or_else(|| serde_json::json!({}));
370+
371+
self.send_request("prompts/list", payload).await
372+
}
373+
374+
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error> {
375+
if !self.completed_initialization() {
376+
return Err(Error::NotInitialized);
377+
}
378+
379+
// If prompts is not supported, return an error
380+
if self.server_capabilities.as_ref().unwrap().prompts.is_none() {
381+
return Err(Error::RpcError {
382+
code: METHOD_NOT_FOUND,
383+
message: "Server does not support 'prompts' capability".to_string(),
384+
});
385+
}
386+
387+
let params = serde_json::json!({ "name": name, "arguments": arguments });
388+
389+
self.send_request("prompts/get", params).await
390+
}
349391
}

crates/mcp-client/src/transport/sse.rs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,23 @@ impl SseActor {
111111
// Attempt to parse the SSE data as a JsonRpcMessage
112112
match serde_json::from_str::<JsonRpcMessage>(&e.data) {
113113
Ok(message) => {
114-
// If it's a response, complete the pending request
115-
if let JsonRpcMessage::Response(resp) = &message {
116-
if let Some(id) = &resp.id {
117-
pending_requests.respond(&id.to_string(), Ok(message)).await;
114+
match &message {
115+
JsonRpcMessage::Response(response) => {
116+
if let Some(id) = &response.id {
117+
pending_requests
118+
.respond(&id.to_string(), Ok(message))
119+
.await;
120+
}
118121
}
122+
JsonRpcMessage::Error(error) => {
123+
if let Some(id) = &error.id {
124+
pending_requests
125+
.respond(&id.to_string(), Ok(message))
126+
.await;
127+
}
128+
}
129+
_ => {} // TODO: Handle other variants (Request, etc.)
119130
}
120-
// If it's something else (notification, etc.), handle as needed
121131
}
122132
Err(err) => {
123133
warn!("Failed to parse SSE message: {err}");

crates/mcp-client/src/transport/stdio.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,18 @@ impl StdioActor {
8787
"Received incoming message"
8888
);
8989

90-
if let JsonRpcMessage::Response(response) = &message {
91-
if let Some(id) = &response.id {
92-
pending_requests.respond(&id.to_string(), Ok(message)).await;
90+
match &message {
91+
JsonRpcMessage::Response(response) => {
92+
if let Some(id) = &response.id {
93+
pending_requests.respond(&id.to_string(), Ok(message)).await;
94+
}
9395
}
96+
JsonRpcMessage::Error(error) => {
97+
if let Some(id) = &error.id {
98+
pending_requests.respond(&id.to_string(), Ok(message)).await;
99+
}
100+
}
101+
_ => {} // TODO: Handle other variants (Request, etc.)
94102
}
95103
}
96104
line.clear();

crates/mcp-core/src/prompt.rs

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,28 @@ use serde::{Deserialize, Serialize};
1010
pub struct Prompt {
1111
/// The name of the prompt
1212
pub name: String,
13-
/// A description of what the prompt does
14-
pub description: String,
15-
/// The arguments that can be passed to customize the prompt
16-
pub arguments: Vec<PromptArgument>,
13+
/// Optional description of what the prompt does
14+
#[serde(skip_serializing_if = "Option::is_none")]
15+
pub description: Option<String>,
16+
/// Optional arguments that can be passed to customize the prompt
17+
#[serde(skip_serializing_if = "Option::is_none")]
18+
pub arguments: Option<Vec<PromptArgument>>,
1719
}
1820

1921
impl Prompt {
2022
/// Create a new prompt with the given name, description and arguments
21-
pub fn new<N, D>(name: N, description: D, arguments: Vec<PromptArgument>) -> Self
23+
pub fn new<N, D>(
24+
name: N,
25+
description: Option<D>,
26+
arguments: Option<Vec<PromptArgument>>,
27+
) -> Self
2228
where
2329
N: Into<String>,
2430
D: Into<String>,
2531
{
2632
Prompt {
2733
name: name.into(),
28-
description: description.into(),
34+
description: description.map(Into::into),
2935
arguments,
3036
}
3137
}
@@ -37,9 +43,11 @@ pub struct PromptArgument {
3743
/// The name of the argument
3844
pub name: String,
3945
/// A description of what the argument is used for
40-
pub description: String,
46+
#[serde(skip_serializing_if = "Option::is_none")]
47+
pub description: Option<String>,
4148
/// Whether this argument is required
42-
pub required: bool,
49+
#[serde(skip_serializing_if = "Option::is_none")]
50+
pub required: Option<bool>,
4351
}
4452

4553
/// Represents the role of a message sender in a prompt conversation
@@ -151,6 +159,6 @@ pub struct PromptTemplate {
151159
#[derive(Debug, Serialize, Deserialize)]
152160
pub struct PromptArgumentTemplate {
153161
pub name: String,
154-
pub description: String,
155-
pub required: bool,
162+
pub description: Option<String>,
163+
pub required: Option<bool>,
156164
}

crates/mcp-server/src/main.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use anyhow::Result;
22
use mcp_core::content::Content;
3-
use mcp_core::handler::ResourceError;
3+
use mcp_core::handler::{PromptError, ResourceError};
4+
use mcp_core::prompt::{Prompt, PromptArgument};
45
use mcp_core::{handler::ToolError, protocol::ServerCapabilities, resource::Resource, tool::Tool};
56
use mcp_server::router::{CapabilitiesBuilder, RouterService};
67
use mcp_server::{ByteTransport, Router, Server};
@@ -61,6 +62,7 @@ impl Router for CounterRouter {
6162
CapabilitiesBuilder::new()
6263
.with_tools(false)
6364
.with_resources(false, false)
65+
.with_prompts(false)
6466
.build()
6567
}
6668

@@ -153,6 +155,37 @@ impl Router for CounterRouter {
153155
}
154156
})
155157
}
158+
159+
fn list_prompts(&self) -> Vec<Prompt> {
160+
vec![Prompt::new(
161+
"example_prompt",
162+
Some("This is an example prompt that takes one required agrument, message"),
163+
Some(vec![PromptArgument {
164+
name: "message".to_string(),
165+
description: Some("A message to put in the prompt".to_string()),
166+
required: Some(true),
167+
}]),
168+
)]
169+
}
170+
171+
fn get_prompt(
172+
&self,
173+
prompt_name: &str,
174+
) -> Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>> {
175+
let prompt_name = prompt_name.to_string();
176+
Box::pin(async move {
177+
match prompt_name.as_str() {
178+
"example_prompt" => {
179+
let prompt = "This is an example prompt with your message here: '{message}'";
180+
Ok(prompt.to_string())
181+
}
182+
_ => Err(PromptError::NotFound(format!(
183+
"Prompt {} not found",
184+
prompt_name
185+
))),
186+
}
187+
})
188+
}
156189
}
157190

158191
#[tokio::main]

crates/mcp-server/src/router.rs

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,8 @@ pub trait Router: Send + Sync + 'static {
9797
&self,
9898
uri: &str,
9999
) -> Pin<Box<dyn Future<Output = Result<String, ResourceError>> + Send + 'static>>;
100-
fn list_prompts(&self) -> Option<Vec<Prompt>> {
101-
None
102-
}
103-
fn get_prompt(&self, _prompt_name: &str) -> Option<PromptFuture> {
104-
None
105-
}
100+
fn list_prompts(&self) -> Vec<Prompt>;
101+
fn get_prompt(&self, prompt_name: &str) -> PromptFuture;
106102

107103
// Helper method to create base response
108104
fn create_response(&self, id: Option<u64>) -> JsonRpcResponse {
@@ -257,7 +253,7 @@ pub trait Router: Send + Sync + 'static {
257253
req: JsonRpcRequest,
258254
) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
259255
async move {
260-
let prompts = self.list_prompts().unwrap_or_default();
256+
let prompts = self.list_prompts();
261257

262258
let result = ListPromptsResult { prompts };
263259

@@ -294,36 +290,36 @@ pub trait Router: Send + Sync + 'static {
294290
.ok_or_else(|| RouterError::InvalidParams("Missing arguments object".into()))?;
295291

296292
// Fetch the prompt definition first
297-
let prompt = match self.list_prompts() {
298-
Some(prompts) => prompts
299-
.into_iter()
300-
.find(|p| p.name == prompt_name)
301-
.ok_or_else(|| {
302-
RouterError::PromptNotFound(format!("Prompt '{}' not found", prompt_name))
303-
})?,
304-
None => return Err(RouterError::PromptNotFound("No prompts available".into())),
305-
};
293+
let prompt = self
294+
.list_prompts()
295+
.into_iter()
296+
.find(|p| p.name == prompt_name)
297+
.ok_or_else(|| {
298+
RouterError::PromptNotFound(format!("Prompt '{}' not found", prompt_name))
299+
})?;
306300

307301
// Validate required arguments
308-
for arg in &prompt.arguments {
309-
if arg.required
310-
&& (!arguments.contains_key(&arg.name)
311-
|| arguments
312-
.get(&arg.name)
313-
.and_then(Value::as_str)
314-
.is_none_or(str::is_empty))
315-
{
316-
return Err(RouterError::InvalidParams(format!(
317-
"Missing required argument: '{}'",
318-
arg.name
319-
)));
302+
if let Some(args) = &prompt.arguments {
303+
for arg in args {
304+
if arg.required.is_some()
305+
&& arg.required.unwrap()
306+
&& (!arguments.contains_key(&arg.name)
307+
|| arguments
308+
.get(&arg.name)
309+
.and_then(Value::as_str)
310+
.is_none_or(str::is_empty))
311+
{
312+
return Err(RouterError::InvalidParams(format!(
313+
"Missing required argument: '{}'",
314+
arg.name
315+
)));
316+
}
320317
}
321318
}
322319

323320
// Now get the prompt content
324321
let description = self
325322
.get_prompt(prompt_name)
326-
.ok_or_else(|| RouterError::PromptNotFound("Prompt not found".into()))?
327323
.await
328324
.map_err(|e| RouterError::Internal(e.to_string()))?;
329325

0 commit comments

Comments
 (0)