Skip to content

Commit 10b3151

Browse files
Adding countgd as default counting tool
1 parent d545395 commit 10b3151

File tree

5 files changed

+120
-10
lines changed

5 files changed

+120
-10
lines changed

tests/integ/test_tools.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
ixc25_video_vqa,
2525
loca_visual_prompt_counting,
2626
loca_zero_shot_counting,
27+
countgd_counting,
28+
countgd_example_based_counting,
2729
ocr,
2830
owl_v2,
2931
template_match,
@@ -186,6 +188,22 @@ def test_loca_visual_prompt_counting() -> None:
186188
assert result["count"] == 25
187189

188190

191+
def test_countgd_counting() -> None:
192+
img = ski.data.coins()
193+
194+
result = countgd_counting(image=img, prompt="coin")
195+
assert result["count"] == 24
196+
197+
198+
def test_countgd_example_based_counting() -> None:
199+
img = ski.data.coins()
200+
result = countgd_example_based_counting(
201+
visual_prompt=[[85, 106, 122, 145]],
202+
image=img,
203+
)
204+
assert result["count"] == 24
205+
206+
189207
def test_git_vqa_v2() -> None:
190208
img = ski.data.rocket()
191209
result = git_vqa_v2(

vision_agent/agent/vision_agent_coder_prompts.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,20 +81,19 @@
8181
- Count the number of detected objects labeled as 'person'.
8282
plan3:
8383
- Load the image from the provided file path 'image.jpg'.
84-
- Use the 'loca_zero_shot_counting' tool to count the dominant foreground object, which in this case is people.
84+
- Use the 'countgd_counting' tool to count the dominant foreground object, which in this case is people.
8585
8686
```python
87-
from vision_agent.tools import load_image, owl_v2, grounding_sam, loca_zero_shot_counting
87+
from vision_agent.tools import load_image, owl_v2, grounding_sam, countgd_counting
8888
image = load_image("image.jpg")
8989
owl_v2_out = owl_v2("person", image)
9090
9191
gsam_out = grounding_sam("person", image)
9292
gsam_out = [{{k: v for k, v in o.items() if k != "mask"}} for o in gsam_out]
9393
94-
loca_out = loca_zero_shot_counting(image)
95-
loca_out = loca_out["count"]
94+
cgd_out = countgd_counting(image)
9695
97-
final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "loca_zero_shot_counting": loca_out}}
96+
final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "countgd_counting": cgd_out}}
9897
print(final_out)
9998
```
10099
"""

vision_agent/lmm/lmm.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,6 @@ def generate_segmentor(self, question: str) -> Callable:
276276

277277
return lambda x: T.grounding_sam(params["prompt"], x)
278278

279-
def generate_zero_shot_counter(self, question: str) -> Callable:
280-
return T.loca_zero_shot_counting
281-
282279
def generate_image_qa_tool(self, question: str) -> Callable:
283280
return lambda x: T.git_vqa_v2(question, x)
284281

vision_agent/tools/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
load_image,
3838
loca_visual_prompt_counting,
3939
loca_zero_shot_counting,
40+
countgd_counting,
41+
countgd_example_based_counting,
4042
ocr,
4143
overlay_bounding_boxes,
4244
overlay_heat_map,

vision_agent/tools/tools.py

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,8 @@ def loca_visual_prompt_counting(
467467
468468
Parameters:
469469
image (np.ndarray): The image that contains lot of instances of a single object
470+
visual_prompt (Dict[str, List[float]]): Bounding box of the object in format
471+
[xmin, ymin, xmax, ymax]. Only 1 bounding box can be provided.
470472
471473
Returns:
472474
Dict[str, Any]: A dictionary containing the key 'count' and the count as a
@@ -499,6 +501,99 @@ def loca_visual_prompt_counting(
499501
return resp_data
500502

501503

504+
def countgd_counting(
505+
prompt: str,
506+
image: np.ndarray,
507+
box_threshold: float = 0.23,
508+
) -> List[Dict[str, Any]]:
509+
"""'countgd_counting' is a tool that can precisely count multiple instances of an
510+
object given a text prompt. It returns a list of bounding boxes with normalized
511+
coordinates, label names and associated confidence scores.
512+
513+
Parameters:
514+
prompt (str): The object that needs to be counted.
515+
image (np.ndarray): The image that contains multiple instances of the object.
516+
box_threshold (float, optional): The threshold for detection. Defaults
517+
to 0.23.
518+
519+
Returns:
520+
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
521+
bounding box of the detected objects with normalized coordinates between 0
522+
and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
523+
top-left and xmax and ymax are the coordinates of the bottom-right of the
524+
bounding box.
525+
526+
Example
527+
-------
528+
>>> countgd_counting("flower", image)
529+
[
530+
{'score': 0.49, 'label': 'flower', 'bbox': [0.1, 0.11, 0.35, 0.4]},
531+
{'score': 0.68, 'label': 'flower', 'bbox': [0.2, 0.21, 0.45, 0.5},
532+
{'score': 0.78, 'label': 'flower', 'bbox': [0.3, 0.35, 0.48, 0.52},
533+
{'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58},
534+
]
535+
"""
536+
buffer_bytes = numpy_to_bytes(image)
537+
files = [("image", buffer_bytes)]
538+
payload = {
539+
"text": prompt,
540+
"visual_prompts": [],
541+
"box_threshold": box_threshold,
542+
"function_name": "countgd_counting",
543+
}
544+
data: Dict[str, Any] = send_inference_request(
545+
payload, "countgd_counting", files=files, v2=True
546+
)
547+
return data
548+
549+
550+
def countgd_example_based_counting(
551+
visual_prompts: List[List[float]],
552+
image: np.ndarray,
553+
box_threshold: float = 0.23,
554+
) -> List[Dict[str, Any]]:
555+
"""'countgd_example_based_counting' is a tool that can precisely count multiple
556+
instances of an object given few visual example prompts. It returns a list of bounding
557+
boxes with normalized coordinates, label names and associated confidence scores.
558+
559+
Parameters:
560+
visual_prompts (List[List[float]]): Bounding boxes of the object in format
561+
[xmin, ymin, xmax, ymax]. Upto 3 bounding boxes can be provided.
562+
image (np.ndarray): The image that contains multiple instances of the object.
563+
box_threshold (float, optional): The threshold for detection. Defaults
564+
to 0.23.
565+
566+
Returns:
567+
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
568+
bounding box of the detected objects with normalized coordinates between 0
569+
and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
570+
top-left and xmax and ymax are the coordinates of the bottom-right of the
571+
bounding box.
572+
573+
Example
574+
-------
575+
>>> countgd_example_based_counting(visual_prompts=[[0.1, 0.1, 0.4, 0.42], [0.2, 0.3, 0.25, 0.35]], image=image)
576+
[
577+
{'score': 0.49, 'label': 'object', 'bbox': [0.1, 0.11, 0.35, 0.4]},
578+
{'score': 0.68, 'label': 'object', 'bbox': [0.2, 0.21, 0.45, 0.5},
579+
{'score': 0.78, 'label': 'object', 'bbox': [0.3, 0.35, 0.48, 0.52},
580+
{'score': 0.98, 'label': 'object', 'bbox': [0.44, 0.24, 0.49, 0.58},
581+
]
582+
"""
583+
buffer_bytes = numpy_to_bytes(image)
584+
files = [("image", buffer_bytes)]
585+
payload = {
586+
"text": "",
587+
"visual_prompts": visual_prompts,
588+
"box_threshold": box_threshold,
589+
"function_name": "countgd_example_based_counting",
590+
}
591+
data: Dict[str, Any] = send_inference_request(
592+
payload, "countgd_example_based_counting", files=files, v2=True
593+
)
594+
return data
595+
596+
502597
def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str:
503598
"""'florence2_roberta_vqa' is a tool that takes an image and analyzes
504599
its contents, generates detailed captions and then tries to answer the given
@@ -1657,8 +1752,7 @@ def florencev2_fine_tuned_object_detection(
16571752
clip,
16581753
vit_image_classification,
16591754
vit_nsfw_classification,
1660-
loca_zero_shot_counting,
1661-
loca_visual_prompt_counting,
1755+
countgd_counting,
16621756
florence2_image_caption,
16631757
florence2_ocr,
16641758
florence2_sam2_image,

0 commit comments

Comments
 (0)