diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index b8a0b844..b4d08cd8 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -489,7 +489,7 @@ def __call__( image: Optional[Union[str, Path]] = None, reference_data: Optional[Dict[str, str]] = None, visualize_output: Optional[bool] = False, - reflect_output: Optional[bool] = True, + self_reflect: Optional[bool] = True, ) -> str: """Invoke the vision agent. @@ -502,6 +502,7 @@ def __call__( {"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]} where the bounding box coordinates are normalized. visualize_output: Whether to visualize the output. + self_reflect: boolean to enable and disable self reflection. Returns: The result of the vision agent in text. @@ -513,6 +514,7 @@ def __call__( image=image, visualize_output=visualize_output, reference_data=reference_data, + self_reflect=self_reflect, ) def log_progress(self, description: str) -> None: @@ -539,7 +541,7 @@ def chat_with_workflow( image: Optional[Union[str, Path]] = None, reference_data: Optional[Dict[str, str]] = None, visualize_output: Optional[bool] = False, - reflect_output: Optional[bool] = True, + self_reflect: Optional[bool] = True, ) -> Tuple[str, List[Dict]]: """Chat with the vision agent and return the final answer and all tool results. @@ -552,6 +554,7 @@ def chat_with_workflow( {"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]} where the bounding box coordinates are normalized. visualize_output: Whether to visualize the output. + self_reflect: boolean to enable and disable self reflection. Returns: A tuple where the first item is the final answer and the second item is a @@ -628,7 +631,7 @@ def chat_with_workflow( else: reflection_images = None - if reflect_output: + if self_reflect: reflection = self_reflect( self.reflect_model, question, @@ -644,7 +647,7 @@ def chat_with_workflow( else: reflections += "\n" + parsed_reflection["Reflection"] else: - self.log_progress("Reflection skipped based on user request.") + self.log_progress("Self Reflection skipped based on user request.") break # '' is a symbol to indicate the end of the chat, which is useful for streaming logs. self.log_progress( @@ -667,14 +670,14 @@ def chat( image: Optional[Union[str, Path]] = None, reference_data: Optional[Dict[str, str]] = None, visualize_output: Optional[bool] = False, - reflect_output: Optional[bool] = True, + self_reflect: Optional[bool] = True, ) -> str: answer, _ = self.chat_with_workflow( chat, image=image, visualize_output=visualize_output, reference_data=reference_data, - reflect_output=reflect_output, + self_reflect=self_reflect, ) return answer