Skip to content

Commit 814b80d

Browse files
committed
Add MCPServer
1 parent 75bdabc commit 814b80d

File tree

5 files changed

+352
-4
lines changed

5 files changed

+352
-4
lines changed

coagent/agents/mcp_server.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
import asyncio
2+
from contextlib import AbstractAsyncContextManager, AsyncExitStack
3+
from typing import Any, Literal
4+
5+
from coagent.core import BaseAgent, Context, handler, logger, Message
6+
from coagent.core.exceptions import InternalError
7+
from mcp import ClientSession, Tool as MCPTool # ruff: noqa: F401
8+
from mcp.client.sse import sse_client
9+
from mcp.client.stdio import stdio_client, StdioServerParameters
10+
from mcp.types import (
11+
CallToolResult as MCPCallToolResult,
12+
ListToolsResult as MCPListToolsResult,
13+
ImageContent as MCPImageContent, # ruff: noqa: F401
14+
TextContent as MCPTextContent, # ruff: noqa: F401
15+
)
16+
from pydantic import BaseModel
17+
18+
19+
# An alias of `mcp.client.stdio.StdioServerParameters`.
20+
MCPServerStdioParams = StdioServerParameters
21+
22+
23+
class MCPServerSSEParams(BaseModel):
24+
"""Core parameters in `mcp.client.sse.sse_client`."""
25+
26+
url: str
27+
"""The URL of the server."""
28+
29+
headers: dict[str, str] | None = None
30+
"""The headers to send to the server."""
31+
32+
33+
class Connect(Message):
34+
"""A message to connect to the server.
35+
36+
To close the server, send a `Cancel` message to close the connection
37+
and delete corresponding server agent.
38+
"""
39+
40+
transport: Literal["sse", "stdio"]
41+
"""The transport to use.
42+
43+
See https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/.
44+
"""
45+
46+
params: MCPServerStdioParams | MCPServerSSEParams
47+
"""The parameters to connect to the server."""
48+
49+
enable_cache: bool = True
50+
"""Whether to cache the list result. Defaults to `True`.
51+
52+
If `True`, the tools list will be cached and only fetched from the server
53+
once. If `False`, the tools list will be fetched from the server on each
54+
`ListTools` message. The cache can be invalidated by sending an
55+
`InvalidateCache` message.
56+
57+
Only set this to `False` if you know the server will change its tools list,
58+
because it can drastically increase latency (by introducing a round-trip
59+
to the server every time).
60+
"""
61+
62+
63+
class InvalidateCache(Message):
64+
"""A message to invalidate the cache of the list result."""
65+
66+
pass
67+
68+
69+
class ListTools(Message):
70+
"""A message to list the tools available on the server."""
71+
72+
pass
73+
74+
75+
class ListToolsResult(Message, MCPListToolsResult):
76+
"""The result of `ListTools`."""
77+
78+
pass
79+
80+
81+
class CallTool(Message):
82+
"""A message to call a tool on the server."""
83+
84+
name: str
85+
"""The name of the tool to call."""
86+
87+
arguments: dict[str, Any] | None = None
88+
"""The arguments to pass to the tool."""
89+
90+
91+
class CallToolResult(Message, MCPCallToolResult):
92+
"""The result of `ListTools`."""
93+
94+
pass
95+
96+
97+
class MCPServer(BaseAgent):
98+
"""An agent that acts as an MCP client to connect to an MCP server."""
99+
100+
def __init__(self, timeout: int = float("inf")) -> None:
101+
super().__init__(timeout=timeout)
102+
103+
self._client_session: ClientSession | None = None
104+
self._exit_stack: AsyncExitStack = AsyncExitStack()
105+
106+
self._list_tools_result_cache: ListToolsResult | None = None
107+
self._cache_enabled: bool = False
108+
self._cache_invalidated: bool = False
109+
110+
# Ongoing tasks that need to be cancelled when the server is stopped.
111+
self._pending_tasks: set[asyncio.Task] = set()
112+
113+
async def stopped(self) -> None:
114+
await self._cleanup()
115+
116+
async def _handle_data(self) -> None:
117+
"""Override the method to handle exceptions properly."""
118+
try:
119+
await super()._handle_data()
120+
finally:
121+
# Ensure the resources are properly cleaned up.
122+
await self._cleanup()
123+
124+
async def _handle_data_custom(self, msg: Message, ctx: Context) -> None:
125+
"""Override to handle `ListTools` and `CallTool` messages concurrently."""
126+
match msg:
127+
case ListTools() | CallTool():
128+
task = asyncio.create_task(super()._handle_data_custom(msg, ctx))
129+
self._pending_tasks.add(task)
130+
task.add_done_callback(self._pending_tasks.discard)
131+
case _:
132+
await super()._handle_data_custom(msg, ctx)
133+
134+
@handler
135+
async def connect(self, msg: Connect, ctx: Context) -> None:
136+
"""Connect to the server."""
137+
if msg.transport == "sse":
138+
ctx_manager: AbstractAsyncContextManager = sse_client(
139+
**msg.params.model_dump()
140+
)
141+
else: # "stdio":
142+
ctx_manager: AbstractAsyncContextManager = stdio_client(msg.params)
143+
144+
try:
145+
transport = await self._exit_stack.enter_async_context(ctx_manager)
146+
read, write = transport
147+
session = await self._exit_stack.enter_async_context(
148+
ClientSession(read, write)
149+
)
150+
await session.initialize()
151+
152+
self._client_session = session
153+
self._cache_enabled = msg.enable_cache
154+
except Exception as exc:
155+
logger.error(f"Error initializing MCP server: {exc}")
156+
await self._cleanup()
157+
raise
158+
159+
@handler
160+
async def invalidate_cache(self, msg: InvalidateCache, ctx: Context) -> None:
161+
self._cache_invalidated = True
162+
163+
@handler
164+
async def list_tools(self, msg: ListTools, ctx: Context) -> ListToolsResult:
165+
if not self._client_session:
166+
raise InternalError(
167+
"Server not initialized. Make sure to send the `Connect` message first."
168+
)
169+
170+
# Return the cached result if the cache is enabled and not invalidated.
171+
if (
172+
self._cache_enabled
173+
and not self._cache_invalidated
174+
and self._list_tools_result_cache
175+
):
176+
return self._list_tools_result_cache
177+
178+
# Reset the cache status.
179+
self._cache_invalidated = False
180+
181+
result = await self._client_session.list_tools()
182+
self._list_tools_result_cache = ListToolsResult(**result.model_dump())
183+
return self._list_tools_result_cache
184+
185+
@handler
186+
async def call_tool(self, msg: CallTool, ctx: Context) -> CallToolResult:
187+
if not self._client_session:
188+
raise InternalError(
189+
"Server not initialized. Make sure to send the `Connect` message first."
190+
)
191+
192+
result = await self._client_session.call_tool(msg.name, arguments=msg.arguments)
193+
return CallToolResult(**result.model_dump())
194+
195+
async def _cleanup(self) -> None:
196+
"""Cleanup the server."""
197+
if self._pending_tasks:
198+
# Cancel all pending tasks.
199+
for task in self._pending_tasks:
200+
task.cancel()
201+
202+
if not self._client_session:
203+
return
204+
205+
try:
206+
await self._exit_stack.aclose()
207+
self._client_session = None
208+
except Exception as exc:
209+
logger.error(f"Error cleaning up server: {exc}")
File renamed without changes.

