diff --git a/poetry.lock b/poetry.lock index d5a8aff9..1acf55d6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aenum" @@ -2529,6 +2529,17 @@ files = [ [package.extras] cli = ["click (>=5.0)"] +[[package]] +name = "pytube" +version = "15.0.0" +description = "Python 3 library for downloading YouTube Videos." +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytube-15.0.0-py3-none-any.whl", hash = "sha256:07b9904749e213485780d7eb606e5e5b8e4341aa4dccf699160876da00e12d78"}, + {file = "pytube-15.0.0.tar.gz", hash = "sha256:076052efe76f390dfa24b1194ff821d4e86c17d41cb5562f3a276a8bcbfc9d1d"}, +] + [[package]] name = "pytz" version = "2024.1" @@ -3636,4 +3647,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "05cd29e9f780719371f5172060c2262dab3d9f8a6f1577b7292529e2773b97dc" +content-hash = "53a12eb4508f1ed2426ff24462819b1e4272acfef681535e25b564ccae7a99c0" diff --git a/pyproject.toml b/pyproject.toml index b642b888..8dd7377f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ e2b = "^0.17.1" e2b-code-interpreter = "^0.0.9" tenacity = "^8.3.0" pillow-heif = "^0.16.0" +pytube = "15.0.0" [tool.poetry.group.dev.dependencies] autoflake = "1.*" diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index f0c0ec11..57d536bd 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -5,13 +5,22 @@ blip_image_caption, clip, closest_mask_distance, + florencev2_image_caption, + depth_anything_v2, + dpt_hybrid_midas, + generate_pose_image, + generate_soft_edge_image, + florencev2_object_detection, + detr_segmentation, git_vqa_v2, grounding_dino, grounding_sam, + florencev2_roberta_vqa, loca_visual_prompt_counting, loca_zero_shot_counting, ocr, owl_v2, + template_match, vit_image_classification, vit_nsfw_classification, ) @@ -48,6 +57,24 @@ def test_owl(): assert [res["label"] for res in result] == ["coin"] * 25 +def test_object_detection(): + img = ski.data.coins() + result = florencev2_object_detection( + image=img, + ) + assert len(result) == 24 + assert [res["label"] for res in result] == ["coin"] * 24 + + +def test_template_match(): + img = ski.data.coins() + result = template_match( + image=img, + template_image=img[32:76, 20:68], + ) + assert len(result) == 2 + + def test_grounding_sam(): img = ski.data.coins() result = grounding_sam( @@ -59,6 +86,16 @@ def test_grounding_sam(): assert len([res["mask"] for res in result]) == 24 +def test_segmentation(): + img = ski.data.coins() + result = detr_segmentation( + image=img, + ) + assert len(result) == 1 + assert [res["label"] for res in result] == ["pizza"] + assert len([res["mask"] for res in result]) == 1 + + def test_clip(): img = ski.data.coins() result = clip( @@ -92,6 +129,14 @@ def test_image_caption() -> None: assert result.strip() == "a rocket on a stand" +def test_florence_image_caption() -> None: + img = ski.data.rocket() + result = florencev2_image_caption( + image=img, + ) + assert "The image shows a rocket on a launch pad at night" in result.strip() + + def test_loca_zero_shot_counting() -> None: img = ski.data.coins() @@ -119,6 +164,15 @@ def test_git_vqa_v2() -> None: assert result.strip() == "night" +def test_image_qa_with_context() -> None: + img = ski.data.rocket() + result = florencev2_roberta_vqa( + prompt="Is the scene captured during day or night ?", + image=img, + ) + assert "night" in result.strip() + + def test_ocr() -> None: img = ski.data.page() result = ocr( @@ -144,3 +198,41 @@ def test_mask_distance(): np.sqrt(2) * 81, atol=1e-2, ), f"Expected {np.sqrt(2) * 81}, got {distance}" + + +def test_generate_depth(): + img = ski.data.coins() + result = depth_anything_v2( + image=img, + ) + + assert result.shape == img.shape + + +def test_generate_pose(): + img = ski.data.coins() + result = generate_pose_image( + image=img, + ) + import cv2 + + cv2.imwrite("imag.png", result) + assert result.shape == img.shape + (3,) + + +def test_generate_normal(): + img = ski.data.coins() + result = dpt_hybrid_midas( + image=img, + ) + + assert result.shape == img.shape + (3,) + + +def test_generate_hed(): + img = ski.data.coins() + result = generate_soft_edge_image( + image=img, + ) + + assert result.shape == img.shape diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index da5ed2b8..61d118da 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -12,10 +12,18 @@ closest_box_distance, closest_mask_distance, extract_frames, + florencev2_image_caption, get_tool_documentation, + florencev2_object_detection, + detr_segmentation, + depth_anything_v2, + generate_soft_edge_image, + dpt_hybrid_midas, + generate_pose_image, git_vqa_v2, grounding_dino, grounding_sam, + florencev2_roberta_vqa, load_image, loca_visual_prompt_counting, loca_zero_shot_counting, @@ -27,6 +35,7 @@ save_image, save_json, save_video, + template_match, vit_image_classification, vit_nsfw_classification, ) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 8acda4b2..a888dcc2 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -14,6 +14,7 @@ from moviepy.editor import ImageSequenceClip from PIL import Image, ImageDraw, ImageFont from pillow_heif import register_heif_opener # type: ignore +from pytube import YouTube # type: ignore from vision_agent.tools.tool_utils import send_inference_request from vision_agent.utils import extract_frames_from_video @@ -126,7 +127,7 @@ def owl_v2( ) -> List[Dict[str, Any]]: """'owl_v2' is a tool that can detect and count multiple objects given a text prompt such as category names or referring expressions. The categories in text prompt - are separated by commas or periods. It returns a list of bounding boxes with + are separated by commas. It returns a list of bounding boxes with normalized coordinates, label names and associated probability scores. Parameters: @@ -136,7 +137,6 @@ def owl_v2( to 0.10. iou_threshold (float, optional): The threshold for the Intersection over Union (IoU). Defaults to 0.10. - model_size (str, optional): The size of the model to use. Returns: List[Dict[str, Any]]: A list of dictionaries containing the score, label, and @@ -180,7 +180,7 @@ def grounding_sam( box_threshold: float = 0.20, iou_threshold: float = 0.20, ) -> List[Dict[str, Any]]: - """'grounding_sam' is a tool that can detect and segment multiple objects given a + """'grounding_sam' is a tool that can segment multiple objects given a text prompt such as category names or referring expressions. The categories in text prompt are separated by commas or periods. It returns a list of bounding boxes, label names, mask file names and associated probability scores. @@ -242,12 +242,12 @@ def grounding_sam( def extract_frames( video_uri: Union[str, Path], fps: float = 0.5 ) -> List[Tuple[np.ndarray, float]]: - """'extract_frames' extracts frames from a video, returns a list of tuples (frame, - timestamp), where timestamp is the relative time in seconds where the frame was - captured. The frame is a numpy array. + """'extract_frames' extracts frames from a video which can be a file path or youtube + link, returns a list of tuples (frame, timestamp), where timestamp is the relative + time in seconds where the frame was captured. The frame is a numpy array. Parameters: - video_uri (Union[str, Path]): The path to the video file. + video_uri (Union[str, Path]): The path to the video file or youtube link fps (float, optional): The frame rate per second to extract the frames. Defaults to 0.5. @@ -261,6 +261,29 @@ def extract_frames( [(frame1, 0.0), (frame2, 0.5), ...] """ + if str(video_uri).startswith( + ( + "http://www.youtube.com/", + "https://www.youtube.com/", + "http://youtu.be/", + "https://youtu.be/", + ) + ): + with tempfile.TemporaryDirectory() as temp_dir: + yt = YouTube(str(video_uri)) + # Download the highest resolution video + video = ( + yt.streams.filter(progressive=True, file_extension="mp4") + .order_by("resolution") + .desc() + .first() + ) + if not video: + raise Exception("No suitable video stream found") + video_file_path = video.download(output_path=temp_dir) + + return extract_frames_from_video(video_file_path, fps) + return extract_frames_from_video(str(video_uri), fps) @@ -381,6 +404,35 @@ def loca_visual_prompt_counting( return resp_data +def florencev2_roberta_vqa(prompt: str, image: np.ndarray) -> str: + """'florencev2_roberta_vqa' is a tool that takes an image and analyzes + its contents, generates detailed captions and then tries to answer the given + question using the generated context. It returns text as an answer to the question. + + Parameters: + prompt (str): The question about the image + image (np.ndarray): The reference image used for the question + + Returns: + str: A string which is the answer to the given prompt. + + Example + ------- + >>> florencev2_roberta_vqa('What is the top left animal in this image ?', image) + 'white tiger' + """ + + image_b64 = convert_to_b64(image) + data = { + "image": image_b64, + "prompt": prompt, + "tool": "image_question_answering_with_context", + } + + answer = send_inference_request(data, "tools") + return answer["text"][0] # type: ignore + + def git_vqa_v2(prompt: str, image: np.ndarray) -> str: """'git_vqa_v2' is a tool that can answer questions about the visual contents of an image given a question and an image. It returns an answer to the @@ -391,8 +443,7 @@ def git_vqa_v2(prompt: str, image: np.ndarray) -> str: image (np.ndarray): The reference image used for the question Returns: - str: A string which is the answer to the given prompt. E.g. {'text': 'This - image contains a cat sitting on a table with a bowl of milk.'}. + str: A string which is the answer to the given prompt. Example ------- @@ -521,6 +572,309 @@ def blip_image_caption(image: np.ndarray) -> str: return answer["text"][0] # type: ignore +def florencev2_image_caption(image: np.ndarray, detail_caption: bool = True) -> str: + """'florencev2_image_caption' is a tool that can caption or describe an image based + on its contents. It returns a text describing the image. + + Parameters: + image (np.ndarray): The image to caption + detail_caption (bool): If True, the caption will be as detailed as possible else + the caption will be a brief description. + + Returns: + str: A string which is the caption for the given image. + + Example + ------- + >>> florencev2_image_caption(image, False) + 'This image contains a cat sitting on a table with a bowl of milk.' + """ + image_b64 = convert_to_b64(image) + data = { + "image": image_b64, + "tool": "florence2_image_captioning", + "detail_caption": detail_caption, + } + + answer = send_inference_request(data, "tools") + return answer["text"][0] # type: ignore + + +def florencev2_object_detection(image: np.ndarray) -> List[Dict[str, Any]]: + """'florencev2_object_detection' is a tool that can detect common objects in an + image without any text prompt or thresholding. It returns a list of detected objects + as labels and their location as bounding boxes. + + Parameters: + image (np.ndarray): The image to used to detect objects + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the score, label, and + bounding box of the detected objects with normalized coordinates between 0 + and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the + top-left and xmax and ymax are the coordinates of the bottom-right of the + bounding box. The scores are always 1.0 and cannot be thresholded + + Example + ------- + >>> florencev2_object_detection(image) + [ + {'score': 1.0, 'label': 'window', 'bbox': [0.1, 0.11, 0.35, 0.4]}, + {'score': 1.0, 'label': 'car', 'bbox': [0.2, 0.21, 0.45, 0.5}, + {'score': 1.0, 'label': 'person', 'bbox': [0.34, 0.21, 0.85, 0.5}, + ] + """ + image_size = image.shape[:2] + image_b64 = convert_to_b64(image) + data = { + "image": image_b64, + "tool": "object_detection", + } + + answer = send_inference_request(data, "tools") + return_data = [] + for i in range(len(answer["bboxes"])): + return_data.append( + { + "score": round(answer["scores"][i], 2), + "label": answer["labels"][i], + "bbox": normalize_bbox(answer["bboxes"][i], image_size), + } + ) + return return_data + + +def detr_segmentation(image: np.ndarray) -> List[Dict[str, Any]]: + """'detr_segmentation' is a tool that can segment common objects in an + image without any text prompt. It returns a list of detected objects + as labels, their regions as masks and their scores. + + Parameters: + image (np.ndarray): The image used to segment things and objects + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the score, label + and mask of the detected objects. The mask is binary 2D numpy array where 1 + indicates the object and 0 indicates the background. + + Example + ------- + >>> detr_segmentation(image) + [ + { + 'score': 0.45, + 'label': 'window', + 'mask': array([[0, 0, 0, ..., 0, 0, 0], + [0, 0, 0, ..., 0, 0, 0], + ..., + [0, 0, 0, ..., 0, 0, 0], + [0, 0, 0, ..., 0, 0, 0]], dtype=uint8), + }, + { + 'score': 0.70, + 'label': 'bird', + 'mask': array([[0, 0, 0, ..., 0, 0, 0], + [0, 0, 0, ..., 0, 0, 0], + ..., + [0, 0, 0, ..., 0, 0, 0], + [0, 0, 0, ..., 0, 0, 0]], dtype=uint8), + }, + ] + """ + image_b64 = convert_to_b64(image) + data = { + "image": image_b64, + "tool": "panoptic_segmentation", + } + + answer = send_inference_request(data, "tools") + return_data = [] + + for i in range(len(answer["scores"])): + return_data.append( + { + "score": round(answer["scores"][i], 2), + "label": answer["labels"][i], + "mask": rle_decode( + mask_rle=answer["masks"][i], shape=answer["mask_shape"][0] + ), + } + ) + return return_data + + +def depth_anything_v2(image: np.ndarray) -> np.ndarray: + """'depth_anything_v2' is a tool that runs depth_anythingv2 model to generate a + depth image from a given RGB image. The returned depth image is monochrome and + represents depth values as pixel intesities with pixel values ranging from 0 to 255. + + Parameters: + image (np.ndarray): The image to used to generate depth image + + Returns: + np.ndarray: A grayscale depth image with pixel values ranging from 0 to 255. + + Example + ------- + >>> depth_anything_v2(image) + array([[0, 0, 0, ..., 0, 0, 0], + [0, 20, 24, ..., 0, 100, 103], + ..., + [10, 11, 15, ..., 202, 202, 205], + [10, 10, 10, ..., 200, 200, 200]], dtype=uint8), + """ + image_b64 = convert_to_b64(image) + data = { + "image": image_b64, + "tool": "generate_depth", + } + + answer = send_inference_request(data, "tools") + return_data = np.array(b64_to_pil(answer["masks"][0]).convert("L")) + return return_data + + +def generate_soft_edge_image(image: np.ndarray) -> np.ndarray: + """'generate_soft_edge_image' is a tool that runs Holistically Nested edge detection + to generate a soft edge image (HED) from a given RGB image. The returned image is + monochrome and represents object boundaries as soft white edges on black background + + Parameters: + image (np.ndarray): The image to used to generate soft edge image + + Returns: + np.ndarray: A soft edge image with pixel values ranging from 0 to 255. + + Example + ------- + >>> generate_soft_edge_image(image) + array([[0, 0, 0, ..., 0, 0, 0], + [0, 20, 24, ..., 0, 100, 103], + ..., + [10, 11, 15, ..., 202, 202, 205], + [10, 10, 10, ..., 200, 200, 200]], dtype=uint8), + """ + image_b64 = convert_to_b64(image) + data = { + "image": image_b64, + "tool": "generate_hed", + } + + answer = send_inference_request(data, "tools") + return_data = np.array(b64_to_pil(answer["masks"][0]).convert("L")) + return return_data + + +def dpt_hybrid_midas(image: np.ndarray) -> np.ndarray: + """'dpt_hybrid_midas' is a tool that generates a normal mapped from a given RGB + image. The returned RGB image is texture mapped image of the surface normals and the + RGB values represent the surface normals in the x, y, z directions. + + Parameters: + image (np.ndarray): The image to used to generate normal image + + Returns: + np.ndarray: A mapped normal image with RGB pixel values indicating surface + normals in x, y, z directions. + + Example + ------- + >>> dpt_hybrid_midas(image) + array([[0, 0, 0, ..., 0, 0, 0], + [0, 20, 24, ..., 0, 100, 103], + ..., + [10, 11, 15, ..., 202, 202, 205], + [10, 10, 10, ..., 200, 200, 200]], dtype=uint8), + """ + image_b64 = convert_to_b64(image) + data = { + "image": image_b64, + "tool": "generate_normal", + } + + answer = send_inference_request(data, "tools") + return_data = np.array(b64_to_pil(answer["masks"][0]).convert("RGB")) + return return_data + + +def generate_pose_image(image: np.ndarray) -> np.ndarray: + """'generate_pose_image' is a tool that generates a open pose bone/stick image from + a given RGB image. The returned bone image is RGB with the pose amd keypoints colored + and background as black. + + Parameters: + image (np.ndarray): The image to used to generate pose image + + Returns: + np.ndarray: A bone or pose image indicating the pose and keypoints + + Example + ------- + >>> generate_pose_image(image) + array([[0, 0, 0, ..., 0, 0, 0], + [0, 20, 24, ..., 0, 100, 103], + ..., + [10, 11, 15, ..., 202, 202, 205], + [10, 10, 10, ..., 200, 200, 200]], dtype=uint8), + """ + image_b64 = convert_to_b64(image) + data = { + "image": image_b64, + "tool": "generate_pose", + } + + answer = send_inference_request(data, "tools") + return_data = np.array(b64_to_pil(answer["masks"][0]).convert("RGB")) + return return_data + + +def template_match( + image: np.ndarray, template_image: np.ndarray +) -> List[Dict[str, Any]]: + """'template_match' is a tool that can detect all instances of a template in + a given image. It returns the locations of the detected template, a corresponding + similarity score of the same + + Parameters: + image (np.ndarray): The image used for searching the template + template_image (np.ndarray): The template image or crop to search in the image + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the score and + bounding box of the detected template with normalized coordinates between 0 + and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the + top-left and xmax and ymax are the coordinates of the bottom-right of the + bounding box. + + Example + ------- + >>> template_match(image, template) + [ + {'score': 0.79, 'bbox': [0.1, 0.11, 0.35, 0.4]}, + {'score': 0.38, 'bbox': [0.2, 0.21, 0.45, 0.5}, + ] + """ + image_size = image.shape[:2] + image_b64 = convert_to_b64(image) + template_image_b64 = convert_to_b64(template_image) + data = { + "image": image_b64, + "template": template_image_b64, + "tool": "template_match", + } + + answer = send_inference_request(data, "tools") + return_data = [] + for i in range(len(answer["bboxes"])): + return_data.append( + { + "score": round(answer["scores"][i], 2), + "bbox": normalize_bbox(answer["bboxes"][i], image_size), + } + ) + return return_data + + def closest_mask_distance(mask1: np.ndarray, mask2: np.ndarray) -> float: """'closest_mask_distance' calculates the closest distance between two masks. @@ -733,7 +1087,7 @@ def overlay_bounding_boxes( image, [{'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]}], ) """ - pil_image = Image.fromarray(image.astype(np.uint8)) + pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB") if len(set([box["label"] for box in bboxes])) > len(COLORS): _LOGGER.warning( @@ -920,8 +1274,14 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame: vit_nsfw_classification, loca_zero_shot_counting, loca_visual_prompt_counting, - git_vqa_v2, - blip_image_caption, + florencev2_roberta_vqa, + florencev2_image_caption, + florencev2_object_detection, + detr_segmentation, + depth_anything_v2, + generate_soft_edge_image, + dpt_hybrid_midas, + generate_pose_image, closest_mask_distance, closest_box_distance, save_json, @@ -931,6 +1291,7 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame: overlay_bounding_boxes, overlay_segmentation_masks, overlay_heat_map, + template_match, ] TOOLS_DF = get_tools_df(TOOLS) # type: ignore TOOL_DESCRIPTIONS = get_tool_descriptions(TOOLS) # type: ignore