diff --git a/.github/workflows/ci_cd.yml b/.github/workflows/ci_cd.yml index d41e7592..3576e10c 100644 --- a/.github/workflows/ci_cd.yml +++ b/.github/workflows/ci_cd.yml @@ -1,10 +1,14 @@ name: CI + on: push: branches: [ main ] pull_request: branches: [ main ] +env: + LANDINGAI_DEV_API_KEY: ${{ secrets.LANDINGAI_DEV_API_KEY }} + jobs: unit_test: name: Test @@ -79,6 +83,9 @@ jobs: - name: Test with pytest run: | poetry run pytest -v tests/integ + - name: Test with pytest, dev env + run: | + LANDINGAI_API_KEY=$LANDINGAI_DEV_API_KEY LANDINGAI_URL=https://api.dev.landing.ai poetry run pytest -v tests/integration_dev release: name: Release diff --git a/tests/integration_dev/__init__.py b/tests/integration_dev/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_dev/test_tools.py b/tests/integration_dev/test_tools.py new file mode 100644 index 00000000..29262245 --- /dev/null +++ b/tests/integration_dev/test_tools.py @@ -0,0 +1,21 @@ +import skimage as ski + +from vision_agent.tools import ( + countgd_counting, + countgd_example_based_counting, +) + + +def test_countgd_counting() -> None: + img = ski.data.coins() + result = countgd_counting(image=img, prompt="coin") + assert len(result) == 24 + + +def test_countgd_example_based_counting() -> None: + img = ski.data.coins() + result = countgd_example_based_counting( + visual_prompts=[[85, 106, 122, 145]], + image=img, + ) + assert len(result) == 24 diff --git a/vision_agent/agent/vision_agent_coder_prompts.py b/vision_agent/agent/vision_agent_coder_prompts.py index c68f73fe..b4c8a9bf 100644 --- a/vision_agent/agent/vision_agent_coder_prompts.py +++ b/vision_agent/agent/vision_agent_coder_prompts.py @@ -81,20 +81,19 @@ - Count the number of detected objects labeled as 'person'. plan3: - Load the image from the provided file path 'image.jpg'. -- Use the 'loca_zero_shot_counting' tool to count the dominant foreground object, which in this case is people. +- Use the 'countgd_counting' tool to count the dominant foreground object, which in this case is people. ```python -from vision_agent.tools import load_image, owl_v2, grounding_sam, loca_zero_shot_counting +from vision_agent.tools import load_image, owl_v2, grounding_sam, countgd_counting image = load_image("image.jpg") owl_v2_out = owl_v2("person", image) gsam_out = grounding_sam("person", image) gsam_out = [{{k: v for k, v in o.items() if k != "mask"}} for o in gsam_out] -loca_out = loca_zero_shot_counting(image) -loca_out = loca_out["count"] +cgd_out = countgd_counting(image) -final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "loca_zero_shot_counting": loca_out}} +final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "countgd_counting": cgd_out}} print(final_out) ``` """ diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 76481f3f..e68666d2 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -286,9 +286,6 @@ def generate_segmentor(self, question: str) -> Callable: return lambda x: T.grounding_sam(params["prompt"], x) - def generate_zero_shot_counter(self, question: str) -> Callable: - return T.loca_zero_shot_counting - def generate_image_qa_tool(self, question: str) -> Callable: return lambda x: T.git_vqa_v2(question, x) diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index e82d7553..43460fbd 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -37,10 +37,13 @@ load_image, loca_visual_prompt_counting, loca_zero_shot_counting, + countgd_counting, + countgd_example_based_counting, ocr, overlay_bounding_boxes, overlay_heat_map, overlay_segmentation_masks, + overlay_counting_results, owl_v2, save_image, save_json, diff --git a/vision_agent/tools/tool_utils.py b/vision_agent/tools/tool_utils.py index 185563a4..a14443bd 100644 --- a/vision_agent/tools/tool_utils.py +++ b/vision_agent/tools/tool_utils.py @@ -1,6 +1,6 @@ +import os import inspect import logging -import os from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple import pandas as pd @@ -13,6 +13,7 @@ from vision_agent.utils.exceptions import RemoteToolCallFailed from vision_agent.utils.execute import Error, MimeType from vision_agent.utils.type_defs import LandingaiAPIKey +from vision_agent.tools.tools_types import BoundingBoxes _LOGGER = logging.getLogger(__name__) _LND_API_KEY = os.environ.get("LANDINGAI_API_KEY", LandingaiAPIKey().api_key) @@ -34,61 +35,58 @@ def send_inference_request( files: Optional[List[Tuple[Any, ...]]] = None, v2: bool = False, metadata_payload: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: +) -> Any: # TODO: runtime_tag and function_name should be metadata_payload and now included # in the service payload - try: - if runtime_tag := os.environ.get("RUNTIME_TAG", ""): - payload["runtime_tag"] = runtime_tag + if runtime_tag := os.environ.get("RUNTIME_TAG", ""): + payload["runtime_tag"] = runtime_tag + + url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}" + if "TOOL_ENDPOINT_URL" in os.environ: + url = os.environ["TOOL_ENDPOINT_URL"] + + headers = {"apikey": _LND_API_KEY} + if "TOOL_ENDPOINT_AUTH" in os.environ: + headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"] + headers.pop("apikey") + + session = _create_requests_session( + url=url, + num_retry=3, + headers=headers, + ) - url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}" - if "TOOL_ENDPOINT_URL" in os.environ: - url = os.environ["TOOL_ENDPOINT_URL"] + function_name = "unknown" + if "function_name" in payload: + function_name = payload["function_name"] + elif metadata_payload is not None and "function_name" in metadata_payload: + function_name = metadata_payload["function_name"] - tool_call_trace = ToolCallTrace( - endpoint_url=url, - request=payload, - response={}, - error=None, - ) - headers = {"apikey": _LND_API_KEY} - if "TOOL_ENDPOINT_AUTH" in os.environ: - headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"] - headers.pop("apikey") - - session = _create_requests_session( - url=url, - num_retry=3, - headers=headers, - ) + response = _call_post(url, payload, session, files, function_name) - if files is not None: - res = session.post(url, data=payload, files=files) - else: - res = session.post(url, json=payload) - if res.status_code != 200: - tool_call_trace.error = Error( - name="RemoteToolCallFailed", - value=f"{res.status_code} - {res.text}", - traceback_raw=[], - ) - _LOGGER.error(f"Request failed: {res.status_code} {res.text}") - # TODO: function_name should be in metadata_payload - function_name = "unknown" - if "function_name" in payload: - function_name = payload["function_name"] - elif metadata_payload is not None and "function_name" in metadata_payload: - function_name = metadata_payload["function_name"] - raise RemoteToolCallFailed(function_name, res.status_code, res.text) - - resp = res.json() - tool_call_trace.response = resp - # TODO: consider making the response schema the same between below two sources - return resp if "TOOL_ENDPOINT_AUTH" in os.environ else resp["data"] # type: ignore - finally: - trace = tool_call_trace.model_dump() - trace["type"] = "tool_call" - display({MimeType.APPLICATION_JSON: trace}, raw=True) + # TODO: consider making the response schema the same between below two sources + return response if "TOOL_ENDPOINT_AUTH" in os.environ else response["data"] + + +def send_task_inference_request( + payload: Dict[str, Any], + task_name: str, + files: Optional[List[Tuple[Any, ...]]] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> Any: + url = f"{_LND_API_URL_v2}/{task_name}" + headers = {"apikey": _LND_API_KEY} + session = _create_requests_session( + url=url, + num_retry=3, + headers=headers, + ) + + function_name = "unknown" + if metadata is not None and "function_name" in metadata: + function_name = metadata["function_name"] + response = _call_post(url, payload, session, files, function_name) + return response["data"] def _create_requests_session( @@ -195,3 +193,49 @@ def get_tools_info(funcs: List[Callable[..., Any]]) -> Dict[str, str]: data[func.__name__] = f"{func.__name__}{inspect.signature(func)}:\n{desc}" return data + + +def _call_post( + url: str, + payload: dict[str, Any], + session: Session, + files: Optional[List[Tuple[Any, ...]]] = None, + function_name: str = "unknown", +) -> Any: + try: + tool_call_trace = ToolCallTrace( + endpoint_url=url, + request=payload, + response={}, + error=None, + ) + + if files is not None: + response = session.post(url, data=payload, files=files) + else: + response = session.post(url, json=payload) + + if response.status_code != 200: + tool_call_trace.error = Error( + name="RemoteToolCallFailed", + value=f"{response.status_code} - {response.text}", + traceback_raw=[], + ) + _LOGGER.error(f"Request failed: {response.status_code} {response.text}") + raise RemoteToolCallFailed( + function_name, response.status_code, response.text + ) + + result = response.json() + tool_call_trace.response = result + return result + finally: + trace = tool_call_trace.model_dump() + trace["type"] = "tool_call" + display({MimeType.APPLICATION_JSON: trace}, raw=True) + + +def filter_bboxes_by_threshold( + bboxes: BoundingBoxes, threshold: float +) -> BoundingBoxes: + return list(filter(lambda bbox: bbox.score >= threshold, bboxes)) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 0695b547..8012e60d 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -13,7 +13,7 @@ import numpy as np import requests from moviepy.editor import ImageSequenceClip -from PIL import Image, ImageDraw, ImageFont +from PIL import Image, ImageDraw, ImageFont, ImageEnhance from pillow_heif import register_heif_opener # type: ignore from pytube import YouTube # type: ignore @@ -24,6 +24,8 @@ get_tools_df, get_tools_info, send_inference_request, + send_task_inference_request, + filter_bboxes_by_threshold, ) from vision_agent.tools.tools_types import ( BboxInput, @@ -32,6 +34,7 @@ Florencev2FtRequest, JobStatus, PromptTask, + ODResponseData, ) from vision_agent.utils import extract_frames_from_video from vision_agent.utils.exceptions import FineTuneModelIsNotReady @@ -455,7 +458,7 @@ def loca_zero_shot_counting(image: np.ndarray) -> Dict[str, Any]: "image": image_b64, "function_name": "loca_zero_shot_counting", } - resp_data = send_inference_request(data, "loca", v2=True) + resp_data: dict[str, Any] = send_inference_request(data, "loca", v2=True) resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8) return resp_data @@ -469,6 +472,8 @@ def loca_visual_prompt_counting( Parameters: image (np.ndarray): The image that contains lot of instances of a single object + visual_prompt (Dict[str, List[float]]): Bounding box of the object in format + [xmin, ymin, xmax, ymax]. Only 1 bounding box can be provided. Returns: Dict[str, Any]: A dictionary containing the key 'count' and the count as a @@ -496,11 +501,109 @@ def loca_visual_prompt_counting( "bbox": list(map(int, denormalize_bbox(bbox, image_size))), "function_name": "loca_visual_prompt_counting", } - resp_data = send_inference_request(data, "loca", v2=True) + resp_data: dict[str, Any] = send_inference_request(data, "loca", v2=True) resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8) return resp_data +def countgd_counting( + prompt: str, + image: np.ndarray, + box_threshold: float = 0.23, +) -> List[Dict[str, Any]]: + """'countgd_counting' is a tool that can precisely count multiple instances of an + object given a text prompt. It returns a list of bounding boxes with normalized + coordinates, label names and associated confidence scores. + + Parameters: + prompt (str): The object that needs to be counted. + image (np.ndarray): The image that contains multiple instances of the object. + box_threshold (float, optional): The threshold for detection. Defaults + to 0.23. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the score, label, and + bounding box of the detected objects with normalized coordinates between 0 + and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the + top-left and xmax and ymax are the coordinates of the bottom-right of the + bounding box. + + Example + ------- + >>> countgd_counting("flower", image) + [ + {'score': 0.49, 'label': 'flower', 'bbox': [0.1, 0.11, 0.35, 0.4]}, + {'score': 0.68, 'label': 'flower', 'bbox': [0.2, 0.21, 0.45, 0.5}, + {'score': 0.78, 'label': 'flower', 'bbox': [0.3, 0.35, 0.48, 0.52}, + {'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58}, + ] + """ + buffer_bytes = numpy_to_bytes(image) + files = [("image", buffer_bytes)] + prompt = prompt.replace(", ", " .") + payload = {"prompts": [prompt], "model": "countgd"} + metadata = {"function_name": "countgd_counting"} + resp_data = send_task_inference_request( + payload, "text-to-object-detection", files=files, metadata=metadata + ) + bboxes_per_frame = resp_data[0] + bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame] + filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold) + return [bbox.model_dump() for bbox in filtered_bboxes] + + +def countgd_example_based_counting( + visual_prompts: List[List[float]], + image: np.ndarray, + box_threshold: float = 0.23, +) -> List[Dict[str, Any]]: + """'countgd_example_based_counting' is a tool that can precisely count multiple + instances of an object given few visual example prompts. It returns a list of bounding + boxes with normalized coordinates, label names and associated confidence scores. + + Parameters: + visual_prompts (List[List[float]]): Bounding boxes of the object in format + [xmin, ymin, xmax, ymax]. Upto 3 bounding boxes can be provided. + image (np.ndarray): The image that contains multiple instances of the object. + box_threshold (float, optional): The threshold for detection. Defaults + to 0.23. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the score, label, and + bounding box of the detected objects with normalized coordinates between 0 + and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the + top-left and xmax and ymax are the coordinates of the bottom-right of the + bounding box. + + Example + ------- + >>> countgd_example_based_counting( + visual_prompts=[[0.1, 0.1, 0.4, 0.42], [0.2, 0.3, 0.25, 0.35]], + image=image + ) + [ + {'score': 0.49, 'label': 'object', 'bounding_box': [0.1, 0.11, 0.35, 0.4]}, + {'score': 0.68, 'label': 'object', 'bounding_box': [0.2, 0.21, 0.45, 0.5}, + {'score': 0.78, 'label': 'object', 'bounding_box': [0.3, 0.35, 0.48, 0.52}, + {'score': 0.98, 'label': 'object', 'bounding_box': [0.44, 0.24, 0.49, 0.58}, + ] + """ + buffer_bytes = numpy_to_bytes(image) + files = [("image", buffer_bytes)] + visual_prompts = [ + denormalize_bbox(bbox, image.shape[:2]) for bbox in visual_prompts + ] + payload = {"visual_prompts": json.dumps(visual_prompts), "model": "countgd"} + metadata = {"function_name": "countgd_example_based_counting"} + resp_data = send_task_inference_request( + payload, "visual-prompts-to-object-detection", files=files, metadata=metadata + ) + bboxes_per_frame = resp_data[0] + bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame] + filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold) + return [bbox.model_dump() for bbox in filtered_bboxes] + + def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str: """'florence2_roberta_vqa' is a tool that takes an image and analyzes its contents, generates detailed captions and then tries to answer the given @@ -646,7 +749,7 @@ def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]: "tool": "closed_set_image_classification", "function_name": "clip", } - resp_data = send_inference_request(data, "tools") + resp_data: dict[str, Any] = send_inference_request(data, "tools") resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]] return resp_data @@ -674,7 +777,7 @@ def vit_image_classification(image: np.ndarray) -> Dict[str, Any]: "tool": "image_classification", "function_name": "vit_image_classification", } - resp_data = send_inference_request(data, "tools") + resp_data: dict[str, Any] = send_inference_request(data, "tools") resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]] return resp_data @@ -701,7 +804,9 @@ def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]: "image": image_b64, "function_name": "vit_nsfw_classification", } - resp_data = send_inference_request(data, "nsfw-classification", v2=True) + resp_data: dict[str, Any] = send_inference_request( + data, "nsfw-classification", v2=True + ) resp_data["score"] = round(resp_data["score"], 4) return resp_data @@ -1559,6 +1664,74 @@ def overlay_heat_map( return np.array(combined) +def overlay_counting_results( + image: np.ndarray, instances: List[Dict[str, Any]] +) -> np.ndarray: + """'overlay_counting_results' is a utility function that displays counting results on + an image. + + Parameters: + image (np.ndarray): The image to display the bounding boxes on. + instances (List[Dict[str, Any]]): A list of dictionaries containing the bounding + box information of each instance + + Returns: + np.ndarray: The image with the instance_id dislpayed + + Example + ------- + >>> image_with_bboxes = overlay_counting_results( + image, [{'score': 0.99, 'label': 'object', 'bbox': [0.1, 0.11, 0.35, 0.4]}], + ) + """ + pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB") + color = (158, 218, 229) + + width, height = pil_image.size + fontsize = max(10, int(min(width, height) / 80)) + pil_image = ImageEnhance.Brightness(pil_image).enhance(0.5) + draw = ImageDraw.Draw(pil_image) + font = ImageFont.truetype( + str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")), + fontsize, + ) + + for i, elt in enumerate(instances): + label = f"{i}" + box = elt["bbox"] + + # denormalize the box if it is normalized + box = denormalize_bbox(box, (height, width)) + x0, y0, x1, y1 = box + cx, cy = (x0 + x1) / 2, (y0 + y1) / 2 + + text_box = draw.textbbox( + (cx, cy), text=label, font=font, align="center", anchor="mm" + ) + + # Calculate the offset to center the text within the bounding box + text_width = text_box[2] - text_box[0] + text_height = text_box[3] - text_box[1] + text_x0 = cx - text_width / 2 + text_y0 = cy - text_height / 2 + text_x1 = cx + text_width / 2 + text_y1 = cy + text_height / 2 + + # Draw the rectangle encapsulating the text + draw.rectangle((text_x0, text_y0, text_x1, text_y1), fill=color) + + # Draw the text at the center of the bounding box + draw.text( + (text_x0, text_y0), + label, + fill="black", + font=font, + anchor="lt", + ) + + return np.array(pil_image) + + # TODO: add this function to the imports so that is picked in the agent def florencev2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> UUID: """'florencev2_fine_tuning' is a tool that fine-tune florencev2 to be able @@ -1679,8 +1852,7 @@ def florencev2_fine_tuned_object_detection( clip, vit_image_classification, vit_nsfw_classification, - loca_zero_shot_counting, - loca_visual_prompt_counting, + countgd_counting, florence2_image_caption, florence2_ocr, florence2_sam2_image, @@ -1703,6 +1875,7 @@ def florencev2_fine_tuned_object_detection( overlay_bounding_boxes, overlay_segmentation_masks, overlay_heat_map, + overlay_counting_results, ] TOOLS = FUNCTION_TOOLS + UTIL_TOOLS @@ -1720,5 +1893,6 @@ def florencev2_fine_tuned_object_detection( overlay_bounding_boxes, overlay_segmentation_masks, overlay_heat_map, + overlay_counting_results, ] ) diff --git a/vision_agent/tools/tools_types.py b/vision_agent/tools/tools_types.py index 7b640adb..af1e8ee9 100644 --- a/vision_agent/tools/tools_types.py +++ b/vision_agent/tools/tools_types.py @@ -1,8 +1,8 @@ from enum import Enum -from typing import List, Optional, Tuple from uuid import UUID +from typing import List, Tuple, Optional, Union -from pydantic import BaseModel, ConfigDict, Field, SerializationInfo, field_serializer +from pydantic import BaseModel, ConfigDict, Field, field_serializer, SerializationInfo class BboxInput(BaseModel): @@ -82,3 +82,16 @@ class JobStatus(str, Enum): SUCCEEDED = "SUCCEEDED" FAILED = "FAILED" STOPPED = "STOPPED" + + +class ODResponseData(BaseModel): + label: str + score: float + bbox: Union[list[int], list[float]] = Field(alias="bounding_box") + + model_config = ConfigDict( + populate_by_name=True, + ) + + +BoundingBoxes = list[ODResponseData] diff --git a/vision_agent/utils/image_utils.py b/vision_agent/utils/image_utils.py index c1cc8eb6..f0113c9f 100644 --- a/vision_agent/utils/image_utils.py +++ b/vision_agent/utils/image_utils.py @@ -181,7 +181,7 @@ def denormalize_bbox( raise ValueError("Bounding box must be of length 4.") arr = np.array(bbox) - if np.all((arr >= 0) & (arr <= 1)): + if np.all((arr[:2] >= 0) & (arr[:2] <= 1)): x1, y1, x2, y2 = bbox x1 = round(x1 * image_size[1]) y1 = round(y1 * image_size[0])