From daa48313b645f4b17006eda5ee29fb1f05660033 Mon Sep 17 00:00:00 2001 From: Asia <2736300+humpydonkey@users.noreply.github.com> Date: Wed, 5 Jun 2024 15:50:48 -0700 Subject: [PATCH] Fix missing typing module error and improve load_image() (#116) * Minor improvements * Updates * Update the final code * Update the printed code * Update the printed code * Update the printed code * Update return code --- vision_agent/agent/vision_agent.py | 110 +++++++++++++++++------------ vision_agent/tools/tools.py | 8 ++- 2 files changed, 69 insertions(+), 49 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 142cdf15..ed45b353 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -36,11 +36,25 @@ _LOGGER = logging.getLogger(__name__) _MAX_TABULATE_COL_WIDTH = 80 _CONSOLE = Console() -_DEFAULT_IMPORT = "\n".join(T.__new_tools__) + "\n".join( - [ + + +class DefaultImports: + """Container for default imports used in the code execution.""" + + common_imports = [ "from typing import *", ] -) + + @staticmethod + def to_code_string() -> str: + return "\n".join(DefaultImports.common_imports + T.__new_tools__) + + @staticmethod + def prepend_imports(code: str) -> str: + """Run this method to prepend the default imports to the code. + NOTE: be sure to run this method after the custom tools have been registered. + """ + return DefaultImports.to_code_string() + "\n\n" + code def get_diff(before: str, after: str) -> str: @@ -202,18 +216,20 @@ def write_and_test_code( "type": "code", "status": "running", "payload": { - "code": code, + "code": DefaultImports.prepend_imports(code), "test": test, }, } ) - result = code_interpreter.exec_isolation(f"{_DEFAULT_IMPORT}\n{code}\n{test}") + result = code_interpreter.exec_isolation( + f"{DefaultImports.to_code_string()}\n{code}\n{test}" + ) log_progress( { "type": "code", "status": "completed" if result.success else "failed", "payload": { - "code": code, + "code": DefaultImports.prepend_imports(code), "test": test, "result": result.to_json(), }, @@ -264,19 +280,21 @@ def write_and_test_code( "type": "code", "status": "running", "payload": { - "code": code, + "code": DefaultImports.prepend_imports(code), "test": test, }, } ) - result = code_interpreter.exec_isolation(f"{_DEFAULT_IMPORT}\n{code}\n{test}") + result = code_interpreter.exec_isolation( + f"{DefaultImports.to_code_string()}\n{code}\n{test}" + ) log_progress( { "type": "code", "status": "completed" if result.success else "failed", "payload": { - "code": code, + "code": DefaultImports.prepend_imports(code), "test": test, "result": result.to_json(), }, @@ -307,7 +325,14 @@ def write_and_test_code( def _print_code(title: str, code: str, test: Optional[str] = None) -> None: _CONSOLE.print(title, style=Style(bgcolor="dark_orange3", bold=True)) _CONSOLE.print("=" * 30 + " Code " + "=" * 30) - _CONSOLE.print(Syntax(code, "python", theme="gruvbox-dark", line_numbers=True)) + _CONSOLE.print( + Syntax( + DefaultImports.prepend_imports(code), + "python", + theme="gruvbox-dark", + line_numbers=True, + ) + ) if test: _CONSOLE.print("=" * 30 + " Test " + "=" * 30) _CONSOLE.print(Syntax(test, "python", theme="gruvbox-dark", line_numbers=True)) @@ -464,10 +489,6 @@ def chat_with_workflow( if chat_i["role"] == "user": chat_i["content"] += f" Image name {media}" - # re-grab custom tools - global _DEFAULT_IMPORT - _DEFAULT_IMPORT = "\n".join(T.__new_tools__) - code = "" test = "" working_memory: List[Dict[str, str]] = [] @@ -531,38 +552,35 @@ def chat_with_workflow( working_memory.extend(results["working_memory"]) # type: ignore plan.append({"code": code, "test": test, "plan": plan_i}) - if self_reflection: - self.log_progress( - { - "type": "self_reflection", - "status": "started", - } - ) - reflection = reflect( - chat, - FULL_TASK.format( - user_request=chat[0]["content"], subtasks=plan_i_str - ), - code, - self.planner, - ) - if self.verbosity > 0: - _LOGGER.info(f"Reflection: {reflection}") - feedback = cast(str, reflection["feedback"]) - success = cast(bool, reflection["success"]) - self.log_progress( - { - "type": "self_reflection", - "status": "completed" if success else "failed", - "payload": reflection, - } - ) - working_memory.append( - {"code": f"{code}\n{test}", "feedback": feedback} - ) - else: + if not self_reflection: break + self.log_progress( + { + "type": "self_reflection", + "status": "started", + } + ) + reflection = reflect( + chat, + FULL_TASK.format( + user_request=chat[0]["content"], subtasks=plan_i_str + ), + code, + self.planner, + ) + if self.verbosity > 0: + _LOGGER.info(f"Reflection: {reflection}") + feedback = cast(str, reflection["feedback"]) + success = cast(bool, reflection["success"]) + self.log_progress( + { + "type": "self_reflection", + "status": "completed" if success else "failed", + "payload": reflection, + } + ) + working_memory.append({"code": f"{code}\n{test}", "feedback": feedback}) retries += 1 execution_result = cast(Execution, results["test_result"]) @@ -571,7 +589,7 @@ def chat_with_workflow( "type": "final_code", "status": "completed" if success else "failed", "payload": { - "code": code, + "code": DefaultImports.prepend_imports(code), "test": test, "result": execution_result.to_json(), }, @@ -586,7 +604,7 @@ def chat_with_workflow( play_video(res.mp4) return { - "code": code, + "code": DefaultImports.prepend_imports(code), "test": test, "test_result": execution_result, "plan": plan, diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 626f93e6..84cb69f4 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -187,7 +187,7 @@ def extract_frames( Returns: List[Tuple[np.ndarray, float]]: A list of tuples containing the extracted frame - and the timestamp in seconds. + as a numpy array and the timestamp in seconds. Example ------- @@ -515,7 +515,7 @@ def default(self, obj: Any): # type: ignore def load_image(image_path: str) -> np.ndarray: - """'load_image' is a utility function that loads an image from the given path. + """'load_image' is a utility function that loads an image from the given file path string. Parameters: image_path (str): The path to the image. @@ -527,7 +527,9 @@ def load_image(image_path: str) -> np.ndarray: ------- >>> load_image("path/to/image.jpg") """ - + # NOTE: sometimes the generated code pass in a NumPy array + if isinstance(image_path, np.ndarray): + return image_path image = Image.open(image_path).convert("RGB") return np.array(image)