From 1ccacf8755abf47c2c92b2ae6cc5b9c154abdde9 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Tue, 4 Jun 2024 19:57:06 -0700 Subject: [PATCH] Changes to debugging and working memory (#109) * added LMM planner * updated OCR tool desc * isort * fixed type error * moved back to llm for planner * small fixes to prompt * updated types * revert tools doc * updating debugging and working memory * make box colors consistent per class * isort * resolve merge conflicts * fixed typo --- tests/test_tools.py | 4 +- vision_agent/agent/vision_agent.py | 61 ++++++++++++++++------ vision_agent/agent/vision_agent_prompts.py | 2 - vision_agent/tools/tools.py | 2 + vision_agent/utils/video.py | 2 +- 5 files changed, 50 insertions(+), 21 deletions(-) diff --git a/tests/test_tools.py b/tests/test_tools.py index 410ae561..d893e7aa 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,8 +1,9 @@ -import skimage as ski import numpy as np +import skimage as ski from vision_agent.tools import ( clip, + closest_mask_distance, grounding_dino, grounding_sam, image_caption, @@ -10,7 +11,6 @@ ocr, visual_prompt_counting, zero_shot_counting, - closest_mask_distance, ) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index e421fe16..27880b7f 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -1,4 +1,5 @@ import copy +import difflib import json import logging import sys @@ -16,7 +17,6 @@ from vision_agent.agent import Agent from vision_agent.agent.vision_agent_prompts import ( CODE, - FEEDBACK, FIX_BUG, FULL_TASK, PLAN, @@ -39,17 +39,27 @@ _DEFAULT_IMPORT = "\n".join(T.__new_tools__) -def format_memory(memory: List[Dict[str, str]]) -> str: - return FEEDBACK.format( - feedback="\n".join( - [ - f"### Feedback {i}:\nCode: ```python\n{m['code']}\n```\nFeedback: {m['feedback']}\n" - for i, m in enumerate(memory) - ] +def get_diff(before: str, after: str) -> str: + return "".join( + difflib.unified_diff( + before.splitlines(keepends=True), after.splitlines(keepends=True) ) ) +def format_memory(memory: List[Dict[str, str]]) -> str: + output_str = "" + for i, m in enumerate(memory): + output_str += f"### Feedback {i}:\n" + output_str += f"Code {i}:\n```python\n{m['code']}```\n\n" + output_str += f"Feedback {i}: {m['feedback']}\n\n" + if "edits" in m: + output_str += f"Edits {i}:\n{m['edits']}\n" + output_str += "\n" + + return output_str + + def extract_code(code: str) -> str: if "\n```python" in code: start = "\n```python" @@ -146,7 +156,7 @@ def write_and_test_code( task: str, tool_info: str, tool_utils: str, - working_memory: str, + working_memory: List[Dict[str, str]], coder: LLM, tester: LLM, debugger: LLM, @@ -163,7 +173,13 @@ def write_and_test_code( } ) code = extract_code( - coder(CODE.format(docstring=tool_info, question=task, feedback=working_memory)) + coder( + CODE.format( + docstring=tool_info, + question=task, + feedback=format_memory(working_memory), + ) + ) ) test = extract_code( tester( @@ -206,7 +222,7 @@ def write_and_test_code( ) count = 0 - new_working_memory = [] + new_working_memory: List[Dict[str, str]] = [] while not result.success and count < max_retries: log_progress( { @@ -217,14 +233,28 @@ def write_and_test_code( fixed_code_and_test = extract_json( debugger( FIX_BUG.format( - code=code, tests=test, result=result.text(), feedback=working_memory + code=code, + tests=test, + result="\n".join(result.text().splitlines()[-50:]), + feedback=format_memory(working_memory + new_working_memory), ) ) ) + old_code = code + old_test = test + if fixed_code_and_test["code"].strip() != "": code = extract_code(fixed_code_and_test["code"]) if fixed_code_and_test["test"].strip() != "": test = extract_code(fixed_code_and_test["test"]) + + new_working_memory.append( + { + "code": f"{code}\n{test}", + "feedback": fixed_code_and_test["reflections"], + "edits": get_diff(f"{old_code}\n{old_test}", f"{code}\n{test}"), + } + ) log_progress( { "type": "code", @@ -235,9 +265,6 @@ def write_and_test_code( }, } ) - new_working_memory.append( - {"code": f"{code}\n{test}", "feedback": fixed_code_and_test["reflections"]} - ) result = code_interpreter.exec_isolation(f"{_DEFAULT_IMPORT}\n{code}\n{test}") log_progress( @@ -485,7 +512,7 @@ def chat_with_workflow( ), tool_info=tool_info, tool_utils=T.UTILITIES_DOCSTRING, - working_memory=format_memory(working_memory), + working_memory=working_memory, coder=self.coder, tester=self.tester, debugger=self.debugger, @@ -529,6 +556,8 @@ def chat_with_workflow( working_memory.append( {"code": f"{code}\n{test}", "feedback": feedback} ) + else: + break retries += 1 diff --git a/vision_agent/agent/vision_agent_prompts.py b/vision_agent/agent/vision_agent_prompts.py index 97f73ebc..d2c6fe4b 100644 --- a/vision_agent/agent/vision_agent_prompts.py +++ b/vision_agent/agent/vision_agent_prompts.py @@ -197,9 +197,7 @@ def find_text(image_path: str, text: str) -> str: ``` It raises this error: -```python {result} -``` This is previous feedback provided on the code: {feedback} diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 061da146..11e219d0 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -608,6 +608,7 @@ def overlay_bounding_boxes( label: COLORS[i % len(COLORS)] for i, label in enumerate(set([box["label"] for box in bboxes])) } + bboxes = sorted(bboxes, key=lambda x: x["label"], reverse=True) width, height = pil_image.size fontsize = max(12, int(min(width, height) / 40)) @@ -680,6 +681,7 @@ def overlay_segmentation_masks( label: COLORS[i % len(COLORS)] for i, label in enumerate(set([mask["label"] for mask in masks])) } + masks = sorted(masks, key=lambda x: x["label"], reverse=True) for elt in masks: mask = elt["mask"] diff --git a/vision_agent/utils/video.py b/vision_agent/utils/video.py index 93c8c7fb..bd04ac8c 100644 --- a/vision_agent/utils/video.py +++ b/vision_agent/utils/video.py @@ -2,8 +2,8 @@ import logging import math import os -from concurrent.futures import ProcessPoolExecutor, as_completed import tempfile +from concurrent.futures import ProcessPoolExecutor, as_completed from typing import List, Tuple, cast import cv2