diff --git a/README.md b/README.md index bff0a12a..741b8ff2 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,8 @@ the individual steps and tools to get the answer: } ]], "answer": "The jar is located at [0.58, 0.2, 0.72, 0.45].", -}] +}, +{"visualize_output": "final_output.png"}] ``` ### Tools diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index e78fc886..fbbf1bc7 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -400,7 +400,7 @@ def __call__( """ if isinstance(input, str): input = [{"role": "user", "content": input}] - return self.chat(input, image=image) + return self.chat(input, image=image, visualize_output=visualize_output) def log_progress(self, description: str) -> None: _LOGGER.info(description) @@ -480,9 +480,9 @@ def chat_with_workflow( ) if visualize_output: - visualized_output = all_tool_results[-1]["visualized_images"] + visualized_output = all_tool_results[-1]["visualized_output"] for image in visualized_output: - Image.open(image).show() + Image.open(image).show() # type: ignore return final_answer, all_tool_results @@ -492,7 +492,7 @@ def chat( image: Optional[Union[str, Path]] = None, visualize_output: Optional[bool] = False, ) -> str: - answer, _ = self.chat_with_workflow(chat, image=image) + answer, _ = self.chat_with_workflow(chat, image=image, visualize_output=visualize_output) return answer def retrieval(