Skip to content

Commit 4258a39

Browse files
author
zach
authored
refactor: use MCP protocol to access mcp.run via SSE/stdio (#4)
* refactor: only use MCP for mcp.run tool calling
1 parent 51dc4eb commit 4258a39

File tree

6 files changed

+523
-656
lines changed

6 files changed

+523
-656
lines changed

.python-version

Lines changed: 0 additions & 1 deletion
This file was deleted.

example.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ async def run(agent, msg):
3131
Send a message to an agent and return the result
3232
"""
3333
global history
34-
async with agent.run_stream(msg, message_history=history) as result:
35-
return await result.get_data()
34+
async with agent.run_mcp_servers():
35+
async with agent.run_stream(msg, message_history=history) as result:
36+
return await result.get_output()
3637

3738

3839
types = ImageList | VowelCount

mcpx_pydantic_ai.py

Lines changed: 33 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,21 @@
44
from pydantic import BaseModel, Field
55
from pydantic_ai.models.openai import OpenAIModel
66
from pydantic_ai.providers.openai import OpenAIProvider
7-
from mcp_run.client import _convert_type
8-
9-
from typing import TypedDict, List, Set, AsyncIterator, Any
10-
import traceback
11-
12-
__all__ = ["BaseModel", "Field", "Agent", "mcp_run", "pydantic_ai", "pydantic"]
7+
from pydantic_ai.mcp import MCPServerHTTP, MCPServerStdio
8+
from datetime import timedelta
9+
from mcp_run import MCPClient, SSEClientConfig, StdioClientConfig
10+
11+
__all__ = [
12+
"BaseModel",
13+
"Field",
14+
"Agent",
15+
"mcp_run",
16+
"pydantic_ai",
17+
"pydantic",
18+
"MCPClient",
19+
"SSEClientConfig",
20+
"StdioClientConfig",
21+
]
1322

1423

1524
def openai_compatible_model(url: str, model: str, api_key: str | None = None):
@@ -26,103 +35,30 @@ class Agent(pydantic_ai.Agent):
2635
"""
2736

2837
client: mcp_run.Client
29-
ignore_tools: Set[str]
30-
_original_tools: list
31-
_registered_tools: List[str]
3238

3339
def __init__(
3440
self,
3541
*args,
3642
client: mcp_run.Client | None = None,
37-
ignore_tools: List[str] | None = None,
43+
mcp_client: MCPClient | None = None,
44+
expires_in: timedelta | None = None,
3845
**kw,
3946
):
4047
self.client = client or mcp_run.Client()
41-
self._original_tools = kw.get("tools", [])
42-
self._registered_tools = []
43-
self.ignore_tools = set(ignore_tools or [])
44-
super().__init__(*args, **kw)
45-
self._update_tools()
46-
47-
for t in self._original_tools:
48-
self._registered_tools.append(t.name)
49-
50-
def set_profile(self, profile: str):
51-
self.client.set_profile(profile)
52-
self._update_tools()
53-
54-
def register_tool(self, tool: mcp_run.Tool, f=None):
55-
if tool.name in self.ignore_tools:
56-
return
57-
58-
def wrap(tool, inner):
59-
if inner is not None:
60-
props = tool.input_schema["properties"]
61-
t = {k: _convert_type(v["type"]) for k, v in props.items()}
62-
InputType = TypedDict("Input", t)
63-
64-
def f(input: InputType):
65-
try:
66-
return inner(input)
67-
except Exception as exc:
68-
return f"ERROR call to tool {tool.name} failed: {traceback.format_exception(exc)}"
69-
70-
return f
71-
else:
72-
return self.client._make_pydantic_function(tool)
73-
74-
self._register_tool(
75-
pydantic_ai.Tool(
76-
wrap(tool, f),
77-
name=tool.name,
78-
description=tool.description,
79-
)
48+
mcp = mcp_client or self.client.mcp_sse(
49+
profile=self.client.config.profile, expires_in=expires_in
8050
)
81-
82-
if f is not None:
83-
self._registered_tools.append(tool.name)
84-
85-
def reset_tools(self):
86-
for k in list(self._function_tools.keys()):
87-
if k not in self._registered_tools:
88-
del self._function_tools[k]
89-
90-
def _update_tools(self):
91-
self.reset_tools()
92-
for tool in self.client.tools.values():
93-
self.register_tool(tool)
94-
95-
async def run(self, *args, update_tools: bool = True, **kw):
96-
if update_tools:
97-
self._update_tools()
98-
return await super().run(*args, **kw)
99-
100-
def run_sync(self, *args, update_tools: bool = True, **kw):
101-
if update_tools:
102-
self._update_tools()
103-
return super().run_sync(*args, **kw)
104-
105-
async def run_async(self, *args, update_tools: bool = True, **kw):
106-
if update_tools:
107-
self._update_tools()
108-
return await super().run_async(*args, **kw)
109-
110-
def run_stream(
111-
self,
112-
*args,
113-
update_tools: bool = True,
114-
**kw,
115-
) -> AsyncIterator[Any]:
116-
if update_tools:
117-
self._update_tools()
118-
return super().run_stream(*args, **kw)
119-
120-
def iter(
121-
self,
122-
*args,
123-
update_tools: bool = True,
124-
**kw,
125-
) -> AsyncIterator[Any]:
126-
if update_tools:
127-
self._update_tools()
128-
return super().iter(*args, **kw)
51+
mcp_servers = kw.get("mcp_servers", [])
52+
if mcp.is_sse:
53+
mcp_servers.append(MCPServerHTTP(url=mcp.config.url))
54+
elif mcp.is_stdio:
55+
mcp_servers.append(
56+
MCPServerStdio(
57+
command=mcp.config.command,
58+
args=mcp.config.args,
59+
env=mcp.config.env,
60+
cwd=mcp.config.cwd,
61+
)
62+
)
63+
kw["mcp_servers"] = mcp_servers
64+
super().__init__(*args, **kw)

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
[project]
22
name = "mcpx-pydantic-ai"
3-
version = "0.6.0"
3+
version = "0.7.0"
44
description = "Pydantic Agent with mcp.run tools"
55
readme = "README.md"
66
requires-python = ">=3.12"
77
dependencies = [
8-
"mcp-run>=0.4.0",
8+
"mcp-run>=0.5.0",
99
"pydantic>=2.10.4",
1010
"pydantic-ai>=0.0.35",
1111
]

tests/test_mcpx.py

Lines changed: 14 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import unittest
2-
from unittest.mock import Mock, patch
3-
from mcpx_pydantic_ai import Agent, _convert_type
2+
from unittest.mock import Mock
43
from typing import Dict, Any
54
import os
65

6+
7+
from mcpx_pydantic_ai import Agent
8+
79
os.environ["ANTHROPIC_API_KEY"] = "something"
810

911

@@ -35,33 +37,28 @@ def __init__(self):
3537
}
3638
self.called_tool = None
3739
self.called_params = None
40+
self.config = Mock(profile="default")
3841

3942
def call_tool(self, tool: str, params: Dict[str, Any]) -> MockResponse:
4043
self.called_tool = tool
4144
self.called_params = params
4245
return MockResponse("mock response")
4346

44-
def set_profile(self, profile: str):
45-
self.profile = profile
46-
4747
def _make_pydantic_function(self, tool):
4848
def test(input: dict):
4949
return self.call_tool(tool.name, input).content[0].text
50-
return test
5150

51+
return test
5252

53-
class TestTypeConversion(unittest.TestCase):
54-
def test_convert_basic_types(self):
55-
self.assertEqual(_convert_type("string"), str)
56-
self.assertEqual(_convert_type("boolean"), bool)
57-
self.assertEqual(_convert_type("number"), float)
58-
self.assertEqual(_convert_type("integer"), int)
59-
self.assertEqual(_convert_type("object"), dict)
60-
self.assertEqual(_convert_type("array"), list)
53+
def mcp_sse(self, profile=None, expires_in=None):
54+
mock_mcp = Mock()
55+
mock_mcp.is_sse = True
56+
mock_mcp.is_stdio = False
57+
mock_mcp.config = Mock(url="http://mock-url.com")
58+
return mock_mcp
6159

62-
def test_convert_invalid_type(self):
63-
with self.assertRaises(TypeError):
64-
_convert_type("invalid_type")
60+
def _fix_profile(self, profile):
61+
return profile
6562

6663

6764
class TestAgent(unittest.TestCase):
@@ -76,102 +73,6 @@ def setUp(self):
7673
def test_init_with_custom_client(self):
7774
"""Test agent initialization with custom client"""
7875
self.assertEqual(self.agent.client, self.mock_client)
79-
self.assertEqual(
80-
len(self.agent._function_tools), 1
81-
) # Should have our mock tool
82-
83-
def test_init_with_ignore_tools(self):
84-
"""Test agent initialization with ignored tools"""
85-
agent = Agent(
86-
model="claude-3-5-sonnet-latest",
87-
client=self.mock_client,
88-
ignore_tools=["test_tool"],
89-
system_prompt="test prompt",
90-
)
91-
self.assertEqual(
92-
len(agent._function_tools), 0
93-
) # Should have no tools due to ignore
94-
95-
def test_set_profile(self):
96-
"""Test setting profile updates client profile"""
97-
self.agent.set_profile("test_profile")
98-
self.assertEqual(self.mock_client.profile, "test_profile")
99-
100-
def test_register_custom_tool(self):
101-
"""Test registering a custom tool with custom function"""
102-
custom_mock = Mock(return_value="custom response")
103-
104-
self.agent.register_tool(
105-
MockTool(
106-
"custom_tool",
107-
"A custom tool",
108-
{"properties": {"param": {"type": "string"}}},
109-
),
110-
custom_mock,
111-
)
112-
113-
# Verify tool was registered
114-
self.assertIn("custom_tool", self.agent._function_tools)
115-
116-
# Test tool execution
117-
tool_func = self.agent._function_tools["custom_tool"].function
118-
result = tool_func({"param": "test"})
119-
120-
custom_mock.assert_called_once_with({"param": "test"})
121-
self.assertEqual(result, "custom response")
122-
123-
def test_tool_execution(self):
124-
"""Test executing a registered tool"""
125-
# Our mock tool should be registered automatically
126-
tool_func = self.agent._function_tools["test_tool"].function
127-
128-
result = tool_func({"param1": "test", "param2": 123})
129-
130-
self.assertEqual(self.mock_client.called_tool, "test_tool")
131-
self.assertEqual(
132-
self.mock_client.called_params, {"param1": "test", "param2": 123}
133-
)
134-
self.assertEqual(result, "mock response")
135-
136-
def test_reset_tools(self):
137-
"""Test resetting tools"""
138-
# Add a custom tool
139-
self.agent.register_tool(
140-
MockTool(
141-
"custom_tool",
142-
"A custom tool",
143-
{"properties": {"param": {"type": "string"}}},
144-
),
145-
Mock(),
146-
)
147-
148-
# Reset tools
149-
self.agent.reset_tools()
150-
151-
# Only custom tool should remain
152-
self.assertEqual(len(self.agent._function_tools), 1)
153-
self.assertIn("custom_tool", self.agent._function_tools)
154-
self.assertNotIn("test_tool", self.agent._function_tools)
155-
156-
@patch("mcpx_pydantic_ai.pydantic_ai.Agent.run_sync")
157-
def test_run_sync_updates_tools(self, mock_run_sync):
158-
"""Test that run_sync updates tools by default"""
159-
mock_run_sync.return_value = "test response"
160-
161-
result = self.agent.run_sync("test prompt")
162-
163-
self.assertEqual(result, "test response")
164-
mock_run_sync.assert_called_once()
165-
166-
@patch("mcpx_pydantic_ai.pydantic_ai.Agent.run_sync")
167-
def test_run_sync_without_tool_update(self, mock_run_sync):
168-
"""Test that run_sync can skip tool updates"""
169-
mock_run_sync.return_value = "test response"
170-
171-
result = self.agent.run_sync("test prompt", update_tools=False)
172-
173-
self.assertEqual(result, "test response")
174-
mock_run_sync.assert_called_once()
17576

17677

17778
if __name__ == "__main__":

0 commit comments

Comments
 (0)