From 0d5f99922cb40b6eb1b1b2df06eaabb482487ffa Mon Sep 17 00:00:00 2001 From: Shankar <90070882+shankar-landing-ai@users.noreply.github.com> Date: Tue, 30 Apr 2024 13:19:44 -0700 Subject: [PATCH] Enable/Disable Reflection (#72) * adding reflect to be optional for cases where LMM might not be able to understand the image * changed the param name to self_reflect * fixing param name as it overlaps with function call --- README.md | 3 ++- vision_agent/agent/vision_agent.py | 38 ++++++++++++++++++++---------- 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 76bcc2e1..3eeb19c8 100644 --- a/README.md +++ b/README.md @@ -122,7 +122,7 @@ you. For example: #### Custom Tools You can also add your own custom tools for your vision agent to use: - + ```python from vision_agent.tools import Tool, register_tool @register_tool @@ -160,6 +160,7 @@ find an example that creates a custom tool for template matching [here](examples | BboxIoU | BboxIoU returns the intersection over union of two bounding boxes normalized to 2 decimal places. | | SegIoU | SegIoU returns the intersection over union of two segmentation masks normalized to 2 decimal places. | | BoxDistance | BoxDistance returns the minimum distance between two bounding boxes normalized to 2 decimal places. | +| MaskDistance | MaskDistance returns the minimum distance between two segmentation masks in pixel units | | BboxContains | BboxContains returns the intersection of two boxes over the target box area. It is good for check if one box is contained within another box. | | ExtractFrames | ExtractFrames extracts frames with motion from a video. | | ZeroShotCounting | ZeroShotCounting returns the total number of objects belonging to a single class in a given image. | diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index c22c9983..9e5099de 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -489,6 +489,7 @@ def __call__( image: Optional[Union[str, Path]] = None, reference_data: Optional[Dict[str, str]] = None, visualize_output: Optional[bool] = False, + self_reflection: Optional[bool] = True, ) -> str: """Invoke the vision agent. @@ -501,6 +502,7 @@ def __call__( {"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]} where the bounding box coordinates are normalized. visualize_output: Whether to visualize the output. + self_reflection: boolean to enable and disable self reflection. Returns: The result of the vision agent in text. @@ -512,6 +514,7 @@ def __call__( image=image, visualize_output=visualize_output, reference_data=reference_data, + self_reflection=self_reflection, ) def log_progress(self, description: str) -> None: @@ -538,6 +541,7 @@ def chat_with_workflow( image: Optional[Union[str, Path]] = None, reference_data: Optional[Dict[str, str]] = None, visualize_output: Optional[bool] = False, + self_reflection: Optional[bool] = True, ) -> Tuple[str, List[Dict]]: """Chat with the vision agent and return the final answer and all tool results. @@ -550,6 +554,7 @@ def chat_with_workflow( {"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]} where the bounding box coordinates are normalized. visualize_output: Whether to visualize the output. + self_reflection: boolean to enable and disable self reflection. Returns: A tuple where the first item is the final answer and the second item is a @@ -625,20 +630,25 @@ def chat_with_workflow( reflection_images = [image] else: reflection_images = None - reflection = self_reflect( - self.reflect_model, - question, - self.tools, - all_tool_results, - final_answer, - reflection_images, - ) - self.log_progress(f"Reflection: {reflection}") - parsed_reflection = parse_reflect(reflection) - if parsed_reflection["Finish"]: - break + + if self_reflection: + reflection = self_reflect( + self.reflect_model, + question, + self.tools, + all_tool_results, + final_answer, + reflection_images, + ) + self.log_progress(f"Reflection: {reflection}") + parsed_reflection = parse_reflect(reflection) + if parsed_reflection["Finish"]: + break + else: + reflections += "\n" + parsed_reflection["Reflection"] else: - reflections += "\n" + parsed_reflection["Reflection"] + self.log_progress("Self Reflection skipped based on user request.") + break # '' is a symbol to indicate the end of the chat, which is useful for streaming logs. self.log_progress( f"The Vision Agent has concluded this chat. {final_answer}" @@ -660,12 +670,14 @@ def chat( image: Optional[Union[str, Path]] = None, reference_data: Optional[Dict[str, str]] = None, visualize_output: Optional[bool] = False, + self_reflection: Optional[bool] = True, ) -> str: answer, _ = self.chat_with_workflow( chat, image=image, visualize_output=visualize_output, reference_data=reference_data, + self_reflection=self_reflection, ) return answer