55
66from typing import TypedDict , List , Set , AsyncIterator , Any
77import traceback
8- from contextlib import asynccontextmanager
98
109__all__ = ["BaseModel" , "Field" , "Agent" , "mcp_run" , "pydantic_ai" , "pydantic" ]
1110
@@ -34,6 +33,7 @@ class Agent(pydantic_ai.Agent):
3433 client : mcp_run .Client
3534 ignore_tools : Set [str ]
3635 _original_tools : list
36+ _registered_tools : List [str ]
3737
3838 def __init__ (
3939 self ,
@@ -44,10 +44,14 @@ def __init__(
4444 ):
4545 self .client = client or mcp_run .Client ()
4646 self ._original_tools = kw .get ("tools" , [])
47+ self ._registered_tools = []
4748 self .ignore_tools = set (ignore_tools or [])
4849 super ().__init__ (* args , ** kw )
4950 self ._update_tools ()
5051
52+ for t in self ._original_tools :
53+ self ._registered_tools .append (t .name )
54+
5155 def set_profile (self , profile : str ):
5256 self .client .set_profile (profile )
5357 self ._update_tools ()
@@ -79,6 +83,8 @@ def f(input: InputType):
7983
8084 return f
8185
86+ self ._registered_tools .append (tool .name )
87+
8288 self ._register_tool (
8389 pydantic_ai .Tool (
8490 wrap (tool , f ),
@@ -89,38 +95,46 @@ def f(input: InputType):
8995
9096 def reset_tools (self ):
9197 self ._function_tools = {}
92- for t in self ._original_tools .copy ():
93- self ._register_tool (t )
98+ for k in self ._function_tools .keys ():
99+ if k not in self ._registered_tools :
100+ del self ._function_tools [k ]
94101
95102 def _update_tools (self ):
96103 self .reset_tools ()
97104 for tool in self .client .tools .values ():
98105 self .register_tool (tool )
99106
100- async def run (self , * args , ** kw ):
101- self ._update_tools ()
107+ async def run (self , * args , update_tools : bool = True , ** kw ):
108+ if update_tools :
109+ self ._update_tools ()
102110 return await super ().run (* args , ** kw )
103111
104- def run_sync (self , * args , ** kw ):
105- self ._update_tools ()
112+ def run_sync (self , * args , update_tools : bool = True , ** kw ):
113+ if update_tools :
114+ self ._update_tools ()
106115 return super ().run_sync (* args , ** kw )
107116
108- async def run_async (self , * args , ** kw ):
109- self ._update_tools ()
117+ async def run_async (self , * args , update_tools : bool = True , ** kw ):
118+ if update_tools :
119+ self ._update_tools ()
110120 return await super ().run_async (* args , ** kw )
111121
112122 def run_stream (
113123 self ,
114124 * args ,
125+ update_tools : bool = True ,
115126 ** kw ,
116127 ) -> AsyncIterator [Any ]:
117- self ._update_tools ()
128+ if update_tools :
129+ self ._update_tools ()
118130 return super ().run_stream (* args , ** kw )
119131
120132 def iter (
121133 self ,
122134 * args ,
135+ update_tools : bool = True ,
123136 ** kw ,
124137 ) -> AsyncIterator [Any ]:
125- self ._update_tools ()
138+ if update_tools :
139+ self ._update_tools ()
126140 return super ().iter (* args , ** kw )
0 commit comments