From ef65ff2fc2c67704363865c39200f6c29b8aee58 Mon Sep 17 00:00:00 2001 From: shankar_ws3 Date: Mon, 29 Apr 2024 22:57:04 -0700 Subject: [PATCH] adding reflect to be optional for cases where LMM might not be able to understand the image --- README.md | 3 ++- vision_agent/agent/vision_agent.py | 35 +++++++++++++++++++----------- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 76bcc2e1..3eeb19c8 100644 --- a/README.md +++ b/README.md @@ -122,7 +122,7 @@ you. For example: #### Custom Tools You can also add your own custom tools for your vision agent to use: - + ```python from vision_agent.tools import Tool, register_tool @register_tool @@ -160,6 +160,7 @@ find an example that creates a custom tool for template matching [here](examples | BboxIoU | BboxIoU returns the intersection over union of two bounding boxes normalized to 2 decimal places. | | SegIoU | SegIoU returns the intersection over union of two segmentation masks normalized to 2 decimal places. | | BoxDistance | BoxDistance returns the minimum distance between two bounding boxes normalized to 2 decimal places. | +| MaskDistance | MaskDistance returns the minimum distance between two segmentation masks in pixel units | | BboxContains | BboxContains returns the intersection of two boxes over the target box area. It is good for check if one box is contained within another box. | | ExtractFrames | ExtractFrames extracts frames with motion from a video. | | ZeroShotCounting | ZeroShotCounting returns the total number of objects belonging to a single class in a given image. | diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index c22c9983..b8a0b844 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -489,6 +489,7 @@ def __call__( image: Optional[Union[str, Path]] = None, reference_data: Optional[Dict[str, str]] = None, visualize_output: Optional[bool] = False, + reflect_output: Optional[bool] = True, ) -> str: """Invoke the vision agent. @@ -538,6 +539,7 @@ def chat_with_workflow( image: Optional[Union[str, Path]] = None, reference_data: Optional[Dict[str, str]] = None, visualize_output: Optional[bool] = False, + reflect_output: Optional[bool] = True, ) -> Tuple[str, List[Dict]]: """Chat with the vision agent and return the final answer and all tool results. @@ -625,20 +627,25 @@ def chat_with_workflow( reflection_images = [image] else: reflection_images = None - reflection = self_reflect( - self.reflect_model, - question, - self.tools, - all_tool_results, - final_answer, - reflection_images, - ) - self.log_progress(f"Reflection: {reflection}") - parsed_reflection = parse_reflect(reflection) - if parsed_reflection["Finish"]: - break + + if reflect_output: + reflection = self_reflect( + self.reflect_model, + question, + self.tools, + all_tool_results, + final_answer, + reflection_images, + ) + self.log_progress(f"Reflection: {reflection}") + parsed_reflection = parse_reflect(reflection) + if parsed_reflection["Finish"]: + break + else: + reflections += "\n" + parsed_reflection["Reflection"] else: - reflections += "\n" + parsed_reflection["Reflection"] + self.log_progress("Reflection skipped based on user request.") + break # '' is a symbol to indicate the end of the chat, which is useful for streaming logs. self.log_progress( f"The Vision Agent has concluded this chat. {final_answer}" @@ -660,12 +667,14 @@ def chat( image: Optional[Union[str, Path]] = None, reference_data: Optional[Dict[str, str]] = None, visualize_output: Optional[bool] = False, + reflect_output: Optional[bool] = True, ) -> str: answer, _ = self.chat_with_workflow( chat, image=image, visualize_output=visualize_output, reference_data=reference_data, + reflect_output=reflect_output, ) return answer