From 2126b06714f1048acdc2288746ef1589631ae3e3 Mon Sep 17 00:00:00 2001 From: wuyiqunLu <132986242+wuyiqunLu@users.noreply.github.com> Date: Sun, 22 Sep 2024 12:20:54 +0800 Subject: [PATCH] feat: when code edited by code executed directly (#242) * feat: when code edited by code executed directly * fix lint * fix lint * address comment --- vision_agent/agent/vision_agent.py | 27 +++++++++++++++++++++++++++ vision_agent/tools/meta_tools.py | 1 + 2 files changed, 28 insertions(+) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 1e1abbe6..62682524 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -218,6 +218,7 @@ def chat_with_code( ) as code_interpreter: orig_chat = copy.deepcopy(chat) int_chat = copy.deepcopy(chat) + last_user_message_content = chat[-1].get("content") media_list = [] for chat_i in int_chat: if "media" in chat_i: @@ -266,6 +267,32 @@ def chat_with_code( orig_chat.append({"role": "observation", "content": artifacts_loaded}) self.streaming_message({"role": "observation", "content": artifacts_loaded}) + if isinstance(last_user_message_content, str): + user_code_action = parse_execution(last_user_message_content, False) + if user_code_action is not None: + user_result, user_obs = run_code_action( + user_code_action, code_interpreter, str(remote_artifacts_path) + ) + if self.verbosity >= 1: + _LOGGER.info(user_obs) + int_chat.append({"role": "observation", "content": user_obs}) + orig_chat.append( + { + "role": "observation", + "content": user_obs, + "execution": user_result, + } + ) + self.streaming_message( + { + "role": "observation", + "content": user_obs, + "execution": user_result, + "finished": True, + } + ) + finished = True + while not finished and iterations < self.max_iterations: response = run_conversation(self.agent, int_chat) if self.verbosity >= 1: diff --git a/vision_agent/tools/meta_tools.py b/vision_agent/tools/meta_tools.py index 9012e9d4..0d20cb28 100644 --- a/vision_agent/tools/meta_tools.py +++ b/vision_agent/tools/meta_tools.py @@ -250,6 +250,7 @@ def edit_code_artifact( total_lines = len(artifacts[name].splitlines()) if start < 0 or end < 0 or start > end or end > total_lines: + print("[Invalid line range]") return "[Invalid line range]" if start == end: end += 1