Skip to content

Commit

Permalink
Fix Baby Cam Use Case (#51)
Browse files Browse the repository at this point in the history
* fix visualization error

* added font and score to viz

* changed to smaller font file

* Support streaming chat logs of an agent (#47)

Add a callback for reporting the chat progress of an agent

* fix visualize score issue

* updated descriptions, fixed counter bug

* added visualize_output

* make feedback more concrete

* made naming more consistent

* replaced individual calc ops with calculator tool

* fix random colors

* fix prompts for tools

* update reflection prompt

* update readme

* formatting fix

* fixed mypy errors

* fix merge issue

---------

Co-authored-by: Asia <[email protected]>
  • Loading branch information
dillonalaird and humpydonkey committed Apr 17, 2024
1 parent e062992 commit b7cdbee
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 119 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ the individual steps and tools to get the answer:
}
]],
"answer": "The jar is located at [0.58, 0.2, 0.72, 0.45].",
}]
},
{"visualize_output": "final_output.png"}]
```

### Tools
Expand Down
51 changes: 36 additions & 15 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from PIL import Image
from tabulate import tabulate

from vision_agent.image_utils import overlay_bboxes, overlay_masks
Expand Down Expand Up @@ -288,9 +289,8 @@ def visualize_result(all_tool_results: List[Dict]) -> List[str]:
continue
parameters = [parameters]
elif isinstance(tool_result["parameters"], list):
if (
len(tool_result["parameters"]) < 1
and "image" not in tool_result["parameters"][0]
if len(tool_result["parameters"]) < 1 or (
"image" not in tool_result["parameters"][0]
):
continue

Expand All @@ -304,10 +304,16 @@ def visualize_result(all_tool_results: List[Dict]) -> List[str]:
# if the call was successful, then we can add the image data
image = param["image"]
if image not in image_to_data:
image_to_data[image] = {"bboxes": [], "masks": [], "labels": []}
image_to_data[image] = {
"bboxes": [],
"masks": [],
"labels": [],
"scores": [],
}

image_to_data[image]["bboxes"].extend(call_result["bboxes"])
image_to_data[image]["labels"].extend(call_result["labels"])
image_to_data[image]["scores"].extend(call_result["scores"])
if "masks" in call_result:
image_to_data[image]["masks"].extend(call_result["masks"])

Expand Down Expand Up @@ -345,7 +351,7 @@ def __init__(
task_model: Optional[Union[LLM, LMM]] = None,
answer_model: Optional[Union[LLM, LMM]] = None,
reflect_model: Optional[Union[LLM, LMM]] = None,
max_retries: int = 2,
max_retries: int = 3,
verbose: bool = False,
report_progress_callback: Optional[Callable[[str], None]] = None,
):
Expand Down Expand Up @@ -380,6 +386,7 @@ def __call__(
self,
input: Union[List[Dict[str, str]], str],
image: Optional[Union[str, Path]] = None,
visualize_output: Optional[bool] = False,
) -> str:
"""Invoke the vision agent.
Expand All @@ -393,15 +400,18 @@ def __call__(
"""
if isinstance(input, str):
input = [{"role": "user", "content": input}]
return self.chat(input, image=image)
return self.chat(input, image=image, visualize_output=visualize_output)

def log_progress(self, description: str) -> None:
_LOGGER.info(description)
if self.report_progress_callback:
self.report_progress_callback(description)

def chat_with_workflow(
self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None
self,
chat: List[Dict[str, str]],
image: Optional[Union[str, Path]] = None,
visualize_output: Optional[bool] = False,
) -> Tuple[str, List[Dict]]:
question = chat[0]["content"]
if image:
Expand Down Expand Up @@ -449,31 +459,42 @@ def chat_with_workflow(
self.answer_model, question, answers, reflections
)

visualized_images = visualize_result(all_tool_results)
all_tool_results.append({"visualized_images": visualized_images})
visualized_output = visualize_result(all_tool_results)
all_tool_results.append({"visualized_output": visualized_output})
reflection = self_reflect(
self.reflect_model,
question,
self.tools,
all_tool_results,
final_answer,
visualized_images[0] if len(visualized_images) > 0 else image,
visualized_output[0] if len(visualized_output) > 0 else image,
)
self.log_progress(f"Reflection: {reflection}")
if parse_reflect(reflection):
break
else:
reflections += reflection
# '<ANSWER>' is a symbol to indicate the end of the chat, which is useful for streaming logs.
reflections += "\n" + reflection
# '<END>' is a symbol to indicate the end of the chat, which is useful for streaming logs.
self.log_progress(
f"The Vision Agent has concluded this chat. <ANSWER>{final_answer}</ANSWER>"
f"The Vision Agent has concluded this chat. <ANSWER>{final_answer}</<ANSWER>"
)

if visualize_output:
visualized_output = all_tool_results[-1]["visualized_output"]
for image in visualized_output:
Image.open(image).show()

return final_answer, all_tool_results

def chat(
self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None
self,
chat: List[Dict[str, str]],
image: Optional[Union[str, Path]] = None,
visualize_output: Optional[bool] = False,
) -> str:
answer, _ = self.chat_with_workflow(chat, image=image)
answer, _ = self.chat_with_workflow(
chat, image=image, visualize_output=visualize_output
)
return answer

def retrieval(
Expand Down
4 changes: 1 addition & 3 deletions vision_agent/agent/vision_agent_prompts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question and the final answer the agent provided. You must determine if the agent's answer was correct or incorrect. If the agent's answer was correct, respond with Finish. If the agent's answer was incorrect, you must diagnose a possible reason for failure or phrasing discrepancy and devise a new, concise, high level plan that aims to mitigate the same failure with the tools available. Use complete sentences.
VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question and the final answer the agent provided. You may also receive an image with the visualized bounding boxes or masks with their associated labels and scores from the tools used. You must determine if the agent's answer was correct or incorrect. If the agent's answer was correct, respond with Finish. If the agent's answer was incorrect, you must diagnose a possible reason for failure or phrasing discrepancy and devise a new, concise, concrete plan that aims to mitigate the same failure with the tools available. Do not make vague steps like re-evaluate the threshold, instead make concrete steps like use a threshold of 0.5 or whatever threshold you think would fix this issue. If the task cannot be completed with the existing tools, respond with Finish. Use complete sentences.
User's question: {question}
Expand Down Expand Up @@ -49,7 +49,6 @@

CHOOSE_TOOL = """This is the user's question: {question}
These are the tools you can select to solve the question:
{tools}
Please note that:
Expand All @@ -63,7 +62,6 @@

CHOOSE_TOOL_DEPENDS = """This is the user's question: {question}
These are the tools you can select to solve the question:
{tools}
This is a reflection from a previous failed attempt:
Expand Down
Empty file added vision_agent/fonts/__init__.py
Empty file.
Binary file added vision_agent/fonts/default_font_ch_en.ttf
Binary file not shown.
32 changes: 22 additions & 10 deletions vision_agent/image_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utility functions for image processing."""

import base64
from importlib import resources
from io import BytesIO
from pathlib import Path
from typing import Dict, Tuple, Union
Expand Down Expand Up @@ -104,19 +105,28 @@ def overlay_bboxes(

color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(bboxes["labels"])}

draw = ImageDraw.Draw(image)
font = ImageFont.load_default()
width, height = image.size
fontsize = max(12, int(min(width, height) / 40))
draw = ImageDraw.Draw(image)
font = ImageFont.truetype(
str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")),
fontsize,
)
if "bboxes" not in bboxes:
return image.convert("RGB")

for label, box in zip(bboxes["labels"], bboxes["bboxes"]):
box = [box[0] * width, box[1] * height, box[2] * width, box[3] * height]
draw.rectangle(box, outline=color[label], width=3)
label = f"{label}"
text_box = draw.textbbox((box[0], box[1]), text=label, font=font)
draw.rectangle(text_box, fill=color[label])
draw.text((text_box[0], text_box[1]), label, fill="black", font=font)
for label, box, scores in zip(bboxes["labels"], bboxes["bboxes"], bboxes["scores"]):
box = [
int(box[0] * width),
int(box[1] * height),
int(box[2] * width),
int(box[3] * height),
]
draw.rectangle(box, outline=color[label], width=4)
text = f"{label}: {scores:.2f}"
text_box = draw.textbbox((box[0], box[1]), text=text, font=font)
draw.rectangle((box[0], box[1], text_box[2], text_box[3]), fill=color[label])
draw.text((box[0], box[1]), text, fill="black", font=font)
return image.convert("RGB")


Expand All @@ -138,7 +148,9 @@ def overlay_masks(
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)

color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(masks["labels"])}
color = {
label: COLORS[i % len(COLORS)] for i, label in enumerate(set(masks["labels"]))
}
if "masks" not in masks:
return image.convert("RGB")

Expand Down
Loading

0 comments on commit b7cdbee

Please sign in to comment.