diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index c1a779fd..68f5d396 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -61,9 +61,10 @@ def test_object_detection(): img = ski.data.coins() result = florencev2_object_detection( image=img, + prompt="coin", ) - assert len(result) == 24 - assert [res["label"] for res in result] == ["coin"] * 24 + assert len(result) == 25 + assert [res["label"] for res in result] == ["coin"] * 25 def test_template_match(): diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 4e0fdf27..06a7b72d 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -606,10 +606,10 @@ def florencev2_image_caption(image: np.ndarray, detail_caption: bool = True) -> return answer[task] # type: ignore -def florencev2_object_detection(image: np.ndarray) -> List[Dict[str, Any]]: - """'florencev2_object_detection' is a tool that can detect common objects in an - image without any text prompt or thresholding. It returns a list of detected objects - as labels and their location as bounding boxes. +def florencev2_object_detection(image: np.ndarray, prompt: str) -> List[Dict[str, Any]]: + """'florencev2_object_detection' is a tool that can detect objects given a text + prompt such as a phrase or class names separated by commas. It returns a list of + detected objects as labels and their location as bounding boxes with score of 1.0. Parameters: image (np.ndarray): The image to used to detect objects @@ -623,23 +623,23 @@ def florencev2_object_detection(image: np.ndarray) -> List[Dict[str, Any]]: Example ------- - >>> florencev2_object_detection(image) + >>> florencev2_object_detection(image, 'person looking at a coyote') [ - {'score': 1.0, 'label': 'window', 'bbox': [0.1, 0.11, 0.35, 0.4]}, - {'score': 1.0, 'label': 'car', 'bbox': [0.2, 0.21, 0.45, 0.5}, - {'score': 1.0, 'label': 'person', 'bbox': [0.34, 0.21, 0.85, 0.5}, + {'score': 1.0, 'label': 'person', 'bbox': [0.1, 0.11, 0.35, 0.4]}, + {'score': 1.0, 'label': 'coyote', 'bbox': [0.34, 0.21, 0.85, 0.5}, ] """ image_size = image.shape[:2] image_b64 = convert_to_b64(image) data = { "image": image_b64, - "task": "", + "task": "", + "prompt": prompt, "function_name": "florencev2_object_detection", } detections = send_inference_request(data, "florence2", v2=True) - detections = detections[""] + detections = detections[""] return_data = [] for i in range(len(detections["bboxes"])): return_data.append( @@ -1249,7 +1249,6 @@ def overlay_heat_map( loca_visual_prompt_counting, florencev2_roberta_vqa, florencev2_image_caption, - florencev2_object_detection, detr_segmentation, depth_anything_v2, generate_soft_edge_image,