diff --git a/vision_agent/image_utils.py b/vision_agent/image_utils.py new file mode 100644 index 00000000..86533972 --- /dev/null +++ b/vision_agent/image_utils.py @@ -0,0 +1,28 @@ +import base64 +from io import BytesIO +from pathlib import Path +from typing import Union + +import numpy as np +from PIL import Image + + +def b64_to_pil(b64_str: str) -> Image.Image: + # , can't be encoded in b64 data so must be part of prefix + if "," in b64_str: + b64_str = b64_str.split(",")[1] + return Image.open(BytesIO(base64.b64decode(b64_str))) + + +def convert_to_b64(data: Union[str, Path, np.ndarray, Image.Image]) -> str: + if data is None: + raise ValueError(f"Invalid input image: {data}. Input image can't be None.") + if isinstance(data, (str, Path)): + data = Image.open(data) + if isinstance(data, Image.Image): + buffer = BytesIO() + data.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + else: + arr_bytes = data.tobytes() + return base64.b64encode(arr_bytes).decode("utf-8") diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 488048fc..14472884 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -8,19 +8,19 @@ import requests from vision_agent.tools import ( - SYSTEM_PROMPT, CHOOSE_PARAMS, - ImageTool, CLIP, + SYSTEM_PROMPT, GroundingDINO, GroundingSAM, + ImageTool, ) logging.basicConfig(level=logging.INFO) _LOGGER = logging.getLogger(__name__) -_LLAVA_ENDPOINT = "https://cpvlqoxw6vhpdro27uhkvceady0kvvqk.lambda-url.us-east-2.on.aws" +_LLAVA_ENDPOINT = "https://svtswgdnleslqcsjvilau4p6u40jwrkn.lambda-url.us-east-2.on.aws" def encode_image(image: Union[str, Path]) -> str: diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 9ca70452..474eba1e 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -1,8 +1,15 @@ -from typing import Dict, List, Union +import logging from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, Dict, List, Union, cast +import requests from PIL.Image import Image as ImageType +from vision_agent.image_utils import convert_to_b64 + +_LOGGER = logging.getLogger(__name__) + class ImageTool(ABC): @abstractmethod @@ -27,6 +34,8 @@ def __call__(self, image: Union[str, ImageType]) -> List[Dict]: class GroundingDINO(ImageTool): + _ENDPOINT = "https://chnicr4kes5ku77niv2zoytggq0qyqlp.lambda-url.us-east-2.on.aws" + doc = ( "Grounding DINO is a tool that can detect arbitrary objects with inputs such as category names or referring expressions." "Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n" @@ -38,8 +47,21 @@ class GroundingDINO(ImageTool): def __init__(self, prompt: str): self.prompt = prompt - def __call__(self, image: Union[str, ImageType]) -> List[Dict]: - raise NotImplementedError + def __call__(self, image: Union[str, Path, ImageType]) -> List[Dict]: + image_b64 = convert_to_b64(image) + data = { + "prompt": self.prompt, + "images": [image_b64], + } + res = requests.post( + self._ENDPOINT, + headers={"Content-Type": "application/json"}, + json=data, + ) + resp_json: Dict[str, Any] = res.json() + if resp_json["statusCode"] != 200: + _LOGGER.error(f"Request failed: {resp_json}") + return cast(List[Dict], resp_json["data"]) class GroundingSAM(ImageTool):