Skip to content

Commit

Permalink
Add OwlVIT, GDINO tiny, NSFW and Generic Image classifier (#137)
Browse files Browse the repository at this point in the history
* added some of the new tools and updated test cases. Renamed tool names to model names

* fix tool names in lmm.py

* minor fixes in owl_v2
  • Loading branch information
shankar-vision-eng authored Jun 14, 2024
1 parent 7a74b28 commit 482ccf4
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 35 deletions.
55 changes: 44 additions & 11 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
closest_mask_distance,
grounding_dino,
grounding_sam,
image_caption,
image_question_answering,
blip_image_caption,
git_vqa_v2,
ocr,
visual_prompt_counting,
zero_shot_counting,
loca_visual_prompt_counting,
loca_zero_shot_counting,
vit_nsfw_classification,
vit_image_classification,
owl_v2,
)


Expand All @@ -24,6 +27,20 @@ def test_grounding_dino():
assert [res["label"] for res in result] == ["coin"] * 24


def test_grounding_dino_tiny():
img = ski.data.coins()
result = grounding_dino(prompt="coin", image=img, model_size="tiny")
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24


def test_owl():
img = ski.data.coins()
result = owl_v2(prompt="coin", image=img, box_threshold=0.15)
assert len(result) == 25
assert [res["label"] for res in result] == ["coin"] * 25


