Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement the Grounding DINO tool #12

Merged
merged 3 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions vision_agent/image_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import base64
from io import BytesIO
from pathlib import Path
from typing import Union

import numpy as np
from PIL import Image


def b64_to_pil(b64_str: str) -> Image.Image:
# , can't be encoded in b64 data so must be part of prefix
if "," in b64_str:
b64_str = b64_str.split(",")[1]
return Image.open(BytesIO(base64.b64decode(b64_str)))


def convert_to_b64(data: Union[str, Path, np.ndarray, Image.Image]) -> str:
if data is None:
raise ValueError(f"Invalid input image: {data}. Input image can't be None.")
if isinstance(data, (str, Path)):
data = Image.open(data)
if isinstance(data, Image.Image):
buffer = BytesIO()
data.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
else:
arr_bytes = data.tobytes()
return base64.b64encode(arr_bytes).decode("utf-8")
6 changes: 3 additions & 3 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@
import requests

from vision_agent.tools import (
SYSTEM_PROMPT,
CHOOSE_PARAMS,
ImageTool,
CLIP,
SYSTEM_PROMPT,
GroundingDINO,
GroundingSAM,
ImageTool,
)

logging.basicConfig(level=logging.INFO)

_LOGGER = logging.getLogger(__name__)

_LLAVA_ENDPOINT = "https://cpvlqoxw6vhpdro27uhkvceady0kvvqk.lambda-url.us-east-2.on.aws"
_LLAVA_ENDPOINT = "https://svtswgdnleslqcsjvilau4p6u40jwrkn.lambda-url.us-east-2.on.aws"


def encode_image(image: Union[str, Path]) -> str:
Expand Down
28 changes: 25 additions & 3 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from typing import Dict, List, Union
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Union, cast

import requests
from PIL.Image import Image as ImageType

from vision_agent.image_utils import convert_to_b64

_LOGGER = logging.getLogger(__name__)


class ImageTool(ABC):
@abstractmethod
Expand All @@ -27,6 +34,8 @@ def __call__(self, image: Union[str, ImageType]) -> List[Dict]:


class GroundingDINO(ImageTool):
_ENDPOINT = "https://chnicr4kes5ku77niv2zoytggq0qyqlp.lambda-url.us-east-2.on.aws"

doc = (
"Grounding DINO is a tool that can detect arbitrary objects with inputs such as category names or referring expressions."
"Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n"
Expand All @@ -38,8 +47,21 @@ class GroundingDINO(ImageTool):
def __init__(self, prompt: str):
self.prompt = prompt

def __call__(self, image: Union[str, ImageType]) -> List[Dict]:
raise NotImplementedError
def __call__(self, image: Union[str, Path, ImageType]) -> List[Dict]:
image_b64 = convert_to_b64(image)
data = {
"prompt": self.prompt,
"images": [image_b64],
}
res = requests.post(
self._ENDPOINT,
headers={"Content-Type": "application/json"},
json=data,
)
resp_json: Dict[str, Any] = res.json()
if resp_json["statusCode"] != 200:
_LOGGER.error(f"Request failed: {resp_json}")
return cast(List[Dict], resp_json["data"])


class GroundingSAM(ImageTool):
Expand Down
Loading