Skip to content

Commit

Permalink
feat: allow disable motion detection in frame extraction function (#55)
Browse files Browse the repository at this point in the history
* Tweak frame extraction function

* remove default motion detection, extract at 0.5 fps

* lmm now take multiple images

* removed counter

* tweaked prompt

* updated vision agent to reflect on multiple images

* fix test case

* added box distance

* adjusted prompts

---------

Co-authored-by: Yazhou Cao <[email protected]>
Co-authored-by: Dillon Laird <[email protected]>
  • Loading branch information
3 people authored Apr 17, 2024
1 parent 0f16897 commit 8415cb3
Show file tree
Hide file tree
Showing 11 changed files with 176 additions and 111 deletions.
2 changes: 1 addition & 1 deletion tests/test_lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def create_temp_image(image_format="jpeg"):
def test_generate_with_mock(openai_lmm_mock): # noqa: F811
temp_image = create_temp_image()
lmm = OpenAILMM()
response = lmm.generate("test prompt", image=temp_image)
response = lmm.generate("test prompt", images=[temp_image])
assert response == "mocked response"
assert (
"image_url"
Expand Down
25 changes: 24 additions & 1 deletion tests/tools/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from PIL import Image

from vision_agent.tools.tools import BboxIoU, SegArea, SegIoU
from vision_agent.tools.tools import BboxIoU, BoxDistance, SegArea, SegIoU


def test_bbox_iou():
Expand Down Expand Up @@ -42,3 +42,26 @@ def test_seg_area_2():
mask_path = os.path.join(tmpdir, "mask.png")
Image.fromarray(mask).save(mask_path)
assert SegArea()(mask_path) == 4.0


def test_box_distance():
box_dist = BoxDistance()
# horizontal dist
box1 = [0, 0, 2, 2]
box2 = [4, 1, 6, 3]
assert box_dist(box1, box2) == 2.0

# vertical dist
box1 = [0, 0, 2, 2]
box2 = [1, 4, 3, 6]
assert box_dist(box1, box2) == 2.0

# vertical and horizontal
box1 = [0, 0, 2, 2]
box2 = [3, 3, 5, 5]
assert box_dist(box1, box2) == 1.41

# overlap
box1 = [0, 0, 2, 2]
box2 = [1, 1, 3, 3]
assert box_dist(box1, box2) == 0.0
10 changes: 9 additions & 1 deletion vision_agent/agent/easytool_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
These are logs of previous questions and answers:
{previous_log}
This is the current user's question: {question}
This is the API tool documentation: {tool_usage}
Output: """
Expand All @@ -67,15 +68,22 @@
2. We will not show the API response to the user, thus you need to make full use of the response and give the information in the response that can satisfy the user's question in as much detail as possible.
3. If the API tool does not provide useful information in the response, please answer with your knowledge.
4. The question may have dependencies on answers of other questions, so we will provide logs of previous questions and answers.
These are logs of previous questions and answers:
{previous_log}
This is the user's question: {question}
This is the response output by the API tool:
{call_results}
We will not show the API response to the user, thus you need to make full use of the response and give the information in the response that can satisfy the user's question in as much detail as possible.
Output: """

ANSWER_SUMMARIZE = """We break down a complex user's problems into simple subtasks and provide answers to each simple subtask. You need to organize these answers to each subtask and form a self-consistent final answer to the user's question.
This is the user's question: {question}
These are subtasks and their answers: {answers}
These are subtasks and their answers:
{answers}
Final answer: """
20 changes: 14 additions & 6 deletions vision_agent/agent/reflexion.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,20 @@ def prompt_agent(
self._build_agent_prompt(question, reflections, scratchpad)
)
)
return format_step(
self.action_agent(
self._build_agent_prompt(question, reflections, scratchpad),
image=image,
elif isinstance(self.action_agent, LMM):
return format_step(
self.action_agent(
self._build_agent_prompt(question, reflections, scratchpad),
images=[image] if image is not None else None,
)
)
elif isinstance(self.action_agent, Agent):
return format_step(
self.action_agent(
self._build_agent_prompt(question, reflections, scratchpad),
image=image,
)
)
)

def prompt_reflection(
self,
Expand All @@ -261,7 +269,7 @@ def prompt_reflection(
return format_step(
self.self_reflect_model(
self._build_reflect_prompt(question, context, scratchpad),
image=image,
images=[image] if image is not None else None,
)
)

Expand Down
22 changes: 14 additions & 8 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import tempfile
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

from PIL import Image
from tabulate import tabulate
Expand Down Expand Up @@ -264,7 +264,7 @@ def self_reflect(
tools: Dict[int, Any],
tool_result: List[Dict],
final_answer: str,
image: Optional[Union[str, Path]] = None,
images: Optional[Sequence[Union[str, Path]]] = None,
) -> str:
prompt = VISION_AGENT_REFLECTION.format(
question=question,
Expand All @@ -275,10 +275,10 @@ def self_reflect(
)
if (
issubclass(type(reflect_model), LMM)
and image is not None
and Path(image).suffix in [".jpg", ".jpeg", ".png"]
and images is not None
and all([Path(image).suffix in [".jpg", ".jpeg", ".png"] for image in images])
):
return reflect_model(prompt, image=image) # type: ignore
return reflect_model(prompt, images=images) # type: ignore
return reflect_model(prompt)


Expand Down Expand Up @@ -357,7 +357,7 @@ def _handle_viz_tools(
return image_to_data


def visualize_result(all_tool_results: List[Dict]) -> List[str]:
def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]:
image_to_data: Dict[str, Dict] = {}
for tool_result in all_tool_results:
# only handle bbox/mask tools or frame extraction
Expand Down Expand Up @@ -407,7 +407,7 @@ def __init__(
task_model: Optional[Union[LLM, LMM]] = None,
answer_model: Optional[Union[LLM, LMM]] = None,
reflect_model: Optional[Union[LLM, LMM]] = None,
max_retries: int = 2,
max_retries: int = 3,
verbose: bool = False,
report_progress_callback: Optional[Callable[[str], None]] = None,
):
Expand Down Expand Up @@ -519,13 +519,19 @@ def chat_with_workflow(

visualized_output = visualize_result(all_tool_results)
all_tool_results.append({"visualized_output": visualized_output})
if len(visualized_output) > 0:
reflection_images = visualized_output
elif image is not None:
reflection_images = [image]
else:
reflection_images = None
reflection = self_reflect(
self.reflect_model,
question,
self.tools,
all_tool_results,
final_answer,
visualized_output[0] if len(visualized_output) > 0 else image,
reflection_images,
)
self.log_progress(f"Reflection: {reflection}")
parsed_reflection = parse_reflect(reflection)
Expand Down
9 changes: 5 additions & 4 deletions vision_agent/agent/vision_agent_prompts.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question and the final answer the agent provided. You may also receive an image with the visualized bounding boxes or masks with their associated labels and scores from the tools used.
VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question, the tool usage for each of the tools used and the final answer the agent provided. You may also receive an image with the visualized bounding boxes or masks with their associated labels and scores from the tools used.
Please note that:
1. You must ONLY output parsible JSON format. If the agents output was correct set "Finish" to true, else set "Finish" to false. An example output looks like:
{{"Finish": true, "Reflection": "The agent's answer was correct."}}
2. You must utilize the image with the visualized bounding boxes or masks and determine if the tools were used correctly or, using your own judgement, utilized incorrectly.
3. If the agent's answer was incorrect, you must diagnose a possible reason for failure or phrasing discrepancy and devise a new, concise, concrete plan that aims to mitigate the same failure with the tools available. An example output looks like:
{{"Finish": false, "Reflection": "I can see from teh visualized bounding boxes that the agent's answer was incorrect because the grounding_dino_ tool produced false positive predictions. The agent should use the following tools with the following parameters:
2. You must utilize the image with the visualized bounding boxes or masks and determine if the tools were used correctly or if the tools were used incorrectly or the wrong tools were used.
3. If the agent's answer was incorrect, you must diagnose the reason for failure and devise a new concise and concrete plan that aims to mitigate the same failure with the tools available. An example output looks like:
{{"Finish": false, "Reflection": "I can see from the visualized bounding boxes that the agent's answer was incorrect because the grounding_dino_ tool produced false positive predictions. The agent should use the following tools with the following parameters:
Step 1: Use 'grounding_dino_' with a 'prompt' of 'baby. bed' and a 'box_threshold' of 0.7 to reduce the false positives.
Step 2: Use 'box_iou_' with the baby bounding box and the bed bounding box to determine if the baby is on the bed or not."}}
4. If the task cannot be completed with the existing tools or by adjusting the parameters, set "Finish" to true.
Expand Down Expand Up @@ -140,4 +140,5 @@
This is a reflection from a previous failed attempt:
{reflections}
Final answer: """
4 changes: 2 additions & 2 deletions vision_agent/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def add_column(

self.df[name] = self.df["image_paths"].progress_apply( # type: ignore
lambda x: (
func(self.lmm.generate(prompt, image=x))
func(self.lmm.generate(prompt, images=[x]))
if func
else self.lmm.generate(prompt, image=x)
else self.lmm.generate(prompt, images=[x])
)
)
return self
Expand Down
102 changes: 58 additions & 44 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,24 @@ def encode_image(image: Union[str, Path]) -> str:

class LMM(ABC):
@abstractmethod
def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str:
def generate(
self, prompt: str, images: Optional[List[Union[str, Path]]] = None
) -> str:
pass

@abstractmethod
def chat(
self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None
self,
chat: List[Dict[str, str]],
images: Optional[List[Union[str, Path]]] = None,
) -> str:
pass

@abstractmethod
def __call__(
self,
input: Union[str, List[Dict[str, str]]],
image: Optional[Union[str, Path]] = None,
images: Optional[List[Union[str, Path]]] = None,
) -> str:
pass

Expand All @@ -57,27 +61,29 @@ def __init__(self, model_name: str):
def __call__(
self,
input: Union[str, List[Dict[str, str]]],
image: Optional[Union[str, Path]] = None,
images: Optional[List[Union[str, Path]]] = None,
) -> str:
if isinstance(input, str):
return self.generate(input, image)
return self.chat(input, image)
return self.generate(input, images)
return self.chat(input, images)

def chat(
self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None
self,
chat: List[Dict[str, str]],
images: Optional[List[Union[str, Path]]] = None,
) -> str:
raise NotImplementedError("Chat not supported for LLaVA")

def generate(
self,
prompt: str,
image: Optional[Union[str, Path]] = None,
images: Optional[List[Union[str, Path]]] = None,
temperature: float = 0.1,
max_new_tokens: int = 1500,
) -> str:
data = {"prompt": prompt}
if image:
data["image"] = encode_image(image)
if images and len(images) > 0:
data["image"] = encode_image(images[0])
data["temperature"] = temperature # type: ignore
data["max_new_tokens"] = max_new_tokens # type: ignore
res = requests.post(
Expand Down Expand Up @@ -121,48 +127,55 @@ def __init__(
def __call__(
self,
input: Union[str, List[Dict[str, str]]],
image: Optional[Union[str, Path]] = None,
images: Optional[List[Union[str, Path]]] = None,
) -> str:
if isinstance(input, str):
return self.generate(input, image)
return self.chat(input, image)
return self.generate(input, images)
return self.chat(input, images)

def chat(
self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None
self,
chat: List[Dict[str, str]],
images: Optional[List[Union[str, Path]]] = None,
) -> str:
fixed_chat = []
for c in chat:
fixed_c = {"role": c["role"]}
fixed_c["content"] = [{"type": "text", "text": c["content"]}] # type: ignore
fixed_chat.append(fixed_c)

if image:
extension = Path(image).suffix
if extension.lower() == ".jpeg" or extension.lower() == ".jpg":
extension = "jpg"
elif extension.lower() == ".png":
extension = "png"
else:
raise ValueError(f"Unsupported image extension: {extension}")

encoded_image = encode_image(image)
fixed_chat[0]["content"].append( # type: ignore
{
"type": "image_url",
"image_url": {
"url": f"data:image/{extension};base64,{encoded_image}",
"detail": "low",
if images and len(images) > 0:
for image in images:
extension = Path(image).suffix
if extension.lower() == ".jpeg" or extension.lower() == ".jpg":
extension = "jpg"
elif extension.lower() == ".png":
extension = "png"
else:
raise ValueError(f"Unsupported image extension: {extension}")

encoded_image = encode_image(image)
fixed_chat[0]["content"].append( # type: ignore
{
"type": "image_url",
"image_url": {
"url": f"data:image/{extension};base64,{encoded_image}",
"detail": "low",
},
},
},
)
)

response = self.client.chat.completions.create(
model=self.model_name, messages=fixed_chat, **self.kwargs # type: ignore
)

return cast(str, response.choices[0].message.content)

def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str:
def generate(
self,
prompt: str,
images: Optional[List[Union[str, Path]]] = None,
) -> str:
message: List[Dict[str, Any]] = [
{
"role": "user",
Expand All @@ -171,18 +184,19 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str
],
}
]
if image:
extension = Path(image).suffix
encoded_image = encode_image(image)
message[0]["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/{extension};base64,{encoded_image}",
"detail": "low",
if images and len(images) > 0:
for image in images:
extension = Path(image).suffix
encoded_image = encode_image(image)
message[0]["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/{extension};base64,{encoded_image}",
"detail": "low",
},
},
},
)
)

response = self.client.chat.completions.create(
model=self.model_name, messages=message, **self.kwargs # type: ignore
Expand Down
Loading

0 comments on commit 8415cb3

Please sign in to comment.