diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index afa9dcb4..bca1f6ea 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -10,7 +10,7 @@ detr_segmentation, dpt_hybrid_midas, florence2_image_caption, - florence2_object_detection, + florence2_phrase_grounding, florence2_ocr, florence2_roberta_vqa, florence2_sam2_image, @@ -65,7 +65,7 @@ def test_owl(): def test_object_detection(): img = ski.data.coins() - result = florence2_object_detection( + result = florence2_phrase_grounding( image=img, prompt="coin", ) diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py index b10988c6..3b3b5f68 100644 --- a/vision_agent/agent/vision_agent_coder.py +++ b/vision_agent/agent/vision_agent_coder.py @@ -744,29 +744,14 @@ def chat_with_workflow( results = {"code": "", "test": "", "plan": []} plan = [] success = False - self.log_progress( - { - "type": "log", - "log_content": "Creating plans", - "status": "started", - } - ) - plans = write_plans( - int_chat, - T.get_tool_descriptions_by_names( - customized_tool_names, T.FUNCTION_TOOLS, T.UTIL_TOOLS # type: ignore - ), - format_memory(working_memory), - self.planner, + + plans = self._create_plans( + int_chat, customized_tool_names, working_memory, self.planner ) - if self.verbosity >= 1: - for p in plans: - # tabulate will fail if the keys are not the same for all elements - p_fixed = [{"instructions": e} for e in plans[p]["instructions"]] - _LOGGER.info( - f"\n{tabulate(tabular_data=p_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" - ) + if test_multi_plan: + self._log_plans(plans, self.verbosity) + tool_infos = retrieve_tools( plans, self.tool_recommender, @@ -860,6 +845,39 @@ def log_progress(self, data: Dict[str, Any]) -> None: if self.report_progress_callback is not None: self.report_progress_callback(data) + def _create_plans( + self, + int_chat: List[Message], + customized_tool_names: Optional[List[str]], + working_memory: List[Dict[str, str]], + planner: LMM, + ) -> Dict[str, Any]: + self.log_progress( + { + "type": "log", + "log_content": "Creating plans", + "status": "started", + } + ) + plans = write_plans( + int_chat, + T.get_tool_descriptions_by_names( + customized_tool_names, T.FUNCTION_TOOLS, T.UTIL_TOOLS # type: ignore + ), + format_memory(working_memory), + planner, + ) + return plans + + def _log_plans(self, plans: Dict[str, Any], verbosity: int) -> None: + if verbosity >= 1: + for p in plans: + # tabulate will fail if the keys are not the same for all elements + p_fixed = [{"instructions": e} for e in plans[p]["instructions"]] + _LOGGER.info( + f"\n{tabulate(tabular_data=p_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" + ) + class OllamaVisionAgentCoder(VisionAgentCoder): """VisionAgentCoder that uses Ollama models for planning, coding, testing. diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index a90b7181..3372fcbb 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -21,7 +21,7 @@ dpt_hybrid_midas, extract_frames, florence2_image_caption, - florence2_object_detection, + florence2_phrase_grounding, florence2_ocr, florence2_roberta_vqa, florence2_sam2_image, diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 594fcf6d..4e1a0f40 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -760,10 +760,10 @@ def florence2_image_caption(image: np.ndarray, detail_caption: bool = True) -> s return answer[task] # type: ignore -def florence2_object_detection(prompt: str, image: np.ndarray) -> List[Dict[str, Any]]: - """'florencev2_object_detection' is a tool that can detect and count multiple - objects given a text prompt such as category names or referring expressions. You - can optionally separate the categories in the text with commas. It returns a list +def florence2_phrase_grounding(prompt: str, image: np.ndarray) -> List[Dict[str, Any]]: + """'florence2_phrase_grounding' is a tool that can detect multiple + objects given a text prompt which can be object names or caption. You + can optionally separate the object names in the text with commas. It returns a list of bounding boxes with normalized coordinates, label names and associated probability scores of 1.0. @@ -780,7 +780,7 @@ def florence2_object_detection(prompt: str, image: np.ndarray) -> List[Dict[str, Example ------- - >>> florence2_object_detection('person looking at a coyote', image) + >>> florence2_phrase_grounding('person looking at a coyote', image) [ {'score': 1.0, 'label': 'person', 'bbox': [0.1, 0.11, 0.35, 0.4]}, {'score': 1.0, 'label': 'coyote', 'bbox': [0.34, 0.21, 0.85, 0.5}, @@ -792,7 +792,7 @@ def florence2_object_detection(prompt: str, image: np.ndarray) -> List[Dict[str, "image": image_b64, "task": "", "prompt": prompt, - "function_name": "florence2_object_detection", + "function_name": "florence2_phrase_grounding", } detections = send_inference_request(data, "florence2", v2=True) @@ -1418,6 +1418,7 @@ def overlay_segmentation_masks( medias: Union[np.ndarray, List[np.ndarray]], masks: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]], draw_label: bool = True, + secondary_label_key: str = "tracking_label", ) -> Union[np.ndarray, List[np.ndarray]]: """'overlay_segmentation_masks' is a utility function that displays segmentation masks. @@ -1426,7 +1427,10 @@ def overlay_segmentation_masks( medias (Union[np.ndarray, List[np.ndarray]]): The image or frames to display the masks on. masks (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]): A list of - dictionaries containing the masks. + dictionaries containing the masks, labels and scores. + draw_label (bool, optional): If True, the labels will be displayed on the image. + secondary_label_key (str, optional): The key to use for the secondary + tracking label which is needed in videos to display tracking information. Returns: np.ndarray: The image with the masks displayed. @@ -1471,6 +1475,7 @@ def overlay_segmentation_masks( for elt in masks_int[i]: mask = elt["mask"] label = elt["label"] + tracking_lbl = elt.get(secondary_label_key, None) np_mask = np.zeros((pil_image.size[1], pil_image.size[0], 4)) np_mask[mask > 0, :] = color[label] + (255 * 0.5,) mask_img = Image.fromarray(np_mask.astype(np.uint8)) @@ -1478,16 +1483,17 @@ def overlay_segmentation_masks( if draw_label: draw = ImageDraw.Draw(pil_image) - text_box = draw.textbbox((0, 0), text=label, font=font) + text = tracking_lbl if tracking_lbl else label + text_box = draw.textbbox((0, 0), text=text, font=font) x, y = _get_text_coords_from_mask( mask, v_gap=(text_box[3] - text_box[1]) + 10, h_gap=(text_box[2] - text_box[0]) // 2, ) if x != 0 and y != 0: - text_box = draw.textbbox((x, y), text=label, font=font) + text_box = draw.textbbox((x, y), text=text, font=font) draw.rectangle((x, y, text_box[2], text_box[3]), fill=color[label]) - draw.text((x, y), label, fill="black", font=font) + draw.text((x, y), text, fill="black", font=font) frame_out.append(np.array(pil_image)) return frame_out[0] if len(frame_out) == 1 else frame_out @@ -1663,7 +1669,7 @@ def florencev2_fine_tuned_object_detection( florence2_ocr, florence2_sam2_image, florence2_sam2_video, - florence2_object_detection, + florence2_phrase_grounding, ixc25_image_vqa, ixc25_video_vqa, detr_segmentation,