diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 1300ca2d..b54d08b8 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -30,6 +30,12 @@ if str(WORKSPACE) != "": os.environ["PYTHONPATH"] = f"{WORKSPACE}:{os.getenv('PYTHONPATH', '')}" +STUCK_IN_LOOP_ERROR_MESSAGE = { + "name": "Error when running conversation agent", + "value": "Agent is stuck in conversation loop, exited", + "traceback_raw": [], +} + class BoilerplateCode: pre_code = [ @@ -278,33 +284,9 @@ def chat_with_code( orig_chat.append({"role": "observation", "content": artifacts_loaded}) self.streaming_message({"role": "observation", "content": artifacts_loaded}) - if last_user_message["role"] == "user": - user_code_action = parse_execution( - cast(str, last_user_message.get("content", "")), False - ) - if user_code_action is not None: - user_result, user_obs = run_code_action( - user_code_action, code_interpreter, str(remote_artifacts_path) - ) - if self.verbosity >= 1: - _LOGGER.info(user_obs) - int_chat.append({"role": "observation", "content": user_obs}) - orig_chat.append( - { - "role": "observation", - "content": user_obs, - "execution": user_result, - } - ) - self.streaming_message( - { - "role": "observation", - "content": user_obs, - "execution": user_result, - "finished": True, - } - ) - finished = True + finished = self.execute_user_code_action( + last_user_message, code_interpreter, remote_artifacts_path + ) while not finished and iterations < self.max_iterations: response = run_conversation(self.agent, int_chat) @@ -320,11 +302,7 @@ def chat_with_code( { "role": "assistant", "content": "{}", - "error": { - "name": "Error when running conversation agent", - "value": "Agent is stuck in conversation loop, exited", - "traceback_raw": [], - }, + "error": STUCK_IN_LOOP_ERROR_MESSAGE, } ) @@ -339,11 +317,7 @@ def chat_with_code( { "role": "assistant", "content": "{}", - "error": { - "name": "Error when running conversation agent", - "value": "Agent is stuck in conversation loop, exited", - "traceback_raw": [], - }, + "error": STUCK_IN_LOOP_ERROR_MESSAGE, "finished": finished and code_action is None, } ) @@ -397,6 +371,34 @@ def chat_with_code( artifacts.save() return orig_chat, artifacts + def execute_user_code_action( + self, + last_user_message: Message, + code_interpreter: CodeInterpreter, + remote_artifacts_path: Path, + ) -> bool: + if last_user_message["role"] != "user": + return False + user_code_action = parse_execution( + cast(str, last_user_message.get("content", "")), False + ) + if user_code_action is not None: + user_result, user_obs = run_code_action( + user_code_action, code_interpreter, str(remote_artifacts_path) + ) + if self.verbosity >= 1: + _LOGGER.info(user_obs) + self.streaming_message( + { + "role": "observation", + "content": user_obs, + "execution": user_result, + "finished": True, + } + ) + return True + return False + def streaming_message(self, message: Dict[str, Any]) -> None: if self.callback_message: self.callback_message(message)