From dea87566965859603b6447fe44afb02b69da465b Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Wed, 16 Oct 2024 09:13:39 -0700 Subject: [PATCH] reduced code complexity --- vision_agent/agent/vision_agent.py | 127 +++++++++++------------------ vision_agent/tools/meta_tools.py | 6 +- 2 files changed, 51 insertions(+), 82 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 42204190..8cec02db 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -2,6 +2,7 @@ import json import logging import os +import pickle as pkl import tempfile from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast @@ -122,7 +123,9 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]: and "media" in chat[-1] and len(chat[-1]["media"]) > 0 # type: ignore ): - message["media"] = chat[-1]["media"] + media_obs = [media for media in chat[-1]["media"] if Path(media).exists()] # type: ignore + if len(media_obs) > 0: + message["media"] = media_obs # type: ignore conv_resp = cast(str, orch([message], stream=False)) # clean the response first, if we are executing code, do not resond or end @@ -146,10 +149,11 @@ def execute_code_action( artifacts: Artifacts, code: str, code_interpreter: CodeInterpreter, - artifact_remote_path: str, ) -> Tuple[Execution, str]: result = code_interpreter.exec_isolation( - BoilerplateCode.add_boilerplate(code, remote_path=artifact_remote_path) + BoilerplateCode.add_boilerplate( + code, remote_path=str(artifacts.remote_save_path) + ) ) obs = str(result.logs) @@ -163,7 +167,6 @@ def execute_user_code_action( artifacts: Artifacts, last_user_message: Message, code_interpreter: CodeInterpreter, - artifact_remote_path: str, ) -> Tuple[Optional[Execution], Optional[str]]: user_result = None user_obs = None @@ -180,50 +183,28 @@ def execute_user_code_action( if user_code_action is not None: user_code_action = use_extra_vision_agent_args(user_code_action, False) user_result, user_obs = execute_code_action( - artifacts, user_code_action, code_interpreter, artifact_remote_path + artifacts, user_code_action, code_interpreter ) if user_result.error: user_obs += f"\n{user_result.error}" - extract_and_save_files_to_artifacts( - artifacts, user_code_action, user_obs, user_result - ) return user_result, user_obs -def _add_media_obs( - code_action: str, - artifacts: Artifacts, - result: Execution, - obs: str, - code_interpreter: CodeInterpreter, - remote_artifacts_path: Path, - local_artifacts_path: Path, -) -> Dict[str, Any]: - obs_chat_elt: Message = {"role": "observation", "content": obs} - media_obs = check_and_load_image(code_action) - if media_obs and result.success: - # for view_media_artifact, we need to ensure the media is loaded locally so - # the conversation agent can actually see it. We also download it here so we - # can check if it contains the actual media (note this is in addition to - # downloading it per turn). +def download_and_merge_artifacts( + code_interpreter: CodeInterpreter, artifacts: Artifacts +) -> None: + with tempfile.TemporaryFile() as temp_file: code_interpreter.download_file( - str(remote_artifacts_path.name), - str(local_artifacts_path), - ) - artifacts.load( - local_artifacts_path, - local_artifacts_path.parent, + str(artifacts.remote_save_path), + str(temp_file), ) - - # check if the media is actually in the artifacts - media_obs_chat = [] - for media_ob in media_obs: - if media_ob in artifacts.artifacts: - media_obs_chat.append(local_artifacts_path.parent / media_ob) - if len(media_obs_chat) > 0: - obs_chat_elt["media"] = media_obs_chat - - return obs_chat_elt + temp_file.seek(0) + with open(str(temp_file), "rb") as f: + remote_artifacts = pkl.load(f) + merged_artifacts = {**artifacts.artifacts, **remote_artifacts} + artifacts.artifacts = merged_artifacts + artifacts.save() + artifacts.load(artifacts.local_save_path, artifacts.local_save_path.parent) def add_step_descriptions(response: Dict[str, Any]) -> Dict[str, Any]: @@ -354,21 +335,15 @@ def __init__( self.callback_message = callback_message if self.verbosity >= 1: _LOGGER.setLevel(logging.INFO) - self.local_artifacts_path = cast( - str, - ( - Path(local_artifacts_path) - if local_artifacts_path is not None - else Path(tempfile.NamedTemporaryFile(delete=False).name) - ), + self.local_artifacts_path = ( + Path(local_artifacts_path) + if local_artifacts_path is not None + else Path(tempfile.NamedTemporaryFile(delete=False).name) ) - self.remote_artifacts_path = cast( - str, - ( - Path(remote_artifacts_path) - if remote_artifacts_path is not None - else Path(WORKSPACE / "artifacts.pkl") - ), + self.remote_artifacts_path = ( + Path(remote_artifacts_path) + if remote_artifacts_path is not None + else Path(WORKSPACE / "artifacts.pkl") ) def __call__( @@ -455,8 +430,15 @@ def chat_with_artifacts( and not isinstance(self.code_interpreter, str) else CodeInterpreterFactory.new_instance( code_sandbox_runtime=self.code_interpreter, + remote_path=self.remote_artifacts_path.parent, ) ) + + if code_interpreter.remote_path != self.remote_artifacts_path.parent: + raise ValueError( + f"Code interpreter remote path {code_interpreter.remote_path} does not match {self.remote_artifacts_path.parent}" + ) + with code_interpreter: orig_chat = copy.deepcopy(chat) int_chat = copy.deepcopy(chat) @@ -501,9 +483,7 @@ def chat_with_artifacts( # Upload artifacts to remote location and show where they are going # to be loaded to. The actual loading happens in BoilerplateCode as # part of the pre_code. - remote_artifacts_path = code_interpreter.upload_file( - self.local_artifacts_path - ) + code_interpreter.upload_file(self.local_artifacts_path) artifacts_loaded = artifacts.show(code_interpreter.remote_path) int_chat.append({"role": "observation", "content": artifacts_loaded}) orig_chat.append({"role": "observation", "content": artifacts_loaded}) @@ -513,7 +493,6 @@ def chat_with_artifacts( artifacts, last_user_message, code_interpreter, - str(remote_artifacts_path), ) finished = user_result is not None and user_obs is not None if user_result is not None and user_obs is not None: @@ -537,6 +516,11 @@ def chat_with_artifacts( code_interpreter.upload_file(self.local_artifacts_path) response = run_conversation(self.agent, int_chat) + code_action = use_extra_vision_agent_args( + response.get("execute_python", None), + test_multi_plan, + custom_tool_names, + ) if self.verbosity >= 1: _LOGGER.info(response) int_chat.append( @@ -562,12 +546,6 @@ def chat_with_artifacts( finished = response.get("let_user_respond", False) - code_action = response.get("execute_python", None) - if code_action is not None: - code_action = use_extra_vision_agent_args( - code_action, test_multi_plan, custom_tool_names - ) - if last_response == response: self.streaming_message( { @@ -597,17 +575,11 @@ def chat_with_artifacts( artifacts, code_action, code_interpreter, - str(remote_artifacts_path), - ) - obs_chat_elt = _add_media_obs( - code_action, - artifacts, - result, - obs, - code_interpreter, - Path(remote_artifacts_path), - Path(self.local_artifacts_path), ) + obs_chat_elt: Message = {"role": "observation", "content": obs} + media_obs = check_and_load_image(code_action) + if media_obs and result.success: + obs_chat_elt["media"] = media_obs if self.verbosity >= 1: _LOGGER.info(obs) @@ -629,12 +601,7 @@ def chat_with_artifacts( last_response = response # after each turn, download the artifacts locally - code_interpreter.download_file( - str(remote_artifacts_path.name), str(self.local_artifacts_path) - ) - artifacts.load( - self.local_artifacts_path, Path(self.local_artifacts_path).parent - ) + download_and_merge_artifacts(code_interpreter, artifacts) return orig_chat, artifacts diff --git a/vision_agent/tools/meta_tools.py b/vision_agent/tools/meta_tools.py index ffbfc204..e8185daf 100644 --- a/vision_agent/tools/meta_tools.py +++ b/vision_agent/tools/meta_tools.py @@ -662,10 +662,10 @@ def get_diff_with_prompts(name: str, before: str, after: str) -> str: def use_extra_vision_agent_args( - code: str, + code: Optional[str], test_multi_plan: bool = True, custom_tool_names: Optional[List[str]] = None, -) -> str: +) -> Optional[str]: """This is for forcing arguments passed by the user to VisionAgent into the VisionAgentCoder call. @@ -677,6 +677,8 @@ def use_extra_vision_agent_args( Returns: str: The edited code. """ + if code is None: + return None red = RedBaron(code) for node in red: # seems to always be atomtrailers not call type