Skip to content

Commit

Permalink
Enable/Disable Reflection (#72)
Browse files Browse the repository at this point in the history
* adding reflect to be optional for cases where LMM might not be able to understand the image

* changed the param name to self_reflect

* fixing param name as it overlaps with function call
  • Loading branch information
shankar-vision-eng authored Apr 30, 2024
1 parent f5943e0 commit 0d5f999
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 14 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ you. For example:

#### Custom Tools
You can also add your own custom tools for your vision agent to use:

```python
from vision_agent.tools import Tool, register_tool
@register_tool
Expand Down Expand Up @@ -160,6 +160,7 @@ find an example that creates a custom tool for template matching [here](examples
| BboxIoU | BboxIoU returns the intersection over union of two bounding boxes normalized to 2 decimal places. |
| SegIoU | SegIoU returns the intersection over union of two segmentation masks normalized to 2 decimal places. |
| BoxDistance | BoxDistance returns the minimum distance between two bounding boxes normalized to 2 decimal places. |
| MaskDistance | MaskDistance returns the minimum distance between two segmentation masks in pixel units |
| BboxContains | BboxContains returns the intersection of two boxes over the target box area. It is good for check if one box is contained within another box. |
| ExtractFrames | ExtractFrames extracts frames with motion from a video. |
| ZeroShotCounting | ZeroShotCounting returns the total number of objects belonging to a single class in a given image. |
Expand Down
38 changes: 25 additions & 13 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ def __call__(
image: Optional[Union[str, Path]] = None,
reference_data: Optional[Dict[str, str]] = None,
visualize_output: Optional[bool] = False,
self_reflection: Optional[bool] = True,
) -> str:
"""Invoke the vision agent.
Expand All @@ -501,6 +502,7 @@ def __call__(
{"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]}
where the bounding box coordinates are normalized.
visualize_output: Whether to visualize the output.
self_reflection: boolean to enable and disable self reflection.
Returns:
The result of the vision agent in text.
Expand All @@ -512,6 +514,7 @@ def __call__(
image=image,
visualize_output=visualize_output,
reference_data=reference_data,
self_reflection=self_reflection,
)

def log_progress(self, description: str) -> None:
Expand All @@ -538,6 +541,7 @@ def chat_with_workflow(
image: Optional[Union[str, Path]] = None,
reference_data: Optional[Dict[str, str]] = None,
visualize_output: Optional[bool] = False,
self_reflection: Optional[bool] = True,
) -> Tuple[str, List[Dict]]:
"""Chat with the vision agent and return the final answer and all tool results.
Expand All @@ -550,6 +554,7 @@ def chat_with_workflow(
{"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]}
where the bounding box coordinates are normalized.
visualize_output: Whether to visualize the output.
self_reflection: boolean to enable and disable self reflection.
Returns:
A tuple where the first item is the final answer and the second item is a
Expand Down Expand Up @@ -625,20 +630,25 @@ def chat_with_workflow(
reflection_images = [image]
else:
reflection_images = None
reflection = self_reflect(
self.reflect_model,
question,
self.tools,
all_tool_results,
final_answer,
reflection_images,
)
self.log_progress(f"Reflection: {reflection}")
parsed_reflection = parse_reflect(reflection)
if parsed_reflection["Finish"]:
break

if self_reflection:
reflection = self_reflect(
self.reflect_model,
question,
self.tools,
all_tool_results,
final_answer,
reflection_images,
)
self.log_progress(f"Reflection: {reflection}")
parsed_reflection = parse_reflect(reflection)
if parsed_reflection["Finish"]:
break
else:
reflections += "\n" + parsed_reflection["Reflection"]
else:
reflections += "\n" + parsed_reflection["Reflection"]
self.log_progress("Self Reflection skipped based on user request.")
break
# '<ANSWER>' is a symbol to indicate the end of the chat, which is useful for streaming logs.
self.log_progress(
f"The Vision Agent has concluded this chat. <ANSWER>{final_answer}</ANSWER>"
Expand All @@ -660,12 +670,14 @@ def chat(
image: Optional[Union[str, Path]] = None,
reference_data: Optional[Dict[str, str]] = None,
visualize_output: Optional[bool] = False,
self_reflection: Optional[bool] = True,
) -> str:
answer, _ = self.chat_with_workflow(
chat,
image=image,
visualize_output=visualize_output,
reference_data=reference_data,
self_reflection=self_reflection,
)
return answer

Expand Down

0 comments on commit 0d5f999

Please sign in to comment.