Skip to content

Commit 02ee6e4

Browse files
committed
add an optional code_interpreter
1 parent 393bb4e commit 02ee6e4

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

vision_agent/agent/vision_agent.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def __init__(
197197
local_artifacts_path: Optional[Union[str, Path]] = None,
198198
code_sandbox_runtime: Optional[str] = None,
199199
callback_message: Optional[Callable[[Dict[str, Any]], None]] = None,
200+
code_interpreter: Optional[CodeInterpreter] = None,
200201
) -> None:
201202
"""Initialize the VisionAgent.
202203
@@ -207,12 +208,14 @@ def __init__(
207208
local_artifacts_path (Optional[Union[str, Path]]): The path to the local
208209
artifacts file.
209210
code_sandbox_runtime (Optional[str]): The code sandbox runtime to use.
211+
code_interpreter (Optional[CodeInterpreter]): if not None, use this CodeInterpreter
210212
"""
211213

212214
self.agent = AnthropicLMM(temperature=0.0) if agent is None else agent
213215
self.max_iterations = 12
214216
self.verbosity = verbosity
215217
self.code_sandbox_runtime = code_sandbox_runtime
218+
self.code_interpreter = code_interpreter
216219
self.callback_message = callback_message
217220
if self.verbosity >= 1:
218221
_LOGGER.setLevel(logging.INFO)
@@ -284,9 +287,14 @@ def chat_with_code(
284287
# this is setting remote artifacts path
285288
artifacts = Artifacts(WORKSPACE / "artifacts.pkl")
286289

287-
with CodeInterpreterFactory.new_instance(
288-
code_sandbox_runtime=self.code_sandbox_runtime,
289-
) as code_interpreter:
290+
code_interpreter = (
291+
self.code_interpreter
292+
if self.code_interpreter is not None
293+
else CodeInterpreterFactory.new_instance(
294+
code_sandbox_runtime=self.code_sandbox_runtime,
295+
)
296+
)
297+
with code_interpreter:
290298
orig_chat = copy.deepcopy(chat)
291299
int_chat = copy.deepcopy(chat)
292300
last_user_message = chat[-1]

0 commit comments

Comments
 (0)