diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index 1d99ff69..afa9dcb4 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -1,5 +1,6 @@ import numpy as np import skimage as ski +from PIL import Image from vision_agent.tools import ( blip_image_caption, @@ -8,15 +9,19 @@ depth_anything_v2, detr_segmentation, dpt_hybrid_midas, - florencev2_image_caption, - florencev2_object_detection, - florencev2_roberta_vqa, - florencev2_ocr, + florence2_image_caption, + florence2_object_detection, + florence2_ocr, + florence2_roberta_vqa, + florence2_sam2_image, + florence2_sam2_video, generate_pose_image, generate_soft_edge_image, git_vqa_v2, grounding_dino, grounding_sam, + ixc25_image_vqa, + ixc25_video_vqa, loca_visual_prompt_counting, loca_zero_shot_counting, ocr, @@ -60,7 +65,7 @@ def test_owl(): def test_object_detection(): img = ski.data.coins() - result = florencev2_object_detection( + result = florence2_object_detection( image=img, prompt="coin", ) @@ -88,6 +93,30 @@ def test_grounding_sam(): assert len([res["mask"] for res in result]) == 24 +def test_florence2_sam2_image(): + img = ski.data.coins() + result = florence2_sam2_image( + prompt="coin", + image=img, + ) + assert len(result) == 25 + assert [res["label"] for res in result] == ["coin"] * 25 + assert len([res["mask"] for res in result]) == 25 + + +def test_florence2_sam2_video(): + frames = [ + np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10) + ] + result = florence2_sam2_video( + prompt="coin", + frames=frames, + ) + assert len(result) == 10 + assert len([res["label"] for res in result[0]]) == 25 + assert len([res["mask"] for res in result[0]]) == 25 + + def test_segmentation(): img = ski.data.coins() result = detr_segmentation( @@ -133,7 +162,7 @@ def test_image_caption() -> None: def test_florence_image_caption() -> None: img = ski.data.rocket() - result = florencev2_image_caption( + result = florence2_image_caption( image=img, ) assert "The image shows a rocket on a launch pad at night" in result.strip() @@ -168,13 +197,33 @@ def test_git_vqa_v2() -> None: def test_image_qa_with_context() -> None: img = ski.data.rocket() - result = florencev2_roberta_vqa( + result = florence2_roberta_vqa( prompt="Is the scene captured during day or night ?", image=img, ) assert "night" in result.strip() +def test_ixc25_image_vqa() -> None: + img = ski.data.cat() + result = ixc25_image_vqa( + prompt="What animal is in this image?", + image=img, + ) + assert "cat" in result.strip() + + +def test_ixc25_video_vqa() -> None: + frames = [ + np.array(Image.fromarray(ski.data.cat()).convert("RGB")) for _ in range(10) + ] + result = ixc25_video_vqa( + prompt="What animal is in this video?", + frames=frames, + ) + assert "cat" in result.strip() + + def test_ocr() -> None: img = ski.data.page() result = ocr( @@ -183,9 +232,9 @@ def test_ocr() -> None: assert any("Region-based segmentation" in res["label"] for res in result) -def test_florencev2_ocr() -> None: +def test_florence2_ocr() -> None: img = ski.data.page() - result = florencev2_ocr( + result = florence2_ocr( image=img, ) assert any("Region-based segmentation" in res["label"] for res in result) diff --git a/vision_agent/agent/agent_utils.py b/vision_agent/agent/agent_utils.py index e4e678d7..5d55e963 100644 --- a/vision_agent/agent/agent_utils.py +++ b/vision_agent/agent/agent_utils.py @@ -4,14 +4,13 @@ from typing import Any, Dict logging.basicConfig(stream=sys.stdout) -_LOGGER = logging.getLogger(__name__) def extract_json(json_str: str) -> Dict[str, Any]: try: + json_str = json_str.replace("\n", " ") json_dict = json.loads(json_str) except json.JSONDecodeError: - input_json_str = json_str if "```json" in json_str: json_str = json_str[json_str.find("```json") + len("```json") :] json_str = json_str[: json_str.find("```")] @@ -19,12 +18,8 @@ def extract_json(json_str: str) -> Dict[str, Any]: json_str = json_str[json_str.find("```") + len("```") :] # get the last ``` not one from an intermediate string json_str = json_str[: json_str.find("}```")] - try: - json_dict = json.loads(json_str) - except json.JSONDecodeError as e: - error_msg = f"Could not extract JSON from the given str: {json_str}.\nFunction input:\n{input_json_str}" - _LOGGER.exception(error_msg) - raise ValueError(error_msg) from e + + json_dict = json.loads(json_str) return json_dict # type: ignore diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py index 3a370c5e..3f445d80 100644 --- a/vision_agent/agent/vision_agent_coder.py +++ b/vision_agent/agent/vision_agent_coder.py @@ -4,6 +4,7 @@ import os import sys import tempfile +from json import JSONDecodeError from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast @@ -86,8 +87,8 @@ def format_memory(memory: List[Dict[str, str]]) -> str: def format_plans(plans: Dict[str, Any]) -> str: plan_str = "" for k, v in plans.items(): - plan_str += f"{k}:\n" - plan_str += "-" + "\n-".join([e["instructions"] for e in v]) + plan_str += "\n" + f"{k}: {v['thoughts']}\n" + plan_str += " -" + "\n -".join([e for e in v["instructions"]]) return plan_str @@ -228,13 +229,11 @@ def pick_plan( "status": "completed" if tool_output.success else "failed", } ) - tool_output_str = "" - if len(tool_output.logs.stdout) > 0: - tool_output_str = tool_output.logs.stdout[0] + tool_output_str = tool_output.text().strip() if verbosity == 2: _print_code("Code and test after attempted fix:", code) - _LOGGER.info(f"Code execution result after attempte {count}") + _LOGGER.info(f"Code execution result after attempt {count}") count += 1 @@ -251,7 +250,21 @@ def pick_plan( tool_output=tool_output_str[:20_000], ) chat[-1]["content"] = prompt - best_plan = extract_json(model(chat, stream=False)) # type: ignore + + count = 0 + best_plan = None + while best_plan is None and count < max_retries: + try: + best_plan = extract_json(model(chat, stream=False)) # type: ignore + except JSONDecodeError as e: + _LOGGER.exception( + f"Error while extracting JSON during picking best plan {str(e)}" + ) + pass + count += 1 + + if best_plan is None: + best_plan = {"best_plan": list(plans.keys())[0]} if verbosity >= 1: _LOGGER.info(f"Best plan:\n{best_plan}") @@ -525,7 +538,7 @@ def _print_code(title: str, code: str, test: Optional[str] = None) -> None: def retrieve_tools( - plans: Dict[str, List[Dict[str, str]]], + plans: Dict[str, Dict[str, Any]], tool_recommender: Sim, log_progress: Callable[[Dict[str, Any]], None], verbosity: int = 0, @@ -542,8 +555,8 @@ def retrieve_tools( tool_lists: Dict[str, List[Dict[str, str]]] = {} for k, plan in plans.items(): tool_lists[k] = [] - for task in plan: - tools = tool_recommender.top_k(task["instructions"], k=2, thresh=0.3) + for task in plan["instructions"]: + tools = tool_recommender.top_k(task, k=2, thresh=0.3) tool_info.extend([e["doc"] for e in tools]) tool_desc.extend([e["desc"] for e in tools]) tool_lists[k].extend( @@ -737,14 +750,7 @@ def chat_with_workflow( if self.verbosity >= 1: for p in plans: # tabulate will fail if the keys are not the same for all elements - p_fixed = [ - { - "instructions": ( - e["instructions"] if "instructions" in e else "" - ) - } - for e in plans[p] - ] + p_fixed = [{"instructions": e} for e in plans[p]["instructions"]] _LOGGER.info( f"\n{tabulate(tabular_data=p_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" ) @@ -793,13 +799,15 @@ def chat_with_workflow( ) if self.verbosity >= 1: + plan_i_fixed = [{"instructions": e} for e in plan_i["instructions"]] _LOGGER.info( - f"Picked best plan:\n{tabulate(tabular_data=plan_i, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" + f"Picked best plan:\n{tabulate(tabular_data=plan_i_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" ) results = write_and_test_code( chat=[{"role": c["role"], "content": c["content"]} for c in int_chat], - plan="\n-" + "\n-".join([e["instructions"] for e in plan_i]), + plan=f"\n{plan_i['thoughts']}\n-" + + "\n-".join([e for e in plan_i["instructions"]]), tool_info=tool_info, tool_output=tool_output_str, tool_utils=T.UTILITIES_DOCSTRING, diff --git a/vision_agent/agent/vision_agent_coder_prompts.py b/vision_agent/agent/vision_agent_coder_prompts.py index cb4c3eeb..c68f73fe 100644 --- a/vision_agent/agent/vision_agent_coder_prompts.py +++ b/vision_agent/agent/vision_agent_coder_prompts.py @@ -30,18 +30,19 @@ **Instructions**: 1. Based on the context and tools you have available, create a plan of subtasks to achieve the user request. -2. Output three different plans each utilize a different strategy or tool. +2. Output three different plans each utilize a different strategy or set of tools. Output a list of jsons in the following format ```json {{ "plan1": - [ - {{ - "instructions": str # what you should do in this task associated with a tool - }} - ], + {{ + "thoughts": str # your thought process for choosing this plan + "instructions": [ + str # what you should do in this task associated with a tool + ] + }}, "plan2": ..., "plan3": ... }} @@ -127,7 +128,8 @@ **Instructions**: 1. Given the plans, image, and tool outputs, decide which plan is the best to achieve the user request. -2. Output a JSON object with the following format: +2. Try solving the problem yourself given the image and pick the plan that matches your solution the best. +3. Output a JSON object with the following format: {{ "thoughts": str # your thought process for choosing the best plan "best_plan": str # the best plan you have chosen diff --git a/vision_agent/agent/vision_agent_prompts.py b/vision_agent/agent/vision_agent_prompts.py index 6f0fdf74..4774d84d 100644 --- a/vision_agent/agent/vision_agent_prompts.py +++ b/vision_agent/agent/vision_agent_prompts.py @@ -15,7 +15,7 @@ **Examples**: Here is an example of how you can interact with a user and Actions to complete a task: --- START EXAMPLES --- -[Current directory: /workspace/test] +[Current directory: /example/workspace] {examples} --- END EXAMPLES --- @@ -27,16 +27,17 @@ Here is the current conversation so far: --- START CONVERSATION --- [Current directory: {dir}] + {conversation} """ EXAMPLES_CODE1 = """ USER: Can you detect the dogs in this image? Media name dog.jpg -AGENT: {"thoughts": "I will use the generate_vision_code to detect the dogs in the image.", "response": "generate_vision_code('/workspace/test/dog_detector.py', 'Can you write code to detect dogs in this image?', media=['/workspace/test/dog.jpg'])", "let_user_respond": false} +AGENT: {"thoughts": "I will use the generate_vision_code to detect the dogs in the image.", "response": "generate_vision_code('/example/workspace/dog_detector.py', 'Can you write code to detect dogs in this image?', media=['/example/workspace/dog.jpg'])", "let_user_respond": false} OBSERVATION: -[File /workspace/test/dog_detector.py] +[File /example/workspace/dog_detector.py] 0|from vision_agent.tools import load_image, owl_v2 1|def detect_dogs(image_path: str): 2| image = load_image(image_path) @@ -44,7 +45,7 @@ 4| return dogs [End of file] -AGENT: {"thoughts": "I have generated the code to detect the dogs in the image, I must now run the code to get the output.", "response": "from dog_detector import detect_dogs\n print(detect_dogs('/workspace/test/dog.jpg'))", "let_user_respond": false} +AGENT: {"thoughts": "I have generated the code to detect the dogs in the image, I must now run the code to get the output.", "response": "from dog_detector import detect_dogs\n print(detect_dogs('/example/workspace/dog.jpg'))", "let_user_respond": false} OBSERVATION: ----- stdout ----- @@ -55,10 +56,10 @@ USER: The the image only has one dog, can you fix this? -AGENT: {"thoughts": "I will use edit_vision_code to fix the code to detect only one dog. The lower score is 0.23 so I will set the threshold above that.", "response": "edit_vision_code('/workspace/test/dog_detector.py', ['Can you write code to detect dogs in this image?', 'Can you increase the threshold to 0.24?'], media=['/workspace/test/dog.jpg'])", "let_user_respond": false} +AGENT: {"thoughts": "I will use edit_vision_code to fix the code to detect only one dog. The lower score is 0.23 so I will set the threshold above that.", "response": "edit_vision_code('/example/workspace/dog_detector.py', ['Can you write code to detect dogs in this image?', 'Can you increase the threshold to 0.24?'], media=['/example/workspace/dog.jpg'])", "let_user_respond": false} OBSERVATION: -[File /workspace/test/dog_detector.py] +[File /example/workspace/dog_detector.py] 0|from vision_agent.tools import load_image, owl_v2 1|def detect_dogs(image_path: str): 2| image = load_image(image_path) @@ -66,7 +67,7 @@ 4| return dogs [End of file] -AGENT: {"thoughts": "I have edited the code to detect only one dog, I must now run the code to get the output.", "response": "from dog_detector import detect_dogs\n print(detect_dogs('/workspace/test/dog.jpg'))", "let_user_respond": false} +AGENT: {"thoughts": "I have edited the code to detect only one dog, I must now run the code to get the output.", "response": "from dog_detector import detect_dogs\n print(detect_dogs('/example/workspace/dog.jpg'))", "let_user_respond": false} OBSERVATION: ----- stdout ----- @@ -82,10 +83,10 @@ USER: Yes you can use workers.png -AGENT: {"thoughts": "I will use the generate_vision_code to count the workers with helmets in the image.", "response": "generate_vision_code('/workspace/test/code.py', 'Can you write code to count workers with helmets in this image?', media=['/workspace/test/workers.png'])", "let_user_respond": false} +AGENT: {"thoughts": "I will use the generate_vision_code to count the workers with helmets in the image.", "response": "generate_vision_code('/example/workspace/code.py', 'Can you write code to count workers with helmets in this image?', media=['/example/workspace/workers.png'])", "let_user_respond": false} OBSERVATION: -[File /workspace/test/code.py] +[File /example/workspace/code.py] 0|from vision_agent.tools import load_image, owl_v2, closest_box_distance 1|def count_workers_with_helmets(image_path: str): 2| image = load_image(image_path) @@ -104,7 +105,7 @@ 15| return count [End of file] -AGENT: {"thoughts": "I have generated the code to count the workers with helmets in the image, I must now run the code to get the output.", "response": "from code import count_workers_with_helmets\n print(count_workers_with_helmets('/workspace/test/workers.png'))", "let_user_respond": false} +AGENT: {"thoughts": "I have generated the code to count the workers with helmets in the image, I must now run the code to get the output.", "response": "from code import count_workers_with_helmets\n print(count_workers_with_helmets('/example/workspace/workers.png'))", "let_user_respond": false} OBSERVATION: ----- stdout ----- diff --git a/vision_agent/clients/landing_public_api.py b/vision_agent/clients/landing_public_api.py index 3fd1928e..eec218ad 100644 --- a/vision_agent/clients/landing_public_api.py +++ b/vision_agent/clients/landing_public_api.py @@ -1,6 +1,6 @@ import os -from uuid import UUID from typing import List +from uuid import UUID from requests.exceptions import HTTPError diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 2f4ab4d6..1e587ce7 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -19,16 +19,20 @@ detr_segmentation, dpt_hybrid_midas, extract_frames, - florencev2_image_caption, - florencev2_object_detection, - florencev2_roberta_vqa, - florencev2_ocr, + florence2_image_caption, + florence2_object_detection, + florence2_ocr, + florence2_roberta_vqa, + florence2_sam2_image, + florence2_sam2_video, generate_pose_image, generate_soft_edge_image, get_tool_documentation, git_vqa_v2, grounding_dino, grounding_sam, + ixc25_image_vqa, + ixc25_video_vqa, load_image, loca_visual_prompt_counting, loca_zero_shot_counting, diff --git a/vision_agent/tools/meta_tools.py b/vision_agent/tools/meta_tools.py index 7c857550..4a82436d 100644 --- a/vision_agent/tools/meta_tools.py +++ b/vision_agent/tools/meta_tools.py @@ -8,7 +8,6 @@ from vision_agent.tools.tool_utils import get_tool_documentation from vision_agent.tools.tools import TOOL_DESCRIPTIONS - # These tools are adapted from SWE-Agent https://github.com/princeton-nlp/SWE-agent CURRENT_FILE = None diff --git a/vision_agent/tools/tool_utils.py b/vision_agent/tools/tool_utils.py index 8d7e3aa9..4f1b1f34 100644 --- a/vision_agent/tools/tool_utils.py +++ b/vision_agent/tools/tool_utils.py @@ -1,7 +1,7 @@ import inspect import logging import os -from typing import Any, Callable, Dict, List, MutableMapping, Optional +from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple import pandas as pd from IPython.display import display @@ -31,6 +31,7 @@ class ToolCallTrace(BaseModel): def send_inference_request( payload: Dict[str, Any], endpoint_name: str, + files: Optional[List[Tuple[Any, ...]]] = None, v2: bool = False, metadata_payload: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: @@ -50,7 +51,7 @@ def send_inference_request( response={}, error=None, ) - headers = {"Content-Type": "application/json", "apikey": _LND_API_KEY} + headers = {"apikey": _LND_API_KEY} if "TOOL_ENDPOINT_AUTH" in os.environ: headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"] headers.pop("apikey") @@ -60,7 +61,11 @@ def send_inference_request( num_retry=3, headers=headers, ) - res = session.post(url, json=payload) + + if files is not None: + res = session.post(url, data=payload, files=files) + else: + res = session.post(url, json=payload) if res.status_code != 200: tool_call_trace.error = Error( name="RemoteToolCallFailed", diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index c05369e3..2dade7f7 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -2,47 +2,50 @@ import json import logging import tempfile -from uuid import UUID -from pathlib import Path from importlib import resources +from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union, cast +from uuid import UUID import cv2 -import requests import numpy as np -from pytube import YouTube # type: ignore +import requests from moviepy.editor import ImageSequenceClip from PIL import Image, ImageDraw, ImageFont from pillow_heif import register_heif_opener # type: ignore +from pytube import YouTube # type: ignore +from vision_agent.clients.landing_public_api import LandingPublicAPI from vision_agent.tools.tool_utils import ( - send_inference_request, get_tool_descriptions, get_tool_documentation, get_tools_df, get_tools_info, + send_inference_request, +) +from vision_agent.tools.tools_types import ( + BboxInput, + BboxInputBase64, + FineTuning, + Florencev2FtRequest, + JobStatus, + PromptTask, ) -from vision_agent.utils.exceptions import FineTuneModelIsNotReady from vision_agent.utils import extract_frames_from_video +from vision_agent.utils.exceptions import FineTuneModelIsNotReady from vision_agent.utils.execute import FileSerializer, MimeType from vision_agent.utils.image_utils import ( b64_to_pil, + convert_quad_box_to_bbox, convert_to_b64, denormalize_bbox, + frames_to_bytes, get_image_size, normalize_bbox, - convert_quad_box_to_bbox, + numpy_to_bytes, rle_decode, + rle_decode_array, ) -from vision_agent.tools.tools_types import ( - BboxInput, - BboxInputBase64, - PromptTask, - Florencev2FtRequest, - FineTuning, - JobStatus, -) -from vision_agent.clients.landing_public_api import LandingPublicAPI register_heif_opener() @@ -141,9 +144,9 @@ def owl_v2( box_threshold: float = 0.10, ) -> List[Dict[str, Any]]: """'owl_v2' is a tool that can detect and count multiple objects given a text - prompt such as category names or referring expressions. The categories in text prompt - are separated by commas. It returns a list of bounding boxes with - normalized coordinates, label names and associated probability scores. + prompt such as category names or referring expressions. The categories in text + prompt are separated by commas. It returns a list of bounding boxes with normalized + coordinates, label names and associated probability scores. Parameters: prompt (str): The prompt to ground to the image. @@ -194,10 +197,10 @@ def grounding_sam( box_threshold: float = 0.20, iou_threshold: float = 0.20, ) -> List[Dict[str, Any]]: - """'grounding_sam' is a tool that can segment multiple objects given a - text prompt such as category names or referring expressions. The categories in text - prompt are separated by commas or periods. It returns a list of bounding boxes, - label names, mask file names and associated probability scores. + """'grounding_sam' is a tool that can segment multiple objects given a text prompt + such as category names or referring expressions. The categories in text prompt are + separated by commas or periods. It returns a list of bounding boxes, label names, + mask file names and associated probability scores. Parameters: prompt (str): The prompt to ground to the image. @@ -254,52 +257,114 @@ def grounding_sam( return return_data -def extract_frames( - video_uri: Union[str, Path], fps: float = 0.5 -) -> List[Tuple[np.ndarray, float]]: - """'extract_frames' extracts frames from a video which can be a file path or youtube - link, returns a list of tuples (frame, timestamp), where timestamp is the relative - time in seconds where the frame was captured. The frame is a numpy array. +def florence2_sam2_image(prompt: str, image: np.ndarray) -> List[Dict[str, Any]]: + """'florence2_sam2_image' is a tool that can segment multiple objects given a text + prompt such as category names or referring expressions. The categories in the text + prompt are separated by commas. It returns a list of bounding boxes, label names, + mask file names and associated probability scores of 1.0. Parameters: - video_uri (Union[str, Path]): The path to the video file or youtube link - fps (float, optional): The frame rate per second to extract the frames. Defaults - to 0.5. + prompt (str): The prompt to ground to the image. + image (np.ndarray): The image to ground the prompt to. Returns: - List[Tuple[np.ndarray, float]]: A list of tuples containing the extracted frame - as a numpy array and the timestamp in seconds. + List[Dict[str, Any]]: A list of dictionaries containing the score, label, + bounding box, and mask of the detected objects with normalized coordinates + (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left + and xmax and ymax are the coordinates of the bottom-right of the bounding box. + The mask is binary 2D numpy array where 1 indicates the object and 0 indicates + the background. Example ------- - >>> extract_frames("path/to/video.mp4") - [(frame1, 0.0), (frame2, 0.5), ...] + >>> florence2_sam2_image("car, dinosaur", image) + [ + { + 'score': 1.0, + 'label': 'dinosaur', + 'bbox': [0.1, 0.11, 0.35, 0.4], + 'mask': array([[0, 0, 0, ..., 0, 0, 0], + [0, 0, 0, ..., 0, 0, 0], + ..., + [0, 0, 0, ..., 0, 0, 0], + [0, 0, 0, ..., 0, 0, 0]], dtype=uint8), + }, + ] """ + buffer_bytes = numpy_to_bytes(image) - if str(video_uri).startswith( - ( - "http://www.youtube.com/", - "https://www.youtube.com/", - "http://youtu.be/", - "https://youtu.be/", - ) - ): - with tempfile.TemporaryDirectory() as temp_dir: - yt = YouTube(str(video_uri)) - # Download the highest resolution video - video = ( - yt.streams.filter(progressive=True, file_extension="mp4") - .order_by("resolution") - .desc() - .first() - ) - if not video: - raise Exception("No suitable video stream found") - video_file_path = video.download(output_path=temp_dir) + files = [("image", buffer_bytes)] + payload = { + "prompts": [s.strip() for s in prompt.split(",")], + "function_name": "florence2_sam2_image", + } + data: Dict[str, Any] = send_inference_request( + payload, "florence2-sam2", files=files, v2=True + ) + return_data = [] + for _, data_i in data["0"].items(): + mask = rle_decode_array(data_i["mask"]) + label = data_i["label"] + bbox = normalize_bbox(data_i["bounding_box"], data_i["mask"]["size"]) + return_data.append({"label": label, "bbox": bbox, "mask": mask, "score": 1.0}) + return return_data - return extract_frames_from_video(video_file_path, fps) - return extract_frames_from_video(str(video_uri), fps) +def florence2_sam2_video( + prompt: str, frames: List[np.ndarray] +) -> List[List[Dict[str, Any]]]: + """'florence2_sam2_video' is a tool that can segment and track multiple entities + in a video given a text prompt such as category names or referring expressions. You + can optionally separate the categories in the text with commas. It only tracks + entities present in the first frame and only returns segmentation masks. It is + useful for tracking and counting without duplicating counts. + + Parameters: + prompt (str): The prompt to ground to the video. + frames (List[np.ndarray]): The list of frames to ground the prompt to. + + Returns: + List[List[Dict[str, Any]]]: A list of list of dictionaries containing the label + and segment mask. The outer list represents each frame and the inner list is + the entities per frame. The label contains the object ID followed by the label + name. The objects are only identified in the first framed and tracked + throughout the video. + + Example + ------- + >>> florence2_sam2_video("car, dinosaur", frames) + [ + [ + { + 'label': '0: dinosaur', + 'mask': array([[0, 0, 0, ..., 0, 0, 0], + [0, 0, 0, ..., 0, 0, 0], + ..., + [0, 0, 0, ..., 0, 0, 0], + [0, 0, 0, ..., 0, 0, 0]], dtype=uint8), + }, + ], + ] + """ + + buffer_bytes = frames_to_bytes(frames) + files = [("video", buffer_bytes)] + payload = { + "prompts": prompt.split(","), + "function_name": "florence2_sam2_video", + } + data: Dict[str, Any] = send_inference_request( + payload, "florence2-sam2", files=files, v2=True + ) + return_data = [] + for frame_i in data.keys(): + return_frame_data = [] + for obj_id, data_j in data[frame_i].items(): + mask = rle_decode_array(data_j["mask"]) + label = obj_id + ": " + data_j["label"] + return_frame_data.append({"label": label, "mask": mask, "score": 1.0}) + return_data.append(return_frame_data) + return return_data def ocr(image: np.ndarray) -> List[Dict[str, Any]]: @@ -368,12 +433,19 @@ def loca_zero_shot_counting(image: np.ndarray) -> Dict[str, Any]: Returns: Dict[str, Any]: A dictionary containing the key 'count' and the count as a - value. E.g. {count: 12}. + value, e.g. {count: 12} and a heat map for visaulization purposes. Example ------- >>> loca_zero_shot_counting(image) - {'count': 45}, + {'count': 83, + 'heat_map': array([[ 0, 0, 0, ..., 0, 0, 0], + [ 0, 0, 0, ..., 0, 0, 0], + [ 0, 0, 0, ..., 0, 0, 1], + ..., + [ 0, 0, 0, ..., 30, 35, 41], + [ 0, 0, 0, ..., 41, 47, 53], + [ 0, 0, 0, ..., 53, 59, 64]], dtype=uint8)} """ image_b64 = convert_to_b64(image) @@ -398,12 +470,19 @@ def loca_visual_prompt_counting( Returns: Dict[str, Any]: A dictionary containing the key 'count' and the count as a - value. E.g. {count: 12}. + value, e.g. {count: 12} and a heat map for visaulization purposes. Example ------- >>> loca_visual_prompt_counting(image, {"bbox": [0.1, 0.1, 0.4, 0.42]}) - {'count': 45}, + {'count': 83, + 'heat_map': array([[ 0, 0, 0, ..., 0, 0, 0], + [ 0, 0, 0, ..., 0, 0, 0], + [ 0, 0, 0, ..., 0, 0, 1], + ..., + [ 0, 0, 0, ..., 30, 35, 41], + [ 0, 0, 0, ..., 41, 47, 53], + [ 0, 0, 0, ..., 53, 59, 64]], dtype=uint8)} """ image_size = get_image_size(image) @@ -420,8 +499,8 @@ def loca_visual_prompt_counting( return resp_data -def florencev2_roberta_vqa(prompt: str, image: np.ndarray) -> str: - """'florencev2_roberta_vqa' is a tool that takes an image and analyzes +def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str: + """'florence2_roberta_vqa' is a tool that takes an image and analyzes its contents, generates detailed captions and then tries to answer the given question using the generated context. It returns text as an answer to the question. @@ -434,7 +513,7 @@ def florencev2_roberta_vqa(prompt: str, image: np.ndarray) -> str: Example ------- - >>> florencev2_roberta_vqa('What is the top left animal in this image ?', image) + >>> florence2_roberta_vqa('What is the top left animal in this image?', image) 'white tiger' """ @@ -442,13 +521,73 @@ def florencev2_roberta_vqa(prompt: str, image: np.ndarray) -> str: data = { "image": image_b64, "question": prompt, - "function_name": "florencev2_roberta_vqa", + "function_name": "florence2_roberta_vqa", } answer = send_inference_request(data, "florence2-qa", v2=True) return answer # type: ignore +def ixc25_image_vqa(prompt: str, image: np.ndarray) -> str: + """'ixc25_image_vqa' is a tool that can answer any questions about arbitrary images + including regular images or images of documents or presentations. It returns text + as an answer to the question. + + Parameters: + prompt (str): The question about the image + image (np.ndarray): The reference image used for the question + + Returns: + str: A string which is the answer to the given prompt. + + Example + ------- + >>> ixc25_image_vqa('What is the cat doing?', image) + 'drinking milk' + """ + + buffer_bytes = numpy_to_bytes(image) + files = [("image", buffer_bytes)] + payload = { + "prompt": prompt, + "function_name": "ixc25_image_vqa", + } + data: Dict[str, Any] = send_inference_request( + payload, "internlm-xcomposer2", files=files, v2=True + ) + return cast(str, data["answer"]) + + +def ixc25_video_vqa(prompt: str, frames: List[np.ndarray]) -> str: + """'ixc25_video_vqa' is a tool that can answer any questions about arbitrary videos + including regular videos or videos of documents or presentations. It returns text + as an answer to the question. + + Parameters: + prompt (str): The question about the video + frames (List[np.ndarray]): The reference frames used for the question + + Returns: + str: A string which is the answer to the given prompt. + + Example + ------- + >>> ixc25_video_vqa('Which football player made the goal?', frames) + 'Lionel Messi' + """ + + buffer_bytes = frames_to_bytes(frames) + files = [("video", buffer_bytes)] + payload = { + "prompt": prompt, + "function_name": "ixc25_video_vqa", + } + data: Dict[str, Any] = send_inference_request( + payload, "internlm-xcomposer2", files=files, v2=True + ) + return cast(str, data["answer"]) + + def git_vqa_v2(prompt: str, image: np.ndarray) -> str: """'git_vqa_v2' is a tool that can answer questions about the visual contents of an image given a question and an image. It returns an answer to the @@ -592,8 +731,8 @@ def blip_image_caption(image: np.ndarray) -> str: return answer["text"][0] # type: ignore -def florencev2_image_caption(image: np.ndarray, detail_caption: bool = True) -> str: - """'florencev2_image_caption' is a tool that can caption or describe an image based +def florence2_image_caption(image: np.ndarray, detail_caption: bool = True) -> str: + """'florence2_image_caption' is a tool that can caption or describe an image based on its contents. It returns a text describing the image. Parameters: @@ -606,7 +745,7 @@ def florencev2_image_caption(image: np.ndarray, detail_caption: bool = True) -> Example ------- - >>> florencev2_image_caption(image, False) + >>> florence2_image_caption(image, False) 'This image contains a cat sitting on a table with a bowl of milk.' """ image_b64 = convert_to_b64(image) @@ -614,17 +753,19 @@ def florencev2_image_caption(image: np.ndarray, detail_caption: bool = True) -> data = { "image": image_b64, "task": task, - "function_name": "florencev2_image_caption", + "function_name": "florence2_image_caption", } answer = send_inference_request(data, "florence2", v2=True) return answer[task] # type: ignore -def florencev2_object_detection(prompt: str, image: np.ndarray) -> List[Dict[str, Any]]: - """'florencev2_object_detection' is a tool that can detect objects given a text - prompt such as a phrase or class names separated by commas. It returns a list of - detected objects as labels and their location as bounding boxes with score of 1.0. +def florence2_object_detection(prompt: str, image: np.ndarray) -> List[Dict[str, Any]]: + """'florencev2_object_detection' is a tool that can detect and count multiple + objects given a text prompt such as category names or referring expressions. You + can optionally separate the categories in the text with commas. It returns a list + of bounding boxes with normalized coordinates, label names and associated + probability scores of 1.0. Parameters: prompt (str): The prompt to ground to the image. @@ -639,7 +780,7 @@ def florencev2_object_detection(prompt: str, image: np.ndarray) -> List[Dict[str Example ------- - >>> florencev2_object_detection('person looking at a coyote', image) + >>> florence2_object_detection('person looking at a coyote', image) [ {'score': 1.0, 'label': 'person', 'bbox': [0.1, 0.11, 0.35, 0.4]}, {'score': 1.0, 'label': 'coyote', 'bbox': [0.34, 0.21, 0.85, 0.5}, @@ -651,7 +792,7 @@ def florencev2_object_detection(prompt: str, image: np.ndarray) -> List[Dict[str "image": image_b64, "task": "", "prompt": prompt, - "function_name": "florencev2_object_detection", + "function_name": "florence2_object_detection", } detections = send_inference_request(data, "florence2", v2=True) @@ -668,8 +809,8 @@ def florencev2_object_detection(prompt: str, image: np.ndarray) -> List[Dict[str return return_data -def florencev2_ocr(image: np.ndarray) -> List[Dict[str, Any]]: - """'florencev2_ocr' is a tool that can detect text and text regions in an image. +def florence2_ocr(image: np.ndarray) -> List[Dict[str, Any]]: + """'florence2_ocr' is a tool that can detect text and text regions in an image. Each text region contains one line of text. It returns a list of detected text, the text region as a bounding box with normalized coordinates, and confidence scores. The results are sorted from top-left to bottom right. @@ -683,7 +824,7 @@ def florencev2_ocr(image: np.ndarray) -> List[Dict[str, Any]]: Example ------- - >>> florencev2_ocr(image) + >>> florence2_ocr(image) [ {'label': 'hello world', 'bbox': [0.1, 0.11, 0.35, 0.4], 'score': 0.99}, ] @@ -694,7 +835,7 @@ def florencev2_ocr(image: np.ndarray) -> List[Dict[str, Any]]: data = { "image": image_b64, "task": "", - "function_name": "florencev2_ocr", + "function_name": "florence2_ocr", } detections = send_inference_request(data, "florence2", v2=True) @@ -1035,6 +1176,54 @@ def closest_box_distance( # Utility and visualization functions +def extract_frames( + video_uri: Union[str, Path], fps: float = 1 +) -> List[Tuple[np.ndarray, float]]: + """'extract_frames' extracts frames from a video which can be a file path or youtube + link, returns a list of tuples (frame, timestamp), where timestamp is the relative + time in seconds where the frame was captured. The frame is a numpy array. + + Parameters: + video_uri (Union[str, Path]): The path to the video file or youtube link + fps (float, optional): The frame rate per second to extract the frames. Defaults + to 10. + + Returns: + List[Tuple[np.ndarray, float]]: A list of tuples containing the extracted frame + as a numpy array and the timestamp in seconds. + + Example + ------- + >>> extract_frames("path/to/video.mp4") + [(frame1, 0.0), (frame2, 0.5), ...] + """ + + if str(video_uri).startswith( + ( + "http://www.youtube.com/", + "https://www.youtube.com/", + "http://youtu.be/", + "https://youtu.be/", + ) + ): + with tempfile.TemporaryDirectory() as temp_dir: + yt = YouTube(str(video_uri)) + # Download the highest resolution video + video = ( + yt.streams.filter(progressive=True, file_extension="mp4") + .order_by("resolution") + .desc() + .first() + ) + if not video: + raise Exception("No suitable video stream found") + video_file_path = video.download(output_path=temp_dir) + + return extract_frames_from_video(video_file_path, fps) + + return extract_frames_from_video(str(video_uri), fps) + + def save_json(data: Any, file_path: str) -> None: """'save_json' is a utility function that saves data as a JSON file. It is helpful for saving data that contains NumPy arrays which are not JSON serializable. @@ -1099,7 +1288,7 @@ def save_image(image: np.ndarray, file_path: str) -> None: def save_video( - frames: List[np.ndarray], output_video_path: Optional[str] = None, fps: float = 4 + frames: List[np.ndarray], output_video_path: Optional[str] = None, fps: float = 1 ) -> str: """'save_video' is a utility function that saves a list of frames as a mp4 video file on disk. @@ -1201,15 +1390,43 @@ def overlay_bounding_boxes( return np.array(pil_image) +def _get_text_coords_from_mask( + mask: np.ndarray, v_gap: int = 10, h_gap: int = 10 +) -> Tuple[int, int]: + mask = mask.astype(np.uint8) + if np.sum(mask) == 0: + return (0, 0) + + rows, cols = np.nonzero(mask) + top = rows.min() + bottom = rows.max() + left = cols.min() + right = cols.max() + + if top - v_gap < 0: + if bottom + v_gap > mask.shape[0]: + top = top + else: + top = bottom + v_gap + else: + top = top - v_gap + + return left + (right - left) // 2 - h_gap, top + + def overlay_segmentation_masks( - image: np.ndarray, masks: List[Dict[str, Any]] -) -> np.ndarray: + medias: Union[np.ndarray, List[np.ndarray]], + masks: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]], + draw_label: bool = True, +) -> Union[np.ndarray, List[np.ndarray]]: """'overlay_segmentation_masks' is a utility function that displays segmentation masks. Parameters: - image (np.ndarray): The image to display the masks on. - masks (List[Dict[str, Any]]): A list of dictionaries containing the masks. + medias (Union[np.ndarray, List[np.ndarray]]): The image or frames to display + the masks on. + masks (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]): A list of + dictionaries containing the masks. Returns: np.ndarray: The image with the masks displayed. @@ -1229,27 +1446,50 @@ def overlay_segmentation_masks( }], ) """ - pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGBA") + medias_int: List[np.ndarray] = ( + [medias] if isinstance(medias, np.ndarray) else medias + ) + masks_int = [masks] if isinstance(masks[0], dict) else masks + masks_int = cast(List[List[Dict[str, Any]]], masks_int) - if len(set([mask["label"] for mask in masks])) > len(COLORS): - _LOGGER.warning( - "Number of unique labels exceeds the number of available colors. Some labels may have the same color." - ) + labels = set() + for mask_i in masks_int: + for mask_j in mask_i: + labels.add(mask_j["label"]) + color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(labels)} - color = { - label: COLORS[i % len(COLORS)] - for i, label in enumerate(set([mask["label"] for mask in masks])) - } - masks = sorted(masks, key=lambda x: x["label"], reverse=True) + width, height = Image.fromarray(medias_int[0]).size + fontsize = max(12, int(min(width, height) / 40)) + font = ImageFont.truetype( + str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")), + fontsize, + ) - for elt in masks: - mask = elt["mask"] - label = elt["label"] - np_mask = np.zeros((pil_image.size[1], pil_image.size[0], 4)) - np_mask[mask > 0, :] = color[label] + (255 * 0.5,) - mask_img = Image.fromarray(np_mask.astype(np.uint8)) - pil_image = Image.alpha_composite(pil_image, mask_img) - return np.array(pil_image) + frame_out = [] + for i, frame in enumerate(medias_int): + pil_image = Image.fromarray(frame.astype(np.uint8)).convert("RGBA") + for elt in masks_int[i]: + mask = elt["mask"] + label = elt["label"] + np_mask = np.zeros((pil_image.size[1], pil_image.size[0], 4)) + np_mask[mask > 0, :] = color[label] + (255 * 0.5,) + mask_img = Image.fromarray(np_mask.astype(np.uint8)) + pil_image = Image.alpha_composite(pil_image, mask_img) + + if draw_label: + draw = ImageDraw.Draw(pil_image) + text_box = draw.textbbox((0, 0), text=label, font=font) + x, y = _get_text_coords_from_mask( + mask, + v_gap=(text_box[3] - text_box[1]) + 10, + h_gap=(text_box[2] - text_box[0]) // 2, + ) + if x != 0 and y != 0: + text_box = draw.textbbox((x, y), text=label, font=font) + draw.rectangle((x, y, text_box[2], text_box[3]), fill=color[label]) + draw.text((x, y), label, fill="black", font=font) + frame_out.append(np.array(pil_image)) + return frame_out[0] if len(frame_out) == 1 else frame_out def overlay_heat_map( @@ -1412,7 +1652,6 @@ def florencev2_fine_tuned_object_detection( TOOLS = [ owl_v2, - grounding_sam, extract_frames, ocr, clip, @@ -1420,13 +1659,15 @@ def florencev2_fine_tuned_object_detection( vit_nsfw_classification, loca_zero_shot_counting, loca_visual_prompt_counting, - florencev2_roberta_vqa, - florencev2_image_caption, - florencev2_ocr, + florence2_image_caption, + florence2_ocr, + florence2_sam2_image, + florence2_sam2_video, + florence2_object_detection, + ixc25_image_vqa, + ixc25_video_vqa, detr_segmentation, depth_anything_v2, - generate_soft_edge_image, - dpt_hybrid_midas, generate_pose_image, closest_mask_distance, closest_box_distance, @@ -1437,7 +1678,6 @@ def florencev2_fine_tuned_object_detection( overlay_bounding_boxes, overlay_segmentation_masks, overlay_heat_map, - template_match, ] TOOLS_DF = get_tools_df(TOOLS) # type: ignore TOOL_DESCRIPTIONS = get_tool_descriptions(TOOLS) # type: ignore diff --git a/vision_agent/utils/execute.py b/vision_agent/utils/execute.py index b157b1df..b62308ff 100644 --- a/vision_agent/utils/execute.py +++ b/vision_agent/utils/execute.py @@ -416,7 +416,6 @@ def download_file(self, file_path: str) -> Path: class E2BCodeInterpreter(CodeInterpreter): - def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) assert os.getenv("E2B_API_KEY"), "E2B_API_KEY environment variable must be set" diff --git a/vision_agent/utils/image_utils.py b/vision_agent/utils/image_utils.py index ddbd14b3..d2bc8a6d 100644 --- a/vision_agent/utils/image_utils.py +++ b/vision_agent/utils/image_utils.py @@ -1,12 +1,15 @@ """Utility functions for image processing.""" import base64 +import io +import tempfile from importlib import resources from io import BytesIO from pathlib import Path from typing import Dict, List, Tuple, Union import numpy as np +from moviepy.editor import ImageSequenceClip from PIL import Image, ImageDraw, ImageFont from PIL.Image import Image as ImageType @@ -63,6 +66,46 @@ def rle_decode(mask_rle: str, shape: Tuple[int, int]) -> np.ndarray: return img.reshape(shape) +def rle_decode_array(rle: Dict[str, List[int]]) -> np.ndarray: + r"""Decode a run-length encoded mask. Returns numpy array, 1 - mask, 0 - background. + + Parameters: + mask: The mask in run-length encoded as an array. + """ + size = rle["size"] + counts = rle["counts"] + + total_elements = size[0] * size[1] + flattened_mask = np.zeros(total_elements, dtype=np.uint8) + + current_pos = 0 + for i, count in enumerate(counts): + if i % 2 == 1: + flattened_mask[current_pos : current_pos + count] = 1 + current_pos += count + + binary_mask = flattened_mask.reshape(size, order="F") + return binary_mask + + +def frames_to_bytes( + frames: List[np.ndarray], fps: float = 10, file_ext: str = "mp4" +) -> bytes: + r"""Convert a list of frames to a video file encoded into a byte string. + + Parameters: + frames: the list of frames + fps: the frames per second of the video + file_ext: the file extension of the video file + """ + with tempfile.NamedTemporaryFile(delete=True) as temp_file: + clip = ImageSequenceClip(frames, fps=fps) + clip.write_videofile(temp_file.name + f".{file_ext}", fps=fps) + with open(temp_file.name + f".{file_ext}", "rb") as f: + buffer_bytes = f.read() + return buffer_bytes + + def b64_to_pil(b64_str: str) -> ImageType: r"""Convert a base64 string to a PIL Image. @@ -78,6 +121,15 @@ def b64_to_pil(b64_str: str) -> ImageType: return Image.open(BytesIO(base64.b64decode(b64_str))) +def numpy_to_bytes(image: np.ndarray) -> bytes: + pil_image = Image.fromarray(image).convert("RGB") + image_buffer = io.BytesIO() + pil_image.save(image_buffer, format="PNG") + buffer_bytes = image_buffer.getvalue() + image_buffer.close() + return buffer_bytes + + def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int, ...]: r"""Get the size of an image.