Skip to content

Commit

Permalink
adding reflect to be optional for cases where LMM might not be able t…
Browse files Browse the repository at this point in the history
…o understand the image
  • Loading branch information
shankar-vision-eng committed Apr 30, 2024
1 parent 305b343 commit ef65ff2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 14 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ you. For example:

#### Custom Tools
You can also add your own custom tools for your vision agent to use:

```python
from vision_agent.tools import Tool, register_tool
@register_tool
Expand Down Expand Up @@ -160,6 +160,7 @@ find an example that creates a custom tool for template matching [here](examples
| BboxIoU | BboxIoU returns the intersection over union of two bounding boxes normalized to 2 decimal places. |
| SegIoU | SegIoU returns the intersection over union of two segmentation masks normalized to 2 decimal places. |
| BoxDistance | BoxDistance returns the minimum distance between two bounding boxes normalized to 2 decimal places. |
| MaskDistance | MaskDistance returns the minimum distance between two segmentation masks in pixel units |
| BboxContains | BboxContains returns the intersection of two boxes over the target box area. It is good for check if one box is contained within another box. |
| ExtractFrames | ExtractFrames extracts frames with motion from a video. |
| ZeroShotCounting | ZeroShotCounting returns the total number of objects belonging to a single class in a given image. |
Expand Down
35 changes: 22 additions & 13 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ def __call__(
image: Optional[Union[str, Path]] = None,
reference_data: Optional[Dict[str, str]] = None,
visualize_output: Optional[bool] = False,
reflect_output: Optional[bool] = True,
) -> str:
"""Invoke the vision agent.
Expand Down Expand Up @@ -538,6 +539,7 @@ def chat_with_workflow(
image: Optional[Union[str, Path]] = None,
reference_data: Optional[Dict[str, str]] = None,
visualize_output: Optional[bool] = False,
reflect_output: Optional[bool] = True,
) -> Tuple[str, List[Dict]]:
"""Chat with the vision agent and return the final answer and all tool results.
Expand Down Expand Up @@ -625,20 +627,25 @@ def chat_with_workflow(
reflection_images = [image]
else:
reflection_images = None
reflection = self_reflect(
self.reflect_model,
question,
self.tools,
all_tool_results,
final_answer,
reflection_images,
)
self.log_progress(f"Reflection: {reflection}")
parsed_reflection = parse_reflect(reflection)
if parsed_reflection["Finish"]:
break

if reflect_output:
reflection = self_reflect(
self.reflect_model,
question,
self.tools,
all_tool_results,
final_answer,
reflection_images,
)
self.log_progress(f"Reflection: {reflection}")
parsed_reflection = parse_reflect(reflection)
if parsed_reflection["Finish"]:
break
else:
reflections += "\n" + parsed_reflection["Reflection"]
else:
reflections += "\n" + parsed_reflection["Reflection"]
self.log_progress("Reflection skipped based on user request.")
break
# '<ANSWER>' is a symbol to indicate the end of the chat, which is useful for streaming logs.
self.log_progress(
f"The Vision Agent has concluded this chat. <ANSWER>{final_answer}</ANSWER>"
Expand All @@ -660,12 +667,14 @@ def chat(
image: Optional[Union[str, Path]] = None,
reference_data: Optional[Dict[str, str]] = None,
visualize_output: Optional[bool] = False,
reflect_output: Optional[bool] = True,
) -> str:
answer, _ = self.chat_with_workflow(
chat,
image=image,
visualize_output=visualize_output,
reference_data=reference_data,
reflect_output=reflect_output,
)
return answer

Expand Down

0 comments on commit ef65ff2

Please sign in to comment.