From 482ccf4903dbb4a24d089f517d74a8408733ba15 Mon Sep 17 00:00:00 2001 From: Shankar <90070882+shankar-vision-eng@users.noreply.github.com> Date: Fri, 14 Jun 2024 10:13:31 -0700 Subject: [PATCH] Add OwlVIT, GDINO tiny, NSFW and Generic Image classifier (#137) * added some of the new tools and updated test cases. Renamed tool names to model names * fix tool names in lmm.py * minor fixes in owl_v2 --- tests/test_tools.py | 55 +++++++++--- vision_agent/lmm/lmm.py | 4 +- vision_agent/tools/__init__.py | 11 ++- vision_agent/tools/tools.py | 154 +++++++++++++++++++++++++++++---- 4 files changed, 189 insertions(+), 35 deletions(-) diff --git a/tests/test_tools.py b/tests/test_tools.py index d893e7aa..ed7670cb 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -6,11 +6,14 @@ closest_mask_distance, grounding_dino, grounding_sam, - image_caption, - image_question_answering, + blip_image_caption, + git_vqa_v2, ocr, - visual_prompt_counting, - zero_shot_counting, + loca_visual_prompt_counting, + loca_zero_shot_counting, + vit_nsfw_classification, + vit_image_classification, + owl_v2, ) @@ -24,6 +27,20 @@ def test_grounding_dino(): assert [res["label"] for res in result] == ["coin"] * 24 +def test_grounding_dino_tiny(): + img = ski.data.coins() + result = grounding_dino(prompt="coin", image=img, model_size="tiny") + assert len(result) == 24 + assert [res["label"] for res in result] == ["coin"] * 24 + + +def test_owl(): + img = ski.data.coins() + result = owl_v2(prompt="coin", image=img, box_threshold=0.15) + assert len(result) == 25 + assert [res["label"] for res in result] == ["coin"] * 25 + + def test_grounding_sam(): img = ski.data.coins() result = grounding_sam( @@ -44,34 +61,50 @@ def test_clip(): assert result["scores"] == [0.9999, 0.0001] +def test_vit_classification(): + img = ski.data.coins() + result = vit_image_classification( + image=img, + ) + assert "typewriter keyboard" in result["labels"] + + +def test_nsfw_classification(): + img = ski.data.coins() + result = vit_nsfw_classification( + image=img, + ) + assert result["labels"] == "normal" + + def test_image_caption() -> None: img = ski.data.rocket() - result = image_caption( + result = blip_image_caption( image=img, ) assert result.strip() == "a rocket on a stand" -def test_zero_shot_counting() -> None: +def test_loca_zero_shot_counting() -> None: img = ski.data.coins() - result = zero_shot_counting( + result = loca_zero_shot_counting( image=img, ) assert result["count"] == 21 -def test_visual_prompt_counting() -> None: +def test_loca_visual_prompt_counting() -> None: img = ski.data.coins() - result = visual_prompt_counting( + result = loca_visual_prompt_counting( visual_prompt={"bbox": [85, 106, 122, 145]}, image=img, ) assert result["count"] == 25 -def test_image_question_answering() -> None: +def test_git_vqa_v2() -> None: img = ski.data.rocket() - result = image_question_answering( + result = git_vqa_v2( prompt="Is the scene captured during day or night ?", image=img, ) diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index ac047ec4..693851fa 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -224,10 +224,10 @@ def generate_segmentor(self, question: str) -> Callable: return lambda x: T.grounding_sam(params["prompt"], x) def generate_zero_shot_counter(self, question: str) -> Callable: - return T.zero_shot_counting + return T.loca_zero_shot_counting def generate_image_qa_tool(self, question: str) -> Callable: - return lambda x: T.image_question_answering(question, x) + return lambda x: T.git_vqa_v2(question, x) class AzureOpenAILMM(OpenAILMM): diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 87223043..fac8b87f 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -7,25 +7,28 @@ TOOLS, TOOLS_DF, UTILITIES_DOCSTRING, + blip_image_caption, clip, closest_box_distance, closest_mask_distance, extract_frames, get_tool_documentation, + git_vqa_v2, grounding_dino, grounding_sam, - image_caption, - image_question_answering, load_image, ocr, overlay_bounding_boxes, overlay_heat_map, overlay_segmentation_masks, + owl_v2, save_image, save_json, save_video, - visual_prompt_counting, - zero_shot_counting, + loca_visual_prompt_counting, + loca_zero_shot_counting, + vit_image_classification, + vit_nsfw_classification, ) __new_tools__ = [ diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 8bffe5c6..92e888f7 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -59,6 +59,7 @@ def grounding_dino( image: np.ndarray, box_threshold: float = 0.20, iou_threshold: float = 0.20, + model_size: str = "large", ) -> List[Dict[str, Any]]: """'grounding_dino' is a tool that can detect and count multiple objects given a text prompt such as category names or referring expressions. The categories in text prompt @@ -72,6 +73,7 @@ def grounding_dino( to 0.20. iou_threshold (float, optional): The threshold for the Intersection over Union (IoU). Defaults to 0.20. + model_size (str, optional): The size of the model to use. Returns: List[Dict[str, Any]]: A list of dictionaries containing the score, label, and @@ -90,10 +92,14 @@ def grounding_dino( """ image_size = image.shape[:2] image_b64 = convert_to_b64(image) + if model_size not in ["large", "tiny"]: + raise ValueError("model_size must be either 'large' or 'tiny'") request_data = { "prompt": prompt, "image": image_b64, - "tool": "visual_grounding", + "tool": ( + "visual_grounding" if model_size == "large" else "visual_grounding_tiny" + ), "kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold}, } data: Dict[str, Any] = _send_inference_request(request_data, "tools") @@ -109,6 +115,62 @@ def grounding_dino( return return_data +def owl_v2( + prompt: str, + image: np.ndarray, + box_threshold: float = 0.10, + iou_threshold: float = 0.10, +) -> List[Dict[str, Any]]: + """'owl_v2' is a tool that can detect and count multiple objects given a text + prompt such as category names or referring expressions. The categories in text prompt + are separated by commas or periods. It returns a list of bounding boxes with + normalized coordinates, label names and associated probability scores. + + Parameters: + prompt (str): The prompt to ground to the image. + image (np.ndarray): The image to ground the prompt to. + box_threshold (float, optional): The threshold for the box detection. Defaults + to 0.10. + iou_threshold (float, optional): The threshold for the Intersection over Union + (IoU). Defaults to 0.10. + model_size (str, optional): The size of the model to use. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the score, label, and + bounding box of the detected objects with normalized coordinates between 0 + and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the + top-left and xmax and ymax are the coordinates of the bottom-right of the + bounding box. + + Example + ------- + >>> owl_v2("car. dinosaur", image) + [ + {'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]}, + {'score': 0.98, 'label': 'car', 'bbox': [0.2, 0.21, 0.45, 0.5}, + ] + """ + image_size = image.shape[:2] + image_b64 = convert_to_b64(image) + request_data = { + "prompt": prompt, + "image": image_b64, + "tool": "open_vocab_detection", + "kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold}, + } + data: Dict[str, Any] = _send_inference_request(request_data, "tools") + return_data = [] + for i in range(len(data["bboxes"])): + return_data.append( + { + "score": round(data["scores"][i], 2), + "label": data["labels"][i].strip(), + "bbox": normalize_bbox(data["bboxes"][i], image_size), + } + ) + return return_data + + def grounding_sam( prompt: str, image: np.ndarray, @@ -253,8 +315,8 @@ def ocr(image: np.ndarray) -> List[Dict[str, Any]]: return ocr_results -def zero_shot_counting(image: np.ndarray) -> Dict[str, Any]: - """'zero_shot_counting' is a tool that counts the dominant foreground object given +def loca_zero_shot_counting(image: np.ndarray) -> Dict[str, Any]: + """'loca_zero_shot_counting' is a tool that counts the dominant foreground object given an image and no other information about the content. It returns only the count of the objects in the image. @@ -267,7 +329,7 @@ def zero_shot_counting(image: np.ndarray) -> Dict[str, Any]: Example ------- - >>> zero_shot_counting(image) + >>> loca_zero_shot_counting(image) {'count': 45}, """ @@ -281,10 +343,10 @@ def zero_shot_counting(image: np.ndarray) -> Dict[str, Any]: return resp_data -def visual_prompt_counting( +def loca_visual_prompt_counting( image: np.ndarray, visual_prompt: Dict[str, List[float]] ) -> Dict[str, Any]: - """'visual_prompt_counting' is a tool that counts the dominant foreground object + """'loca_visual_prompt_counting' is a tool that counts the dominant foreground object given an image and a visual prompt which is a bounding box describing the object. It returns only the count of the objects in the image. @@ -297,7 +359,7 @@ def visual_prompt_counting( Example ------- - >>> visual_prompt_counting(image, {"bbox": [0.1, 0.1, 0.4, 0.42]}) + >>> loca_visual_prompt_counting(image, {"bbox": [0.1, 0.1, 0.4, 0.42]}) {'count': 45}, """ @@ -316,8 +378,8 @@ def visual_prompt_counting( return resp_data -def image_question_answering(prompt: str, image: np.ndarray) -> str: - """'image_question_answering_' is a tool that can answer questions about the visual +def git_vqa_v2(prompt: str, image: np.ndarray) -> str: + """'git_vqa_v2' is a tool that can answer questions about the visual contents of an image given a question and an image. It returns an answer to the question @@ -331,7 +393,7 @@ def image_question_answering(prompt: str, image: np.ndarray) -> str: Example ------- - >>> image_question_answering('What is the cat doing ?', image) + >>> git_vqa_v2('What is the cat doing ?', image) 'drinking milk' """ @@ -376,8 +438,62 @@ def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]: return resp_data -def image_caption(image: np.ndarray) -> str: - """'image_caption' is a tool that can caption an image based on its contents. It +def vit_image_classification(image: np.ndarray) -> Dict[str, Any]: + """'vit_image_classification' is a tool that can classify an image. It returns a + list of classes and their probability scores based on image content. + + Parameters: + image (np.ndarray): The image to classify or tag + + Returns: + Dict[str, Any]: A dictionary containing the labels and scores. One dictionary + contains a list of labels and other a list of scores. + + Example + ------- + >>> vit_image_classification(image) + {"labels": ["leopard", "lemur, otter", "bird"], "scores": [0.68, 0.30, 0.02]}, + """ + + image_b64 = convert_to_b64(image) + data = { + "image": image_b64, + "tool": "image_classification", + } + resp_data = _send_inference_request(data, "tools") + resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]] + return resp_data + + +def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]: + """'vit_nsfw_classification' is a tool that can classify an image as 'nsfw' or 'normal'. + It returns the predicted label and their probability scores based on image content. + + Parameters: + image (np.ndarray): The image to classify or tag + + Returns: + Dict[str, Any]: A dictionary containing the labels and scores. One dictionary + contains a list of labels and other a list of scores. + + Example + ------- + >>> vit_nsfw_classification(image) + {"labels": "normal", "scores": 0.68}, + """ + + image_b64 = convert_to_b64(image) + data = { + "image": image_b64, + "tool": "nsfw_image_classification", + } + resp_data = _send_inference_request(data, "tools") + resp_data["scores"] = round(resp_data["scores"], 4) + return resp_data + + +def blip_image_caption(image: np.ndarray) -> str: + """'blip_image_caption' is a tool that can caption an image based on its contents. It returns a text describing the image. Parameters: @@ -388,7 +504,7 @@ def image_caption(image: np.ndarray) -> str: Example ------- - >>> image_caption(image) + >>> blip_image_caption(image) 'This image contains a cat sitting on a table with a bowl of milk.' """ @@ -792,15 +908,17 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame: TOOLS = [ - grounding_dino, + owl_v2, grounding_sam, extract_frames, ocr, clip, - zero_shot_counting, - visual_prompt_counting, - image_question_answering, - image_caption, + vit_image_classification, + vit_nsfw_classification, + loca_zero_shot_counting, + loca_visual_prompt_counting, + git_vqa_v2, + blip_image_caption, closest_mask_distance, closest_box_distance, save_json,