Skip to content

Commit

Permalink
Changes to debugging and working memory (#109)
Browse files Browse the repository at this point in the history
* added LMM planner

* updated OCR tool desc

* isort

* fixed type error

* moved back to llm for planner

* small fixes to prompt

* updated types

* revert tools doc

* updating debugging and working memory

* make box colors consistent per class

* isort

* resolve merge conflicts

* fixed typo
  • Loading branch information
dillonalaird authored Jun 5, 2024
1 parent 7223a3e commit 1ccacf8
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 21 deletions.
4 changes: 2 additions & 2 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import skimage as ski
import numpy as np
import skimage as ski

from vision_agent.tools import (
clip,
closest_mask_distance,
grounding_dino,
grounding_sam,
image_caption,
image_question_answering,
ocr,
visual_prompt_counting,
zero_shot_counting,
closest_mask_distance,
)


Expand Down
61 changes: 45 additions & 16 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import difflib
import json
import logging
import sys
Expand All @@ -16,7 +17,6 @@
from vision_agent.agent import Agent
from vision_agent.agent.vision_agent_prompts import (
CODE,
FEEDBACK,
FIX_BUG,
FULL_TASK,
PLAN,
Expand All @@ -39,17 +39,27 @@
_DEFAULT_IMPORT = "\n".join(T.__new_tools__)


def format_memory(memory: List[Dict[str, str]]) -> str:
return FEEDBACK.format(
feedback="\n".join(
[
f"### Feedback {i}:\nCode: ```python\n{m['code']}\n```\nFeedback: {m['feedback']}\n"
for i, m in enumerate(memory)
]
def get_diff(before: str, after: str) -> str:
return "".join(
difflib.unified_diff(
before.splitlines(keepends=True), after.splitlines(keepends=True)
)
)


def format_memory(memory: List[Dict[str, str]]) -> str:
output_str = ""
for i, m in enumerate(memory):
output_str += f"### Feedback {i}:\n"
output_str += f"Code {i}:\n```python\n{m['code']}```\n\n"
output_str += f"Feedback {i}: {m['feedback']}\n\n"
if "edits" in m:
output_str += f"Edits {i}:\n{m['edits']}\n"
output_str += "\n"

return output_str


def extract_code(code: str) -> str:
if "\n```python" in code:
start = "\n```python"
Expand Down Expand Up @@ -146,7 +156,7 @@ def write_and_test_code(
task: str,
tool_info: str,
tool_utils: str,
working_memory: str,
working_memory: List[Dict[str, str]],
coder: LLM,
tester: LLM,
debugger: LLM,
Expand All @@ -163,7 +173,13 @@ def write_and_test_code(
}
)
code = extract_code(
coder(CODE.format(docstring=tool_info, question=task, feedback=working_memory))
coder(
CODE.format(
docstring=tool_info,
question=task,
feedback=format_memory(working_memory),
)
)
)
test = extract_code(
tester(
Expand Down Expand Up @@ -206,7 +222,7 @@ def write_and_test_code(
)

count = 0
new_working_memory = []
new_working_memory: List[Dict[str, str]] = []
while not result.success and count < max_retries:
log_progress(
{
Expand All @@ -217,14 +233,28 @@ def write_and_test_code(
fixed_code_and_test = extract_json(
debugger(
FIX_BUG.format(
code=code, tests=test, result=result.text(), feedback=working_memory
code=code,
tests=test,
result="\n".join(result.text().splitlines()[-50:]),
feedback=format_memory(working_memory + new_working_memory),
)
)
)
old_code = code
old_test = test

if fixed_code_and_test["code"].strip() != "":
code = extract_code(fixed_code_and_test["code"])
if fixed_code_and_test["test"].strip() != "":
test = extract_code(fixed_code_and_test["test"])

new_working_memory.append(
{
"code": f"{code}\n{test}",
"feedback": fixed_code_and_test["reflections"],
"edits": get_diff(f"{old_code}\n{old_test}", f"{code}\n{test}"),
}
)
log_progress(
{
"type": "code",
Expand All @@ -235,9 +265,6 @@ def write_and_test_code(
},
}
)
new_working_memory.append(
{"code": f"{code}\n{test}", "feedback": fixed_code_and_test["reflections"]}
)

result = code_interpreter.exec_isolation(f"{_DEFAULT_IMPORT}\n{code}\n{test}")
log_progress(
Expand Down Expand Up @@ -485,7 +512,7 @@ def chat_with_workflow(
),
tool_info=tool_info,
tool_utils=T.UTILITIES_DOCSTRING,
working_memory=format_memory(working_memory),
working_memory=working_memory,
coder=self.coder,
tester=self.tester,
debugger=self.debugger,
Expand Down Expand Up @@ -529,6 +556,8 @@ def chat_with_workflow(
working_memory.append(
{"code": f"{code}\n{test}", "feedback": feedback}
)
else:
break

retries += 1

Expand Down
2 changes: 0 additions & 2 deletions vision_agent/agent/vision_agent_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,7 @@ def find_text(image_path: str, text: str) -> str:
```
It raises this error:
```python
{result}
```
This is previous feedback provided on the code:
{feedback}
Expand Down
2 changes: 2 additions & 0 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,7 @@ def overlay_bounding_boxes(
label: COLORS[i % len(COLORS)]
for i, label in enumerate(set([box["label"] for box in bboxes]))
}
bboxes = sorted(bboxes, key=lambda x: x["label"], reverse=True)

width, height = pil_image.size
fontsize = max(12, int(min(width, height) / 40))
Expand Down Expand Up @@ -680,6 +681,7 @@ def overlay_segmentation_masks(
label: COLORS[i % len(COLORS)]
for i, label in enumerate(set([mask["label"] for mask in masks]))
}
masks = sorted(masks, key=lambda x: x["label"], reverse=True)

for elt in masks:
mask = elt["mask"]
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/utils/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import logging
import math
import os
from concurrent.futures import ProcessPoolExecutor, as_completed
import tempfile
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import List, Tuple, cast

import cv2
Expand Down

0 comments on commit 1ccacf8

Please sign in to comment.