|
| 1 | +import asyncio |
| 2 | +from contextlib import AbstractAsyncContextManager, AsyncExitStack |
| 3 | +from typing import Any, Literal |
| 4 | + |
| 5 | +from coagent.core import BaseAgent, Context, handler, logger, Message |
| 6 | +from coagent.core.exceptions import InternalError |
| 7 | +from mcp import ClientSession, Tool as MCPTool # ruff: noqa: F401 |
| 8 | +from mcp.client.sse import sse_client |
| 9 | +from mcp.client.stdio import stdio_client, StdioServerParameters |
| 10 | +from mcp.types import ( |
| 11 | + CallToolResult as MCPCallToolResult, |
| 12 | + ListToolsResult as MCPListToolsResult, |
| 13 | + ImageContent as MCPImageContent, # ruff: noqa: F401 |
| 14 | + TextContent as MCPTextContent, # ruff: noqa: F401 |
| 15 | +) |
| 16 | +from pydantic import BaseModel |
| 17 | + |
| 18 | + |
| 19 | +# An alias of `mcp.client.stdio.StdioServerParameters`. |
| 20 | +MCPServerStdioParams = StdioServerParameters |
| 21 | + |
| 22 | + |
| 23 | +class MCPServerSSEParams(BaseModel): |
| 24 | + """Core parameters in `mcp.client.sse.sse_client`.""" |
| 25 | + |
| 26 | + url: str |
| 27 | + """The URL of the server.""" |
| 28 | + |
| 29 | + headers: dict[str, str] | None = None |
| 30 | + """The headers to send to the server.""" |
| 31 | + |
| 32 | + |
| 33 | +class Connect(Message): |
| 34 | + """A message to connect to the server. |
| 35 | +
|
| 36 | + To close the server, send a `Cancel` message to close the connection |
| 37 | + and delete corresponding server agent. |
| 38 | + """ |
| 39 | + |
| 40 | + transport: Literal["sse", "stdio"] |
| 41 | + """The transport to use. |
| 42 | +
|
| 43 | + See https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/. |
| 44 | + """ |
| 45 | + |
| 46 | + params: MCPServerStdioParams | MCPServerSSEParams |
| 47 | + """The parameters to connect to the server.""" |
| 48 | + |
| 49 | + enable_cache: bool = True |
| 50 | + """Whether to cache the list result. Defaults to `True`. |
| 51 | + |
| 52 | + If `True`, the tools list will be cached and only fetched from the server |
| 53 | + once. If `False`, the tools list will be fetched from the server on each |
| 54 | + `ListTools` message. The cache can be invalidated by sending an |
| 55 | + `InvalidateCache` message. |
| 56 | + |
| 57 | + Only set this to `False` if you know the server will change its tools list, |
| 58 | + because it can drastically increase latency (by introducing a round-trip |
| 59 | + to the server every time). |
| 60 | + """ |
| 61 | + |
| 62 | + |
| 63 | +class InvalidateCache(Message): |
| 64 | + """A message to invalidate the cache of the list result.""" |
| 65 | + |
| 66 | + pass |
| 67 | + |
| 68 | + |
| 69 | +class ListTools(Message): |
| 70 | + """A message to list the tools available on the server.""" |
| 71 | + |
| 72 | + pass |
| 73 | + |
| 74 | + |
| 75 | +class ListToolsResult(Message, MCPListToolsResult): |
| 76 | + """The result of `ListTools`.""" |
| 77 | + |
| 78 | + pass |
| 79 | + |
| 80 | + |
| 81 | +class CallTool(Message): |
| 82 | + """A message to call a tool on the server.""" |
| 83 | + |
| 84 | + name: str |
| 85 | + """The name of the tool to call.""" |
| 86 | + |
| 87 | + arguments: dict[str, Any] | None = None |
| 88 | + """The arguments to pass to the tool.""" |
| 89 | + |
| 90 | + |
| 91 | +class CallToolResult(Message, MCPCallToolResult): |
| 92 | + """The result of `ListTools`.""" |
| 93 | + |
| 94 | + pass |
| 95 | + |
| 96 | + |
| 97 | +class MCPServer(BaseAgent): |
| 98 | + """An agent that acts as an MCP client to connect to an MCP server.""" |
| 99 | + |
| 100 | + def __init__(self, timeout: int = float("inf")) -> None: |
| 101 | + super().__init__(timeout=timeout) |
| 102 | + |
| 103 | + self._client_session: ClientSession | None = None |
| 104 | + self._exit_stack: AsyncExitStack = AsyncExitStack() |
| 105 | + |
| 106 | + self._list_tools_result_cache: ListToolsResult | None = None |
| 107 | + self._cache_enabled: bool = False |
| 108 | + self._cache_invalidated: bool = False |
| 109 | + |
| 110 | + # Ongoing tasks that need to be cancelled when the server is stopped. |
| 111 | + self._pending_tasks: set[asyncio.Task] = set() |
| 112 | + |
| 113 | + async def stopped(self) -> None: |
| 114 | + await self._cleanup() |
| 115 | + |
| 116 | + async def _handle_data(self) -> None: |
| 117 | + """Override the method to handle exceptions properly.""" |
| 118 | + try: |
| 119 | + await super()._handle_data() |
| 120 | + finally: |
| 121 | + # Ensure the resources are properly cleaned up. |
| 122 | + await self._cleanup() |
| 123 | + |
| 124 | + async def _handle_data_custom(self, msg: Message, ctx: Context) -> None: |
| 125 | + """Override to handle `ListTools` and `CallTool` messages concurrently.""" |
| 126 | + match msg: |
| 127 | + case ListTools() | CallTool(): |
| 128 | + task = asyncio.create_task(super()._handle_data_custom(msg, ctx)) |
| 129 | + self._pending_tasks.add(task) |
| 130 | + task.add_done_callback(self._pending_tasks.discard) |
| 131 | + case _: |
| 132 | + await super()._handle_data_custom(msg, ctx) |
| 133 | + |
| 134 | + @handler |
| 135 | + async def connect(self, msg: Connect, ctx: Context) -> None: |
| 136 | + """Connect to the server.""" |
| 137 | + if msg.transport == "sse": |
| 138 | + ctx_manager: AbstractAsyncContextManager = sse_client( |
| 139 | + **msg.params.model_dump() |
| 140 | + ) |
| 141 | + else: # "stdio": |
| 142 | + ctx_manager: AbstractAsyncContextManager = stdio_client(msg.params) |
| 143 | + |
| 144 | + try: |
| 145 | + transport = await self._exit_stack.enter_async_context(ctx_manager) |
| 146 | + read, write = transport |
| 147 | + session = await self._exit_stack.enter_async_context( |
| 148 | + ClientSession(read, write) |
| 149 | + ) |
| 150 | + await session.initialize() |
| 151 | + |
| 152 | + self._client_session = session |
| 153 | + self._cache_enabled = msg.enable_cache |
| 154 | + except Exception as exc: |
| 155 | + logger.error(f"Error initializing MCP server: {exc}") |
| 156 | + await self._cleanup() |
| 157 | + raise |
| 158 | + |
| 159 | + @handler |
| 160 | + async def invalidate_cache(self, msg: InvalidateCache, ctx: Context) -> None: |
| 161 | + self._cache_invalidated = True |
| 162 | + |
| 163 | + @handler |
| 164 | + async def list_tools(self, msg: ListTools, ctx: Context) -> ListToolsResult: |
| 165 | + if not self._client_session: |
| 166 | + raise InternalError( |
| 167 | + "Server not initialized. Make sure to send the `Connect` message first." |
| 168 | + ) |
| 169 | + |
| 170 | + # Return the cached result if the cache is enabled and not invalidated. |
| 171 | + if ( |
| 172 | + self._cache_enabled |
| 173 | + and not self._cache_invalidated |
| 174 | + and self._list_tools_result_cache |
| 175 | + ): |
| 176 | + return self._list_tools_result_cache |
| 177 | + |
| 178 | + # Reset the cache status. |
| 179 | + self._cache_invalidated = False |
| 180 | + |
| 181 | + result = await self._client_session.list_tools() |
| 182 | + self._list_tools_result_cache = ListToolsResult(**result.model_dump()) |
| 183 | + return self._list_tools_result_cache |
| 184 | + |
| 185 | + @handler |
| 186 | + async def call_tool(self, msg: CallTool, ctx: Context) -> CallToolResult: |
| 187 | + if not self._client_session: |
| 188 | + raise InternalError( |
| 189 | + "Server not initialized. Make sure to send the `Connect` message first." |
| 190 | + ) |
| 191 | + |
| 192 | + result = await self._client_session.call_tool(msg.name, arguments=msg.arguments) |
| 193 | + return CallToolResult(**result.model_dump()) |
| 194 | + |
| 195 | + async def _cleanup(self) -> None: |
| 196 | + """Cleanup the server.""" |
| 197 | + if self._pending_tasks: |
| 198 | + # Cancel all pending tasks. |
| 199 | + for task in self._pending_tasks: |
| 200 | + task.cancel() |
| 201 | + |
| 202 | + if not self._client_session: |
| 203 | + return |
| 204 | + |
| 205 | + try: |
| 206 | + await self._exit_stack.aclose() |
| 207 | + self._client_session = None |
| 208 | + except Exception as exc: |
| 209 | + logger.error(f"Error cleaning up server: {exc}") |
0 commit comments