Skip to content

Commit

Permalink
Implement the Grounding DINO tool (#12)
Browse files Browse the repository at this point in the history
* Implement Grounding DINO tool

* error handling

* Fix typing errors

---------

Co-authored-by: Yazhou Cao <[email protected]>
  • Loading branch information
humpydonkey and AsiaCao authored Mar 11, 2024
1 parent 9716eb7 commit 3d16fcf
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 6 deletions.
28 changes: 28 additions & 0 deletions vision_agent/image_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import base64
from io import BytesIO
from pathlib import Path
from typing import Union

import numpy as np
from PIL import Image


def b64_to_pil(b64_str: str) -> Image.Image:
# , can't be encoded in b64 data so must be part of prefix
if "," in b64_str:
b64_str = b64_str.split(",")[1]
return Image.open(BytesIO(base64.b64decode(b64_str)))


def convert_to_b64(data: Union[str, Path, np.ndarray, Image.Image]) -> str:
if data is None:
raise ValueError(f"Invalid input image: {data}. Input image can't be None.")
if isinstance(data, (str, Path)):
data = Image.open(data)
if isinstance(data, Image.Image):
buffer = BytesIO()
data.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
else:
arr_bytes = data.tobytes()
return base64.b64encode(arr_bytes).decode("utf-8")
6 changes: 3 additions & 3 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@
import requests

from vision_agent.tools import (
SYSTEM_PROMPT,
CHOOSE_PARAMS,
ImageTool,
CLIP,
SYSTEM_PROMPT,
GroundingDINO,
GroundingSAM,
ImageTool,
)

logging.basicConfig(level=logging.INFO)

_LOGGER = logging.getLogger(__name__)

_LLAVA_ENDPOINT = "https://cpvlqoxw6vhpdro27uhkvceady0kvvqk.lambda-url.us-east-2.on.aws"
_LLAVA_ENDPOINT = "https://svtswgdnleslqcsjvilau4p6u40jwrkn.lambda-url.us-east-2.on.aws"


def encode_image(image: Union[str, Path]) -> str:
Expand Down
28 changes: 25 additions & 3 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from typing import Dict, List, Union
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Union, cast

import requests
from PIL.Image import Image as ImageType

from vision_agent.image_utils import convert_to_b64

_LOGGER = logging.getLogger(__name__)


class ImageTool(ABC):
@abstractmethod
Expand All @@ -27,6 +34,8 @@ def __call__(self, image: Union[str, ImageType]) -> List[Dict]:


class GroundingDINO(ImageTool):
_ENDPOINT = "https://chnicr4kes5ku77niv2zoytggq0qyqlp.lambda-url.us-east-2.on.aws"

doc = (
"Grounding DINO is a tool that can detect arbitrary objects with inputs such as category names or referring expressions."
"Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n"
Expand All @@ -38,8 +47,21 @@ class GroundingDINO(ImageTool):
def __init__(self, prompt: str):
self.prompt = prompt

def __call__(self, image: Union[str, ImageType]) -> List[Dict]:
raise NotImplementedError
def __call__(self, image: Union[str, Path, ImageType]) -> List[Dict]:
image_b64 = convert_to_b64(image)
data = {
"prompt": self.prompt,
"images": [image_b64],
}
res = requests.post(
self._ENDPOINT,
headers={"Content-Type": "application/json"},
json=data,
)
resp_json: Dict[str, Any] = res.json()
if resp_json["statusCode"] != 200:
_LOGGER.error(f"Request failed: {resp_json}")
return cast(List[Dict], resp_json["data"])


class GroundingSAM(ImageTool):
Expand Down

0 comments on commit 3d16fcf

Please sign in to comment.