tests/agents/test_mcp_agent.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ class TestMCPAgent:
1010
@pytest.mark.skipif(sys.platform == "win32", reason="Does not run on Windows.")
1111
@pytest.mark.asyncio
1212
async def test_get_prompt(self):
13-
agent = MCPAgent(mcp_server_base_url="python tests/agents/mcp_server.py")
13+
agent = MCPAgent(
14+
mcp_server_base_url="python tests/agents/example_mcp_server.py"
15+
)
1416
await agent.started()
1517

1618
# String
@@ -34,7 +36,9 @@ async def test_get_prompt(self):
3436
@pytest.mark.skipif(sys.platform == "win32", reason="Does not run on Windows.")
3537
@pytest.mark.asyncio
3638
async def test_get_tools(self):
37-
agent = MCPAgent(mcp_server_base_url="python tests/agents/mcp_server.py")
39+
agent = MCPAgent(
40+
mcp_server_base_url="python tests/agents/example_mcp_server.py"
41+
)
3842
await agent.started()
3943

4044
tools = await agent._get_tools(None)
@@ -101,7 +105,7 @@ async def test_get_tools(self):
101105
async def test_get_tools_with_selection(self):
102106
selected_tools = ["query_weather"]
103107
agent = MCPAgent(
104-
mcp_server_base_url="python tests/agents/mcp_server.py",
108+
mcp_server_base_url="python tests/agents/example_mcp_server.py",
105109
tools=selected_tools,
106110
)
107111
await agent.started()

