Skip to content

Commit

Permalink
Implement Grounding DINO tool
Browse files Browse the repository at this point in the history
  • Loading branch information
AsiaCao committed Mar 11, 2024
1 parent 9716eb7 commit 0c9d74b
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 6 deletions.
28 changes: 28 additions & 0 deletions vision_agent/image_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import base64
from io import BytesIO
from pathlib import Path
from typing import Union

import numpy as np
from PIL import Image


def b64_to_pil(b64_str: str) -> Image:
# , can't be encoded in b64 data so must be part of prefix
if "," in b64_str:
b64_str = b64_str.split(",")[1]
return Image.open(BytesIO(base64.b64decode(b64_str)))


def convert_to_b64(data: Union[str, Path, np.ndarray, Image.Image]) -> str:
if data is None:
raise ValueError(f"Invalid input image: {data}. Input image can't be None.")
if isinstance(data, (str, Path)):
data = Image.open(data)
if isinstance(data, Image.Image):
buffer = BytesIO()
data.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
else:
arr_bytes = data.tobytes()
return base64.b64encode(arr_bytes).decode("utf-8")
6 changes: 3 additions & 3 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@
import requests

from vision_agent.tools import (
SYSTEM_PROMPT,
CHOOSE_PARAMS,
ImageTool,
CLIP,
SYSTEM_PROMPT,
GroundingDINO,
GroundingSAM,
ImageTool,
)

logging.basicConfig(level=logging.INFO)

_LOGGER = logging.getLogger(__name__)

_LLAVA_ENDPOINT = "https://cpvlqoxw6vhpdro27uhkvceady0kvvqk.lambda-url.us-east-2.on.aws"
_LLAVA_ENDPOINT = "https://svtswgdnleslqcsjvilau4p6u40jwrkn.lambda-url.us-east-2.on.aws"


def encode_image(image: Union[str, Path]) -> str:
Expand Down
25 changes: 22 additions & 3 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Dict, List, Union
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Union, cast

import requests
from PIL.Image import Image as ImageType

from vision_agent.image_utils import convert_to_b64


class ImageTool(ABC):
@abstractmethod
Expand All @@ -27,6 +31,8 @@ def __call__(self, image: Union[str, ImageType]) -> List[Dict]:


class GroundingDINO(ImageTool):
_ENDPOINT = "https://chnicr4kes5ku77niv2zoytggq0qyqlp.lambda-url.us-east-2.on.aws"

doc = (
"Grounding DINO is a tool that can detect arbitrary objects with inputs such as category names or referring expressions."
"Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n"
Expand All @@ -38,8 +44,21 @@ class GroundingDINO(ImageTool):
def __init__(self, prompt: str):
self.prompt = prompt

def __call__(self, image: Union[str, ImageType]) -> List[Dict]:
raise NotImplementedError
def __call__(self, image: Union[str, Path, ImageType]) -> List[Dict]:
image_b64 = convert_to_b64(image)
data = {
"prompt": self.prompt,
"images": [image_b64],
}
res = requests.post(
self._ENDPOINT,
headers={"Content-Type": "application/json"},
json=data,
)
resp_json: Dict[str, Any] = res.json()
# if resp_json["statusCode"] != 200:
# _LOGGER.error(f"Request failed: {resp_json['data']}")
return cast(str, resp_json["data"])


class GroundingSAM(ImageTool):
Expand Down

0 comments on commit 0c9d74b

Please sign in to comment.