From 0c9d74bad4b5bf8b3ca5ebeac08ba8bf4119ed8b Mon Sep 17 00:00:00 2001 From: Yazhou Cao Date: Mon, 11 Mar 2024 13:12:20 +0800 Subject: [PATCH 1/3] Implement Grounding DINO tool --- vision_agent/image_utils.py | 28 ++++++++++++++++++++++++++++ vision_agent/lmm/lmm.py | 6 +++--- vision_agent/tools/tools.py | 25 ++++++++++++++++++++++--- 3 files changed, 53 insertions(+), 6 deletions(-) create mode 100644 vision_agent/image_utils.py diff --git a/vision_agent/image_utils.py b/vision_agent/image_utils.py new file mode 100644 index 00000000..78022342 --- /dev/null +++ b/vision_agent/image_utils.py @@ -0,0 +1,28 @@ +import base64 +from io import BytesIO +from pathlib import Path +from typing import Union + +import numpy as np +from PIL import Image + + +def b64_to_pil(b64_str: str) -> Image: + # , can't be encoded in b64 data so must be part of prefix + if "," in b64_str: + b64_str = b64_str.split(",")[1] + return Image.open(BytesIO(base64.b64decode(b64_str))) + + +def convert_to_b64(data: Union[str, Path, np.ndarray, Image.Image]) -> str: + if data is None: + raise ValueError(f"Invalid input image: {data}. Input image can't be None.") + if isinstance(data, (str, Path)): + data = Image.open(data) + if isinstance(data, Image.Image): + buffer = BytesIO() + data.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + else: + arr_bytes = data.tobytes() + return base64.b64encode(arr_bytes).decode("utf-8") diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 488048fc..14472884 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -8,19 +8,19 @@ import requests from vision_agent.tools import ( - SYSTEM_PROMPT, CHOOSE_PARAMS, - ImageTool, CLIP, + SYSTEM_PROMPT, GroundingDINO, GroundingSAM, + ImageTool, ) logging.basicConfig(level=logging.INFO) _LOGGER = logging.getLogger(__name__) -_LLAVA_ENDPOINT = "https://cpvlqoxw6vhpdro27uhkvceady0kvvqk.lambda-url.us-east-2.on.aws" +_LLAVA_ENDPOINT = "https://svtswgdnleslqcsjvilau4p6u40jwrkn.lambda-url.us-east-2.on.aws" def encode_image(image: Union[str, Path]) -> str: diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 9ca70452..fc5013da 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -1,8 +1,12 @@ -from typing import Dict, List, Union from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, Dict, List, Union, cast +import requests from PIL.Image import Image as ImageType +from vision_agent.image_utils import convert_to_b64 + class ImageTool(ABC): @abstractmethod @@ -27,6 +31,8 @@ def __call__(self, image: Union[str, ImageType]) -> List[Dict]: class GroundingDINO(ImageTool): + _ENDPOINT = "https://chnicr4kes5ku77niv2zoytggq0qyqlp.lambda-url.us-east-2.on.aws" + doc = ( "Grounding DINO is a tool that can detect arbitrary objects with inputs such as category names or referring expressions." "Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n" @@ -38,8 +44,21 @@ class GroundingDINO(ImageTool): def __init__(self, prompt: str): self.prompt = prompt - def __call__(self, image: Union[str, ImageType]) -> List[Dict]: - raise NotImplementedError + def __call__(self, image: Union[str, Path, ImageType]) -> List[Dict]: + image_b64 = convert_to_b64(image) + data = { + "prompt": self.prompt, + "images": [image_b64], + } + res = requests.post( + self._ENDPOINT, + headers={"Content-Type": "application/json"}, + json=data, + ) + resp_json: Dict[str, Any] = res.json() + # if resp_json["statusCode"] != 200: + # _LOGGER.error(f"Request failed: {resp_json['data']}") + return cast(str, resp_json["data"]) class GroundingSAM(ImageTool): From e4f3b7e0bf0535e0b2908797153e8c1727f7b639 Mon Sep 17 00:00:00 2001 From: Yazhou Cao Date: Mon, 11 Mar 2024 13:13:36 +0800 Subject: [PATCH 2/3] error handling --- vision_agent/tools/tools.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index fc5013da..4d41871e 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -1,3 +1,4 @@ +import logging from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Dict, List, Union, cast @@ -7,6 +8,8 @@ from vision_agent.image_utils import convert_to_b64 +_LOGGER = logging.getLogger(__name__) + class ImageTool(ABC): @abstractmethod @@ -56,8 +59,8 @@ def __call__(self, image: Union[str, Path, ImageType]) -> List[Dict]: json=data, ) resp_json: Dict[str, Any] = res.json() - # if resp_json["statusCode"] != 200: - # _LOGGER.error(f"Request failed: {resp_json['data']}") + if resp_json["statusCode"] != 200: + _LOGGER.error(f"Request failed: {resp_json['data']}") return cast(str, resp_json["data"]) From dfdb7922a56358980b0c2c7509bb1bb273e0fcc4 Mon Sep 17 00:00:00 2001 From: Yazhou Cao Date: Mon, 11 Mar 2024 13:21:22 +0800 Subject: [PATCH 3/3] Fix typing errors --- vision_agent/image_utils.py | 2 +- vision_agent/tools/tools.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vision_agent/image_utils.py b/vision_agent/image_utils.py index 78022342..86533972 100644 --- a/vision_agent/image_utils.py +++ b/vision_agent/image_utils.py @@ -7,7 +7,7 @@ from PIL import Image -def b64_to_pil(b64_str: str) -> Image: +def b64_to_pil(b64_str: str) -> Image.Image: # , can't be encoded in b64 data so must be part of prefix if "," in b64_str: b64_str = b64_str.split(",")[1] diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 4d41871e..474eba1e 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -60,8 +60,8 @@ def __call__(self, image: Union[str, Path, ImageType]) -> List[Dict]: ) resp_json: Dict[str, Any] = res.json() if resp_json["statusCode"] != 200: - _LOGGER.error(f"Request failed: {resp_json['data']}") - return cast(str, resp_json["data"]) + _LOGGER.error(f"Request failed: {resp_json}") + return cast(List[Dict], resp_json["data"]) class GroundingSAM(ImageTool):