diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 48cd6590..b302c366 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -12,6 +12,7 @@ EXAMPLES_CODE1, EXAMPLES_CODE2, EXAMPLES_CODE3, + EXAMPLES_CODE3_EXTRA2, VA_CODE, ) from vision_agent.lmm import LMM, AnthropicLMM, Message, OpenAILMM @@ -110,7 +111,7 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]: prompt = VA_CODE.format( documentation=META_TOOL_DOCSTRING, - examples=f"{EXAMPLES_CODE1}\n{EXAMPLES_CODE2}\n{EXAMPLES_CODE3}", + examples=f"{EXAMPLES_CODE1}\n{EXAMPLES_CODE2}\n{EXAMPLES_CODE3}\n{EXAMPLES_CODE3_EXTRA2}", conversation=conversation, ) message: Message = {"role": "user", "content": prompt} @@ -182,10 +183,46 @@ def execute_user_code_action( ) if user_result.error: user_obs += f"\n{user_result.error}" - extract_and_save_files_to_artifacts(artifacts, user_code_action, user_obs) + extract_and_save_files_to_artifacts( + artifacts, user_code_action, user_obs, user_result + ) return user_result, user_obs +def _add_media_obs( + code_action: str, + artifacts: Artifacts, + result: Execution, + obs: str, + code_interpreter: CodeInterpreter, + remote_artifacts_path: Path, + local_artifacts_path: Path, +) -> Dict[str, Any]: + obs_chat_elt: Message = {"role": "observation", "content": obs} + media_obs = check_and_load_image(code_action) + if media_obs and result.success: + # for view_media_artifact, we need to ensure the media is loaded + # locally so the conversation agent can actually see it + code_interpreter.download_file( + str(remote_artifacts_path.name), + str(local_artifacts_path), + ) + artifacts.load( + local_artifacts_path, + local_artifacts_path.parent, + ) + + # check if the media is actually in the artifacts + media_obs_chat = [] + for media_ob in media_obs: + if media_ob in artifacts.artifacts: + media_obs_chat.append(local_artifacts_path.parent / media_ob) + if len(media_obs_chat) > 0: + obs_chat_elt["media"] = media_obs_chat + + return obs_chat_elt + + def add_step_descriptions(response: Dict[str, Any]) -> Dict[str, Any]: response = copy.deepcopy(response) @@ -544,35 +581,19 @@ def chat_with_artifacts( code_interpreter, str(remote_artifacts_path), ) - - media_obs = check_and_load_image(code_action) + obs_chat_elt = _add_media_obs( + code_action, + artifacts, + result, + obs, + code_interpreter, + Path(remote_artifacts_path), + Path(self.local_artifacts_path), + ) if self.verbosity >= 1: _LOGGER.info(obs) - obs_chat_elt: Message = {"role": "observation", "content": obs} - if media_obs and result.success: - # for view_media_artifact, we need to ensure the media is loaded - # locally so the conversation agent can actually see it - code_interpreter.download_file( - str(remote_artifacts_path.name), - str(self.local_artifacts_path), - ) - artifacts.load( - self.local_artifacts_path, - Path(self.local_artifacts_path).parent, - ) - - # check if the media is actually in the artifacts - media_obs_chat = [] - for media_ob in media_obs: - if media_ob in artifacts.artifacts: - media_obs_chat.append( - Path(self.local_artifacts_path).parent / media_ob - ) - if len(media_obs_chat) > 0: - obs_chat_elt["media"] = media_obs_chat - # don't add execution results to internal chat int_chat.append(obs_chat_elt) obs_chat_elt["execution"] = result