diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index b02cdf72..240b47a1 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -8,7 +8,7 @@ from PIL import Image from tabulate import tabulate -from vision_agent.image_utils import overlay_bboxes, overlay_masks, overlay_heat_map +from vision_agent.image_utils import overlay_bboxes, overlay_heat_map, overlay_masks from vision_agent.llm import LLM, OpenAILLM from vision_agent.lmm import LMM, OpenAILMM from vision_agent.tools import TOOLS @@ -492,20 +492,20 @@ def chat_with_workflow( if image: question += f" Image name: {image}" if reference_data: - if not ( - "image" in reference_data - and ("mask" in reference_data or "bbox" in reference_data) - ): - raise ValueError( - f"Reference data must contain 'image' and a visual prompt which can be 'mask' or 'bbox'. but got {reference_data}" - ) - visual_prompt_data = ( - f"Reference mask: {reference_data['mask']}" + question += ( + f" Reference image: {reference_data['image']}" + if "image" in reference_data + else "" + ) + question += ( + f" Reference mask: {reference_data['mask']}" if "mask" in reference_data - else f"Reference bbox: {reference_data['bbox']}" + else "" ) question += ( - f" Reference image: {reference_data['image']}, {visual_prompt_data}" + f" Reference bbox: {reference_data['bbox']}" + if "bbox" in reference_data + else "" ) reflections = ""