tests/agents/test_mcp_server.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import sys
2+
3+
import pytest
4+
5+
from coagent.agents.mcp_server import (
6+
CallTool,
7+
Connect,
8+
ListTools,
9+
MCPServer,
10+
MCPServerStdioParams,
11+
)
12+
from coagent.core import Context
13+
14+
15+
class TestMCPServer:
16+
@pytest.mark.skipif(sys.platform == "win32", reason="Does not run on Windows.")
17+
@pytest.mark.asyncio
18+
async def test_connect(self):
19+
agent = MCPServer()
20+
ctx = Context()
21+
22+
# Connect successfully
23+
await agent.connect(
24+
Connect(
25+
transport="stdio",
26+
params=MCPServerStdioParams(
27+
command="python",
28+
args=["tests/agents/example_mcp_server.py"],
29+
),
30+
),
31+
ctx,
32+
)
33+
34+
# Connect error
35+
with pytest.raises(Exception) as exc:
36+
await agent.connect(
37+
Connect(
38+
transport="stdio",
39+
params=MCPServerStdioParams(
40+
command="pythonx",
41+
args=["tests/agents/example_mcp_server.py"],
42+
),
43+
),
44+
ctx,
45+
)
46+
assert str(exc.value).startswith(
47+
"[Errno 2] No such file or directory: 'pythonx'"
48+
)
49+
50+
@pytest.mark.skipif(sys.platform == "win32", reason="Does not run on Windows.")
51+
@pytest.mark.asyncio
52+
async def test_list_tools(self):
53+
agent = MCPServer()
54+
ctx = Context()
55+
56+
# Connect to the server
57+
await agent.connect(
58+
Connect(
59+
transport="stdio",
60+
params=MCPServerStdioParams(
61+
command="python",
62+
args=["tests/agents/example_mcp_server.py"],
63+
),
64+
),
65+
ctx,
66+
)
67+
68+
result = await agent.list_tools(ListTools(), ctx)
69+
assert len(result.tools) == 2
70+
71+
# Validate tool query_weather
72+
tool = result.tools[0]
73+
assert tool.name == "query_weather"
74+
assert tool.description == "Query the weather in the given city."
75+
assert tool.inputSchema == {
76+
"properties": {
77+
"city": {
78+
"title": "City",
79+
"type": "string",
80+
}
81+
},
82+
"required": ["city"],
83+
"title": "query_weatherArguments",
84+
"type": "object",
85+
}
86+
87+
# Validate tool book_flight
88+
tool = result.tools[1]
89+
assert tool.name == "book_flight"
90+
assert tool.description == "Book a flight from departure to arrival."
91+
assert tool.inputSchema == {
92+
"properties": {
93+
"arrival": {
94+
"title": "Arrival",
95+
"type": "string",
96+
},
97+
"departure": {
98+
"title": "Departure",
99+
"type": "string",
100+
},
101+
},
102+
"required": ["departure", "arrival"],
103+
"title": "book_flightArguments",
104+
"type": "object",
105+
}
106+
107+
await agent.stopped()
108+
109+
@pytest.mark.skipif(sys.platform == "win32", reason="Does not run on Windows.")
110+
@pytest.mark.asyncio
111+
async def test_call_tool(self):
112+
agent = MCPServer()
113+
ctx = Context()
114+
115+
# Connect to the server
116+
await agent.connect(
117+
Connect(
118+
transport="stdio",
119+
params=MCPServerStdioParams(
120+
command="python",
121+
args=["tests/agents/example_mcp_server.py"],
122+
),
123+
),
124+
ctx,
125+
)
126+
127+
result = await agent.call_tool(
128+
CallTool(
129+
name="query_weather",
130+
arguments={"city": "Beijing"},
131+
),
132+
ctx,
133+
)
134+
assert result.isError is False
135+
assert result.content[0].text == "The weather in Beijing is sunny."

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)