diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index c64390d5..b54d08b8 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -30,6 +30,12 @@ if str(WORKSPACE) != "": os.environ["PYTHONPATH"] = f"{WORKSPACE}:{os.getenv('PYTHONPATH', '')}" +STUCK_IN_LOOP_ERROR_MESSAGE = { + "name": "Error when running conversation agent", + "value": "Agent is stuck in conversation loop, exited", + "traceback_raw": [], +} + class BoilerplateCode: pre_code = [ @@ -229,7 +235,7 @@ def chat_with_code( ) as code_interpreter: orig_chat = copy.deepcopy(chat) int_chat = copy.deepcopy(chat) - last_user_message_content = chat[-1].get("content") + last_user_message = chat[-1] media_list = [] for chat_i in int_chat: if "media" in chat_i: @@ -278,32 +284,9 @@ def chat_with_code( orig_chat.append({"role": "observation", "content": artifacts_loaded}) self.streaming_message({"role": "observation", "content": artifacts_loaded}) - if int_chat[-1]["role"] == "user": - last_user_message_content = cast(str, int_chat[-1].get("content", "")) - user_code_action = parse_execution(last_user_message_content, False) - if user_code_action is not None: - user_result, user_obs = run_code_action( - user_code_action, code_interpreter, str(remote_artifacts_path) - ) - if self.verbosity >= 1: - _LOGGER.info(user_obs) - int_chat.append({"role": "observation", "content": user_obs}) - orig_chat.append( - { - "role": "observation", - "content": user_obs, - "execution": user_result, - } - ) - self.streaming_message( - { - "role": "observation", - "content": user_obs, - "execution": user_result, - "finished": True, - } - ) - finished = True + finished = self.execute_user_code_action( + last_user_message, code_interpreter, remote_artifacts_path + ) while not finished and iterations < self.max_iterations: response = run_conversation(self.agent, int_chat) @@ -316,10 +299,12 @@ def chat_with_code( if last_response == response: response["let_user_respond"] = True self.streaming_message( - {"role": "assistant", "error": "Stuck in loop"} + { + "role": "assistant", + "content": "{}", + "error": STUCK_IN_LOOP_ERROR_MESSAGE, + } ) - else: - self.streaming_message({"role": "assistant", "content": response}) finished = response["let_user_respond"] @@ -327,6 +312,24 @@ def chat_with_code( response["response"], test_multi_plan, customized_tool_names ) + if last_response == response: + self.streaming_message( + { + "role": "assistant", + "content": "{}", + "error": STUCK_IN_LOOP_ERROR_MESSAGE, + "finished": finished and code_action is None, + } + ) + else: + self.streaming_message( + { + "role": "assistant", + "content": response, + "finished": finished and code_action is None, + } + ) + if code_action is not None: result, obs = run_code_action( code_action, code_interpreter, str(remote_artifacts_path) @@ -353,6 +356,7 @@ def chat_with_code( "role": "observation", "content": obs, "execution": result, + "finished": finished, } ) @@ -367,6 +371,34 @@ def chat_with_code( artifacts.save() return orig_chat, artifacts + def execute_user_code_action( + self, + last_user_message: Message, + code_interpreter: CodeInterpreter, + remote_artifacts_path: Path, + ) -> bool: + if last_user_message["role"] != "user": + return False + user_code_action = parse_execution( + cast(str, last_user_message.get("content", "")), False + ) + if user_code_action is not None: + user_result, user_obs = run_code_action( + user_code_action, code_interpreter, str(remote_artifacts_path) + ) + if self.verbosity >= 1: + _LOGGER.info(user_obs) + self.streaming_message( + { + "role": "observation", + "content": user_obs, + "execution": user_result, + "finished": True, + } + ) + return True + return False + def streaming_message(self, message: Dict[str, Any]) -> None: if self.callback_message: self.callback_message(message) diff --git a/vision_agent/tools/meta_tools.py b/vision_agent/tools/meta_tools.py index 52d732f7..8c32e7fc 100644 --- a/vision_agent/tools/meta_tools.py +++ b/vision_agent/tools/meta_tools.py @@ -425,6 +425,7 @@ def detect_dogs(image_path: str): agent = va.agent.VisionAgentCoder() if name not in artifacts: + print(f"[Artifact {name} does not exist]") return f"[Artifact {name} does not exist]" code = artifacts[name]