From 89d78d8e9725dec03fdfc138693030b3af9e5f12 Mon Sep 17 00:00:00 2001 From: shankar_ws3 Date: Thu, 16 May 2024 12:19:10 -0700 Subject: [PATCH] adding more cv tools to coder agent --- vision_agent/tools/tools.py | 2 +- vision_agent/tools/tools_v2.py | 131 ++++++++++++++++++++++++++++++++- 2 files changed, 131 insertions(+), 2 deletions(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index cad7f4ad..fdbc1fe2 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -53,7 +53,7 @@ def __call__(self) -> None: class CLIP(Tool): - r"""CLIP is a tool that can classify or tag any image given a set if input classes + r"""CLIP is a tool that can classify or tag any image given a set of input classes or tags. Example diff --git a/vision_agent/tools/tools_v2.py b/vision_agent/tools/tools_v2.py index 1f1a3c6d..ada93de2 100644 --- a/vision_agent/tools/tools_v2.py +++ b/vision_agent/tools/tools_v2.py @@ -15,7 +15,14 @@ from vision_agent.tools.tool_utils import _send_inference_request from vision_agent.utils import extract_frames_from_video -from vision_agent.utils.image_utils import convert_to_b64, normalize_bbox, rle_decode +from vision_agent.utils.image_utils import ( + convert_to_b64, + normalize_bbox, + rle_decode, + b64_to_pil, + get_image_size, + denormalize_bbox, +) COLORS = [ (158, 218, 229), @@ -235,6 +242,124 @@ def ocr(image: np.ndarray) -> List[Dict[str, Any]]: return output +def zero_shot_counting(image: np.ndarray, classes: List[str]) -> Dict[str, Any]: + """'zero_shot_counting' is a tool that counts the dominant foreground object given an image and no other information about the content. + It returns only the count of the objects in the image. + + Parameters: + image (np.ndarray): The image that contains lot of instances of a single object + + Returns: + Dict[str, Any]: A dictionary containing the key 'count' and the count as a value. E.g. {count: 12}. + + Example + ------- + >>> zero_shot_counting(image) + {'count': 45}, + + """ + + image_b64 = convert_to_b64(image) + data = { + "image": image_b64, + "tool": "zero_shot_counting", + } + resp_data = _send_inference_request(data, "tools") + resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0])) + return resp_data + + +def visual_prompt_counting( + image: np.ndarray, visual_prompt: Dict[str, List[float]] +) -> Dict[str, Any]: + """'visual_prompt_counting' is a tool that counts the dominant foreground object given an image and a visual prompt which is a bounding box describing the object. + It returns only the count of the objects in the image. + + Parameters: + image (np.ndarray): The image that contains lot of instances of a single object + + Returns: + Dict[str, Any]: A dictionary containing the key 'count' and the count as a value. E.g. {count: 12}. + + Example + ------- + >>> visual_prompt_counting(image, {"bbox": [0.1, 0.1, 0.4, 0.42]}) + {'count': 45}, + + """ + + image_size = get_image_size(image) + bbox = visual_prompt["bbox"] + bbox_str = ", ".join(map(str, denormalize_bbox(bbox, image_size))) + image_b64 = convert_to_b64(image) + + data = { + "image": image_b64, + "prompt": bbox_str, + "tool": "few_shot_counting", + } + resp_data = _send_inference_request(data, "tools") + resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0])) + return resp_data + + +def image_question_answering(image: np.ndarray, prompt: str) -> str: + """'image_question_answering_' is a tool that can answer questions about the visual contents of an image given a question and an image. + It returns a text describing the image and the answer to the question + + Parameters: + image (np.ndarray): The reference image used for the question + prompt (str): The question about the image + + Returns: + str: A string which is the answer to the given prompt. E.g. {'text': 'This image contains a cat sitting on a table with a bowl of milk.'}. + + Example + ------- + >>> image_question_answering(image, 'What is the cat doing ?') + 'This image contains a cat sitting on a table with a bowl of milk.' + + """ + + image_b64 = convert_to_b64(image) + data = { + "image": image_b64, + "prompt": prompt, + "tool": "image_question_answering", + } + + return _send_inference_request(data, "tools")["text"] + + +def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]: + """'clip' is a tool that can classify an image given a list of input classes or tags. + It returns the same list of the input classes along with their probability scores based on image content. + + Parameters: + image (np.ndarray): The image to classify or tag + classes (List[str]): The list of classes or tags that is associated with the image + + Returns: + Dict[str, Any]: A dictionary containing the labels and scores. One dictionary contains a list of given labels and other a list of scores. + + Example + ------- + >>> clip(image, ['dog', 'cat', 'bird']) + {"labels": ["dog", "cat", "bird"], "scores": [0.68, 0.30, 0.02]}, + + """ + + image_b64 = convert_to_b64(image) + data = { + "prompt": classes, + "image": image_b64, + "tool": "closed_set_image_classification", + } + resp_data = _send_inference_request(data, "tools") + resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]] + return resp_data + + def closest_mask_distance(mask1: np.ndarray, mask2: np.ndarray) -> float: """'closest_mask_distance' calculates the closest distance between two masks. @@ -504,6 +629,10 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame: grounding_sam, extract_frames, ocr, + clip, + zero_shot_counting, + visual_prompt_counting, + image_question_answering, closest_mask_distance, closest_box_distance, save_json,