def test_grounding_sam():
img = ski.data.coins()
result = grounding_sam(
Expand All @@ -44,34 +61,50 @@ def test_clip():
assert result["scores"] == [0.9999, 0.0001]


def test_vit_classification():
img = ski.data.coins()
result = vit_image_classification(
image=img,
)
assert "typewriter keyboard" in result["labels"]


def test_nsfw_classification():
img = ski.data.coins()
result = vit_nsfw_classification(
image=img,
)
assert result["labels"] == "normal"


def test_image_caption() -> None:
img = ski.data.rocket()
result = image_caption(
result = blip_image_caption(
image=img,
)
assert result.strip() == "a rocket on a stand"


def test_zero_shot_counting() -> None:
def test_loca_zero_shot_counting() -> None:
img = ski.data.coins()
result = zero_shot_counting(
result = loca_zero_shot_counting(
image=img,
)
assert result["count"] == 21


def test_visual_prompt_counting() -> None:
def test_loca_visual_prompt_counting() -> None:
img = ski.data.coins()
result = visual_prompt_counting(
result = loca_visual_prompt_counting(
visual_prompt={"bbox": [85, 106, 122, 145]},
image=img,
)
assert result["count"] == 25


def test_image_question_answering() -> None:
def test_git_vqa_v2() -> None:
img = ski.data.rocket()
result = image_question_answering(
result = git_vqa_v2(
prompt="Is the scene captured during day or night ?",
image=img,
)
Expand Down
4 changes: 2 additions & 2 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,10 @@ def generate_segmentor(self, question: str) -> Callable:
return lambda x: T.grounding_sam(params["prompt"], x)

def generate_zero_shot_counter(self, question: str) -> Callable:
return T.zero_shot_counting
return T.loca_zero_shot_counting

def generate_image_qa_tool(self, question: str) -> Callable:
return lambda x: T.image_question_answering(question, x)
return lambda x: T.git_vqa_v2(question, x)


class AzureOpenAILMM(OpenAILMM):
Expand Down
11 changes: 7 additions & 4 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,28 @@
TOOLS,
TOOLS_DF,
UTILITIES_DOCSTRING,
blip_image_caption,
clip,
closest_box_distance,
closest_mask_distance,
extract_frames,
get_tool_documentation,
git_vqa_v2,
grounding_dino,
grounding_sam,
image_caption,
image_question_answering,
load_image,
ocr,
overlay_bounding_boxes,
overlay_heat_map,
overlay_segmentation_masks,
owl_v2,
save_image,
save_json,
save_video,
visual_prompt_counting,
zero_shot_counting,
loca_visual_prompt_counting,
loca_zero_shot_counting,
vit_image_classification,
vit_nsfw_classification,
)

__new_tools__ = [
Expand Down
154 changes: 136 additions & 18 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def grounding_dino(
image: np.ndarray,
box_threshold: float = 0.20,
iou_threshold: float = 0.20,
model_size: str = "large",
) -> List[Dict[str, Any]]:
"""'grounding_dino' is a tool that can detect and count multiple objects given a text
prompt such as category names or referring expressions. The categories in text prompt
Expand All @@ -72,6 +73,7 @@ def grounding_dino(
to 0.20.
iou_threshold (float, optional): The threshold for the Intersection over Union
(IoU). Defaults to 0.20.
model_size (str, optional): The size of the model to use.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
Expand All @@ -90,10 +92,14 @@ def grounding_dino(
"""
image_size = image.shape[:2]
image_b64 = convert_to_b64(image)
if model_size not in ["large", "tiny"]:
raise ValueError("model_size must be either 'large' or 'tiny'")
request_data = {
"prompt": prompt,
"image": image_b64,
"tool": "visual_grounding",
"tool": (
"visual_grounding" if model_size == "large" else "visual_grounding_tiny"
),
"kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold},
}
data: Dict[str, Any] = _send_inference_request(request_data, "tools")
Expand All @@ -109,6 +115,62 @@ def grounding_dino(
return return_data


def owl_v2(
prompt: str,
image: np.ndarray,
box_threshold: float = 0.10,
iou_threshold: float = 0.10,
) -> List[Dict[str, Any]]:
"""'owl_v2' is a tool that can detect and count multiple objects given a text
prompt such as category names or referring expressions. The categories in text prompt
are separated by commas or periods. It returns a list of bounding boxes with
normalized coordinates, label names and associated probability scores.
Parameters:
prompt (str): The prompt to ground to the image.
image (np.ndarray): The image to ground the prompt to.
box_threshold (float, optional): The threshold for the box detection. Defaults
to 0.10.
iou_threshold (float, optional): The threshold for the Intersection over Union
(IoU). Defaults to 0.10.
model_size (str, optional): The size of the model to use.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
bounding box of the detected objects with normalized coordinates between 0
and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
top-left and xmax and ymax are the coordinates of the bottom-right of the
bounding box.
Example
-------
>>> owl_v2("car. dinosaur", image)
[
{'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]},
{'score': 0.98, 'label': 'car', 'bbox': [0.2, 0.21, 0.45, 0.5},
]
"""
image_size = image.shape[:2]
image_b64 = convert_to_b64(image)
request_data = {
"prompt": prompt,
"image": image_b64,
"tool": "open_vocab_detection",
"kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold},
}
data: Dict[str, Any] = _send_inference_request(request_data, "tools")
return_data = []
for i in range(len(data["bboxes"])):
return_data.append(
{
"score": round(data["scores"][i], 2),
"label": data["labels"][i].strip(),
"bbox": normalize_bbox(data["bboxes"][i], image_size),
}
)
return return_data


def grounding_sam(
prompt: str,
image: np.ndarray,
Expand Down Expand Up @@ -253,8 +315,8 @@ def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
return ocr_results


def zero_shot_counting(image: np.ndarray) -> Dict[str, Any]:
"""'zero_shot_counting' is a tool that counts the dominant foreground object given
def loca_zero_shot_counting(image: np.ndarray) -> Dict[str, Any]:
"""'loca_zero_shot_counting' is a tool that counts the dominant foreground object given
an image and no other information about the content. It returns only the count of
the objects in the image.
Expand All @@ -267,7 +329,7 @@ def zero_shot_counting(image: np.ndarray) -> Dict[str, Any]:
Example
-------
>>> zero_shot_counting(image)
>>> loca_zero_shot_counting(image)
{'count': 45},
"""

Expand All @@ -281,10 +343,10 @@ def zero_shot_counting(image: np.ndarray) -> Dict[str, Any]:
return resp_data


def visual_prompt_counting(
def loca_visual_prompt_counting(
image: np.ndarray, visual_prompt: Dict[str, List[float]]
) -> Dict[str, Any]:
"""'visual_prompt_counting' is a tool that counts the dominant foreground object
"""'loca_visual_prompt_counting' is a tool that counts the dominant foreground object
given an image and a visual prompt which is a bounding box describing the object.
It returns only the count of the objects in the image.
Expand All @@ -297,7 +359,7 @@ def visual_prompt_counting(
Example
-------
>>> visual_prompt_counting(image, {"bbox": [0.1, 0.1, 0.4, 0.42]})
>>> loca_visual_prompt_counting(image, {"bbox": [0.1, 0.1, 0.4, 0.42]})
{'count': 45},
"""

Expand All @@ -316,8 +378,8 @@ def visual_prompt_counting(
return resp_data


def image_question_answering(prompt: str, image: np.ndarray) -> str:
"""'image_question_answering_' is a tool that can answer questions about the visual
def git_vqa_v2(prompt: str, image: np.ndarray) -> str:
"""'git_vqa_v2' is a tool that can answer questions about the visual
contents of an image given a question and an image. It returns an answer to the
question
Expand All @@ -331,7 +393,7 @@ def image_question_answering(prompt: str, image: np.ndarray) -> str:
Example
-------
>>> image_question_answering('What is the cat doing ?', image)
>>> git_vqa_v2('What is the cat doing ?', image)
'drinking milk'
"""

Expand Down Expand Up @@ -376,8 +438,62 @@ def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]:
return resp_data


def image_caption(image: np.ndarray) -> str:
"""'image_caption' is a tool that can caption an image based on its contents. It
def vit_image_classification(image: np.ndarray) -> Dict[str, Any]:
"""'vit_image_classification' is a tool that can classify an image. It returns a
list of classes and their probability scores based on image content.
Parameters:
image (np.ndarray): The image to classify or tag
Returns:
Dict[str, Any]: A dictionary containing the labels and scores. One dictionary
contains a list of labels and other a list of scores.
Example
-------
>>> vit_image_classification(image)
{"labels": ["leopard", "lemur, otter", "bird"], "scores": [0.68, 0.30, 0.02]},
"""

image_b64 = convert_to_b64(image)
data = {
"image": image_b64,
"tool": "image_classification",
}
resp_data = _send_inference_request(data, "tools")
resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]]
return resp_data


def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]:
"""'vit_nsfw_classification' is a tool that can classify an image as 'nsfw' or 'normal'.
It returns the predicted label and their probability scores based on image content.
Parameters:
image (np.ndarray): The image to classify or tag
Returns:
Dict[str, Any]: A dictionary containing the labels and scores. One dictionary
contains a list of labels and other a list of scores.
Example
-------
>>> vit_nsfw_classification(image)
{"labels": "normal", "scores": 0.68},
"""

image_b64 = convert_to_b64(image)
data = {
"image": image_b64,
"tool": "nsfw_image_classification",
}
resp_data = _send_inference_request(data, "tools")
resp_data["scores"] = round(resp_data["scores"], 4)
return resp_data


def blip_image_caption(image: np.ndarray) -> str:
"""'blip_image_caption' is a tool that can caption an image based on its contents. It
returns a text describing the image.
Parameters:
Expand All @@ -388,7 +504,7 @@ def image_caption(image: np.ndarray) -> str:
Example
-------
>>> image_caption(image)
>>> blip_image_caption(image)
'This image contains a cat sitting on a table with a bowl of milk.'
"""

Expand Down Expand Up @@ -792,15 +908,17 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:


TOOLS = [
grounding_dino,
owl_v2,
grounding_sam,
extract_frames,
ocr,
clip,
zero_shot_counting,
visual_prompt_counting,
image_question_answering,
image_caption,
vit_image_classification,
vit_nsfw_classification,
loca_zero_shot_counting,
loca_visual_prompt_counting,
git_vqa_v2,
blip_image_caption,
closest_mask_distance,
closest_box_distance,
save_json,
Expand Down

0 comments on commit 482ccf4

Please sign in to comment.