Skip to content

Commit

Permalink
Adding countgd as default counting tool
Browse files Browse the repository at this point in the history
  • Loading branch information
shankar-vision-eng committed Aug 28, 2024
1 parent d545395 commit 10b3151
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 10 deletions.
18 changes: 18 additions & 0 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
ixc25_video_vqa,
loca_visual_prompt_counting,
loca_zero_shot_counting,
countgd_counting,
countgd_example_based_counting,
ocr,
owl_v2,
template_match,
Expand Down Expand Up @@ -186,6 +188,22 @@ def test_loca_visual_prompt_counting() -> None:
assert result["count"] == 25


def test_countgd_counting() -> None:
img = ski.data.coins()

result = countgd_counting(image=img, prompt="coin")
assert result["count"] == 24


def test_countgd_example_based_counting() -> None:
img = ski.data.coins()
result = countgd_example_based_counting(
visual_prompt=[[85, 106, 122, 145]],
image=img,
)
assert result["count"] == 24


def test_git_vqa_v2() -> None:
img = ski.data.rocket()
result = git_vqa_v2(
Expand Down
9 changes: 4 additions & 5 deletions vision_agent/agent/vision_agent_coder_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,19 @@
- Count the number of detected objects labeled as 'person'.
plan3:
- Load the image from the provided file path 'image.jpg'.
- Use the 'loca_zero_shot_counting' tool to count the dominant foreground object, which in this case is people.
- Use the 'countgd_counting' tool to count the dominant foreground object, which in this case is people.
```python
from vision_agent.tools import load_image, owl_v2, grounding_sam, loca_zero_shot_counting
from vision_agent.tools import load_image, owl_v2, grounding_sam, countgd_counting
image = load_image("image.jpg")
owl_v2_out = owl_v2("person", image)
gsam_out = grounding_sam("person", image)
gsam_out = [{{k: v for k, v in o.items() if k != "mask"}} for o in gsam_out]
loca_out = loca_zero_shot_counting(image)
loca_out = loca_out["count"]
cgd_out = countgd_counting(image)
final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "loca_zero_shot_counting": loca_out}}
final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "countgd_counting": cgd_out}}
print(final_out)
```
"""
Expand Down
3 changes: 0 additions & 3 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,6 @@ def generate_segmentor(self, question: str) -> Callable:

return lambda x: T.grounding_sam(params["prompt"], x)

def generate_zero_shot_counter(self, question: str) -> Callable:
return T.loca_zero_shot_counting

def generate_image_qa_tool(self, question: str) -> Callable:
return lambda x: T.git_vqa_v2(question, x)

Expand Down
2 changes: 2 additions & 0 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
load_image,
loca_visual_prompt_counting,
loca_zero_shot_counting,
countgd_counting,
countgd_example_based_counting,
ocr,
overlay_bounding_boxes,
overlay_heat_map,
Expand Down
98 changes: 96 additions & 2 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,8 @@ def loca_visual_prompt_counting(
Parameters:
image (np.ndarray): The image that contains lot of instances of a single object
visual_prompt (Dict[str, List[float]]): Bounding box of the object in format
[xmin, ymin, xmax, ymax]. Only 1 bounding box can be provided.
Returns:
Dict[str, Any]: A dictionary containing the key 'count' and the count as a
Expand Down Expand Up @@ -499,6 +501,99 @@ def loca_visual_prompt_counting(
return resp_data


def countgd_counting(
prompt: str,
image: np.ndarray,
box_threshold: float = 0.23,
) -> List[Dict[str, Any]]:
"""'countgd_counting' is a tool that can precisely count multiple instances of an
object given a text prompt. It returns a list of bounding boxes with normalized
coordinates, label names and associated confidence scores.
Parameters:
prompt (str): The object that needs to be counted.
image (np.ndarray): The image that contains multiple instances of the object.
box_threshold (float, optional): The threshold for detection. Defaults
to 0.23.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
bounding box of the detected objects with normalized coordinates between 0
and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
top-left and xmax and ymax are the coordinates of the bottom-right of the
bounding box.
Example
-------
>>> countgd_counting("flower", image)
[
{'score': 0.49, 'label': 'flower', 'bbox': [0.1, 0.11, 0.35, 0.4]},
{'score': 0.68, 'label': 'flower', 'bbox': [0.2, 0.21, 0.45, 0.5},
{'score': 0.78, 'label': 'flower', 'bbox': [0.3, 0.35, 0.48, 0.52},
{'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58},
]
"""
buffer_bytes = numpy_to_bytes(image)
files = [("image", buffer_bytes)]
payload = {
"text": prompt,
"visual_prompts": [],
"box_threshold": box_threshold,
"function_name": "countgd_counting",
}
data: Dict[str, Any] = send_inference_request(
payload, "countgd_counting", files=files, v2=True
)
return data


def countgd_example_based_counting(
visual_prompts: List[List[float]],
image: np.ndarray,
box_threshold: float = 0.23,
) -> List[Dict[str, Any]]:
"""'countgd_example_based_counting' is a tool that can precisely count multiple
instances of an object given few visual example prompts. It returns a list of bounding
boxes with normalized coordinates, label names and associated confidence scores.
Parameters:
visual_prompts (List[List[float]]): Bounding boxes of the object in format
[xmin, ymin, xmax, ymax]. Upto 3 bounding boxes can be provided.
image (np.ndarray): The image that contains multiple instances of the object.
box_threshold (float, optional): The threshold for detection. Defaults
to 0.23.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
bounding box of the detected objects with normalized coordinates between 0
and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
top-left and xmax and ymax are the coordinates of the bottom-right of the
bounding box.
Example
-------
>>> countgd_example_based_counting(visual_prompts=[[0.1, 0.1, 0.4, 0.42], [0.2, 0.3, 0.25, 0.35]], image=image)
[
{'score': 0.49, 'label': 'object', 'bbox': [0.1, 0.11, 0.35, 0.4]},
{'score': 0.68, 'label': 'object', 'bbox': [0.2, 0.21, 0.45, 0.5},
{'score': 0.78, 'label': 'object', 'bbox': [0.3, 0.35, 0.48, 0.52},
{'score': 0.98, 'label': 'object', 'bbox': [0.44, 0.24, 0.49, 0.58},
]
"""
buffer_bytes = numpy_to_bytes(image)
files = [("image", buffer_bytes)]
payload = {
"text": "",
"visual_prompts": visual_prompts,
"box_threshold": box_threshold,
"function_name": "countgd_example_based_counting",
}
data: Dict[str, Any] = send_inference_request(
payload, "countgd_example_based_counting", files=files, v2=True
)
return data


def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str:
"""'florence2_roberta_vqa' is a tool that takes an image and analyzes
its contents, generates detailed captions and then tries to answer the given
Expand Down Expand Up @@ -1657,8 +1752,7 @@ def florencev2_fine_tuned_object_detection(
clip,
vit_image_classification,
vit_nsfw_classification,
loca_zero_shot_counting,
loca_visual_prompt_counting,
countgd_counting,
florence2_image_caption,
florence2_ocr,
florence2_sam2_image,
Expand Down

0 comments on commit 10b3151

Please sign in to comment.