From 6ca3a56219f16386dedb50773fa5f4964dc6af32 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Wed, 25 Sep 2024 20:58:48 -0700 Subject: [PATCH] fixed user exec obs --- vision_agent/agent/vision_agent.py | 85 +++++++++++++++++------------- 1 file changed, 47 insertions(+), 38 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index bf35e5e9..14dec56b 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -87,7 +87,7 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]: return extract_json(orch([message], stream=False)) # type: ignore -def run_code_action( +def execute_code_action( code: str, code_interpreter: CodeInterpreter, artifact_remote_path: str ) -> Tuple[Execution, str]: result = code_interpreter.exec_isolation( @@ -115,10 +115,33 @@ def parse_execution( return code +def execute_user_code_action( + last_user_message: Message, + code_interpreter: CodeInterpreter, + artifact_remote_path: str, +) -> Tuple[Optional[Execution], Optional[str]]: + user_result = None + user_obs = None + + if last_user_message["role"] != "user": + return user_result, user_obs + + last_user_content = cast(str, last_user_message.get("content", "")) + + user_code_action = parse_execution(last_user_content, False) + if user_code_action is not None: + user_result, user_obs = execute_code_action( + user_code_action, code_interpreter, artifact_remote_path + ) + if user_result.error: + user_obs += f"\n{user_result.error}" + return user_result, user_obs + + class VisionAgent(Agent): """Vision Agent is an agent that can chat with the user and call tools or other agents to generate code for it. Vision Agent uses python code to execute actions - for the user. Vision Agent is inspired by by OpenDev + for the user. Vision Agent is inspired by by OpenDevin https://github.com/OpenDevin/OpenDevin and CodeAct https://arxiv.org/abs/2402.01030 Example @@ -278,9 +301,23 @@ def chat_with_code( orig_chat.append({"role": "observation", "content": artifacts_loaded}) self.streaming_message({"role": "observation", "content": artifacts_loaded}) - finished = self.execute_user_code_action( - last_user_message, code_interpreter, remote_artifacts_path + user_result, user_obs = execute_user_code_action( + last_user_message, code_interpreter, str(remote_artifacts_path) ) + finished = user_result is not None and user_obs is not None + if user_result is not None and user_obs is not None: + chat_elt: Message = {"role": "observation", "content": user_obs} + int_chat.append(chat_elt) + chat_elt["execution"] = user_result + orig_chat.append(chat_elt) + self.streaming_message( + { + "role": "observation", + "content": user_obs, + "execution": user_result, + "finished": finished, + } + ) while not finished and iterations < self.max_iterations: response = run_conversation(self.agent, int_chat) @@ -322,7 +359,7 @@ def chat_with_code( ) if code_action is not None: - result, obs = run_code_action( + result, obs = execute_code_action( code_action, code_interpreter, str(remote_artifacts_path) ) @@ -331,17 +368,17 @@ def chat_with_code( if self.verbosity >= 1: _LOGGER.info(obs) - chat_elt: Message = {"role": "observation", "content": obs} + obs_chat_elt: Message = {"role": "observation", "content": obs} if media_obs and result.success: - chat_elt["media"] = [ + obs_chat_elt["media"] = [ Path(code_interpreter.remote_path) / media_ob for media_ob in media_obs ] # don't add execution results to internal chat - int_chat.append(chat_elt) - chat_elt["execution"] = result - orig_chat.append(chat_elt) + int_chat.append(obs_chat_elt) + obs_chat_elt["execution"] = result + orig_chat.append(obs_chat_elt) self.streaming_message( { "role": "observation", @@ -362,34 +399,6 @@ def chat_with_code( artifacts.save() return orig_chat, artifacts - def execute_user_code_action( - self, - last_user_message: Message, - code_interpreter: CodeInterpreter, - remote_artifacts_path: Path, - ) -> bool: - if last_user_message["role"] != "user": - return False - user_code_action = parse_execution( - cast(str, last_user_message.get("content", "")), False - ) - if user_code_action is not None: - user_result, user_obs = run_code_action( - user_code_action, code_interpreter, str(remote_artifacts_path) - ) - if self.verbosity >= 1: - _LOGGER.info(user_obs) - self.streaming_message( - { - "role": "observation", - "content": user_obs, - "execution": user_result, - "finished": True, - } - ) - return True - return False - def streaming_message(self, message: Dict[str, Any]) -> None: if self.callback_message: self.callback_message(message)